From d627ed49e22ba9c57e50cd6de748a9f26fc03720 Mon Sep 17 00:00:00 2001 From: jesopo Date: Mon, 25 Feb 2019 10:36:17 +0000 Subject: Pull "is main thread" logic out to utils, force Database to be accessed on main thread --- src/Database.py | 5 ++++- src/IRCBot.py | 3 ++- src/utils/__init__.py | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Database.py b/src/Database.py index 1ba9bac3..c498b4a8 100644 --- a/src/Database.py +++ b/src/Database.py @@ -1,5 +1,5 @@ import json, os, sqlite3, threading, time, typing -from src import Logging +from src import Logging, utils sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v))) @@ -309,6 +309,9 @@ class Database(object): def _execute_fetch(self, query: str, fetch_func: typing.Callable[[sqlite3.Cursor], typing.Any], params: typing.List=[]): + if not utils.is_main_thread(): + raise RuntimeError("Can't access Database outside of main thread") + printable_query = " ".join(query.split()) start = time.monotonic() diff --git a/src/IRCBot.py b/src/IRCBot.py index 2733db1f..c5809d42 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -38,7 +38,8 @@ class Bot(object): func: typing.Optional[typing.Callable[[], typing.Any]]=None ) -> typing.Any: func = func or (lambda: None) - if threading.current_thread() is threading.main_thread(): + + if utils.is_main_thread(): returned = func() self._trigger_client.send(b"TRIGGER") return returned diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 609b0eaa..44ba4587 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,4 +1,4 @@ -import datetime, decimal, enum, io, ipaddress, re, typing +import datetime, decimal, enum, io, ipaddress, re, threading, typing from src.utils import cli, consts, irc, http, parse, security class Direction(enum.Enum): @@ -199,3 +199,6 @@ def is_ip(s: str) -> bool: except ValueError: return False return True + +def is_main_thread() -> bool: + return threading.current_thread() is threading.main_thread() -- cgit v1.3.1-10-gc9f91