aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jesopo2018-10-06 15:37:05 +0100
committerGravatar jesopo2018-10-06 15:37:05 +0100
commit0794a5173ae2a45aad4f91f76a0fd0a9fe2939c4 (patch)
tree6652a20f47bb0113d9a70b392533a608664b12a4 /src
parentAttempt to register servers for read/write when sending github hook notices (diff)
signature
Add a way to track non-IRC sockets within the main epoll loop; use this for a
unix domain control socket!
Diffstat (limited to 'src')
-rw-r--r--src/ControlSocket.py36
-rw-r--r--src/IRCBot.py53
-rw-r--r--src/IRCServer.py15
-rw-r--r--src/Socket.py51
4 files changed, 139 insertions, 16 deletions
diff --git a/src/ControlSocket.py b/src/ControlSocket.py
new file mode 100644
index 00000000..ead2624b
--- /dev/null
+++ b/src/ControlSocket.py
@@ -0,0 +1,36 @@
+import os, socket
+from src import Socket
+
+class ControlSocket(object):
+ def __init__(self, bot):
+ self.bot = bot
+
+ location = bot.config["control-socket"]
+ if os.path.exists(location):
+ os.unlink(location)
+ self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.socket.bind(location)
+ self.socket.listen()
+ self.connected = True
+
+ def fileno(self):
+ return self.socket.fileno()
+ def waiting_send(self):
+ return False
+ def _send(self):
+ pass
+ def read(self):
+ client, addr = self.socket.accept()
+ self.bot.add_socket(Socket.Socket(client, self.on_read))
+ return []
+ def parse_data(self, data):
+ command = data.split(" ", 1)[0].upper()
+ if command == "TRIGGER":
+ pass
+ else:
+ raise ValueError("unknown control socket command: '%s'" %
+ command)
+
+ def on_read(self, sock, data):
+ data = data.strip("\r\n")
+ print(data)
diff --git a/src/IRCBot.py b/src/IRCBot.py
index e12e4d9e..29c45b9b 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,5 +1,6 @@
import os, select, sys, threading, time, traceback, uuid
-from . import EventManager, Exports, IRCServer, Logging, ModuleManager
+from src import ControlSocket, EventManager, Exports, IRCServer, Logging
+from src import ModuleManager, utils
class Bot(object):
def __init__(self, directory, args, cache, config, database, events,
@@ -15,13 +16,16 @@ class Bot(object):
self.modules = modules
self.timers = timers
- events.on("timer.reconnect").hook(self.reconnect)
self.start_time = time.time()
self.lock = threading.Lock()
- self.servers = {}
self.running = True
self.poll = select.epoll()
+ self.servers = {}
+ self.other_sockets = {}
+ self.control_socket = ControlSocket.ControlSocket(self)
+ self.add_socket(self.control_socket)
+
def add_server(self, server_id, connect=True):
(_, alias, hostname, port, password, ipv4, tls, bindhost, nickname,
username, realname) = self.database.servers.get(server_id)
@@ -36,6 +40,14 @@ class Bot(object):
self.connect(new_server)
return new_server
+ def add_socket(self, sock):
+ self.other_sockets[sock.fileno()] = sock
+ self.poll.register(sock.fileno(), select.EPOLLIN)
+
+ def remove_socket(self, sock):
+ del self.other_sockets[sock.fileno()]
+ self.poll.unregister(sock.fileno())
+
def get_server(self, id):
for server in self.servers.values():
if server.id == id:
@@ -101,6 +113,7 @@ class Bot(object):
pass
del self.servers[server.fileno()]
+ @utils.hook("timer.reconnect")
def reconnect(self, event):
server = self.add_server(event["server_id"], False)
if self.connect(server):
@@ -128,19 +141,30 @@ class Bot(object):
self.cache.expire()
for fd, event in events:
+ sock = None
+ irc = False
if fd in self.servers:
- server = self.servers[fd]
+ sock = self.servers[fd]
+ irc = True
+ elif fd in self.other_sockets:
+ sock = self.other_sockets[fd]
+
+ if sock:
if event & select.EPOLLIN:
- lines = server.read()
- for line in lines:
- self.log.debug("%s (raw) | %s", [str(server), line])
- server.parse_line(line)
+ data = sock.read()
+ if data == None:
+ sock.disconnect()
+ for piece in data:
+ if irc:
+ self.log.debug("%s (raw) | %s",
+ [str(sock), data])
+ sock.parse_data(piece)
elif event & select.EPOLLOUT:
- server._send()
- self.register_read(server)
+ sock._send()
+ self.register_read(sock)
elif event & select.EPULLHUP:
print("hangup")
- server.disconnect()
+ sock.disconnect()
for server in list(self.servers.values()):
if server.read_timed_out():
@@ -160,4 +184,11 @@ class Bot(object):
str(server), reconnect_delay))
elif server.waiting_send() and server.throttle_done():
self.register_both(server)
+
+ for sock in list(self.other_sockets.values()):
+ if not sock.connected:
+ self.remove_socket(sock)
+ elif sock.waiting_send():
+ self.register_both(sock)
+
self.lock.release()
diff --git a/src/IRCServer.py b/src/IRCServer.py
index 35050ff3..c389efed 100644
--- a/src/IRCServer.py
+++ b/src/IRCServer.py
@@ -199,7 +199,7 @@ class Server(IRCObject.Object):
for user in channel.users:
user.part_channel(channel)
del self.channels[channel.name]
- def parse_line(self, line):
+ def parse_data(self, line):
if not line:
return
self.events.on("raw").call(server=self, line=line)
@@ -212,16 +212,22 @@ class Server(IRCObject.Object):
def read(self):
data = b""
try:
- data = self.read_buffer + self.socket.recv(4096)
+ data = self.socket.recv(4096)
except (ConnectionResetError, socket.timeout):
self.disconnect()
- return []
+ return None
+ if not data:
+ self.disconnect()
+ return None
+ data = self.read_buffer+data
self.read_buffer = b""
+
data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
if data_lines[-1]:
self.read_buffer = data_lines[-1]
data_lines.pop(-1)
decoded_lines = []
+
for line in data_lines:
try:
line = line.decode(self.get_setting(
@@ -233,8 +239,7 @@ class Server(IRCObject.Object):
except:
continue
decoded_lines.append(line)
- if not decoded_lines:
- self.disconnect()
+
self.last_read = time.monotonic()
self.ping_sent = False
return decoded_lines
diff --git a/src/Socket.py b/src/Socket.py
new file mode 100644
index 00000000..f405b0a9
--- /dev/null
+++ b/src/Socket.py
@@ -0,0 +1,51 @@
+
+
+class Socket(object):
+ def __init__(self, socket, on_read, encoding="utf8"):
+ self.socket = socket
+ self._on_read = on_read
+ self.encoding = encoding
+
+ self._write_buffer = b""
+ self._read_buffer = b""
+ self.delimiter = None
+ self.length = None
+ self.connected = True
+
+ def fileno(self):
+ return self.socket.fileno()
+
+ def disconnect(self):
+ self.connected = False
+
+ def _decode(self, s):
+ return s.decode(self.encoding) if self.encoding else s
+ def _encode(self, s):
+ return s.encode(self.encoding) if self.encoding else s
+
+ def read(self):
+ data = self.socket.recv(1024)
+ if not data:
+ return None
+
+ data = self._read_buffer+data
+ self._read_buffer = b""
+ if not self.delimiter == None:
+ data_split = data.split(delimiter)
+ if data_split[-1]:
+ self._read_buffer = data_split.pop(-1)
+ return [self._decode(data) for data in data_split]
+ return [data.decode(self.encoding)]
+
+ def parse_data(self, data):
+ self._on_read(self, data)
+
+ def send(self, data):
+ self._write_buffer += self._encode(data)
+
+ def _send(self):
+ self._write_buffer = self._write_buffer[self.socket.send(
+ self._write_buffer):]
+
+ def waiting_send(self):
+ return bool(len(self._write_buffer))