aboutsummaryrefslogtreecommitdiff
path: root/src/Database.py
diff options
context:
space:
mode:
authorGravatar jesopo2018-10-30 14:58:48 +0000
committerGravatar jesopo2018-10-30 14:58:48 +0000
commite07553c3627b80f20cdc81a35030bf0540924db8 (patch)
tree0a81640b280e007cbe5d2cb956681068ab80c58e /src/Database.py
parentDon't needlessly search a youtube URL before getting the information for it's (diff)
signature
Add type/return hints throughout src/ and, in doing so, fix some cyclical
references.
Diffstat (limited to 'src/Database.py')
-rw-r--r--src/Database.py101
1 files changed, 56 insertions, 45 deletions
diff --git a/src/Database.py b/src/Database.py
index c1d07869..c3d48cb6 100644
--- a/src/Database.py
+++ b/src/Database.py
@@ -1,12 +1,14 @@
-import json, os, sqlite3, threading, time
+import json, os, sqlite3, threading, time, typing
+from src import Logging
class Table(object):
def __init__(self, database):
self.database = database
class Servers(Table):
- def add(self, alias, hostname, port, password, ipv4, tls, bindhost,
- nickname, username=None, realname=None):
+ def add(self, alias: str, hostname: str, port: int, password: str,
+ ipv4: bool, tls: bool, bindhost: str,
+ nickname: str, username: str=None, realname: str=None):
username = username or nickname
realname = realname or nickname
self.database.execute(
@@ -18,7 +20,7 @@ class Servers(Table):
def get_all(self):
return self.database.execute_fetchall(
"SELECT server_id, alias FROM servers")
- def get(self, id):
+ def get(self, id: int):
return self.database.execute_fetchone(
"""SELECT server_id, alias, hostname, port, password, ipv4,
tls, bindhost, nickname, username, realname FROM servers WHERE
@@ -26,46 +28,46 @@ class Servers(Table):
[id])
class Channels(Table):
- def add(self, server_id, name):
+ def add(self, server_id: int, name: str):
self.database.execute("""INSERT OR IGNORE INTO channels
(server_id, name) VALUES (?, ?)""",
[server_id, name.lower()])
- def delete(self, channel_id):
+ def delete(self, channel_id: int):
self.database.execute("DELETE FROM channels WHERE channel_id=?",
[channel_id])
- def get_id(self, server_id, name):
+ def get_id(self, server_id: int, name: str):
value = self.database.execute_fetchone("""SELECT channel_id FROM
channels WHERE server_id=? AND name=?""",
[server_id, name.lower()])
return value if value == None else value[0]
class Users(Table):
- def add(self, server_id, nickname):
+ def add(self, server_id: int, nickname: str):
self.database.execute("""INSERT OR IGNORE INTO users
(server_id, nickname) VALUES (?, ?)""",
[server_id, nickname.lower()])
- def delete(self, user_id):
+ def delete(self, user_id: int):
self.database.execute("DELETE FROM users WHERE user_id=?",
[user_id])
- def get_id(self, server_id, nickname):
+ def get_id(self, server_id: int, nickname: str):
value = self.database.execute_fetchone("""SELECT user_id FROM
users WHERE server_id=? and nickname=?""",
[server_id, nickname.lower()])
return value if value == None else value[0]
class BotSettings(Table):
- def set(self, setting, value):
+ def set(self, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO bot_settings VALUES (?, ?)",
[setting.lower(), json.dumps(value)])
- def get(self, setting, default=None):
+ def get(self, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"SELECT value FROM bot_settings WHERE setting=?",
[setting.lower()])
if value:
return json.loads(value[0])
return default
- def find(self, pattern, default=[]):
+ def find(self, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"SELECT setting, value FROM bot_settings WHERE setting LIKE ?",
[pattern.lower()])
@@ -74,19 +76,19 @@ class BotSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, prefix, default=[]):
+ def find_prefix(self, prefix: str, default: typing.Any=[]):
return self.find("%s%%" % prefix, default)
- def delete(self, setting):
+ def delete(self, setting: str):
self.database.execute(
"DELETE FROM bot_settings WHERE setting=?",
[setting.lower()])
class ServerSettings(Table):
- def set(self, server_id, setting, value):
+ def set(self, server_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO server_settings VALUES (?, ?, ?)",
[server_id, setting.lower(), json.dumps(value)])
- def get(self, server_id, setting, default=None):
+ def get(self, server_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM server_settings WHERE
server_id=? AND setting=?""",
@@ -94,7 +96,7 @@ class ServerSettings(Table):
if value:
return json.loads(value[0])
return default
- def find(self, server_id, pattern, default=[]):
+ def find(self, server_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM server_settings WHERE
server_id=? AND setting LIKE ?""",
@@ -104,26 +106,26 @@ class ServerSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, server_id, prefix, default=[]):
+ def find_prefix(self, server_id: int, prefix: str, default: typing.Any=[]):
return self.find_server_settings(server_id, "%s%%" % prefix, default)
- def delete(self, server_id, setting):
+ def delete(self, server_id: int, setting: str):
self.database.execute(
"DELETE FROM server_settings WHERE server_id=? AND setting=?",
[server_id, setting.lower()])
class ChannelSettings(Table):
- def set(self, channel_id, setting, value):
+ def set(self, channel_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO channel_settings VALUES (?, ?, ?)",
[channel_id, setting.lower(), json.dumps(value)])
- def get(self, channel_id, setting, default=None):
+ def get(self, channel_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM channel_settings WHERE
channel_id=? AND setting=?""", [channel_id, setting.lower()])
if value:
return json.loads(value[0])
return default
- def find(self, channel_id, pattern, default=[]):
+ def find(self, channel_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM channel_settings WHERE
channel_id=? setting LIKE '?'""", [channel_id, pattern.lower()])
@@ -132,15 +134,15 @@ class ChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, channel_id, prefix, default=[]):
+ def find_prefix(self, channel_id: int, prefix: str, default: typing.Any=[]):
return self.find_channel_settings(channel_id, "%s%%" % prefix,
default)
- def delete(self, channel_id, setting):
+ def delete(self, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM channel_settings WHERE channel_id=?
AND setting=?""", [channel_id, setting.lower()])
- def find_by_setting(self, setting, default=[]):
+ def find_by_setting(self, setting: str, default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.server_id, channels.name,
channel_settings.value FROM channel_settings
@@ -154,18 +156,19 @@ class ChannelSettings(Table):
return default
class UserSettings(Table):
- def set(self, user_id, setting, value):
+ def set(self, user_id: int, setting: str, value: typing.Any):
self.database.execute(
"INSERT OR REPLACE INTO user_settings VALUES (?, ?, ?)",
[user_id, setting.lower(), json.dumps(value)])
- def get(self, user_id, setting, default=None):
+ def get(self, user_id: int, setting: str, default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_settings WHERE
user_id=? and setting=?""", [user_id, setting.lower()])
if value:
return json.loads(value[0])
return default
- def find_all_by_setting(self, server_id, setting, default=[]):
+ def find_all_by_setting(self, server_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT users.nickname, user_settings.value FROM
user_settings INNER JOIN users ON
@@ -177,7 +180,7 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find(self, user_id, pattern, default=[]):
+ def find(self, user_id: int, pattern: str, default: typing.Any=[]):
values = self.database.execute(
"""SELECT setting, value FROM user_settings WHERE
user_id=? AND setting LIKE '?'""", [user_id, pattern.lower()])
@@ -186,20 +189,22 @@ class UserSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, user_id, prefix, default=[]):
+ def find_prefix(self, user_id: int, prefix: str, default: typing.Any=[]):
return self.find_user_settings(user_id, "%s%%" % prefix, default)
- def delete(self, user_id, setting):
+ def delete(self, user_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_settings WHERE
user_id=? AND setting=?""", [user_id, setting.lower()])
class UserChannelSettings(Table):
- def set(self, user_id, channel_id, setting, value):
+ def set(self, user_id: int, channel_id: int, setting: str,
+ value: typing.Any):
self.database.execute(
"""INSERT OR REPLACE INTO user_channel_settings VALUES
(?, ?, ?, ?)""",
[user_id, channel_id, setting.lower(), json.dumps(value)])
- def get(self, user_id, channel_id, setting, default=None):
+ def get(self, user_id: int, channel_id: int, setting: str,
+ default: typing.Any=None):
value = self.database.execute_fetchone(
"""SELECT value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
@@ -207,7 +212,8 @@ class UserChannelSettings(Table):
if value:
return json.loads(value[0])
return default
- def find(self, user_id, channel_id, pattern, default=[]):
+ def find(self, user_id: int, channel_id: int, pattern: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT setting, value FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting LIKE '?'""",
@@ -217,10 +223,12 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_prefix(self, user_id, channel_id, prefix, default=[]):
+ def find_prefix(self, user_id: int, channel_id: int, prefix: str,
+ default: typing.Any=[]):
return self.find_user_settings(user_id, channel_id, "%s%%" % prefix,
default)
- def find_by_setting(self, user_id, setting, default=[]):
+ def find_by_setting(self, user_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, user_channel_settings.value FROM
user_channel_settings INNER JOIN channels ON
@@ -232,7 +240,8 @@ class UserChannelSettings(Table):
values[i] = value[0], json.loads(value[1])
return values
return default
- def find_all_by_setting(self, server_id, setting, default=[]):
+ def find_all_by_setting(self, server_id: int, setting: str,
+ default: typing.Any=[]):
values = self.database.execute_fetchall(
"""SELECT channels.name, users.nickname,
user_channel_settings.value FROM
@@ -246,14 +255,14 @@ class UserChannelSettings(Table):
values[i] = value[0], value[1], json.loads(value[2])
return values
return default
- def delete(self, user_id, channel_id, setting):
+ def delete(self, user_id: int, channel_id: int, setting: str):
self.database.execute(
"""DELETE FROM user_channel_settings WHERE
user_id=? AND channel_id=? AND setting=?""",
[user_id, channel_id, setting.lower()])
class Database(object):
- def __init__(self, log, location):
+ def __init__(self, log: "Logging.Log", location: str):
self.log = log
self.location = location
self.database = sqlite3.connect(self.location,
@@ -284,7 +293,9 @@ class Database(object):
self._cursor = self.database.cursor()
return self._cursor
- def _execute_fetch(self, query, fetch_func, params=[]):
+ def _execute_fetch(self, query: str,
+ fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
+ params: typing.List=[]):
printable_query = " ".join(query.split())
self.log.trace("executing query: \"%s\" (params: %s)",
[printable_query, params])
@@ -299,16 +310,16 @@ class Database(object):
self.log.trace("executed in %fms", [total_milliseconds])
return value
- def execute_fetchall(self, query, params=[]):
+ def execute_fetchall(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchall(), params)
- def execute_fetchone(self, query, params=[]):
+ def execute_fetchone(self, query: str, params: typing.List=[]):
return self._execute_fetch(query,
lambda cursor: cursor.fetchone(), params)
- def execute(self, query, params=[]):
+ def execute(self, query: str, params: typing.List=[]):
return self._execute_fetch(query, lambda cursor: None, params)
- def has_table(self, table_name):
+ def has_table(self, table_name: str):
result = self.execute_fetchone("""SELECT COUNT(*) FROM
sqlite_master WHERE type='table' AND name=?""",
[table_name])