aboutsummaryrefslogtreecommitdiff
path: root/http2irc.py
diff options
context:
space:
mode:
authorGravatar JustAnotherArchivist2021-10-09 01:50:26 +0000
committerGravatar JustAnotherArchivist2021-10-09 01:50:26 +0000
commitbdb396caff34816e79b79e4b4484744bc4fc5aca (patch)
treea5244de1d0c9083315a132e5fc6156ffd66fea9f /http2irc.py
parentFix crash due to missing time import (diff)
signature
Merge changes from irclog
Port to ircstates/irctokens, more capabilities, IRC family config, fix various small bugs
Diffstat (limited to 'http2irc.py')
-rw-r--r--http2irc.py320
1 files changed, 223 insertions, 97 deletions
diff --git a/http2irc.py b/http2irc.py
index 36022b1..c71e082 100644
--- a/http2irc.py
+++ b/http2irc.py
@@ -3,13 +3,17 @@ import aiohttp.web
import asyncio
import base64
import collections
-import concurrent.futures
+import functools
import importlib.util
import inspect
+import ircstates
+import irctokens
import itertools
+import json
import logging
import os.path
import signal
+import socket
import ssl
import string
import sys
@@ -53,14 +57,20 @@ async def wait_cancel_pending(aws, paws = None, **kwargs):
if paws is None:
paws = set()
tasks = aws | paws
+ logger.debug(f'waiting for {tasks!r}')
done, pending = await asyncio.wait(tasks, **kwargs)
+ logger.debug(f'done waiting for {tasks!r}; cancelling pending non-persistent tasks: {pending!r}')
for task in pending:
if task not in paws:
+ logger.debug(f'cancelling {task!r}')
task.cancel()
+ logger.debug(f'awaiting cancellation of {task!r}')
try:
await task
except asyncio.CancelledError:
pass
+ logger.debug(f'done cancelling {task!r}')
+ logger.debug(f'done wait_cancel_pending {tasks!r}')
return done, pending
@@ -92,7 +102,7 @@ class Config(dict):
except (ValueError, AssertionError) as e:
raise InvalidConfig('Invalid log format: parsing failed') from e
if 'irc' in obj:
- if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
+ if any(x not in ('host', 'port', 'ssl', 'family', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
raise InvalidConfig('Unknown key found in irc section')
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
raise InvalidConfig('Invalid IRC host')
@@ -100,6 +110,10 @@ class Config(dict):
raise InvalidConfig('Invalid IRC port')
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
+ if 'family' in obj['irc']:
+ if obj['irc']['family'] not in ('inet', 'INET', 'inet6', 'INET6'):
+ raise InvalidConfig('Invalid IRC family')
+ obj['irc']['family'] = getattr(socket, f'AF_{obj["irc"]["family"].upper()}')
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
raise InvalidConfig('Invalid IRC nick')
if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
@@ -192,7 +206,12 @@ class Config(dict):
raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value')
# Default values
- finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
+ finalObj = {
+ 'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'},
+ 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'family': 0, 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None},
+ 'web': {'host': '127.0.0.1', 'port': 8080},
+ 'maps': {}
+ }
# Fill in default values for the maps
for key, map_ in obj['maps'].items():
@@ -253,7 +272,7 @@ class Config(dict):
class MessageQueue:
- # An object holding onto the messages received from nodeping
+ # An object holding onto the messages received over HTTP for sending to IRC
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
# Differences to asyncio.Queue include:
@@ -310,12 +329,14 @@ class MessageQueue:
class IRCClientProtocol(asyncio.Protocol):
logger = logging.getLogger('http2irc.IRCClientProtocol')
- def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
- self.messageQueue = messageQueue
+ def __init__(self, http2ircMessageQueue, connectionClosedEvent, loop, config, channels):
+ self.http2ircMessageQueue = http2ircMessageQueue
self.connectionClosedEvent = connectionClosedEvent
self.loop = loop
self.config = config
self.lastRecvTime = None
+ self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit
+ self.sendQueue = asyncio.Queue()
self.buffer = b''
self.connected = False
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
@@ -323,7 +344,14 @@ class IRCClientProtocol(asyncio.Protocol):
self.pongReceivedEvent = asyncio.Event()
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
self.authenticated = False
- self.usermask = None
+ self.server = ircstates.Server(self.config['irc']['host'])
+ self.capReqsPending = set() # Capabilities requested from the server but not yet ACKd or NAKd
+ self.caps = set() # Capabilities acknowledged by the server
+ self.whoxQueue = collections.deque() # Names of channels that were joined successfully but for which no WHO (WHOX) query was sent yet
+ self.whoxChannel = None # Name of channel for which a WHO query is currently running
+ self.whoxReply = [] # List of (nickname, account) tuples from the currently running WHO query
+ self.whoxStartTime = None
+ self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}
@staticmethod
def nick_command(nick: str):
@@ -334,17 +362,16 @@ class IRCClientProtocol(asyncio.Protocol):
nickb = nick.encode('utf-8')
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')
- def _maybe_set_usermask(self, usermask):
- if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')):
- self.usermask = usermask
- self.logger.debug(f'Usermask is now {usermask!r}')
-
def connection_made(self, transport):
self.logger.info('IRC connected')
self.transport = transport
self.connected = True
+ caps = [b'multi-prefix', b'userhost-in-names', b'away-notify', b'account-notify', b'extended-join']
if self.sasl:
- self.send(b'CAP REQ :sasl')
+ caps.append(b'sasl')
+ for cap in caps:
+ self.capReqsPending.add(cap.decode('ascii'))
+ self.send(b'CAP REQ :' + cap)
self.send(self.nick_command(self.config['irc']['nick']))
self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))
@@ -393,15 +420,41 @@ class IRCClientProtocol(asyncio.Protocol):
self._send_join_part(b'JOIN', channelsToJoin)
def send(self, data):
- self.logger.debug(f'Send: {data!r}')
+ self.logger.debug(f'Queueing for send: {data!r}')
if len(data) > 510:
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
+ self.sendQueue.put_nowait(data)
+
+ def _direct_send(self, data):
+ self.logger.debug(f'Send: {data!r}')
+ time_ = time.time()
self.transport.write(data + b'\r\n')
+ return time_
+
+ async def send_queue(self):
+ while True:
+ self.logger.debug('Trying to get data from send queue')
+ t = asyncio.create_task(self.sendQueue.get())
+ done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
+ if self.connectionClosedEvent.is_set():
+ break
+ assert t in done, f'{t!r} is not in {done!r}'
+ data = t.result()
+ self.logger.debug(f'Got {data!r} from send queue')
+ now = time.time()
+ if self.lastSentTime is not None and now - self.lastSentTime < 1:
+ self.logger.debug(f'Rate limited')
+ await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now)
+ if self.connectionClosedEvent.is_set():
+ break
+ time_ = self._direct_send(data)
+ if self.lastSentTime is not None:
+ self.lastSentTime = time_
async def _get_message(self):
- self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
- messageFuture = asyncio.create_task(self.messageQueue.get())
- done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = concurrent.futures.FIRST_COMPLETED)
+ self.logger.debug(f'Message queue {id(self.http2ircMessageQueue)} length: {self.http2ircMessageQueue.qsize()}')
+ messageFuture = asyncio.create_task(self.http2ircMessageQueue.get())
+ done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = asyncio.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set():
if messageFuture in pending:
self.logger.debug('Cancelling messageFuture')
@@ -413,11 +466,16 @@ class IRCClientProtocol(asyncio.Protocol):
pass
else:
# messageFuture is already done but we're stopping, so put the result back onto the queue
- self.messageQueue.putleft_nowait(messageFuture.result())
- return None, None
+ self.http2ircMessageQueue.putleft_nowait(messageFuture.result())
+ return None, None, None
assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
return messageFuture.result()
+ def _self_usermask_length(self):
+ if not self.server.nickname or not self.server.username or not self.server.hostname:
+ return 100
+ return len(self.server.nickname) + len(self.server.username) + len(self.server.hostname)
+
async def send_messages(self):
while self.connected:
self.logger.debug(f'Trying to get a message')
@@ -427,7 +485,7 @@ class IRCClientProtocol(asyncio.Protocol):
break
channelB = channel.encode('utf-8')
messageB = message.encode('utf-8')
- usermaskPrefixLength = 1 + (len(self.usermask) if self.usermask else 100) + 1
+ usermaskPrefixLength = 1 + self._self_usermask_length() + 1
if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
# Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
self.logger.debug(f'Message too long, overlongmode = {overlongmode}')
@@ -466,20 +524,19 @@ class IRCClientProtocol(asyncio.Protocol):
messageB = message.encode('utf-8')
if overlongmode == 'split':
for msg in reversed(messages):
- self.messageQueue.putleft_nowait((channel, msg, overlongmode))
+ self.http2ircMessageQueue.putleft_nowait((channel, msg, overlongmode))
elif overlongmode == 'truncate':
- self.messageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
+ self.http2ircMessageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
else:
self.logger.info(f'Sending {message!r} to {channel!r}')
self.unconfirmedMessages.append((channel, message, overlongmode))
self.send(b'PRIVMSG ' + channelB + b' :' + messageB)
- await asyncio.sleep(1) # Rate limit
async def confirm_messages(self):
while self.connected:
- await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
+ await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
+ self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
self.unconfirmedMessages = []
break
if not self.unconfirmedMessages:
@@ -488,18 +545,19 @@ class IRCClientProtocol(asyncio.Protocol):
self.logger.debug('Trying to confirm message delivery')
self.pongReceivedEvent.clear()
self.send(b'PING :42')
- await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED, timeout = 5)
+ await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 5)
self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}')
if not self.pongReceivedEvent.is_set():
# No PONG received in five seconds, assume connection's dead
self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue')
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
+ self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
self.transport.close()
self.unconfirmedMessages = []
def data_received(self, data):
+ time_ = time.time()
self.logger.debug(f'Data received: {data!r}')
- self.lastRecvTime = time.time()
+ self.lastRecvTime = time_
# If there's any data left in the buffer, prepend it to the data. Split on CRLF.
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
@@ -507,104 +565,146 @@ class IRCClientProtocol(asyncio.Protocol):
data = self.buffer + data
messages = data.split(b'\r\n')
for message in messages[:-1]:
- self.message_received(message)
+ lines = self.server.recv(message + b'\r\n')
+ assert len(lines) == 1, f'recv did not return exactly one line: {message!r} -> {lines!r}'
+ self.message_received(time_, message, lines[0])
+ self.server.parse_tokens(lines[0])
self.buffer = messages[-1]
- def message_received(self, message):
- self.logger.debug(f'Message received: {message!r}')
- rawMessage = message
- if message.startswith(b':') and b' ' in message:
- # Prefixed message, extract command + parameters (the prefix cannot contain a space)
- message = message.split(b' ', 1)[1]
+ def message_received(self, time_, message, line):
+ self.logger.debug(f'Message received at {time_}: {message!r}')
+
+ maybeTriggerWhox = False
# PING/PONG
- if message.startswith(b'PING '):
- self.send(b'PONG ' + message[5:])
- elif message.startswith(b'PONG '):
+ if line.command == 'PING':
+ self._direct_send(irctokens.build('PONG', line.params).format().encode('utf-8'))
+ elif line.command == 'PONG':
self.pongReceivedEvent.set()
- # SASL
- elif message.startswith(b'CAP ') and self.sasl:
- if message[message.find(b' ', 4) + 1:] == b'ACK :sasl':
- self.send(b'AUTHENTICATE EXTERNAL')
- else:
- self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection')
- self.transport.close()
- elif message == b'AUTHENTICATE +':
+ # IRCv3 and SASL
+ elif line.command == 'CAP':
+ if line.params[1] == 'ACK':
+ for cap in line.params[2].split(' '):
+ self.logger.debug(f'CAP ACK: {cap}')
+ self.caps.add(cap)
+ if cap == 'sasl' and self.sasl:
+ self.send(b'AUTHENTICATE EXTERNAL')
+ else:
+ self.capReqsPending.remove(cap)
+ elif line.params[1] == 'NAK':
+ self.logger.warning(f'Failed to activate CAP(s): {line.params[2]}')
+ for cap in line.params[2].split(' '):
+ self.capReqsPending.remove(cap)
+ if len(self.capReqsPending) == 0:
+ self.send(b'CAP END')
+ elif line.command == 'AUTHENTICATE' and line.params == ['+']:
self.send(b'AUTHENTICATE +')
- elif message.startswith(b'900 '): # "You are now logged in", includes the usermask
- words = message.split(b' ')
- if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]:
- if b'!~' not in words[2]:
- # At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde.
- words[2] = words[2].replace(b'!', b'!~', 1)
- self._maybe_set_usermask(words[2])
- elif message.startswith(b'903 '): # SASL auth successful
+ elif line.command == ircstates.numerics.RPL_SASLSUCCESS:
self.authenticated = True
- self.send(b'CAP END')
- elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')):
+ self.capReqsPending.remove('sasl')
+ if len(self.capReqsPending) == 0:
+ self.send(b'CAP END')
+ elif line.command in ('902', ircstates.numerics.ERR_SASLFAIL, ircstates.numerics.ERR_SASLTOOLONG, ircstates.numerics.ERR_SASLABORTED, ircstates.numerics.RPL_SASLMECHS):
self.logger.error('SASL error, terminating connection')
self.transport.close()
# NICK errors
- elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')):
+ elif line.command in ('431', ircstates.numerics.ERR_ERRONEUSNICKNAME, ircstates.numerics.ERR_NICKNAMEINUSE, '436'):
self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
self.transport.close()
# USER errors
- elif any(message.startswith(x) for x in (b'461 ', b'462 ')):
+ elif line.command in ('461', '462'):
self.logger.error(f'Failed to register: {message!r}, terminating connection')
self.transport.close()
# JOIN errors
- elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')):
+ elif line.command in (
+ ircstates.numerics.ERR_TOOMANYCHANNELS,
+ ircstates.numerics.ERR_CHANNELISFULL,
+ ircstates.numerics.ERR_INVITEONLYCHAN,
+ ircstates.numerics.ERR_BANNEDFROMCHAN,
+ ircstates.numerics.ERR_BADCHANNELKEY,
+ ):
self.logger.error(f'Failed to join channel: {message!r}, terminating connection')
self.transport.close()
# PART errors
- elif message.startswith(b'442 '):
+ elif line.command == '442':
self.logger.error(f'Failed to part channel: {message!r}')
# JOIN/PART errors
- elif message.startswith(b'403 '):
+ elif line.command == ircstates.numerics.ERR_NOSUCHCHANNEL:
self.logger.error(f'Failed to join or part channel: {message!r}')
# PRIVMSG errors
- elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')):
+ elif line.command in (ircstates.numerics.ERR_NOSUCHNICK, '404', '407', '411', '412', '413', '414'):
self.logger.error(f'Failed to send message: {message!r}')
# Connection registration reply
- elif message.startswith(b'001 '):
+ elif line.command == ircstates.numerics.RPL_WELCOME:
self.logger.info('IRC connection registered')
if self.sasl and not self.authenticated:
self.logger.error('IRC connection registered but not authenticated, terminating connection')
self.transport.close()
return
+ self.lastSentTime = time.time()
self._send_join_part(b'JOIN', self.channels)
asyncio.create_task(self.send_messages())
asyncio.create_task(self.confirm_messages())
- # JOIN success
- elif message.startswith(b'JOIN ') and not self.usermask:
- # If this is my own join message, it should contain the usermask in the prefix
- if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage:
- usermask = rawMessage.split(b' ', 1)[0][1:]
- self._maybe_set_usermask(usermask)
+ # Bot getting KICKed
+ elif line.command == 'KICK' and line.source and self.server.casefold(line.params[1]) == self.server.casefold(self.server.nickname):
+ self.logger.warning(f'Got kicked from {line.params[0]}')
+ kickedChannel = self.server.casefold(line.params[0])
+ for channel in self.channels:
+ if self.server.casefold(channel) == kickedChannel:
+ self.channels.remove(channel)
+ break
+
+ # WHOX on successful JOIN if supported to fetch account information
+ elif line.command == 'JOIN' and self.server.isupport.whox and line.source and self.server.casefold(line.hostmask.nickname) == self.server.casefold(self.server.nickname):
+ self.whoxQueue.extend(line.params[0].split(','))
+ maybeTriggerWhox = True
+
+ # WHOX response
+ elif line.command == ircstates.numerics.RPL_WHOSPCRPL and line.params[1] == '042':
+ self.whoxReply.append({'nick': line.params[4], 'hostmask': f'{line.params[4]}!{line.params[2]}@{line.params[3]}', 'account': line.params[5] if line.params[5] != '0' else None})
+
+ # End of WHOX response
+ elif line.command == ircstates.numerics.RPL_ENDOFWHO:
+ # Patch ircstates account info; ircstates does not parse the WHOX reply itself.
+ for entry in self.whoxReply:
+ if entry['account']:
+ self.server.users[self.server.casefold(entry['nick'])].account = entry['account']
+ self.whoxChannel = None
+ self.whoxReply = []
+ self.whoxStartTime = None
+ maybeTriggerWhox = True
+
+ # General fatal ERROR
+ elif line.command == 'ERROR':
+ self.logger.error(f'Server sent ERROR: {message!r}')
+ self.transport.close()
+
+ # Send next WHOX if appropriate
+ if maybeTriggerWhox and self.whoxChannel is None and self.whoxQueue:
+ self.whoxChannel = self.whoxQueue.popleft()
+ self.whoxReply = []
+ self.whoxStartTime = time.time() # Note, may not be the actual start time due to rate limiting
+ self.send(b'WHO ' + self.whoxChannel.encode('utf-8') + b' c%tuhna,042')
- # Services host change
- elif message.startswith(b'396 '):
- words = message.split(b' ')
- if len(words) >= 3:
- # Sanity check inspired by irssi src/irc/core/irc-servers.c
- if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-':
- if b'@' in words[2]: # user@host
- self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2])
- else: # host (get user from previous mask or settings)
- if self.usermask:
- user = self.usermask.split(b'@')[0].split(b'!')[1]
- else:
- user = b'~' + self.config['irc']['nick'].encode('utf-8')
- self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2])
+ async def quit(self):
+ # The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event.
+ self.logger.info('Quitting')
+ self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue
+ self._direct_send(b'QUIT :Bye')
+ await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10)
+ if not self.connectionClosedEvent.is_set():
+ self.logger.error('Quitting cleanly did not work, closing connection forcefully')
+ # Event will be set implicitly in connection_lost.
+ self.transport.close()
def connection_lost(self, exc):
self.logger.info('IRC connection lost')
@@ -615,8 +715,8 @@ class IRCClientProtocol(asyncio.Protocol):
class IRCClient:
logger = logging.getLogger('http2irc.IRCClient')
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
+ def __init__(self, http2ircMessageQueue, config):
+ self.http2ircMessageQueue = http2ircMessageQueue
self.config = config
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
@@ -647,17 +747,43 @@ class IRCClient:
while True:
connectionClosedEvent.clear()
try:
- self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
+ self.logger.debug('Creating IRC connection')
+ t = asyncio.create_task(loop.create_connection(
+ protocol_factory = lambda: IRCClientProtocol(self.http2ircMessageQueue, connectionClosedEvent, loop, self.config, self.channels),
+ host = self.config['irc']['host'],
+ port = self.config['irc']['port'],
+ ssl = self._get_ssl_context(),
+ family = self.config['irc']['family'],
+ ))
+ # No automatic cancellation of t because it's handled manually below.
+ done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30)
+ if t not in done:
+ t.cancel()
+ await t # Raises the CancelledError
+ self._transport, self._protocol = t.result()
+ self.logger.debug('Starting send queue processing')
+ sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent
+ self.logger.debug('Waiting for connection closure or SIGINT')
try:
- await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED)
+ await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
finally:
- self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
+ self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}')
+ if not connectionClosedEvent.is_set():
+ self.logger.debug('Quitting connection')
+ await self._protocol.quit()
+ if not sendTask.done():
+ sendTask.cancel()
+ try:
+ await sendTask
+ except asyncio.CancelledError:
+ pass
self._transport = None
self._protocol = None
- except (ConnectionRefusedError, asyncio.TimeoutError) as e:
- self.logger.error(str(e))
+ except (ConnectionError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e:
+ self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}')
await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5)
if sigintEvent.is_set():
+ self.logger.debug('Got SIGINT, breaking IRC loop')
break
@property
@@ -668,8 +794,8 @@ class IRCClient:
class WebServer:
logger = logging.getLogger('http2irc.WebServer')
- def __init__(self, messageQueue, ircClient, config):
- self.messageQueue = messageQueue
+ def __init__(self, http2ircMessageQueue, ircClient, config):
+ self.http2ircMessageQueue = http2ircMessageQueue
self.ircClient = ircClient
self.config = config
@@ -697,7 +823,7 @@ class WebServer:
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start()
- await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = concurrent.futures.FIRST_COMPLETED)
+ await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED)
await runner.cleanup()
if stopEvent.is_set():
break
@@ -735,7 +861,7 @@ class WebServer:
self.logger.debug(f'Processing request {id(request)} using default processor')
message = await self._default_process(request)
self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
- self.messageQueue.put_nowait((channel, message, overlongmode))
+ self.http2ircMessageQueue.put_nowait((channel, message, overlongmode))
raise aiohttp.web.HTTPOk()
async def _default_process(self, request):
@@ -777,10 +903,10 @@ async def main():
loop = asyncio.get_running_loop()
- messageQueue = MessageQueue()
+ http2ircMessageQueue = MessageQueue()
- irc = IRCClient(messageQueue, config)
- webserver = WebServer(messageQueue, irc, config)
+ irc = IRCClient(http2ircMessageQueue, config)
+ webserver = WebServer(http2ircMessageQueue, irc, config)
sigintEvent = asyncio.Event()
def sigint_callback():