aboutsummaryrefslogtreecommitdiff
path: root/src/IRCBot.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/IRCBot.py')
-rw-r--r--src/IRCBot.py251
1 files changed, 150 insertions, 101 deletions
diff --git a/src/IRCBot.py b/src/IRCBot.py
index 9d47b33e..de2042fb 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,4 +1,5 @@
-import enum, queue, os, select, socket, threading, time, traceback, typing, uuid
+import enum, queue, os, queue, select, socket, threading, time, traceback
+import typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import Socket, utils
@@ -24,37 +25,66 @@ class Bot(object):
self._timers = timers
self.start_time = time.time()
- self.lock = threading.Lock()
self.running = True
- self.poll = select.poll()
-
self.servers = {}
- self.other_sockets = {}
- self._trigger_server, self._trigger_client = socket.socketpair()
- self.add_socket(Socket.Socket(self._trigger_server, lambda _, s: None))
- self._trigger_functions = []
+ self._event_queue = queue.Queue()
+
+ self._read_poll = select.poll()
+ self._write_poll = select.poll()
+
+ self._rtrigger_server, self._rtrigger_client = socket.socketpair()
+ self._read_poll.register(self._rtrigger_server.fileno(), select.POLLIN)
+
+ self._rtrigger_lock = threading.Lock()
+ self._rtriggered = False
+ self._write_condition = threading.Condition()
+
+ self._read_thread = None
+ self._write_thread = None
+
self._events.on("timer.reconnect").hook(self._timed_reconnect)
+ def _trigger_both(self):
+ self.trigger_read()
+ self.trigger_write()
+ def trigger_read(self):
+ with self._rtrigger_lock:
+ if not self._rtriggered:
+ self._rtriggered = True
+ self._rtrigger_client.send(b"TRIGGER")
+ def trigger_write(self):
+ with self._write_condition:
+ self._write_condition.notify()
+
def trigger(self,
- func: typing.Optional[typing.Callable[[], typing.Any]]=None
- ) -> typing.Any:
+ func: typing.Optional[typing.Callable[[], typing.Any]]=None,
+ trigger_threads=True) -> typing.Any:
func = func or (lambda: None)
if utils.is_main_thread():
returned = func()
- self._trigger_client.send(b"TRIGGER")
+ if trigger_threads:
+ self._trigger_both()
return returned
- self.lock.acquire()
-
func_queue = queue.Queue(1) # type: queue.Queue[str]
- self._trigger_functions.append([func, func_queue])
- self.lock.release()
- self._trigger_client.send(b"TRIGGER")
+ def _action():
+ try:
+ returned = func()
+ type = TriggerResult.Return
+ except Exception as e:
+ returned = e
+ type = TriggerResult.Exception
+ func_queue.put([type, returned])
+ self._event_queue.put(_action)
type, returned = func_queue.get(block=True)
+
+ if trigger_threads:
+ self._trigger_both()
+
if type == TriggerResult.Exception:
raise returned
elif type == TriggerResult.Return:
@@ -95,14 +125,6 @@ class Bot(object):
return new_server
- def add_socket(self, sock: socket.socket):
- self.other_sockets[sock.fileno()] = sock
- self.poll.register(sock.fileno(), select.POLLIN)
-
- def remove_socket(self, sock: socket.socket):
- del self.other_sockets[sock.fileno()]
- self.poll.unregister(sock.fileno())
-
def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values():
if server.id == id:
@@ -123,14 +145,14 @@ class Bot(object):
[str(server), str(e)])
return False
self.servers[server.fileno()] = server
- self.poll.register(server.fileno(), select.POLLOUT)
+ self._read_poll.register(server.fileno(), select.POLLIN)
return True
def next_send(self) -> typing.Optional[float]:
next = None
for server in self.servers.values():
timeout = server.socket.send_throttle_timeout()
- if (server.socket.waiting_send() and
+ if (server.socket.waiting_throttled_send() and
(next == None or timeout < next)):
next = timeout
return next
@@ -163,20 +185,9 @@ class Bot(object):
min_secs = min([timeout for timeout in timeouts if not timeout == None])
return min_secs*1000 # return milliseconds
- def register_read(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(), select.POLLIN)
- def register_write(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(), select.POLLOUT)
- def register_both(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(),
- select.POLLIN|select.POLLOUT)
-
def disconnect(self, server: IRCServer.Server):
- try:
- self.poll.unregister(server.fileno())
- except FileNotFoundError:
- pass
del self.servers[server.fileno()]
+ self._trigger_both()
def _timed_reconnect(self, event: EventManager.Event):
if not self.reconnect(event["server_id"],
@@ -209,85 +220,123 @@ class Bot(object):
def del_setting(self, setting: str):
self.database.bot_settings.delete(setting)
+ def _daemon_thread(self, target: typing.Callable[[], None]):
+ thread = threading.Thread(target=target)
+ thread.daemon = True
+ thread.start()
+ return thread
+
def run(self):
+ self._read_thread = self._daemon_thread(self._read_loop)
+ self._write_thread = self._daemon_thread(self._write_loop)
+ self._event_loop()
+
+ def _event_loop(self):
+ while self.running:
+ item = self._event_queue.get(block=True, timeout=None)
+ item()
+
+ def _post_send_factory(self, server, lines):
+ return lambda: server._post_send(lines)
+ def _post_read_factory(self, server, lines):
+ return lambda: server._post_read(lines)
+
+ def _write_loop(self):
while self.running:
if not self.servers:
break
- events = self.poll.poll(self.get_poll_timeout())
- self.lock.acquire()
- self._timers.call()
- self.cache.expire()
+ with self._write_condition:
+ writeable = False
+ for fd, server in self.servers.items():
+ if server.socket.waiting_immediate_send():
+ self._write_poll.register(fd, select.POLLOUT)
+ writeable = True
- for func, func_queue in self._trigger_functions:
- try:
- returned = func()
- type = TriggerResult.Return
- except Exception as e:
- returned = e
- type = TriggerResult.Exception
- func_queue.put([type, returned])
- self._trigger_functions.clear()
+ if not writeable:
+ self._write_condition.wait()
+ continue
- for fd, event in events:
- sock = None
- irc = False
- if fd in self.servers:
- sock = self.servers[fd]
- irc = True
- elif fd in self.other_sockets:
- sock = self.other_sockets[fd]
+ events = self._write_poll.poll()
- if sock:
- if event & select.POLLIN:
- data = sock.read()
- if data == None:
- sock.disconnect()
- continue
+ for fd, event in events:
+ if event & select.POLLOUT:
+ self._write_poll.unregister(fd)
+ if fd in self.servers:
+ server = self.servers[fd]
- for piece in data:
- sock.parse_data(piece)
- elif event & select.POLLOUT:
try:
- sock._send()
+ lines = server._send()
except:
self.log.error("Failed to write to %s",
- [str(sock)])
+ [str(server)])
raise
+ self._event_queue.put(self._post_send_factory(server,
+ lines))
+
+ def _read_loop(self):
+ while self.running:
+ if not self.servers:
+ self.running = False
+ self._event_queue.put(lambda: None)
+ break
- if sock.fileno() in self.servers:
- self.register_read(sock)
+ events = self._read_poll.poll(self.get_poll_timeout())
+
+ self.trigger(self._check, False)
+
+ for fd, event in events:
+ if fd == self._rtrigger_server.fileno():
+ # throw away data from trigger socket
+ self._rtrigger_server.recv(1024)
+ with self._rtrigger_lock:
+ self._rtriggered = False
+ else:
+ if not fd in self.servers:
+ self._read_poll.unregister(fd)
+ continue
+
+ server = self.servers[fd]
+ if event & select.POLLIN:
+ lines = server.read()
+ if lines == None:
+ server.disconnect()
+ continue
+
+ self.trigger(self._post_read_factory(server, lines),
+ False)
elif event & select.POLLHUP:
- self.log.warn("Recieved POLLHUP for %s", [str(sock)])
- sock.disconnect()
+ self.log.warn("Recieved POLLHUP for %s", [str(server)])
+ server.disconnect()
+
+ def _check(self):
+ self._timers.call()
+ self.cache.expire()
- for server in list(self.servers.values()):
- if server.read_timed_out():
- self.log.warn("Pinged out from %s", [str(server)])
- server.disconnect()
- elif server.ping_due() and not server.ping_sent:
- server.send_ping()
- server.ping_sent = True
- if not server.socket.connected:
- self._events.on("server.disconnect").call(server=server)
- self.disconnect(server)
+ throttle_filled = False
+ for server in list(self.servers.values()):
+ if server.read_timed_out():
+ self.log.warn("Pinged out from %s", [str(server)])
+ server.disconnect()
+ elif server.ping_due() and not server.ping_sent:
+ server.send_ping()
+ server.ping_sent = True
- if not self.get_server_by_id(server.id):
- reconnect_delay = self.config.get("reconnect-delay", 10)
- self._timers.add("reconnect", reconnect_delay,
- server_id=server.id)
- self.log.warn(
- "Disconnected from %s, reconnecting in %d seconds",
- [str(server), reconnect_delay])
- elif server.socket.waiting_immediate_send() or (
- server.socket.waiting_send() and
- server.socket.throttle_done()):
- self.register_both(server)
+ if not server.socket.connected:
+ self._events.on("server.disconnect").call(server=server)
+ self.disconnect(server)
- for sock in list(self.other_sockets.values()):
- if not sock.connected:
- self.remove_socket(sock)
- elif sock.waiting_send():
- self.register_both(sock)
+ if not self.get_server_by_id(server.id):
+ reconnect_delay = self.config.get("reconnect-delay", 10)
+ self._timers.add("reconnect", reconnect_delay,
+ server_id=server.id)
+ self.log.warn(
+ "Disconnected from %s, reconnecting in %d seconds",
+ [str(server), reconnect_delay])
+ elif (server.socket.waiting_throttled_send() and
+ server.socket.throttle_done()):
+ server.socket._fill_throttle()
+ throttle_filled = True
- self.lock.release()
+ if throttle_filled:
+ self.trigger_write()