diff options
| -rw-r--r-- | modules/coins.py | 4 | ||||
| -rw-r--r-- | modules/database_backup.py | 4 | ||||
| -rw-r--r-- | modules/in.py | 7 | ||||
| -rw-r--r-- | src/Config.py | 14 | ||||
| -rw-r--r-- | src/IRCBot.py | 73 | ||||
| -rw-r--r-- | src/IRCLineHandler.py | 11 | ||||
| -rw-r--r-- | src/ModuleManager.py | 33 | ||||
| -rw-r--r-- | src/Timer.py | 39 | ||||
| -rw-r--r-- | src/Timers.py | 73 | ||||
| -rwxr-xr-x | start.py | 29 |
10 files changed, 145 insertions, 142 deletions
diff --git a/modules/coins.py b/modules/coins.py index 5ec1c728..86c334e2 100644 --- a/modules/coins.py +++ b/modules/coins.py @@ -33,8 +33,8 @@ class Module(object): until_next_hour = 60-now.second until_next_hour += ((60-(now.minute+1))*60) - bot.add_timer("coin-interest", INTEREST_INTERVAL, persist=False, - next_due=time.time()+until_next_hour) + bot.timers.add("coin-interest", INTEREST_INTERVAL, + time.time()+until_next_hour) @Utils.hook("received.command.coins") def coins(self, event): diff --git a/modules/database_backup.py b/modules/database_backup.py index e4a4ceb4..ab8dc194 100644 --- a/modules/database_backup.py +++ b/modules/database_backup.py @@ -11,8 +11,8 @@ class Module(object): until_next_hour = 60-now.second until_next_hour += ((60-(now.minute+1))*60) - bot.add_timer("database-backup", BACKUP_INTERVAL, persist=False, - next_due=time.time()+until_next_hour) + bot.timers.add("database-backup", BACKUP_INTERVAL, + time.time()+until_next_hour) @Utils.hook("timer.database-backup") def backup(self, event): diff --git a/modules/in.py b/modules/in.py index 36b84f28..7fb7d7ea 100644 --- a/modules/in.py +++ b/modules/in.py @@ -16,10 +16,9 @@ class Module(ModuleManager.BaseModule): if seconds <= SECONDS_MAX: due_time = int(time.time())+seconds - self.bot.add_timer("in", seconds, - target=event["target"].name, due_time=due_time, - server_id=event["server"].id, nickname=event["user"].nickname, - message=message) + self.bot.timers.add_persistent("in", seconds, due_time=due_time, + target=event["target"].name, server_id=event["server"].id, + nickname=event["user"].nickname, message=message) event["stdout"].write("Saved") else: event["stderr"].write( diff --git a/src/Config.py b/src/Config.py index bf597b78..b5d27ea9 100644 --- a/src/Config.py +++ b/src/Config.py @@ -3,10 +3,20 @@ import configparser, os class Config(object): def __init__(self, location): self.location = location + self._config = {} + self.load() - def load_config(self): + def load(self): if os.path.isfile(self.location): with open(self.location) as config_file: parser = configparser.ConfigParser() parser.read_string(config_file.read()) - return dict(parser["bot"].items()) + self._config = dict(parser["bot"].items()) + + def __getitem__(self, key): + return self._config[key] + def get(self, key, default=None): + return self._config.get(key, default) + def __contains__(self, key): + return key in self.config + diff --git a/src/IRCBot.py b/src/IRCBot.py index 0e1d628f..596a9e66 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -1,24 +1,27 @@ import os, select, sys, threading, time, traceback, uuid from . import EventManager, Exports, IRCLineHandler, IRCServer, Logging -from . import ModuleManager, Timer +from . import ModuleManager + class Bot(object): - def __init__(self): + def __init__(self, args, config, database, events, exports, line_handler, + log, modules, timers): + self.args = args + self.config = config + self.database = database + self._events = events + self._exports = exports + self.line_handler = line_handler + self.log = log + self.modules = modules + self.timers = timers + + events.on("timer.reconnect").hook(self.reconnect) self.start_time = time.time() self.lock = threading.Lock() - self.args = None - self.database = None - self.config = None self.servers = {} self.running = True self.poll = select.epoll() - self.timers = [] - - self._events = None - self._exports = None - self.modules = None - self.log = None - self.line_handler = None def add_server(self, server_id, connect=True): (_, alias, hostname, port, password, ipv4, tls, nickname, @@ -43,46 +46,7 @@ class Bot(object): self.servers[server.fileno()] = server self.poll.register(server.fileno(), select.EPOLLOUT) return True - def setup_timers(self, event): - for setting, value in self.find_settings("timer-%"): - id = setting.split("timer-", 1)[1] - self.add_timer(value["event-name"], value["delay"], value[ - "next-due"], id, **value["kwargs"]) - def timer_setting(self, timer): - self.set_setting("timer-%s" % timer.id, { - "event-name": timer.event_name, "delay": timer.delay, - "next-due": timer.next_due, "kwargs": timer.kwargs}) - def timer_setting_remove(self, timer): - self.timers.remove(timer) - self.del_setting("timer-%s" % timer.id) - def add_timer(self, event_name, delay, next_due=None, id=None, persist=True, - **kwargs): - id = id or uuid.uuid4().hex - timer = Timer.Timer(id, self, self._events, event_name, delay, - next_due, **kwargs) - if id: - timer.id = id - elif persist: - self.timer_setting(timer) - self.timers.append(timer) - def next_timer(self): - next = None - for timer in self.timers: - time_left = timer.time_left() - if next == None or time_left < next: - next = time_left - if next == None: - return None - if next < 0: - return 0 - return next - def call_timers(self): - for timer in self.timers[:]: - if timer.due(): - timer.call() - if timer.done(): - self.timer_setting_remove(timer) def next_send(self): next = None for server in self.servers.values(): @@ -110,7 +74,7 @@ class Bot(object): def get_poll_timeout(self): timeouts = [] - timeouts.append(self.next_timer()) + timeouts.append(self.timers.next()) timeouts.append(self.next_send()) timeouts.append(self.next_ping()) timeouts.append(self.next_read_timeout()) @@ -154,7 +118,8 @@ class Bot(object): while self.running: self.lock.acquire() events = self.poll.poll(self.get_poll_timeout()) - self.call_timers() + self.timers.call() + for fd, event in events: if fd in self.servers: server = self.servers[fd] @@ -185,7 +150,7 @@ class Bot(object): self.disconnect(server) reconnect_delay = self.config.get("reconnect-delay", 10) - self.add_timer("reconnect", reconnect_delay, None, None, False, + self.timers.add("reconnect", reconnect_delay, server_id=server.id) print("disconnected from %s, reconnecting in %d seconds" % ( diff --git a/src/IRCLineHandler.py b/src/IRCLineHandler.py index 7d2336dc..805daa20 100644 --- a/src/IRCLineHandler.py +++ b/src/IRCLineHandler.py @@ -14,9 +14,9 @@ CAPABILITIES = {"multi-prefix", "chghost", "invite-notify", "account-tag", "batch", "draft/labeled-response"} class LineHandler(object): - def __init__(self, bot, events): - self.bot = bot + def __init__(self, events, timers): self.events = events + self.timers = timers events.on("raw.PING").hook(self.ping) events.on("raw.001").hook(self.handle_001, default_event=True) @@ -570,10 +570,9 @@ class LineHandler(object): # we need a registered nickname for this channel def handle_477(self, event): channel_name = Utils.irc_lower(event["server"], event["args"][1]) - if channel_name in event["server"].attempted_join: - self.bot.add_timer("rejoin", 5, - channel_name=event["args"][1], - key=event["server"].attempted_join[channel_name], + if channel_name in event["server"]: + key = event["server"].attempted_join[channel_name] + self.timers.add("rejoin", 5, channel_name=channe_name, key=key, server_id=event["server"].id) # someone's been kicked from a channel diff --git a/src/ModuleManager.py b/src/ModuleManager.py index 7787d9f3..c9aa650a 100644 --- a/src/ModuleManager.py +++ b/src/ModuleManager.py @@ -27,13 +27,16 @@ class BaseModule(object): self.exports = exports class ModuleManager(object): - def __init__(self, bot, events, exports, directory): - self.bot = bot + def __init__(self, events, exports, config, log, directory): self.events = events self.exports = exports + self.config = config + self.log = log self.directory = directory + self.modules = {} self.waiting_requirement = {} + def list_modules(self): return sorted(glob.glob(os.path.join(self.directory, "*.py"))) @@ -47,7 +50,7 @@ class ModuleManager(object): def _get_magic(self, obj, magic, default): return getattr(obj, magic) if hasattr(obj, magic) else default - def _load_module(self, name): + def _load_module(self, bot, name): path = self._module_path(name) with io.open(path, mode="r", encoding="utf8") as module_file: @@ -61,8 +64,7 @@ class ModuleManager(object): raise ModuleNotLoadedWarning("module ignored") elif line_split[0] == "#--require-config" and len( line_split) > 1: - if not line_split[1].lower() in self.bot.config or not self.bot.config[ - line_split[1].lower()]: + if not self.config.get(line_split[1].lower(), None): # nope, required config option not present. raise ModuleNotLoadedWarning( "required config not present") @@ -88,8 +90,7 @@ class ModuleManager(object): context = str(uuid.uuid4()) context_events = self.events.new_context(context) context_exports = self.exports.new_context(context) - module_object = module.Module(self.bot, context_events, - context_exports) + module_object = module.Module(bot, context_events, context_exports) if not hasattr(module_object, "_name"): module_object._name = name.title() @@ -109,29 +110,29 @@ class ModuleManager(object): "attempted to be used twice") return module_object - def load_module(self, name): + def load_module(self, bot, name): try: - module = self._load_module(name) + module = self._load_module(bot, name) except ModuleWarning as warning: - self.bot.log.error("Module '%s' not loaded", [name]) + self.log.error("Module '%s' not loaded", [name]) raise except Exception as e: - self.bot.log.error("Failed to load module \"%s\": %s", + self.log.error("Failed to load module \"%s\": %s", [name, str(e)]) raise self.modules[module._import_name] = module if name in self.waiting_requirement: for requirement_name in self.waiting_requirement: - self.load_module(requirement_name) - self.bot.log.info("Module '%s' loaded", [name]) + self.load_module(bot, requirement_name) + self.log.info("Module '%s' loaded", [name]) - def load_modules(self, whitelist=[], blacklist=[]): + def load_modules(self, bot, whitelist=[], blacklist=[]): for path in self.list_modules(): name = self._module_name(path) if name in whitelist or (not whitelist and not name in blacklist): try: - self.load_module(name) + self.load_module(bot, name) except ModuleWarning: pass @@ -151,5 +152,5 @@ class ModuleManager(object): references -= 1 # 'del module' removes one reference references -= 1 # one of the refs is from getrefcount - self.bot.log.info("Module '%s' unloaded (%d reference%s)", + self.log.info("Module '%s' unloaded (%d reference%s)", [name, references, "" if references == 1 else "s"]) diff --git a/src/Timer.py b/src/Timer.py deleted file mode 100644 index 7ac83630..00000000 --- a/src/Timer.py +++ /dev/null @@ -1,39 +0,0 @@ -import time, uuid - -class Timer(object): - def __init__(self, id, bot, events, event_name, delay, - next_due=None, **kwargs): - self.id = id - self.bot = bot - self.events = events - self.event_name = event_name - self.delay = delay - if next_due: - self.next_due = next_due - else: - self.set_next_due() - self.kwargs = kwargs - self._done = False - self.call_count = 0 - - def set_next_due(self): - self.next_due = time.time()+self.delay - - def due(self): - return self.time_left() <= 0 - - def time_left(self): - return self.next_due-time.time() - - def call(self): - self._done = True - self.call_count +=1 - self.events.on("timer").on(self.event_name).call( - timer=self, **self.kwargs) - - def redo(self): - self._done = False - self.set_next_due() - - def done(self): - return self._done diff --git a/src/Timers.py b/src/Timers.py new file mode 100644 index 00000000..d3a297df --- /dev/null +++ b/src/Timers.py @@ -0,0 +1,73 @@ +import time, uuid + +class Timer(object): + def __init__(self, id, name, delay, next_due, kwargs): + self.id = id + self.name = name + self.delay = delay + if next_due: + self.next_due = next_due + else: + self.set_next_due() + self.kwargs = kwargs + self._done = False + + def set_next_due(self): + self.next_due = time.time()+self.delay + def due(self): + return self.time_left() <= 0 + def time_left(self): + return self.next_due-time.time() + + def redo(self): + self._done = False + self.set_next_due() + def finish(): + self._done = True + def done(self): + return self._done + +class Timers(object): + def __init__(self, events, log): + self.events = events + self.log = log + self.timers = [] + + def setup(self, timers): + for name, timer in timers: + id = name.split("timer-", 1)[1] + self._add(timer["name"], timer["delay"], timer[ + "next-due"], id, False, timer["kwargs"]) + + def _persist(self, timer): + self.set_setting("timer-%s" % timer.id, { + "name": timer.name, "delay": timer.delay, + "next-due": timer.next_due, "kwargs": timer.kwargs}) + def _remove(self, timer): + self.timers.remove(timer) + self.del_setting("timer-%s" % timer.id) + + def add(self, name, delay, next_due=None, **kwargs): + self._add(name, delay, next_due, None, False, kwargs) + def add_persistent(self, name, delay, next_due=None, **kwargs): + self._add(name, delay, next_due, None, True, kwargs) + def _add(self, name, delay, next_due, id, persist, kwargs): + id = id or uuid.uuid4().hex + timer = Timer(id, name, delay, next_due, kwargs) + if persist: + self._persist(timer) + self.timers.append(timer) + + def next(self): + times = filter(None, [timer.time_left() for timer in self.timers]) + if not times: + return None + return max(min(times), 0) + + def call(self): + for timer in self.timers[:]: + if timer.due(): + timer.finish() + self.events.on("timer.%s" % timer.name, timer=timer) + if timer.done(): + self._remove(timer) @@ -2,7 +2,7 @@ import argparse, os, sys, time from src import Config, Database, EventManager, Exports, IRCBot -from src import IRCLineHandler, Logging, ModuleManager +from src import IRCLineHandler, Logging, ModuleManager, Timers def bool_input(s): result = input("%s (Y/n): " % s) @@ -29,31 +29,23 @@ arg_parser.add_argument("--verbose", "-v", action="store_true") args = arg_parser.parse_args() + log = Logging.Log(args.log) -config = Config.Config(args.config).load_config() +config = Config.Config(args.config) database = Database.Database(log, args.database) events = events = EventManager.EventHook(log) exports = exports = Exports.Exports() - -bot = IRCBot.Bot() - -bot.modules = modules = ModuleManager.ModuleManager(bot, events, exports, +timers = Timers.Timers(events, log) +line_handler = IRCLineHandler.LineHandler(events, timers) +modules = modules = ModuleManager.ModuleManager(events, exports, config, log, os.path.join(directory, "modules")) -bot.line_handler = IRCLineHandler.LineHandler(bot, events) -bot.log = log -bot.config = config -bot.database = database -bot._events = events -bot._exports = exports -bot.args = args - -bot._events.on("timer.reconnect").hook(bot.reconnect) -bot._events.on("boot.done").hook(bot.setup_timers) +bot = IRCBot.Bot(args, config, database, events, exports, line_handler, log, + modules, timers) whitelist = bot.get_setting("module-whitelist", []) blacklist = bot.get_setting("module-blacklist", []) -bot.modules.load_modules(whitelist=whitelist, blacklist=blacklist) +modules.load_modules(bot, whitelist=whitelist, blacklist=blacklist) servers = [] for server_id, alias in bot.database.servers.get_all(): @@ -62,6 +54,9 @@ for server_id, alias in bot.database.servers.get_all(): servers.append(server) if len(servers): bot._events.on("boot.done").call() + + bot.timers.setup(bot.find_settings_prefix("timer-")) + for server in servers: if not bot.connect(server): sys.stderr.write("failed to connect to '%s', exiting\r\n" % ( |
