diff options
| author | 2019-06-06 17:05:44 +0100 | |
|---|---|---|
| committer | 2019-06-06 17:05:44 +0100 | |
| commit | a1ebe8035e9e9d13c28db258eeb983c6ee0f7202 (patch) | |
| tree | 171153e8b02dcaabff98a48a4a98b761eb1682d7 | |
| parent | Make `params` arg for logging functions optional (diff) | |
| signature | ||
Split read/write/process in to 3 different threads
| -rw-r--r-- | src/IRCBot.py | 210 | ||||
| -rw-r--r-- | src/IRCServer.py | 18 | ||||
| -rw-r--r-- | src/IRCSocket.py | 16 |
3 files changed, 139 insertions, 105 deletions
diff --git a/src/IRCBot.py b/src/IRCBot.py index f5407e59..3d638fe1 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,4 +1,5 @@ -import enum, queue, os, select, socket, threading, time, traceback, typing, uuid +import enum, queue, os, queue, select, socket, threading, time, traceback +import typing, uuid from src import EventManager, Exports, IRCServer, Logging, ModuleManager from src import Socket, utils @@ -24,18 +25,31 @@ class Bot(object): self._timers = timers self.start_time = time.time() - self.lock = threading.Lock() self.running = True - self.poll = select.epoll() - self.servers = {} - self.other_sockets = {} - self._trigger_server, self._trigger_client = socket.socketpair() - self.add_socket(Socket.Socket(self._trigger_server, lambda _, s: None)) + + self._event_queue = queue.Queue() + + self._read_poll = select.epoll() + self._write_poll = select.epoll() + + self._rtrigger_server, self._rtrigger_client = socket.socketpair() + self._read_poll.register(self._rtrigger_server.fileno(), select.EPOLLIN) + + self._wtrigger_server, self._wtrigger_client = socket.socketpair() + self._write_poll.register(self._wtrigger_server.fileno(), + select.EPOLLIN) + + self._read_thread = None + self._write_thread = None self._trigger_functions = [] self._events.on("timer.reconnect").hook(self._timed_reconnect) + def _thread_trigger(self): + self._rtrigger_client.send(b"TRIGGER") + self._wtrigger_client.send(b"TRIGGER") + def trigger(self, func: typing.Optional[typing.Callable[[], typing.Any]]=None ) -> typing.Any: @@ -43,18 +57,25 @@ class Bot(object): if utils.is_main_thread(): returned = func() - self._trigger_client.send(b"TRIGGER") + self._thread_trigger() return returned - self.lock.acquire() - func_queue = queue.Queue(1) # type: queue.Queue[str] - self._trigger_functions.append([func, func_queue]) - self.lock.release() - self._trigger_client.send(b"TRIGGER") + def _action(): + try: + returned = func() + type = TriggerResult.Return + except Exception as e: + returned = e + type = TriggerResult.Exception + func_queue.put([type, returned]) + self._event_queue.put(_action) type, returned = func_queue.get(block=True) + + self._thread_trigger() + if type == TriggerResult.Exception: raise returned elif type == TriggerResult.Return: @@ -95,14 +116,6 @@ class Bot(object): return new_server - def add_socket(self, sock: socket.socket): - self.other_sockets[sock.fileno()] = sock - self.poll.register(sock.fileno(), select.EPOLLIN) - - def remove_socket(self, sock: socket.socket): - del self.other_sockets[sock.fileno()] - self.poll.unregister(sock.fileno()) - def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]: for server in self.servers.values(): if server.id == id: @@ -123,14 +136,14 @@ class Bot(object): [str(server), str(e)]) return False self.servers[server.fileno()] = server - self.poll.register(server.fileno(), select.EPOLLOUT) + self._read_poll.register(server.fileno(), select.EPOLLIN) return True def next_send(self) -> typing.Optional[float]: next = None for server in self.servers.values(): timeout = server.socket.send_throttle_timeout() - if (server.socket.waiting_send() and + if (server.socket.waiting_throttled_send() and (next == None or timeout < next)): next = timeout return next @@ -162,17 +175,13 @@ class Bot(object): timeouts.append(self.cache.next_expiration()) return min([timeout for timeout in timeouts if not timeout == None]) - def register_read(self, server: IRCServer.Server): - self.poll.modify(server.fileno(), select.EPOLLIN) - def register_write(self, server: IRCServer.Server): - self.poll.modify(server.fileno(), select.EPOLLOUT) - def register_both(self, server: IRCServer.Server): - self.poll.modify(server.fileno(), - select.EPOLLIN|select.EPOLLOUT) - def disconnect(self, server: IRCServer.Server): try: - self.poll.unregister(server.fileno()) + self._read_poll.unregister(server.fileno()) + except FileNotFoundError: + pass + try: + self._write_poll.unregister(server.fileno()) except FileNotFoundError: pass del self.servers[server.fileno()] @@ -208,13 +217,54 @@ class Bot(object): def del_setting(self, setting: str): self.database.bot_settings.delete(setting) + def _daemon_thread(self, target: typing.Callable[[], None]): + thread = threading.Thread(target=target) + thread.daemon = True + thread.start() + return thread + def run(self): + self._read_thread = self._daemon_thread(self._read_loop) + self._write_thread = self._daemon_thread(self._write_loop) + self._event_loop() + + def _event_loop(self): + while self.running: + item = self._event_queue.get(block=True, timeout=None) + item() + + def _write_loop(self): while self.running: if not self.servers: break - events = self.poll.poll(self.get_poll_timeout()) - self.lock.acquire() + for fd, server in self.servers.items(): + if server.socket.waiting_immediate_send(): + self._write_poll.register(fd, select.EPOLLOUT) + + events = self._write_poll.poll() + for fd, event in events: + if fd == self._wtrigger_server.fileno(): + # throw away data from trigger socket + self._wtrigger_server.recv(1024) + elif event & select.EPOLLOUT: + self._write_poll.unregister(fd) + server = self.servers[fd] + + try: + lines = server.socket._send() + except: + self.log.error("Failed to write to %s", [str(server)]) + raise + self._event_queue.put(lambda: server._post_send(lines)) + + def _read_loop(self): + while self.running: + if not self.servers: + self.running = False + break + + events = self._read_poll.poll(self.get_poll_timeout()) self._timers.call() self.cache.expire() @@ -229,64 +279,48 @@ class Bot(object): self._trigger_functions.clear() for fd, event in events: - sock = None - irc = False - if fd in self.servers: - sock = self.servers[fd] - irc = True - elif fd in self.other_sockets: - sock = self.other_sockets[fd] - - if sock: + if fd == self._rtrigger_server.fileno(): + # throw away data from trigger socket + self._rtrigger_server.recv(1024) + else: + server = self.servers[fd] if event & select.EPOLLIN: - data = sock.read() - if data == None: - sock.disconnect() + lines = server.read() + if lines == None: + server.disconnect() continue - for piece in data: - sock.parse_data(piece) - elif event & select.EPOLLOUT: - try: - sock._send() - except: - self.log.error("Failed to write to %s", - [str(sock)]) - raise - - if sock.fileno() in self.servers: - self.register_read(sock) - elif event & select.EPULLHUP: - self.log.warn("Recieved EPOLLHUP for %s", [str(sock)]) - sock.disconnect() + self._event_queue.put(lambda: server._post_read(lines)) + elif event & select.EPOLLHUP: + self.log.warn("Recieved EPOLLHUP for %s", [str(server)]) + server.disconnect() - for server in list(self.servers.values()): - if server.read_timed_out(): - self.log.warn("Pinged out from %s", [str(server)]) - server.disconnect() - elif server.ping_due() and not server.ping_sent: - server.send_ping() - server.ping_sent = True - if not server.socket.connected: - self._events.on("server.disconnect").call(server=server) - self.disconnect(server) + self.trigger(self._check_servers) - if not self.get_server_by_id(server.id): - reconnect_delay = self.config.get("reconnect-delay", 10) - self._timers.add("reconnect", reconnect_delay, - server_id=server.id) - self.log.warn( - "Disconnected from %s, reconnecting in %d seconds", - [str(server), reconnect_delay]) - elif server.socket.waiting_immediate_send() or ( - server.socket.waiting_send() and - server.socket.throttle_done()): - self.register_both(server) + def _check_servers(self): + throttle_filled = False + for server in list(self.servers.values()): + if server.read_timed_out(): + self.log.warn("Pinged out from %s", [str(server)]) + server.disconnect() + elif server.ping_due() and not server.ping_sent: + server.send_ping() + server.ping_sent = True + if not server.socket.connected: + self._events.on("server.disconnect").call(server=server) + self.disconnect(server) - for sock in list(self.other_sockets.values()): - if not sock.connected: - self.remove_socket(sock) - elif sock.waiting_send(): - self.register_both(sock) + if not self.get_server_by_id(server.id): + reconnect_delay = self.config.get("reconnect-delay", 10) + self._timers.add("reconnect", reconnect_delay, + server_id=server.id) + self.log.warn( + "Disconnected from %s, reconnecting in %d seconds", + [str(server), reconnect_delay]) + elif (server.socket.waiting_throttled_send() and + server.socket.throttle_done()): + server.socket._fill_throttle() + throttle_filled = True - self.lock.release() + if throttle_filled: + self._wtrigger_client.send(b"TRIGGER") diff --git a/src/IRCServer.py b/src/IRCServer.py index a8bd12ea..e037eca7 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -207,14 +207,12 @@ class Server(IRCObject.Object): return utils.irc.hostmask_match(self.irc_lower(hostmask), self.irc_lower(pattern)) - def parse_data(self, line: str): - if not line: - return - - self.bot.log.debug("%s (raw recv) | %s", [str(self), line]) - self.events.on("raw.received").call_unsafe(server=self, - line=utils.irc.parse_line(line)) - self.check_users() + def _post_read(self, lines: typing.List[str]): + for line in lines: + self.bot.log.debug("%s (raw recv) | %s", [str(self), line]) + self.events.on("raw.received").call_unsafe(server=self, + line=utils.irc.parse_line(line)) + self.check_users() def check_users(self): for user in self.new_users: if not len(user.channels): @@ -261,8 +259,8 @@ class Server(IRCObject.Object): def send_raw(self, line: str): return self.send(utils.irc.parse_line(line)) - def _send(self): - lines = self.socket._send() + + def _post_send(self, lines: typing.List[IRCLine.SentLine]): for line in lines: self.bot.log.debug("%s (raw send) | %s", [ str(self), line.parsed_line.format()]) diff --git a/src/IRCSocket.py b/src/IRCSocket.py index 3b3395e8..0223a7a9 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -129,12 +129,8 @@ class Socket(IRCObject.Object): else: self._queued_lines.append(line) - def _send(self) -> typing.List[IRCLine.SentLine]: - if not self._write_buffer and self._throttle_when_empty: - self._throttle_when_empty = False - self._write_throttling = True - self._recent_sends.clear() + def _fill_throttle(self): throttle_space = self.throttle_space() if throttle_space: to_buffer = self._queued_lines[:throttle_space] @@ -142,6 +138,12 @@ class Socket(IRCObject.Object): for line in to_buffer: self._immediate_buffer(line) + def _send(self) -> typing.List[IRCLine.SentLine]: + if not self._write_buffer and self._throttle_when_empty: + self._throttle_when_empty = False + self._write_throttling = True + self._recent_sends.clear() + bytes_written_i = self._socket.send(self._write_buffer) bytes_written = self._write_buffer[:bytes_written_i] @@ -165,8 +167,8 @@ class Socket(IRCObject.Object): def clear_send_buffer(self): self._queued_lines.clear() - def waiting_send(self) -> bool: - return bool(len(self._write_buffer)) or bool(len(self._queued_lines)) + def waiting_throttled_send(self) -> bool: + return bool(len(self._queued_lines)) def waiting_immediate_send(self) -> bool: return bool(len(self._write_buffer)) |
