diff options
author | Arjun Satarkar <me@arjunsatarkar.net> | 2024-03-12 18:45:03 +0000 |
---|---|---|
committer | Arjun Satarkar <me@arjunsatarkar.net> | 2024-03-12 18:45:03 +0000 |
commit | 228565b70ba3d1015ee85459c508f72ab363be0f (patch) | |
tree | 1d6e7b21500a6e073414d1fa0278e96c719c27ff /markov | |
parent | 529a23180fbcd9ca09ef55a9de2d4faa646aae55 (diff) | |
download | aps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.tar aps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.tar.gz aps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.zip |
markov: reduce code duplication, add command to nuke guild data
This also fixes bugs related to inconsistent handling of per-member
and whole-guild generation.
Diffstat (limited to 'markov')
-rw-r--r-- | markov/errors.py | 12 | ||||
-rw-r--r-- | markov/markov.py | 291 |
2 files changed, 169 insertions, 134 deletions
diff --git a/markov/errors.py b/markov/errors.py index 048aa5b..1af7de8 100644 --- a/markov/errors.py +++ b/markov/errors.py @@ -1,2 +1,14 @@ class MarkovGenerationError(Exception): pass + + +class NoTotalCompletionCountError(MarkovGenerationError): + pass + + +class NoNextTokenError(MarkovGenerationError): + pass + + +class InvalidCompletionCountError(MarkovGenerationError): + pass diff --git a/markov/markov.py b/markov/markov.py index fd7a501..19822a1 100644 --- a/markov/markov.py +++ b/markov/markov.py @@ -40,6 +40,7 @@ class Markov(commands.Cog): ) as setup_script_file: async with aiosqlite.connect(self.db_path) as db: await db.executescript(setup_script_file.read()) + await db.commit() @commands.Cog.listener() async def on_message_without_command(self, message): @@ -354,6 +355,36 @@ class Markov(commands.Cog): await self.exclusion_list(ctx, ExclusionType.IGNORE) @markov.command() + @commands.admin_or_permissions(manage_guild=True) + async def delete_guild_data(self, ctx, confirmation: str | None): + if confirmation != "YES_DELETE_IT_ALL": + await ctx.reply( + "This will delete **all** markov data for this Discord server." + " Rerun this as `markov delete_guild_data YES_DELETE_IT_ALL`" + " if you are sure." + ) + return + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "DELETE FROM guild_total_completion_count WHERE guild_id = ?;", + (self.uint_to_bytes(ctx.guild.id),), + ) + await db.execute( + "DELETE FROM guild_pairs WHERE guild_id = ?;", + (self.uint_to_bytes(ctx.guild.id),), + ) + await db.execute( + "DELETE FROM member_total_completion_count WHERE guild_id = ?;", + (self.uint_to_bytes(ctx.guild.id),), + ) + await db.execute( + "DELETE FROM member_pairs WHERE guild_id = ?;", + (self.uint_to_bytes(ctx.guild.id),), + ) + await db.commit() + await ctx.reply("All markov data for this guild has been deleted.") + + @markov.command() async def generate(self, ctx, member: discord.Member | None): if not await self.config.guild(ctx.guild).use_messages(): await ctx.reply("Not enabled in this guild.") @@ -363,145 +394,137 @@ class Markov(commands.Cog): await ctx.reply("That member has opted out of markov generation.") return + async def get_total_completion_count( + db: aiosqlite.Connection, + guild_id: int, + member_id: int | None, + first_token: str, + ): + if not member_id: + row = await ( + await db.execute( + "SELECT total_completion_count FROM guild_total_completion_count" + " WHERE guild_id = ? AND first_token = ?;", + (self.uint_to_bytes(guild_id), first_token), + ) + ).fetchone() + else: + row = await ( + await db.execute( + "SELECT total_completion_count FROM member_total_completion_count" + " WHERE guild_id = ? AND member_id = ? AND first_token = ?;", + ( + self.uint_to_bytes(guild_id), + self.uint_to_bytes(member_id), + first_token, + ), + ) + ).fetchone() + return row[0] if row else None + + async def get_possible_next_token( + db: aiosqlite.Connection, + guild_id: int, + member_id: int | None, + first_token: str, + offset: int, + ): + if not member_id: + row = await ( + await db.execute( + "SELECT second_token, frequency FROM guild_pairs" + " WHERE guild_id = ? AND first_token = ?" + " ORDER BY frequency DESC LIMIT 1 OFFSET ?;", + (self.uint_to_bytes(guild_id), first_token, offset), + ) + ).fetchone() + else: + row = await ( + await db.execute( + "SELECT second_token, frequency FROM member_pairs" + " WHERE guild_id = ? AND member_id = ? AND first_token = ?" + " ORDER BY frequency DESC LIMIT 1 OFFSET ?;", + ( + self.uint_to_bytes(guild_id), + self.uint_to_bytes(member_id), + first_token, + offset, + ), + ) + ).fetchone() + if not row: + return None, None + next_token, frequency = row + return next_token, frequency + # NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note PUNCTUATION = ".,!?/" - if member is None: - result = "" - token = "" - async with aiosqlite.connect(self.db_path) as db: - while True: - row = await ( - await db.execute( - "SELECT total_completion_count FROM guild_total_completion_count" - " WHERE guild_id = ? AND first_token = ?;", - (self.uint_to_bytes(ctx.guild.id), token), - ) - ).fetchone() - if row is None: - if token == "": - await ctx.reply("Error: no data for this guild yet!") - return - raise MarkovGenerationError( - "Table guild_total_completion_count had no row for token" - f" {repr(token)} for guild {ctx.guild.id} - this should never happen!" - ) - completion_count = row[0] - - for i in range(MAX_TOKEN_GENERATION_ITERATIONS): - row = await ( - await db.execute( - "SELECT second_token, frequency FROM guild_pairs" - " WHERE guild_id = ? AND first_token = ?" - " ORDER BY frequency DESC LIMIT 1 OFFSET ?;", - (self.uint_to_bytes(ctx.guild.id), token, i), - ) - ).fetchone() - if row is None: - raise MarkovGenerationError( - "There was no completion in guild_pairs for token" - f" {repr(token)} for guild {ctx.guild.id} on iteration {i}" - " - this should never happen!" - ) - next_token, frequency = row - - if random.randint(1, completion_count) <= frequency: - if next_token == "/": - result = result[:-1] + next_token - elif next_token in PUNCTUATION: - result = result[:-1] + next_token + " " - else: - result += next_token + " " - token = next_token - break - - completion_count -= frequency - if completion_count <= 0: - raise MarkovGenerationError( - "Sum of all frequencies in guild_pairs for token" - f" {repr(token)} in guild {ctx.guild.id} added up" - " to more than completion_count or we failed to" - " choose a completion despite trying all of them" - " This should never happen!" - ) - else: - token = "" - + member_id = member.id if member else None + result = "" + token = "" + async with aiosqlite.connect(self.db_path) as db: + while True: + completion_count = await get_total_completion_count( + db, ctx.guild.id, member_id, token + ) + if completion_count is None: if token == "": - break - await ctx.send(result, allowed_mentions=discord.AllowedMentions.none()) - else: - result = "" - token = "" - async with aiosqlite.connect(self.db_path) as db: - while True: - row = await ( - await db.execute( - "SELECT total_completion_count FROM member_total_completion_count" - " WHERE guild_id = ? AND member_id = ? AND first_token = ?;", - ( - self.uint_to_bytes(ctx.guild.id), - self.uint_to_bytes(member.id), - token, - ), + await ctx.reply( + f"Error: no data for this {'member' if member else 'guild'} yet!" ) - ).fetchone() - if row is None: - if token == "": - await ctx.reply("Error: no data for this member yet!") - return - raise MarkovGenerationError( - "Table member_total_completion_count had no row for token" - f" {repr(token)} for guild {ctx.guild.id} member {member.id}" - " - this should never happen!" + return + raise NoTotalCompletionCountError( + repr( + { + "guild_id": ctx.guild.id, + "member_id": member_id, + "token": token, + } ) - completion_count = row[0] - - next_token = None - for i in range(MAX_TOKEN_GENERATION_ITERATIONS): - row = await ( - await db.execute( - "SELECT second_token, frequency FROM member_pairs" - " WHERE guild_id = ? AND member_id = ? AND first_token = ?" - " ORDER BY frequency DESC LIMIT 1 OFFSET ?;", - ( - self.uint_to_bytes(ctx.guild.id), - self.uint_to_bytes(member.id), - token, - i, - ), + ) + next_token = None + for i in range(MAX_TOKEN_GENERATION_ITERATIONS): + next_token, frequency = await get_possible_next_token( + db, ctx.guild.id, member_id, token, i + ) + if next_token is None: + raise NoNextTokenError( + repr( + { + "guild_id": ctx.guild.id, + "member_id": member_id, + "token": token, + "offset": i, + } ) - ).fetchone() - if row is None: - raise MarkovGenerationError( - "There was no completion in guild_pairs for token" - f" {repr(token)} for guild {ctx.guild.id} member {member.id}" - f" on iteration {i} - this should never happen!" - ) - next_token, frequency = row - - if random.randint(1, completion_count) <= frequency: - if next_token in PUNCTUATION: - result = result[:-1] + next_token + " " - else: - result += next_token + " " - token = next_token - break - - completion_count -= frequency - if completion_count <= 0: - raise MarkovGenerationError( - "Sum of all frequencies in guild_pairs for token" - f" {repr(token)} for guild {ctx.guild.id} member" - f" {member.id} added up to more than completion_count" - " or we failed to choose a completion despite trying" - " all of them. This should never happen!" - ) - else: - # If we went through MAX_TOKEN_GENERATION_ITERATIONS completions - # without selecting any, then just select the last one we considered - # (round off the probability, effectively). + ) + if random.randint(1, completion_count) <= frequency: + if next_token == "/": + result = result[:-1] + next_token + elif next_token in PUNCTUATION: + result = result[:-1] + next_token + " " + else: + result += next_token + " " token = next_token - - if token == "": break - await ctx.send(result, allowed_mentions=discord.AllowedMentions.none()) + + completion_count -= frequency + if completion_count <= 0: + raise InvalidCompletionCountError( + repr( + { + "guild_id": ctx.guild.id, + "member_id": member_id, + "token": token, + "offset": i, + } + ) + ) + else: + # If we went through MAX_TOKEN_GENERATION_ITERATIONS completions + # without selecting any, then just select the last one we considered + # (round off the probability, effectively) + token = next_token + if token == "": + break + await ctx.send(result, allowed_mentions=discord.AllowedMentions.none()) |