diff options
author | Arjun Satarkar <me@arjunsatarkar.net> | 2024-03-19 18:12:44 +0000 |
---|---|---|
committer | Arjun Satarkar <me@arjunsatarkar.net> | 2024-03-19 18:12:44 +0000 |
commit | 92a0d3eb4fd919ff9ed33b81d379ccb77af43026 (patch) | |
tree | c2fc790a99456783d0ca0f8e5ccf2c1ebaace3ae | |
parent | 314597b069c4fcd0f7d6c4f7796b0048809e0c8a (diff) | |
download | aps-cogs-92a0d3eb4fd919ff9ed33b81d379ccb77af43026.tar aps-cogs-92a0d3eb4fd919ff9ed33b81d379ccb77af43026.tar.gz aps-cogs-92a0d3eb4fd919ff9ed33b81d379ccb77af43026.zip |
markov: support brackets, separate out append_token logic
-rw-r--r-- | markov/info.json | 2 | ||||
-rw-r--r-- | markov/markov.py | 26 |
2 files changed, 17 insertions, 11 deletions
diff --git a/markov/info.json b/markov/info.json index 22ddb18..6002815 100644 --- a/markov/info.json +++ b/markov/info.json @@ -2,5 +2,5 @@ "author": ["Arjun Satarkar"], "description": "Use Markov chains to mimic users or the server as a whole.", "short": "Markov chains based on message content.", - "requirements": ["aiosqlite"] + "requirements": ["aiosqlite", "more-itertools"] } diff --git a/markov/markov.py b/markov/markov.py index b750519..e206800 100644 --- a/markov/markov.py +++ b/markov/markov.py @@ -86,13 +86,13 @@ class Markov(commands.Cog): # Extract words, punctuation, custom emoji, and mentions as # individual tokens, then add a sentinel (empty string) on either end. - # NOTE: if changing the punctuation in the regex, also change PUNCTUATION in generate() + # NOTE: if changing the punctuation in the regex, also change PUNCTUATION in append_token() tokens = ( [""] + [ token for token in re.findall( - r"[\w']+|[\.,!?\/;]|<a?:\w+:\d+>|<#\d+>|<@!?\d+>", content + r"[\w']+|[\.,!?\/;\(\)]|<a?:\w+:\d+>|<#\d+>|<@!?\d+>", content ) if len(token) <= MAX_TOKEN_LENGTH ] @@ -385,6 +385,19 @@ class Markov(commands.Cog): await db.commit() await ctx.reply("All markov data for this guild has been deleted.") + def append_token(self, text, token): + # NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note + PUNCTUATION = r".,!?/;()" + if token == "/": + text = text[:-1] + token + elif token == "(": + text += token + elif token in PUNCTUATION: + text = text[:-1] + token + " " + else: + text += token + " " + return text + @markov.command() async def generate(self, ctx, member: discord.Member | None): if not await self.config.guild(ctx.guild).use_messages(): @@ -458,8 +471,6 @@ class Markov(commands.Cog): next_token, frequency = row return next_token, frequency - # NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note - PUNCTUATION = r".,!?/;" member_id = member.id if member else None result = "" token = "" @@ -483,12 +494,7 @@ class Markov(commands.Cog): if next_token is None: raise NoNextTokenError(ctx.guild.id, member_id, token, i) if random.randint(1, completion_count) <= frequency: - if next_token == "/": - result = result[:-1] + next_token - elif next_token in PUNCTUATION: - result = result[:-1] + next_token + " " - else: - result += next_token + " " + result = self.append_token(result, next_token) token = next_token break |