diff options
| author | 2019-02-11 10:00:41 +0000 | |
|---|---|---|
| committer | 2019-02-11 10:00:41 +0000 | |
| commit | 9b44b6cd139f4e47bd57513a3af663cd2eecb571 (patch) | |
| tree | 26ba2b81b2f684acc718c5558dfdda2237d816a5 /src/IRCServer.py | |
| parent | We don't need to send `writebuffer.empty` event any more (src/IRCServer.py) (diff) | |
| signature | ||
Shift socket.socket related logic to IRCSocket.py
Diffstat (limited to 'src/IRCServer.py')
| -rw-r--r-- | src/IRCServer.py | 185 |
1 files changed, 27 insertions, 158 deletions
diff --git a/src/IRCServer.py b/src/IRCServer.py index 2907b8eb..bc800ed4 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -1,10 +1,7 @@ -import collections, datetime, socket, ssl, sys, time, typing -from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine -from src import IRCObject, IRCUser, utils -THROTTLE_LINES = 4 -THROTTLE_SECONDS = 1 -UNTHROTTLED_MAX_LINES = 10 +import collections, datetime, sys, time, typing +from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine +from src import IRCObject, IRCSocket, IRCUser, utils READ_TIMEOUT_SECONDS = 120 PING_INTERVAL_SECONDS = 30 @@ -37,16 +34,6 @@ class Server(IRCObject.Object): self.batches = {} # type: typing.Dict[str, utils.irc.IRCParsedLine] self.cap_started = False - self.write_buffer = b"" - self.queued_lines = [] # type: typing.List[IRCLine.Line] - self.buffered_lines = [] # type: typing.List[IRCLine.Line] - self._write_throttling = False - self.read_buffer = b"" - self.recent_sends = [] # type: typing.List[float] - self.cached_fileno = None # type: typing.Optional[int] - self.bytes_written = 0 - self.bytes_read = 0 - self.users = {} # type: typing.Dict[str, IRCUser.User] self.new_users = set([]) #type: typing.Set[IRCUser.User] self.channels = IRCChannels.Channels(self, self.bot, self.events) @@ -88,40 +75,27 @@ class Server(IRCObject.Object): return "%s:%s%s" % (self.connection_params.hostname, "+" if self.connection_params.tls else "", self.connection_params.port) + def fileno(self) -> int: - return self.cached_fileno or self.socket.fileno() + return self.socket.fileno() def hostmask(self): return "%s!%s@%s" % (self.nickname, self.username, self.hostname) - def tls_wrap(self): - client_certificate = self.bot.config.get("tls-certificate", None) - client_key = self.bot.config.get("tls-key", None) - verify = self.get_setting("ssl-verify", True) - - server_hostname = None - if not utils.is_ip(self.connection_params.hostname): - server_hostname = self.connection_params.hostname - - self.socket = utils.security.ssl_wrap(self.socket, - cert=client_certificate, key=client_key, - verify=verify, hostname=server_hostname) - def connect(self): - ipv4 = self.connection_params.ipv4 - family = socket.AF_INET if ipv4 else socket.AF_INET6 - self.socket = socket.socket(family, socket.SOCK_STREAM) - - self.socket.settimeout(5.0) - - if self.connection_params.bindhost: - self.socket.bind((self.connection_params.bindhost, 0)) - if self.connection_params.tls: - self.tls_wrap() - - self.socket.connect((self.connection_params.hostname, - self.connection_params.port)) - self.cached_fileno = self.socket.fileno() + self.socket = IRCSocket.Socket( + self.bot.log, + self.get_setting("encoding", "utf8"), + self.get_setting("fallback-encoding", "iso-8859-1"), + self.connection_params.hostname, + self.connection_params.port, + self.connection_params.ipv4, + self.connection_params.bindhost, + self.connection_params.tls, + tls_verify=self.get_setting("ssl-verify", True), + cert=self.bot.config.get("tls-certificate", None), + key=self.bot.config.get("tls-key", None)) + self.socket.connect() if self.connection_params.password: self.send_pass(self.connection_params.password) @@ -135,16 +109,9 @@ class Server(IRCObject.Object): self.send_user(username, realname) self.send_nick(nickname) self.connected = True + def disconnect(self): - self.connected = False - try: - self.socket.shutdown(socket.SHUT_RDWR) - except: - pass - try: - self.socket.close() - except: - pass + self.socket.disconnect() def set_setting(self, setting: str, value: typing.Any): self.bot.database.server_settings.set(self.id, setting, @@ -252,46 +219,6 @@ class Server(IRCObject.Object): if not len(user.channels): self.remove_user(user) self.new_users.clear() - def read(self) -> typing.Optional[typing.List[str]]: - data = b"" - try: - data = self.socket.recv(4096) - except (ConnectionResetError, socket.timeout, OSError): - self.disconnect() - return None - if not data: - self.disconnect() - return None - self.bytes_read += len(data) - data = self.read_buffer+data - self.read_buffer = b"" - - data_lines = [line.strip(b"\r") for line in data.split(b"\n")] - if data_lines[-1]: - self.read_buffer = data_lines[-1] - self.bot.log.trace("recevied and buffered non-complete line: %s", - [data_lines[-1]]) - - data_lines.pop(-1) - decoded_lines = [] - - for line in data_lines: - encoding = self.get_setting("encoding", "utf8") - try: - decoded_line = line.decode(encoding) - except: - self.bot.log.trace("can't decode line with '%s', falling back", - [encoding]) - try: - decoded_line = line.decode(self.get_setting( - "fallback-encoding", "latin-1")) - except: - continue - decoded_lines.append(decoded_line) - - self.last_read = time.monotonic() - self.ping_sent = False - return decoded_lines def until_next_ping(self) -> typing.Optional[float]: if self.ping_sent: @@ -307,6 +234,8 @@ class Server(IRCObject.Object): def read_timed_out(self) -> bool: return self.until_read_timeout == 0 + def read(self) -> typing.Optional[typing.List[str]]: + return self.socket.read() def send(self, line: str): results = self.events.on("preprocess.send").call_unsafe( server=self, line=line) @@ -314,75 +243,15 @@ class Server(IRCObject.Object): if result: line = result break - line_stripped = line.split("\n", 1)[0].strip("\r") line_obj = IRCLine.Line(self, datetime.datetime.utcnow(), line_stripped) - self.queued_lines.append(line_obj) - + self.socket.send(line_obj) return line_obj - def _send(self): - if not len(self.write_buffer): - throttle_space = self.throttle_space() - to_buffer = self.queued_lines[:throttle_space] - self.queued_lines = self.queued_lines[throttle_space:] - for line in to_buffer: - decoded_data = line.decoded_data() - self.bot.log.debug("%s (raw send) | %s", - [str(self), decoded_data]) - self.events.on("raw.send").call_unsafe( - server=self, line=decoded_data) - - self.write_buffer += line.data() - self.buffered_lines.append(line) - - bytes_written_i = self.socket.send(self.write_buffer) - bytes_written = self.write_buffer[:bytes_written_i] - lines_sent = bytes_written.count(b"\r\n") - for i in range(lines_sent): - self.buffered_lines.pop(0).sent() - - self.write_buffer = self.write_buffer[bytes_written_i:] - - self.bytes_written += bytes_written_i - - now = time.monotonic() - self.recent_sends.append(now) - self.last_send = now - def waiting_send(self) -> bool: - return bool(len(self.write_buffer)) or bool(len(self.queued_lines)) - - def throttle_done(self) -> bool: - return self.send_throttle_timeout() == 0 - - def throttle_prune(self): - now = time.monotonic() - popped = 0 - for i, recent_send in enumerate(self.recent_sends[:]): - time_since = now-recent_send - if time_since >= THROTTLE_SECONDS: - self.recent_sends.pop(i-popped) - popped += 1 - - def throttle_space(self) -> int: - if not self._write_throttling: - return UNTHROTTLED_MAX_LINES - return max(0, THROTTLE_LINES-len(self.recent_sends)) - - def send_throttle_timeout(self) -> float: - if len(self.write_buffer) or not self._write_throttling: - return 0 - - self.throttle_prune() - if self.throttle_space() > 0: - return 0 - - time_left = self.recent_sends[0]+THROTTLE_SECONDS - time_left = time_left-time.monotonic() - return time_left - - def set_write_throttling(self, is_on: bool): - self._write_throttling = is_on + lines = self.socket._send() + for line in lines: + self.bot.log.debug("%s (raw send) | %s", [str(self), line]) + self.events.on("raw.send").call_unsafe(server=self, line=line) def send_user(self, username: str, realname: str) -> IRCLine.Line: return self.send("USER %s 0 * :%s" % (username, realname)) |
