aboutsummaryrefslogtreecommitdiff
path: root/http2irc.py
diff options
context:
space:
mode:
Diffstat (limited to 'http2irc.py')
-rw-r--r--http2irc.py375
1 files changed, 375 insertions, 0 deletions
diff --git a/http2irc.py b/http2irc.py
new file mode 100644
index 0000000..dd786b4
--- /dev/null
+++ b/http2irc.py
@@ -0,0 +1,375 @@
+import aiohttp
+import aiohttp.web
+import asyncio
+import base64
+import collections
+import concurrent.futures
+import json
+import logging
+import signal
+import ssl
+import sys
+import toml
+import types
+
+
+logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')
+
+
+SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
+
+
+class InvalidConfig(Exception):
+ '''Error in configuration file'''
+
+
+def _mapping_to_namespace(d):
+ '''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively'''
+ return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()})
+
+
+class Config:
+ def __init__(self, filename):
+ self._filename = filename
+ # Set below:
+ self.irc = None
+ self.web = None
+ self.maps = None
+
+ with open(self._filename, 'r') as fp:
+ obj = toml.load(fp)
+
+ logging.info(repr(obj))
+
+ # Sanity checks
+ if any(x not in ('irc', 'web', 'maps') for x in obj.keys()):
+ raise InvalidConfig('Unknown sections found in base object')
+ if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
+ raise InvalidConfig('Invalid section type(s), expected objects/dicts')
+ if 'irc' in obj:
+ if any(x not in ('host', 'port', 'ssl', 'nick', 'real') for x in obj['irc']):
+ raise InvalidConfig('Unknown key found in irc section')
+ if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
+ raise InvalidConfig('Invalid IRC host')
+ if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
+ raise InvalidConfig('Invalid IRC port')
+ if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
+ raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
+ if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
+ raise InvalidConfig('Invalid IRC nick')
+ if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
+ raise InvalidConfig('Invalid IRC realname')
+ if 'web' in obj:
+ if any(x not in ('host', 'port') for x in obj['web']):
+ raise InvalidConfig('Unknown key found in web section')
+ if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
+ raise InvalidConfig('Invalid web hostname')
+ if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
+ raise InvalidConfig('Invalid web port')
+ if 'maps' in obj:
+ for key, map_ in obj['maps'].items():
+ # Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace.
+ #TODO: Support for fancier identifiers (PEP 3131)?
+ if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '':
+ raise InvalidConfig(f'Invalid map key {key!r}')
+ if not isinstance(map_, collections.abc.Mapping):
+ raise InvalidConfig(f'Invalid map for {key!r}')
+ if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_):
+ raise InvalidConfig(f'Unknown key(s) found in map {key!r}')
+ #TODO: Check values
+
+ # Default values
+ self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.'}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
+
+ # Fill in default values for the maps
+ for key, map_ in obj['maps'].items():
+ if 'webpath' not in map_:
+ map_['webpath'] = f'/{key}'
+ if 'ircchannel' not in map_:
+ map_['ircchannel'] = f'#{key}'
+ if 'auth' not in map_:
+ map_['auth'] = False
+
+ # Merge in what was read from the config file and convert to SimpleNamespace
+ for key in ('irc', 'web', 'maps'):
+ if key in obj:
+ self._obj[key].update(obj[key])
+ setattr(self, key, _mapping_to_namespace(self._obj[key]))
+
+ def __repr__(self):
+ return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})'
+
+ def reread(self):
+ return Config(self._filename)
+
+
+class MessageQueue:
+ # An object holding onto the messages received from nodeping
+ # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
+ # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
+ # Differences to asyncio.Queue include:
+ # - No maxsize
+ # - No put coroutine (not necessary since the queue can never be full)
+ # - Only one concurrent getter
+ # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
+
+ def __init__(self):
+ self._getter = None # None | asyncio.Future
+ self._queue = collections.deque()
+
+ async def get(self):
+ if self._getter is not None:
+ raise RuntimeError('Cannot get concurrently')
+ if len(self._queue) == 0:
+ self._getter = asyncio.get_running_loop().create_future()
+ logging.debug('Awaiting getter')
+ try:
+ await self._getter
+ except asyncio.CancelledError:
+ logging.debug('Cancelled getter')
+ self._getter = None
+ raise
+ logging.debug('Awaited getter')
+ self._getter = None
+ # For testing the cancellation/putting back onto the queue
+ #logging.debug('Delaying message queue get')
+ #await asyncio.sleep(3)
+ #logging.debug('Done delaying')
+ return self.get_nowait()
+
+ def get_nowait(self):
+ if len(self._queue) == 0:
+ raise asyncio.QueueEmpty
+ return self._queue.popleft()
+
+ def put_nowait(self, item):
+ self._queue.append(item)
+ if self._getter is not None:
+ self._getter.set_result(None)
+
+ def putleft_nowait(self, item):
+ self._queue.appendleft(item)
+ if self._getter is not None:
+ self._getter.set_result(None)
+
+ def qsize(self):
+ return len(self._queue)
+
+
+class IRCClientProtocol(asyncio.Protocol):
+ def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
+ logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}')
+ self.messageQueue = messageQueue
+ self.connectionClosedEvent = connectionClosedEvent
+ self.loop = loop
+ self.config = config
+ self.buffer = b''
+ self.connected = False
+ self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
+
+ def connection_made(self, transport):
+ logging.info('Connected')
+ self.transport = transport
+ self.connected = True
+ nickb = self.config.irc.nick.encode('utf-8')
+ self.send(b'NICK ' + nickb)
+ self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8'))
+ self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long
+ asyncio.create_task(self.send_messages())
+
+ def update_channels(self, channels: set):
+ channelsToPart = self.channels - channels
+ channelsToJoin = channels - self.channels
+ self.channels = channels
+
+ if self.connected:
+ if channelsToPart:
+ #TODO: Split if too long
+ self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8'))
+ if channelsToJoin:
+ self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8'))
+
+ def send(self, data):
+ logging.info(f'Send: {data!r}')
+ self.transport.write(data + b'\r\n')
+
+ async def _get_message(self):
+ logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
+ messageFuture = asyncio.create_task(self.messageQueue.get())
+ done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
+ if self.connectionClosedEvent.is_set():
+ if messageFuture in pending:
+ logging.debug('Cancelling messageFuture')
+ messageFuture.cancel()
+ try:
+ await messageFuture
+ except asyncio.CancelledError:
+ logging.debug('Cancelled messageFuture')
+ pass
+ else:
+ # messageFuture is already done but we're stopping, so put the result back onto the queue
+ self.messageQueue.putleft_nowait(messageFuture.result())
+ return None, None
+ assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
+ return messageFuture.result()
+
+ async def send_messages(self):
+ while self.connected:
+ logging.debug(f'{id(self)}: trying to get a message')
+ channel, message = await self._get_message()
+ logging.debug(f'{id(self)}: got message: {message!r}')
+ if message is None:
+ break
+ self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8'))
+ #TODO self.messageQueue.putleft_nowait if delivery fails
+ await asyncio.sleep(1) # Rate limit
+
+ def data_received(self, data):
+ logging.debug(f'Data received: {data!r}')
+ # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
+ # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
+ # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
+ messages = data.split(b'\r\n')
+ if self.buffer:
+ self.message_received(self.buffer + messages[0])
+ messages = messages[1:]
+ for message in messages[:-1]:
+ self.message_received(message)
+ self.buffer = messages[-1]
+
+ def message_received(self, message):
+ logging.info(f'Message received: {message!r}')
+ if message.startswith(b'PING '):
+ self.send(b'PONG ' + message[5:])
+
+ def connection_lost(self, exc):
+ logging.info('The server closed the connection')
+ self.connected = False
+ self.connectionClosedEvent.set()
+
+
+class IRCClient:
+ def __init__(self, messageQueue, config):
+ self.messageQueue = messageQueue
+ self.config = config
+ self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
+
+ self._transport = None
+ self._protocol = None
+
+ def update_config(self, config):
+ needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl)
+ self.config = config
+ if self._transport: # if currently connected:
+ if needReconnect:
+ self._transport.close()
+ else:
+ self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
+ self._protocol.update_channels(self.channels)
+
+ async def run(self, loop, sigintEvent):
+ connectionClosedEvent = asyncio.Event()
+ while True:
+ connectionClosedEvent.clear()
+ try:
+ self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = SSL_CONTEXTS[self.config.irc.ssl])
+ try:
+ await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
+ finally:
+ self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
+ except (ConnectionRefusedError, asyncio.TimeoutError) as e:
+ logging.error(str(e))
+ await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
+ if sigintEvent.is_set():
+ break
+
+
+class WebServer:
+ def __init__(self, messageQueue, config):
+ self.messageQueue = messageQueue
+ self.config = config
+
+ self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth
+
+ self._app = aiohttp.web.Application()
+ self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)])
+
+ self.update_config(config)
+
+ def update_config(self, config):
+ self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()}
+ needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port)
+ self.config = config
+ if needRebind:
+ #TODO
+ logging.error('Webserver host or port changes while running are currently not supported')
+
+ async def run(self, stopEvent):
+ runner = aiohttp.web.AppRunner(self._app)
+ await runner.setup()
+ site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port)
+ await site.start()
+ await stopEvent.wait()
+ await runner.cleanup()
+
+ async def post(self, request):
+ logging.info(f'Received request for {request.path!r} with data {await request.read()!r}')
+ try:
+ channel, auth = self._paths[request.path]
+ except KeyError:
+ raise aiohttp.web.HTTPNotFound()
+ if auth:
+ authHeader = request.headers.get('Authorization')
+ if not authHeader or authHeader != auth:
+ raise aiohttp.web.HTTPForbidden()
+ try:
+ data = await request.json()
+ except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
+ logging.error(f'Invalid data received: {await request.read()!r}')
+ raise aiohttp.web.HTTPBadRequest()
+ if 'message' not in data:
+ logging.error(f'Message missing: {await request.read()!r}')
+ raise aiohttp.web.HTTPBadRequest()
+ if '\r' in data['message'] or '\n' in data['message']:
+ logging.error(f'Linebreaks in message: {await request.read()!r}')
+ raise aiohttp.web.HTTPBadRequest()
+ logging.debug(f'Putting message {data["message"]!r} for {channel} into message queue')
+ self.messageQueue.put_nowait((channel, data['message']))
+ raise aiohttp.web.HTTPOk()
+
+
+async def main():
+ if len(sys.argv) != 2:
+ print('Usage: web2irc.py CONFIGFILE', file = sys.stderr)
+ sys.exit(1)
+ configFile = sys.argv[1]
+ config = Config(configFile)
+
+ loop = asyncio.get_running_loop()
+
+ messageQueue = MessageQueue()
+
+ irc = IRCClient(messageQueue, config)
+ webserver = WebServer(messageQueue, config)
+
+ sigintEvent = asyncio.Event()
+ def sigint_callback():
+ logging.info('Got SIGINT')
+ nonlocal sigintEvent
+ sigintEvent.set()
+ loop.add_signal_handler(signal.SIGINT, sigint_callback)
+
+ def sigusr1_callback():
+ logging.info('Got SIGUSR1, reloading config')
+ nonlocal config, irc, webserver
+ newConfig = config.reread()
+ config = newConfig
+ irc.update_config(config)
+ webserver.update_config(config)
+ loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)
+
+ await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent))
+
+
+if __name__ == '__main__':
+ asyncio.run(main())