diff options
Diffstat (limited to 'src/Database.py')
| -rw-r--r-- | src/Database.py | 101 |
1 files changed, 56 insertions, 45 deletions
diff --git a/src/Database.py b/src/Database.py index c1d07869..c3d48cb6 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,12 +1,14 @@ -import json, os, sqlite3, threading, time +import json, os, sqlite3, threading, time, typing +from src import Logging class Table(object): def __init__(self, database): self.database = database class Servers(Table): - def add(self, alias, hostname, port, password, ipv4, tls, bindhost, - nickname, username=None, realname=None): + def add(self, alias: str, hostname: str, port: int, password: str, + ipv4: bool, tls: bool, bindhost: str, + nickname: str, username: str=None, realname: str=None): username = username or nickname realname = realname or nickname self.database.execute( @@ -18,7 +20,7 @@ class Servers(Table): def get_all(self): return self.database.execute_fetchall( "SELECT server_id, alias FROM servers") - def get(self, id): + def get(self, id: int): return self.database.execute_fetchone( """SELECT server_id, alias, hostname, port, password, ipv4, tls, bindhost, nickname, username, realname FROM servers WHERE @@ -26,46 +28,46 @@ class Servers(Table): [id]) class Channels(Table): - def add(self, server_id, name): + def add(self, server_id: int, name: str): self.database.execute("""INSERT OR IGNORE INTO channels (server_id, name) VALUES (?, ?)""", [server_id, name.lower()]) - def delete(self, channel_id): + def delete(self, channel_id: int): self.database.execute("DELETE FROM channels WHERE channel_id=?", [channel_id]) - def get_id(self, server_id, name): + def get_id(self, server_id: int, name: str): value = self.database.execute_fetchone("""SELECT channel_id FROM channels WHERE server_id=? AND name=?""", [server_id, name.lower()]) return value if value == None else value[0] class Users(Table): - def add(self, server_id, nickname): + def add(self, server_id: int, nickname: str): self.database.execute("""INSERT OR IGNORE INTO users (server_id, nickname) VALUES (?, ?)""", [server_id, nickname.lower()]) - def delete(self, user_id): + def delete(self, user_id: int): self.database.execute("DELETE FROM users WHERE user_id=?", [user_id]) - def get_id(self, server_id, nickname): + def get_id(self, server_id: int, nickname: str): value = self.database.execute_fetchone("""SELECT user_id FROM users WHERE server_id=? and nickname=?""", [server_id, nickname.lower()]) return value if value == None else value[0] class BotSettings(Table): - def set(self, setting, value): + def set(self, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO bot_settings VALUES (?, ?)", [setting.lower(), json.dumps(value)]) - def get(self, setting, default=None): + def get(self, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( "SELECT value FROM bot_settings WHERE setting=?", [setting.lower()]) if value: return json.loads(value[0]) return default - def find(self, pattern, default=[]): + def find(self, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( "SELECT setting, value FROM bot_settings WHERE setting LIKE ?", [pattern.lower()]) @@ -74,19 +76,19 @@ class BotSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, prefix, default=[]): + def find_prefix(self, prefix: str, default: typing.Any=[]): return self.find("%s%%" % prefix, default) - def delete(self, setting): + def delete(self, setting: str): self.database.execute( "DELETE FROM bot_settings WHERE setting=?", [setting.lower()]) class ServerSettings(Table): - def set(self, server_id, setting, value): + def set(self, server_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)", [server_id, setting.lower(), json.dumps(value)]) - def get(self, server_id, setting, default=None): + def get(self, server_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM server_settings WHERE server_id=? AND setting=?""", @@ -94,7 +96,7 @@ class ServerSettings(Table): if value: return json.loads(value[0]) return default - def find(self, server_id, pattern, default=[]): + def find(self, server_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM server_settings WHERE server_id=? AND setting LIKE ?""", @@ -104,26 +106,26 @@ class ServerSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, server_id, prefix, default=[]): + def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]): return self.find_server_settings(server_id, "%s%%" % prefix, default) - def delete(self, server_id, setting): + def delete(self, server_id: int, setting: str): self.database.execute( "DELETE FROM server_settings WHERE server_id=? AND setting=?", [server_id, setting.lower()]) class ChannelSettings(Table): - def set(self, channel_id, setting, value): + def set(self, channel_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)", [channel_id, setting.lower(), json.dumps(value)]) - def get(self, channel_id, setting, default=None): + def get(self, channel_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM channel_settings WHERE channel_id=? AND setting=?""", [channel_id, setting.lower()]) if value: return json.loads(value[0]) return default - def find(self, channel_id, pattern, default=[]): + def find(self, channel_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM channel_settings WHERE channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()]) @@ -132,15 +134,15 @@ class ChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, channel_id, prefix, default=[]): + def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]): return self.find_channel_settings(channel_id, "%s%%" % prefix, default) - def delete(self, channel_id, setting): + def delete(self, channel_id: int, setting: str): self.database.execute( """DELETE FROM channel_settings WHERE channel_id=? AND setting=?""", [channel_id, setting.lower()]) - def find_by_setting(self, setting, default=[]): + def find_by_setting(self, setting: str, default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.server_id, channels.name, channel_settings.value FROM channel_settings @@ -154,18 +156,19 @@ class ChannelSettings(Table): return default class UserSettings(Table): - def set(self, user_id, setting, value): + def set(self, user_id: int, setting: str, value: typing.Any): self.database.execute( "INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)", [user_id, setting.lower(), json.dumps(value)]) - def get(self, user_id, setting, default=None): + def get(self, user_id: int, setting: str, default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM user_settings WHERE user_id=? and setting=?""", [user_id, setting.lower()]) if value: return json.loads(value[0]) return default - def find_all_by_setting(self, server_id, setting, default=[]): + def find_all_by_setting(self, server_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT users.nickname, user_settings.value FROM user_settings INNER JOIN users ON @@ -177,7 +180,7 @@ class UserSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find(self, user_id, pattern, default=[]): + def find(self, user_id: int, pattern: str, default: typing.Any=[]): values = self.database.execute( """SELECT setting, value FROM user_settings WHERE user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()]) @@ -186,20 +189,22 @@ class UserSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, user_id, prefix, default=[]): + def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]): return self.find_user_settings(user_id, "%s%%" % prefix, default) - def delete(self, user_id, setting): + def delete(self, user_id: int, setting: str): self.database.execute( """DELETE FROM user_settings WHERE user_id=? AND setting=?""", [user_id, setting.lower()]) class UserChannelSettings(Table): - def set(self, user_id, channel_id, setting, value): + def set(self, user_id: int, channel_id: int, setting: str, + value: typing.Any): self.database.execute( """INSERT OR REPLACE INTO user_channel_settings VALUES (?, ?, ?, ?)""", [user_id, channel_id, setting.lower(), json.dumps(value)]) - def get(self, user_id, channel_id, setting, default=None): + def get(self, user_id: int, channel_id: int, setting: str, + default: typing.Any=None): value = self.database.execute_fetchone( """SELECT value FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting=?""", @@ -207,7 +212,8 @@ class UserChannelSettings(Table): if value: return json.loads(value[0]) return default - def find(self, user_id, channel_id, pattern, default=[]): + def find(self, user_id: int, channel_id: int, pattern: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT setting, value FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting LIKE '?'""", @@ -217,10 +223,12 @@ class UserChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_prefix(self, user_id, channel_id, prefix, default=[]): + def find_prefix(self, user_id: int, channel_id: int, prefix: str, + default: typing.Any=[]): return self.find_user_settings(user_id, channel_id, "%s%%" % prefix, default) - def find_by_setting(self, user_id, setting, default=[]): + def find_by_setting(self, user_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.name, user_channel_settings.value FROM user_channel_settings INNER JOIN channels ON @@ -232,7 +240,8 @@ class UserChannelSettings(Table): values[i] = value[0], json.loads(value[1]) return values return default - def find_all_by_setting(self, server_id, setting, default=[]): + def find_all_by_setting(self, server_id: int, setting: str, + default: typing.Any=[]): values = self.database.execute_fetchall( """SELECT channels.name, users.nickname, user_channel_settings.value FROM @@ -246,14 +255,14 @@ class UserChannelSettings(Table): values[i] = value[0], value[1], json.loads(value[2]) return values return default - def delete(self, user_id, channel_id, setting): + def delete(self, user_id: int, channel_id: int, setting: str): self.database.execute( """DELETE FROM user_channel_settings WHERE user_id=? AND channel_id=? AND setting=?""", [user_id, channel_id, setting.lower()]) class Database(object): - def __init__(self, log, location): + def __init__(self, log: "Logging.Log", location: str): self.log = log self.location = location self.database = sqlite3.connect(self.location, @@ -284,7 +293,9 @@ class Database(object): self._cursor = self.database.cursor() return self._cursor - def _execute_fetch(self, query, fetch_func, params=[]): + def _execute_fetch(self, query: str, + fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any], + params: typing.List=[]): printable_query = " ".join(query.split()) self.log.trace("executing query: \"%s\" (params: %s)", [printable_query, params]) @@ -299,16 +310,16 @@ class Database(object): self.log.trace("executed in %fms", [total_milliseconds]) return value - def execute_fetchall(self, query, params=[]): + def execute_fetchall(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: cursor.fetchall(), params) - def execute_fetchone(self, query, params=[]): + def execute_fetchone(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: cursor.fetchone(), params) - def execute(self, query, params=[]): + def execute(self, query: str, params: typing.List=[]): return self._execute_fetch(query, lambda cursor: None, params) - def has_table(self, table_name): + def has_table(self, table_name: str): result = self.execute_fetchone("""SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?""", [table_name]) |
