aboutsummaryrefslogtreecommitdiff
path: root/src/IRCSocket.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/IRCSocket.py')
-rw-r--r--src/IRCSocket.py177
1 files changed, 177 insertions, 0 deletions
diff --git a/src/IRCSocket.py b/src/IRCSocket.py
new file mode 100644
index 00000000..e9379ddd
--- /dev/null
+++ b/src/IRCSocket.py
@@ -0,0 +1,177 @@
+import datetime, socket, ssl, time, typing
+from src import IRCLine, Logging, IRCObject, utils
+
+THROTTLE_LINES = 4
+THROTTLE_SECONDS = 1
+UNTHROTTLED_MAX_LINES = 10
+
+class Socket(IRCObject.Object):
+ def __init__(self, log: Logging.Log, encoding: str, fallback_encoding: str,
+ hostname: str, port: int, ipv4: bool, bindhost: str, tls: bool,
+ tls_verify: bool=True, cert: str=None, key: str=None):
+ self.log = log
+
+ self._encoding = encoding
+ self._fallback_encoding = fallback_encoding
+ self._hostname = hostname
+ self._port = port
+ self._ipv4 = ipv4
+ self._bindhost = bindhost
+
+ self._tls = tls
+ self._tls_verify = tls_verify
+ self._cert = cert
+ self._key = key
+
+ self._write_buffer = b""
+ self._queued_lines = [] # type: typing.List[IRCLine.Line]
+ self._buffered_lines = [] # type: typing.List[IRCLine.Line]
+ self._write_throttling = False
+ self._read_buffer = b""
+ self._recent_sends = [] # type: typing.List[float]
+ self.cached_fileno = None # type: typing.Optional[int]
+ self.bytes_written = 0
+ self.bytes_read = 0
+
+ def fileno(self) -> int:
+ return self.cached_fileno or self._socket.fileno()
+
+ def _tls_wrap(self):
+ server_hostname = None
+ if not utils.is_ip(self._hostname):
+ server_hostname = self._hostname
+
+ self._socket = utils.security.ssl_wrap(self._socket,
+ cert=self._cert, key=self._key, verify=self._tls_verify,
+ hostname=server_hostname)
+
+ def connect(self):
+ family = socket.AF_INET if self._ipv4 else socket.AF_INET6
+ self._socket = socket.socket(family, socket.SOCK_STREAM)
+
+ self._socket.settimeout(5.0)
+
+ if self._bindhost:
+ self._socket.bind((self._bindhost, 0))
+ if self._tls:
+ self._tls_wrap()
+
+ self._socket.connect((self._hostname, self._port))
+ self.cached_fileno = self._socket.fileno()
+
+ def disconnect(self):
+ self.connected = False
+ try:
+ self._socket.shutdown(socket.SHUT_RDWR)
+ except:
+ pass
+ try:
+ self._socket.close()
+ except:
+ pass
+
+ def read(self) -> typing.Optional[typing.List[str]]:
+ data = b""
+ try:
+ data = self._socket.recv(4096)
+ except (ConnectionResetError, socket.timeout, OSError):
+ self.disconnect()
+ return None
+ if not data:
+ self.disconnect()
+ return None
+ self.bytes_read += len(data)
+ data = self._read_buffer+data
+ self._read_buffer = b""
+
+ data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
+ if data_lines[-1]:
+ self._read_buffer = data_lines[-1]
+ self.log.trace("recevied and buffered non-complete line: %s",
+ [data_lines[-1]])
+
+ data_lines.pop(-1)
+ decoded_lines = []
+
+ for line in data_lines:
+ try:
+ decoded_line = line.decode(self._encoding)
+ except:
+ self.log.trace("can't decode line with '%s', falling back",
+ [self._encoding])
+ try:
+ decoded_line = line.decode(self._fallback_encoding)
+ except:
+ continue
+ decoded_lines.append(decoded_line)
+
+ self.last_read = time.monotonic()
+ self.ping_sent = False
+ return decoded_lines
+
+ def send(self, line: IRCLine.Line):
+ self._queued_lines.append(line)
+
+ def _send(self) -> typing.List[str]:
+ decoded_sent = []
+ if not len(self._write_buffer):
+ throttle_space = self.throttle_space()
+ to_buffer = self._queued_lines[:throttle_space]
+ self._queued_lines = self._queued_lines[throttle_space:]
+ for line in to_buffer:
+ decoded_data = line.decoded_data()
+ decoded_sent.append(decoded_data)
+
+ self._write_buffer += line.data()
+ self._buffered_lines.append(line)
+
+ bytes_written_i = self._socket.send(self._write_buffer)
+ bytes_written = self._write_buffer[:bytes_written_i]
+ lines_sent = bytes_written.count(b"\r\n")
+ for i in range(lines_sent):
+ self._buffered_lines.pop(0).sent()
+
+ self._write_buffer = self._write_buffer[bytes_written_i:]
+
+ self.bytes_written += bytes_written_i
+
+ now = time.monotonic()
+ self._recent_sends.append(now)
+ self.last_send = now
+
+ return decoded_sent
+
+ def waiting_send(self) -> bool:
+ return bool(len(self._write_buffer)) or bool(len(self._queued_lines))
+
+ def throttle_done(self) -> bool:
+ return self.send_throttle_timeout() == 0
+
+ def throttle_prune(self):
+ now = time.monotonic()
+ popped = 0
+ for i, recent_send in enumerate(self._recent_sends[:]):
+ time_since = now-recent_send
+ if time_since >= THROTTLE_SECONDS:
+ self._recent_sends.pop(i-popped)
+ popped += 1
+
+ def throttle_space(self) -> int:
+ if not self._write_throttling:
+ return UNTHROTTLED_MAX_LINES
+ return max(0, THROTTLE_LINES-len(self._recent_sends))
+
+ def send_throttle_timeout(self) -> float:
+ if len(self._write_buffer) or not self._write_throttling:
+ return 0
+
+ self.throttle_prune()
+ if self.throttle_space() > 0:
+ return 0
+
+ time_left = self._recent_sends[0]+THROTTLE_SECONDS
+ time_left = time_left-time.monotonic()
+ return time_left
+
+ def set_write_throttling(self, is_on: bool):
+ self._write_throttling = is_on