diff options
| author | 2019-06-14 12:01:55 +0100 | |
|---|---|---|
| committer | 2019-06-14 12:01:55 +0100 | |
| commit | f05fc209b015e9d883566fc8cc4141dc9eff4db3 (patch) | |
| tree | 599c4bb8f31e22af304c3b8f20977738b9b55649 | |
| parent | Catch `yield`s in command callbacks for e.g. permission checks (diff) | |
| signature | ||
Add a way to __or__ `utils.Check`s, as a "if one of these passes" thing
| -rw-r--r-- | modules/check_mode.py | 2 | ||||
| -rw-r--r-- | modules/commands/__init__.py | 22 | ||||
| -rw-r--r-- | modules/permissions/__init__.py | 16 | ||||
| -rw-r--r-- | src/utils/__init__.py | 9 |
4 files changed, 37 insertions, 12 deletions
diff --git a/modules/check_mode.py b/modules/check_mode.py index add90cc8..f3d3edb5 100644 --- a/modules/check_mode.py +++ b/modules/check_mode.py @@ -34,4 +34,4 @@ class Module(ModuleManager.BaseModule): @utils.hook("check.command.channel-mode") def check_command(self, event): - return self._check_command(event, event["check_args"][0]) + return self._check_command(event, event["request_args"][0]) diff --git a/modules/commands/__init__.py b/modules/commands/__init__.py index 3d948c68..3b4dc74e 100644 --- a/modules/commands/__init__.py +++ b/modules/commands/__init__.py @@ -130,12 +130,16 @@ class Module(ModuleManager.BaseModule): return hook, args_split - def _check(self, context, kwargs, request=None, **extra_kwargs): + def _check(self, context, kwargs, requests=[]): event_hook = self.events.on(context).on("command") - if not request == None: - event_hook = event_hook.on(request) - returns = event_hook.call_unsafe(**kwargs, **extra_kwargs) + returns = [] + if requests: + for request, request_args in requests: + returns.append(event_hook.on(request).call_unsafe_for_result( + **kwargs, request_args=request_args)) + else: + returns = event_hook.call_unsafe(**kwargs) hard_fail = False force_success = False @@ -174,10 +178,16 @@ class Module(ModuleManager.BaseModule): break if next_success: + multi_check = None if isinstance(next_return, utils.Check): + multi_check = next_return.to_multi() + elif isinstance(next_return, utils.MultiCheck): + multi_check = next_return + + if multi_check: check_success, check_message = self._check("check", - check_kwargs, next_return.request, - check_args=next_return.args) + check_kwargs, multi_check.requests) + if not check_success: return False, check_message else: diff --git a/modules/permissions/__init__.py b/modules/permissions/__init__.py index 3e2ce3f1..269f2a0f 100644 --- a/modules/permissions/__init__.py +++ b/modules/permissions/__init__.py @@ -191,14 +191,10 @@ class Module(ModuleManager.BaseModule): event["stdout"].write("Reset password for '%s'" % target.nickname) - @utils.hook("preprocess.command") - def preprocess_command(self, event): + def _check_command(self, event, permission, authenticated): if event["user"].admin_master: return utils.consts.PERMISSION_FORCE_SUCCESS - permission = event["hook"].get_kwarg("permission", None) - authenticated = event["hook"].kwargs.get("authenticated", False) - identity_mechanism = event["server"].get_setting("identity-mechanism", "internal") identified_account = None @@ -231,6 +227,16 @@ class Module(ModuleManager.BaseModule): else: return utils.consts.PERMISSION_FORCE_SUCCESS + @utils.hook("preprocess.command") + def preprocess_command(self, event): + permission = event["hook"].get_kwarg("permission", None) + authenticated = event["hook"].kwargs.get("authenticated", False) + return self._check_command(event, permission, authenticated) + + @utils.hook("check.command.permission") + def check_command(self, event): + return self._check_command(event, event["request_args"][0], False) + @utils.hook("received.command.mypermissions", authenticated=True) def my_permissions(self, event): """ diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 869690ff..c0acee4d 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -185,10 +185,19 @@ def export(setting: str, value: typing.Any): return module return _export_func +class MultiCheck(object): + def __init__(self, + requests: typing.List[typing.Tuple[str, typing.List[str]]]): + self.requests = requests class Check(object): def __init__(self, request: str, *args: typing.List[str]): self.request = request self.args = args + def to_multi(self): + return MultiCheck([(self.request, self.args)]) + def __or__(self, other: "Check"): + return MultiCheck([(self.request, self.args), + (other.request, other.args)]) TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any] def top_10(items: typing.Dict[typing.Any, typing.Any], |
