aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/check_mode.py2
-rw-r--r--modules/commands/__init__.py22
-rw-r--r--modules/permissions/__init__.py16
-rw-r--r--src/utils/__init__.py9
4 files changed, 37 insertions, 12 deletions
diff --git a/modules/check_mode.py b/modules/check_mode.py
index add90cc8..f3d3edb5 100644
--- a/modules/check_mode.py
+++ b/modules/check_mode.py
@@ -34,4 +34,4 @@ class Module(ModuleManager.BaseModule):
@utils.hook("check.command.channel-mode")
def check_command(self, event):
- return self._check_command(event, event["check_args"][0])
+ return self._check_command(event, event["request_args"][0])
diff --git a/modules/commands/__init__.py b/modules/commands/__init__.py
index 3d948c68..3b4dc74e 100644
--- a/modules/commands/__init__.py
+++ b/modules/commands/__init__.py
@@ -130,12 +130,16 @@ class Module(ModuleManager.BaseModule):
return hook, args_split
- def _check(self, context, kwargs, request=None, **extra_kwargs):
+ def _check(self, context, kwargs, requests=[]):
event_hook = self.events.on(context).on("command")
- if not request == None:
- event_hook = event_hook.on(request)
- returns = event_hook.call_unsafe(**kwargs, **extra_kwargs)
+ returns = []
+ if requests:
+ for request, request_args in requests:
+ returns.append(event_hook.on(request).call_unsafe_for_result(
+ **kwargs, request_args=request_args))
+ else:
+ returns = event_hook.call_unsafe(**kwargs)
hard_fail = False
force_success = False
@@ -174,10 +178,16 @@ class Module(ModuleManager.BaseModule):
break
if next_success:
+ multi_check = None
if isinstance(next_return, utils.Check):
+ multi_check = next_return.to_multi()
+ elif isinstance(next_return, utils.MultiCheck):
+ multi_check = next_return
+
+ if multi_check:
check_success, check_message = self._check("check",
- check_kwargs, next_return.request,
- check_args=next_return.args)
+ check_kwargs, multi_check.requests)
+
if not check_success:
return False, check_message
else:
diff --git a/modules/permissions/__init__.py b/modules/permissions/__init__.py
index 3e2ce3f1..269f2a0f 100644
--- a/modules/permissions/__init__.py
+++ b/modules/permissions/__init__.py
@@ -191,14 +191,10 @@ class Module(ModuleManager.BaseModule):
event["stdout"].write("Reset password for '%s'" %
target.nickname)
- @utils.hook("preprocess.command")
- def preprocess_command(self, event):
+ def _check_command(self, event, permission, authenticated):
if event["user"].admin_master:
return utils.consts.PERMISSION_FORCE_SUCCESS
- permission = event["hook"].get_kwarg("permission", None)
- authenticated = event["hook"].kwargs.get("authenticated", False)
-
identity_mechanism = event["server"].get_setting("identity-mechanism",
"internal")
identified_account = None
@@ -231,6 +227,16 @@ class Module(ModuleManager.BaseModule):
else:
return utils.consts.PERMISSION_FORCE_SUCCESS
+ @utils.hook("preprocess.command")
+ def preprocess_command(self, event):
+ permission = event["hook"].get_kwarg("permission", None)
+ authenticated = event["hook"].kwargs.get("authenticated", False)
+ return self._check_command(event, permission, authenticated)
+
+ @utils.hook("check.command.permission")
+ def check_command(self, event):
+ return self._check_command(event, event["request_args"][0], False)
+
@utils.hook("received.command.mypermissions", authenticated=True)
def my_permissions(self, event):
"""
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index 869690ff..c0acee4d 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -185,10 +185,19 @@ def export(setting: str, value: typing.Any):
return module
return _export_func
+class MultiCheck(object):
+ def __init__(self,
+ requests: typing.List[typing.Tuple[str, typing.List[str]]]):
+ self.requests = requests
class Check(object):
def __init__(self, request: str, *args: typing.List[str]):
self.request = request
self.args = args
+ def to_multi(self):
+ return MultiCheck([(self.request, self.args)])
+ def __or__(self, other: "Check"):
+ return MultiCheck([(self.request, self.args),
+ (other.request, other.args)])
TOP_10_CALLABLE = typing.Callable[[typing.Any], typing.Any]
def top_10(items: typing.Dict[typing.Any, typing.Any],