aboutsummaryrefslogtreecommitdiff
path: root/markov
diff options
context:
space:
mode:
authorArjun Satarkar <me@arjunsatarkar.net>2024-03-12 18:45:03 +0000
committerArjun Satarkar <me@arjunsatarkar.net>2024-03-12 18:45:03 +0000
commit228565b70ba3d1015ee85459c508f72ab363be0f (patch)
tree1d6e7b21500a6e073414d1fa0278e96c719c27ff /markov
parent529a23180fbcd9ca09ef55a9de2d4faa646aae55 (diff)
downloadaps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.tar
aps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.tar.gz
aps-cogs-228565b70ba3d1015ee85459c508f72ab363be0f.zip
markov: reduce code duplication, add command to nuke guild data
This also fixes bugs related to inconsistent handling of per-member and whole-guild generation.
Diffstat (limited to 'markov')
-rw-r--r--markov/errors.py12
-rw-r--r--markov/markov.py291
2 files changed, 169 insertions, 134 deletions
diff --git a/markov/errors.py b/markov/errors.py
index 048aa5b..1af7de8 100644
--- a/markov/errors.py
+++ b/markov/errors.py
@@ -1,2 +1,14 @@
class MarkovGenerationError(Exception):
pass
+
+
+class NoTotalCompletionCountError(MarkovGenerationError):
+ pass
+
+
+class NoNextTokenError(MarkovGenerationError):
+ pass
+
+
+class InvalidCompletionCountError(MarkovGenerationError):
+ pass
diff --git a/markov/markov.py b/markov/markov.py
index fd7a501..19822a1 100644
--- a/markov/markov.py
+++ b/markov/markov.py
@@ -40,6 +40,7 @@ class Markov(commands.Cog):
) as setup_script_file:
async with aiosqlite.connect(self.db_path) as db:
await db.executescript(setup_script_file.read())
+ await db.commit()
@commands.Cog.listener()
async def on_message_without_command(self, message):
@@ -354,6 +355,36 @@ class Markov(commands.Cog):
await self.exclusion_list(ctx, ExclusionType.IGNORE)
@markov.command()
+ @commands.admin_or_permissions(manage_guild=True)
+ async def delete_guild_data(self, ctx, confirmation: str | None):
+ if confirmation != "YES_DELETE_IT_ALL":
+ await ctx.reply(
+ "This will delete **all** markov data for this Discord server."
+ " Rerun this as `markov delete_guild_data YES_DELETE_IT_ALL`"
+ " if you are sure."
+ )
+ return
+ async with aiosqlite.connect(self.db_path) as db:
+ await db.execute(
+ "DELETE FROM guild_total_completion_count WHERE guild_id = ?;",
+ (self.uint_to_bytes(ctx.guild.id),),
+ )
+ await db.execute(
+ "DELETE FROM guild_pairs WHERE guild_id = ?;",
+ (self.uint_to_bytes(ctx.guild.id),),
+ )
+ await db.execute(
+ "DELETE FROM member_total_completion_count WHERE guild_id = ?;",
+ (self.uint_to_bytes(ctx.guild.id),),
+ )
+ await db.execute(
+ "DELETE FROM member_pairs WHERE guild_id = ?;",
+ (self.uint_to_bytes(ctx.guild.id),),
+ )
+ await db.commit()
+ await ctx.reply("All markov data for this guild has been deleted.")
+
+ @markov.command()
async def generate(self, ctx, member: discord.Member | None):
if not await self.config.guild(ctx.guild).use_messages():
await ctx.reply("Not enabled in this guild.")
@@ -363,145 +394,137 @@ class Markov(commands.Cog):
await ctx.reply("That member has opted out of markov generation.")
return
+ async def get_total_completion_count(
+ db: aiosqlite.Connection,
+ guild_id: int,
+ member_id: int | None,
+ first_token: str,
+ ):
+ if not member_id:
+ row = await (
+ await db.execute(
+ "SELECT total_completion_count FROM guild_total_completion_count"
+ " WHERE guild_id = ? AND first_token = ?;",
+ (self.uint_to_bytes(guild_id), first_token),
+ )
+ ).fetchone()
+ else:
+ row = await (
+ await db.execute(
+ "SELECT total_completion_count FROM member_total_completion_count"
+ " WHERE guild_id = ? AND member_id = ? AND first_token = ?;",
+ (
+ self.uint_to_bytes(guild_id),
+ self.uint_to_bytes(member_id),
+ first_token,
+ ),
+ )
+ ).fetchone()
+ return row[0] if row else None
+
+ async def get_possible_next_token(
+ db: aiosqlite.Connection,
+ guild_id: int,
+ member_id: int | None,
+ first_token: str,
+ offset: int,
+ ):
+ if not member_id:
+ row = await (
+ await db.execute(
+ "SELECT second_token, frequency FROM guild_pairs"
+ " WHERE guild_id = ? AND first_token = ?"
+ " ORDER BY frequency DESC LIMIT 1 OFFSET ?;",
+ (self.uint_to_bytes(guild_id), first_token, offset),
+ )
+ ).fetchone()
+ else:
+ row = await (
+ await db.execute(
+ "SELECT second_token, frequency FROM member_pairs"
+ " WHERE guild_id = ? AND member_id = ? AND first_token = ?"
+ " ORDER BY frequency DESC LIMIT 1 OFFSET ?;",
+ (
+ self.uint_to_bytes(guild_id),
+ self.uint_to_bytes(member_id),
+ first_token,
+ offset,
+ ),
+ )
+ ).fetchone()
+ if not row:
+ return None, None
+ next_token, frequency = row
+ return next_token, frequency
+
# NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note
PUNCTUATION = ".,!?/"
- if member is None:
- result = ""
- token = ""
- async with aiosqlite.connect(self.db_path) as db:
- while True:
- row = await (
- await db.execute(
- "SELECT total_completion_count FROM guild_total_completion_count"
- " WHERE guild_id = ? AND first_token = ?;",
- (self.uint_to_bytes(ctx.guild.id), token),
- )
- ).fetchone()
- if row is None:
- if token == "":
- await ctx.reply("Error: no data for this guild yet!")
- return
- raise MarkovGenerationError(
- "Table guild_total_completion_count had no row for token"
- f" {repr(token)} for guild {ctx.guild.id} - this should never happen!"
- )
- completion_count = row[0]
-
- for i in range(MAX_TOKEN_GENERATION_ITERATIONS):
- row = await (
- await db.execute(
- "SELECT second_token, frequency FROM guild_pairs"
- " WHERE guild_id = ? AND first_token = ?"
- " ORDER BY frequency DESC LIMIT 1 OFFSET ?;",
- (self.uint_to_bytes(ctx.guild.id), token, i),
- )
- ).fetchone()
- if row is None:
- raise MarkovGenerationError(
- "There was no completion in guild_pairs for token"
- f" {repr(token)} for guild {ctx.guild.id} on iteration {i}"
- " - this should never happen!"
- )
- next_token, frequency = row
-
- if random.randint(1, completion_count) <= frequency:
- if next_token == "/":
- result = result[:-1] + next_token
- elif next_token in PUNCTUATION:
- result = result[:-1] + next_token + " "
- else:
- result += next_token + " "
- token = next_token
- break
-
- completion_count -= frequency
- if completion_count <= 0:
- raise MarkovGenerationError(
- "Sum of all frequencies in guild_pairs for token"
- f" {repr(token)} in guild {ctx.guild.id} added up"
- " to more than completion_count or we failed to"
- " choose a completion despite trying all of them"
- " This should never happen!"
- )
- else:
- token = ""
-
+ member_id = member.id if member else None
+ result = ""
+ token = ""
+ async with aiosqlite.connect(self.db_path) as db:
+ while True:
+ completion_count = await get_total_completion_count(
+ db, ctx.guild.id, member_id, token
+ )
+ if completion_count is None:
if token == "":
- break
- await ctx.send(result, allowed_mentions=discord.AllowedMentions.none())
- else:
- result = ""
- token = ""
- async with aiosqlite.connect(self.db_path) as db:
- while True:
- row = await (
- await db.execute(
- "SELECT total_completion_count FROM member_total_completion_count"
- " WHERE guild_id = ? AND member_id = ? AND first_token = ?;",
- (
- self.uint_to_bytes(ctx.guild.id),
- self.uint_to_bytes(member.id),
- token,
- ),
+ await ctx.reply(
+ f"Error: no data for this {'member' if member else 'guild'} yet!"
)
- ).fetchone()
- if row is None:
- if token == "":
- await ctx.reply("Error: no data for this member yet!")
- return
- raise MarkovGenerationError(
- "Table member_total_completion_count had no row for token"
- f" {repr(token)} for guild {ctx.guild.id} member {member.id}"
- " - this should never happen!"
+ return
+ raise NoTotalCompletionCountError(
+ repr(
+ {
+ "guild_id": ctx.guild.id,
+ "member_id": member_id,
+ "token": token,
+ }
)
- completion_count = row[0]
-
- next_token = None
- for i in range(MAX_TOKEN_GENERATION_ITERATIONS):
- row = await (
- await db.execute(
- "SELECT second_token, frequency FROM member_pairs"
- " WHERE guild_id = ? AND member_id = ? AND first_token = ?"
- " ORDER BY frequency DESC LIMIT 1 OFFSET ?;",
- (
- self.uint_to_bytes(ctx.guild.id),
- self.uint_to_bytes(member.id),
- token,
- i,
- ),
+ )
+ next_token = None
+ for i in range(MAX_TOKEN_GENERATION_ITERATIONS):
+ next_token, frequency = await get_possible_next_token(
+ db, ctx.guild.id, member_id, token, i
+ )
+ if next_token is None:
+ raise NoNextTokenError(
+ repr(
+ {
+ "guild_id": ctx.guild.id,
+ "member_id": member_id,
+ "token": token,
+ "offset": i,
+ }
)
- ).fetchone()
- if row is None:
- raise MarkovGenerationError(
- "There was no completion in guild_pairs for token"
- f" {repr(token)} for guild {ctx.guild.id} member {member.id}"
- f" on iteration {i} - this should never happen!"
- )
- next_token, frequency = row
-
- if random.randint(1, completion_count) <= frequency:
- if next_token in PUNCTUATION:
- result = result[:-1] + next_token + " "
- else:
- result += next_token + " "
- token = next_token
- break
-
- completion_count -= frequency
- if completion_count <= 0:
- raise MarkovGenerationError(
- "Sum of all frequencies in guild_pairs for token"
- f" {repr(token)} for guild {ctx.guild.id} member"
- f" {member.id} added up to more than completion_count"
- " or we failed to choose a completion despite trying"
- " all of them. This should never happen!"
- )
- else:
- # If we went through MAX_TOKEN_GENERATION_ITERATIONS completions
- # without selecting any, then just select the last one we considered
- # (round off the probability, effectively).
+ )
+ if random.randint(1, completion_count) <= frequency:
+ if next_token == "/":
+ result = result[:-1] + next_token
+ elif next_token in PUNCTUATION:
+ result = result[:-1] + next_token + " "
+ else:
+ result += next_token + " "
token = next_token
-
- if token == "":
break
- await ctx.send(result, allowed_mentions=discord.AllowedMentions.none())
+
+ completion_count -= frequency
+ if completion_count <= 0:
+ raise InvalidCompletionCountError(
+ repr(
+ {
+ "guild_id": ctx.guild.id,
+ "member_id": member_id,
+ "token": token,
+ "offset": i,
+ }
+ )
+ )
+ else:
+ # If we went through MAX_TOKEN_GENERATION_ITERATIONS completions
+ # without selecting any, then just select the last one we considered
+ # (round off the probability, effectively)
+ token = next_token
+ if token == "":
+ break
+ await ctx.send(result, allowed_mentions=discord.AllowedMentions.none())