diff options
62 files changed, 661 insertions, 316 deletions
diff --git a/.travis.yml b/.travis.yml index ec70dc6b..f5413483 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,6 @@ python: - "3.7" - "3.8" - "3.8-dev" - - "nightly" install: - pip3 install mypy -r requirements.txt script: diff --git a/CHANGELOG.md b/CHANGELOG.md index ea835bea..b94f3784 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,31 @@ +# 2020-08-26 - BitBot v1.20.0 + +Added: +- `ban-enforce.py` to kick people that match a new ban mask +- MLOCK-like functionality in `channel_op.py` +- Channels can be opted out of inactivity pruning +- Optional to disable youtubeifying `!np` output +- Allow RSS requests to bind to specific source addresses +- `!words` in PM +- `yourls.py` for yourl url shortners +- handle `RPL_VISIBLEHOST` + +Changed: +- `ERR_LINKCHANNEL` (470) now removes the initial channel from autojoin +- (IRCv3) `+draft/typing` was ratified +- We're no longer supporting a specific broken charybdis install, for line length calculations +- Much better line truncation +- Handle HTTP redirects ourselves, to avoid redirects on to forbidden hosts + +Fixed: +- All bot.conf paths should have ~/{DATA} expanded +- `host-meta` URL for fediverse accounts should be optional - fallback to default webfinger +- Message filter `m/` criterias should operate on formatting-stripped lines +- Quote `!grab`s were wiping the user's category +- Quote `!quotedel`s were looking at the wrong categories +- `!words` was squashing results in to a dict, losing days +- INVEX and EXCEPT lists were looking at the wrong index for masks + # 2020-02-29 - BitBot v1.19.0 ("Command Specs Spark Joy") Added: @@ -1 +1 @@ -1.19.0 +1.20.0 @@ -49,11 +49,15 @@ if args.version: config = Config.Config("bot", args.config) config.load() -DATA_DIR = os.path.expanduser(config.get("data-directory", "~/.bitbot")) -LOG_DIR = config.get("log-directory", "{DATA}/logs/").format(DATA=DATA_DIR) -DATABASE = config.get("database", "sqlite3:{DATA}/bot.db").format(DATA=DATA_DIR) -LOCK_FILE = config.get("lock-file", "{DATA}/bot.lock").format(DATA=DATA_DIR) -SOCK_FILE = config.get("sock-file", "{DATA}/bot.sock").format(DATA=DATA_DIR) +DATA_DIR = "" +def _expand(s: str): + return os.path.expanduser(s).format(DATA=DATA_DIR) + +DATA_DIR = _expand(config.get("data-directory", "~/.bitbot")) +LOG_DIR = _expand(config.get("log-directory", "{DATA}/logs/")) +DATABASE = _expand(config.get("database", "sqlite3:{DATA}/bot.db")) +LOCK_FILE = _expand(config.get("lock-file", "{DATA}/bot.lock")) +SOCK_FILE = _expand(config.get("sock-file", "{DATA}/bot.sock")) if not os.path.isdir(LOG_DIR): os.mkdir(LOG_DIR) @@ -111,7 +115,8 @@ extra_modules = [os.path.join(directory, "modules")] if args.external: extra_modules.append(os.path.abspath(args.external)) if "external-modules" in config: - extra_modules.append(os.path.abspath(config["external-modules"])) + conf_extra = os.path.abspath(config["external-modules"]) + extra_modules.append(_expand(conf_extra)) modules = ModuleManager.ModuleManager(events, exports, timers, config, log, core_modules, extra_modules) diff --git a/docs/help/config.md b/docs/help/config.md index e11d14ec..4db2b78a 100644 --- a/docs/help/config.md +++ b/docs/help/config.md @@ -2,8 +2,8 @@ * Move `docs/bot.conf.example` to `~/.bitbot/bot.conf` and fill in the config options you care about. Ones blank or removed will disable relevant functionality. * Run `./bitbotd -a` to add a server. -* Run `./bitbotctl command master-password` to get the master admin password (needed to add regular admin accounts) * Run `./bitbotd` to start the bot. +* Run `./bitbotctl command master-password` to get the master admin password (needed to add regular admin accounts) * Join `#bitbot` on a server with the bot (or invite it to another channel) * `/msg <bot> register <password here>` to register your nickname with the bot * (use `/msg <bot> identify <password>` to log in in the future) diff --git a/modules/alias_variables.py b/modules/alias_variables.py index 6f2499a3..b618a7c6 100644 --- a/modules/alias_variables.py +++ b/modules/alias_variables.py @@ -5,7 +5,11 @@ class Module(ModuleManager.BaseModule): @utils.hook("get.command") @utils.kwarg("priority", EventManager.PRIORITY_HIGH) def get_command(self, event): + event["kwargs"]["CTRIGGER"] = event["command_prefix"] + + event["kwargs"]["BNICK"] = event["server"].nickname event["kwargs"]["NICK"] = event["user"].nickname + if event["is_channel"]: event["kwargs"]["CHAN"] = event["target"].name random_user = random.choice(list(event["target"].users)) diff --git a/modules/channel_save.py b/modules/autojoin.py index d5b69638..bbd08dfb 100644 --- a/modules/channel_save.py +++ b/modules/autojoin.py @@ -25,6 +25,9 @@ class Module(ModuleManager.BaseModule): if channel_name in channels: channels.remove(channel_name) server.set_setting("autojoin", channels) + return True + else: + return False @utils.hook("self.part") def on_part(self, event): @@ -33,3 +36,14 @@ class Module(ModuleManager.BaseModule): @utils.hook("self.kick") def on_kick(self, event): self._remove_channel(event["server"], event["channel"].name) + + @utils.hook("raw.received.470") + def on_linkchannel(self, event): + initial = event["line"].args[1] + initial_lower = event["server"].irc_lower(initial) + linked = event["line"].args[2] + + if self._remove_channel(event["server"], initial_lower): + self.log.warn(f"{str(event['server'])} " + f"channel {initial} linked to {linked} " + "- removed from autojoin") diff --git a/modules/badwords.py b/modules/badwords.py index d23a4137..db5ed24c 100644 --- a/modules/badwords.py +++ b/modules/badwords.py @@ -80,9 +80,8 @@ class Module(ModuleManager.BaseModule): ban = True if ban: - event["channel"].send_ban("*!%s@%s" % ( - event["user"].username, - event["user"].realname)) + event["channel"].send_ban(self.exports.get("ban-mask")( + event["server"], event["channel"], event["user"])) if kick: event["channel"].send_kick(event["user"].nickname, "You said a badword!") diff --git a/modules/ban_enforce.py b/modules/ban_enforce.py new file mode 100644 index 00000000..4e40bb84 --- /dev/null +++ b/modules/ban_enforce.py @@ -0,0 +1,27 @@ +from src import ModuleManager, utils + +REASON = "User is banned from this channel" + +@utils.export("channelset", utils.BoolSetting("ban-enforce", + "Whether or not to parse new bans and kick who they affect")) +class Module(ModuleManager.BaseModule): + @utils.hook("received.mode.channel") + def on_mode(self, event): + if event["channel"].get_setting("ban-enforce", False): + bans = [] + kicks = set([]) + for mode, arg in event["modes"]: + if mode[0] == "+" and mode[1] == "b": + bans.append(arg) + + if bans: + umasks = {u.hostmask(): u for u in event["channel"].users} + for ban in bans: + mask = utils.irc.hostmask_parse(ban) + matches = list(utils.irc.hostmask_match_many( + umasks.keys(), mask)) + for match in matches: + kicks.add(umasks[match]) + if kicks: + nicks = [u.nickname for u in kicks] + event["channel"].send_kicks(sorted(nicks), REASON) diff --git a/modules/channel_op.py b/modules/channel_op.py index e7ab7bc9..d37e27b4 100644 --- a/modules/channel_op.py +++ b/modules/channel_op.py @@ -26,6 +26,28 @@ class TargetType(enum.Enum): MASK = 2 ACCOUNT = 3 +def _mlock(s): + modes, *args = s.split(" ") + + adds = "" + removes = "" + add = True + for c in modes: + if c == "+": + add = True + elif c == "-": + add = False + elif add: + adds += c + else: + removes += c + + return ( + "" if not adds else f"+{''.join(adds)}", + "" if not removes else f"-{''.join(removes)}", + args + ) + KICK_REASON_SETTING = utils.Setting("default-kick-reason", "Set the default kick reason", example="have a nice trip") @@ -33,10 +55,6 @@ BAN_FORMATTING = "${n} = nick, ${u} = username, ${h} = hostname, ${a} = account" @utils.export("channelset", utils.Setting("ban-format", "Set ban format (%s)" % BAN_FORMATTING, example="*!${u}@${h}")) -@utils.export("channelset", utils.Setting("ban-format-account", - "Set ban format for users with accounts (%s)" % BAN_FORMATTING, - example="~a:${a}")) - @utils.export("serverset", utils.OptionsSetting( list(QUIET_METHODS.keys()), "quiet-method", "Set this server's method of muting users")) @@ -47,6 +65,10 @@ BAN_FORMATTING = "${n} = nick, ${u} = username, ${h} = hostname, ${a} = account" @utils.export("botset", KICK_REASON_SETTING) @utils.export("serverset", KICK_REASON_SETTING) @utils.export("channelset", KICK_REASON_SETTING) + +@utils.export("channelset", utils.FunctionSetting(_mlock, "mlock", + "Set which modes are locked on and off for the current channel", + example="+mnt-z")) class Module(ModuleManager.BaseModule): _name = "ChanOp" @@ -75,23 +97,6 @@ class Module(ModuleManager.BaseModule): reason = reason or self._kick_reason(server, channel) channel.send_kicks(nicknames, reason) - def _format_hostmask(self, user, s): - vars = {} - vars["n"] = vars["nickname"] = user.nickname - vars["u"] = vars["username"] = user.username - vars["h"] = vars["hostname"] = user.hostname - vars["a"] = vars["account"] = user.account or "" - return utils.parse.format_token_replace(s, vars) - def _get_hostmask(self, channel, user): - if not user.account == None: - account_format = channel.get_setting("ban-format-account", None) - if not account_format == None: - return self._format_hostmask(user, account_format) - - format = channel.get_setting("ban-format", "*!${u}@${h}") - return self._format_hostmask(user, format) - - @utils.hook("received.command.topic") @utils.kwarg("require_mode", "o") @utils.kwarg("require_access", "low,topic") @@ -209,7 +214,8 @@ class Module(ModuleManager.BaseModule): elif flag == "V" and identified: modes.append(("v", user.nickname)) elif flag == "b": - modes.append(("b", self._get_hostmask(channel, user))) + mask = self.exports.get("ban-mask")(server, channel, user) + modes.append(("b", mask)) kick_reason = "User is banned from this channel" new_modes = [] @@ -394,15 +400,16 @@ class Module(ModuleManager.BaseModule): elif spec[2][0] in ["user", "cuser"]: users = [spec[2][1]] elif spec[2][0] == "word": - masks = [spec[2][1]] + args = [spec[2][1]] target_type, mode, prefix = self._find_mode(type, server) if users: if target_type == TargetType.MASK: - args = [self._get_hostmask(spec[0], u) for u in users] + mask_f = self.exports.get("ban-mask") + args = [mask_f(server, spec[0], u) for u in users] elif target_type == TargetType.NICKNAME: args = [ - u.nickname for u in users if not spec[0].has_mode(u, mode)] + u.nickname for u in users if not spec[0].has_umode(u, mode)] elif target_type == TargetType.ACCOUNT: args = [u.account for u in users if not u.account == None] @@ -428,7 +435,9 @@ class Module(ModuleManager.BaseModule): users = args = [] if event["spec"][1][0] == "user": - masks = [self._get_hostmask(event["spec"][0], event["spec"][1][1])] + mask_f = self.exports.get("ban-mask") + masks = [ + mask_f(event["server"], event["spec"][0], event["spec"][1][1])] elif event["spec"][1][0] == "word": masks = self._list_query_event(event["spec"][0], event["spec"][1][1], mode, prefix) @@ -457,7 +466,7 @@ class Module(ModuleManager.BaseModule): _, mode, _ = self._find_mode( event["hook"].get_kwarg("type"), event["server"]) valid_nicks = [ - u.nickname for u in users if event["spec"][0].has_mode(u, mode)] + u.nickname for u in users if event["spec"][0].has_umode(u, mode)] if valid_nicks: event["spec"][0].send_modes([(mode, a) for a in valid_nicks], False) @@ -489,3 +498,43 @@ class Module(ModuleManager.BaseModule): if modes: event["spec"][0].send_modes(modes, False) + + @utils.hook("received.324") + @utils.hook("received.mode.channel") + def on_modes(self, event): + mlock = event["channel"].get_setting("mlock", None) + if mlock: + changes_adds = "" + changes_removes = "" + changes_args = [] + adds, removes, args = mlock + args = args.copy() # cached settings objects are mutable + + for mode in adds[1:]: + if not event["channel"].has_mode(mode): + if (mode in event["server"].channel_list_modes or + mode in event["server"].channel_parametered_modes or + mode in event["server"].channel_setting_modes): + changes_adds += mode + changes_args.append(args.pop(0)) + elif mode in event["server"].channel_modes: + changes_adds += mode + + for mode in removes[1:]: + if event["channel"].has_mode(mode): + if (mode in event["server"].channel_list_modes or + mode in event["server"].channel_parametered_modes): + changes_removes += mode + changes_args.append(args.pop(0)) + elif (mode in event["server"].channel_setting_modes or + mode in event["server"].channel_modes): + changes_removes += mode + + out = "" + if changes_adds: + out += f"+{changes_adds}" + if changes_removes: + out += f"-{changes_removes}" + + if out: + event["channel"].send_mode(out, changes_args) diff --git a/modules/dnsbl/__init__.py b/modules/dnsbl/__init__.py index 495cbd8c..2b3daf35 100644 --- a/modules/dnsbl/__init__.py +++ b/modules/dnsbl/__init__.py @@ -14,11 +14,11 @@ class Module(ModuleManager.BaseModule): lists = [] for i, arg in reversed(list(enumerate(args))): if arg[0] == "@": - hostname = args.pop(i) + hostname = args.pop(i)[1:] if hostname in default_lists: lists.insert(0, default_lists[hostname]) else: - lists.insert(0, lists.DNSBL(hostname)) + lists.insert(0, _lists.DNSBL(hostname)) lists = lists or list(default_lists.values()) diff --git a/modules/dnsbl/lists.py b/modules/dnsbl/lists.py index ce2b6404..b84628ea 100644 --- a/modules/dnsbl/lists.py +++ b/modules/dnsbl/lists.py @@ -6,7 +6,7 @@ class DNSBL(object): self.hostname = hostname def process(self, result: str): - return "unknown" + return result class ZenSpamhaus(DNSBL): hostname = "zen.spamhaus.org" @@ -40,10 +40,18 @@ class DroneBL(DNSBL): elif result in ["12", "13", "15", "16"]: return "exploits" +class AbuseAtCBL(DNSBL): + hostname = "cbl.abuseat.org" + def process(self, result): + result = result.rsplit(".", 1)[1] + if result == "2": + return "abuse" + DEFAULT_LISTS = [ ZenSpamhaus(), EFNetRBL(), - DroneBL() + DroneBL(), + AbuseAtCBL() ] def default_lists(): diff --git a/modules/fediverse/__init__.py b/modules/fediverse/__init__.py index c1279eb7..aefd966e 100644 --- a/modules/fediverse/__init__.py +++ b/modules/fediverse/__init__.py @@ -95,7 +95,7 @@ class Module(ModuleManager.BaseModule): note = note["object"] cw, author, content, url = ap_utils.parse_note(actor, note, type) - shorturl = self.exports.get_one("shorturl")(event["server"], url, + shorturl = self.exports.get("shorturl")(event["server"], url, context=event["target"]) if cw: diff --git a/modules/fediverse/ap_server.py b/modules/fediverse/ap_server.py index e1b5cece..1fb98760 100644 --- a/modules/fediverse/ap_server.py +++ b/modules/fediverse/ap_server.py @@ -12,7 +12,7 @@ class Server(object): self.username = username self.instance = instance - url_for = self.exports.get_one("url-for") + url_for = self.exports.get("url-for") key_id = self._ap_keyid_url(url_for) private_key = ap_security.PrivateKey(self.bot.config["tls-key"], key_id) @@ -61,7 +61,7 @@ class Server(object): def _toot(self, activity_id): content, timestamp = self.bot.get_setting( "ap-activity-%s" % activity_id) - url_for = self.exports.get_one("url-for") + url_for = self.exports.get("url-for") self_id = self._ap_self_url(url_for) activity_url = self._ap_activity_url(url_for, activity_id) diff --git a/modules/fediverse/ap_utils.py b/modules/fediverse/ap_utils.py index 0d44d523..9174648c 100644 --- a/modules/fediverse/ap_utils.py +++ b/modules/fediverse/ap_utils.py @@ -39,15 +39,17 @@ class FindActorException(Exception): pass def find_actor(username, instance): - hostmeta = HOSTMETA_TEMPLATE % instance - hostmeta_request = utils.http.Request(HOSTMETA_TEMPLATE % instance) + hostmeta_url = HOSTMETA_TEMPLATE % instance + hostmeta_request = utils.http.Request(hostmeta_url) try: hostmeta = utils.http.request(hostmeta_request) except: - raise FindActorException("Failed to get host-meta for %s" % instance) + # failed to GET hostmeta; this is an optional step for servers that do + # not host their webfinger at the usual URL (see WEBFINGER_TEMPLATE) + hostmeta = None webfinger_url = None - if hostmeta.code == 200: + if hostmeta and hostmeta.code == 200: for item in hostmeta.soup().find_all("link"): if item["rel"] and item["rel"][0] == "lrdd": webfinger_url = item["template"] @@ -60,8 +62,9 @@ def find_actor(username, instance): try: webfinger = activity_request(webfinger_url, type=JRD_TYPE) - except: - raise FindActorException("Failed to get webfinger for %s" % instance) + except Exception as e: + raise FindActorException("Failed to get webfinger for %s: %s" % + (instance, str(e))) actor_url = None if webfinger.code == 200: diff --git a/modules/git_webhooks/__init__.py b/modules/git_webhooks/__init__.py index b87f0665..44d9b5dc 100644 --- a/modules/git_webhooks/__init__.py +++ b/modules/git_webhooks/__init__.py @@ -138,7 +138,7 @@ class Module(ModuleManager.BaseModule): if url: if channel.get_setting("git-shorten-urls", False): - url = self.exports.get_one("shorturl")(server, url, + url = self.exports.get("shorturl")(server, url, context=channel) or url output = "%s - %s" % (output, url) diff --git a/modules/git_webhooks/github.py b/modules/git_webhooks/github.py index 21b672ce..8bf61114 100644 --- a/modules/git_webhooks/github.py +++ b/modules/git_webhooks/github.py @@ -441,7 +441,7 @@ class GitHub(object): url = "" if data["check_run"]["details_url"]: url = data["check_run"]["details_url"] - url = " - %s" % self.exports.get_one("shorturl-any")(url) + url = " - %s" % self.exports.get("shorturl-any")(url) duration = "" if data["check_run"]["completed_at"]: diff --git a/modules/ids.py b/modules/ids.py index 1738c549..2e056892 100644 --- a/modules/ids.py +++ b/modules/ids.py @@ -16,7 +16,7 @@ class Module(ModuleManager.BaseModule): @utils.kwarg("help", "Show what I think your account name is") def account(self, event): event["stdout"].write("%s: %s" % (event["user"].nickname, - self.exports.get_one("account-name")(event["user"]))) + self.exports.get("account-name")(event["user"]))) @utils.hook("received.command.channelid", channel_only=True) def channel_id(self, event): diff --git a/modules/inactive_channels.py b/modules/inactive_channels.py index 5959d432..78f724c7 100644 --- a/modules/inactive_channels.py +++ b/modules/inactive_channels.py @@ -1,21 +1,23 @@ import datetime from src import ModuleManager, utils -PRUNE_TIMEDELTA = datetime.timedelta(weeks=2) +PRUNE_TIMEDELTA = datetime.timedelta(weeks=4) -SETTING_NAME = "inactive-channels" -SETTING = utils.BoolSetting(SETTING_NAME, - "Whether or not to leave inactive channels after 2 weeks") +SETTING_NAME = "inactive-prune" +SETTING = utils.IntRangeSetting(0, None, SETTING_NAME, + "Amount of days of inactivity before we leave a channel") -MODE_SETTING_NAME = "inactive-channel-modes" +MODE_SETTING_NAME = "inactive-prune-modes" MODE_SETTING = utils.BoolSetting(MODE_SETTING_NAME, "Whether or not we will leave inactive channels that we have a mode in") @utils.export("botset", SETTING) @utils.export("serverset", SETTING) -@utils.export("channelset", SETTING) @utils.export("serverset", MODE_SETTING) @utils.export("channelset", MODE_SETTING) + +@utils.export("channelset", utils.BoolSetting(SETTING_NAME, + "Whether or not to leave this channel when it is inactive")) class Module(ModuleManager.BaseModule): def _get_timestamp(self, channel): return channel.get_setting("last-message", None) @@ -35,29 +37,27 @@ class Module(ModuleManager.BaseModule): def hourly(self, event): parts = [] now = utils.datetime.utcnow() - botwide_setting = self.bot.get_setting(SETTING_NAME, False) + botwide_days = self.bot.get_setting(SETTING_NAME, None) botwide_mode_setting = self.bot.get_setting(MODE_SETTING_NAME, False) for server in self.bot.servers.values(): - serverwide_setting = server.get_setting( - SETTING_NAME, botwide_setting) - if not serverwide_setting: + serverwide_days = server.get_setting(SETTING_NAME, botwide_days) + if serverwide_days == None: continue mode_setting = server.get_setting( MODE_SETTING_NAME, botwide_mode_setting) - our_user = server.get_user(server.nickname) + for channel in server.channels: - if not channel.get_setting(SETTING_NAME, serverwide_setting): - continue - if not mode_setting and channel.get_user_modes(our_user): + if (not channel.get_setting(SETTING_NAME, True) or + not mode_setting and channel.get_user_modes(our_user)): continue timestamp = self._get_timestamp(channel) if timestamp: dt = utils.datetime.parse.iso8601(timestamp) - if (now-dt) >= PRUNE_TIMEDELTA: + if (now-dt).days >= serverwide_days: parts.append([server, channel]) for server, channel in parts: @@ -66,6 +66,7 @@ class Module(ModuleManager.BaseModule): channel.send_part("Channel inactive") self._del_timestamp(channel) + @utils.hook("send.message.channel") @utils.hook("received.message.channel") def channel_message(self, event): self._set_timestamp(event["channel"]) diff --git a/modules/ircv3_editmsg.py b/modules/ircv3_editmsg.py index a7b5b2b0..15163053 100644 --- a/modules/ircv3_editmsg.py +++ b/modules/ircv3_editmsg.py @@ -18,7 +18,7 @@ class Module(ModuleManager.BaseModule): timestamp, line.message) line = "- %s" % minimal - self.exports.get_one("format")("delete", event["server"], line, + self.exports.get("format")("delete", event["server"], line, event["target_str"], minimal=minimal, channel=channel, user=event["user"]) diff --git a/modules/ircv3_typing.py b/modules/ircv3_typing.py index 3c8c9c6d..05de5162 100644 --- a/modules/ircv3_typing.py +++ b/modules/ircv3_typing.py @@ -5,7 +5,10 @@ CAP = utils.irc.Capability("message-tags", "draft/message-tags-0.2") class Module(ModuleManager.BaseModule): def _tagmsg(self, target, state): return IRCLine.ParsedLine("TAGMSG", [target], - tags={"+draft/typing": state}) + tags={ + "+typing": state, + "+draft/typing": state + }) def _has_tags(self, server): return server.has_capability(CAP) diff --git a/modules/lastfm.py b/modules/lastfm.py index 14e9128a..6931e0c0 100644 --- a/modules/lastfm.py +++ b/modules/lastfm.py @@ -7,8 +7,13 @@ from src import ModuleManager, utils URL_SCROBBLER = "http://ws.audioscrobbler.com/2.0/" +SETTING_YT = utils.BoolSetting("lastfm-youtube", + "Whether or not to search last.fm now-playing results on youtube") + @utils.export("set", utils.Setting("lastfm", "Set last.fm username", example="jesopo")) +@utils.export("botset", SETTING_YT) +@utils.export("serverset", SETTING_YT) class Module(ModuleManager.BaseModule): _name = "last.fm" @@ -59,11 +64,13 @@ class Module(ModuleManager.BaseModule): time_language = "is listening to" if np else "last listened to" - yt_url = self.exports.get_one("search-youtube")( - "%s - %s" % (artist, track_name)) yt_url_str = "" - if yt_url: - yt_url_str = " - %s" % yt_url + if event["server"].get_setting("lastfm-youtube", + self.bot.get_setting("lastfm-youtube", False)): + yt_url = self.exports.get("search-youtube")( + "%s - %s" % (artist, track_name)) + if yt_url: + yt_url_str = " - %s" % yt_url info_page = utils.http.request(URL_SCROBBLER, get_params={ "method": "track.getInfo", "artist": artist, diff --git a/modules/message_filter.py b/modules/message_filter.py index a6191a66..f773ad36 100644 --- a/modules/message_filter.py +++ b/modules/message_filter.py @@ -31,20 +31,23 @@ class Module(ModuleManager.BaseModule): filters = self._get_filters(event["server"], target) for filter in filters: sed = utils.parse.sed.parse(filter) - type, out = utils.parse.sed.sed(sed, message) - if type == "m" and out: - self.log.info("Message matched filter, dropping: %s" - % event["line"].format()) - event["line"].invalidate() - return - elif type == "s": + if sed.type == "m": + out = utils.parse.sed.sed(sed, message_plain) + if out: + self.log.info("Message matched filter, dropping: %s" + % event["line"].format()) + event["line"].invalidate() + return + elif sed.type == "s": + out = utils.parse.sed.sed(sed, message) message = out if not message == original_message: event["line"].args[1] = message - @utils.hook("received.command.cfilter", channel_only=True) + @utils.hook("received.command.cfilter", channel_only=True, + require_access="high,filter", require_mode="o") @utils.hook("received.command.filter") @utils.hook("received.command.bfilter") @utils.kwarg("help", "Add a message filter for the current channel") diff --git a/modules/quotes.py b/modules/quotes.py index a843bdfb..e323cc9b 100644 --- a/modules/quotes.py +++ b/modules/quotes.py @@ -63,22 +63,30 @@ class Module(ModuleManager.BaseModule): category) found_target = None + found_quote = None if not remove_quote == None: - remove_quote_lower = remove_quote.lower() + remove_quote_lower = remove_quote.lower().strip() for nickname, time_added, quote, target in quotes[:]: - if quote.lower() == remove_quote_lower: - quotes.remove([nickname, time_added, quote]) + if remove_quote_lower in quote.lower().strip(): found_target = target + found_quote = [nickname, time_added, quote] message = "Removed quote from '%s'" break else: if quotes: - quote = quotes.pop(-1) - found_target = quote[-1] + nickname, time_added, quote, target = quotes.pop(-1) + + found_target = target + found_quote = [nickname, time_added, quote] message = "Removed last '%s' quote" if not message == None: - self._set_quotes(found_target, category, quotes) + target_quotes = self._get_quotes(found_target, category) + target_quotes.remove(found_quote) + self._set_quotes(found_target, category, target_quotes) + + _, _, quote = found_quote + message = f"{message} ({quote})" event["stdout"].write(message % category) else: event["stderr"].write("Quote not found") @@ -89,6 +97,10 @@ class Module(ModuleManager.BaseModule): @utils.kwarg("usage", "<category> [= <search>]") def quote(self, event): category, search = self.category_and_quote(event["args"]) + if event["server"].has_user(category): + category = event["server"].get_user_nickname( + event["server"].get_user(category).get_id()) + quotes = event["server"].get_setting("quotes-%s" % category, []) if event["is_channel"]: quotes += self._get_quotes(event["target"], category) @@ -122,28 +134,26 @@ class Module(ModuleManager.BaseModule): raise utils.EventError( "Please provide a number between 1 and 3") - target = event["args_split"][0] - lines = event["target"].buffer.find_many_from(target, line_count) + target_user = event["args_split"][0] + lines = event["target"].buffer.find_many_from(target_user, line_count) if lines: lines.reverse() target = event["server"] if event["target"].get_setting("channel-quotes", False): target = event["target"] - quotes = self._get_quotes(target, target) - lines_str = [] for line in lines: lines_str.append(line.format()) text = " ".join(lines_str) - quotes.append([event["user"].name, int(time.time()), text]) - quote_category = line.sender if event["server"].has_user(quote_category): - account = event["server"].get_user_nickname( + quote_category = event["server"].get_user_nickname( event["server"].get_user(quote_category).get_id()) + quotes = self._get_quotes(target, quote_category) + quotes.append([event["user"].name, int(time.time()), text]) self._set_quotes(target, quote_category, quotes) event["stdout"].write("Quote added") diff --git a/modules/rss.py b/modules/rss.py index 07916861..9ba23db4 100644 --- a/modules/rss.py +++ b/modules/rss.py @@ -7,10 +7,15 @@ import feedparser RSS_INTERVAL = 60 # 1 minute +SETTING_BIND = utils.Setting("rss-bindhost", + "Which local address to bind to for RSS requests", example="127.0.0.1") + @utils.export("botset", utils.IntSetting("rss-interval", "Interval (in seconds) between RSS polls", example="120")) @utils.export("channelset", utils.BoolSetting("rss-shorten", "Whether or not to shorten RSS urls")) +@utils.export("serverset", SETTING_BIND) +@utils.export("channelset", SETTING_BIND) class Module(ModuleManager.BaseModule): _name = "RSS" def on_load(self): @@ -27,7 +32,7 @@ class Module(ModuleManager.BaseModule): link = entry.get("link", None) if shorten: try: - link = self.exports.get_one("shorturl")(server, link) + link = self.exports.get("shorturl")(server, link) except: pass link = " - %s" % link if link else "" @@ -49,26 +54,36 @@ class Module(ModuleManager.BaseModule): if server and channel_name in server.channels: channel = server.channels.get(channel_name) for url in urls: - if not url in hooks: - hooks[url] = [] - hooks[url].append((server, channel)) + bindhost = channel.get_setting("rss-bindhost", + server.get_setting("rss-bindhost", None)) + + if url.startswith("www."): + url = url.replace("www.", "", 1) + + key = (url, bindhost) + if not key in hooks: + hooks[key] = [] + + hooks[key].append((server, channel)) if not hooks: return requests = [] - for url in hooks.keys(): - requests.append(utils.http.Request(url, id=url)) + for url, bindhost in hooks.keys(): + requests.append(utils.http.Request(url, id=f"{url} {bindhost}", + bindhost=bindhost)) pages = utils.http.request_many(requests) - for url, channels in hooks.items(): - if not url in pages: + for (url, bindhost), channels in hooks.items(): + key = f"{url} {bindhost}" + if not key in pages: # async url get failed continue try: - data = pages[url].decode() + data = pages[key].decode() except Exception as e: self.log.error("Failed to decode rss URL %s", [url], exc_info=True) diff --git a/modules/seen.py b/modules/seen.py index b19dce61..7eb037d8 100644 --- a/modules/seen.py +++ b/modules/seen.py @@ -36,7 +36,7 @@ class Module(ModuleManager.BaseModule): since = utils.datetime.format.to_pretty_since( time.time()-seen_seconds, max_units=2) event["stdout"].write("%s was last seen %s ago%s" % ( - event["args_split"][0], since, seen_info or "")) + user.nickname, since, seen_info or "")) else: event["stderr"].write("I have never seen %s before." % ( - event["args_split"][0])) + user.nickname)) diff --git a/modules/shorturl.py b/modules/shorturl.py index afc1b50c..345d2183 100644 --- a/modules/shorturl.py +++ b/modules/shorturl.py @@ -16,36 +16,48 @@ class Module(ModuleManager.BaseModule): self.exports.add("botset", setting) def _shorturl_options_factory(self): - shorteners = self.exports.find("shorturl-s-") - return [s.replace("shorturl-s-", "", 1) for s in shorteners] + shorteners = set(self.exports.find("shorturl-s-")) + shorteners.update(self.exports.find("shorturl-x-")) + return sorted(s.split("-", 2)[-1] for s in shorteners) def _get_shortener(self, name): - return self.exports.get_one("shorturl-s-%s" % name, None) - def _call_shortener(self, shortener_name, url): - shortener = self._get_shortener(shortener_name) + extended = self.exports.get("shorturl-x-%s" % name, None) + if not extended == None: + return True, extended + return False, self.exports.get("shorturl-s-%s" % name, None) + def _call_shortener(self, server, context, shortener_name, url): + extended, shortener = self._get_shortener(shortener_name) if shortener == None: return None - short_url = shortener(url) + + if extended: + short_url = shortener(server, context, url) + else: + short_url = shortener(url) + if short_url == None: return None return short_url @utils.export("shorturl-any") def _shorturl_any(self, url): - return self._call_shortener("bitly", url) or url + return self._call_shortener(server, None, "bitly", url) or url @utils.export("shorturl") def _shorturl(self, server, url, context=None): shortener_name = None if context: shortener_name = context.get_setting("url-shortener", - server.get_setting("url-shortener", "bitly")) + server.get_setting("url-shortener", + self.bot.get_setting("url-shortener", "bitly"))) else: - shortener_name = server.get_setting("url-shortener", "bitly") + shortener_name = server.get_setting("url-shortener", + self.bot.get_setting("url-shortener", "bitly")) if shortener_name == None: return url - return self._call_shortener(shortener_name, url) or url + return self._call_shortener( + server, context, shortener_name, url) or url @utils.export("shorturl-s-bitly") def _bitly(self, url): diff --git a/modules/title.py b/modules/title.py index df05546e..eba18566 100644 --- a/modules/title.py +++ b/modules/title.py @@ -27,7 +27,7 @@ class Module(ModuleManager.BaseModule): for title_word in RE_WORDSPLIT.split(title): if len(title_word) > 1 or title_word.isalpha(): title_word = title_word.lower() - title_words.append(title_word.strip("'\"<>()")) + title_words.append(title_word.strip("'\"<>(),:")) if title_words: present = 0 @@ -45,13 +45,9 @@ class Module(ModuleManager.BaseModule): if not urllib.parse.urlparse(url).scheme: url = "http://%s" % url - hostname = urllib.parse.urlparse(url).hostname - if not utils.http.host_permitted(hostname): - self.log.warn("Attempted to get forbidden host: %s", [url]) - return -1, None - + request = utils.http.Request(url, check_hostname=True) try: - page = utils.http.request(url) + page = utils.http.request(request) except Exception as e: self.log.error("failed to get URL title for %s: %s", [url, str(e)]) return -1, None @@ -71,7 +67,7 @@ class Module(ModuleManager.BaseModule): return -2, title if channel.get_setting("title-shorten", False): - short_url = self.exports.get_one("shorturl")(server, url, + short_url = self.exports.get("shorturl")(server, url, context=channel) return page.code, "%s - %s" % (title, short_url) return page.code, title diff --git a/modules/tweets/format.py b/modules/tweets/format.py index 540e7638..9648dc51 100644 --- a/modules/tweets/format.py +++ b/modules/tweets/format.py @@ -23,7 +23,7 @@ def _tweet(exports, server, tweet, from_url): short_url = "" if not from_url: - short_url = exports.get_one("shorturl")(server, tweet_link) + short_url = exports.get("shorturl")(server, tweet_link) short_url = " - %s" % short_url if short_url else "" created_at = _timestamp(tweet.created_at) diff --git a/modules/user_time.py b/modules/user_time.py index 762a9440..8d0701e6 100644 --- a/modules/user_time.py +++ b/modules/user_time.py @@ -34,7 +34,7 @@ class Module(ModuleManager.BaseModule): location["timezone"]) if query: - location = self.exports.get_one("get-location")(query) + location = self.exports.get("get-location")(query) if location: return (LocationType.NAME, location["name"], location["timezone"]) diff --git a/modules/weather.py b/modules/weather.py index 1677f990..bd6b5e8b 100644 --- a/modules/weather.py +++ b/modules/weather.py @@ -45,7 +45,7 @@ class Module(ModuleManager.BaseModule): if location == None and query: - location_info = self.exports.get_one("get-location")(query) + location_info = self.exports.get("get-location")(query) if not location_info == None: location = [location_info["lat"], location_info["lon"], location_info.get("name", None)] @@ -74,8 +74,8 @@ class Module(ModuleManager.BaseModule): # wind speed is in metres per second - 3.6* for KMh wind_speed = 3.6*page["wind"]["speed"] - wind_speed_k = "%sKMh" % round(wind_speed, 1) - wind_speed_m = "%sMPh" % round(0.6214*wind_speed, 1) + wind_speed_k = "%skm/h" % round(wind_speed, 1) + wind_speed_m = "%smi/h" % round(0.6214*wind_speed, 1) if not nickname == None: location_str = "(%s) %s" % (nickname, location_str) diff --git a/modules/words.py b/modules/words.py index 88dbc6cc..59b9ef92 100644 --- a/modules/words.py +++ b/modules/words.py @@ -66,7 +66,7 @@ class Module(ModuleManager.BaseModule): if event["channel"].get_setting("word-tracking-registered", event["server"].get_setting("word-tracking-registered", False)): - if not self.exports.get_one("is-identified")(event["user"]): + if not self.exports.get("is-identified")(event["user"]): return if user.get_setting("first-words", None) == None: @@ -105,19 +105,24 @@ class Module(ModuleManager.BaseModule): self._channel_message(event["server"].get_user( event["server"].nickname), event) - @utils.hook("received.command.words", channel_only=True) + @utils.hook("received.command.words") @utils.kwarg("help", "See how many words you or the given nickname have used") - @utils.spec("!-channelonly ?<nickname>ouser") + @utils.spec("?<nickname>ouser") def words(self, event): - target_user = event["spec"][0] or event["user"] + if event["spec"][0] and event["is_channel"]: + target_user = event["spec"][0] + else: + target_user = event["user"] - words = dict(self._user_all(target_user)) - this_channel = words.get(event["target"].id, 0) + word_items = self._user_all(target_user) - total = 0 - for channel_id in words: - total += words[channel_id] + words = {} + for channel_id, count in word_items: + if not channel_id in words: + words[channel_id] = 0 + words[channel_id] += count + total = sum(words.values()) since = "" first_words = target_user.get_setting("first-words", None) @@ -125,9 +130,14 @@ class Module(ModuleManager.BaseModule): since = " since %s" % utils.datetime.format.date_human( utils.datetime.timestamp(first_words)) - event["stdout"].write("%s has used %d words (%d in %s)%s" % ( - target_user.nickname, total, this_channel, event["target"].name, - since)) + if event["is_channel"]: + this_channel = words.get(event["target"].id, 0) + event["stdout"].write("%s has used %d words (%d in %s)%s" % ( + target_user.nickname, total, this_channel, event["target"].name, + since)) + else: + event["stdout"].write("%s has used %d words%s" % ( + target_user.nickname, total, since)) @utils.hook("received.command.trackword") @utils.kwarg("help", "Start tracking a word") @@ -193,7 +203,9 @@ class Module(ModuleManager.BaseModule): user_words = {} for user_id, word_count in words: _, nickname = self.bot.database.users.by_id(user_id) - user_words[nickname] = word_count + if not nickname in user_words: + user_words[nickname] = 0 + user_words[nickname] += word_count top_10 = utils.top_10(user_words, convert_key=lambda nickname: self._get_nickname( diff --git a/modules/yourls.py b/modules/yourls.py new file mode 100644 index 00000000..3d0fa6e1 --- /dev/null +++ b/modules/yourls.py @@ -0,0 +1,34 @@ +import urllib.parse +from src import ModuleManager, utils + +def _parse(s): + parsed = urllib.parse.urlparse(s) + return urllib.parse.urljoin(s, parsed.path), parsed.query + +SETTING = utils.FunctionSetting(_parse, "yourls", + "Set YOURLS server (and token) to use for URL shortening", + example="https://bitbot.dev/yourls-api.php?1002a612b4", + format=utils.sensitive_format) + +@utils.export("botset", SETTING) +@utils.export("serverset", SETTING) +@utils.export("channelset", SETTING) +class Module(ModuleManager.BaseModule): + @utils.export("shorturl-x-yourls") + def _shorturl(self, server, context, url): + setting = server.get_setting("yourls", + self.bot.get_setting("yourls", None)) + if context: + setting = context.get_setting("yourls", setting) + + if not setting == None: + shortener_url, token = setting + + page = utils.http.request(shortener_url, get_params={ + "signature": token, + "action": "shorturl", + "url": url, + "format": "json"}).json() + if page: + return page["shorturl"] + return None diff --git a/modules/youtube.py b/modules/youtube.py index f5710199..7699fdd6 100644 --- a/modules/youtube.py +++ b/modules/youtube.py @@ -24,6 +24,8 @@ ARROW_DOWN = "↓" "Turn safe search off/on")) class Module(ModuleManager.BaseModule): def get_video_page(self, video_id): + self.log.debug("youtube API request: " + "videos.list [contentDetails,snippet,statistics]") return utils.http.request(URL_YOUTUBEVIDEO, get_params={ "part": "contentDetails,snippet,statistics", "id": video_id, "key": self.bot.config["google-api-key"]}).json() @@ -76,7 +78,10 @@ class Module(ModuleManager.BaseModule): return None def get_playlist_page(self, playlist_id): - return utils.http.request(URL_YOUTUBEPLAYLIST, get_params={ + self.log.debug("youtube API request: " + "playlists.list [contentDetails,snippet]") + + return utils.http.request(URL_YOUTUBEPLAYLIST, get_params={ "part": "contentDetails,snippet", "id": playlist_id, "key": self.bot.config["google-api-key"]}).json() def playlist_details(self, playlist_id): @@ -109,6 +114,8 @@ class Module(ModuleManager.BaseModule): def _search_youtube(self, query): video_id = "" + self.log.debug("youtube API request: search.list (A) [snippet]") + search_page = utils.http.request(URL_YOUTUBESEARCH, get_params={"q": query, "part": "snippet", "maxResults": "1", "type": "video", @@ -144,6 +151,9 @@ class Module(ModuleManager.BaseModule): if not url: safe_setting = event["target"].get_setting("youtube-safesearch", True) safe = "moderate" if safe_setting else "none" + + self.log.debug("youtube API request: search.list (B) [snippet]") + search_page = utils.http.request(URL_YOUTUBESEARCH, get_params={"q": search, "part": "snippet", "maxResults": "1", "type": "video", "key": self.bot.config["google-api-key"], diff --git a/requirements.txt b/requirements.txt index 093bb8cc..94ee9290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ beautifulsoup4 ==4.8.0 cryptography ==2.7 -dataclasses ==0.6 +dataclasses ==0.6;python_version<'3.7' dnspython ==1.16.0 feedparser ==5.2.1 html5lib ==1.0.1 @@ -15,3 +15,4 @@ scrypt ==0.8.13 suds-jurko ==0.6 tornado ==6.0.3 tweepy ==3.8.0 +requests-toolbelt ==0.9.1 diff --git a/src/Exports.py b/src/Exports.py index a9fe3175..64c10d69 100644 --- a/src/Exports.py +++ b/src/Exports.py @@ -35,7 +35,7 @@ class Exports(object): return self._exports.get(setting, []) + sum([ exports.get(setting, []) for exports in self._context_exports.values()], []) - def get_one(self, setting: str, default: typing.Any=None + def get(self, setting: str, default: typing.Any=None ) -> typing.Optional[typing.Any]: values = self.get_all(setting) return values[0] if values else default @@ -60,8 +60,8 @@ class ExportsContext(object): self._parent._context_add(self.context, setting, value) def get_all(self, setting: str) -> typing.List[typing.Any]: return self._parent.get_all(setting) - def get_one(self, setting: str, default: typing.Any=None + def get(self, setting: str, default: typing.Any=None ) -> typing.Optional[typing.Any]: - return self._parent.get_one(setting, default) + return self._parent.get(setting, default) def find(self, setting_prefix: str) -> typing.List[typing.Any]: return self._parent.find(setting_prefix) diff --git a/src/IRCBot.py b/src/IRCBot.py index f5b66b23..b2c2861f 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -204,8 +204,12 @@ class Bot(object): try: server.connect() except Exception as e: - self.log.warn("Failed to connect to %s: %s", - [str(server), str(e)]) + ip = "" + if server.socket.connected_ip is not None: + ip = f" ({server.socket.connected_ip})" + + self.log.warn("Failed to connect to %s%s: %s", + [str(server), ip, str(e)]) self.log.debug("Connection failure reason:", exc_info=True) return False self.servers[server.fileno()] = server diff --git a/src/IRCChannel.py b/src/IRCChannel.py index 028a51b7..4e8aed9d 100644 --- a/src/IRCChannel.py +++ b/src/IRCChannel.py @@ -210,6 +210,8 @@ class Channel(IRCObject.Object): def send_message(self, text: str, tags: dict={}): return self.server.send_message(self.name, text, tags=tags) + def send_action(self, text: str, tags: dict={}): + return self.server.send_action(self.name, text, tags=tags) def send_notice(self, text: str, tags: dict={}): return self.server.send_notice(self.name, text, tags=tags) def send_tagmsg(self, tags: dict): @@ -260,7 +262,10 @@ class Channel(IRCObject.Object): return True return False - def has_mode(self, user: IRCUser.User, mode: str) -> bool: + def has_mode(self, mode: str) -> bool: + return mode in self.modes + + def has_umode(self, user: IRCUser.User, mode: str) -> bool: return user in self.modes.get(mode, []) def get_user_modes(self, user: IRCUser.User) -> typing.Set: diff --git a/src/IRCLine.py b/src/IRCLine.py index 62d9e8b7..1dda449b 100644 --- a/src/IRCLine.py +++ b/src/IRCLine.py @@ -1,8 +1,7 @@ -import datetime, typing, uuid +import codecs, datetime, typing, uuid from src import EventManager, IRCObject, utils -# this should be 510 (RFC1459, 512 with \r\n) but a server BitBot uses is broken -LINE_MAX = 470 +LINE_MAX = 510 class IRCArgs(object): def __init__(self, args: typing.List[str]): @@ -125,42 +124,43 @@ class ParsedLine(object): return tags, " ".join(pieces).replace("\r", "") def format(self) -> str: tags, line = self._format() - line, _ = self._newline_truncate(line) if tags: return "%s %s" % (tags, line) else: return line - def _newline_truncate(self, line: str) -> typing.Tuple[str, str]: - line, sep, overflow = line.partition("\n") - return (line, overflow) - def _line_max(self, hostmask: str, margin: int) -> int: - return LINE_MAX-len((":%s " % hostmask).encode("utf8"))-margin - def truncate(self, hostmask: str, margin: int=0) -> typing.Tuple[str, str]: - valid_bytes = b"" - valid_index = -1 - - line_max = self._line_max(hostmask, margin) +class SendableLine(ParsedLine): + def __init__(self, command: str, args: typing.List[str], + margin: int=0, tags: typing.Dict[str, str]=None): + ParsedLine.__init__(self, command, args, None, tags) + self._margin = margin - tags_formatted, line_formatted = self._format() - for i, char in enumerate(line_formatted): - encoded_char = char.encode("utf8") - if (len(valid_bytes)+len(encoded_char) > line_max - or encoded_char == b"\n"): - break - else: - valid_bytes += encoded_char - valid_index = i - valid_index += 1 + def push_last(self, arg: str, extra_margin: int=0, + human_trunc: bool=False) -> typing.Optional[str]: + last_arg = self.args[-1] + tags, line = self._format() + n = len(line.encode("utf8")) # get length of current line + n += self._margin # margin used for :hostmask + if " " in arg and not " " in last_arg: + n += 1 # +1 for colon on new arg + n += extra_margin # used for things like (more ...) - valid = line_formatted[:valid_index] - if tags_formatted: - valid = "%s %s" % (tags_formatted, valid) - overflow = line_formatted[valid_index:] - if overflow and overflow[0] == "\n": - overflow = overflow[1:] + overflow: typing.Optional[str] = None - return valid, overflow + if (n+len(arg.encode("utf8"))) > LINE_MAX: + for i, char in enumerate(codecs.iterencode(arg, "utf8")): + n += len(char) + if n > LINE_MAX: + arg, overflow = arg[:i], arg[i:] + if human_trunc and not overflow[0] == " ": + new_arg, sep, new_overflow = arg.rpartition(" ") + if sep: + arg = new_arg + overflow = new_overflow+overflow + break + if arg: + self.args[-1] = last_arg+arg + return overflow def parse_line(line: str) -> ParsedLine: tags = {} # type: typing.Dict[str, typing.Any] @@ -220,7 +220,7 @@ class SentLine(IRCObject.Object): return self._for_wire() def _for_wire(self) -> str: - return self.parsed_line.truncate(self._hostmask)[0] + return str(self.parsed_line) def for_wire(self) -> bytes: return b"%s\r\n" % self._for_wire().encode("utf8") diff --git a/src/IRCServer.py b/src/IRCServer.py index 1454fedb..bb22adab 100644 --- a/src/IRCServer.py +++ b/src/IRCServer.py @@ -80,6 +80,11 @@ class Server(IRCObject.Object): def hostmask(self): return "%s!%s@%s" % (self.nickname, self.username, self.hostname) + def new_line(self, command: str, args: typing.List[str]=None, + tags: typing.Dict[str, str]=None) -> IRCLine.SendableLine: + return IRCLine.SendableLine(command, args or [], + len((":%s " % self.hostmask()).encode("utf8")), tags) + def connect(self): self.socket = IRCSocket.Socket( self.bot.log, @@ -385,6 +390,10 @@ class Server(IRCObject.Object): def send_message(self, target: str, message: str, tags: dict={} ) -> typing.Optional[IRCLine.SentLine]: return self.send(self._line("PRIVMSG", [target, message], tags=tags)) + def send_action(self, target: str, message: str, tags: dict={} + ) -> typing.Optional[IRCLine.SentLine]: + return self.send(self._line("PRIVMSG", + [target, f"\x01ACTION {message}\x01"], tags=tags)) def send_notice(self, target: str, message: str, tags: dict={} ) -> typing.Optional[IRCLine.SentLine]: diff --git a/src/IRCSocket.py b/src/IRCSocket.py index 7fbae39d..507c6471 100644 --- a/src/IRCSocket.py +++ b/src/IRCSocket.py @@ -69,11 +69,12 @@ class Socket(IRCObject.Object): 5.0) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self.connected_ip = self._socket.getpeername()[0] + if self._tls: self._tls_wrap() self.connect_time = time.time() - self.connected_ip = self._socket.getpeername()[0] self.cached_fileno = self._socket.fileno() self.connected = True diff --git a/src/IRCUser.py b/src/IRCUser.py index 2e141794..6db124f0 100644 --- a/src/IRCUser.py +++ b/src/IRCUser.py @@ -77,6 +77,8 @@ class User(IRCObject.Object): def send_message(self, message: str, tags: dict={}): self.server.send_message(self.nickname, message, tags=tags) + def send_action(self, message: str, tags: dict={}): + self.server.send_action(self.nickname, message, tags=tags) def send_notice(self, text: str, tags: dict={}): self.server.send_notice(self.nickname, text, tags=tags) def send_ctcp_response(self, command: str, args: str): diff --git a/src/core_modules/admin.py b/src/core_modules/admin.py index f87927b1..5008952d 100644 --- a/src/core_modules/admin.py +++ b/src/core_modules/admin.py @@ -32,9 +32,17 @@ class Module(ModuleManager.BaseModule): @utils.kwarg("permission", "part") @utils.kwarg("require_mode", "high") @utils.kwarg("require_access", "high,part") - @utils.spec("!r~channel") + @utils.spec("!-privateonly !<channel>word") + @utils.spec("!-channelonly ?<channel>word") def part(self, event): - event["server"].send_part(event["spec"][0].name) + event["server"].send_part(event["spec"][0] or event["target"].name) + + @utils.hook("received.command.join") + @utils.kwarg("help", "Join a given channel") + @utils.kwarg("permission", "join") + @utils.spec("!<channel>word") + def join(self, event): + event["server"].send_join(event["spec"][0]) def _id_from_alias(self, alias): return self.bot.database.servers.get_by_alias(alias) diff --git a/src/core_modules/aliases.py b/src/core_modules/aliases.py index b6ca8961..b5ed86f1 100644 --- a/src/core_modules/aliases.py +++ b/src/core_modules/aliases.py @@ -3,6 +3,8 @@ from src import EventManager, ModuleManager, utils SETTING_PREFIX = "command-alias-" +class VariableKeyError(KeyError): + pass class Module(ModuleManager.BaseModule): def _arg_replace(self, s, args_split, kwargs): vars = {} @@ -11,7 +13,11 @@ class Module(ModuleManager.BaseModule): vars["%d-" % i] = " ".join(args_split[i:]) vars["-"] = " ".join(args_split) vars.update(kwargs) - return utils.parse.format_token_replace(s, vars) + + not_found, new_s = utils.parse.format_token_replace(s, vars) + if not_found: + raise VariableKeyError(f"not found: {not_found!r}") + return new_s def _get_alias(self, server, target, command): setting = "%s%s" % (SETTING_PREFIX, command) @@ -45,9 +51,14 @@ class Module(ModuleManager.BaseModule): if event["command"].args: given_args = event["command"].args.split(" ") - event["command"].command = alias - event["command"].args = self._arg_replace(alias_args, given_args, - event["kwargs"]) + try: + event["command"].args = self._arg_replace(alias_args, + given_args, event["kwargs"]) + except VariableKeyError: + pass + else: + event["command"].command = alias + @utils.hook("received.command.alias", permission="alias") diff --git a/src/core_modules/banmask.py b/src/core_modules/banmask.py new file mode 100644 index 00000000..bc33d4a9 --- /dev/null +++ b/src/core_modules/banmask.py @@ -0,0 +1,24 @@ +from src import ModuleManager, utils + +SETTING = utils.Setting("ban-format", + "Set ban format " + "(${n} = nick, ${u} = username, ${h} = hostname, ${a} = account", + example="*!${u}@${h}") + +@utils.export("channelset", SETTING) +@utils.export("serverset", SETTING) +class Module(ModuleManager.BaseModule): + def _format_hostmask(self, user, s): + vars = {} + vars["n"] = vars["nickname"] = user.nickname + vars["u"] = vars["username"] = user.username + vars["h"] = vars["hostname"] = user.hostname + vars["a"] = vars["account"] = user.account or "" + missing, out = utils.parse.format_token_replace(s, vars) + return out + @utils.export("ban-mask") + def banmask(self, server, channel, user): + format = channel.get_setting("ban-format", + server.get_setting("ban-format", "*!${u}@${h}")) + return self._format_hostmask(user, format) + diff --git a/src/core_modules/channel_access.py b/src/core_modules/channel_access.py index 662599e0..a8c371a8 100644 --- a/src/core_modules/channel_access.py +++ b/src/core_modules/channel_access.py @@ -21,7 +21,7 @@ class Module(ModuleManager.BaseModule): user_access = target.get_user_setting(user.get_id(), "access", []) - identified = self.exports.get_one("is-identified")(user) + identified = self.exports.get("is-identified")(user) matched = list(set(required_access)&set(user_access)) return ("*" in user_access or matched) and identified diff --git a/src/core_modules/command_spec/__init__.py b/src/core_modules/command_spec/__init__.py index 34fd5aef..6345d342 100644 --- a/src/core_modules/command_spec/__init__.py +++ b/src/core_modules/command_spec/__init__.py @@ -48,7 +48,7 @@ class Module(ModuleManager.BaseModule): if argument_type.type in types.TYPES: func = types.TYPES[argument_type.type] else: - func = self.exports.get_one( + func = self.exports.get( "command-spec.%s" % argument_type.type) if func: @@ -142,7 +142,7 @@ class Module(ModuleManager.BaseModule): usages = [ utils.parse.argument_spec_human(s, context) for s in specs] command = "%s%s" % (event["command_prefix"], event["command"]) - usages = ["%s%s" % (command, u) for u in usages] + usages = ["%s %s" % (command, u) for u in usages] error_out = "%s (Usage: %s)" % (overall_error, " | ".join(usages)) diff --git a/src/core_modules/commands/__init__.py b/src/core_modules/commands/__init__.py index a5c5faa6..3a0e74cd 100644 --- a/src/core_modules/commands/__init__.py +++ b/src/core_modules/commands/__init__.py @@ -8,8 +8,8 @@ COMMAND_METHOD = "command-method" COMMAND_METHODS = ["PRIVMSG", "NOTICE"] STR_MORE = " (more...)" +STR_CONTINUED = "(...continued) " STR_MORE_LEN = len(STR_MORE.encode("utf8")) -STR_CONTINUED = "(...continued)" WORD_BOUNDARIES = [" "] NON_ALPHANUMERIC = [char for char in string.printable if not char.isalnum()] @@ -73,12 +73,12 @@ class Module(ModuleManager.BaseModule): self.bot.get_setting(COMMAND_METHOD, default))).upper() def _find_command_hook(self, server, target, is_channel, command, user, - args): + command_prefix, args): if not self.has_command(command): command_event = CommandEvent(command, args) self.events.on("get.command").call(command=command_event, server=server, target=target, is_channel=is_channel, user=user, - kwargs={}) + command_prefix=command_prefix, kwargs={}) command = command_event.command args = command_event.args @@ -237,29 +237,25 @@ class Module(ModuleManager.BaseModule): color = utils.consts.RED line_str = obj.pop() + prefix = "" if obj.prefix: - line_str = "[%s] %s" % ( - utils.irc.color(obj.prefix, color), line_str) + prefix = "[%s] " % utils.irc.color(obj.prefix, color) + if obj._overflowed: + prefix = "%s%s" % (prefix, STR_CONTINUED) method = self._command_method(server, target, is_channel) if not method in ["PRIVMSG", "NOTICE"]: raise ValueError("Unknown command-method '%s'" % method) - line = IRCLine.ParsedLine(method, [target_str, line_str], - tags=tags) - valid, trunc = line.truncate(server.hostmask(), - margin=STR_MORE_LEN) + line = server.new_line(method, [target_str, prefix], tags=tags) + + overflow = line.push_last(line_str, human_trunc=True, + extra_margin=STR_MORE_LEN) + if overflow: + line.push_last(STR_MORE) + obj.insert(overflow) + obj._overflowed = True - if trunc: - if not trunc[0] in WORD_BOUNDARIES: - for boundary in WORD_BOUNDARIES: - left, *right = valid.rsplit(boundary, 1) - if right: - valid = left - trunc = right[0]+trunc - obj.insert("%s %s" % (STR_CONTINUED, trunc)) - valid = valid+STR_MORE - line = IRCLine.parse_line(valid) if obj._assured: line.assure() server.send(line) @@ -309,7 +305,7 @@ class Module(ModuleManager.BaseModule): try: hook, command, args_split = self._find_command_hook( event["server"], event["channel"], True, command, - event["user"], args) + event["user"], command_prefix, args) except BadContextException: event["channel"].send_message( "%s: That command is not valid in a channel" % @@ -371,7 +367,7 @@ class Module(ModuleManager.BaseModule): try: hook, command, args_split = self._find_command_hook( event["server"], event["user"], False, command, - event["user"], args) + event["user"], "", args) except BadContextException: event["user"].send_message( "That command is not valid in a PM") diff --git a/src/core_modules/commands/outs.py b/src/core_modules/commands/outs.py index e82ceefd..c6b489ae 100644 --- a/src/core_modules/commands/outs.py +++ b/src/core_modules/commands/outs.py @@ -6,6 +6,13 @@ class StdOut(object): self.prefix = prefix self._lines = [] self._assured = False + self._overflowed = False + + def copy_from(self, other): + self.prefix = other.prefix + self._lines = other._lines + self._assured = other._assured + self._overflowed = other._overflowed def assure(self): self._assured = True diff --git a/src/core_modules/line_handler/__init__.py b/src/core_modules/line_handler/__init__.py index a77dc451..f23a98ed 100644 --- a/src/core_modules/line_handler/__init__.py +++ b/src/core_modules/line_handler/__init__.py @@ -205,6 +205,10 @@ class Module(ModuleManager.BaseModule): @utils.hook("raw.received.chghost") def chghost(self, event): user.chghost(self.events, event) + # RPL_VISIBLEHOST, telling us what our hostname (and sometimes username) is + @utils.hook("raw.received.396") + def handle_396(self, event): + core.handle_396(event) # IRCv3 SETNAME, to change a user's realname @utils.hook("raw.received.setname") diff --git a/src/core_modules/line_handler/core.py b/src/core_modules/line_handler/core.py index a7075613..aa862a03 100644 --- a/src/core_modules/line_handler/core.py +++ b/src/core_modules/line_handler/core.py @@ -165,3 +165,9 @@ def handle_433(event): _nick_in_use(event["server"]) def handle_437(event): _nick_in_use(event["server"]) + +def handle_396(event): + username, sep, hostname = event["line"].args[1].rpartition("@") + event["server"].hostname = hostname + if sep: + event["server"].username = username diff --git a/src/core_modules/line_handler/message.py b/src/core_modules/line_handler/message.py index 87b91c9e..b3e64be1 100644 --- a/src/core_modules/line_handler/message.py +++ b/src/core_modules/line_handler/message.py @@ -64,7 +64,7 @@ def message(events, event): action = False - if message: + if not message == None: ctcp_message = utils.irc.parse_ctcp(message) if ctcp_message: @@ -97,7 +97,7 @@ def message(events, event): hook = events.on(direction).on(event_type).on(context) buffer_line = None - if message: + if not message == None: buffer_line = IRCBuffer.BufferLine(user.nickname, message, action, event["line"].tags, from_self, event["line"].command) diff --git a/src/core_modules/mode_lists.py b/src/core_modules/mode_lists.py index f560ffa0..37918608 100644 --- a/src/core_modules/mode_lists.py +++ b/src/core_modules/mode_lists.py @@ -24,7 +24,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.348") def on_348(self, event): mode = self._excepts(event["server"]) - self._mode_list_mask(event, mode, event["line"].args[3]) + self._mode_list_mask(event, mode, event["line"].args[2]) @utils.hook("received.349") def on_349(self, event): self._mode_list_end(event, self._excepts(event["server"])) @@ -35,7 +35,7 @@ class Module(ModuleManager.BaseModule): @utils.hook("received.346") def on_346(self, event): mode = self._invex(event["server"]) - self._mode_list_mask(event, mode, event["line"].args[3]) + self._mode_list_mask(event, mode, event["line"].args[2]) @utils.hook("received.347") def on_347(self, event): self._mode_list_end(event, self._invex(event["server"])) diff --git a/src/core_modules/more.py b/src/core_modules/more.py index 52849938..bc76fb6b 100644 --- a/src/core_modules/more.py +++ b/src/core_modules/more.py @@ -20,4 +20,4 @@ class Module(ModuleManager.BaseModule): def more(self, event): last_stdout = event["target"]._last_stdout if last_stdout and last_stdout.has_text(): - event["stdout"].write_lines(last_stdout.get_all()) + event["stdout"].copy_from(last_stdout) diff --git a/src/core_modules/perform.py b/src/core_modules/perform.py index 832cab54..06f8cd89 100644 --- a/src/core_modules/perform.py +++ b/src/core_modules/perform.py @@ -24,53 +24,39 @@ class Module(ModuleManager.BaseModule): self._execute(event["server"], commands, NICK=event["server"].nickname, CHAN=event["channel"].name) - def _perform(self, target, args_split): - subcommand = args_split[0].lower() + def _perform(self, target, spec): + subcommand = spec[0] current_perform = target.get_setting("perform", []) if subcommand == "list": return "Configured commands: %s" % ", ".join(current_perform) message = None if subcommand == "add": - if not len(args_split) > 1: - raise utils.EventError("Please provide a raw command to add") - current_perform.append(" ".join(args_split[1:])) + current_perform.append(spec[1]) message = "Added command" elif subcommand == "remove": - if not len(args_split) > 1: - raise utils.EventError("Please provide an index to remove") - if not args_split[1].isdigit(): - raise utils.EventError("Please provide a number") - index = int(args_split[1]) + index = spec[1] if not index < len(current_perform): raise utils.EventError("Index out of bounds") - current_perform.pop(index) - message = "Removed command" - else: - raise utils.EventError("Unknown subcommand '%s'" % subcommand) + command = current_perform.pop(index) + message = "Removed command %d (%s)" % (index, command) target.set_setting("perform", current_perform) return message - @utils.hook("received.command.perform", min_args=1) - @utils.kwarg("min_args", 1) + @utils.hook("received.command.perform", permission="perform", + help="Edit on-connect command configuration") + @utils.hook("received.command.cperform", permission="perform", + help="Edit channel on-join command configuration", channel_only=True) @utils.kwarg("help", "Edit on-connect command configuration") - @utils.kwarg("usage", "list") - @utils.kwarg("usage", "add <raw command>") - @utils.kwarg("usage", "remove <index>") - @utils.kwarg("permission", "perform") + @utils.spec("!'list") + @utils.spec("!'add !<command>string") + @utils.spec("!'remove !<index>int") def perform(self, event): - event["stdout"].write(self._perform(event["server"], - event["args_split"])) + if event["command"] == "perform": + target = event["server"] + elif event["command"] == "cperform": + target = event["target"] - @utils.hook("received.command.cperform", min_args=1) - @utils.kwarg("min_args", 1) - @utils.kwarg("channel_only", True) - @utils.kwarg("help", "Edit channel on-join command configuration") - @utils.kwarg("usage", "list") - @utils.kwarg("usage", "add <raw command>") - @utils.kwarg("usage", "remove <index>") - @utils.kwarg("permission", "cperform") - def cperform(self, event): - event["stdout"].write(self._perform(event["target"], - event["args_split"])) + out = self._perform(target, event["spec"]) + event["stdout"].write("%s: %s" % (event["user"].nickname, out)) diff --git a/src/utils/datetime/format.py b/src/utils/datetime/format.py index 889f5906..4b56e7b3 100644 --- a/src/utils/datetime/format.py +++ b/src/utils/datetime/format.py @@ -89,7 +89,7 @@ def to_pretty_time(total_seconds: int, max_units: int=UNIT_MINIMUM, if hours and len(out) < max_units: out.append("%dh" % hours) if minutes and len(out) < max_units: - out.append("%dmi" % minutes) + out.append("%dm" % minutes) if seconds and len(out) < max_units: out.append("%ds" % seconds) diff --git a/src/utils/http.py b/src/utils/http.py index 9f25b315..045d9641 100644 --- a/src/utils/http.py +++ b/src/utils/http.py @@ -3,6 +3,7 @@ import typing, urllib.error, urllib.parse, uuid import json as _json import bs4, netifaces, requests, tornado.httpclient from src import IRCBot, utils +from requests_toolbelt.adapters import source REGEX_URL = re.compile("https?://\S+", re.I) @@ -69,6 +70,7 @@ class Request(object): json_body: bool = False allow_redirects: bool = True + check_hostname: bool = False check_content_type: bool = True fallback_encoding: typing.Optional[str] = None content_type: typing.Optional[str] = None @@ -77,6 +79,8 @@ class Request(object): timeout: int=5 + bindhost: typing.Optional[str] = None + def validate(self): self.id = self.id or str(uuid.uuid4()) self.set_url(self.url) @@ -169,24 +173,61 @@ def request(request_obj: typing.Union[str, Request], **kwargs) -> Response: request_obj = Request(request_obj, **kwargs) return _request(request_obj) +class HostNameInvalidError(ValueError): + pass +class TooManyRedirectionsError(Exception): + pass + def _request(request_obj: Request) -> Response: request_obj.validate() + + def _assert_allowed(url: str): + hostname = urllib.parse.urlparse(url).hostname + if hostname is None or not host_permitted(hostname): + raise HostNameInvalidError( + f"hostname {hostname} is not permitted") + def _wrap() -> Response: headers = request_obj.get_headers() - response = requests.request( - request_obj.method, - request_obj.url, - headers=headers, - params=request_obj.get_params, - data=request_obj.get_body(), - allow_redirects=request_obj.allow_redirects, - stream=True, - cookies=request_obj.cookies - ) - response_content = response.raw.read(RESPONSE_MAX, - decode_content=True) - if not response.raw.read(1) == b"": - raise ValueError("Response too large") + + redirect = 0 + current_url = request_obj.url + session = requests.Session() + if not request_obj.bindhost is None: + new_source = source.SourceAddressAdapter(request_obj.bindhost) + session.mount('http://', new_source) + session.mount('https://', new_source) + + while True: + if request_obj.check_hostname: + _assert_allowed(current_url) + + response = session.request( + request_obj.method, + current_url, + headers=headers, + params=request_obj.get_params, + data=request_obj.get_body(), + allow_redirects=False, + stream=True, + cookies=request_obj.cookies + ) + + if response.status_code in [301, 302]: + redirect += 1 + if redirect == 5: + raise TooManyRedirectionsError(f"{redirect} redirects") + else: + current_url = response.headers["location"] + continue + + response_content = response.raw.read(RESPONSE_MAX, + decode_content=True) + if not response.raw.read(1) == b"": + raise ValueError("Response too large") + break + + session.close() headers = utils.CaseInsensitiveDict(dict(response.headers)) our_response = Response(response.status_code, response_content, diff --git a/src/utils/irc.py b/src/utils/irc.py index 400d5352..f93f61ed 100644 --- a/src/utils/irc.py +++ b/src/utils/irc.py @@ -75,48 +75,30 @@ FORMAT_TOKENS = [ FORMAT_STRIP = [ "\x08" # backspace ] -def _format_tokens(s: str) -> typing.List[str]: - is_color = False - foreground: typing.List[str] = [] - background: typing.List[str] = [] - is_background = False - matches = [] # type: typing.List[str] - - for i, char in enumerate(s): - last_char = i == len(s)-1 - if is_color: - current_color = background if is_background else foreground - color_finished = True - - if char == "," and not is_background: - is_background = True - color_finished = False - elif char.isdigit() and len(current_color) < 2: - current_color.append(char) - color_finished = len(current_color) == 2 and is_background - - if color_finished or last_char: - color = "".join(foreground) - if background: - color += "".join([","]+background) +def _format_tokens(s: str) -> typing.List[str]: + tokens: typing.List[str] = [] - matches.append("\x03%s" % color) - is_color = False - foreground = [] - background = [] - is_background = False + s_copy = list(s) + while s_copy: + token = s_copy.pop(0) + if token == "\x03": + for i in range(2): + if s_copy and s_copy[0].isdigit(): + token += s_copy.pop(0) + if (len(s_copy) > 1 and + s_copy[0] == "," and + s_copy[1].isdigit()): + token += s_copy.pop(0) + token += s_copy.pop(0) + if s_copy and s_copy[0].isdigit(): + token += s_copy.pop(0) - if char == consts.COLOR: - if is_color: - matches.append(char) - else: - is_color = True - elif char in FORMAT_TOKENS: - matches.append(char) - elif char in FORMAT_STRIP: - matches.append(char) - return matches + tokens.append(token) + elif (token in FORMAT_TOKENS or + token in FORMAT_STRIP): + tokens.append(token) + return tokens def _color_match(code: typing.Optional[str], foreground: bool) -> str: if not code: diff --git a/src/utils/parse/__init__.py b/src/utils/parse/__init__.py index 262edf4a..36da8ffb 100644 --- a/src/utils/parse/__init__.py +++ b/src/utils/parse/__init__.py @@ -131,7 +131,7 @@ def format_tokens(s: str, sigil: str="$" return tokens def format_token_replace(s: str, vars: typing.Dict[str, str], - sigil: str="$") -> str: + sigil: str="$") -> typing.Tuple[typing.List[str], str]: vars = vars.copy() vars.update({sigil: sigil}) @@ -140,7 +140,10 @@ def format_token_replace(s: str, vars: typing.Dict[str, str], tokens.sort(key=lambda x: x[0]) tokens.reverse() + not_found: typing.List[str] = [] for start, end, token in tokens: if token in vars: s = s[:start] + vars[token] + s[end+1:] - return s + else: + not_found += token + return not_found, s diff --git a/src/utils/parse/sed.py b/src/utils/parse/sed.py index 76b9e567..8a1c895d 100644 --- a/src/utils/parse/sed.py +++ b/src/utils/parse/sed.py @@ -76,6 +76,6 @@ def parse(sed_s: str) -> typing.Optional[Sed]: return SedMatch(type, re.compile(pattern, flags)) return None -def sed(sed_obj: Sed, s: str) -> typing.Tuple[str, typing.Optional[str]]: +def sed(sed_obj: Sed, s: str) -> typing.Optional[str]: out = sed_obj.match(s) - return sed_obj.type, out + return out diff --git a/src/utils/parse/spec.py b/src/utils/parse/spec.py index 7d7ffd3a..2a682c3a 100644 --- a/src/utils/parse/spec.py +++ b/src/utils/parse/spec.py @@ -15,6 +15,7 @@ class SpecArgumentContext(enum.IntFlag): class SpecArgumentType(object): context = SpecArgumentContext.ALL + _modifier: typing.Optional[str] def __init__(self, type_name: str, name: typing.Optional[str], modifier: typing.Optional[str], exported: typing.Optional[str]): @@ -24,7 +25,7 @@ class SpecArgumentType(object): self.exported = exported def _set_modifier(self, modifier: typing.Optional[str]): - pass + self._modifier = modifier def name(self) -> typing.Optional[str]: return self._name @@ -102,6 +103,18 @@ class SpecArgumentTypeDate(SpecArgumentType): return date_human(args[0]), 1 return None, 1 +class SpecArgumentTypeFlag(SpecArgumentType): + def _str(self): + pref = "-" if len(self._modifier) == 1 else "--" + return f"{pref}{self._modifier}" + def name(self): + return self._str() + def simple(self, args): + print("flag _str()", self._str()) + if args and args[0] == self._str(): + return True, 1 + return None, 1 + class SpecArgumentPrivateType(SpecArgumentType): context = SpecArgumentContext.PRIVATE @@ -115,7 +128,8 @@ SPEC_ARGUMENT_TYPES = { "int": SpecArgumentTypeInt, "date": SpecArgumentTypeDate, "duration": SpecArgumentTypeDuration, - "pattern": SpecArgumentTypePattern + "pattern": SpecArgumentTypePattern, + "flag": SpecArgumentTypeFlag } class SpecArgument(object): diff --git a/src/utils/settings.py b/src/utils/settings.py index 97fe885b..53ee0ebd 100644 --- a/src/utils/settings.py +++ b/src/utils/settings.py @@ -49,15 +49,17 @@ class IntSetting(Setting): class IntRangeSetting(IntSetting): example: typing.Optional[str] = None - def __init__(self, n_min: int, n_max: int, name: str, help: str=None, - example: str=None): + def __init__(self, n_min: int, n_max: typing.Optional[int], name: str, + help: str=None, example: str=None): self._n_min = n_min self._n_max = n_max Setting.__init__(self, name, help, example) def parse(self, value: str) -> typing.Any: out = IntSetting.parse(self, value) - if not out == None and self._n_min <= out <= self._n_max: + if (not out == None and + self._n_min <= out and + (self._n_max == None or out <= self._n_max)): return out return None |
