diff options
Diffstat (limited to 'src')
27 files changed, 275 insertions, 170 deletions
diff --git a/src/Exports.py b/src/Exports.py index a9fe3175..64c10d69 100644 --- a/src/Exports.py +++ b/src/Exports.py @@ -35,7 +35,7 @@ class Exports(object): return self._exports.get(setting, []) + sum([ exports.get(setting, []) for exports in self._context_exports.values()], []) - def get_one(self, setting: str, default: typing.Any=None + def get(self, setting: str, default: typing.Any=None ) -> typing.Optional[typing.Any]: values = self.get_all(setting) return values[0] if values else default @@ -60,8 +60,8 @@ class ExportsContext(object): self._parent._context_add(self.context, setting, value) def get_all(self, setting: str) -> typing.List[typing.Any]: return self._parent.get_all(setting) - def get_one(self, setting: str, default: typing.Any=None + def get(self, setting: str, default: typing.Any=None ) -> typing.Optional[typing.Any]: - return self._parent.get_one(setting, default) + return self._parent.get(setting, default) def find(self, setting_prefix: str) -> typing.List[typing.Any]: return self._parent.find(setting_prefix) diff --git a/src/IRCBot.py b/src/IRCBot.py index f5b66b23..b2c2861f 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -204,8 +204,12 @@ class Bot(object): try: server.connect() except Exception as e: - self.log.warn("Failed to connect to %s: %s", - [str(server), str(e)]) + ip = "" + if server.socket.connected_ip is not None: + ip = f" ({server.socket.connected_ip})" + + self.log.warn("Failed to connect to %s%s: %s", + [str(server), ip, str(e)]) self.log.debug("Connection failure reason:", exc_info=True) return False self.servers[server.fileno()] = server diff --git a/src/IRCChannel.py b/src/IRCChannel.py index 028a51b7..4e8aed9d 100644 --- a/src/IRCChannel.py +++ b/src/IRCChannel.py @@ -210,6 +210,8 @@ class Channel(IRCObject.Object): def send_message(self, text: str, tags: dict={}): return self.server.send_message(self.name, text, tags=tags) + def send_action(self, text: str, tags: dict={}): + return self.server.send_action(self.name, text, tags=tags) def send_notice(self, text: str, tags: dict={}): return self.server.send_notice(self.name, text, tags=tags) def send_tagmsg(self, tags: dict): @@ -260,7 +262,10 @@ class Channel(IRCObject.Object): return True return False - def has_mode(self, user: IRCUser.User, mode: str) -> bool: + def has_mode(self, mode: str) -> bool: + return mode in self.modes + + def has_umode(self, user: IRCUser.User, mode: str) -> bool: return user in self.modes.get(mode, []) def get_user_modes(self, user: IRCUser.User) -> typing.Set: diff --git a/src/IRCLine.py b/src/IRCLine.py index 62d9e8b7..1dda449b 100644 --- a/src/IRCLine.py +++ b/src/IRCLine.py @@ -1,8 +1,7 @@ -import datetime, typing, uuid +import codecs, datetime, typing, uuid from src import EventManager, IRCObject, utils -# this should be 510 (RFC1459, 512 with \r\n) but a server BitBot uses is broken -LINE_MAX = 470 +LINE_MAX = 510 class IRCArgs(object): def __init__(self, args: typing.List[str]): @@ -125,42 +124,43 @@ class ParsedLine(object): return tags, " ".join(pieces).replace("\r", "") def format(self) -> str: tags, line = self._format() - line, _ = self._newline_truncate(line) if tags: return "%s %s" % (tags, line) else: return line - def _newline_truncate(self, line: str) -> typing.Tuple[str, str]: - line, sep, overflow = line.partition("\n") - return (line, overflow) - def _line_max(self, hostmask: str, margin: int) -> int: - return LINE_MAX-len((":%s " % hostmask).encode("utf8"))-margin - def truncate(self, hostmask: str, margin: int=0) -> typing.Tuple[str, str]: - valid_bytes = b"" - valid_index = -1 - - line_max = self._line_max(hostmask, margin) +class SendableLine(ParsedLine): + def __init__(self, command: str, args: typing.List[str], + margin: int=0, tags: typing.Dict[str, str]=None): + ParsedLine.__init__(self, command, args, None, tags) + self._margin = margin - tags_formatted, line_formatted = self._format() - for i, char in enumerate(line_formatted): - encoded_char = char.encode("utf8") - if (len(valid_bytes)+len(encoded_char) > line_max - or encoded_char == b"\n"): - break - else: - valid_bytes += encoded_char - valid_index = i - valid_index += 1 + def push_last(self, arg: str, extra_margin: int=0, + human_trunc: bool=False) -> typing.Optional[str]: + last_arg = self.args[-1] + tags, line = self._format() + n = len(line.encode("utf8")) # get length of current line + n += self._margin # margin used for :hostmask + if " " in arg and not " " in last_arg: + n += 1 # +1 for colon on new arg + n += extra_margin # used for things like (more ...) - valid = line_formatted[:valid_index] - if tags_formatted: - valid = "%s %s" % (tags_formatted, valid) - overflow = line_formatted[valid_index:] - if overflow and overflow[0] == "\n": - overflow = overflow[1:] + overflow: typing.Optional[str] = None - return valid, overflow + if (n+len(arg.encode("utf8"))) > LINE_MAX: + for i, char in enumerate(codecs.iterencode(arg, "utf8")): + n += len(char) + if n > LINE_MAX: + arg, overflow = arg[:i], arg[i:] + if human_trunc and not overflow[0] == " ": + new_arg, sep, new_overflow = arg.rpartition(" ") + if sep: + arg = new_arg + overflow = new_overflow+overflow + break + if arg: + self.args[-1] = last_arg+arg + return overflow def parse_line(line: str) -> ParsedLine: tags = {} # type: typing.Dict[str, typing.Any] @@ -220,7 +220,7 @@ class SentLine(IRCObject.Object): return self._for_wire() def _for_wire(self) -> str: - return self.parsed_line.truncate(self._hostmask)[0] + return str(self.parsed_line) def for_wire(self) -> bytes: return b"%s\r\n" % self._for_wire().encode("utf8") diff --git a/src/IRCServer.py b/src/IRCServer.py index 1454fedb..bb22adab 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -80,6 +80,11 @@ class Server(IRCObject.Object): def hostmask(self): return "%s!%s@%s" % (self.nickname, self.username, self.hostname) + def new_line(self, command: str, args: typing.List[str]=None, + tags: typing.Dict[str, str]=None) -> IRCLine.SendableLine: + return IRCLine.SendableLine(command, args or [], + len((":%s " % self.hostmask()).encode("utf8")), tags) + def connect(self): self.socket = IRCSocket.Socket( self.bot.log, @@ -385,6 +390,10 @@ class Server(IRCObject.Object): def send_message(self, target: str, message: str, tags: dict={} ) -> typing.Optional[IRCLine.SentLine]: return self.send(self._line("PRIVMSG", [target, message], tags=tags)) + def send_action(self, target: str, message: str, tags: dict={} + ) -> typing.Optional[IRCLine.SentLine]: + return self.send(self._line("PRIVMSG", + [target, f"\x01ACTION {message}\x01"], tags=tags)) def send_notice(self, target: str, message: str, tags: dict={} ) -> typing.Optional[IRCLine.SentLine]: diff --git a/src/IRCSocket.py b/src/IRCSocket.py index 7fbae39d..507c6471 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -69,11 +69,12 @@ class Socket(IRCObject.Object): 5.0) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self.connected_ip = self._socket.getpeername()[0] + if self._tls: self._tls_wrap() self.connect_time = time.time() - self.connected_ip = self._socket.getpeername()[0] self.cached_fileno = self._socket.fileno() self.connected = True diff --git a/src/IRCUser.py b/src/IRCUser.py index 2e141794..6db124f0 100644 --- a/src/IRCUser.py +++ b/src/IRCUser.py @@ -77,6 +77,8 @@ class User(IRCObject.Object): def send_message(self, message: str, tags: dict={}): self.server.send_message(self.nickname, message, tags=tags) + def send_action(self, message: str, tags: dict={}): + self.server.send_action(self.nickname, message, tags=tags) def send_notice(self, text: str, tags: dict={}): self.server.send_notice(self.nickname, text, tags=tags) def send_ctcp_response(self, command: str, args: str): diff --git a/src/core_modules/admin.py b/src/core_modules/admin.py index f87927b1..5008952d 100644 --- a/src/core_modules/admin.py +++ b/src/core_modules/admin.py @@ -32,9 +32,17 @@ class Module(ModuleManager.BaseModule): @utils.kwarg("permission", "part") @utils.kwarg("require_mode", "high") @utils.kwarg("require_access", "high,part") - @utils.spec("!r~channel") + @utils.spec("!-privateonly !<channel>word") + @utils.spec("!-channelonly ?<channel>word") def part(self, event): - event["server"].send_part(event["spec"][0].name) + event["server"].send_part(event["spec"][0] or event["target"].name) + + @utils.hook("received.command.join") + @utils.kwarg("help", "Join a given channel") + @utils.kwarg("permission", "join") + @utils.spec("!<channel>word") + def join(self, event): + event["server"].send_join(event["spec"][0]) def _id_from_alias(self, alias): return self.bot.database.servers.get_by_alias(alias) diff --git a/src/core_modules/aliases.py b/src/core_modules/aliases.py index b6ca8961..b5ed86f1 100644 --- a/src/core_modules/aliases.py +++ b/src/core_modules/aliases.py @@ -3,6 +3,8 @@ from src import EventManager, ModuleManager, utils SETTING_PREFIX = "command-alias-" +class VariableKeyError(KeyError): + pass class Module(ModuleManager.BaseModule): def _arg_replace(self, s, args_split, kwargs): vars = {} @@ -11,7 +13,11 @@ class Module(ModuleManager.BaseModule): vars["%d-" % i] = " ".join(args_split[i:]) vars["-"] = " ".join(args_split) vars.update(kwargs) - return utils.parse.format_token_replace(s, vars) + + not_found, new_s = utils.parse.format_token_replace(s, vars) + if not_found: + raise VariableKeyError(f"not found: {not_found!r}") + return new_s def _get_alias(self, server, target, command): setting = "%s%s" % (SETTING_PREFIX, command) @@ -45,9 +51,14 @@ class Module(ModuleManager.BaseModule): if event["command"].args: given_args = event["command"].args.split(" ") - event["command"].command = alias - event["command"].args = self._arg_replace(alias_args, given_args, - event["kwargs"]) + try: + event["command"].args = self._arg_replace(alias_args, + given_args, event["kwargs"]) + except VariableKeyError: + pass + else: + event["command"].command = alias + @utils.hook("received.command.alias", permission="alias") diff --git a/src/core_modules/banmask.py b/src/core_modules/banmask.py new file mode 100644 index 00000000..bc33d4a9 --- /dev/null +++ b/src/core_modules/banmask.py @@ -0,0 +1,24 @@ +from src import ModuleManager, utils + +SETTING = utils.Setting("ban-format", + "Set ban format " + "(${n} = nick, ${u} = username, ${h} = hostname, ${a} = account", + example="*!${u}@${h}") + +@utils.export("channelset", SETTING) +@utils.export("serverset", SETTING) +class Module(ModuleManager.BaseModule): + def _format_hostmask(self, user, s): + vars = {} + vars["n"] = vars["nickname"] = user.nickname + vars["u"] = vars["username"] = user.username + vars["h"] = vars["hostname"] = user.hostname + vars["a"] = vars["account"] = user.account or "" + missing, out = utils.parse.format_token_replace(s, vars) + return out + @utils.export("ban-mask") + def banmask(self, server, channel, user): + format = channel.get_setting("ban-format", + server.get_setting("ban-format", "*!${u}@${h}")) + return self._format_hostmask(user, format) + diff --git a/src/core_modules/channel_access.py b/src/core_modules/channel_access.py index 662599e0..a8c371a8 100644 --- a/src/core_modules/channel_access.py +++ b/src/core_modules/channel_access.py @@ -21,7 +21,7 @@ class Module(ModuleManager.BaseModule): user_access = target.get_user_setting(user.get_id(), "access", []) - identified = self.exports.get_one("is-identified")(user) + identified = self.exports.get("is-identified")(user) matched = list(set(required_access)&set(user_access)) return ("*" in user_access or matched) and identified diff --git a/src/core_modules/command_spec/__init__.py b/src/core_modules/command_spec/__init__.py index 34fd5aef..6345d342 100644 --- a/src/core_modules/command_spec/__init__.py +++ b/src/core_modules/command_spec/__init__.py @@ -48,7 +48,7 @@ class Module(ModuleManager.BaseModule): if argument_type.type in types.TYPES: func = types.TYPES[argument_type.type] else: - func = self.exports.get_one( + func = self.exports.get( "command-spec.%s" % argument_type.type) if func: @@ -142,7 +142,7 @@ class Module(ModuleManager.BaseModule): usages = [ utils.parse.argument_spec_human(s, context) for s in specs] command = "%s%s" % (event["command_prefix"], event["command"]) - usages = ["%s%s" % (command, u) for u in usages] + usages = ["%s %s" % (command, u) for u in usages] error_out = "%s (Usage: %s)" % (overall_error, " | ".join(usages)) diff --git a/src/core_modules/commands/__init__.py b/src/core_modules/commands/__init__.py index a5c5faa6..3a0e74cd 100644 --- a/src/core_modules/commands/__init__.py +++ b/src/core_modules/commands/__init__.py @@ -8,8 +8,8 @@ COMMAND_METHOD = "command-method" COMMAND_METHODS = ["PRIVMSG", "NOTICE"] STR_MORE = " (more...)" +STR_CONTINUED = "(...continued) " STR_MORE_LEN = len(STR_MORE.encode("utf8")) -STR_CONTINUED = "(...continued)" WORD_BOUNDARIES = [" "] NON_ALPHANUMERIC = [char for char in string.printable if not char.isalnum()] @@ -73,12 +73,12 @@ class Module(ModuleManager.BaseModule): self.bot.get_setting(COMMAND_METHOD, default))).upper() def _find_command_hook(self, server, target, is_channel, command, user, - args): + command_prefix, args): if not self.has_command(command): command_event = CommandEvent(command, args) self.events.on("get.command").call(command=command_event, server=server, target=target, is_channel=is_channel, user=user, - kwargs={}) + command_prefix=command_prefix, kwargs={}) command = command_event.command args = command_event.args @@ -237,29 +237,25 @@ class Module(ModuleManager.BaseModule): color = utils.consts.RED line_str = obj.pop() + prefix = "" if obj.prefix: - line_str = "[%s] %s" % ( - utils.irc.color(obj.prefix, color), line_str) + prefix = "[%s] " % utils.irc.color(obj.prefix, color) + if obj._overflowed: + prefix = "%s%s" % (prefix, STR_CONTINUED) method = self._command_method(server, target, is_channel) if not method in ["PRIVMSG", "NOTICE"]: raise ValueError("Unknown command-method '%s'" % method) - line = IRCLine.ParsedLine(method, [target_str, line_str], - tags=tags) - valid, trunc = line.truncate(server.hostmask(), - margin=STR_MORE_LEN) + line = server.new_line(method, [target_str, prefix], tags=tags) + + overflow = line.push_last(line_str, human_trunc=True, + extra_margin=STR_MORE_LEN) + if overflow: + line.push_last(STR_MORE) + obj.insert(overflow) + obj._overflowed = True - if trunc: - if not trunc[0] in WORD_BOUNDARIES: - for boundary in WORD_BOUNDARIES: - left, *right = valid.rsplit(boundary, 1) - if right: - valid = left - trunc = right[0]+trunc - obj.insert("%s %s" % (STR_CONTINUED, trunc)) - valid = valid+STR_MORE - line = IRCLine.parse_line(valid) if obj._assured: line.assure() server.send(line) @@ -309,7 +305,7 @@ class Module(ModuleManager.BaseModule): try: hook, command, args_split = self._find_command_hook( event["server"], event["channel"], True, command, - event["user"], args) + event["user"], command_prefix, args) except BadContextException: event["channel"].send_message( "%s: That command is not valid in a channel" % @@ -371,7 +367,7 @@ class Module(ModuleManager.BaseModule): try: hook, command, args_split = self._find_command_hook( event["server"], event["user"], False, command, - event["user"], args) + event["user"], "", args) except BadContextException: event["user"].send_message( "That command is not valid in a PM") diff --git a/src/core_modules/commands/outs.py b/src/core_modules/commands/outs.py index e82ceefd..c6b489ae 100644 --- a/src/core_modules/commands/outs.py +++ b/src/core_modules/commands/outs.py @@ -6,6 +6,13 @@ class StdOut(object): self.prefix = prefix self._lines = [] self._assured = False + self._overflowed = False + + def copy_from(self, other): + self.prefix = other.prefix + self._lines = other._lines + self._assured = other._assured + self._overflowed = other._overflowed def assure(self): self._assured = True diff --git a/src/core_modules/line_handler/__init__.py b/src/core_modules/line_handler/__init__.py index a77dc451..f23a98ed 100644 --- a/src/core_modules/line_handler/__init__.py +++ b/src/core_modules/line_handler/__init__.py @@ -205,6 +205,10 @@ class Module(ModuleManager.BaseModule): @utils.hook("raw.received.chghost") def chghost(self, event): user.chghost(self.events, event) + # RPL_VISIBLEHOST, telling us what our hostname (and sometimes username) is + @utils.hook("raw.received.396") + def handle_396(self, event): + core.handle_396(event) # IRCv3 SETNAME, to change a user's realname @utils.hook("raw.received.setname") diff --git a/src/core_modules/line_handler/core.py b/src/core_modules/line_handler/core.py index a7075613..aa862a03 100644 --- a/src/core_modules/line_handler/core.py +++ b/src/core_modules/line_handler/core.py @@ -165,3 +165,9 @@ def handle_433(event): _nick_in_use(event["server"]) def handle_437(event): _nick_in_use(event["server"]) + +def handle_396(event): + username, sep, hostname = event["line"].args[1].rpartition("@") + event["server"].hostname = hostname + if sep: + event["server"].username = username diff --git a/src/core_modules/line_handler/message.py b/src/core_modules/line_handler/message.py index 87b91c9e..b3e64be1 100644 --- a/src/core_modules/line_handler/message.py +++ b/src/core_modules/line_handler/message.py @@ -64,7 +64,7 @@ def message(events, event): action = False - if message: + if not message == None: ctcp_message = utils.irc.parse_ctcp(message) if ctcp_message: @@ -97,7 +97,7 @@ def message(events, event): hook = events.on(direction).on(event_type).on(context) buffer_line = None - if message: + if not message == None: buffer_line = IRCBuffer.BufferLine(user.nickname, message, action, event["line"].tags, from_self, event["line"].command) diff --git a/src/core_modules/mode_lists.py b/src/core_modules/mode_lists.py index f560ffa0..37918608 100644 --- a/src/core_modules/mode_lists.py +++ b/src/core_modules/mode_lists.py @@ -24,7 +24,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.348") def on_348(self, event): mode = self._excepts(event["server"]) - self._mode_list_mask(event, mode, event["line"].args[3]) + self._mode_list_mask(event, mode, event["line"].args[2]) @utils.hook("received.349") def on_349(self, event): self._mode_list_end(event, self._excepts(event["server"])) @@ -35,7 +35,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.346") def on_346(self, event): mode = self._invex(event["server"]) - self._mode_list_mask(event, mode, event["line"].args[3]) + self._mode_list_mask(event, mode, event["line"].args[2]) @utils.hook("received.347") def on_347(self, event): self._mode_list_end(event, self._invex(event["server"])) diff --git a/src/core_modules/more.py b/src/core_modules/more.py index 52849938..bc76fb6b 100644 --- a/src/core_modules/more.py +++ b/src/core_modules/more.py @@ -20,4 +20,4 @@ class Module(ModuleManager.BaseModule): def more(self, event): last_stdout = event["target"]._last_stdout if last_stdout and last_stdout.has_text(): - event["stdout"].write_lines(last_stdout.get_all()) + event["stdout"].copy_from(last_stdout) diff --git a/src/core_modules/perform.py b/src/core_modules/perform.py index 832cab54..06f8cd89 100644 --- a/src/core_modules/perform.py +++ b/src/core_modules/perform.py @@ -24,53 +24,39 @@ class Module(ModuleManager.BaseModule): self._execute(event["server"], commands, NICK=event["server"].nickname, CHAN=event["channel"].name) - def _perform(self, target, args_split): - subcommand = args_split[0].lower() + def _perform(self, target, spec): + subcommand = spec[0] current_perform = target.get_setting("perform", []) if subcommand == "list": return "Configured commands: %s" % ", ".join(current_perform) message = None if subcommand == "add": - if not len(args_split) > 1: - raise utils.EventError("Please provide a raw command to add") - current_perform.append(" ".join(args_split[1:])) + current_perform.append(spec[1]) message = "Added command" elif subcommand == "remove": - if not len(args_split) > 1: - raise utils.EventError("Please provide an index to remove") - if not args_split[1].isdigit(): - raise utils.EventError("Please provide a number") - index = int(args_split[1]) + index = spec[1] if not index < len(current_perform): raise utils.EventError("Index out of bounds") - current_perform.pop(index) - message = "Removed command" - else: - raise utils.EventError("Unknown subcommand '%s'" % subcommand) + command = current_perform.pop(index) + message = "Removed command %d (%s)" % (index, command) target.set_setting("perform", current_perform) return message - @utils.hook("received.command.perform", min_args=1) - @utils.kwarg("min_args", 1) + @utils.hook("received.command.perform", permission="perform", + help="Edit on-connect command configuration") + @utils.hook("received.command.cperform", permission="perform", + help="Edit channel on-join command configuration", channel_only=True) @utils.kwarg("help", "Edit on-connect command configuration") - @utils.kwarg("usage", "list") - @utils.kwarg("usage", "add <raw command>") - @utils.kwarg("usage", "remove <index>") - @utils.kwarg("permission", "perform") + @utils.spec("!'list") + @utils.spec("!'add !<command>string") + @utils.spec("!'remove !<index>int") def perform(self, event): - event["stdout"].write(self._perform(event["server"], - event["args_split"])) + if event["command"] == "perform": + target = event["server"] + elif event["command"] == "cperform": + target = event["target"] - @utils.hook("received.command.cperform", min_args=1) - @utils.kwarg("min_args", 1) - @utils.kwarg("channel_only", True) - @utils.kwarg("help", "Edit channel on-join command configuration") - @utils.kwarg("usage", "list") - @utils.kwarg("usage", "add <raw command>") - @utils.kwarg("usage", "remove <index>") - @utils.kwarg("permission", "cperform") - def cperform(self, event): - event["stdout"].write(self._perform(event["target"], - event["args_split"])) + out = self._perform(target, event["spec"]) + event["stdout"].write("%s: %s" % (event["user"].nickname, out)) diff --git a/src/utils/datetime/format.py b/src/utils/datetime/format.py index 889f5906..4b56e7b3 100644 --- a/src/utils/datetime/format.py +++ b/src/utils/datetime/format.py @@ -89,7 +89,7 @@ def to_pretty_time(total_seconds: int, max_units: int=UNIT_MINIMUM, if hours and len(out) < max_units: out.append("%dh" % hours) if minutes and len(out) < max_units: - out.append("%dmi" % minutes) + out.append("%dm" % minutes) if seconds and len(out) < max_units: out.append("%ds" % seconds) diff --git a/src/utils/http.py b/src/utils/http.py index 9f25b315..045d9641 100644 --- a/src/utils/http.py +++ b/src/utils/http.py @@ -3,6 +3,7 @@ import typing, urllib.error, urllib.parse, uuid import json as _json import bs4, netifaces, requests, tornado.httpclient from src import IRCBot, utils +from requests_toolbelt.adapters import source REGEX_URL = re.compile("https?://\S+", re.I) @@ -69,6 +70,7 @@ class Request(object): json_body: bool = False allow_redirects: bool = True + check_hostname: bool = False check_content_type: bool = True fallback_encoding: typing.Optional[str] = None content_type: typing.Optional[str] = None @@ -77,6 +79,8 @@ class Request(object): timeout: int=5 + bindhost: typing.Optional[str] = None + def validate(self): self.id = self.id or str(uuid.uuid4()) self.set_url(self.url) @@ -169,24 +173,61 @@ def request(request_obj: typing.Union[str, Request], **kwargs) -> Response: request_obj = Request(request_obj, **kwargs) return _request(request_obj) +class HostNameInvalidError(ValueError): + pass +class TooManyRedirectionsError(Exception): + pass + def _request(request_obj: Request) -> Response: request_obj.validate() + + def _assert_allowed(url: str): + hostname = urllib.parse.urlparse(url).hostname + if hostname is None or not host_permitted(hostname): + raise HostNameInvalidError( + f"hostname {hostname} is not permitted") + def _wrap() -> Response: headers = request_obj.get_headers() - response = requests.request( - request_obj.method, - request_obj.url, - headers=headers, - params=request_obj.get_params, - data=request_obj.get_body(), - allow_redirects=request_obj.allow_redirects, - stream=True, - cookies=request_obj.cookies - ) - response_content = response.raw.read(RESPONSE_MAX, - decode_content=True) - if not response.raw.read(1) == b"": - raise ValueError("Response too large") + + redirect = 0 + current_url = request_obj.url + session = requests.Session() + if not request_obj.bindhost is None: + new_source = source.SourceAddressAdapter(request_obj.bindhost) + session.mount('http://', new_source) + session.mount('https://', new_source) + + while True: + if request_obj.check_hostname: + _assert_allowed(current_url) + + response = session.request( + request_obj.method, + current_url, + headers=headers, + params=request_obj.get_params, + data=request_obj.get_body(), + allow_redirects=False, + stream=True, + cookies=request_obj.cookies + ) + + if response.status_code in [301, 302]: + redirect += 1 + if redirect == 5: + raise TooManyRedirectionsError(f"{redirect} redirects") + else: + current_url = response.headers["location"] + continue + + response_content = response.raw.read(RESPONSE_MAX, + decode_content=True) + if not response.raw.read(1) == b"": + raise ValueError("Response too large") + break + + session.close() headers = utils.CaseInsensitiveDict(dict(response.headers)) our_response = Response(response.status_code, response_content, diff --git a/src/utils/irc.py b/src/utils/irc.py index 400d5352..f93f61ed 100644 --- a/src/utils/irc.py +++ b/src/utils/irc.py @@ -75,48 +75,30 @@ FORMAT_TOKENS = [ FORMAT_STRIP = [ "\x08" # backspace ] -def _format_tokens(s: str) -> typing.List[str]: - is_color = False - foreground: typing.List[str] = [] - background: typing.List[str] = [] - is_background = False - matches = [] # type: typing.List[str] - - for i, char in enumerate(s): - last_char = i == len(s)-1 - if is_color: - current_color = background if is_background else foreground - color_finished = True - - if char == "," and not is_background: - is_background = True - color_finished = False - elif char.isdigit() and len(current_color) < 2: - current_color.append(char) - color_finished = len(current_color) == 2 and is_background - - if color_finished or last_char: - color = "".join(foreground) - if background: - color += "".join([","]+background) +def _format_tokens(s: str) -> typing.List[str]: + tokens: typing.List[str] = [] - matches.append("\x03%s" % color) - is_color = False - foreground = [] - background = [] - is_background = False + s_copy = list(s) + while s_copy: + token = s_copy.pop(0) + if token == "\x03": + for i in range(2): + if s_copy and s_copy[0].isdigit(): + token += s_copy.pop(0) + if (len(s_copy) > 1 and + s_copy[0] == "," and + s_copy[1].isdigit()): + token += s_copy.pop(0) + token += s_copy.pop(0) + if s_copy and s_copy[0].isdigit(): + token += s_copy.pop(0) - if char == consts.COLOR: - if is_color: - matches.append(char) - else: - is_color = True - elif char in FORMAT_TOKENS: - matches.append(char) - elif char in FORMAT_STRIP: - matches.append(char) - return matches + tokens.append(token) + elif (token in FORMAT_TOKENS or + token in FORMAT_STRIP): + tokens.append(token) + return tokens def _color_match(code: typing.Optional[str], foreground: bool) -> str: if not code: diff --git a/src/utils/parse/__init__.py b/src/utils/parse/__init__.py index 262edf4a..36da8ffb 100644 --- a/src/utils/parse/__init__.py +++ b/src/utils/parse/__init__.py @@ -131,7 +131,7 @@ def format_tokens(s: str, sigil: str="$" return tokens def format_token_replace(s: str, vars: typing.Dict[str, str], - sigil: str="$") -> str: + sigil: str="$") -> typing.Tuple[typing.List[str], str]: vars = vars.copy() vars.update({sigil: sigil}) @@ -140,7 +140,10 @@ def format_token_replace(s: str, vars: typing.Dict[str, str], tokens.sort(key=lambda x: x[0]) tokens.reverse() + not_found: typing.List[str] = [] for start, end, token in tokens: if token in vars: s = s[:start] + vars[token] + s[end+1:] - return s + else: + not_found += token + return not_found, s diff --git a/src/utils/parse/sed.py b/src/utils/parse/sed.py index 76b9e567..8a1c895d 100644 --- a/src/utils/parse/sed.py +++ b/src/utils/parse/sed.py @@ -76,6 +76,6 @@ def parse(sed_s: str) -> typing.Optional[Sed]: return SedMatch(type, re.compile(pattern, flags)) return None -def sed(sed_obj: Sed, s: str) -> typing.Tuple[str, typing.Optional[str]]: +def sed(sed_obj: Sed, s: str) -> typing.Optional[str]: out = sed_obj.match(s) - return sed_obj.type, out + return out diff --git a/src/utils/parse/spec.py b/src/utils/parse/spec.py index 7d7ffd3a..2a682c3a 100644 --- a/src/utils/parse/spec.py +++ b/src/utils/parse/spec.py @@ -15,6 +15,7 @@ class SpecArgumentContext(enum.IntFlag): class SpecArgumentType(object): context = SpecArgumentContext.ALL + _modifier: typing.Optional[str] def __init__(self, type_name: str, name: typing.Optional[str], modifier: typing.Optional[str], exported: typing.Optional[str]): @@ -24,7 +25,7 @@ class SpecArgumentType(object): self.exported = exported def _set_modifier(self, modifier: typing.Optional[str]): - pass + self._modifier = modifier def name(self) -> typing.Optional[str]: return self._name @@ -102,6 +103,18 @@ class SpecArgumentTypeDate(SpecArgumentType): return date_human(args[0]), 1 return None, 1 +class SpecArgumentTypeFlag(SpecArgumentType): + def _str(self): + pref = "-" if len(self._modifier) == 1 else "--" + return f"{pref}{self._modifier}" + def name(self): + return self._str() + def simple(self, args): + print("flag _str()", self._str()) + if args and args[0] == self._str(): + return True, 1 + return None, 1 + class SpecArgumentPrivateType(SpecArgumentType): context = SpecArgumentContext.PRIVATE @@ -115,7 +128,8 @@ SPEC_ARGUMENT_TYPES = { "int": SpecArgumentTypeInt, "date": SpecArgumentTypeDate, "duration": SpecArgumentTypeDuration, - "pattern": SpecArgumentTypePattern + "pattern": SpecArgumentTypePattern, + "flag": SpecArgumentTypeFlag } class SpecArgument(object): diff --git a/src/utils/settings.py b/src/utils/settings.py index 97fe885b..53ee0ebd 100644 --- a/src/utils/settings.py +++ b/src/utils/settings.py @@ -49,15 +49,17 @@ class IntSetting(Setting): class IntRangeSetting(IntSetting): example: typing.Optional[str] = None - def __init__(self, n_min: int, n_max: int, name: str, help: str=None, - example: str=None): + def __init__(self, n_min: int, n_max: typing.Optional[int], name: str, + help: str=None, example: str=None): self._n_min = n_min self._n_max = n_max Setting.__init__(self, name, help, example) def parse(self, value: str) -> typing.Any: out = IntSetting.parse(self, value) - if not out == None and self._n_min <= out <= self._n_max: + if (not out == None and + self._n_min <= out and + (self._n_max == None or out <= self._n_max)): return out return None |
