aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jesopo2019-06-06 17:05:44 +0100
committerGravatar jesopo2019-06-06 17:05:44 +0100
commita1ebe8035e9e9d13c28db258eeb983c6ee0f7202 (patch)
tree171153e8b02dcaabff98a48a4a98b761eb1682d7
parentMake `params` arg for logging functions optional (diff)
signature
Split read/write/process in to 3 different threads
-rw-r--r--src/IRCBot.py210
-rw-r--r--src/IRCServer.py18
-rw-r--r--src/IRCSocket.py16
3 files changed, 139 insertions, 105 deletions
diff --git a/src/IRCBot.py b/src/IRCBot.py
index f5407e59..3d638fe1 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,4 +1,5 @@
-import enum, queue, os, select, socket, threading, time, traceback, typing, uuid
+import enum, queue, os, queue, select, socket, threading, time, traceback
+import typing, uuid
from src import EventManager, Exports, IRCServer, Logging, ModuleManager
from src import Socket, utils
@@ -24,18 +25,31 @@ class Bot(object):
self._timers = timers
self.start_time = time.time()
- self.lock = threading.Lock()
self.running = True
- self.poll = select.epoll()
-
self.servers = {}
- self.other_sockets = {}
- self._trigger_server, self._trigger_client = socket.socketpair()
- self.add_socket(Socket.Socket(self._trigger_server, lambda _, s: None))
+
+ self._event_queue = queue.Queue()
+
+ self._read_poll = select.epoll()
+ self._write_poll = select.epoll()
+
+ self._rtrigger_server, self._rtrigger_client = socket.socketpair()
+ self._read_poll.register(self._rtrigger_server.fileno(), select.EPOLLIN)
+
+ self._wtrigger_server, self._wtrigger_client = socket.socketpair()
+ self._write_poll.register(self._wtrigger_server.fileno(),
+ select.EPOLLIN)
+
+ self._read_thread = None
+ self._write_thread = None
self._trigger_functions = []
self._events.on("timer.reconnect").hook(self._timed_reconnect)
+ def _thread_trigger(self):
+ self._rtrigger_client.send(b"TRIGGER")
+ self._wtrigger_client.send(b"TRIGGER")
+
def trigger(self,
func: typing.Optional[typing.Callable[[], typing.Any]]=None
) -> typing.Any:
@@ -43,18 +57,25 @@ class Bot(object):
if utils.is_main_thread():
returned = func()
- self._trigger_client.send(b"TRIGGER")
+ self._thread_trigger()
return returned
- self.lock.acquire()
-
func_queue = queue.Queue(1) # type: queue.Queue[str]
- self._trigger_functions.append([func, func_queue])
- self.lock.release()
- self._trigger_client.send(b"TRIGGER")
+ def _action():
+ try:
+ returned = func()
+ type = TriggerResult.Return
+ except Exception as e:
+ returned = e
+ type = TriggerResult.Exception
+ func_queue.put([type, returned])
+ self._event_queue.put(_action)
type, returned = func_queue.get(block=True)
+
+ self._thread_trigger()
+
if type == TriggerResult.Exception:
raise returned
elif type == TriggerResult.Return:
@@ -95,14 +116,6 @@ class Bot(object):
return new_server
- def add_socket(self, sock: socket.socket):
- self.other_sockets[sock.fileno()] = sock
- self.poll.register(sock.fileno(), select.EPOLLIN)
-
- def remove_socket(self, sock: socket.socket):
- del self.other_sockets[sock.fileno()]
- self.poll.unregister(sock.fileno())
-
def get_server_by_id(self, id: int) -> typing.Optional[IRCServer.Server]:
for server in self.servers.values():
if server.id == id:
@@ -123,14 +136,14 @@ class Bot(object):
[str(server), str(e)])
return False
self.servers[server.fileno()] = server
- self.poll.register(server.fileno(), select.EPOLLOUT)
+ self._read_poll.register(server.fileno(), select.EPOLLIN)
return True
def next_send(self) -> typing.Optional[float]:
next = None
for server in self.servers.values():
timeout = server.socket.send_throttle_timeout()
- if (server.socket.waiting_send() and
+ if (server.socket.waiting_throttled_send() and
(next == None or timeout < next)):
next = timeout
return next
@@ -162,17 +175,13 @@ class Bot(object):
timeouts.append(self.cache.next_expiration())
return min([timeout for timeout in timeouts if not timeout == None])
- def register_read(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(), select.EPOLLIN)
- def register_write(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(), select.EPOLLOUT)
- def register_both(self, server: IRCServer.Server):
- self.poll.modify(server.fileno(),
- select.EPOLLIN|select.EPOLLOUT)
-
def disconnect(self, server: IRCServer.Server):
try:
- self.poll.unregister(server.fileno())
+ self._read_poll.unregister(server.fileno())
+ except FileNotFoundError:
+ pass
+ try:
+ self._write_poll.unregister(server.fileno())
except FileNotFoundError:
pass
del self.servers[server.fileno()]
@@ -208,13 +217,54 @@ class Bot(object):
def del_setting(self, setting: str):
self.database.bot_settings.delete(setting)
+ def _daemon_thread(self, target: typing.Callable[[], None]):
+ thread = threading.Thread(target=target)
+ thread.daemon = True
+ thread.start()
+ return thread
+
def run(self):
+ self._read_thread = self._daemon_thread(self._read_loop)
+ self._write_thread = self._daemon_thread(self._write_loop)
+ self._event_loop()
+
+ def _event_loop(self):
+ while self.running:
+ item = self._event_queue.get(block=True, timeout=None)
+ item()
+
+ def _write_loop(self):
while self.running:
if not self.servers:
break
- events = self.poll.poll(self.get_poll_timeout())
- self.lock.acquire()
+ for fd, server in self.servers.items():
+ if server.socket.waiting_immediate_send():
+ self._write_poll.register(fd, select.EPOLLOUT)
+
+ events = self._write_poll.poll()
+ for fd, event in events:
+ if fd == self._wtrigger_server.fileno():
+ # throw away data from trigger socket
+ self._wtrigger_server.recv(1024)
+ elif event & select.EPOLLOUT:
+ self._write_poll.unregister(fd)
+ server = self.servers[fd]
+
+ try:
+ lines = server.socket._send()
+ except:
+ self.log.error("Failed to write to %s", [str(server)])
+ raise
+ self._event_queue.put(lambda: server._post_send(lines))
+
+ def _read_loop(self):
+ while self.running:
+ if not self.servers:
+ self.running = False
+ break
+
+ events = self._read_poll.poll(self.get_poll_timeout())
self._timers.call()
self.cache.expire()
@@ -229,64 +279,48 @@ class Bot(object):
self._trigger_functions.clear()
for fd, event in events:
- sock = None
- irc = False
- if fd in self.servers:
- sock = self.servers[fd]
- irc = True
- elif fd in self.other_sockets:
- sock = self.other_sockets[fd]
-
- if sock:
+ if fd == self._rtrigger_server.fileno():
+ # throw away data from trigger socket
+ self._rtrigger_server.recv(1024)
+ else:
+ server = self.servers[fd]
if event & select.EPOLLIN:
- data = sock.read()
- if data == None:
- sock.disconnect()
+ lines = server.read()
+ if lines == None:
+ server.disconnect()
continue
- for piece in data:
- sock.parse_data(piece)
- elif event & select.EPOLLOUT:
- try:
- sock._send()
- except:
- self.log.error("Failed to write to %s",
- [str(sock)])
- raise
-
- if sock.fileno() in self.servers:
- self.register_read(sock)
- elif event & select.EPULLHUP:
- self.log.warn("Recieved EPOLLHUP for %s", [str(sock)])
- sock.disconnect()
+ self._event_queue.put(lambda: server._post_read(lines))
+ elif event & select.EPOLLHUP:
+ self.log.warn("Recieved EPOLLHUP for %s", [str(server)])
+ server.disconnect()
- for server in list(self.servers.values()):
- if server.read_timed_out():
- self.log.warn("Pinged out from %s", [str(server)])
- server.disconnect()
- elif server.ping_due() and not server.ping_sent:
- server.send_ping()
- server.ping_sent = True
- if not server.socket.connected:
- self._events.on("server.disconnect").call(server=server)
- self.disconnect(server)
+ self.trigger(self._check_servers)
- if not self.get_server_by_id(server.id):
- reconnect_delay = self.config.get("reconnect-delay", 10)
- self._timers.add("reconnect", reconnect_delay,
- server_id=server.id)
- self.log.warn(
- "Disconnected from %s, reconnecting in %d seconds",
- [str(server), reconnect_delay])
- elif server.socket.waiting_immediate_send() or (
- server.socket.waiting_send() and
- server.socket.throttle_done()):
- self.register_both(server)
+ def _check_servers(self):
+ throttle_filled = False
+ for server in list(self.servers.values()):
+ if server.read_timed_out():
+ self.log.warn("Pinged out from %s", [str(server)])
+ server.disconnect()
+ elif server.ping_due() and not server.ping_sent:
+ server.send_ping()
+ server.ping_sent = True
+ if not server.socket.connected:
+ self._events.on("server.disconnect").call(server=server)
+ self.disconnect(server)
- for sock in list(self.other_sockets.values()):
- if not sock.connected:
- self.remove_socket(sock)
- elif sock.waiting_send():
- self.register_both(sock)
+ if not self.get_server_by_id(server.id):
+ reconnect_delay = self.config.get("reconnect-delay", 10)
+ self._timers.add("reconnect", reconnect_delay,
+ server_id=server.id)
+ self.log.warn(
+ "Disconnected from %s, reconnecting in %d seconds",
+ [str(server), reconnect_delay])
+ elif (server.socket.waiting_throttled_send() and
+ server.socket.throttle_done()):
+ server.socket._fill_throttle()
+ throttle_filled = True
- self.lock.release()
+ if throttle_filled:
+ self._wtrigger_client.send(b"TRIGGER")
diff --git a/src/IRCServer.py b/src/IRCServer.py
index a8bd12ea..e037eca7 100644
--- a/src/IRCServer.py
+++ b/src/IRCServer.py
@@ -207,14 +207,12 @@ class Server(IRCObject.Object):
return utils.irc.hostmask_match(self.irc_lower(hostmask),
self.irc_lower(pattern))
- def parse_data(self, line: str):
- if not line:
- return
-
- self.bot.log.debug("%s (raw recv) | %s", [str(self), line])
- self.events.on("raw.received").call_unsafe(server=self,
- line=utils.irc.parse_line(line))
- self.check_users()
+ def _post_read(self, lines: typing.List[str]):
+ for line in lines:
+ self.bot.log.debug("%s (raw recv) | %s", [str(self), line])
+ self.events.on("raw.received").call_unsafe(server=self,
+ line=utils.irc.parse_line(line))
+ self.check_users()
def check_users(self):
for user in self.new_users:
if not len(user.channels):
@@ -261,8 +259,8 @@ class Server(IRCObject.Object):
def send_raw(self, line: str):
return self.send(utils.irc.parse_line(line))
- def _send(self):
- lines = self.socket._send()
+
+ def _post_send(self, lines: typing.List[IRCLine.SentLine]):
for line in lines:
self.bot.log.debug("%s (raw send) | %s", [
str(self), line.parsed_line.format()])
diff --git a/src/IRCSocket.py b/src/IRCSocket.py
index 3b3395e8..0223a7a9 100644
--- a/src/IRCSocket.py
+++ b/src/IRCSocket.py
@@ -129,12 +129,8 @@ class Socket(IRCObject.Object):
else:
self._queued_lines.append(line)
- def _send(self) -> typing.List[IRCLine.SentLine]:
- if not self._write_buffer and self._throttle_when_empty:
- self._throttle_when_empty = False
- self._write_throttling = True
- self._recent_sends.clear()
+ def _fill_throttle(self):
throttle_space = self.throttle_space()
if throttle_space:
to_buffer = self._queued_lines[:throttle_space]
@@ -142,6 +138,12 @@ class Socket(IRCObject.Object):
for line in to_buffer:
self._immediate_buffer(line)
+ def _send(self) -> typing.List[IRCLine.SentLine]:
+ if not self._write_buffer and self._throttle_when_empty:
+ self._throttle_when_empty = False
+ self._write_throttling = True
+ self._recent_sends.clear()
+
bytes_written_i = self._socket.send(self._write_buffer)
bytes_written = self._write_buffer[:bytes_written_i]
@@ -165,8 +167,8 @@ class Socket(IRCObject.Object):
def clear_send_buffer(self):
self._queued_lines.clear()
- def waiting_send(self) -> bool:
- return bool(len(self._write_buffer)) or bool(len(self._queued_lines))
+ def waiting_throttled_send(self) -> bool:
+ return bool(len(self._queued_lines))
def waiting_immediate_send(self) -> bool:
return bool(len(self._write_buffer))