diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/IRCBot.py | 27 | ||||
| -rw-r--r-- | src/IRCLine.py | 15 | ||||
| -rw-r--r-- | src/IRCServer.py | 25 | ||||
| -rw-r--r-- | src/IRCSocket.py | 17 | ||||
| -rw-r--r-- | src/ModuleManager.py | 30 | ||||
| -rw-r--r-- | src/Timers.py | 3 | ||||
| -rw-r--r-- | src/utils/http.py | 7 | ||||
| -rw-r--r-- | src/utils/irc/__init__.py | 42 |
8 files changed, 103 insertions, 63 deletions
diff --git a/src/IRCBot.py b/src/IRCBot.py index 67a85c42..f5407e59 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -2,7 +2,7 @@ import enum, queue, os, select, socket, threading, time, traceback, typing, uuid from src import EventManager, Exports, IRCServer, Logging, ModuleManager from src import Socket, utils -VERSION = "v1.7.1" +VERSION = "v1.8.0" SOURCE = "https://git.io/bitbot" class TriggerResult(enum.Enum): @@ -62,8 +62,18 @@ class Bot(object): def load_modules(self, safe: bool=False ) -> typing.Tuple[typing.List[str], typing.List[str]]: - whitelist = self.get_setting("module-whitelist", []) - blacklist = self.get_setting("module-blacklist", []) + db_blacklist = set(self.get_setting("module-blacklist", [])) + db_whitelist = set(self.get_setting("module-whitelist", [])) + + conf_blacklist = self.config.get("module-blacklist", "").split(",") + conf_whitelist = self.config.get("module-whitelist", "").split(",") + + conf_blacklist = set(filter(None, conf_blacklist)) + conf_whitelist = set(filter(None, conf_whitelist)) + + blacklist = db_blacklist|conf_blacklist + whitelist = db_whitelist|conf_whitelist + return self.modules.load_modules(self, whitelist=whitelist, blacklist=blacklist, safe=safe) @@ -237,7 +247,13 @@ class Bot(object): for piece in data: sock.parse_data(piece) elif event & select.EPOLLOUT: - sock._send() + try: + sock._send() + except: + self.log.error("Failed to write to %s", + [str(sock)]) + raise + if sock.fileno() in self.servers: self.register_read(sock) elif event & select.EPULLHUP: @@ -262,7 +278,8 @@ class Bot(object): self.log.warn( "Disconnected from %s, reconnecting in %d seconds", [str(server), reconnect_delay]) - elif (server.socket.waiting_send() and + elif server.socket.waiting_immediate_send() or ( + server.socket.waiting_send() and server.socket.throttle_done()): self.register_both(server) diff --git a/src/IRCLine.py b/src/IRCLine.py index d388b463..a612c0ad 100644 --- a/src/IRCLine.py +++ b/src/IRCLine.py @@ -39,14 +39,19 @@ class Hostmask(object): class ParsedLine(object): def __init__(self, command: str, args: typing.List[str], - prefix: Hostmask=None, - tags: typing.Dict[str, str]={}): + source: Hostmask=None, + tags: typing.Dict[str, str]=None): self.command = command self._args = args self.args = IRCArgs(args) - self.prefix = prefix + self.source = source self.tags = {} if tags == None else tags + def __repr__(self): + return "ParsedLine(%s)" % self.__str__() + def __str__(self): + return self.format() + def _tag_str(self, tags: typing.Dict[str, str]) -> str: tag_pieces = [] for tag, value in tags.items(): @@ -64,8 +69,8 @@ class ParsedLine(object): if self.tags: pieces.append(self._tag_str(self.tags)) - if self.prefix: - pieces.append(str(self.prefix)) + if self.source: + pieces.append(str(self.source)) pieces.append(self.command.upper()) diff --git a/src/IRCServer.py b/src/IRCServer.py index 3a4ff182..e6c08ee6 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -17,6 +17,7 @@ class Server(IRCObject.Object): self.id = id self.alias = alias self.connection_params = connection_params + self.connected = False self.name = None # type: typing.Optional[str] self.version = None # type: typing.Optional[str] @@ -30,7 +31,7 @@ class Server(IRCObject.Object): self._capabilities_waiting = set([]) # type: typing.Set[str] self.agreed_capabilities = set([]) # type: typing.Set[str] self.server_capabilities = {} # type: typing.Dict[str, str] - self.batches = {} # type: typing.Dict[str, IRCLine.ParsedLine] + self.batches = {} # type: typing.Dict[str, utils.irc.IRCBatch] self.cap_started = False self.users = {} # type: typing.Dict[str, IRCUser.User] @@ -239,9 +240,6 @@ class Server(IRCObject.Object): if lines: self.ping_sent = False - now = datetime.datetime.utcnow() - self.set_setting("last-read", - utils.iso8601_format(now, milliseconds=True)) return lines def send(self, line_parsed: IRCLine.ParsedLine): @@ -250,6 +248,8 @@ class Server(IRCObject.Object): self.events.on("preprocess.send").on(line_parsed.command ).call_unsafe(server=self, line=line_parsed) + self.events.on("preprocess.send").call_unsafe(server=self, + line=line_parsed) line = line_parsed.format() line_obj = IRCLine.SentLine(datetime.datetime.utcnow(), self.hostmask(), @@ -285,9 +285,12 @@ class Server(IRCObject.Object): def send_authenticate(self, text: str) -> IRCLine.SentLine: return self.send(utils.irc.protocol.authenticate(text)) def has_capability(self, capability: utils.irc.Capability) -> bool: - return bool(capability.available(self.agreed_capabilities)) + return bool(self.available_capability(capability)) def has_capability_str(self, capability: str) -> bool: return capability in self.agreed_capabilities + def available_capability(self, capability: utils.irc.Capability + ) -> typing.Optional[str]: + return capability.available(self.agreed_capabilities) def waiting_for_capabilities(self) -> bool: return bool(len(self._capabilities_waiting)) @@ -360,15 +363,3 @@ class Server(IRCObject.Object): def send_whox(self, mask: str, filter: str, fields: str, label: str=None ) -> IRCLine.SentLine: return self.send(utils.irc.protocol.whox(mask, filter, fields, label)) - - def make_batch(self, identifier: str, batch_type: str, - tags: typing.Dict[str, str]={}) -> utils.irc.IRCSendBatch: - return utils.irc.IRCSendBatch(identifier, batch_type, tags) - def send_batch(self, batch: utils.irc.IRCSendBatch) -> IRCLine.SentLine: - self.send(utils.irc.protocol.batch_start(batch.id, batch.type, - batch.tags)) - - for line in batch.lines: - self.send(line) - - return self.send(utils.irc.protocol.batch_end(batch.id)) diff --git a/src/IRCSocket.py b/src/IRCSocket.py index 82c736ee..3a38e8e8 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -27,13 +27,15 @@ class Socket(IRCObject.Object): self._write_buffer = b"" self._queued_lines = [] # type: typing.List[IRCLine.SentLine] self._buffered_lines = [] # type: typing.List[IRCLine.SentLine] - self._write_throttling = False self._read_buffer = b"" self._recent_sends = [] # type: typing.List[float] self.cached_fileno = None # type: typing.Optional[int] self.bytes_written = 0 self.bytes_read = 0 + self._write_throttling = False + self._throttle_when_empty = False + self.last_read = time.monotonic() self.last_send = None # type: typing.Optional[float] @@ -58,6 +60,7 @@ class Socket(IRCObject.Object): bindhost = (self._bindhost, 0) self._socket = socket.create_connection((self._hostname, self._port), 5.0, bindhost) + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if self._tls: self._tls_wrap() @@ -121,6 +124,12 @@ class Socket(IRCObject.Object): def _send(self) -> typing.List[IRCLine.ParsedLine]: sent_lines = [] + + if not self._write_buffer and self._throttle_when_empty: + self._throttle_when_empty = False + self._write_throttling = True + self._recent_sends.clear() + throttle_space = self.throttle_space() if throttle_space: to_buffer = self._queued_lines[:throttle_space] @@ -152,6 +161,8 @@ class Socket(IRCObject.Object): def waiting_send(self) -> bool: return bool(len(self._write_buffer)) or bool(len(self._queued_lines)) + def waiting_immediate_send(self) -> bool: + return bool(len(self._write_buffer)) def throttle_done(self) -> bool: return self.send_throttle_timeout() == 0 @@ -182,5 +193,5 @@ class Socket(IRCObject.Object): time_left = time_left-time.monotonic() return time_left - def set_write_throttling(self, is_on: bool): - self._write_throttling = is_on + def enable_write_throttle(self): + self._throttle_when_empty = True diff --git a/src/ModuleManager.py b/src/ModuleManager.py index 62dd6fa6..5f1bfa61 100644 --- a/src/ModuleManager.py +++ b/src/ModuleManager.py @@ -19,9 +19,13 @@ class ModuleNotLoadedWarning(ModuleWarning): pass class ModuleDependencyNotFulfilled(ModuleException): - def __init__(self, message, dependency): - ModuleException.__init__(self, message) + def __init__(self, module, dependency): + ModuleException.__init__(self, "Dependency for %s not fulfilled: %s" + % (module, dependency)) + self.module = module self.dependency = dependency +class ModuleCircularDependency(ModuleException): + pass class ModuleType(enum.Enum): FILE = 0 @@ -123,6 +127,8 @@ class ModuleManager(object): if os.path.isdir(path): type = ModuleType.DIRECTORY path = os.path.join(path, "__init__.py") + else: + path = "%s.py" % path return self.define_module(type, path) @@ -155,9 +161,8 @@ class ModuleManager(object): dependencies = definition.get_dependencies() for dependency in dependencies: if not dependency in self.modules: - raise ModuleDependencyNotFulfilled( - "Dependency for %s not fulfilled: %s" % - (definition.name, dependency) ,dependency) + raise ModuleDependencyNotFulfilled(definition.name, + dependency) for hashflag, value in definition.hashflags: if hashflag == "ignore": @@ -235,12 +240,19 @@ class ModuleManager(object): definition_dependencies = { d.name: d.get_dependencies() for d in definitions} + for name, deps in definition_dependencies.items(): + for dep in deps: + if not dep in definition_dependencies: + # unknown dependency! + raise ModuleDependencyNotFulfilled(name, dep) + while definition_dependencies: changed = False to_remove = [] for name, dependencies in definition_dependencies.items(): if not dependencies: + changed = True # pop things with no unfufilled dependencies to_remove.append(name) for name in to_remove: @@ -248,19 +260,23 @@ class ModuleManager(object): del definition_dependencies[name] for deps in definition_dependencies.values(): if name in deps: - # fulfill dependencies for things we just popped changed = True + # fulfill dependencies for things we just popped deps.remove(name) if not changed: for name, deps in definition_dependencies.items(): for dep_name in deps: if name in definition_dependencies[dep_name]: - self.log.warn("Circular dependencies: %s<->%s", + self.log.warn( + "Circular dependencies detected: %s<->%s", [name, dep_name]) + changed = True # snap a circular dependence deps.remove(dep_name) definition_dependencies[dep_name].remove(name) + if not changed: + raise ModuleCircularDependency() return [definition_names[name] for name in definitions_ordered] diff --git a/src/Timers.py b/src/Timers.py index 53865181..88cceeed 100644 --- a/src/Timers.py +++ b/src/Timers.py @@ -83,7 +83,8 @@ class Timers(object): self.timers.append(timer) def next(self) -> typing.Optional[float]: - times = filter(None, [timer.time_left() for timer in self.get_timers()]) + times = list(filter(None, + [timer.time_left() for timer in self.get_timers()])) if not times: return None return max(min(times), 0) diff --git a/src/utils/http.py b/src/utils/http.py index 0488b0af..2d7c512a 100644 --- a/src/utils/http.py +++ b/src/utils/http.py @@ -64,18 +64,17 @@ def request(url: str, method: str="GET", get_params: dict={}, signal.signal(signal.SIGALRM, signal.SIG_IGN) response_headers = utils.CaseInsensitiveDict(dict(response.headers)) - + data = response_content.decode(response.encoding or fallback_encoding) content_type = response.headers["Content-Type"].split(";", 1)[0] + if soup: if content_type in SOUP_CONTENT_TYPES: - soup = bs4.BeautifulSoup(response_content, parser) + soup = bs4.BeautifulSoup(data, parser) return Response(response.status_code, soup, response_headers) else: raise HTTPWrongContentTypeException( "Tried to soup non-html/non-xml data") - - data = response_content.decode(response.encoding or fallback_encoding) if json and data: try: return Response(response.status_code, _json.loads(data), diff --git a/src/utils/irc/__init__.py b/src/utils/irc/__init__.py index 14bbb5b0..9577bcc1 100644 --- a/src/utils/irc/__init__.py +++ b/src/utils/irc/__init__.py @@ -45,7 +45,7 @@ def message_tag_unescape(s): def parse_line(line: str) -> IRCLine.ParsedLine: tags = {} # type: typing.Dict[str, typing.Any] - prefix = None # type: typing.Optional[IRCLine.Hostmask] + source = None # type: typing.Optional[IRCLine.Hostmask] command = None if line[0] == "@": @@ -65,8 +65,8 @@ def parse_line(line: str) -> IRCLine.ParsedLine: trailing = trailing_split if line[0] == ":": - prefix_str, line = line[1:].split(" ", 1) - prefix = seperate_hostmask(prefix_str) + source_str, line = line[1:].split(" ", 1) + source = seperate_hostmask(source_str) command, sep, line = line.partition(" ") args = [] # type: typing.List[str] @@ -77,7 +77,7 @@ def parse_line(line: str) -> IRCLine.ParsedLine: if not trailing == None: args.append(typing.cast(str, trailing)) - return IRCLine.ParsedLine(command, args, prefix, tags) + return IRCLine.ParsedLine(command, args, source, tags) REGEX_COLOR = re.compile("%s(?:(\d{1,2})(?:,(\d{1,2}))?)?" % utils.consts.COLOR) @@ -258,28 +258,25 @@ def parse_ctcp(s: str) -> typing.Optional[CTCPMessage]: return None class IRCBatch(object): - def __init__(self, identifier: str, batch_type: str, tags: - typing.Dict[str, str]={}): + def __init__(self, identifier: str, batch_type: str, args: typing.List[str], + tags: typing.Dict[str, str]={}): self.id = identifier self.type = batch_type + self.args = args self.tags = tags - self.lines = [] # type: typing.List[IRCLine.ParsedLine] -class IRCRecvBatch(IRCBatch): - pass -class IRCSendBatch(IRCBatch): - def _add_line(self, line: IRCLine.ParsedLine): - line.tags["batch"] = self.id - self.lines.append(line) - def message(self, target: str, message: str, tags: dict={}): - self._add_line(utils.irc.protocol.message(target, message, tags)) - def notice(self, target: str, message: str, tags: dict={}): - self._add_line(utils.irc.protocol.notice(target, message, tags)) + self._lines = [] # type: typing.List[IRCLine.ParsedLine] + def add_line(self, line: IRCLine.ParsedLine): + self._lines.append(line) + def get_lines(self) -> typing.List[IRCLine.ParsedLine]: + return self._lines class Capability(object): - def __init__(self, name, draft_name=None): + def __init__(self, name: typing.Optional[str], draft_name: str=None): self._caps = set([name, draft_name]) - self._on_ack_callbacks = [] - def available(self, capabilities: typing.Iterable[str]) -> str: + self._on_ack_callbacks = [ + ] # type: typing.List[typing.Callable[[], None]] + def available(self, capabilities: typing.Iterable[str] + ) -> typing.Optional[str]: match = list(set(capabilities)&self._caps) return match[0] if match else None @@ -295,11 +292,14 @@ class Capability(object): pass class MessageTag(object): - def __init__(self, name: str, draft_name: str=None): + def __init__(self, name: typing.Optional[str], draft_name: str=None): self._names = set([name, draft_name]) def get_value(self, tags: typing.Dict[str, str]) -> typing.Optional[str]: key = list(set(tags.keys())&self._names) return tags[key[0]] if key else None + def match(self, s: str) -> typing.Optional[str]: + key = list(set([s])&self._names) + return key[0] if key else None def hostmask_match(hostmask: str, pattern: str) -> bool: return fnmatch.fnmatchcase(hostmask, pattern) |
