aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbitbotd50
-rw-r--r--docs/bot.conf.example12
-rw-r--r--src/Database.py40
-rw-r--r--src/DatabaseEngines.py64
4 files changed, 112 insertions, 54 deletions
diff --git a/bitbotd b/bitbotd
index dd032983..c5c3692f 100755
--- a/bitbotd
+++ b/bitbotd
@@ -25,16 +25,6 @@ arg_parser.add_argument("--version", "-v", action="store_true")
arg_parser.add_argument("--config", "-c", help="Location of config file",
default=os.path.join(default_data, "bot.conf"))
-arg_parser.add_argument("--data-dir", "-x",
- help="Location of data files (database, lock, socket)",
- default=default_data)
-
-arg_parser.add_argument("--database", "-d",
- help="Location of the sqlite3 database file")
-
-arg_parser.add_argument("--log-dir", "-l",
- help="Location of the log directory")
-
arg_parser.add_argument("--add-server", "-a",
help="Add a new server", action="store_true")
@@ -56,36 +46,30 @@ if args.version:
print("BitBot %s" % IRCBot.VERSION)
sys.exit(0)
-if not os.path.isdir(args.data_dir):
- os.mkdir(args.data_dir)
+config = Config.Config(args.config)
+config.load()
-database_location = None
-lock_location = None
-sock_locaiton = None
-log_directory = None
-if not args.database == None:
- database_location = args.database
- lock_location = "%s.lock" % args.database
- sock_location = "%s.sock" % args.database
-else:
- database_location = os.path.join(args.data_dir, "bot.db")
- lock_location = os.path.join(args.data_dir, "bot.lock")
- sock_location = os.path.join(args.data_dir, "bot.sock")
+DATA_DIR = os.path.expanduser(config["data-directory"])
+LOG_DIR = config.get("log-directory", "{DATA}/logs/").format(DATA=DATA_DIR)
+DATABASE = config.get("database", "sqlite3:{DATA}/bot.db").format(DATA=DATA_DIR)
+LOCK_FILE = config.get("lock-file", "{DATA}/bot.lock").format(DATA=DATA_DIR)
+SOCK_FILE = config.get("sock-file", "{DATA}/bot.sock").format(DATA=DATA_DIR)
-log_directory = args.log_dir or os.path.join(args.data_dir, "logs")
-if not os.path.isdir(log_directory):
- os.mkdir(log_directory)
+if not os.path.isdir(DATA_DIR):
+ os.mkdir(DATA_DIR)
+if not os.path.isdir(LOG_DIR):
+ os.mkdir(LOG_DIR)
log_level = args.log_level
if not log_level:
log_level = "debug" if args.verbose else "warn"
-log = Logging.Log(not args.no_logging, log_level, log_directory)
+log = Logging.Log(not args.no_logging, log_level, LOG_DIR)
log.info("Starting BitBot %s (Python v%s, db %s)",
- [IRCBot.VERSION, platform.python_version(), database_location])
+ [IRCBot.VERSION, platform.python_version(), DATABASE])
-lock_file = LockFile.LockFile(lock_location)
+lock_file = LockFile.LockFile(LOCK_FILE)
if not lock_file.available():
log.critical("Database is locked. Is BitBot already running?")
sys.exit(utils.consts.Exit.LOCKED)
@@ -93,7 +77,7 @@ if not lock_file.available():
atexit.register(lock_file.unlock)
lock_file.lock()
-database = Database.Database(log, database_location)
+database = Database.Database(log, DATABASE)
if args.remove_server:
alias = args.remove_server
@@ -117,8 +101,6 @@ if args.add_server:
sys.exit(0)
cache = Cache.Cache()
-config = Config.Config(args.config)
-config.load()
events = EventManager.EventRoot(log).wrap()
exports = Exports.Exports()
timers = Timers.Timers(database, events, log)
@@ -139,7 +121,7 @@ bot.add_poll_hook(cache)
bot.add_poll_hook(lock_file)
bot.add_poll_hook(timers)
-control = Control.Control(bot, sock_location)
+control = Control.Control(bot, SOCK_FILE)
control.bind()
bot.add_poll_source(control)
diff --git a/docs/bot.conf.example b/docs/bot.conf.example
index f22d4fc4..6c7808dc 100644
--- a/docs/bot.conf.example
+++ b/docs/bot.conf.example
@@ -2,6 +2,18 @@
# will be disabled.
[bot]
+
+# configuration related to where/how bitbot accesses files and databases.
+# commented out values are the default values. {DATA} is replaced with data-directory
+
+#data-directory = ~/.bitbot
+#log-directory = /var/log/bitbot/
+#lock-file = {DATA}/bot.lock
+#sock-file = {DATA}/bot.sock
+
+# database - currently only supports sqlite3
+#database = sqlite3:{DATA}/bot.db
+
# client-side tls key/cert for IRC connections
tls-key =
tls-certificate =
diff --git a/src/Database.py b/src/Database.py
index 683de5a6..6c79bb1d 100644
--- a/src/Database.py
+++ b/src/Database.py
@@ -1,7 +1,8 @@
-import json, os, sqlite3, threading, time, typing
+import json, os, threading, time, typing, urllib.parse
from src import Logging, utils
-sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v)))
+from .DatabaseEngines import DatabaseEngine, DatabaseEngineCursor
+from .DatabaseEngines import SQLite3Engine
class Table(object):
def __init__(self, database):
@@ -297,14 +298,21 @@ class UserChannelSettings(Table):
[user_id, channel_id, setting.lower()])
class Database(object):
- def __init__(self, log: "Logging.Log", location: str):
+ _engine: DatabaseEngine
+
+ def __init__(self, log: "Logging.Log", database: str):
+ db_parts = urllib.parse.urlparse(database)
+
+ if db_parts.scheme == "sqlite3":
+ self._engine = SQLite3Engine()
+ else:
+ raise ValueError("Unknown database engine '%s'" % db_parts.scheme)
+ self._engine.config(hostname=db_parts.hostname, port=db_parts.port,
+ path=db_parts.path, username=db_parts.username,
+ password=db_parts.password)
+ self._engine.connect()
+
self.log = log
- self.location = location
- self.database = sqlite3.connect(self.location,
- check_same_thread=False, isolation_level=None,
- detect_types=sqlite3.PARSE_DECLTYPES)
- self.database.execute("PRAGMA foreign_keys = ON")
- self._cursor = None
self._lock = threading.Lock()
self.make_servers_table()
@@ -325,13 +333,8 @@ class Database(object):
self.user_settings = UserSettings(self)
self.user_channel_settings = UserChannelSettings(self)
- def cursor(self):
- if self._cursor == None:
- self._cursor = self.database.cursor()
- return self._cursor
-
def _execute_fetch(self, query: str,
- fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any],
+ fetch_func: typing.Callable[[DatabaseEngineCursor], typing.Any],
params: typing.List=[]):
if not utils.is_main_thread():
raise RuntimeError("Can't access Database outside of main thread")
@@ -339,7 +342,7 @@ class Database(object):
printable_query = " ".join(query.split())
start = time.monotonic()
- cursor = self.cursor()
+ cursor = self._engine.cursor()
with self._lock:
cursor.execute(query, params)
value = fetch_func(cursor)
@@ -360,10 +363,7 @@ class Database(object):
return self._execute_fetch(query, lambda cursor: None, params)
def has_table(self, table_name: str):
- result = self.execute_fetchone("""SELECT COUNT(*) FROM
- sqlite_master WHERE type='table' AND name=?""",
- [table_name])
- return result[0] == 1
+ return self._engine.has_table(table_name)
def make_servers_table(self):
if not self.has_table("servers"):
diff --git a/src/DatabaseEngines.py b/src/DatabaseEngines.py
new file mode 100644
index 00000000..165e64d5
--- /dev/null
+++ b/src/DatabaseEngines.py
@@ -0,0 +1,64 @@
+import dataclasses, typing
+import sqlite3
+
+class DatabaseEngineCursor(object):
+ def execute(self, query: str, args: typing.List[str]):
+ pass
+ def fetchone(self) -> typing.Any:
+ pass
+ def fetchall(self) -> typing.List[typing.Any]:
+ pass
+
+class DatabaseEngine(object):
+ def config(self, hostname: str=None, port: int=None, path: str=None,
+ username: str=None, password: str=None):
+ self.hostname = hostname
+ self.port = port
+ self.path = path
+ self.username = username
+ self.password = password
+
+ def database_name(self):
+ return self.path
+ def connect(self):
+ pass
+ def cursor(self) -> DatabaseEngineCursor:
+ pass
+ def has_table(self, name: str):
+ pass
+
+ def execute(self, query: str, args: typing.List[str]):
+ pass
+ def fetchone(self, query: str, args: typing.List[str]):
+ pass
+ def fetchall(self, query: str, args: typing.List[str]):
+ pass
+
+class SQLite3Cursor(DatabaseEngineCursor):
+ def __init__(self, cursor: sqlite3.Cursor):
+ self._cursor = cursor
+ def execute(self, query: str, args: typing.List[str]):
+ self._cursor.execute(query, args)
+ def fetchone(self):
+ return self._cursor.fetchone()
+ def fetchall(self):
+ return self._cursor.fetchall()
+class SQLite3Engine(DatabaseEngine):
+ _connection: sqlite3.Connection
+
+ def connect(self):
+ sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v)))
+ self._connection = sqlite3.connect(self.path,
+ check_same_thread=False, isolation_level=None,
+ detect_types=sqlite3.PARSE_DECLTYPES)
+ self._connection.execute("PRAGMA foreign_keys = ON")
+
+ def has_table(self, name: str):
+ cursor = self.cursor()
+ cursor.execute(
+ "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
+ [name])
+ return cursor.fetchone()[0] == 1
+
+ def cursor(self):
+ return SQLite3Cursor(self._connection.cursor())