aboutsummaryrefslogtreecommitdiff
path: root/src/Database.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/Database.py')
-rw-r--r--src/Database.py40
1 files changed, 20 insertions, 20 deletions
diff --git a/src/Database.py b/src/Database.py
index 683de5a6..6c79bb1d 100644
--- a/src/Database.py
+++ b/src/Database.py
@@ -1,7 +1,8 @@
-import json, os, sqlite3, threading, time, typing
+import json, os, threading, time, typing, urllib.parse
from src import Logging, utils
-sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v)))
+from .DatabaseEngines import DatabaseEngine, DatabaseEngineCursor
+from .DatabaseEngines import SQLite3Engine
class Table(object):
def __init__(self, database):
@@ -297,14 +298,21 @@ class UserChannelSettings(Table):
[user_id, channel_id, setting.lower()])
class Database(object):
- def __init__(self, log: "Logging.Log", location: str):
+ _engine: DatabaseEngine
+
+ def __init__(self, log: "Logging.Log", database: str):
+ db_parts = urllib.parse.urlparse(database)
+
+ if db_parts.scheme == "sqlite3":
+ self._engine = SQLite3Engine()
+ else:
+ raise ValueError("Unknown database engine '%s'" % db_parts.scheme)
+ self._engine.config(hostname=db_parts.hostname, port=db_parts.port,
+ path=db_parts.path, username=db_parts.username,
+ password=db_parts.password)
+ self._engine.connect()
+
self.log = log
- self.location = location
- self.database = sqlite3.connect(self.location,
- check_same_thread=False, isolation_level=None,
- detect_types=sqlite3.PARSE_DECLTYPES)
- self.database.execute("PRAGMA foreign_keys = ON")
- self._cursor = None
self._lock = threading.Lock()
self.make_servers_table()
@@ -325,13 +333,8 @@ class Database(object):
self.user_settings = UserSettings(self)
self.user_channel_settings = UserChannelSettings(self)
- def cursor(self):
- if self._cursor == None:
- self._cursor = self.database.cursor()
- return self._cursor
-
def _execute_fetch(self, query: str,
- fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
+ fetch_func: typing.Callable[[DatabaseEngineCursor], typing.Any],
params: typing.List=[]):
if not utils.is_main_thread():
raise RuntimeError("Can't access Database outside of main thread")
@@ -339,7 +342,7 @@ class Database(object):
printable_query = " ".join(query.split())
start = time.monotonic()
- cursor = self.cursor()
+ cursor = self._engine.cursor()
with self._lock:
cursor.execute(query, params)
value = fetch_func(cursor)
@@ -360,10 +363,7 @@ class Database(object):
return self._execute_fetch(query, lambda cursor: None, params)
def has_table(self, table_name: str):
- result = self.execute_fetchone("""SELECT COUNT(*) FROM
- sqlite_master WHERE type='table' AND name=?""",
- [table_name])
- return result[0] == 1
+ return self._engine.has_table(table_name)
def make_servers_table(self):
if not self.has_table("servers"):