From 3bc4e19de1cb5b30377f0da2c799645093dfc317 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 19 Dec 2024 18:13:38 +0100 Subject: [PATCH] feat: add backoff service and some message reactions --- pyproject.toml | 2 +- src/huesoporro/api/errors.py | 2 +- src/huesoporro/api/routes/api.py | 18 ++- src/huesoporro/bot.py | 172 ++++++++++++++++++++++++- src/huesoporro/svc/backoff_service.py | 111 ++++++++++++++++ src/huesoporro/svc/generate.py | 12 +- src/huesoporro/svc/get_random_quote.py | 2 +- src/huesoporro/svc/hello.py | 4 +- tests/conftest.py | 32 ++++- tests/test_svc.py | 61 +++++++++ uv.lock | 2 +- 11 files changed, 394 insertions(+), 24 deletions(-) create mode 100644 src/huesoporro/svc/backoff_service.py diff --git a/pyproject.toml b/pyproject.toml index 9d19c3b..9333e7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ extend-select = [ "W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET", "SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF" ] -extend-ignore = ["S101", "ISC002", "COM812", "ISC001"] +extend-ignore = ["S101", "ISC002", "COM812", "ISC001", "EM101", "EM102"] [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/src/huesoporro/api/errors.py b/src/huesoporro/api/errors.py index e88a6f0..19426df 100644 --- a/src/huesoporro/api/errors.py +++ b/src/huesoporro/api/errors.py @@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError): ) -async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821 +async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821 """Hook function that will be invoked after each exception.""" state = scope["app"].state if not hasattr(state, "error_count"): diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index dff54c2..d427cde 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template: @put("/api/v1/bot") async def manage_bot( - user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager + user: User, + data: ManageBotDTO, + gbs: ChatbotSettingsGetterSvc, + sbs: ChatbotSettingsStorerSvc, + bm: BotsManager, ) -> Response: chatbot_settings = await gbs.run(user=user) + if not chatbot_settings: + await sbs.run(user=user, bot_settings=ChatbotSettings()) + chatbot_settings = await gbs.run(user=user) if data.command == "start": if not data.channel_name: return Response({"message": "Channel name is required"}, status_code=400) - bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) + bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type] if user.user in bm.bots: await bm.run_user_bot(user) return Response({"message": "Bot started"}) @@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict: @get("/api/v1/bot/settings") async def get_bot_settings( user: User, gbs: ChatbotSettingsGetterSvc -) -> ChatbotSettings: - return await gbs.run(user=user) +) -> ChatbotSettings | dict: + cbs = await gbs.run(user=user) + if not cbs: + return {"status": "Not found"} + return cbs @put("/api/v1/bot/settings") diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index 82151c0..fae650e 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -1,4 +1,7 @@ import asyncio +import random +from collections.abc import Callable +from enum import StrEnum from loguru import logger from twitchio import Channel @@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction from src.huesoporro.infra.db import Database from src.huesoporro.libs.db import Database as MarkovDB from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.generate import SentenceGeneratorSvc from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc from src.huesoporro.svc.hello import HelloGeneratorSvc @@ -75,16 +79,18 @@ class Bot(commands.Bot): @commands.command(aliases=["q", "quote"]) async def get_random_quote(self, ctx: commands.Context): quote = await self.get_random_quote_svc.run(channel_name=self.channel) - await ctx.send(f"«{quote[0]}» - {quote[1]}") + if quote: + await ctx.send(f"«{quote[0]}» - {quote[1]}") def get_channel_conn(self) -> Channel: return Channel(name=self.channel, websocket=self._connection) async def send_quote(self): quote = await self.get_random_quote_svc.run(channel_name=self.channel) - channel = self.get_channel_conn() - logger.info(f"Sending random quote {quote[0]}") - await channel.send(f"«{quote[0]}» - {quote[1]}") + if quote: + channel = self.get_channel_conn() + logger.info(f"Sending random quote {quote[0]}") + await channel.send(f"«{quote[0]}» - {quote[1]}") async def send_generation(self): sentence = await self.generate_svc.run() @@ -108,14 +114,15 @@ class Bot(commands.Bot): self.generation_routine.cancel() -class SaveMessagesCog(commands.Cog): +class SaveMessagesCog2(commands.Cog): def __init__(self, bot): self.bot = bot self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) + self.hello_svc = HelloGeneratorSvc() + self.backoff_svc = BackoffService() @commands.Cog.event() async def event_message(self, message): - # An event inside a cog! content = message.content if content.startswith("!"): return @@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog): await self.store_svc.run(content) + if message.content in ["hola", "HOLA", "hiii", "ayo"]: + hello_message = self.hello_svc.run(message.author.name) + await message.channel.send(hello_message) + return + + if message.content == "Yes": + await message.channel.send("Indeed") + return + + if message.content.startswith("WHAT"): + await message.channel.send("WHAT Ramon") + return + + laughs_messages = [ + "om", + "KEK", + "LuL", + "LUL", + "OMEGALUL", + "kek", + "keking", + "KEKW", + "OMEGADANCEBUTFAST", + ] + + if message.content in laughs_messages: + await message.channel.send(random.choice(laughs_messages)) # noqa: S311 + return + + +class MessageType(StrEnum): + COMMAND = "COMMAND" + HELLO = "HELLO" + YES = "YES" + WHAT = "WHAT" + LAUGH = "LAUGH" + OTHER = "OTHER" + + +class MessageHandler: + """Handles different types of messages with their corresponding responses""" + + def __init__(self, channel_send_func: Callable): + self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"] + self.laugh_patterns = [ + "om", + "KEK", + "LuL", + "LUL", + "OMEGALUL", + "kek", + "keking", + "KEKW", + "OMEGADANCEBUTFAST", + ] + self.send = channel_send_func + + def get_message_type(self, content: str) -> MessageType: + """Determines the type of message based on its content""" + if content.startswith("!"): + return MessageType.COMMAND + if content in self.hello_patterns: + return MessageType.HELLO + if content == "Yes": + return MessageType.YES + if content.startswith("WHAT"): + return MessageType.WHAT + if content in self.laugh_patterns: + return MessageType.LAUGH + return MessageType.OTHER + + async def handle_hello(self, author_name: str, hello_svc) -> str: + """Handles hello messages""" + return hello_svc.run(author_name) + + async def handle_laugh(self) -> str: + """Handles laugh messages""" + return random.choice(self.laugh_patterns) # noqa: S311 + + +class SaveMessagesCog(commands.Cog): + def __init__(self, bot): + self.bot = bot + self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) + self.hello_svc = HelloGeneratorSvc() + self.backoff_svc = BackoffService() + self.message_handler = MessageHandler(self._send_message) + + # Register a separate send function for each message type + self.send_functions = { + MessageType.HELLO: self._create_typed_send("hello"), + MessageType.YES: self._create_typed_send("yes"), + MessageType.WHAT: self._create_typed_send("what"), + MessageType.LAUGH: self._create_typed_send("laugh"), + } + + # Register each send function with its own backoff + for func in self.send_functions.values(): + self.backoff_svc.add_callable(func, backoff_seconds=10) + + def _create_typed_send(self, type_name: str): + """Creates a send function for a specific message type""" + + async def typed_send(content: str): + if hasattr(self, "current_message"): + await self.current_message.channel.send(content) + + # Set a unique name for the function to ensure it's treated as distinct + typed_send.__name__ = f"send_{type_name}" + return typed_send + + async def _send_message(self, content: str): + """Generic send message function (for non-backoff uses)""" + if hasattr(self, "current_message"): + await self.current_message.channel.send(content) + + @commands.Cog.event() + async def event_message(self, message): + """Main message event handler""" + if not message.author: + return + + # Store reference to current message for send functions + self.current_message = message + + # Store the message content + await self.store_svc.run(message.content) + + # Determine message type and handle accordingly + msg_type = self.message_handler.get_message_type(message.content) + + response = None + + match msg_type: + case MessageType.COMMAND: + return + case MessageType.HELLO: + response = await self.message_handler.handle_hello( + message.author.name, self.hello_svc + ) + case MessageType.YES: + response = "Indeed" + case MessageType.WHAT: + response = "WHAT Ramon" + case MessageType.LAUGH: + response = await self.message_handler.handle_laugh() + case MessageType.OTHER: + return + + if response and msg_type in self.send_functions: + # Use the type-specific send function + await self.backoff_svc.call_async(self.send_functions[msg_type], response) + class BotsManager: def __init__(self): diff --git a/src/huesoporro/svc/backoff_service.py b/src/huesoporro/svc/backoff_service.py new file mode 100644 index 0000000..08123d1 --- /dev/null +++ b/src/huesoporro/svc/backoff_service.py @@ -0,0 +1,111 @@ +import asyncio +import time +from collections.abc import Callable + +from pydantic import BaseModel + + +class CallableInfo(BaseModel): + backoff_seconds: int + last_call: float | None = None + is_async: bool + + +class BackoffService(BaseModel): + """Use this service to implement a backoff strategy on random callables. + The callable will be called the first time without delay but every subsequent + call may be hold off for a given time + + Examples: + >>> def callable(x): print(f"foo {x}") + >>> backoff_service = BackoffService() + >>> backoff_service.add_callable(callable, backoff_time=3) + >>> backoff_service.call(callable, "bar") # prints "foo bar" + >>> backoff_service.call(callable, "baz") # prints nothing + >>> # wait 3 seconds before calling callable again + >>> backoff_service.call(callable, "qux") # prints "foo qux" + """ + + callables: dict[Callable, CallableInfo] = {} + + def add_callable(self, func: Callable, backoff_seconds: int): + """Adds a callable to the local mapper with its backoff configuration. + + Args: + func: The function to be registered + backoff_seconds: The number of seconds to wait between successive calls + """ + self.callables[func] = CallableInfo( + backoff_seconds=backoff_seconds, is_async=self._is_async(func) + ) + + @staticmethod + def _is_async(func: Callable) -> bool: + """Checks if the callable is async""" + return asyncio.iscoroutinefunction(func) + + def _can_call(self, func: Callable) -> bool: + """Determines if enough time has passed since the last call""" + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + func_info = self.callables[func] + last_call = func_info.last_call + + if last_call is None: + return True + + elapsed = time.time() - last_call + return elapsed >= func_info.backoff_seconds + + def call(self, func: Callable, *args, **kwargs): + """Calls the callable with arguments and returns its result if it isn't held off + + Args: + func: The function to call + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Optional[Any]: The result of the function call if executed, None if held off + """ + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + if self.callables[func].is_async: + raise ValueError( + "Cannot call async function with .call(), use .call_async() instead" + ) + + if not self._can_call(func): + return None + + result = func(*args, **kwargs) + self.callables[func].last_call = time.time() + return result + + async def call_async(self, func: Callable, *args, **kwargs): + """Same as .call(...) but for async functions + + Args: + func: The async function to call + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Optional[Any]: The result of the async function call if executed, None if held off + """ + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + if not self.callables[func].is_async: + raise ValueError( + "Cannot call sync function with .call_async(), use .call() instead" + ) + + if not self._can_call(func): + return None + + result = await func(*args, **kwargs) + self.callables[func].last_call = time.time() + return result diff --git a/src/huesoporro/svc/generate.py b/src/huesoporro/svc/generate.py index 0f5058d..5bef7fe 100644 --- a/src/huesoporro/svc/generate.py +++ b/src/huesoporro/svc/generate.py @@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel): self, sentence: str | None = None, ) -> str | None: - if sentence: - sentence = tokenize(sentence) - logger.info(f"Generating sentence from {sentence}") - sentence, success = self.generate(sentence) - logger.info(f"Generated sentence: {sentence}") + split_sentence = tokenize(sentence) if sentence else None + + logger.info(f"Generating sentence from {split_sentence}") + generated_sentence, success = self.generate(split_sentence) + logger.info(f"Generated sentence: {generated_sentence}") if not success: return None - return sentence + return generated_sentence diff --git a/src/huesoporro/svc/get_random_quote.py b/src/huesoporro/svc/get_random_quote.py index 4371d4e..f6b4f6c 100644 --- a/src/huesoporro/svc/get_random_quote.py +++ b/src/huesoporro/svc/get_random_quote.py @@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database class RandomQuoteGetterSvc(BaseModel): db: Database - async def run(self, channel_name: str) -> tuple[str, str]: + async def run(self, channel_name: str) -> tuple[str, str] | None: return await self.db.get_random_quote(channel_name=channel_name) diff --git a/src/huesoporro/svc/hello.py b/src/huesoporro/svc/hello.py index 19dd38f..5be2180 100644 --- a/src/huesoporro/svc/hello.py +++ b/src/huesoporro/svc/hello.py @@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel): "Hi", "Bon día", "Hola mi tremendo elemento", + "HOLA", + "hiii", ] ) def run(self, username: str): - return f"{random.choice(self.hellos)} {username}" # noqa: S311 + return f"{random.choice(self.hellos)} @{username}" # noqa: S311 diff --git a/tests/conftest.py b/tests/conftest.py index b01710c..19218d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from caribou.migrate import load_migrations from src.huesoporro.infra.db import Database from src.huesoporro.models import ChatbotSettings, TwitchAuth, User from src.huesoporro.settings import Settings +from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.is_mod import IsModSvc @@ -16,7 +17,8 @@ def user() -> User: user="huesoporro", expires_at=1671234567.0, twitch_auth=TwitchAuth( - access_token="test_access_token", refresh_token="test_refresh_token" + access_token="test_access_token", # noqa: S106 + refresh_token="test_refresh_token", # noqa: S106 ), ) @@ -27,8 +29,8 @@ def s(tmp_path: Path, user: User) -> Settings: static_files_path=tmp_path / "static_files", db_filepath=tmp_path / "huesoporro.db", twitch_client_id="test_client_id", - twitch_client_secret="test_client_secret", # type: ignore[arg-type] - jwt_secret="test_jwt_secret", # type: ignore[arg-type] + twitch_client_secret="test_client_secret", # type: ignore[arg-type] # noqa: S106 + jwt_secret="test_jwt_secret", # type: ignore[arg-type] # noqa: S106 allowed_users=[user.user], ) @@ -54,3 +56,27 @@ async def chatbot_settings(db: Database, user) -> ChatbotSettings: cbs = ChatbotSettings(mods=[user.user, "allowed_user"]) await db.save_chatbot_settings(user=user, chatbot_settings=cbs) return cbs + + +@pytest.fixture +def backoff_callable(): + def foo(): + return "foo" + + return foo + + +@pytest.fixture +def async_backoff_callable(): + async def async_foo(): + return "async foo" + + return async_foo + + +@pytest.fixture +async def backoff_svc(backoff_callable, async_backoff_callable): + backoff_svc = BackoffService() + backoff_svc.add_callable(backoff_callable, 3) + backoff_svc.add_callable(async_backoff_callable, 3) + return backoff_svc diff --git a/tests/test_svc.py b/tests/test_svc.py index 5a5aacd..c88e1c9 100644 --- a/tests/test_svc.py +++ b/tests/test_svc.py @@ -1,3 +1,6 @@ +import asyncio +import time + import pytest from src.huesoporro.models import ChatbotSettings, User @@ -33,3 +36,61 @@ async def test_is_mod_svc_returns_false_for_user_not_in_modlist( ): is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user) assert not is_mod + + +async def test_backoff_svc_returns_for_first_attempt( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + + +async def test_backoff_svc_returns_none_for_second_attempt( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + assert backoff_svc.call(backoff_callable) is None + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + assert await backoff_svc.call_async(async_backoff_callable) is None + + +async def test_backoff_svc_returns_for_second_attempt_after_delay( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + assert backoff_svc.call(backoff_callable) is None + time.sleep(3) + assert backoff_svc.call(backoff_callable) == "foo" + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + assert await backoff_svc.call_async(async_backoff_callable) is None + await asyncio.sleep(3) + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + + +async def test_backoff_svc_raises_value_error_for_unknown_callable(backoff_svc): + with pytest.raises(ValueError, match="not registered with backoff service"): + backoff_svc.call(lambda: "foo") + + +async def test_backoff_svc_raises_value_error_for_unknown_async_callable(backoff_svc): + with pytest.raises(ValueError, match="not registered with backoff service"): + await backoff_svc.call_async(lambda: "foo") + + +async def test_backoff_svc_raises_value_error_for_async_called_from_sync( + backoff_svc, backoff_callable +): + with pytest.raises( + ValueError, match="Cannot call sync function with .call_async()" + ): + await backoff_svc.call_async(backoff_callable) + + +async def test_backoff_svc_raises_value_error_for_sync_called_from_async( + backoff_svc, async_backoff_callable +): + with pytest.raises(ValueError, match="Cannot call async function with .call()"): + backoff_svc.call(async_backoff_callable) diff --git a/uv.lock b/uv.lock index 618cfee..458a1a4 100644 --- a/uv.lock +++ b/uv.lock @@ -460,7 +460,7 @@ wheels = [ [[package]] name = "huesoporro" -version = "0.2.2" +version = "0.2.3" source = { virtual = "." } dependencies = [ { name = "aiosqlite" },