diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Control.py | 109 | ||||
| -rw-r--r-- | src/IRCBot.py | 60 | ||||
| -rw-r--r-- | src/IRCSocket.py | 2 | ||||
| -rw-r--r-- | src/LockFile.py | 38 | ||||
| -rw-r--r-- | src/Logging.py | 28 | ||||
| -rw-r--r-- | src/PollSource.py | 12 | ||||
| -rw-r--r-- | src/utils/__init__.py | 6 | ||||
| -rw-r--r-- | src/utils/parse.py | 5 |
8 files changed, 241 insertions, 19 deletions
diff --git a/src/Control.py b/src/Control.py new file mode 100644 index 00000000..5b59bfec --- /dev/null +++ b/src/Control.py @@ -0,0 +1,109 @@ +import json, os, socket, typing +from src import IRCBot, Logging, PollSource + +class ControlClient(object): + def __init__(self, sock: socket.socket): + self._socket = sock + self._read_buffer = b"" + self._write_buffer = b"" + self.version = -1 + self.log_level = None # type: typing.Optional[int] + + def fileno(self) -> int: + return self._socket.fileno() + + def read_lines(self) -> typing.List[str]: + try: + data = self._socket.recv(2048) + except: + data = b"" + if not data: + return None + lines = (self._read_buffer+data).split(b"\n") + lines = [line.strip(b"\r") for line in lines] + self._read_buffer = lines.pop(-1) + return [line.decode("utf8") for line in lines] + + def write_line(self, line: str): + self._socket.send(("%s\n" % line).encode("utf8")) + + def disconnect(self): + try: + self._socket.shutdown(socket.SHUT_RDWR) + except: + pass + try: + self._socket.close() + except: + pass + + +class Control(PollSource.PollSource): + def __init__(self, bot: IRCBot.Bot, database_location: str): + self._bot = bot + self._bot.log.hook(self._on_log) + + self._socket_location = "%s.sock" % database_location + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._clients = {} + + def _on_log(self, levelno: int, line: str): + for client in self._clients.values(): + if not client.log_level == None and client.log_level <= levelno: + self._send_action(client, "log", line) + + def bind(self): + if os.path.exists(self._socket_location): + os.remove(self._socket_location) + self._socket.bind(self._socket_location) + self._socket.listen(1) + + def get_readables(self) -> typing.List[int]: + return [self._socket.fileno()]+list(self._clients.keys()) + + def is_readable(self, fileno: int): + if fileno == self._socket.fileno(): + client, address = self._socket.accept() + self._clients[client.fileno()] = ControlClient(client) + self._bot.log.debug("New control socket connected") + elif fileno in self._clients: + client = self._clients[fileno] + lines = client.read_lines() + if lines == None: + client.disconnect() + del self._clients[fileno] + else: + for line in lines: + response = self._parse_line(client, line) + + def _parse_line(self, client: ControlClient, line: str): + id, _, command = line.partition(" ") + command, _, data = command.partition(" ") + if not id or not command: + client.disconnect() + return + + command = command.lower() + response_action = "ack" + response_data = None + + keepalive = True + + if command == "version": + client.version = int(data) + elif command == "log": + client.log_level = Logging.LEVELS[data.lower()] + elif command == "rehash": + self._bot.log.info("Reloading config file") + self._bot.config.load() + self._bot.log.info("Reloaded config file") + keepalive = False + + self._send_action(client, response_action, response_data, id) + if not keepalive: + client.disconnect() + + def _send_action(self, client: ControlClient, action: str, data: str, + id: int=None): + client.write_line( + json.dumps({"action": action, "data": data, "id": id})) diff --git a/src/IRCBot.py b/src/IRCBot.py index df0b0540..8eb1873e 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,9 +1,10 @@ import enum, queue, os, queue, select, socket, sys, threading, time, traceback import typing, uuid from src import EventManager, Exports, IRCServer, Logging, ModuleManager -from src import PollHook, Socket, Timers, utils +from src import PollHook, PollSource, Socket, Timers, utils -VERSION = "v1.12.0-rc1" +with open("VERSION", "r") as version_file: + VERSION = "v%s" % version_file.read().strip() SOURCE = "https://git.io/bitbot" URL = "https://bitbot.dev" @@ -68,10 +69,7 @@ class Bot(object): self._read_thread = None self._write_thread = None - self._poll_timeouts = [] # typing.List[PollHook] - self._poll_timeouts.append(self._timers) - self._poll_timeouts.append(self.cache) - + self._poll_timeouts = [] # typing.List[PollHook.PollHook] self._poll_timeouts.append(ListLambdaPollHook( lambda: self.servers.values(), lambda server: server.until_read_timeout())) @@ -83,6 +81,13 @@ class Bot(object): self._poll_timeouts.append(ListLambdaPollHook( lambda: self.servers.values(), self._throttle_timeout)) + self._poll_sources = [] # typing.List[PollSource.PollSource] + + def add_poll_hook(self, hook: PollHook.PollHook): + self._poll_timeouts.append(hook) + def add_poll_source(self, source: PollSource.PollSource): + self._poll_sources.append(source) + def _throttle_timeout(self, server: IRCServer.Server): if server.socket.waiting_throttled_send(): return server.socket.send_throttle_timeout() @@ -276,11 +281,6 @@ class Bot(object): def _event_loop(self): while self.running or not self._event_queue.empty(): - if not self.servers and self._event_queue.empty(): - self._kill() - self.log.warn("No servers, exiting") - break - try: item = self._event_queue.get(block=True, timeout=self.get_poll_timeout()) @@ -316,16 +316,24 @@ class Bot(object): def _write_loop(self): while self.running: + poll_sources = {} with self._write_condition: - writeable = False + fds = [] for fd, server in self.servers.items(): if server.socket.waiting_immediate_send(): - self._write_poll.register(fd, select.POLLOUT) - writeable = True + fds.append(fd) + + for poll_source in self._poll_sources: + for fileno in poll_source.get_writables(): + poll_sources[fileno] = poll_source + fds.append(fileno) - if not writeable: + if not fds: self._write_condition.wait() continue + else: + for fd in fds: + self._write_poll.register(fd, select.POLLOUT) events = self._write_poll.poll() @@ -344,9 +352,25 @@ class Bot(object): event_item = TriggerEvent(TriggerEventType.Action, self._post_send_factory(server, lines)) self._event_queue.put(event_item) + elif fd in poll_sources: + poll_sources[fd].is_writeable(fd) def _read_loop(self): + poll_sources = {} while self.running: + new_poll_sources = {} + for poll_source in self._poll_sources: + for fileno in poll_source.get_readables(): + new_poll_sources[fileno] = poll_source + for fileno in new_poll_sources: + if not fileno in poll_sources: + poll_sources[fileno] = new_poll_sources[fileno] + self._read_poll.register(fileno, select.POLLIN) + for fileno in list(poll_sources.keys()): + if not fileno in new_poll_sources: + del poll_sources[fileno] + self._read_poll.unregister(fileno) + events = self._read_poll.poll() for fd, event in events: @@ -355,6 +379,9 @@ class Bot(object): with self._rtrigger_lock: self._rtrigger_server.recv(1024) self._rtriggered = False + elif fd in poll_sources: + poll_sources[fd].is_readable(fd) + self.trigger_write() else: if not fd in self.servers: self._read_poll.unregister(fd) @@ -376,7 +403,8 @@ class Bot(object): def _check(self): for poll_timeout in self._poll_timeouts: - poll_timeout.call() + if poll_timeout.next() == 0: + poll_timeout.call() throttle_filled = False for server in list(self.servers.values()): diff --git a/src/IRCSocket.py b/src/IRCSocket.py index 98091aff..7fbae39d 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -44,7 +44,7 @@ class Socket(IRCObject.Object): self.last_send = None # type: typing.Optional[float] self.connected_ip = None # type: typing.Optional[str] - self.conncect_time: float = -1 + self.connect_time: float = -1 def fileno(self) -> int: return self.cached_fileno or self._socket.fileno() diff --git a/src/LockFile.py b/src/LockFile.py new file mode 100644 index 00000000..3436eddf --- /dev/null +++ b/src/LockFile.py @@ -0,0 +1,38 @@ +import datetime, os +from src import PollHook, utils + +EXPIRATION = 60 # 1 minute + +class LockFile(PollHook.PollHook): + def __init__(self, database_location: str): + self._lock_location = "%s.lock" % database_location + self._next_lock = None + + def available(self): + now = utils.datetime_utcnow() + if os.path.exists(self._lock_location): + with open(self._lock_location, "r") as lock_file: + timestamp_str = lock_file.read().strip().split(" ", 1)[0] + + timestamp = utils.iso8601_parse(timestamp_str) + + if (now-timestamp).total_seconds() < EXPIRATION: + return False + + return True + + def lock(self): + with open(self._lock_location, "w") as lock_file: + last_lock = utils.datetime_utcnow() + lock_file.write("%s" % utils.iso8601_format(last_lock)) + self._next_lock = last_lock+datetime.timedelta( + seconds=EXPIRATION/2) + + def next(self): + return max(0, (self._next_lock-utils.datetime_utcnow()).total_seconds()) + def call(self): + self.lock() + + def unlock(self): + if os.path.isfile(self._lock_location): + os.remove(self._lock_location) diff --git a/src/Logging.py b/src/Logging.py index 2de6de3d..b40ec698 100644 --- a/src/Logging.py +++ b/src/Logging.py @@ -15,8 +15,18 @@ class BitBotFormatter(logging.Formatter): datetime_obj = datetime.datetime.fromtimestamp(record.created) return utils.iso8601_format(datetime_obj, milliseconds=True) +class HookedHandler(logging.StreamHandler): + def __init__(self, func: typing.Callable[[int, str], None]): + logging.StreamHandler.__init__(self) + self._func = func + + def emit(self, record): + self._func(record.levelno, self.format(record)) + class Log(object): def __init__(self, to_file: bool, level: str, location: str): + self._hooks = [] + logging.addLevelName(LEVELS["trace"], "TRACE") self.logger = logging.getLogger(__name__) @@ -33,6 +43,11 @@ class Log(object): stdout_handler.setFormatter(formatter) self.logger.addHandler(stdout_handler) + hook_handler = HookedHandler(self._on_log) + hook_handler.setLevel(LEVELS["debug"]) + hook_handler.setFormatter(formatter) + self.logger.addHandler(hook_handler) + if to_file: trace_path = os.path.join(location, "trace.log") trace_handler = logging.handlers.TimedRotatingFileHandler( @@ -41,12 +56,25 @@ class Log(object): trace_handler.setFormatter(formatter) self.logger.addHandler(trace_handler) + info_path = os.path.join(location, "info.log") + info_handler = logging.handlers.TimedRotatingFileHandler( + info_path, when="midnight", backupCount=0) + info_handler.setLevel(LEVELS["info"]) + info_handler.setFormatter(formatter) + self.logger.addHandler(info_handler) + warn_path = os.path.join(location, "warn.log") warn_handler = logging.FileHandler(warn_path) warn_handler.setLevel(LEVELS["warn"]) warn_handler.setFormatter(formatter) self.logger.addHandler(warn_handler) + def hook(self, func: typing.Callable[[int, str], None]): + self._hooks.append(func) + def _on_log(self, levelno, line): + for func in self._hooks: + func(levelno, line) + def trace(self, message: str, params: typing.List=None, **kwargs): self._log(message, params, LEVELS["trace"], kwargs) def debug(self, message: str, params: typing.List=None, **kwargs): diff --git a/src/PollSource.py b/src/PollSource.py new file mode 100644 index 00000000..b549b24e --- /dev/null +++ b/src/PollSource.py @@ -0,0 +1,12 @@ +import typing + +class PollSource(object): + def get_readables(self) -> typing.List[int]: + return [] + def get_writables(self) -> typing.List[int]: + return [] + + def is_readable(self, fileno: int): + pass + def is_writable(self, fileno: int): + pass diff --git a/src/utils/__init__.py b/src/utils/__init__.py index eaf3cc03..d36112a5 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -15,6 +15,9 @@ ISO8601_FORMAT_TZ = "%z" DATETIME_HUMAN = "%Y/%m/%d %H:%M:%S" +def datetime_utcnow() -> datetime.datetime: + return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str: dt_format = dt.strftime(ISO8601_FORMAT_DT) tz_format = dt.strftime(ISO8601_FORMAT_TZ) @@ -25,8 +28,7 @@ def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str: return "%s%s%s" % (dt_format, ms_format, tz_format) def iso8601_format_now(milliseconds: bool=False) -> str: - now = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) - return iso8601_format(now, milliseconds=milliseconds) + return iso8601_format(datetime_utcnow(), milliseconds=milliseconds) def iso8601_parse(s: str, microseconds: bool=False) -> datetime.datetime: fmt = ISO8601_PARSE_MICROSECONDS if microseconds else ISO8601_PARSE return datetime.datetime.strptime(s, fmt) diff --git a/src/utils/parse.py b/src/utils/parse.py index 65d0552b..b53b9595 100644 --- a/src/utils/parse.py +++ b/src/utils/parse.py @@ -1,4 +1,5 @@ import io, typing +from src import utils COMMENT_TYPES = ["#", "//"] def hashflags(filename: str @@ -79,3 +80,7 @@ def try_int(s: str) -> typing.Optional[int]: return int(s) except ValueError: return None + +def line_normalise(s: str) -> str: + lines = list(filter(None, [line.strip() for line in s.split("\n")])) + return " ".join(line.replace(" ", " ") for line in lines) |
