aboutsummaryrefslogtreecommitdiff
path: root/modules/markov.py
blob: c5f9ff555d9a64df608d24b02f5344e4ce3c4fe5 (about) (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
from src import ModuleManager, utils

NO_MARKOV = "Markov chains not enabled in this channel"

class Module(ModuleManager.BaseModule):
    def on_load(self):
        if not self.bot.database.has_table("markov"):
            self.bot.database.execute("""CREATE TABLE markov
                (channel_id INTEGER, first_word TEXT, second_word TEXT,
                third_word TEXT, frequency INT,
                FOREIGN KEY (channel_id) REFERENCES channels(channel_id),
                PRIMARY KEY (channel_id, first_word, second_word))""")

    @utils.hook("received.message.channel")
    def channel_message(self, event):
        if event["channel"].get_setting("markov", False):
            self._create(event["channel"].id, event["message_split"])

    @utils.hook("received.command.markovlog")
    @utils.kwarg("min_args", 1)
    @utils.kwarg("permission", "markovlog")
    @utils.kwarg("help", "Load a message-only newline-delimited log in to this "
        "channel's markov chain")
    def load_log(self, event):
        if not event["target"].get_setting("markov", False):
            raise utils.EventError(NO_MARKOV)

        page = utils.http.request(event["args_split"][0])
        if page.code == 200:
            for line in page.data.decode("utf8").split("\n"):
                self._create(event["target"].id, line.strip("\r").split(" "))
            event["stdout"].write("Log imported")
        else:
            event["stderr"].write("Failed to load log (%d)" % page.code)

    def _create(self, channel_id, words):
        words = list(filter(None, words))
        words = [word.lower() for word in words]
        words_n = len(words)

        if not words_n > 2:
            return

        inserts = []
        inserts.append([None, None, words[0]])
        inserts.append([None, words[0], words[1]])

        for i in range(words_n-2):
            inserts.append(words[i:i+3])

        inserts.append([words[-2], words[-1], None])

        for insert in inserts:
            frequency = self.bot.database.execute_fetchone("""SELECT
                frequency FROM markov WHERE channel_id=? AND first_word=?
                AND second_word=? AND third_word=?""",
                [channel_id]+insert)
            frequency = (frequency or [0])[0]+1

            self.bot.database.execute(
                "INSERT OR REPLACE INTO markov VALUES (?, ?, ?, ?, ?)",
                [channel_id]+insert+[frequency])

    def _choose(self, words):
        words, frequencies = list(zip(*words))
        return random.choices(words, weights=frequencies, k=1)[0]

    @utils.hook("received.command.markov")
    @utils.kwarg("channel_only", True)
    @utils.kwarg("help", "Generate a markov chain for the current channel")
    def markov(self, event):
        self._markov_for(event["target"], event["stdout"], event["stderr"])

    @utils.hook("received.command.markovfor")
    @utils.kwarg("min_args", 1)
    @utils.kwarg("permission", "markovfor")
    @utils.kwarg("help", "Generate a markov chain for a given channel")
    @utils.kwarg("usage", "<channel>")
    def markov_for(self, event):
        if event["args_split"][0] in event["server"].channels:
            channel = event["server"].channels.get(event["args_split"][0])
            self._markov_for(channel, event["stdout"], event["stderr"])
        else:
            event["stderr"].write("Unknown channel")

    def _markov_for(self, channel, stdout, stderr):
        if not channel.get_setting("markov", False):
            stderr.write(NO_MARKOV)
        else:
            out = self._generate(channel.id)
            if not out == None:
                stdout.write(out)
            else:
                stderr.write("Failed to generate markov chain")

    def _generate(self, channel_id):
        first_words = self.bot.database.execute_fetchall("""SELECT third_word,
            frequency FROM markov WHERE channel_id=? AND first_word IS NULL AND
            second_word IS NULL AND third_word NOT NULL""", [channel_id])
        if not first_words:
            return None
        first_word = self._choose(first_words)

        second_words = self.bot.database.execute_fetchall("""SELECT third_word,
            frequency FROM markov WHERE channel_id=? AND first_word IS NULL AND
            second_word=? AND third_word NOT NULL""", [channel_id, first_word])
        if not second_words:
            return None
        second_word = self._choose(second_words)

        words = [first_word, second_word]
        for i in range(30):
            two_words = words[-2:]
            third_words = self.bot.database.execute_fetchall("""SELECT
                third_word, frequency FROM markov WHERE channel_id=? AND
                first_word=? AND second_word=?""", [channel_id]+two_words)

            third_word = self._choose(third_words)
            if third_word == None:
                break
            words.append(third_word)

        return " ".join(words)