diff options
Diffstat (limited to 'src/Database.py')
| -rw-r--r-- | src/Database.py | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/src/Database.py b/src/Database.py index 683de5a6..6c79bb1d 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,7 +1,8 @@ -import json, os, sqlite3, threading, time, typing +import json, os, threading, time, typing, urllib.parse from src import Logging, utils -sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v))) +from .DatabaseEngines import DatabaseEngine, DatabaseEngineCursor +from .DatabaseEngines import SQLite3Engine class Table(object): def __init__(self, database): @@ -297,14 +298,21 @@ class UserChannelSettings(Table): [user_id, channel_id, setting.lower()]) class Database(object): - def __init__(self, log: "Logging.Log", location: str): + _engine: DatabaseEngine + + def __init__(self, log: "Logging.Log", database: str): + db_parts = urllib.parse.urlparse(database) + + if db_parts.scheme == "sqlite3": + self._engine = SQLite3Engine() + else: + raise ValueError("Unknown database engine '%s'" % db_parts.scheme) + self._engine.config(hostname=db_parts.hostname, port=db_parts.port, + path=db_parts.path, username=db_parts.username, + password=db_parts.password) + self._engine.connect() + self.log = log - self.location = location - self.database = sqlite3.connect(self.location, - check_same_thread=False, isolation_level=None, - detect_types=sqlite3.PARSE_DECLTYPES) - self.database.execute("PRAGMA foreign_keys = ON") - self._cursor = None self._lock = threading.Lock() self.make_servers_table() @@ -325,13 +333,8 @@ class Database(object): self.user_settings = UserSettings(self) self.user_channel_settings = UserChannelSettings(self) - def cursor(self): - if self._cursor == None: - self._cursor = self.database.cursor() - return self._cursor - def _execute_fetch(self, query: str, - fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any], + fetch_func: typing.Callable[[DatabaseEngineCursor], typing.Any], params: typing.List=[]): if not utils.is_main_thread(): raise RuntimeError("Can't access Database outside of main thread") @@ -339,7 +342,7 @@ class Database(object): printable_query = " ".join(query.split()) start = time.monotonic() - cursor = self.cursor() + cursor = self._engine.cursor() with self._lock: cursor.execute(query, params) value = fetch_func(cursor) @@ -360,10 +363,7 @@ class Database(object): return self._execute_fetch(query, lambda cursor: None, params) def has_table(self, table_name: str): - result = self.execute_fetchone("""SELECT COUNT(*) FROM - sqlite_master WHERE type='table' AND name=?""", - [table_name]) - return result[0] == 1 + return self._engine.has_table(table_name) def make_servers_table(self): if not self.has_table("servers"): |
