import aiohttp import aiohttp.web import argparse import asyncio import collections import concurrent.futures import json import logging import signal logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') class MessageQueue: # An object holding onto the messages received from nodeping # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. # Differences to asyncio.Queue include: # - No maxsize # - No put coroutine (not necessary since the queue can never be full) # - Only one concurrent getter # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) def __init__(self): self._getter = None # None | asyncio.Future self._queue = collections.deque() async def get(self): if self._getter is not None: raise RuntimeError('Cannot get concurrently') if len(self._queue) == 0: self._getter = asyncio.get_running_loop().create_future() logging.debug('Awaiting getter') try: await self._getter except asyncio.CancelledError: logging.debug('Cancelled getter') self._getter = None raise logging.debug('Awaited getter') self._getter = None # For testing the cancellation/putting back onto the queue #logging.debug('Delaying message queue get') #await asyncio.sleep(3) #logging.debug('Done delaying') return self.get_nowait() def get_nowait(self): if len(self._queue) == 0: raise asyncio.QueueEmpty return self._queue.popleft() def put_nowait(self, item): self._queue.append(item) if self._getter is not None: self._getter.set_result(None) def putleft_nowait(self, item): self._queue.appendleft(item) if self._getter is not None: self._getter.set_result(None) def qsize(self): return len(self._queue) class IRCClientProtocol(asyncio.Protocol): def __init__(self, messageQueue, stopEvent, loop, nick, real, channel): logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}') self.messageQueue = messageQueue self.stopEvent = stopEvent self.loop = loop self.nick = nick self.real = real self.channel = channel self.channelb = channel.encode('utf-8') self.buffer = b'' self.connected = False def send(self, data): logging.info(f'Send: {data!r}') self.transport.write(data + b'\r\n') def connection_made(self, transport): logging.info('Connected') self.transport = transport self.connected = True nickb = self.nick.encode('utf-8') self.send(b'NICK ' + nickb) self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.real.encode('utf-8')) self.send(b'JOIN ' + self.channelb) asyncio.create_task(self.send_messages()) async def _get_message(self): logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') messageFuture = asyncio.create_task(self.messageQueue.get()) done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) if self.stopEvent.is_set(): if messageFuture in pending: logging.debug('Cancelling messageFuture') messageFuture.cancel() try: await messageFuture except asyncio.CancelledError: logging.debug('Cancelled messageFuture') pass else: # messageFuture is already done but we're stopping, so put the result back onto the queue self.messageQueue.putleft_nowait(messageFuture.result()) return None assert messageFuture in done, 'Invalid state: messageFuture not in done futures' return messageFuture.result() async def send_messages(self): while self.connected: logging.debug(f'{id(self)}: trying to get a message') message = await self._get_message() logging.debug(f'{id(self)}: got message: {message!r}') if message is None: break self.send(b'PRIVMSG ' + self.channelb + b' :' + message.encode('utf-8')) #TODO self.messageQueue.putleft_nowait if delivery fails await asyncio.sleep(1) # Rate limit def data_received(self, data): logging.debug(f'Data received: {data!r}') # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. messages = data.split(b'\r\n') if self.buffer: self.message_received(self.buffer + messages[0]) messages = messages[1:] for message in messages[:-1]: self.message_received(message) self.buffer = messages[-1] def message_received(self, message): logging.info(f'Message received: {message!r}') if message.startswith(b'PING '): self.send(b'PONG ' + message[5:]) def connection_lost(self, exc): logging.info('The server closed the connection') self.connected = False self.stopEvent.set() class WebServer: def __init__(self, messageQueue, host, port, auth): self.messageQueue = messageQueue self.host = host self.port = port self.auth = auth if auth: self.authHeader = f'Basic {base64.b64encode(auth.encode("utf-8")).decode("utf-8")}' self._app = aiohttp.web.Application() self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)]) async def run(self, stopEvent): runner = aiohttp.web.AppRunner(self._app) await runner.setup() site = aiohttp.web.TCPSite(runner, self.host, self.port) await site.start() await stopEvent.wait() await runner.cleanup() async def nodeping_post(self, request): logging.info(f'Received request with data: {await request.read()!r}') authHeader = request.headers.get('Authorization') if self.auth and (not authHeader or authHeader != self.authHeader): return aiohttp.web.HTTPForbidden() try: data = await request.json() except (aiohttp.ContentTypeError, json.JSONDecodeError) as e: logging.error(f'Received invalid data: {await request.read()!r}') return aiohttp.web.HTTPBadRequest() if 'message' not in data: logging.error(f'Received invalid data: {await request.read()!r}') return aiohttp.web.HTTPBadRequest() if '\r' in data['message'] or '\n' in data['message']: logging.error(f'Received invalid data: {await request.read()!r}') return aiohttp.web.HTTPBadRequest() logging.debug(f'Putting to message queue {id(self.messageQueue)}') self.messageQueue.put_nowait(data['message']) return aiohttp.web.HTTPOk() async def run_irc(loop, messageQueue, sigintEvent, host, port, ssl, nick, real, channel): stopEvent = asyncio.Event() while True: stopEvent.clear() try: transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop, nick = nick, real = real, channel = channel), host, port, ssl = ssl) try: await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) finally: transport.close() except (ConnectionRefusedError, asyncio.TimeoutError) as e: logging.error(str(e)) await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) if sigintEvent.is_set(): break async def run_webserver(loop, messageQueue, sigintEvent, host, port, auth): server = WebServer(messageQueue, host, port, auth) await server.run(sigintEvent) def parse_args(): parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--irchost', type = str, help = 'IRC server hostname', default = 'irc.hackint.org') parser.add_argument('--ircport', type = int, help = 'IRC server port', default = 6697) parser.add_argument('--ircssl', choices = ['yes', 'no', 'insecure'], help = 'enable, disable, or use insecure SSL/TLS', default = 'yes') parser.add_argument('--ircnick', help = 'IRC nickname', default = 'npbot') parser.add_argument('--ircreal', help = 'IRC realname', default = 'I am a bot.') parser.add_argument('--ircchannel', help = 'IRC channel to join and post messages', default = '#nodeping') parser.add_argument('--webhost', type = str, help = 'web server host to bind to', default = '127.0.0.1') parser.add_argument('--webport', type = int, help = 'web server port to bind to', default = 8080) parser.add_argument('--webauth', type = str, help = 'basic auth data (user:pass, or None to disable the check)', default = None) return parser.parse_args() async def main(): args = parse_args() ssl = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}[args.ircssl] loop = asyncio.get_running_loop() messageQueue = MessageQueue() sigintEvent = asyncio.Event() def sigint_callback(): logging.info('Got SIGINT') nonlocal sigintEvent sigintEvent.set() loop.add_signal_handler(signal.SIGINT, sigint_callback) irc = run_irc(loop, messageQueue, sigintEvent, host = args.irchost, port = args.ircport, ssl = ssl, nick = args.ircnick, real = args.ircreal, channel = args.ircchannel) webserver = run_webserver(loop, messageQueue, sigintEvent, host = args.webhost, port = args.webport, auth = args.webauth) await asyncio.gather(irc, webserver) asyncio.run(main())