aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jesopo2018-09-28 16:51:36 +0100
committerGravatar jesopo2018-09-28 16:51:36 +0100
commita8bf3c93007503ec411d0d23fef021f386127fa5 (patch)
tree4aa2c0d367606bd4c5e9ff3a0dda75a59267beb3 /src
parentFix typo in database_backup.py, 'ocation' -> 'location' (diff)
signature
Remove cyclical references to IRCBot
Diffstat (limited to 'src')
-rw-r--r--src/Config.py14
-rw-r--r--src/IRCBot.py73
-rw-r--r--src/IRCLineHandler.py11
-rw-r--r--src/ModuleManager.py33
-rw-r--r--src/Timer.py39
-rw-r--r--src/Timers.py73
6 files changed, 126 insertions, 117 deletions
diff --git a/src/Config.py b/src/Config.py
index bf597b78..b5d27ea9 100644
--- a/src/Config.py
+++ b/src/Config.py
@@ -3,10 +3,20 @@ import configparser, os
class Config(object):
def __init__(self, location):
self.location = location
+ self._config = {}
+ self.load()
- def load_config(self):
+ def load(self):
if os.path.isfile(self.location):
with open(self.location) as config_file:
parser = configparser.ConfigParser()
parser.read_string(config_file.read())
- return dict(parser["bot"].items())
+ self._config = dict(parser["bot"].items())
+
+ def __getitem__(self, key):
+ return self._config[key]
+ def get(self, key, default=None):
+ return self._config.get(key, default)
+ def __contains__(self, key):
+ return key in self.config
+
diff --git a/src/IRCBot.py b/src/IRCBot.py
index 0e1d628f..596a9e66 100644
--- a/src/IRCBot.py
+++ b/src/IRCBot.py
@@ -1,24 +1,27 @@
import os, select, sys, threading, time, traceback, uuid
from . import EventManager, Exports, IRCLineHandler, IRCServer, Logging
-from . import ModuleManager, Timer
+from . import ModuleManager
+
class Bot(object):
- def __init__(self):
+ def __init__(self, args, config, database, events, exports, line_handler,
+ log, modules, timers):
+ self.args = args
+ self.config = config
+ self.database = database
+ self._events = events
+ self._exports = exports
+ self.line_handler = line_handler
+ self.log = log
+ self.modules = modules
+ self.timers = timers
+
+ events.on("timer.reconnect").hook(self.reconnect)
self.start_time = time.time()
self.lock = threading.Lock()
- self.args = None
- self.database = None
- self.config = None
self.servers = {}
self.running = True
self.poll = select.epoll()
- self.timers = []
-
- self._events = None
- self._exports = None
- self.modules = None
- self.log = None
- self.line_handler = None
def add_server(self, server_id, connect=True):
(_, alias, hostname, port, password, ipv4, tls, nickname,
@@ -43,46 +46,7 @@ class Bot(object):
self.servers[server.fileno()] = server
self.poll.register(server.fileno(), select.EPOLLOUT)
return True
- def setup_timers(self, event):
- for setting, value in self.find_settings("timer-%"):
- id = setting.split("timer-", 1)[1]
- self.add_timer(value["event-name"], value["delay"], value[
- "next-due"], id, **value["kwargs"])
- def timer_setting(self, timer):
- self.set_setting("timer-%s" % timer.id, {
- "event-name": timer.event_name, "delay": timer.delay,
- "next-due": timer.next_due, "kwargs": timer.kwargs})
- def timer_setting_remove(self, timer):
- self.timers.remove(timer)
- self.del_setting("timer-%s" % timer.id)
- def add_timer(self, event_name, delay, next_due=None, id=None, persist=True,
- **kwargs):
- id = id or uuid.uuid4().hex
- timer = Timer.Timer(id, self, self._events, event_name, delay,
- next_due, **kwargs)
- if id:
- timer.id = id
- elif persist:
- self.timer_setting(timer)
- self.timers.append(timer)
- def next_timer(self):
- next = None
- for timer in self.timers:
- time_left = timer.time_left()
- if next == None or time_left < next:
- next = time_left
- if next == None:
- return None
- if next < 0:
- return 0
- return next
- def call_timers(self):
- for timer in self.timers[:]:
- if timer.due():
- timer.call()
- if timer.done():
- self.timer_setting_remove(timer)
def next_send(self):
next = None
for server in self.servers.values():
@@ -110,7 +74,7 @@ class Bot(object):
def get_poll_timeout(self):
timeouts = []
- timeouts.append(self.next_timer())
+ timeouts.append(self.timers.next())
timeouts.append(self.next_send())
timeouts.append(self.next_ping())
timeouts.append(self.next_read_timeout())
@@ -154,7 +118,8 @@ class Bot(object):
while self.running:
self.lock.acquire()
events = self.poll.poll(self.get_poll_timeout())
- self.call_timers()
+ self.timers.call()
+
for fd, event in events:
if fd in self.servers:
server = self.servers[fd]
@@ -185,7 +150,7 @@ class Bot(object):
self.disconnect(server)
reconnect_delay = self.config.get("reconnect-delay", 10)
- self.add_timer("reconnect", reconnect_delay, None, None, False,
+ self.timers.add("reconnect", reconnect_delay,
server_id=server.id)
print("disconnected from %s, reconnecting in %d seconds" % (
diff --git a/src/IRCLineHandler.py b/src/IRCLineHandler.py
index 7d2336dc..805daa20 100644
--- a/src/IRCLineHandler.py
+++ b/src/IRCLineHandler.py
@@ -14,9 +14,9 @@ CAPABILITIES = {"multi-prefix", "chghost", "invite-notify", "account-tag",
"batch", "draft/labeled-response"}
class LineHandler(object):
- def __init__(self, bot, events):
- self.bot = bot
+ def __init__(self, events, timers):
self.events = events
+ self.timers = timers
events.on("raw.PING").hook(self.ping)
events.on("raw.001").hook(self.handle_001, default_event=True)
@@ -570,10 +570,9 @@ class LineHandler(object):
# we need a registered nickname for this channel
def handle_477(self, event):
channel_name = Utils.irc_lower(event["server"], event["args"][1])
- if channel_name in event["server"].attempted_join:
- self.bot.add_timer("rejoin", 5,
- channel_name=event["args"][1],
- key=event["server"].attempted_join[channel_name],
+ if channel_name in event["server"]:
+ key = event["server"].attempted_join[channel_name]
+ self.timers.add("rejoin", 5, channel_name=channe_name, key=key,
server_id=event["server"].id)
# someone's been kicked from a channel
diff --git a/src/ModuleManager.py b/src/ModuleManager.py
index 7787d9f3..c9aa650a 100644
--- a/src/ModuleManager.py
+++ b/src/ModuleManager.py
@@ -27,13 +27,16 @@ class BaseModule(object):
self.exports = exports
class ModuleManager(object):
- def __init__(self, bot, events, exports, directory):
- self.bot = bot
+ def __init__(self, events, exports, config, log, directory):
self.events = events
self.exports = exports
+ self.config = config
+ self.log = log
self.directory = directory
+
self.modules = {}
self.waiting_requirement = {}
+
def list_modules(self):
return sorted(glob.glob(os.path.join(self.directory, "*.py")))
@@ -47,7 +50,7 @@ class ModuleManager(object):
def _get_magic(self, obj, magic, default):
return getattr(obj, magic) if hasattr(obj, magic) else default
- def _load_module(self, name):
+ def _load_module(self, bot, name):
path = self._module_path(name)
with io.open(path, mode="r", encoding="utf8") as module_file:
@@ -61,8 +64,7 @@ class ModuleManager(object):
raise ModuleNotLoadedWarning("module ignored")
elif line_split[0] == "#--require-config" and len(
line_split) > 1:
- if not line_split[1].lower() in self.bot.config or not self.bot.config[
- line_split[1].lower()]:
+ if not self.config.get(line_split[1].lower(), None):
# nope, required config option not present.
raise ModuleNotLoadedWarning(
"required config not present")
@@ -88,8 +90,7 @@ class ModuleManager(object):
context = str(uuid.uuid4())
context_events = self.events.new_context(context)
context_exports = self.exports.new_context(context)
- module_object = module.Module(self.bot, context_events,
- context_exports)
+ module_object = module.Module(bot, context_events, context_exports)
if not hasattr(module_object, "_name"):
module_object._name = name.title()
@@ -109,29 +110,29 @@ class ModuleManager(object):
"attempted to be used twice")
return module_object
- def load_module(self, name):
+ def load_module(self, bot, name):
try:
- module = self._load_module(name)
+ module = self._load_module(bot, name)
except ModuleWarning as warning:
- self.bot.log.error("Module '%s' not loaded", [name])
+ self.log.error("Module '%s' not loaded", [name])
raise
except Exception as e:
- self.bot.log.error("Failed to load module \"%s\": %s",
+ self.log.error("Failed to load module \"%s\": %s",
[name, str(e)])
raise
self.modules[module._import_name] = module
if name in self.waiting_requirement:
for requirement_name in self.waiting_requirement:
- self.load_module(requirement_name)
- self.bot.log.info("Module '%s' loaded", [name])
+ self.load_module(bot, requirement_name)
+ self.log.info("Module '%s' loaded", [name])
- def load_modules(self, whitelist=[], blacklist=[]):
+ def load_modules(self, bot, whitelist=[], blacklist=[]):
for path in self.list_modules():
name = self._module_name(path)
if name in whitelist or (not whitelist and not name in blacklist):
try:
- self.load_module(name)
+ self.load_module(bot, name)
except ModuleWarning:
pass
@@ -151,5 +152,5 @@ class ModuleManager(object):
references -= 1 # 'del module' removes one reference
references -= 1 # one of the refs is from getrefcount
- self.bot.log.info("Module '%s' unloaded (%d reference%s)",
+ self.log.info("Module '%s' unloaded (%d reference%s)",
[name, references, "" if references == 1 else "s"])
diff --git a/src/Timer.py b/src/Timer.py
deleted file mode 100644
index 7ac83630..00000000
--- a/src/Timer.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import time, uuid
-
-class Timer(object):
- def __init__(self, id, bot, events, event_name, delay,
- next_due=None, **kwargs):
- self.id = id
- self.bot = bot
- self.events = events
- self.event_name = event_name
- self.delay = delay
- if next_due:
- self.next_due = next_due
- else:
- self.set_next_due()
- self.kwargs = kwargs
- self._done = False
- self.call_count = 0
-
- def set_next_due(self):
- self.next_due = time.time()+self.delay
-
- def due(self):
- return self.time_left() <= 0
-
- def time_left(self):
- return self.next_due-time.time()
-
- def call(self):
- self._done = True
- self.call_count +=1
- self.events.on("timer").on(self.event_name).call(
- timer=self, **self.kwargs)
-
- def redo(self):
- self._done = False
- self.set_next_due()
-
- def done(self):
- return self._done
diff --git a/src/Timers.py b/src/Timers.py
new file mode 100644
index 00000000..d3a297df
--- /dev/null
+++ b/src/Timers.py
@@ -0,0 +1,73 @@
+import time, uuid
+
+class Timer(object):
+ def __init__(self, id, name, delay, next_due, kwargs):
+ self.id = id
+ self.name = name
+ self.delay = delay
+ if next_due:
+ self.next_due = next_due
+ else:
+ self.set_next_due()
+ self.kwargs = kwargs
+ self._done = False
+
+ def set_next_due(self):
+ self.next_due = time.time()+self.delay
+ def due(self):
+ return self.time_left() <= 0
+ def time_left(self):
+ return self.next_due-time.time()
+
+ def redo(self):
+ self._done = False
+ self.set_next_due()
+ def finish():
+ self._done = True
+ def done(self):
+ return self._done
+
+class Timers(object):
+ def __init__(self, events, log):
+ self.events = events
+ self.log = log
+ self.timers = []
+
+ def setup(self, timers):
+ for name, timer in timers:
+ id = name.split("timer-", 1)[1]
+ self._add(timer["name"], timer["delay"], timer[
+ "next-due"], id, False, timer["kwargs"])
+
+ def _persist(self, timer):
+ self.set_setting("timer-%s" % timer.id, {
+ "name": timer.name, "delay": timer.delay,
+ "next-due": timer.next_due, "kwargs": timer.kwargs})
+ def _remove(self, timer):
+ self.timers.remove(timer)
+ self.del_setting("timer-%s" % timer.id)
+
+ def add(self, name, delay, next_due=None, **kwargs):
+ self._add(name, delay, next_due, None, False, kwargs)
+ def add_persistent(self, name, delay, next_due=None, **kwargs):
+ self._add(name, delay, next_due, None, True, kwargs)
+ def _add(self, name, delay, next_due, id, persist, kwargs):
+ id = id or uuid.uuid4().hex
+ timer = Timer(id, name, delay, next_due, kwargs)
+ if persist:
+ self._persist(timer)
+ self.timers.append(timer)
+
+ def next(self):
+ times = filter(None, [timer.time_left() for timer in self.timers])
+ if not times:
+ return None
+ return max(min(times), 0)
+
+ def call(self):
+ for timer in self.timers[:]:
+ if timer.due():
+ timer.finish()
+ self.events.on("timer.%s" % timer.name, timer=timer)
+ if timer.done():
+ self._remove(timer)