diff options
| -rwxr-xr-x | bitbotd | 50 | ||||
| -rw-r--r-- | docs/bot.conf.example | 12 | ||||
| -rw-r--r-- | src/Database.py | 40 | ||||
| -rw-r--r-- | src/DatabaseEngines.py | 64 |
4 files changed, 112 insertions, 54 deletions
@@ -25,16 +25,6 @@ arg_parser.add_argument("--version", "-v", action="store_true") arg_parser.add_argument("--config", "-c", help="Location of config file", default=os.path.join(default_data, "bot.conf")) -arg_parser.add_argument("--data-dir", "-x", - help="Location of data files (database, lock, socket)", - default=default_data) - -arg_parser.add_argument("--database", "-d", - help="Location of the sqlite3 database file") - -arg_parser.add_argument("--log-dir", "-l", - help="Location of the log directory") - arg_parser.add_argument("--add-server", "-a", help="Add a new server", action="store_true") @@ -56,36 +46,30 @@ if args.version: print("BitBot %s" % IRCBot.VERSION) sys.exit(0) -if not os.path.isdir(args.data_dir): - os.mkdir(args.data_dir) +config = Config.Config(args.config) +config.load() -database_location = None -lock_location = None -sock_locaiton = None -log_directory = None -if not args.database == None: - database_location = args.database - lock_location = "%s.lock" % args.database - sock_location = "%s.sock" % args.database -else: - database_location = os.path.join(args.data_dir, "bot.db") - lock_location = os.path.join(args.data_dir, "bot.lock") - sock_location = os.path.join(args.data_dir, "bot.sock") +DATA_DIR = os.path.expanduser(config["data-directory"]) +LOG_DIR = config.get("log-directory", "{DATA}/logs/").format(DATA=DATA_DIR) +DATABASE = config.get("database", "sqlite3:{DATA}/bot.db").format(DATA=DATA_DIR) +LOCK_FILE = config.get("lock-file", "{DATA}/bot.lock").format(DATA=DATA_DIR) +SOCK_FILE = config.get("sock-file", "{DATA}/bot.sock").format(DATA=DATA_DIR) -log_directory = args.log_dir or os.path.join(args.data_dir, "logs") -if not os.path.isdir(log_directory): - os.mkdir(log_directory) +if not os.path.isdir(DATA_DIR): + os.mkdir(DATA_DIR) +if not os.path.isdir(LOG_DIR): + os.mkdir(LOG_DIR) log_level = args.log_level if not log_level: log_level = "debug" if args.verbose else "warn" -log = Logging.Log(not args.no_logging, log_level, log_directory) +log = Logging.Log(not args.no_logging, log_level, LOG_DIR) log.info("Starting BitBot %s (Python v%s, db %s)", - [IRCBot.VERSION, platform.python_version(), database_location]) + [IRCBot.VERSION, platform.python_version(), DATABASE]) -lock_file = LockFile.LockFile(lock_location) +lock_file = LockFile.LockFile(LOCK_FILE) if not lock_file.available(): log.critical("Database is locked. Is BitBot already running?") sys.exit(utils.consts.Exit.LOCKED) @@ -93,7 +77,7 @@ if not lock_file.available(): atexit.register(lock_file.unlock) lock_file.lock() -database = Database.Database(log, database_location) +database = Database.Database(log, DATABASE) if args.remove_server: alias = args.remove_server @@ -117,8 +101,6 @@ if args.add_server: sys.exit(0) cache = Cache.Cache() -config = Config.Config(args.config) -config.load() events = EventManager.EventRoot(log).wrap() exports = Exports.Exports() timers = Timers.Timers(database, events, log) @@ -139,7 +121,7 @@ bot.add_poll_hook(cache) bot.add_poll_hook(lock_file) bot.add_poll_hook(timers) -control = Control.Control(bot, sock_location) +control = Control.Control(bot, SOCK_FILE) control.bind() bot.add_poll_source(control) diff --git a/docs/bot.conf.example b/docs/bot.conf.example index f22d4fc4..6c7808dc 100644 --- a/docs/bot.conf.example +++ b/docs/bot.conf.example @@ -2,6 +2,18 @@ # will be disabled. [bot] + +# configuration related to where/how bitbot accesses files and databases. +# commented out values are the default values. {DATA} is replaced with data-directory + +#data-directory = ~/.bitbot +#log-directory = /var/log/bitbot/ +#lock-file = {DATA}/bot.lock +#sock-file = {DATA}/bot.sock + +# database - currently only supports sqlite3 +#database = sqlite3:{DATA}/bot.db + # client-side tls key/cert for IRC connections tls-key = tls-certificate = diff --git a/src/Database.py b/src/Database.py index 683de5a6..6c79bb1d 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,7 +1,8 @@ -import json, os, sqlite3, threading, time, typing +import json, os, threading, time, typing, urllib.parse from src import Logging, utils -sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v))) +from .DatabaseEngines import DatabaseEngine, DatabaseEngineCursor +from .DatabaseEngines import SQLite3Engine class Table(object): def __init__(self, database): @@ -297,14 +298,21 @@ class UserChannelSettings(Table): [user_id, channel_id, setting.lower()]) class Database(object): - def __init__(self, log: "Logging.Log", location: str): + _engine: DatabaseEngine + + def __init__(self, log: "Logging.Log", database: str): + db_parts = urllib.parse.urlparse(database) + + if db_parts.scheme == "sqlite3": + self._engine = SQLite3Engine() + else: + raise ValueError("Unknown database engine '%s'" % db_parts.scheme) + self._engine.config(hostname=db_parts.hostname, port=db_parts.port, + path=db_parts.path, username=db_parts.username, + password=db_parts.password) + self._engine.connect() + self.log = log - self.location = location - self.database = sqlite3.connect(self.location, - check_same_thread=False, isolation_level=None, - detect_types=sqlite3.PARSE_DECLTYPES) - self.database.execute("PRAGMA foreign_keys = ON") - self._cursor = None self._lock = threading.Lock() self.make_servers_table() @@ -325,13 +333,8 @@ class Database(object): self.user_settings = UserSettings(self) self.user_channel_settings = UserChannelSettings(self) - def cursor(self): - if self._cursor == None: - self._cursor = self.database.cursor() - return self._cursor - def _execute_fetch(self, query: str, - fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any], + fetch_func: typing.Callable[[DatabaseEngineCursor], typing.Any], params: typing.List=[]): if not utils.is_main_thread(): raise RuntimeError("Can't access Database outside of main thread") @@ -339,7 +342,7 @@ class Database(object): printable_query = " ".join(query.split()) start = time.monotonic() - cursor = self.cursor() + cursor = self._engine.cursor() with self._lock: cursor.execute(query, params) value = fetch_func(cursor) @@ -360,10 +363,7 @@ class Database(object): return self._execute_fetch(query, lambda cursor: None, params) def has_table(self, table_name: str): - result = self.execute_fetchone("""SELECT COUNT(*) FROM - sqlite_master WHERE type='table' AND name=?""", - [table_name]) - return result[0] == 1 + return self._engine.has_table(table_name) def make_servers_table(self): if not self.has_table("servers"): diff --git a/src/DatabaseEngines.py b/src/DatabaseEngines.py new file mode 100644 index 00000000..165e64d5 --- /dev/null +++ b/src/DatabaseEngines.py @@ -0,0 +1,64 @@ +import dataclasses, typing +import sqlite3 + +class DatabaseEngineCursor(object): + def execute(self, query: str, args: typing.List[str]): + pass + def fetchone(self) -> typing.Any: + pass + def fetchall(self) -> typing.List[typing.Any]: + pass + +class DatabaseEngine(object): + def config(self, hostname: str=None, port: int=None, path: str=None, + username: str=None, password: str=None): + self.hostname = hostname + self.port = port + self.path = path + self.username = username + self.password = password + + def database_name(self): + return self.path + def connect(self): + pass + def cursor(self) -> DatabaseEngineCursor: + pass + def has_table(self, name: str): + pass + + def execute(self, query: str, args: typing.List[str]): + pass + def fetchone(self, query: str, args: typing.List[str]): + pass + def fetchall(self, query: str, args: typing.List[str]): + pass + +class SQLite3Cursor(DatabaseEngineCursor): + def __init__(self, cursor: sqlite3.Cursor): + self._cursor = cursor + def execute(self, query: str, args: typing.List[str]): + self._cursor.execute(query, args) + def fetchone(self): + return self._cursor.fetchone() + def fetchall(self): + return self._cursor.fetchall() +class SQLite3Engine(DatabaseEngine): + _connection: sqlite3.Connection + + def connect(self): + sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v))) + self._connection = sqlite3.connect(self.path, + check_same_thread=False, isolation_level=None, + detect_types=sqlite3.PARSE_DECLTYPES) + self._connection.execute("PRAGMA foreign_keys = ON") + + def has_table(self, name: str): + cursor = self.cursor() + cursor.execute( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", + [name]) + return cursor.fetchone()[0] == 1 + + def cursor(self): + return SQLite3Cursor(self._connection.cursor()) |
