aboutsummaryrefslogtreecommitdiff
path: root/markov
diff options
context:
space:
mode:
authorArjun Satarkar <me@arjunsatarkar.net>2024-03-20 03:07:54 +0000
committerArjun Satarkar <me@arjunsatarkar.net>2024-03-20 03:07:54 +0000
commit16aeaa50db79e3c17194128a93d22370bf5daa33 (patch)
treecbcef052ca3afb49274513446514438fa1ed1604 /markov
parent92a0d3eb4fd919ff9ed33b81d379ccb77af43026 (diff)
downloadaps-cogs-16aeaa50db79e3c17194128a93d22370bf5daa33.tar
aps-cogs-16aeaa50db79e3c17194128a93d22370bf5daa33.tar.gz
aps-cogs-16aeaa50db79e3c17194128a93d22370bf5daa33.zip
markov: avoid float division in uint_to_bytes
It seems as though the float division could cause an issue where the actual value should be an integer but the float representation is slightly higher than that, causing the addition of an extra unnecessary byte. Although this works fine as long as the same function is used, it could cause problems if seemingly inconsequential details change. I switch to using divmod and checking for the presence of a remainder. This should ensure use of the exact minimum required number of bytes in the blob.
Diffstat (limited to 'markov')
-rw-r--r--markov/markov.py31
1 files changed, 17 insertions, 14 deletions
diff --git a/markov/markov.py b/markov/markov.py
index e206800..73ac08e 100644
--- a/markov/markov.py
+++ b/markov/markov.py
@@ -151,13 +151,29 @@ class Markov(commands.Cog):
def uint_to_bytes(self, x: int):
if x < 0:
raise ValueError(f"x must be non-negative (got {x})")
- return x.to_bytes(math.ceil(x.bit_length() / 8), byteorder="big", signed=False)
+ byte_length, remainder = divmod(x.bit_length(), 8)
+ if remainder:
+ byte_length += 1
+ return x.to_bytes(byte_length, byteorder="big", signed=False)
def get_base_channel(self, channel_or_thread):
if isinstance(channel_or_thread, discord.Thread):
return channel_or_thread.parent
return channel_or_thread
+ def append_token(self, text, token):
+ # NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note
+ PUNCTUATION = r".,!?/;()"
+ if token == "/":
+ text = text[:-1] + token
+ elif token == "(":
+ text += token
+ elif token in PUNCTUATION:
+ text = text[:-1] + token + " "
+ else:
+ text += token + " "
+ return text
+
@commands.group()
async def markov(self, _ctx):
"""
@@ -385,19 +401,6 @@ class Markov(commands.Cog):
await db.commit()
await ctx.reply("All markov data for this guild has been deleted.")
- def append_token(self, text, token):
- # NOTE: if changing PUNCTUATION, also change the regex in process_message() with the corresponding note
- PUNCTUATION = r".,!?/;()"
- if token == "/":
- text = text[:-1] + token
- elif token == "(":
- text += token
- elif token in PUNCTUATION:
- text = text[:-1] + token + " "
- else:
- text += token + " "
- return text
-
@markov.command()
async def generate(self, ctx, member: discord.Member | None):
if not await self.config.guild(ctx.guild).use_messages():