From e891d6fc1d1912c03c488846d486272fd010f3bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 6 Mar 2025 18:10:35 +0100 Subject: [PATCH] feat: add retry capabilities to the bot --- pyproject.toml | 1 + src/huesoporro/bot.py | 60 +++++++++++++++++++++++++++++++++++++------ uv.lock | 11 ++++++++ 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a2888f..3a5c884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "redis>=5.2.1", "pytz>=2024.2", "discord-py>=2.4.0", + "tenacity>=9.0.0", ] [project.scripts] diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index c33cb47..c470b1f 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -5,6 +5,12 @@ from enum import StrEnum from typing import ClassVar from loguru import logger +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from twitchio import Channel from twitchio.ext import commands, routines @@ -64,6 +70,8 @@ class Bot(commands.Bot): ) return await ctx.send(sentence) + # Wait for the specified time + await asyncio.sleep(60) @commands.command(aliases=["qadd"]) async def add_quote(self, ctx: commands.Context, *, quote: str): @@ -134,7 +142,7 @@ class HelloMessagesCog(commands.Cog): if message.content in self.hello_patterns: hello = self.hello_svc.run(message.author.name) if hello: - await message.channel_name.send(hello) + await message.channel.send(hello) class MessageType(StrEnum): @@ -181,12 +189,10 @@ class MessageHandler: class SaveMessagesCog(commands.Cog): - def __init__(self, bot): + def __init__(self, bot: Bot): self.bot = bot - self.store_svc = SentenceStorerSvc(db=MarkovDatabase(channel=bot.channel_name)) - self.generate_svc = SentenceGeneratorSvc( - db=MarkovDatabase(channel=bot.channel_name) - ) + self.store_svc = SentenceStorerSvc(db=MarkovDatabase(channel=bot.channel)) + self.generate_svc = SentenceGeneratorSvc(db=MarkovDatabase(channel=bot.channel)) self.backoff_svc = BackoffService() self.message_handler = MessageHandler(self._send_message) @@ -204,7 +210,7 @@ class SaveMessagesCog(commands.Cog): async def typed_send(content: str): if hasattr(self, "current_message"): - await self.current_message.channel_name.send(content) + await self.current_message.channel.send(content) # Set a unique name for the function to ensure it's treated as distinct typed_send.__name__ = f"send_{type_name}" @@ -213,7 +219,7 @@ class SaveMessagesCog(commands.Cog): async def _send_message(self, content: str): """Generic send message function (for non-backoff uses)""" if hasattr(self, "current_message"): - await self.current_message.channel_name.send(content) + await self.current_message.channel.send(content) @commands.Cog.event() async def event_message(self, message): @@ -266,6 +272,43 @@ class BotsManager: self.bots[user.username] = bot async def run_user_bot(self, user: User): + if user.username not in self.bots: + return None + + logger.info(f"Starting bot for {user.username}") + + bot = self.bots[user.username] + + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=2, min=2, max=60), + retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)), + ) + async def start_bot_with_retry(): + await bot.start() + + task = asyncio.create_task(start_bot_with_retry()) + + def on_bot_done(future): + try: + if future.cancelled(): + logger.warning(f"Bot for {user.username} was cancelled") + elif future.exception(): + logger.error( + f"Bot for {user.username} failed: {future.exception()}" + ) + else: + logger.info(f"Bot for {user.username} stopped normally") + except Exception as e: # noqa: BLE001 + logger.error(f"Error in bot completion callback: {e}") + + task.add_done_callback(on_bot_done) + + bot.start_routines() + + return task + + async def run_user_bot2(self, user: User): if user.username not in self.bots: return @@ -275,6 +318,7 @@ class BotsManager: task.add_done_callback( lambda x: logger.info(f"Bot for {user.username} stopped") ) + bot.start_routines() async def stop_user_bot(self, user: User): diff --git a/uv.lock b/uv.lock index f4b6cb4..b349f89 100644 --- a/uv.lock +++ b/uv.lock @@ -534,6 +534,7 @@ dependencies = [ { name = "pyjwt" }, { name = "pytz" }, { name = "redis" }, + { name = "tenacity" }, { name = "twitchio" }, ] @@ -568,6 +569,7 @@ requires-dist = [ { name = "pyjwt", specifier = ">=2.10.1" }, { name = "pytz", specifier = ">=2024.2" }, { name = "redis", specifier = ">=5.2.1" }, + { name = "tenacity", specifier = ">=9.0.0" }, { name = "twitchio", specifier = ">=2.10.0" }, ] @@ -1404,6 +1406,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, ] +[[package]] +name = "tenacity" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/94/91fccdb4b8110642462e653d5dcb27e7b674742ad68efd146367da7bdb10/tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b", size = 47421 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/cb/b86984bed139586d01532a587464b5805f12e397594f19f931c4c2fbfa61/tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539", size = 28169 }, +] + [[package]] name = "tomli" version = "2.2.1"