diff options
| -rw-r--r-- | modules/rest_api.py | 79 | ||||
| -rw-r--r-- | modules/stats.py | 10 |
2 files changed, 51 insertions, 38 deletions
diff --git a/modules/rest_api.py b/modules/rest_api.py index 5fe7cc3b..c01d8280 100644 --- a/modules/rest_api.py +++ b/modules/rest_api.py @@ -6,48 +6,61 @@ _bot = None _events = None class Handler(http.server.BaseHTTPRequestHandler): timeout = 10 - def do_GET(self): - _bot.lock.acquire() - try: - parsed = urllib.parse.urlparse(self.path) - query = parsed.query - get_params = urllib.parse.parse_qs(query) + def _handle(self, method, path, params): + _, _, endpoint = path[1:].partition("/") + endpoint, _, args = endpoint.partition("/") + args = list(filter(None, args.split("/"))) - _, _, endpoint = parsed.path[1:].partition("/") - endpoint, _, args = endpoint.partition("/") - args = list(filter(None, args.split("/"))) + response = "" + code = 404 - response = "" - code = 404 + hooks = _events.on("api").on(method).on(endpoint).get_hooks() + if hooks: + hook = hooks[0] + authenticated = hook.get_kwarg("authenticated", True) + key = params.get("key", None) + if authenticated and (not key or not _bot.get_setting( + "api-key-%s" % key, False)): + code = 401 + else: + if path.startswith("/api/"): + response = _events.on("api").on(method).on(endpoint + ).call_for_result(params=params, path=args) - hooks = _events.on("api").on(endpoint).get_hooks() - if hooks: - hook = hooks[0] - authenticated = hook.get_kwarg("authenticated", True) - key = get_params.get("key", None) - if authenticated and ( - not key or - not _bot.get_setting("api-key-%s" % key[0], False)): - code = 401 - else: - if parsed.path.startswith("/api/"): - response = _events.on("api").on(endpoint - ).call_for_result(params=get_params, path=args) + if response: + response = json.dumps(response, sort_keys=True, + indent=4) + code = 200 - if response: - response = json.dumps(response, sort_keys=True, - indent=4) - code = 200 - - self.send_response(code) - self.send_header("Content-type", "application/json") - self.end_headers() - self.wfile.write(response.encode("utf8")) + self.send_response(code) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(response.encode("utf8")) + def _safe_handle(self, method, path, params): + _bot.lock.acquire() + try: + self._handle(method, path, params) except: pass finally: _bot.lock.release() + def _decode_params(self, s): + params = urllib.parse.parse_qs(s) + return dict([(k, v[0]) for k, v in params.items()]) + + def do_GET(self): + parsed = urllib.parse.urlparse(self.path) + get_params = self._decode_params(parsed.query) + self._handle("get", parsed.path, get_params) + + def do_POST(self): + parsed = urllib.parse.urlparse(self.path) + content_length = int(self.headers.get("content-length", 0)) + post_body = self.rfile.read(content_length) + post_params = self._decode_params(post_body) + self._handle("post", parsed.path, post_params) + @utils.export("botset", {"setting": "rest-api", "help": "Enable/disable REST API", "validate": utils.bool_or_none}) diff --git a/modules/stats.py b/modules/stats.py index 1913c1c9..8e93c575 100644 --- a/modules/stats.py +++ b/modules/stats.py @@ -11,7 +11,7 @@ class Module(ModuleManager.BaseModule): :help: Show my uptime """ event["stdout"].write("Uptime: %s" % self._uptime()) - @utils.hook("api.uptime") + @utils.hook("api.get.uptime") def uptime_api(self, event): return self._uptime() @@ -43,7 +43,7 @@ class Module(ModuleManager.BaseModule): event["stdout"].write(response) - @utils.hook("api.stats") + @utils.hook("api.get.stats") def stats_api(self, event): networks, channels, users = self._stats() return {"networks": networks, "channels": channels, "users": users} @@ -59,7 +59,7 @@ class Module(ModuleManager.BaseModule): "users": len(server.users) } - @utils.hook("api.servers") + @utils.hook("api.get.servers") def servers_api(self, event): if event["path"]: server_id = event["path"][0] @@ -84,7 +84,7 @@ class Module(ModuleManager.BaseModule): "topic-set-at": channel.topic_time, "topic-set-by": channel.topic_setter_nickname } - @utils.hook("api.channels") + @utils.hook("api.get.channels") def channels_api(self, event): if event["path"]: server_id = event["path"][0] @@ -108,6 +108,6 @@ class Module(ModuleManager.BaseModule): channel) return channels - @utils.hook("api.modules") + @utils.hook("api.get.modules") def modules_api(self, event): return list(self.bot.modules.modules.keys()) |
