aboutsummaryrefslogtreecommitdiff
path: root/src/IRCServer.py
diff options
context:
space:
mode:
authorGravatar jesopo2019-02-11 10:00:41 +0000
committerGravatar jesopo2019-02-11 10:00:41 +0000
commit9b44b6cd139f4e47bd57513a3af663cd2eecb571 (patch)
tree26ba2b81b2f684acc718c5558dfdda2237d816a5 /src/IRCServer.py
parentWe don't need to send `writebuffer.empty` event any more (src/IRCServer.py) (diff)
signature
Shift socket.socket related logic to IRCSocket.py
Diffstat (limited to 'src/IRCServer.py')
-rw-r--r--src/IRCServer.py185
1 files changed, 27 insertions, 158 deletions
diff --git a/src/IRCServer.py b/src/IRCServer.py
index 2907b8eb..bc800ed4 100644
--- a/src/IRCServer.py
+++ b/src/IRCServer.py
@@ -1,10 +1,7 @@
-import collections, datetime, socket, ssl, sys, time, typing
-from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine
-from src import IRCObject, IRCUser, utils
-THROTTLE_LINES = 4
-THROTTLE_SECONDS = 1
-UNTHROTTLED_MAX_LINES = 10
+import collections, datetime, sys, time, typing
+from src import EventManager, IRCBot, IRCChannel, IRCChannels, IRCLine
+from src import IRCObject, IRCSocket, IRCUser, utils
READ_TIMEOUT_SECONDS = 120
PING_INTERVAL_SECONDS = 30
@@ -37,16 +34,6 @@ class Server(IRCObject.Object):
self.batches = {} # type: typing.Dict[str, utils.irc.IRCParsedLine]
self.cap_started = False
- self.write_buffer = b""
- self.queued_lines = [] # type: typing.List[IRCLine.Line]
- self.buffered_lines = [] # type: typing.List[IRCLine.Line]
- self._write_throttling = False
- self.read_buffer = b""
- self.recent_sends = [] # type: typing.List[float]
- self.cached_fileno = None # type: typing.Optional[int]
- self.bytes_written = 0
- self.bytes_read = 0
-
self.users = {} # type: typing.Dict[str, IRCUser.User]
self.new_users = set([]) #type: typing.Set[IRCUser.User]
self.channels = IRCChannels.Channels(self, self.bot, self.events)
@@ -88,40 +75,27 @@ class Server(IRCObject.Object):
return "%s:%s%s" % (self.connection_params.hostname,
"+" if self.connection_params.tls else "",
self.connection_params.port)
+
def fileno(self) -> int:
- return self.cached_fileno or self.socket.fileno()
+ return self.socket.fileno()
def hostmask(self):
return "%s!%s@%s" % (self.nickname, self.username, self.hostname)
- def tls_wrap(self):
- client_certificate = self.bot.config.get("tls-certificate", None)
- client_key = self.bot.config.get("tls-key", None)
- verify = self.get_setting("ssl-verify", True)
-
- server_hostname = None
- if not utils.is_ip(self.connection_params.hostname):
- server_hostname = self.connection_params.hostname
-
- self.socket = utils.security.ssl_wrap(self.socket,
- cert=client_certificate, key=client_key,
- verify=verify, hostname=server_hostname)
-
def connect(self):
- ipv4 = self.connection_params.ipv4
- family = socket.AF_INET if ipv4 else socket.AF_INET6
- self.socket = socket.socket(family, socket.SOCK_STREAM)
-
- self.socket.settimeout(5.0)
-
- if self.connection_params.bindhost:
- self.socket.bind((self.connection_params.bindhost, 0))
- if self.connection_params.tls:
- self.tls_wrap()
-
- self.socket.connect((self.connection_params.hostname,
- self.connection_params.port))
- self.cached_fileno = self.socket.fileno()
+ self.socket = IRCSocket.Socket(
+ self.bot.log,
+ self.get_setting("encoding", "utf8"),
+ self.get_setting("fallback-encoding", "iso-8859-1"),
+ self.connection_params.hostname,
+ self.connection_params.port,
+ self.connection_params.ipv4,
+ self.connection_params.bindhost,
+ self.connection_params.tls,
+ tls_verify=self.get_setting("ssl-verify", True),
+ cert=self.bot.config.get("tls-certificate", None),
+ key=self.bot.config.get("tls-key", None))
+ self.socket.connect()
if self.connection_params.password:
self.send_pass(self.connection_params.password)
@@ -135,16 +109,9 @@ class Server(IRCObject.Object):
self.send_user(username, realname)
self.send_nick(nickname)
self.connected = True
+
def disconnect(self):
- self.connected = False
- try:
- self.socket.shutdown(socket.SHUT_RDWR)
- except:
- pass
- try:
- self.socket.close()
- except:
- pass
+ self.socket.disconnect()
def set_setting(self, setting: str, value: typing.Any):
self.bot.database.server_settings.set(self.id, setting,
@@ -252,46 +219,6 @@ class Server(IRCObject.Object):
if not len(user.channels):
self.remove_user(user)
self.new_users.clear()
- def read(self) -> typing.Optional[typing.List[str]]:
- data = b""
- try:
- data = self.socket.recv(4096)
- except (ConnectionResetError, socket.timeout, OSError):
- self.disconnect()
- return None
- if not data:
- self.disconnect()
- return None
- self.bytes_read += len(data)
- data = self.read_buffer+data
- self.read_buffer = b""
-
- data_lines = [line.strip(b"\r") for line in data.split(b"\n")]
- if data_lines[-1]:
- self.read_buffer = data_lines[-1]
- self.bot.log.trace("recevied and buffered non-complete line: %s",
- [data_lines[-1]])
-
- data_lines.pop(-1)
- decoded_lines = []
-
- for line in data_lines:
- encoding = self.get_setting("encoding", "utf8")
- try:
- decoded_line = line.decode(encoding)
- except:
- self.bot.log.trace("can't decode line with '%s', falling back",
- [encoding])
- try:
- decoded_line = line.decode(self.get_setting(
- "fallback-encoding", "latin-1"))
- except:
- continue
- decoded_lines.append(decoded_line)
-
- self.last_read = time.monotonic()
- self.ping_sent = False
- return decoded_lines
def until_next_ping(self) -> typing.Optional[float]:
if self.ping_sent:
@@ -307,6 +234,8 @@ class Server(IRCObject.Object):
def read_timed_out(self) -> bool:
return self.until_read_timeout == 0
+ def read(self) -> typing.Optional[typing.List[str]]:
+ return self.socket.read()
def send(self, line: str):
results = self.events.on("preprocess.send").call_unsafe(
server=self, line=line)
@@ -314,75 +243,15 @@ class Server(IRCObject.Object):
if result:
line = result
break
-
line_stripped = line.split("\n", 1)[0].strip("\r")
line_obj = IRCLine.Line(self, datetime.datetime.utcnow(), line_stripped)
- self.queued_lines.append(line_obj)
-
+ self.socket.send(line_obj)
return line_obj
-
def _send(self):
- if not len(self.write_buffer):
- throttle_space = self.throttle_space()
- to_buffer = self.queued_lines[:throttle_space]
- self.queued_lines = self.queued_lines[throttle_space:]
- for line in to_buffer:
- decoded_data = line.decoded_data()
- self.bot.log.debug("%s (raw send) | %s",
- [str(self), decoded_data])
- self.events.on("raw.send").call_unsafe(
- server=self, line=decoded_data)
-
- self.write_buffer += line.data()
- self.buffered_lines.append(line)
-
- bytes_written_i = self.socket.send(self.write_buffer)
- bytes_written = self.write_buffer[:bytes_written_i]
- lines_sent = bytes_written.count(b"\r\n")
- for i in range(lines_sent):
- self.buffered_lines.pop(0).sent()
-
- self.write_buffer = self.write_buffer[bytes_written_i:]
-
- self.bytes_written += bytes_written_i
-
- now = time.monotonic()
- self.recent_sends.append(now)
- self.last_send = now
- def waiting_send(self) -> bool:
- return bool(len(self.write_buffer)) or bool(len(self.queued_lines))
-
- def throttle_done(self) -> bool:
- return self.send_throttle_timeout() == 0
-
- def throttle_prune(self):
- now = time.monotonic()
- popped = 0
- for i, recent_send in enumerate(self.recent_sends[:]):
- time_since = now-recent_send
- if time_since >= THROTTLE_SECONDS:
- self.recent_sends.pop(i-popped)
- popped += 1
-
- def throttle_space(self) -> int:
- if not self._write_throttling:
- return UNTHROTTLED_MAX_LINES
- return max(0, THROTTLE_LINES-len(self.recent_sends))
-
- def send_throttle_timeout(self) -> float:
- if len(self.write_buffer) or not self._write_throttling:
- return 0
-
- self.throttle_prune()
- if self.throttle_space() > 0:
- return 0
-
- time_left = self.recent_sends[0]+THROTTLE_SECONDS
- time_left = time_left-time.monotonic()
- return time_left
-
- def set_write_throttling(self, is_on: bool):
- self._write_throttling = is_on
+ lines = self.socket._send()
+ for line in lines:
+ self.bot.log.debug("%s (raw send) | %s", [str(self), line])
+ self.events.on("raw.send").call_unsafe(server=self, line=line)
def send_user(self, username: str, realname: str) -> IRCLine.Line:
return self.send("USER %s 0 * :%s" % (username, realname))