aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jesopo2018-10-30 14:58:48 +0000
committerGravatar jesopo2018-10-30 14:58:48 +0000
commite07553c3627b80f20cdc81a35030bf0540924db8 (patch)
tree0a81640b280e007cbe5d2cb956681068ab80c58e
parentDon't needlessly search a youtube URL before getting the information for it's (diff)
signature
Add type/return hints throughout src/ and, in doing so, fix some cyclical
references.
-rw-r--r--modules/karma.py2
-rw-r--r--modules/line_handler.py3
-rw-r--r--modules/sed.py13
-rw-r--r--src/Cache.py18
-rw-r--r--src/Config.py10
-rw-r--r--src/Database.py101
-rw-r--r--src/EventManager.py207
-rw-r--r--src/Exports.py34
-rw-r--r--src/IRCBot.py46
-rw-r--r--src/IRCBuffer.py29
-rw-r--r--src/IRCChannel.py83
-rw-r--r--src/IRCServer.py137
-rw-r--r--src/IRCUser.py44
-rw-r--r--src/Logging.py18
-rw-r--r--src/ModuleManager.py49
-rw-r--r--src/Socket.py26
-rw-r--r--src/Timers.py62
-rw-r--r--src/utils/__init__.py123
-rw-r--r--src/utils/consts.py2
-rw-r--r--src/utils/http.py11
-rw-r--r--src/utils/irc.py46
-rw-r--r--src/utils/parse.py57
22 files changed, 605 insertions, 516 deletions
diff --git a/modules/karma.py b/modules/karma.py
index 441b061e..d3bab796 100644
--- a/modules/karma.py
+++ b/modules/karma.py
@@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule):
if not event["user"].last_karma or (time.time()-event["user"
].last_karma) >= KARMA_DELAY_SECONDS:
target = match.group(1).strip()
- if utils.irc.lower(event["server"], target
+ if utils.irc.lower(event["server"].case_mapping, target
) == event["user"].name:
if verbose:
self.events.on("send.stderr").call(
diff --git a/modules/line_handler.py b/modules/line_handler.py
index da0d5d07..3ee8304b 100644
--- a/modules/line_handler.py
+++ b/modules/line_handler.py
@@ -542,7 +542,8 @@ class Module(ModuleManager.BaseModule):
# we need a registered nickname for this channel
@utils.hook("raw.477", default_event=True)
def handle_477(self, event):
- channel_name = utils.irc.lower(event["server"], event["args"][1])
+ channel_name = utils.irc.lower(event["server"].case_mapping,
+ event["args"][1])
if channel_name in event["server"]:
key = event["server"].attempted_join[channel_name]
self.timers.add("rejoin", 5, channel_name=channe_name, key=key,
diff --git a/modules/sed.py b/modules/sed.py
index 6de7b4af..7b98ab70 100644
--- a/modules/sed.py
+++ b/modules/sed.py
@@ -11,12 +11,16 @@ REGEX_SED = re.compile("^s/")
"help": "Disable/Enable sed only looking at the messages sent by the user",
"validate": utils.bool_or_none})
class Module(ModuleManager.BaseModule):
+ def _closest_setting(self, event, setting, default):
+ return event["channel"].get_setting(setting,
+ event["server"].get_setting(setting, default))
+
@utils.hook("received.message.channel")
def channel_message(self, event):
sed_split = re.split(REGEX_SPLIT, event["message"], 3)
if event["message"].startswith("s/") and len(sed_split) > 2:
- if event["action"] or not utils.get_closest_setting(
- event, "sed", False):
+ if event["action"] or not self._closest_setting(event, "sed",
+ False):
return
regex_flags = 0
@@ -48,9 +52,8 @@ class Module(ModuleManager.BaseModule):
return
replace = sed_split[2].replace("\\/", "/")
- for_user = event["user"].nickname if utils.get_closest_setting(
- event, "sed-sender-only", False
- ) else None
+ for_user = event["user"].nickname if self._closest_setting(event,
+ "sed-sender-only", False) else None
line = event["channel"].buffer.find(pattern, from_self=False,
for_user=for_user, not_pattern=REGEX_SED)
if line:
diff --git a/src/Cache.py b/src/Cache.py
index 2de55afa..46b39bda 100644
--- a/src/Cache.py
+++ b/src/Cache.py
@@ -1,21 +1,21 @@
-import time, uuid
+import time, typing, uuid
class Cache(object):
def __init__(self):
self._items = {}
self._item_to_id = {}
- def cache(self, item):
+ def cache(self, item: typing.Any) -> str:
return self._cache(item, None)
- def temporary_cache(self, item, timeout):
+ def temporary_cache(self, item: typing.Any, timeout: float)-> str:
return self._cache(item, timeout)
- def _cache(self, item, timeout):
+ def _cache(self, item: typing.Any, timeout: float) -> str:
id = str(uuid.uuid4())
self._items[id] = [item, time.monotonic()+timeout]
self._item_to_id[item] = id
return id
- def next_expiration(self):
+ def next_expiration(self) -> float:
expirations = [self._items[id][1] for id in self._items]
expirations = list(filter(None, expirations))
if not expirations:
@@ -35,17 +35,17 @@ class Cache(object):
del self._items[id]
del self._item_to_id[item]
- def has_item(self, item):
+ def has_item(self, item: typing.Any) -> bool:
return item in self._item_to_id
- def get(self, id):
+ def get(self, id: str) -> typing.Any:
item, expiration = self._items[id]
return item
- def get_expiration(self, item):
+ def get_expiration(self, item: typing.Any) -> float:
id = self._item_to_id[item]
item, expiration = self._items[id]
return expiration
- def until_expiration(self, item):
+ def until_expiration(self, item: typing.Any) -> float:
expiration = self.get_expiration(item)
return expiration-time.monotonic()
diff --git a/src/Config.py b/src/Config.py
index 611b5b7b..dacb14dd 100644
--- a/src/Config.py
+++ b/src/Config.py
@@ -1,7 +1,7 @@
-import configparser, os
+import configparser, os, typing
class Config(object):
- def __init__(self, location):
+ def __init__(self, location: str):
self.location = location
self._config = {}
self.load()
@@ -13,10 +13,10 @@ class Config(object):
parser.read_string(config_file.read())
self._config = dict(parser["bot"].items())
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> typing.Any:
return self._config[key]
- def get(self, key, default=None):
+ def get(self, key: str, default: typing.Any=None) -> typing.Any:
return self._config.get(key, default)
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return key in self._config
diff --git a/src/Database.py b/src/Database.py
index c1d07869..c3d48cb6 100644
--- a/src/Database.py
+++ b/src/Database.py
@@ -1,12 +1,14 @@
-import json, os, sqlite3, threading, time
+import json, os, sqlite3, threading, time, typing
+from src import Logging
class Table(object):
def __init__(self, database):
self.database = database
class Servers(Table):
- def add(self, alias, hostname, port, password, ipv4, tls, bindhost,
- nickname, username=None, realname=None):
+ def add(self, alias: str, hostname: str, port: int, password: str,
+ ipv4: bool, tls: bool, bindhost: str,
+ nickname: str, username: str=None, realname: str=None):
username = username or nickname
realname = realname or nickname
self.database.execute(
@@ -18,7 +20,7 @@ class Servers(Table):
def get_all(self):
return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers")
- def get(self, id):
+ def get(self, id: int):
return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname FROM servers WHERE
@@ -26,46 +28,46 @@ class Servers(Table):
[id])
class Channels(Table):
- def add(self, server_id, name):
+ def add(self, server_id: int, name: str):
self.database.execute("""INSERT OR IGNORE INTO channels
(server_id, name) VALUES (?, ?)""",
[server_id, name.lower()])
- def delete(self, channel_id):
+ def delete(self, channel_id: int):
self.database.execute("DELETE FROM channels WHERE channel_id=?",
[channel_id])
- def get_id(self, server_id, name):
+ def get_id(self, server_id: int, name: str):
value = self.database.execute_fetchone("""SELECT channel_id FROM
channels WHERE server_id=? AND name=?""",
[server_id, name.lower()])
return value if value == None else value[0]
class Users(Table):
- def add(self, server_id, nickname):
+ def add(self, server_id: int, nickname: str):
self.database.execute("""INSERT OR IGNORE INTO users
(server_id, nickname) VALUES (?, ?)""",
[server_id, nickname.lower()])
- def delete(self, user_id):
+ def delete(self, user_id: int):
self.database.execute("DELETE FROM users WHERE user_id=?",
[user_id])
- def get_id(self, server_id, nickname):
+ def get_id(self, server_id: int, nickname: str):
value = self.database.execute_fetchone("""SELECT user_id FROM
users WHERE server_id=? and nickname=?""",
[server_id, nickname.lower()])
return value if value == None else value[0]
class BotSettings(Table):
- def set(self, setting, value):
+ def set(self, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO bot_settings VALUES (?, ?)",
[setting.lower(), json.dumps(value)])
- def get(self, setting, default=None):
+ def get(self, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"SELECT value FROM bot_settings WHERE setting=?",
[setting.lower()])
if value:
return json.loads(value[0])
return default
- def find(self, pattern, default=[]):
+ def find(self, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"SELECT setting, value FROM bot_settings WHERE setting LIKE ?",
[pattern.lower()])
@@ -74,19 +76,19 @@ class BotSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, prefix, default=[]):
+ def find_prefix(self, prefix: str, default: typing.Any=[]):
return self.find("%s%%" % prefix, default)
- def delete(self, setting):
+ def delete(self, setting: str):
self.database.execute(
"DELETE FROM bot_settings WHERE setting=?",
[setting.lower()])
class ServerSettings(Table):
- def set(self, server_id, setting, value):
+ def set(self, server_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)",
[server_id, setting.lower(), json.dumps(value)])
- def get(self, server_id, setting, default=None):
+ def get(self, server_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM server_settings WHERE
server_id=? AND setting=?""",
@@ -94,7 +96,7 @@ class ServerSettings(Table):
if value:
return json.loads(value[0])
return default
- def find(self, server_id, pattern, default=[]):
+ def find(self, server_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM server_settings WHERE
server_id=? AND setting LIKE ?""",
@@ -104,26 +106,26 @@ class ServerSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, server_id, prefix, default=[]):
+ def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]):
return self.find_server_settings(server_id, "%s%%" % prefix, default)
- def delete(self, server_id, setting):
+ def delete(self, server_id: int, setting: str):
self.database.execute(
"DELETE FROM server_settings WHERE server_id=? AND setting=?",
[server_id, setting.lower()])
class ChannelSettings(Table):
- def set(self, channel_id, setting, value):
+ def set(self, channel_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)",
[channel_id, setting.lower(), json.dumps(value)])
- def get(self, channel_id, setting, default=None):
+ def get(self, channel_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM channel_settings WHERE
channel_id=? AND setting=?""", [channel_id, setting.lower()])
if value:
return json.loads(value[0])
return default
- def find(self, channel_id, pattern, default=[]):
+ def find(self, channel_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM channel_settings WHERE
channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()])
@@ -132,15 +134,15 @@ class ChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, channel_id, prefix, default=[]):
+ def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]):
return self.find_channel_settings(channel_id, "%s%%" % prefix,
default)
- def delete(self, channel_id, setting):
+ def delete(self, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM channel_settings WHERE channel_id=?
AND setting=?""", [channel_id, setting.lower()])
- def find_by_setting(self, setting, default=[]):
+ def find_by_setting(self, setting: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.server_id, channels.name,
channel_settings.value FROM channel_settings
@@ -154,18 +156,19 @@ class ChannelSettings(Table):
return default
class UserSettings(Table):
- def set(self, user_id, setting, value):
+ def set(self, user_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)",
[user_id, setting.lower(), json.dumps(value)])
- def get(self, user_id, setting, default=None):
+ def get(self, user_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_settings WHERE
user_id=? and setting=?""", [user_id, setting.lower()])
if value:
return json.loads(value[0])
return default
- def find_all_by_setting(self, server_id, setting, default=[]):
+ def find_all_by_setting(self, server_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT users.nickname, user_settings.value FROM
user_settings INNER JOIN users ON
@@ -177,7 +180,7 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find(self, user_id, pattern, default=[]):
+ def find(self, user_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute(
"""SELECT setting, value FROM user_settings WHERE
user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()])
@@ -186,20 +189,22 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, user_id, prefix, default=[]):
+ def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]):
return self.find_user_settings(user_id, "%s%%" % prefix, default)
- def delete(self, user_id, setting):
+ def delete(self, user_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_settings WHERE
user_id=? AND setting=?""", [user_id, setting.lower()])
class UserChannelSettings(Table):
- def set(self, user_id, channel_id, setting, value):
+ def set(self, user_id: int, channel_id: int, setting: str,
+ value: typing.Any):
self.database.execute(
"""INSERT OR REPLACE INTO user_channel_settings VALUES
(?, ?, ?, ?)""",
[user_id, channel_id, setting.lower(), json.dumps(value)])
- def get(self, user_id, channel_id, setting, default=None):
+ def get(self, user_id: int, channel_id: int, setting: str,
+ default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
@@ -207,7 +212,8 @@ class UserChannelSettings(Table):
if value:
return json.loads(value[0])
return default
- def find(self, user_id, channel_id, pattern, default=[]):
+ def find(self, user_id: int, channel_id: int, pattern: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting LIKE '?'""",
@@ -217,10 +223,12 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, user_id, channel_id, prefix, default=[]):
+ def find_prefix(self, user_id: int, channel_id: int, prefix: str,
+ default: typing.Any=[]):
return self.find_user_settings(user_id, channel_id, "%s%%" % prefix,
default)
- def find_by_setting(self, user_id, setting, default=[]):
+ def find_by_setting(self, user_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, user_channel_settings.value FROM
user_channel_settings INNER JOIN channels ON
@@ -232,7 +240,8 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_all_by_setting(self, server_id, setting, default=[]):
+ def find_all_by_setting(self, server_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, users.nickname,
user_channel_settings.value FROM
@@ -246,14 +255,14 @@ class UserChannelSettings(Table):
values[i] = value[0], value[1], json.loads(value[2])
return values
return default
- def delete(self, user_id, channel_id, setting):
+ def delete(self, user_id: int, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
[user_id, channel_id, setting.lower()])
class Database(object):
- def __init__(self, log, location):
+ def __init__(self, log: "Logging.Log", location: str):
self.log = log
self.location = location
self.database = sqlite3.connect(self.location,
@@ -284,7 +293,9 @@ class Database(object):
self._cursor = self.database.cursor()
return self._cursor
- def _execute_fetch(self, query, fetch_func, params=[]):
+ def _execute_fetch(self, query: str,
+ fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
+ params: typing.List=[]):
printable_query = " ".join(query.split())
self.log.trace("executing query: \"%s\" (params: %s)",
[printable_query, params])
@@ -299,16 +310,16 @@ class Database(object):
self.log.trace("executed in %fms", [total_milliseconds])
return value
- def execute_fetchall(self, query, params=[]):
+ def execute_fetchall(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchall(), params)
- def execute_fetchone(self, query, params=[]):
+ def execute_fetchone(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchone(), params)
- def execute(self, query, params=[]):
+ def execute(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, lambda cursor: None, params)
- def has_table(self, table_name):
+ def has_table(self, table_name: str):
result = self.execute_fetchone("""SELECT COUNT(*) FROM
sqlite_master WHERE type='table' AND name=?""",
[table_name])
diff --git a/src/EventManager.py b/src/EventManager.py
index 15115ad8..bdaa3b7a 100644
--- a/src/EventManager.py
+++ b/src/EventManager.py
@@ -1,5 +1,5 @@
-import itertools, time, traceback
-from src import utils
+import itertools, time, traceback, typing
+from src import Logging, utils
PRIORITY_URGENT = 0
PRIORITY_HIGH = 1
@@ -11,94 +11,39 @@ DEFAULT_PRIORITY = PRIORITY_MEDIUM
DEFAULT_EVENT_DELIMITER = "."
DEFAULT_MULTI_DELIMITER = "|"
+CALLBACK_TYPE = typing.Callable[["Event"], typing.Any]
+
class Event(object):
- def __init__(self, name, **kwargs):
+ def __init__(self, name: str, **kwargs):
self.name = name
self.kwargs = kwargs
self.eaten = False
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> typing.Any:
return self.kwargs[key]
- def get(self, key, default=None):
+ def get(self, key: str, default=None) -> typing.Any:
return self.kwargs.get(key, default)
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return key in self.kwargs
def eat(self):
self.eaten = True
class EventCallback(object):
- def __init__(self, function, priority, kwargs):
+ def __init__(self, function: CALLBACK_TYPE, priority: int, kwargs: dict):
self.function = function
self.priority = priority
self.kwargs = kwargs
- self.docstring = utils.parse_docstring(function.__doc__)
+ self.docstring = utils.parse.docstring(function.__doc__)
- def call(self, event):
+ def call(self, event: Event) -> typing.Any:
return self.function(event)
- def get_kwarg(self, name, default=None):
+ def get_kwarg(self, name: str, default=None) -> typing.Any:
item = self.kwargs.get(name, default)
return item or self.docstring.items.get(name, default)
-class MultipleEventHook(object):
- def __init__(self):
- self._event_hooks = set([])
- def _add(self, event_hook):
- self._event_hooks.add(event_hook)
-
- def hook(self, function, **kwargs):
- for event_hook in self._event_hooks:
- event_hook.hook(function, **kwargs)
-
- def call_limited(self, maximum, **kwargs):
- returns = []
- for event_hook in self._event_hooks:
- returns.append(event_hook.call_limited(maximum, **kwargs))
- return returns
- def call(self, **kwargs):
- returns = []
- for event_hook in self._event_hooks:
- returns.append(event_hook.call(**kwargs))
- return returns
-
-class EventHookContext(object):
- def __init__(self, parent, context):
- self._parent = parent
- self.context = context
- def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
- **kwargs):
- return self._parent._context_hook(self.context, function, priority,
- replay, kwargs)
- def unhook(self, callback):
- self._parent.unhook(callback)
-
- def on(self, subevent, *extra_subevents,
- delimiter=DEFAULT_EVENT_DELIMITER):
- return self._parent._context_on(self.context, subevent,
- extra_subevents, delimiter)
-
- def call_for_result(self, default=None, **kwargs):
- return self._parent.call_for_result(default, **kwargs)
- def assure_call(self, **kwargs):
- self._parent.assure_call(**kwargs)
- def call(self, **kwargs):
- return self._parent.call(**kwargs)
- def call_limited(self, maximum, **kwargs):
- return self._parent.call_limited(maximum, **kwargs)
-
- def call_unsafe_for_result(self, default=None, **kwargs):
- return self._parent.call_unsafe_for_result(default, **kwargs)
- def call_unsafe(self, **kwargs):
- return self._parent.call_unsafe(**kwargs)
- def call_unsafe_limited(self, maximum, **kwargs):
- return self._parent.call_unsafe_limited(maximum, **kwargs)
-
- def get_hooks(self):
- return self._parent.get_hooks()
- def get_children(self):
- return self._parent.get_children()
-
class EventHook(object):
- def __init__(self, log, name=None, parent=None):
+ def __init__(self, log: Logging.Log, name: str = None,
+ parent: "EventHook" = None):
self.log = log
self.name = name
self.parent = parent
@@ -107,10 +52,10 @@ class EventHook(object):
self._stored_events = []
self._context_hooks = {}
- def _make_event(self, kwargs):
+ def _make_event(self, kwargs: dict) -> Event:
return Event(self._get_path(), **kwargs)
- def _get_path(self):
+ def _get_path(self) -> str:
path = []
parent = self
while not parent == None and not parent.name == None:
@@ -118,15 +63,17 @@ class EventHook(object):
parent = parent.parent
return DEFAULT_EVENT_DELIMITER.join(path[::-1])
- def new_context(self, context):
+ def new_context(self, context: str) -> "EventHookContext":
return EventHookContext(self, context)
- def hook(self, function, priority=DEFAULT_PRIORITY, replay=False,
- **kwargs):
+ def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
+ replay: bool = False, **kwargs) -> EventCallback:
return self._hook(function, None, priority, replay, kwargs)
- def _context_hook(self, context, function, priority, replay, kwargs):
+ def _context_hook(self, context: str, function: CALLBACK_TYPE,
+ priority: int, replay: bool, kwargs: dict) -> EventCallback:
return self._hook(function, context, priority, replay, kwargs)
- def _hook(self, function, context, priority, replay, kwargs):
+ def _hook(self, function: CALLBACK_TYPE, context: str, priority: int,
+ replay: bool, kwargs: dict) -> EventCallback:
callback = EventCallback(function, priority, kwargs)
if context == None:
@@ -142,7 +89,7 @@ class EventHook(object):
self._stored_events = None
return callback
- def unhook(self, callback):
+ def unhook(self, callback: "EventHook"):
if callback in self._hooks:
self._hooks.remove(callback)
@@ -155,7 +102,8 @@ class EventHook(object):
for context in empty:
del self._context_hooks[context]
- def _make_multiple_hook(self, source, context, events):
+ def _make_multiple_hook(self, source: "EventHook", context: str,
+ events: typing.List[str]) -> "MultipleEventHook":
multiple_event_hook = MultipleEventHook()
for event in events:
event_hook = source.get_child(event)
@@ -164,13 +112,15 @@ class EventHook(object):
multiple_event_hook._add(event_hook)
return multiple_event_hook
- def on(self, subevent, *extra_subevents,
- delimiter=DEFAULT_EVENT_DELIMITER):
+ def on(self, subevent: str, *extra_subevents,
+ delimiter: int = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, None, delimiter)
- def _context_on(self, context, subevent, extra_subevents,
- delimiter=DEFAULT_EVENT_DELIMITER):
+ def _context_on(self, context: str, subevent: str,
+ extra_subevents: typing.List[str],
+ delimiter: str = DEFAULT_EVENT_DELIMITER) -> "EventHook":
return self._on(subevent, extra_subevents, context, delimiter)
- def _on(self, subevent, extra_subevents, context, delimiter):
+ def _on(self, subevent: str, extra_subevents: typing.List[str],
+ context: str, delimiter: str) -> "EventHook":
if delimiter in subevent:
event_chain = subevent.split(delimiter)
event_obj = self
@@ -193,26 +143,28 @@ class EventHook(object):
child = child.new_context(context)
return child
- def call_for_result(self, default=None, **kwargs):
+ def call_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_limited(1, **kwargs) or [default])[0]
def assure_call(self, **kwargs):
if not self._stored_events == None:
self._stored_events.append(kwargs)
else:
self._call(kwargs, True, None)
- def call(self, **kwargs):
+ def call(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
- def call_limited(self, maximum, **kwargs):
+ def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, True, None)
- def call_unsafe_for_result(self, default=None, **kwargs):
+ def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
return (self.call_unsafe_limited(1, **kwargs) or [default])[0]
- def call_unsafe(self, **kwargs):
+ def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
return self._call(kwargs, False, None)
- def call_unsafe_limited(self, maximum, **kwargs):
+ def call_unsafe_limited(self, maximum: int, **kwargs
+ ) -> typing.List[typing.Any]:
return self._call(kwargs, False, maximum)
- def _call(self, kwargs, safe, maximum):
+ def _call(self, kwargs: dict, safe: bool, maximum: int
+ ) -> typing.List[typing.Any]:
event_path = self._get_path()
self.log.trace("calling event: \"%s\" (params: %s)",
[event_path, kwargs])
@@ -240,13 +192,13 @@ class EventHook(object):
return returns
- def get_child(self, child_name):
+ def get_child(self, child_name: str) -> "EventHook":
child_name_lower = child_name.lower()
if not child_name_lower in self._children:
self._children[child_name_lower] = EventHook(self.log,
child_name_lower, self)
return self._children[child_name_lower]
- def remove_child(self, child_name):
+ def remove_child(self, child_name: str):
child_name_lower = child_name.lower()
if child_name_lower in self._children:
del self._children[child_name_lower]
@@ -256,11 +208,11 @@ class EventHook(object):
self.parent.remove_child(self.name)
self.parent.check_purge()
- def remove_context(self, context):
+ def remove_context(self, context: str):
del self._context_hooks[context]
- def has_context(self, context):
+ def has_context(self, context: str) -> bool:
return context in self._context_hooks
- def purge_context(self, context):
+ def purge_context(self, context: str):
if self.has_context(context):
self.remove_context(context)
@@ -268,10 +220,69 @@ class EventHook(object):
child = self.get_child(child_name)
child.purge_context(context)
- def get_hooks(self):
+ def get_hooks(self) -> typing.List[EventCallback]:
return sorted(self._hooks + sum(self._context_hooks.values(), []),
key=lambda e: e.priority)
- def get_children(self):
+ def get_children(self) -> typing.List["EventHook"]:
return list(self._children.keys())
- def is_empty(self):
+ def is_empty(self) -> bool:
return len(self.get_hooks() + self.get_children()) == 0
+
+class MultipleEventHook(object):
+ def __init__(self):
+ self._event_hooks = set([])
+ def _add(self, event_hook: EventHook):
+ self._event_hooks.add(event_hook)
+
+ def hook(self, function: CALLBACK_TYPE, **kwargs):
+ for event_hook in self._event_hooks:
+ event_hook.hook(function, **kwargs)
+
+ def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
+ returns = []
+ for event_hook in self._event_hooks:
+ returns.append(event_hook.call_limited(maximum, **kwargs))
+ return returns
+ def call(self, **kwargs) -> typing.List[typing.Any]:
+ returns = []
+ for event_hook in self._event_hooks:
+ returns.append(event_hook.call(**kwargs))
+ return returns
+
+class EventHookContext(object):
+ def __init__(self, parent, context):
+ self._parent = parent
+ self.context = context
+ def hook(self, function: CALLBACK_TYPE, priority: int = DEFAULT_PRIORITY,
+ replay: bool = False, **kwargs) -> EventCallback:
+ return self._parent._context_hook(self.context, function, priority,
+ replay, kwargs)
+ def unhook(self, callback: EventCallback):
+ self._parent.unhook(callback)
+
+ def on(self, subevent: str, *extra_subevents,
+ delimiter: str = DEFAULT_EVENT_DELIMITER) -> EventHook:
+ return self._parent._context_on(self.context, subevent,
+ extra_subevents, delimiter)
+
+ def call_for_result(self, default=None, **kwargs) -> typing.Any:
+ return self._parent.call_for_result(default, **kwargs)
+ def assure_call(self, **kwargs):
+ self._parent.assure_call(**kwargs)
+ def call(self, **kwargs) -> typing.List[typing.Any]:
+ return self._parent.call(**kwargs)
+ def call_limited(self, maximum: int, **kwargs) -> typing.List[typing.Any]:
+ return self._parent.call_limited(maximum, **kwargs)
+
+ def call_unsafe_for_result(self, default=None, **kwargs) -> typing.Any:
+ return self._parent.call_unsafe_for_result(default, **kwargs)
+ def call_unsafe(self, **kwargs) -> typing.List[typing.Any]:
+ return self._parent.call_unsafe(**kwargs)
+ def call_unsafe_limited(self, maximum: int, **kwargs
+ ) -> typing.List[typing.Any]:
+ return self._parent.call_unsafe_limited(maximum, **kwargs)
+
+ def get_hooks(self) -> typing.List[EventCallback]:
+ return self._parent.get_hooks()
+ def get_children(self) -> typing.List[EventHook]:
+ return self._parent.get_children()
diff --git a/src/Exports.py b/src/Exports.py
index 8baca50d..68b25933 100644
--- a/src/Exports.py
+++ b/src/Exports.py
@@ -1,28 +1,18 @@
-
-
-class ExportsContext(object):
- def __init__(self, parent, context):
- self._parent = parent
- self.context = context
-
- def add(self, setting, value):
- self._parent._context_add(self.context, setting, value)
- def get_all(self, setting):
- return self._parent.get_all(setting)
+import typing
class Exports(object):
def __init__(self):
self._exports = {}
self._context_exports = {}
- def new_context(self, context):
+ def new_context(self, context: str) -> "ExportsContext":
return ExportsContext(self, context)
- def add(self, setting, value):
+ def add(self, setting: str, value: typing.Any):
self._add(None, setting, value)
- def _context_add(self, context, setting, value):
+ def _context_add(self, context: str, setting: str, value: typing.Any):
self._add(context, setting, value)
- def _add(self, context, setting, value):
+ def _add(self, context: str, setting: str, value: typing.Any):
if context == None:
if not setting in self_exports:
self._exports[setting] = []
@@ -34,11 +24,21 @@ class Exports(object):
self._context_exports[context][setting] = []
self._context_exports[context][setting].append(value)
- def get_all(self, setting):
+ def get_all(self, setting: str) -> typing.List[typing.Any]:
return self._exports.get(setting, []) + sum([
exports.get(setting, []) for exports in
self._context_exports.values()], [])
- def purge_context(self, context):
+ def purge_context(self, context: str):
if context in self._context_exports:
del self._context_exports[context]
+
+class ExportsContext(object):
+ def __init__(self, parent: Exports, context: str):
+ self._parent = parent
+ self.context = context
+
+ def add(self, setting: str, value: typing.Any):
+ self._parent._context_add(self.context, setting, value)
+ def get_all(self, setting: str) -> typing.List[typing.Any]:
+ return self._parent.get_all(setting)
diff --git a/src/IRCBot.py b/src/IRCBot.py
index 657fb135..61f21114 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,4 +1,4 @@
-import os, select, socket, sys, threading, time, traceback, uuid
+import os, select, socket, sys, threading, time, traceback, typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import Socket, utils
@@ -28,14 +28,15 @@ class Bot(object):
self._trigger_functions = []
- def trigger(self, func=None):
+ def trigger(self, func: typing.Callable[[], typing.Any]=None):
self.lock.acquire()
if func:
self._trigger_functions.append(func)
self._trigger_client.send(b"TRIGGER")
self.lock.release()
- def add_server(self, server_id, connect=True):
+ def add_server(self, server_id: int, connect: bool = True
+ ) -> typing.Optional[IRCServer.Server]:
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id)
@@ -49,20 +50,20 @@ class Bot(object):
self.connect(new_server)
return new_server
- def add_socket(self, sock):
+ def add_socket(self, sock: socket.socket):
self.other_sockets[sock.fileno()] = sock
self.poll.register(sock.fileno(), select.EPOLLIN)
- def remove_socket(self, sock):
+ def remove_socket(self, sock: socket.socket):
del self.other_sockets[sock.fileno()]
self.poll.unregister(sock.fileno())
- def get_server(self, id):
+ def get_server(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values():
if server.id == id:
return server
- def connect(self, server):
+ def connect(self, server: IRCServer.Server) -> bool:
try:
server.connect()
except:
@@ -73,7 +74,7 @@ class Bot(object):
self.poll.register(server.fileno(), select.EPOLLOUT)
return True
- def next_send(self):
+ def next_send(self) -> typing.Optional[float]:
next = None
for server in self.servers.values():
timeout = server.send_throttle_timeout()
@@ -81,7 +82,7 @@ class Bot(object):
next = timeout
return next
- def next_ping(self):
+ def next_ping(self) -> typing.Optional[float]:
timeouts = []
for server in self.servers.values():
timeout = server.until_next_ping()
@@ -90,7 +91,8 @@ class Bot(object):
if not timeouts:
return None
return min(timeouts)
- def next_read_timeout(self):
+
+ def next_read_timeout(self) -> typing.Optional[float]:
timeouts = []
for server in self.servers.values():
timeouts.append(server.until_read_timeout())
@@ -98,7 +100,7 @@ class Bot(object):
return None
return min(timeouts)
- def get_poll_timeout(self):
+ def get_poll_timeout(self) -> float:
timeouts = []
timeouts.append(self._timers.next())
timeouts.append(self.next_send())
@@ -107,15 +109,15 @@ class Bot(object):
timeouts.append(self.cache.next_expiration())
return min([timeout for timeout in timeouts if not timeout == None])
- def register_read(self, server):
+ def register_read(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLIN)
- def register_write(self, server):
+ def register_write(self, server: IRCServer.Server):
self.poll.modify(server.fileno(), select.EPOLLOUT)
- def register_both(self, server):
+ def register_both(self, server: IRCServer.Server):
self.poll.modify(server.fileno(),
select.EPOLLIN|select.EPOLLOUT)
- def disconnect(self, server):
+ def disconnect(self, server: IRCServer.Server):
try:
self.poll.unregister(server.fileno())
except FileNotFoundError:
@@ -123,23 +125,25 @@ class Bot(object):
del self.servers[server.fileno()]
@utils.hook("timer.reconnect")
- def reconnect(self, event):
+ def reconnect(self, event: EventManager.Event):
server = self.add_server(event["server_id"], False)
if self.connect(server):
self.servers[server.fileno()] = server
else:
event["timer"].redo()
- def set_setting(self, setting, value):
+ def set_setting(self, setting: str, value: typing.Any):
self.database.bot_settings.set(setting, value)
- def get_setting(self, setting, default=None):
+ def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.database.bot_settings.get(setting, default)
- def find_settings(self, pattern, default=[]):
+ def find_settings(self, pattern: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.database.bot_settings.find(pattern, default)
- def find_settings_prefix(self, prefix, default=[]):
+ def find_settings_prefix(self, prefix: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.database.bot_settings.find_prefix(
prefix, default)
- def del_setting(self, setting):
+ def del_setting(self, setting: str):
self.database.bot_settings.delete(setting)
def run(self):
diff --git a/src/IRCBuffer.py b/src/IRCBuffer.py
index 06465749..24fde7bf 100644
--- a/src/IRCBuffer.py
+++ b/src/IRCBuffer.py
@@ -1,8 +1,9 @@
-import re
-from src import utils
+import re, typing
+from src import IRCBot, utils
class BufferLine(object):
- def __init__(self, sender, message, action, tags, from_self, method):
+ def __init__(self, sender: str, message: str, action: bool, tags: dict,
+ from_self: bool, method: str):
self.sender = sender
self.message = message
self.action = action
@@ -11,35 +12,39 @@ class BufferLine(object):
self.method = method
class Buffer(object):
- def __init__(self, bot, server):
+ def __init__(self, bot: "IRCBot.Bot", server: "IRCServer.Server"):
self.bot = bot
self.server = server
self.lines = []
self.max_lines = 64
self._skip_next = False
- def _add_message(self, sender, message, action, tags, from_self, method):
+ def _add_message(self, sender: str, message: str, action: bool, tags: dict,
+ from_self: bool, method: str):
if not self._skip_next:
line = BufferLine(sender, message, action, tags, from_self, method)
self.lines.insert(0, line)
if len(self.lines) > self.max_lines:
self.lines.pop()
self._skip_next = False
- def add_message(self, sender, message, action, tags, from_self=False):
+ def add_message(self, sender: str, message: str, action: bool, tags: dict,
+ from_self: bool=False):
self._add_message(sender, message, action, tags, from_self, "PRIVMSG")
- def add_notice(self, sender, message, tags, from_self=False):
+ def add_notice(self, sender: str, message: str, tags: dict,
+ from_self: bool=False):
self._add_message(sender, message, False, tags, from_self, "NOTICE")
- def get(self, index=0, **kwargs):
+ def get(self, index: int=0, **kwargs) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True)
for line in self.lines:
if line.from_self and not from_self:
continue
return line
- def find(self, pattern, **kwargs):
+ def find(self, pattern: typing.Union[str, typing.Pattern[str]], **kwargs
+ ) -> typing.Optional[BufferLine]:
from_self = kwargs.get("from_self", True)
for_user = kwargs.get("for_user", "")
- for_user = utils.irc.lower(self.server, for_user
+ for_user = utils.irc.lower(self.server.case_mapping, for_user
) if for_user else None
not_pattern = kwargs.get("not_pattern", None)
for line in self.lines:
@@ -48,8 +53,8 @@ class Buffer(object):
elif re.search(pattern, line.message):
if not_pattern and re.search(not_pattern, line.message):
continue
- if for_user and not utils.irc.lower(self.server, line.sender
- ) == for_user:
+ if for_user and not utils.irc.lower(self.server.case_mapping,
+ line.sender) == for_user:
continue
return line
def skip_next(self):
diff --git a/src/IRCChannel.py b/src/IRCChannel.py
index ffe34fa9..7d2c38fa 100644
--- a/src/IRCChannel.py
+++ b/src/IRCChannel.py
@@ -1,9 +1,10 @@
-import uuid
-from src import IRCBuffer, IRCObject, utils
+import typing, uuid
+from src import IRCBot, IRCBuffer, IRCObject, IRCServer, IRCUser, utils
class Channel(IRCObject.Object):
- def __init__(self, name, id, server, bot):
- self.name = utils.irc.lower(server, name)
+ def __init__(self, name: str, id, server: "IRCServer.Server",
+ bot: "IRCBot.Bot"):
+ self.name = utils.irc.lower(server.case_mapping, name)
self.id = id
self.server = server
self.bot = bot
@@ -18,23 +19,24 @@ class Channel(IRCObject.Object):
self.created_timestamp = None
self.buffer = IRCBuffer.Buffer(bot, server)
- def __repr__(self):
+ def __repr__(self) -> str:
return "IRCChannel.Channel(%s|%s)" % (self.server.name, self.name)
- def __str__(self):
+ def __str__(self) -> str:
return self.name
- def set_topic(self, topic):
+ def set_topic(self, topic: str):
self.topic = topic
- def set_topic_setter(self, nickname, username=None, hostname=None):
+ def set_topic_setter(self, nickname: str, username: str=None,
+ hostname: str=None):
self.topic_setter_nickname = nickname
self.topic_setter_username = username
self.topic_setter_hostname = hostname
- def set_topic_time(self, unix_timestamp):
+ def set_topic_time(self, unix_timestamp: int):
self.topic_time = unix_timestamp
- def add_user(self, user):
+ def add_user(self, user: IRCUser.User):
self.users.add(user)
- def remove_user(self, user):
+ def remove_user(self, user: IRCUser.User):
self.users.remove(user)
for mode in list(self.modes.keys()):
if mode in self.server.prefix_modes and user in self.modes[mode]:
@@ -43,10 +45,10 @@ class Channel(IRCObject.Object):
del self.modes[mode]
if user in self.user_modes:
del self.user_modes[user]
- def has_user(self, user):
+ def has_user(self, user: IRCUser.User) -> bool:
return user in self.users
- def add_mode(self, mode, arg=None):
+ def add_mode(self, mode: str, arg: str=None):
if not mode in self.modes:
self.modes[mode] = set([])
if arg:
@@ -59,7 +61,7 @@ class Channel(IRCObject.Object):
self.user_modes[user].add(mode)
else:
self.modes[mode].add(arg.lower())
- def remove_mode(self, mode, arg=None):
+ def remove_mode(self, mode: str, arg: str=None):
if not arg:
del self.modes[mode]
else:
@@ -76,63 +78,70 @@ class Channel(IRCObject.Object):
self.modes[mode].discard(arg.lower())
if not len(self.modes[mode]):
del self.modes[mode]
- def change_mode(self, remove, mode, arg=None):
+ def change_mode(self, remove: bool, mode: str, arg: str=None):
if remove:
self.remove_mode(mode, arg)
else:
self.add_mode(mode, arg)
- def set_setting(self, setting, value):
+ def set_setting(self, setting: str, value: typing.Any):
self.bot.database.channel_settings.set(self.id, setting, value)
- def get_setting(self, setting, default=None):
+ def get_setting(self, setting: str, default: typing.Any=None
+ ) -> typing.Any:
return self.bot.database.channel_settings.get(self.id, setting,
default)
- def find_settings(self, pattern, default=[]):
+ def find_settings(self, pattern: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find(self.id, pattern,
default)
- def find_settings_prefix(self, prefix, default=[]):
+ def find_settings_prefix(self, prefix: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.bot.database.channel_settings.find_prefix(self.id,
prefix, default)
- def del_setting(self, setting):
+ def del_setting(self, setting: str):
self.bot.database.channel_settings.delete(self.id, setting)
- def set_user_setting(self, user_id, setting, value):
+ def set_user_setting(self, user_id: int, setting: str, value: typing.Any):
self.bot.database.user_channel_settings.set(user_id, self.id,
setting, value)
- def get_user_setting(self, user_id, setting, default=None):
+ def get_user_setting(self, user_id: int, setting: str,
+ default: typing.Any=None) -> typing.Any:
return self.bot.database.user_channel_settings.get(user_id,
self.id, setting, default)
- def find_user_settings(self, user_i, pattern, default=[]):
+ def find_user_settings(self, user_id: int, pattern: str,
+ default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find(user_id,
self.id, pattern, default)
- def find_user_settings_prefix(self, user_id, prefix, default=[]):
+ def find_user_settings_prefix(self, user_id: int, prefix: str,
+ default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_prefix(
user_id, self.id, prefix, default)
- def del_user_setting(self, user_id, setting):
+ def del_user_setting(self, user_id: int, setting: str):
self.bot.database.user_channel_settings.delete(user_id, self.id,
setting)
- def find_all_by_setting(self, setting, default=[]):
+ def find_all_by_setting(self, setting: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default)
- def send_message(self, text, prefix=None, tags={}):
+ def send_message(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.name, text, prefix=prefix, tags=tags)
- def send_notice(self, text, prefix=None, tags={}):
+ def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.name, text, prefix=prefix, tags=tags)
- def send_mode(self, mode=None, target=None):
+ def send_mode(self, mode: str=None, target: str=None):
self.server.send_mode(self.name, mode, target)
- def send_kick(self, target, reason=None):
+ def send_kick(self, target: str, reason: str=None):
self.server.send_kick(self.name, target, reason)
- def send_ban(self, hostmask):
+ def send_ban(self, hostmask: str):
self.server.send_mode(self.name, "+b", hostmask)
- def send_unban(self, hostmask):
+ def send_unban(self, hostmask: str):
self.server.send_mode(self.name, "-b", hostmask)
- def send_topic(self, topic):
+ def send_topic(self, topic: str):
self.server.send_topic(self.name, topic)
- def send_part(self, reason=None):
+ def send_part(self, reason: str=None):
self.server.send_part(self.name, reason)
- def mode_or_above(self, user, mode):
+ def mode_or_above(self, user: IRCUser.User, mode: str) -> bool:
mode_orders = list(self.server.prefix_modes)
mode_index = mode_orders.index(mode)
for mode in mode_orders[:mode_index+1]:
@@ -140,8 +149,8 @@ class Channel(IRCObject.Object):
return True
return False
- def has_mode(self, user, mode):
+ def has_mode(self, user: IRCUser.User, mode: str) -> bool:
return user in self.modes.get(mode, [])
- def get_user_status(self, user):
+ def get_user_status(self, user: IRCUser.User) -> typing.Set:
return self.user_modes.get(user, [])
diff --git a/src/IRCServer.py b/src/IRCServer.py
index e92cd291..16ad839a 100644
--- a/src/IRCServer.py
+++ b/src/IRCServer.py
@@ -1,5 +1,5 @@
-import collections, socket, ssl, sys, time
-from src import IRCChannel, IRCObject, IRCUser, utils
+import collections, socket, ssl, sys, time, typing
+from src import EventManager, IRCBot, IRCChannel, IRCObject, IRCUser, utils
THROTTLE_LINES = 4
THROTTLE_SECONDS = 1
@@ -7,8 +7,12 @@ READ_TIMEOUT_SECONDS = 120
PING_INTERVAL_SECONDS = 30
class Server(IRCObject.Object):
- def __init__(self, bot, events, id, alias, hostname, port, password,
- ipv4, tls, bindhost, nickname, username, realname):
+ def __init__(self,
+ bot: "IRCBot.Bot",
+ events: EventManager.EventHook,
+ id: int, alias: str, hostname: str, port: int, password: str,
+ ipv4: bool, tls: bool, bindhost: str,
+ nickname: str, username: str, realname: str):
self.connected = False
self.bot = bot
self.events = events
@@ -121,77 +125,80 @@ class Server(IRCObject.Object):
except:
pass
- def set_setting(self, setting, value):
+ def set_setting(self, setting: str, value: typing.Any):
self.bot.database.server_settings.set(self.id, setting,
value)
- def get_setting(self, setting, default=None):
+ def get_setting(self, setting: str, default: typing.Any=None):
return self.bot.database.server_settings.get(self.id,
setting, default)
- def find_settings(self, pattern, default=[]):
+ def find_settings(self, pattern: str, default: typing.Any=[]):
return self.bot.database.server_settings.find(self.id,
pattern, default)
- def find_settings_prefix(self, prefix, default=[]):
+ def find_settings_prefix(self, prefix: str, default: typing.Any=[]):
return self.bot.database.server_settings.find_prefix(
self.id, prefix, default)
- def del_setting(self, setting):
+ def del_setting(self, setting: str):
self.bot.database.server_settings.delete(self.id, setting)
- def get_user_setting(self, nickname, setting, default=None):
+ def get_user_setting(self, nickname: str, setting: str,
+ default: typing.Any=None):
user_id = self.get_user_id(nickname)
return self.bot.database.user_settings.get(user_id, setting, default)
- def set_user_setting(self, nickname, setting, value):
+ def set_user_setting(self, nickname: str, setting: str, value: typing.Any):
user_id = self.get_user_id(nickname)
self.bot.database.user_settings.set(user_id, setting, value)
- def get_all_user_settings(self, setting, default=[]):
+ def get_all_user_settings(self, setting: str, default: typing.Any=[]):
return self.bot.database.user_settings.find_all_by_setting(
self.id, setting, default)
- def find_all_user_channel_settings(self, setting, default=[]):
+ def find_all_user_channel_settings(self, setting: str,
+ default: typing.Any=[]):
return self.bot.database.user_channel_settings.find_all_by_setting(
self.id, setting, default)
- def set_own_nickname(self, nickname):
+ def set_own_nickname(self, nickname: str):
self.nickname = nickname
- self.nickname_lower = utils.irc.lower(self, nickname)
- def is_own_nickname(self, nickname):
+ self.nickname_lower = utils.irc.lower(self.case_mapping, nickname)
+ def is_own_nickname(self, nickname: str):
return utils.irc.equals(self, nickname, self.nickname)
- def add_own_mode(self, mode, arg=None):
+ def add_own_mode(self, mode: str, arg: str=None):
self.own_modes[mode] = arg
- def remove_own_mode(self, mode):
+ def remove_own_mode(self, mode: str):
del self.own_modes[mode]
- def change_own_mode(self, remove, mode, arg=None):
+ def change_own_mode(self, remove: bool, mode: str, arg: str=None):
if remove:
self.remove_own_mode(mode)
else:
self.add_own_mode(mode, arg)
- def has_user(self, nickname):
- return utils.irc.lower(self, nickname) in self.users
- def get_user(self, nickname, create=True):
+ def has_user(self, nickname: str):
+ return utils.irc.lower(self.case_mapping, nickname) in self.users
+ def get_user(self, nickname: str, create: bool=True):
if not self.has_user(nickname) and create:
user_id = self.get_user_id(nickname)
new_user = IRCUser.User(nickname, user_id, self, self.bot)
self.events.on("new.user").call(user=new_user, server=self)
self.users[new_user.nickname_lower] = new_user
self.new_users.add(new_user)
- return self.users.get(utils.irc.lower(self, nickname), None)
- def get_user_id(self, nickname):
+ return self.users.get(utils.irc.lower(self.case_mapping, nickname),
+ None)
+ def get_user_id(self, nickname: str):
self.bot.database.users.add(self.id, nickname)
return self.bot.database.users.get_id(self.id, nickname)
- def remove_user(self, user):
+ def remove_user(self, user: IRCUser.User):
del self.users[user.nickname_lower]
for channel in user.channels:
channel.remove_user(user)
- def change_user_nickname(self, old_nickname, new_nickname):
- user = self.users.pop(utils.irc.lower(self, old_nickname))
+ def change_user_nickname(self, old_nickname: str, new_nickname: str):
+ user = self.users.pop(utils.irc.lower(self.case_mapping, old_nickname))
user._id = self.get_user_id(new_nickname)
- self.users[utils.irc.lower(self, new_nickname)] = user
- def has_channel(self, channel_name):
+ self.users[utils.irc.lower(self.case_mapping, new_nickname)] = user
+ def has_channel(self, channel_name: str):
return channel_name[0] in self.channel_types and utils.irc.lower(
- self, channel_name) in self.channels
- def get_channel(self, channel_name):
+ self.case_mapping, channel_name) in self.channels
+ def get_channel(self, channel_name: str):
if not self.has_channel(channel_name):
channel_id = self.get_channel_id(channel_name)
new_channel = IRCChannel.Channel(channel_name, channel_id,
@@ -199,15 +206,15 @@ class Server(IRCObject.Object):
self.events.on("new.channel").call(channel=new_channel,
server=self)
self.channels[new_channel.name] = new_channel
- return self.channels[utils.irc.lower(self, channel_name)]
- def get_channel_id(self, channel_name):
+ return self.channels[utils.irc.lower(self.case_mapping, channel_name)]
+ def get_channel_id(self, channel_name: str):
self.bot.database.channels.add(self.id, channel_name)
return self.bot.database.channels.get_id(self.id, channel_name)
- def remove_channel(self, channel):
+ def remove_channel(self, channel: IRCChannel.Channel):
for user in channel.users:
user.part_channel(channel)
del self.channels[channel.name]
- def parse_data(self, line):
+ def parse_data(self, line: str):
if not line:
return
self.events.on("raw").call_unsafe(server=self, line=line)
@@ -271,7 +278,7 @@ class Server(IRCObject.Object):
def read_timed_out(self):
return self.until_read_timeout == 0
- def send(self, data):
+ def send(self, data: str):
returned = self.events.on("preprocess.send").call_unsafe_for_result(
server=self, line=data)
line = returned or data
@@ -314,16 +321,16 @@ class Server(IRCObject.Object):
time_left = time_left-now
return time_left
- def send_user(self, username, realname):
+ def send_user(self, username: str, realname: str):
self.send("USER %s 0 * :%s" % (username, realname))
- def send_nick(self, nickname):
+ def send_nick(self, nickname: str):
self.send("NICK %s" % nickname)
def send_capibility_ls(self):
self.send("CAP LS 302")
- def queue_capability(self, capability):
+ def queue_capability(self, capability: str):
self._capability_queue.add(capability)
- def queue_capabilities(self, capabilities):
+ def queue_capabilities(self, capabilities: typing.List[str]):
self._capability_queue.update(capabilities)
def send_capability_queue(self):
if self.has_capability_queue():
@@ -332,46 +339,46 @@ class Server(IRCObject.Object):
self.send_capability_request(capabilities)
def has_capability_queue(self):
return bool(len(self._capability_queue))
- def send_capability_request(self, capability):
+ def send_capability_request(self, capability: str):
self.send("CAP REQ :%s" % capability)
def send_capability_end(self):
self.send("CAP END")
- def send_authenticate(self, text):
+ def send_authenticate(self, text: str):
self.send("AUTHENTICATE %s" % text)
def send_starttls(self):
self.send("STARTTLS")
def waiting_for_capabilities(self):
return bool(len(self._capabilities_waiting))
- def wait_for_capability(self, capability):
+ def wait_for_capability(self, capability: str):
self._capabilities_waiting.add(capability)
- def capability_done(self, capability):
+ def capability_done(self, capability: str):
self._capabilities_waiting.remove(capability)
if not self._capabilities_waiting:
self.send_capability_end()
- def send_pass(self, password):
+ def send_pass(self, password: str):
self.send("PASS %s" % password)
- def send_ping(self, nonce="hello"):
+ def send_ping(self, nonce: str="hello"):
self.send("PING :%s" % nonce)
- def send_pong(self, nonce="hello"):
+ def send_pong(self, nonce: str="hello"):
self.send("PONG :%s" % nonce)
- def try_rejoin(self, event):
+ def try_rejoin(self, event: EventManager.Event):
if event["server_id"] == self.id and event["channel_name"
] in self.attempted_join:
self.send_join(event["channel_name"], event["key"])
- def send_join(self, channel_name, key=None):
+ def send_join(self, channel_name: str, key: str=None):
self.send("JOIN %s%s" % (channel_name,
"" if key == None else " %s" % key))
- def send_part(self, channel_name, reason=None):
+ def send_part(self, channel_name: str, reason: str=None):
self.send("PART %s%s" % (channel_name,
"" if reason == None else " %s" % reason))
- def send_quit(self, reason="Leaving"):
+ def send_quit(self, reason: str="Leaving"):
self.send("QUIT :%s" % reason)
- def _tag_str(self, tags):
+ def _tag_str(self, tags: dict):
tag_str = ""
for tag, value in tags.items():
if tag_str:
@@ -383,7 +390,8 @@ class Server(IRCObject.Object):
tag_str = "@%s " % tag_str
return tag_str
- def send_message(self, target, message, prefix=None, tags={}):
+ def send_message(self, target: str, message: str, prefix: str=None,
+ tags: dict={}):
full_message = message if not prefix else prefix+message
self.send("%sPRIVMSG %s :%s" % (self._tag_str(tags), target,
full_message))
@@ -408,7 +416,8 @@ class Server(IRCObject.Object):
message=full_message, message_split=full_message_split,
user=user, action=action, server=self)
- def send_notice(self, target, message, prefix=None, tags={}):
+ def send_notice(self, target: str, message: str, prefix: str=None,
+ tags: dict={}):
full_message = message if not prefix else prefix+message
self.send("%sNOTICE %s :%s" % (self._tag_str(tags), target,
full_message))
@@ -419,31 +428,31 @@ class Server(IRCObject.Object):
self.get_user(target).buffer.add_notice(None, message, tags,
True)
- def send_mode(self, target, mode=None, args=None):
+ def send_mode(self, target: str, mode: str=None, args: str=None):
self.send("MODE %s%s%s" % (target, "" if mode == None else " %s" % mode,
"" if args == None else " %s" % args))
- def send_topic(self, channel_name, topic):
+ def send_topic(self, channel_name: str, topic: str):
self.send("TOPIC %s :%s" % (channel_name, topic))
- def send_kick(self, channel_name, target, reason=None):
+ def send_kick(self, channel_name: str, target: str, reason: str=None):
self.send("KICK %s %s%s" % (channel_name, target,
"" if reason == None else " :%s" % reason))
- def send_names(self, channel_name):
+ def send_names(self, channel_name: str):
self.send("NAMES %s" % channel_name)
- def send_list(self, search_for=None):
+ def send_list(self, search_for: str=None):
self.send(
"LIST%s" % "" if search_for == None else " %s" % search_for)
- def send_invite(self, target, channel_name):
+ def send_invite(self, target: str, channel_name: str):
self.send("INVITE %s %s" % (target, channel_name))
- def send_whois(self, target):
+ def send_whois(self, target: str):
self.send("WHOIS %s" % target)
- def send_whowas(self, target, amount=None, server=None):
+ def send_whowas(self, target: str, amount: int=None, server: str=None):
self.send("WHOWAS %s%s%s" % (target,
"" if amount == None else " %s" % amount,
"" if server == None else " :%s" % server))
- def send_who(self, filter=None):
+ def send_who(self, filter: str=None):
self.send("WHO%s" % ("" if filter == None else " %s" % filter))
- def send_whox(self, mask, filter, fields, label=None):
+ def send_whox(self, mask: str, filter: str, fields: str, label: str=None):
self.send("WHO %s %s%%%s%s" % (mask, filter, fields,
","+label if label else ""))
diff --git a/src/IRCUser.py b/src/IRCUser.py
index c0bbb862..edead0d2 100644
--- a/src/IRCUser.py
+++ b/src/IRCUser.py
@@ -1,8 +1,9 @@
-import uuid
-from src import IRCBuffer, IRCObject, utils
+import typing, uuid
+from src import IRCBot, IRCChannel, IRCBuffer, IRCObject, IRCServer, utils
class User(IRCObject.Object):
- def __init__(self, nickname, id, server, bot):
+ def __init__(self, nickname: str, id: int, server: "IRCServer.Server",
+ bot: "IRCBot.Bot"):
self.server = server
self.set_nickname(nickname)
self._id = id
@@ -20,46 +21,51 @@ class User(IRCObject.Object):
self.away = False
self.buffer = IRCBuffer.Buffer(bot, server)
- def __repr__(self):
+ def __repr__(self) -> str:
return "IRCUser.User(%s|%s)" % (self.server.name, self.name)
- def __str__(self):
+ def __str__(self) -> str:
return self.nickname
- def get_id(self):
+ def get_id(self)-> int:
return (self.identified_account_id_override or
self.identified_account_id or self._id)
- def get_identified_account(self):
+ def get_identified_account(self) -> str:
return (self.identified_account_override or self.identified_account)
- def set_nickname(self, nickname):
+ def set_nickname(self, nickname: str):
self.nickname = nickname
- self.nickname_lower = utils.irc.lower(self.server, nickname)
+ self.nickname_lower = utils.irc.lower(self.server.case_mapping,
+ nickname)
self.name = self.nickname_lower
- def join_channel(self, channel):
+ def join_channel(self, channel: "IRCChannel.Channel"):
self.channels.add(channel)
- def part_channel(self, channel):
+ def part_channel(self, channel: "IRCChannel.Channel"):
self.channels.remove(channel)
- def set_setting(self, setting, value):
+
+ def set_setting(self, setting: str, value: typing.Any):
self.bot.database.user_settings.set(self.get_id(), setting, value)
- def get_setting(self, setting, default=None):
+ def get_setting(self, setting: str, default: typing.Any=None) -> typing.Any:
return self.bot.database.user_settings.get(self.get_id(), setting,
default)
- def find_settings(self, pattern, default=[]):
+ def find_settings(self, pattern: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find(self.get_id(), pattern,
default)
- def find_settings_prefix(self, prefix, default=[]):
+ def find_settings_prefix(self, prefix: str, default: typing.Any=[]
+ ) -> typing.List[typing.Any]:
return self.bot.database.user_settings.find_prefix(self.get_id(),
prefix, default)
def del_setting(self, setting):
self.bot.database.user_settings.delete(self.get_id(), setting)
- def get_channel_settings_per_setting(self, setting, default=[]):
+ def get_channel_settings_per_setting(self, setting: str,
+ default: typing.Any=[]) -> typing.List[typing.Any]:
return self.bot.database.user_channel_settings.find_by_setting(
self.get_id(), setting, default)
- def send_message(self, message, prefix=None, tags={}):
+ def send_message(self, message: str, prefix: str=None, tags: dict={}):
self.server.send_message(self.nickname, message, prefix=prefix,
tags=tags)
- def send_notice(self, text, prefix=None, tags={}):
+ def send_notice(self, text: str, prefix: str=None, tags: dict={}):
self.server.send_notice(self.nickname, text, prefix=prefix, tags=tags)
- def send_ctcp_response(self, command, args):
+ def send_ctcp_response(self, command: str, args: str):
self.send_notice("\x01%s %s\x01" % (command, args))
diff --git a/src/Logging.py b/src/Logging.py
index 6ea6efe8..d5c42e56 100644
--- a/src/Logging.py
+++ b/src/Logging.py
@@ -1,4 +1,4 @@
-import logging, logging.handlers, os, sys, time
+import logging, logging.handlers, os, sys, time, typing
LEVELS = {
"trace": logging.DEBUG-1,
@@ -23,7 +23,7 @@ class BitBotFormatter(logging.Formatter):
return s
class Log(object):
- def __init__(self, level, location):
+ def __init__(self, level: str, location: str):
logging.addLevelName(LEVELS["trace"], "TRACE")
self.logger = logging.getLogger(__name__)
@@ -49,17 +49,17 @@ class Log(object):
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
- def trace(self, message, params, **kwargs):
+ def trace(self, message: str, params: typing.List, **kwargs):
self._log(message, params, LEVELS["trace"], kwargs)
- def debug(self, message, params, **kwargs):
+ def debug(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.DEBUG, kwargs)
- def info(self, message, params, **kwargs):
+ def info(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.INFO, kwargs)
- def warn(self, message, params, **kwargs):
+ def warn(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.WARN, kwargs)
- def error(self, message, params, **kwargs):
+ def error(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.ERROR, kwargs)
- def critical(self, message, params, **kwargs):
+ def critical(self, message: str, params: typing.List, **kwargs):
self._log(message, params, logging.CRITICAL, kwargs)
- def _log(self, message, params, level, kwargs):
+ def _log(self, message: str, params: typing.List, level: int, kwargs: dict):
self.logger.log(level, message, *params, **kwargs)
diff --git a/src/ModuleManager.py b/src/ModuleManager.py
index 8f182c8d..f8ddd077 100644
--- a/src/ModuleManager.py
+++ b/src/ModuleManager.py
@@ -1,8 +1,5 @@
-import gc, glob, imp, io, inspect, os, sys, uuid
-from . import utils
-
-BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
-BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
+import gc, glob, imp, io, inspect, os, sys, typing, uuid
+from src import Config, EventManager, Exports, IRCBot, Logging, Timers, utils
class ModuleException(Exception):
pass
@@ -22,7 +19,11 @@ class ModuleNotLoadedWarning(ModuleWarning):
pass
class BaseModule(object):
- def __init__(self, bot, events, exports, timers):
+ def __init__(self,
+ bot: "IRCBot.Bot",
+ events: EventManager.EventHook,
+ exports: Exports.Exports,
+ timers: Timers.Timers):
self.bot = bot
self.events = events
self.exports = exports
@@ -32,7 +33,13 @@ class BaseModule(object):
pass
class ModuleManager(object):
- def __init__(self, events, exports, timers, config, log, directory):
+ def __init__(self,
+ events: EventManager.EventHook,
+ exports: Exports.Exports,
+ timers: Timers.Timers,
+ config: Config.Config,
+ log: Logging.Log,
+ directory: str):
self.events = events
self.exports = exports
self.config = config
@@ -43,23 +50,24 @@ class ModuleManager(object):
self.modules = {}
self.waiting_requirement = {}
- def list_modules(self):
+ def list_modules(self) -> typing.List[str]:
return sorted(glob.glob(os.path.join(self.directory, "*.py")))
- def _module_name(self, path):
+ def _module_name(self, path: str) -> str:
return os.path.basename(path).rsplit(".py", 1)[0].lower()
- def _module_path(self, name):
+ def _module_path(self, name: str) -> str:
return os.path.join(self.directory, "%s.py" % name)
- def _import_name(self, name):
+ def _import_name(self, name: str) -> str:
return "bitbot_%s" % name
- def _get_magic(self, obj, magic, default):
+ def _get_magic(self, obj: typing.Any, magic: str, default: typing.Any
+ ) -> typing.Any:
return getattr(obj, magic) if hasattr(obj, magic) else default
- def _load_module(self, bot, name):
+ def _load_module(self, bot: "IRCBot.Bot", name: str):
path = self._module_path(name)
- for hashflag, value in utils.get_hashflags(path):
+ for hashflag, value in utils.parse.hashflags(path):
if hashflag == "ignore":
# nope, ignore this module.
raise ModuleNotLoadedWarning("module ignored")
@@ -97,10 +105,12 @@ class ModuleManager(object):
module_object._name = name.title()
for attribute_name in dir(module_object):
attribute = getattr(module_object, attribute_name)
- for hook in self._get_magic(attribute, BITBOT_HOOKS_MAGIC, []):
+ for hook in self._get_magic(attribute,
+ utils.consts.BITBOT_HOOKS_MAGIC, []):
context_events.on(hook["event"]).hook(attribute,
**hook["kwargs"])
- for export in self._get_magic(module_object, BITBOT_EXPORTS_MAGIC, []):
+ for export in self._get_magic(module_object,
+ utils.consts.BITBOT_EXPORTS_MAGIC, []):
context_exports.add(export["setting"], export["value"])
module_object._context = context
@@ -111,7 +121,7 @@ class ModuleManager(object):
"attempted to be used twice")
return module_object
- def load_module(self, bot, name):
+ def load_module(self, bot: "IRCBot.Bot", name: str):
try:
module = self._load_module(bot, name)
except ModuleWarning as warning:
@@ -128,7 +138,8 @@ class ModuleManager(object):
self.load_module(bot, requirement_name)
self.log.info("Module '%s' loaded", [name])
- def load_modules(self, bot, whitelist=[], blacklist=[]):
+ def load_modules(self, bot: "IRCBot.Bot", whitelist: typing.List[str]=[],
+ blacklist: typing.List[str]=[]):
for path in self.list_modules():
name = self._module_name(path)
if name in whitelist or (not whitelist and not name in blacklist):
@@ -137,7 +148,7 @@ class ModuleManager(object):
except ModuleWarning:
pass
- def unload_module(self, name):
+ def unload_module(self, name: str):
if not name in self.modules:
raise ModuleNotFoundException()
module = self.modules[name]
diff --git a/src/Socket.py b/src/Socket.py
index f405b0a9..474336d0 100644
--- a/src/Socket.py
+++ b/src/Socket.py
@@ -1,7 +1,9 @@
-
+import socket, typing
class Socket(object):
- def __init__(self, socket, on_read, encoding="utf8"):
+ def __init__(self, socket: socket.socket,
+ on_read: typing.Callable[["Socket", str], None],
+ encoding: str="utf8"):
self.socket = socket
self._on_read = on_read
self.encoding = encoding
@@ -12,18 +14,18 @@ class Socket(object):
self.length = None
self.connected = True
- def fileno(self):
+ def fileno(self) -> int:
return self.socket.fileno()
def disconnect(self):
self.connected = False
- def _decode(self, s):
- return s.decode(self.encoding) if self.encoding else s
- def _encode(self, s):
- return s.encode(self.encoding) if self.encoding else s
+ def _decode(self, s: bytes) -> str:
+ return s.decode(self.encoding)
+ def _encode(self, s: str) -> bytes:
+ return s.encode(self.encoding)
- def read(self):
+ def read(self) -> typing.Optional[typing.List[str]]:
data = self.socket.recv(1024)
if not data:
return None
@@ -35,17 +37,17 @@ class Socket(object):
if data_split[-1]:
self._read_buffer = data_split.pop(-1)
return [self._decode(data) for data in data_split]
- return [data.decode(self.encoding)]
+ return [self._decode(data)]
- def parse_data(self, data):
+ def parse_data(self, data: str):
self._on_read(self, data)
- def send(self, data):
+ def send(self, data: str):
self._write_buffer += self._encode(data)
def _send(self):
self._write_buffer = self._write_buffer[self.socket.send(
self._write_buffer):]
- def waiting_send(self):
+ def waiting_send(self) -> bool:
return bool(len(self._write_buffer))
diff --git a/src/Timers.py b/src/Timers.py
index e49e7edc..c3336d87 100644
--- a/src/Timers.py
+++ b/src/Timers.py
@@ -1,7 +1,9 @@
-import time, uuid
+import time, typing, uuid
+from src import Database, EventManager, Logging
class Timer(object):
- def __init__(self, id, context, name, delay, next_due, kwargs):
+ def __init__(self, id: int, context: str, name: str, delay: float,
+ next_due: float, kwargs: dict):
self.id = id
self.context = context
self.name = name
@@ -15,9 +17,9 @@ class Timer(object):
def set_next_due(self):
self.next_due = time.time()+self.delay
- def due(self):
+ def due(self) -> bool:
return self.time_left() <= 0
- def time_left(self):
+ def time_left(self) -> float:
return self.next_due-time.time()
def redo(self):
@@ -25,42 +27,33 @@ class Timer(object):
self.set_next_due()
def finish(self):
self._done = True
- def done(self):
+ def done(self) -> bool:
return self._done
-class TimersContext(object):
- def __init__(self, parent, context):
- self._parent = parent
- self.context = context
- def add(self, name, delay, next_due=None, **kwargs):
- self._parent._add(self.context, name, delay, next_due, None, False,
- kwargs)
- def add_persistent(self, name, delay, next_due=None, **kwargs):
- self._parent._add(None, name, delay, next_due, None, True,
- kwargs)
-
class Timers(object):
- def __init__(self, database, events, log):
+ def __init__(self, database: Database.Database,
+ events: EventManager.EventHook,
+ log: Logging.Log):
self.database = database
self.events = events
self.log = log
self.timers = []
self.context_timers = {}
- def new_context(self, context):
+ def new_context(self, context: str) -> "TimersContext":
return TimersContext(self, context)
- def setup(self, timers):
+ def setup(self, timers: typing.List[typing.Tuple[str, dict]]):
for name, timer in timers:
id = name.split("timer-", 1)[1]
self._add(timer["name"], None, timer["delay"], timer[
"next-due"], id, False, timer["kwargs"])
- def _persist(self, timer):
+ def _persist(self, timer: Timer):
self.database.bot_settings.set("timer-%s" % timer.id, {
"name": timer.name, "delay": timer.delay,
"next-due": timer.next_due, "kwargs": timer.kwargs})
- def _remove(self, timer):
+ def _remove(self, timer: Timer):
if timer.context:
self.context_timers[timer.context].remove(timer)
if not self.context_timers[timer.context]:
@@ -69,11 +62,13 @@ class Timers(object):
self.timers.remove(timer)
self.database.bot_settings.delete("timer-%s" % timer.id)
- def add(self, name, delay, next_due=None, **kwargs):
+ def add(self, name: str, delay: float, next_due: float=None, **kwargs):
self._add(None, name, delay, next_due, None, False, kwargs)
- def add_persistent(self, name, delay, next_due=None, **kwargs):
+ def add_persistent(self, name: str, delay: float, next_due: float=None,
+ **kwargs):
self._add(None, name, delay, next_due, None, True, kwargs)
- def _add(self, context, name, delay, next_due, id, persist, kwargs):
+ def _add(self, context: str, name: str, delay: float, next_due: float,
+ id: str, persist: bool, kwargs: dict):
id = id or uuid.uuid4().hex
timer = Timer(id, context, name, delay, next_due, kwargs)
if persist:
@@ -86,13 +81,13 @@ class Timers(object):
else:
self.timers.append(timer)
- def next(self):
+ def next(self) -> float:
times = filter(None, [timer.time_left() for timer in self.get_timers()])
if not times:
return None
return max(min(times), 0)
- def get_timers(self):
+ def get_timers(self) -> typing.List[Timer]:
return self.timers + sum(self.context_timers.values(), [])
def call(self):
@@ -104,6 +99,19 @@ class Timers(object):
if timer.done():
self._remove(timer)
- def purge_context(self, context):
+ def purge_context(self, context: str):
if context in self.context_timers:
del self.context_timers[context]
+
+class TimersContext(object):
+ def __init__(self, parent: Timers, context: str):
+ self._parent = parent
+ self.context = context
+ def add(self, name: str, delay: float, next_due: float=None,
+ **kwargs):
+ self._parent._add(self.context, name, delay, next_due, None, False,
+ kwargs)
+ def add_persistent(self, name: str, delay: float, next_due: float=None,
+ **kwargs):
+ self._parent._add(None, name, delay, next_due, None, True,
+ kwargs)
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index 87671108..52de64ba 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -1,6 +1,5 @@
-import decimal, io, re
-from src import ModuleManager
-from . import irc, http
+import decimal, io, re, typing
+from src.utils import consts, irc, http, parse
TIME_SECOND = 1
TIME_MINUTE = TIME_SECOND*60
@@ -8,7 +7,7 @@ TIME_HOUR = TIME_MINUTE*60
TIME_DAY = TIME_HOUR*24
TIME_WEEK = TIME_DAY*7
-def time_unit(seconds):
+def time_unit(seconds: int) -> typing.Tuple[int, str]:
since = None
unit = None
if seconds >= TIME_WEEK:
@@ -29,7 +28,7 @@ def time_unit(seconds):
since = int(since)
if since > 1:
unit = "%ss" % unit # pluralise the unit
- return [since, unit]
+ return (since, unit)
REGEX_PRETTYTIME = re.compile("\d+[wdhms]", re.I)
@@ -38,7 +37,7 @@ SECONDS_HOURS = SECONDS_MINUTES*60
SECONDS_DAYS = SECONDS_HOURS*24
SECONDS_WEEKS = SECONDS_DAYS*7
-def from_pretty_time(pretty_time):
+def from_pretty_time(pretty_time: str) -> typing.Optional[int]:
seconds = 0
for match in re.findall(REGEX_PRETTYTIME, pretty_time):
number, unit = int(match[:-1]), match[-1].lower()
@@ -54,12 +53,14 @@ def from_pretty_time(pretty_time):
if seconds > 0:
return seconds
+UNIT_MINIMUM = 6
UNIT_SECOND = 5
UNIT_MINUTE = 4
UNIT_HOUR = 3
UNIT_DAY = 2
UNIT_WEEK = 1
-def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
+def to_pretty_time(total_seconds: int, minimum_unit: int=UNIT_SECOND,
+ max_units: int=UNIT_MINIMUM) -> str:
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
@@ -84,7 +85,7 @@ def to_pretty_time(total_seconds, minimum_unit=UNIT_SECOND, max_units=6):
units += 1
return out
-def parse_number(s):
+def parse_number(s: str) -> str:
try:
decimal.Decimal(s)
return s
@@ -110,28 +111,18 @@ def parse_number(s):
IS_TRUE = ["true", "yes", "on", "y"]
IS_FALSE = ["false", "no", "off", "n"]
-def bool_or_none(s):
+def bool_or_none(s: str) -> typing.Optional[bool]:
s = s.lower()
if s in IS_TRUE:
return True
elif s in IS_FALSE:
return False
-def int_or_none(s):
+def int_or_none(s: str) -> typing.Optional[int]:
stripped_s = s.lstrip("0")
if stripped_s.isdigit():
return int(stripped_s)
-def get_closest_setting(event, setting, default=None):
- server = event["server"]
- if "channel" in event:
- closest = event["channel"]
- elif "target" in event and "is_channel" in event and event["is_channel"]:
- closest = event["target"]
- else:
- closest = event["user"]
- return closest.get_setting(setting, server.get_setting(setting, default))
-
-def prevent_highlight(nickname):
+def prevent_highlight(nickname: str) -> str:
return nickname[0]+"\u200c"+nickname[1:]
class EventError(Exception):
@@ -139,86 +130,40 @@ class EventError(Exception):
class EventsResultsError(EventError):
def __init__(self):
EventError.__init__(self, "Failed to load results")
+class EventsNotEnoughArgsError(EventError):
+ def __init__(self, n):
+ EventError.__init__(self, "Not enough arguments (minimum %d)" % n)
+class EventsUsageError(EventError):
+ def __init__(self, usage):
+ EventError.__init__(self, "Not enough arguments, usage: %s" % usage)
-def _set_get_append(obj, setting, item):
+def _set_get_append(obj: typing.Any, setting: str, item: typing.Any):
if not hasattr(obj, setting):
setattr(obj, setting, [])
getattr(obj, setting).append(item)
-def hook(event, **kwargs):
+def hook(event: str, **kwargs):
def _hook_func(func):
- _set_get_append(func, ModuleManager.BITBOT_HOOKS_MAGIC,
+ _set_get_append(func, consts.BITBOT_HOOKS_MAGIC,
{"event": event, "kwargs": kwargs})
return func
return _hook_func
-def export(setting, value):
+def export(setting: str, value: typing.Any):
def _export_func(module):
- _set_get_append(module, ModuleManager.BITBOT_EXPORTS_MAGIC,
+ _set_get_append(module, consts.BITBOT_EXPORTS_MAGIC,
{"setting": setting, "value": value})
return module
return _export_func
-COMMENT_TYPES = ["#", "//"]
-def get_hashflags(filename):
- hashflags = {}
- with io.open(filename, mode="r", encoding="utf8") as f:
- for line in f:
- line = line.strip("\n")
- found = False
- for comment_type in COMMENT_TYPES:
- if line.startswith(comment_type):
- line = line.replace(comment_type, "", 1).lstrip()
- found = True
- break
-
- if not found:
- break
- elif line.startswith("--"):
- hashflag, sep, value = line[2:].partition(" ")
- hashflags[hashflag] = value if sep else None
- return hashflags.items()
-
-class Docstring(object):
- def __init__(self, description, items, var_items):
- self.description = description
- self.items = items
- self.var_items = var_items
-
-def parse_docstring(s):
- description = ""
- last_item = None
- items = {}
- var_items = {}
- if s:
- for line in s.split("\n"):
- line = line.strip()
-
- if line:
- if line[0] == ":":
- key, _, value = line[1:].partition(": ")
- last_item = key
-
- if key in var_items:
- var_items[key].append(value)
- elif key in items:
- var_items[key] = [items.pop(key), value]
- else:
- items[key] = value
- else:
- if last_item:
- items[last_item] += " %s" % line
- else:
- if description:
- description += " "
- description += line
- return Docstring(description, items, var_items)
-
-def top_10(items, convert_key=lambda x: x, value_format=lambda x: x):
- top_10 = sorted(items.keys())
- top_10 = sorted(top_10, key=items.get, reverse=True)[:10]
+TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
+def top_10(items: typing.List[typing.Any],
+ convert_key: TOP_10_CALLABLE=lambda x: x,
+ value_format: TOP_10_CALLABLE=lambda x: x):
+ top_10 = sorted(items.keys())
+ top_10 = sorted(top_10, key=items.get, reverse=True)[:10]
- top_10_items = []
- for key in top_10:
- top_10_items.append("%s (%s)" % (convert_key(key),
- value_format(items[key])))
+ top_10_items = []
+ for key in top_10:
+ top_10_items.append("%s (%s)" % (convert_key(key),
+ value_format(items[key])))
- return top_10_items
+ return top_10_items
diff --git a/src/utils/consts.py b/src/utils/consts.py
new file mode 100644
index 00000000..d2816509
--- /dev/null
+++ b/src/utils/consts.py
@@ -0,0 +1,2 @@
+BITBOT_HOOKS_MAGIC = "__bitbot_hooks"
+BITBOT_EXPORTS_MAGIC = "__bitbot_exports"
diff --git a/src/utils/http.py b/src/utils/http.py
index ddf88b2b..b949e9ff 100644
--- a/src/utils/http.py
+++ b/src/utils/http.py
@@ -1,4 +1,4 @@
-import re, signal, traceback, urllib.error, urllib.parse
+import re, signal, traceback, typing, urllib.error, urllib.parse
import json as _json
import bs4, requests
@@ -18,9 +18,10 @@ class HTTPParsingException(HTTPException):
def throw_timeout():
raise HTTPTimeoutException()
-def get_url(url, method="GET", get_params={}, post_data=None, headers={},
- json_data=None, code=False, json=False, soup=False, parser="lxml",
- fallback_encoding="utf8"):
+def get_url(url: str, method: str="GET", get_params: dict={},
+ post_data: typing.Any=None, headers: dict={},
+ json_data: typing.Any=None, code: bool=False, json: bool=False,
+ soup: bool=False, parser: str="lxml", fallback_encoding: str="utf8"):
if not urllib.parse.urlparse(url).scheme:
url = "http://%s" % url
@@ -66,6 +67,6 @@ def get_url(url, method="GET", get_params={}, post_data=None, headers={},
else:
return data
-def strip_html(s):
+def strip_html(s: str) -> str:
return bs4.BeautifulSoup(s, "lxml").get_text()
diff --git a/src/utils/irc.py b/src/utils/irc.py
index 792de7f3..3e7f8b76 100644
--- a/src/utils/irc.py
+++ b/src/utils/irc.py
@@ -1,4 +1,4 @@
-import string, re
+import string, re, typing
ASCII_UPPER = string.ascii_uppercase
ASCII_LOWER = string.ascii_lowercase
@@ -7,32 +7,36 @@ STRICT_RFC1459_LOWER = ASCII_LOWER+r'|{}'
RFC1459_UPPER = STRICT_RFC1459_UPPER+"^"
RFC1459_LOWER = STRICT_RFC1459_LOWER+"~"
-def remove_colon(s):
+def remove_colon(s: str) -> str:
if s.startswith(":"):
s = s[1:]
return s
+MULTI_REPLACE_ITERABLE = typing.Iterable[str]
# case mapping lowercase/uppcase logic
-def _multi_replace(s, chars1, chars2):
+def _multi_replace(s: str,
+ chars1: typing.Iterable[str],
+ chars2: typing.Iterable[str]) -> str:
for char1, char2 in zip(chars1, chars2):
s = s.replace(char1, char2)
return s
-def lower(server, s):
- if server.case_mapping == "ascii":
+def lower(case_mapping: str, s: str) -> str:
+ if case_mapping == "ascii":
return _multi_replace(s, ASCII_UPPER, ASCII_LOWER)
- elif server.case_mapping == "rfc1459":
+ elif case_mapping == "rfc1459":
return _multi_replace(s, RFC1459_UPPER, RFC1459_LOWER)
- elif server.case_mapping == "strict-rfc1459":
+ elif case_mapping == "strict-rfc1459":
return _multi_replace(s, STRICT_RFC1459_UPPER, STRICT_RFC1459_LOWER)
else:
- raise ValueError("unknown casemapping '%s'" % server.case_mapping)
+ raise ValueError("unknown casemapping '%s'" % case_mapping)
# compare a string while respecting case mapping
-def equals(server, s1, s2):
- return lower(server, s1) == lower(server, s2)
+def equals(case_mapping: str, s1: str, s2: str) -> bool:
+ return lower(case_mapping, s1) == lower(case_mapping, s2)
class IRCHostmask(object):
- def __init__(self, nickname, username, hostname, hostmask):
+ def __init__(self, nickname: str, username: str, hostname: str,
+ hostmask: str):
self.nickname = nickname
self.username = username
self.hostname = hostname
@@ -42,24 +46,24 @@ class IRCHostmask(object):
def __str__(self):
return self.hostmask
-def seperate_hostmask(hostmask):
+def seperate_hostmask(hostmask: str) -> IRCHostmask:
hostmask = remove_colon(hostmask)
nickname, _, username = hostmask.partition("!")
username, _, hostname = username.partition("@")
return IRCHostmask(nickname, username, hostname, hostmask)
-
class IRCLine(object):
- def __init__(self, tags, prefix, command, args, arbitrary, last, server):
+ def __init__(self, tags: dict, prefix: str, command: str,
+ args: typing.List[str], arbitrary: typing.Optional[str],
+ last: str):
self.tags = tags
self.prefix = prefix
self.command = command
self.args = args
self.arbitrary = arbitrary
self.last = last
- self.server = server
-def parse_line(server, line):
+def parse_line(line: str) -> IRCLine:
tags = {}
prefix = None
command = None
@@ -81,7 +85,7 @@ def parse_line(server, line):
args = line.split(" ")
last = arbitrary or args[-1]
- return IRCLine(tags, prefix, command, args, arbitrary, last, server)
+ return IRCLine(tags, prefix, command, args, arbitrary, last)
COLOR_WHITE, COLOR_BLACK, COLOR_BLUE, COLOR_GREEN = 0, 1, 2, 3
COLOR_RED, COLOR_BROWN, COLOR_PURPLE, COLOR_ORANGE = 4, 5, 6, 7
@@ -94,20 +98,20 @@ FONT_BOLD, FONT_ITALIC, FONT_UNDERLINE, FONT_INVERT = ("\x02", "\x1D",
FONT_COLOR, FONT_RESET = "\x03", "\x0F"
REGEX_COLOR = re.compile("%s\d\d(?:,\d\d)?" % FONT_COLOR)
-def color(s, foreground, background=None):
+def color(s: str, foreground: str, background: str=None) -> str:
foreground = str(foreground).zfill(2)
if background:
background = str(background).zfill(2)
return "%s%s%s%s%s" % (FONT_COLOR, foreground,
"" if not background else ",%s" % background, s, FONT_COLOR)
-def bold(s):
+def bold(s: str) -> str:
return "%s%s%s" % (FONT_BOLD, s, FONT_BOLD)
-def underline(s):
+def underline(s: str) -> str:
return "%s%s%s" % (FONT_UNDERLINE, s, FONT_UNDERLINE)
-def strip_font(s):
+def strip_font(s: str) -> str:
s = s.replace(FONT_BOLD, "")
s = s.replace(FONT_ITALIC, "")
s = REGEX_COLOR.sub("", s)
diff --git a/src/utils/parse.py b/src/utils/parse.py
new file mode 100644
index 00000000..03a585c2
--- /dev/null
+++ b/src/utils/parse.py
@@ -0,0 +1,57 @@
+import io, typing
+
+COMMENT_TYPES = ["#", "//"]
+def hashflags(filename: str) -> typing.List[typing.Tuple[str, str]]:
+ hashflags = {}
+ with io.open(filename, mode="r", encoding="utf8") as f:
+ for line in f:
+ line = line.strip("\n")
+ found = False
+ for comment_type in COMMENT_TYPES:
+ if line.startswith(comment_type):
+ line = line.replace(comment_type, "", 1).lstrip()
+ found = True
+ break
+
+ if not found:
+ break
+ elif line.startswith("--"):
+ hashflag, sep, value = line[2:].partition(" ")
+ hashflags[hashflag] = value if sep else None
+ return list(hashflags.items())
+
+class Docstring(object):
+ def __init__(self, description: str, items: dict, var_items: dict):
+ self.description = description
+ self.items = items
+ self.var_items = var_items
+
+def docstring(s: str) -> Docstring:
+ description = ""
+ last_item = None
+ items = {}
+ var_items = {}
+ if s:
+ for line in s.split("\n"):
+ line = line.strip()
+
+ if line:
+ if line[0] == ":":
+ key, _, value = line[1:].partition(": ")
+ last_item = key
+
+ if key in var_items:
+ var_items[key].append(value)
+ elif key in items:
+ var_items[key] = [items.pop(key), value]
+ else:
+ items[key] = value
+ else:
+ if last_item:
+ items[last_item] += " %s" % line
+ else:
+ if description:
+ description += " "
+ description += line
+ return Docstring(description, items, var_items)
+