aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Control.py109
-rw-r--r--src/IRCBot.py60
-rw-r--r--src/IRCSocket.py2
-rw-r--r--src/LockFile.py38
-rw-r--r--src/Logging.py28
-rw-r--r--src/PollSource.py12
-rw-r--r--src/utils/__init__.py6
-rw-r--r--src/utils/parse.py5
8 files changed, 241 insertions, 19 deletions
diff --git a/src/Control.py b/src/Control.py
new file mode 100644
index 00000000..5b59bfec
--- /dev/null
+++ b/src/Control.py
@@ -0,0 +1,109 @@
+import json, os, socket, typing
+from src import IRCBot, Logging, PollSource
+
+class ControlClient(object):
+ def __init__(self, sock: socket.socket):
+ self._socket = sock
+ self._read_buffer = b""
+ self._write_buffer = b""
+ self.version = -1
+ self.log_level = None # type: typing.Optional[int]
+
+ def fileno(self) -> int:
+ return self._socket.fileno()
+
+ def read_lines(self) -> typing.List[str]:
+ try:
+ data = self._socket.recv(2048)
+ except:
+ data = b""
+ if not data:
+ return None
+ lines = (self._read_buffer+data).split(b"\n")
+ lines = [line.strip(b"\r") for line in lines]
+ self._read_buffer = lines.pop(-1)
+ return [line.decode("utf8") for line in lines]
+
+ def write_line(self, line: str):
+ self._socket.send(("%s\n" % line).encode("utf8"))
+
+ def disconnect(self):
+ try:
+ self._socket.shutdown(socket.SHUT_RDWR)
+ except:
+ pass
+ try:
+ self._socket.close()
+ except:
+ pass
+
+
+class Control(PollSource.PollSource):
+ def __init__(self, bot: IRCBot.Bot, database_location: str):
+ self._bot = bot
+ self._bot.log.hook(self._on_log)
+
+ self._socket_location = "%s.sock" % database_location
+ self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self._clients = {}
+
+ def _on_log(self, levelno: int, line: str):
+ for client in self._clients.values():
+ if not client.log_level == None and client.log_level <= levelno:
+ self._send_action(client, "log", line)
+
+ def bind(self):
+ if os.path.exists(self._socket_location):
+ os.remove(self._socket_location)
+ self._socket.bind(self._socket_location)
+ self._socket.listen(1)
+
+ def get_readables(self) -> typing.List[int]:
+ return [self._socket.fileno()]+list(self._clients.keys())
+
+ def is_readable(self, fileno: int):
+ if fileno == self._socket.fileno():
+ client, address = self._socket.accept()
+ self._clients[client.fileno()] = ControlClient(client)
+ self._bot.log.debug("New control socket connected")
+ elif fileno in self._clients:
+ client = self._clients[fileno]
+ lines = client.read_lines()
+ if lines == None:
+ client.disconnect()
+ del self._clients[fileno]
+ else:
+ for line in lines:
+ response = self._parse_line(client, line)
+
+ def _parse_line(self, client: ControlClient, line: str):
+ id, _, command = line.partition(" ")
+ command, _, data = command.partition(" ")
+ if not id or not command:
+ client.disconnect()
+ return
+
+ command = command.lower()
+ response_action = "ack"
+ response_data = None
+
+ keepalive = True
+
+ if command == "version":
+ client.version = int(data)
+ elif command == "log":
+ client.log_level = Logging.LEVELS[data.lower()]
+ elif command == "rehash":
+ self._bot.log.info("Reloading config file")
+ self._bot.config.load()
+ self._bot.log.info("Reloaded config file")
+ keepalive = False
+
+ self._send_action(client, response_action, response_data, id)
+ if not keepalive:
+ client.disconnect()
+
+ def _send_action(self, client: ControlClient, action: str, data: str,
+ id: int=None):
+ client.write_line(
+ json.dumps({"action": action, "data": data, "id": id}))
diff --git a/src/IRCBot.py b/src/IRCBot.py
index df0b0540..8eb1873e 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,9 +1,10 @@
import enum, queue, os, queue, select, socket, sys, threading, time, traceback
import typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
-from src import PollHook, Socket, Timers, utils
+from src import PollHook, PollSource, Socket, Timers, utils
-VERSION = "v1.12.0-rc1"
+with open("VERSION", "r") as version_file:
+ VERSION = "v%s" % version_file.read().strip()
SOURCE = "https://git.io/bitbot"
URL = "https://bitbot.dev"
@@ -68,10 +69,7 @@ class Bot(object):
self._read_thread = None
self._write_thread = None
- self._poll_timeouts = [] # typing.List[PollHook]
- self._poll_timeouts.append(self._timers)
- self._poll_timeouts.append(self.cache)
-
+ self._poll_timeouts = [] # typing.List[PollHook.PollHook]
self._poll_timeouts.append(ListLambdaPollHook(
lambda: self.servers.values(),
lambda server: server.until_read_timeout()))
@@ -83,6 +81,13 @@ class Bot(object):
self._poll_timeouts.append(ListLambdaPollHook(
lambda: self.servers.values(), self._throttle_timeout))
+ self._poll_sources = [] # typing.List[PollSource.PollSource]
+
+ def add_poll_hook(self, hook: PollHook.PollHook):
+ self._poll_timeouts.append(hook)
+ def add_poll_source(self, source: PollSource.PollSource):
+ self._poll_sources.append(source)
+
def _throttle_timeout(self, server: IRCServer.Server):
if server.socket.waiting_throttled_send():
return server.socket.send_throttle_timeout()
@@ -276,11 +281,6 @@ class Bot(object):
def _event_loop(self):
while self.running or not self._event_queue.empty():
- if not self.servers and self._event_queue.empty():
- self._kill()
- self.log.warn("No servers, exiting")
- break
-
try:
item = self._event_queue.get(block=True,
timeout=self.get_poll_timeout())
@@ -316,16 +316,24 @@ class Bot(object):
def _write_loop(self):
while self.running:
+ poll_sources = {}
with self._write_condition:
- writeable = False
+ fds = []
for fd, server in self.servers.items():
if server.socket.waiting_immediate_send():
- self._write_poll.register(fd, select.POLLOUT)
- writeable = True
+ fds.append(fd)
+
+ for poll_source in self._poll_sources:
+ for fileno in poll_source.get_writables():
+ poll_sources[fileno] = poll_source
+ fds.append(fileno)
- if not writeable:
+ if not fds:
self._write_condition.wait()
continue
+ else:
+ for fd in fds:
+ self._write_poll.register(fd, select.POLLOUT)
events = self._write_poll.poll()
@@ -344,9 +352,25 @@ class Bot(object):
event_item = TriggerEvent(TriggerEventType.Action,
self._post_send_factory(server, lines))
self._event_queue.put(event_item)
+ elif fd in poll_sources:
+ poll_sources[fd].is_writeable(fd)
def _read_loop(self):
+ poll_sources = {}
while self.running:
+ new_poll_sources = {}
+ for poll_source in self._poll_sources:
+ for fileno in poll_source.get_readables():
+ new_poll_sources[fileno] = poll_source
+ for fileno in new_poll_sources:
+ if not fileno in poll_sources:
+ poll_sources[fileno] = new_poll_sources[fileno]
+ self._read_poll.register(fileno, select.POLLIN)
+ for fileno in list(poll_sources.keys()):
+ if not fileno in new_poll_sources:
+ del poll_sources[fileno]
+ self._read_poll.unregister(fileno)
+
events = self._read_poll.poll()
for fd, event in events:
@@ -355,6 +379,9 @@ class Bot(object):
with self._rtrigger_lock:
self._rtrigger_server.recv(1024)
self._rtriggered = False
+ elif fd in poll_sources:
+ poll_sources[fd].is_readable(fd)
+ self.trigger_write()
else:
if not fd in self.servers:
self._read_poll.unregister(fd)
@@ -376,7 +403,8 @@ class Bot(object):
def _check(self):
for poll_timeout in self._poll_timeouts:
- poll_timeout.call()
+ if poll_timeout.next() == 0:
+ poll_timeout.call()
throttle_filled = False
for server in list(self.servers.values()):
diff --git a/src/IRCSocket.py b/src/IRCSocket.py
index 98091aff..7fbae39d 100644
--- a/src/IRCSocket.py
+++ b/src/IRCSocket.py
@@ -44,7 +44,7 @@ class Socket(IRCObject.Object):
self.last_send = None # type: typing.Optional[float]
self.connected_ip = None # type: typing.Optional[str]
- self.conncect_time: float = -1
+ self.connect_time: float = -1
def fileno(self) -> int:
return self.cached_fileno or self._socket.fileno()
diff --git a/src/LockFile.py b/src/LockFile.py
new file mode 100644
index 00000000..3436eddf
--- /dev/null
+++ b/src/LockFile.py
@@ -0,0 +1,38 @@
+import datetime, os
+from src import PollHook, utils
+
+EXPIRATION = 60 # 1 minute
+
+class LockFile(PollHook.PollHook):
+ def __init__(self, database_location: str):
+ self._lock_location = "%s.lock" % database_location
+ self._next_lock = None
+
+ def available(self):
+ now = utils.datetime_utcnow()
+ if os.path.exists(self._lock_location):
+ with open(self._lock_location, "r") as lock_file:
+ timestamp_str = lock_file.read().strip().split(" ", 1)[0]
+
+ timestamp = utils.iso8601_parse(timestamp_str)
+
+ if (now-timestamp).total_seconds() < EXPIRATION:
+ return False
+
+ return True
+
+ def lock(self):
+ with open(self._lock_location, "w") as lock_file:
+ last_lock = utils.datetime_utcnow()
+ lock_file.write("%s" % utils.iso8601_format(last_lock))
+ self._next_lock = last_lock+datetime.timedelta(
+ seconds=EXPIRATION/2)
+
+ def next(self):
+ return max(0, (self._next_lock-utils.datetime_utcnow()).total_seconds())
+ def call(self):
+ self.lock()
+
+ def unlock(self):
+ if os.path.isfile(self._lock_location):
+ os.remove(self._lock_location)
diff --git a/src/Logging.py b/src/Logging.py
index 2de6de3d..b40ec698 100644
--- a/src/Logging.py
+++ b/src/Logging.py
@@ -15,8 +15,18 @@ class BitBotFormatter(logging.Formatter):
datetime_obj = datetime.datetime.fromtimestamp(record.created)
return utils.iso8601_format(datetime_obj, milliseconds=True)
+class HookedHandler(logging.StreamHandler):
+ def __init__(self, func: typing.Callable[[int, str], None]):
+ logging.StreamHandler.__init__(self)
+ self._func = func
+
+ def emit(self, record):
+ self._func(record.levelno, self.format(record))
+
class Log(object):
def __init__(self, to_file: bool, level: str, location: str):
+ self._hooks = []
+
logging.addLevelName(LEVELS["trace"], "TRACE")
self.logger = logging.getLogger(__name__)
@@ -33,6 +43,11 @@ class Log(object):
stdout_handler.setFormatter(formatter)
self.logger.addHandler(stdout_handler)
+ hook_handler = HookedHandler(self._on_log)
+ hook_handler.setLevel(LEVELS["debug"])
+ hook_handler.setFormatter(formatter)
+ self.logger.addHandler(hook_handler)
+
if to_file:
trace_path = os.path.join(location, "trace.log")
trace_handler = logging.handlers.TimedRotatingFileHandler(
@@ -41,12 +56,25 @@ class Log(object):
trace_handler.setFormatter(formatter)
self.logger.addHandler(trace_handler)
+ info_path = os.path.join(location, "info.log")
+ info_handler = logging.handlers.TimedRotatingFileHandler(
+ info_path, when="midnight", backupCount=0)
+ info_handler.setLevel(LEVELS["info"])
+ info_handler.setFormatter(formatter)
+ self.logger.addHandler(info_handler)
+
warn_path = os.path.join(location, "warn.log")
warn_handler = logging.FileHandler(warn_path)
warn_handler.setLevel(LEVELS["warn"])
warn_handler.setFormatter(formatter)
self.logger.addHandler(warn_handler)
+ def hook(self, func: typing.Callable[[int, str], None]):
+ self._hooks.append(func)
+ def _on_log(self, levelno, line):
+ for func in self._hooks:
+ func(levelno, line)
+
def trace(self, message: str, params: typing.List=None, **kwargs):
self._log(message, params, LEVELS["trace"], kwargs)
def debug(self, message: str, params: typing.List=None, **kwargs):
diff --git a/src/PollSource.py b/src/PollSource.py
new file mode 100644
index 00000000..b549b24e
--- /dev/null
+++ b/src/PollSource.py
@@ -0,0 +1,12 @@
+import typing
+
+class PollSource(object):
+ def get_readables(self) -> typing.List[int]:
+ return []
+ def get_writables(self) -> typing.List[int]:
+ return []
+
+ def is_readable(self, fileno: int):
+ pass
+ def is_writable(self, fileno: int):
+ pass
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index eaf3cc03..d36112a5 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -15,6 +15,9 @@ ISO8601_FORMAT_TZ = "%z"
DATETIME_HUMAN = "%Y/%m/%d %H:%M:%S"
+def datetime_utcnow() -> datetime.datetime:
+ return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
+
def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str:
dt_format = dt.strftime(ISO8601_FORMAT_DT)
tz_format = dt.strftime(ISO8601_FORMAT_TZ)
@@ -25,8 +28,7 @@ def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str:
return "%s%s%s" % (dt_format, ms_format, tz_format)
def iso8601_format_now(milliseconds: bool=False) -> str:
- now = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
- return iso8601_format(now, milliseconds=milliseconds)
+ return iso8601_format(datetime_utcnow(), milliseconds=milliseconds)
def iso8601_parse(s: str, microseconds: bool=False) -> datetime.datetime:
fmt = ISO8601_PARSE_MICROSECONDS if microseconds else ISO8601_PARSE
return datetime.datetime.strptime(s, fmt)
diff --git a/src/utils/parse.py b/src/utils/parse.py
index 65d0552b..b53b9595 100644
--- a/src/utils/parse.py
+++ b/src/utils/parse.py
@@ -1,4 +1,5 @@
import io, typing
+from src import utils
COMMENT_TYPES = ["#", "//"]
def hashflags(filename: str
@@ -79,3 +80,7 @@ def try_int(s: str) -> typing.Optional[int]:
return int(s)
except ValueError:
return None
+
+def line_normalise(s: str) -> str:
+ lines = list(filter(None, [line.strip() for line in s.split("\n")]))
+ return " ".join(line.replace(" ", " ") for line in lines)