feat: add backoff service and some message reactions

This commit is contained in:
cătălin 2024-12-19 18:13:38 +01:00
commit 3bc4e19de1
No known key found for this signature in database
11 changed files with 394 additions and 24 deletions

View file

@ -52,7 +52,7 @@ extend-select = [
"W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET", "W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET",
"SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF" "SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF"
] ]
extend-ignore = ["S101", "ISC002", "COM812", "ISC001"] extend-ignore = ["S101", "ISC002", "COM812", "ISC001", "EM101", "EM102"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "auto"

View file

@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
) )
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821 async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821
"""Hook function that will be invoked after each exception.""" """Hook function that will be invoked after each exception."""
state = scope["app"].state state = scope["app"].state
if not hasattr(state, "error_count"): if not hasattr(state, "error_count"):

View file

@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
@put("/api/v1/bot") @put("/api/v1/bot")
async def manage_bot( async def manage_bot(
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager user: User,
data: ManageBotDTO,
gbs: ChatbotSettingsGetterSvc,
sbs: ChatbotSettingsStorerSvc,
bm: BotsManager,
) -> Response: ) -> Response:
chatbot_settings = await gbs.run(user=user) chatbot_settings = await gbs.run(user=user)
if not chatbot_settings:
await sbs.run(user=user, bot_settings=ChatbotSettings())
chatbot_settings = await gbs.run(user=user)
if data.command == "start": if data.command == "start":
if not data.channel_name: if not data.channel_name:
return Response({"message": "Channel name is required"}, status_code=400) return Response({"message": "Channel name is required"}, status_code=400)
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
if user.user in bm.bots: if user.user in bm.bots:
await bm.run_user_bot(user) await bm.run_user_bot(user)
return Response({"message": "Bot started"}) return Response({"message": "Bot started"})
@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
@get("/api/v1/bot/settings") @get("/api/v1/bot/settings")
async def get_bot_settings( async def get_bot_settings(
user: User, gbs: ChatbotSettingsGetterSvc user: User, gbs: ChatbotSettingsGetterSvc
) -> ChatbotSettings: ) -> ChatbotSettings | dict:
return await gbs.run(user=user) cbs = await gbs.run(user=user)
if not cbs:
return {"status": "Not found"}
return cbs
@put("/api/v1/bot/settings") @put("/api/v1/bot/settings")

View file

@ -1,4 +1,7 @@
import asyncio import asyncio
import random
from collections.abc import Callable
from enum import StrEnum
from loguru import logger from loguru import logger
from twitchio import Channel from twitchio import Channel
@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction
from src.huesoporro.infra.db import Database from src.huesoporro.infra.db import Database
from src.huesoporro.libs.db import Database as MarkovDB from src.huesoporro.libs.db import Database as MarkovDB
from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.models import ChatbotSettings, User
from src.huesoporro.svc.backoff_service import BackoffService
from src.huesoporro.svc.generate import SentenceGeneratorSvc from src.huesoporro.svc.generate import SentenceGeneratorSvc
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
from src.huesoporro.svc.hello import HelloGeneratorSvc from src.huesoporro.svc.hello import HelloGeneratorSvc
@ -75,6 +79,7 @@ class Bot(commands.Bot):
@commands.command(aliases=["q", "quote"]) @commands.command(aliases=["q", "quote"])
async def get_random_quote(self, ctx: commands.Context): async def get_random_quote(self, ctx: commands.Context):
quote = await self.get_random_quote_svc.run(channel_name=self.channel) quote = await self.get_random_quote_svc.run(channel_name=self.channel)
if quote:
await ctx.send(f"«{quote[0]}» - {quote[1]}") await ctx.send(f"«{quote[0]}» - {quote[1]}")
def get_channel_conn(self) -> Channel: def get_channel_conn(self) -> Channel:
@ -82,6 +87,7 @@ class Bot(commands.Bot):
async def send_quote(self): async def send_quote(self):
quote = await self.get_random_quote_svc.run(channel_name=self.channel) quote = await self.get_random_quote_svc.run(channel_name=self.channel)
if quote:
channel = self.get_channel_conn() channel = self.get_channel_conn()
logger.info(f"Sending random quote {quote[0]}") logger.info(f"Sending random quote {quote[0]}")
await channel.send(f"«{quote[0]}» - {quote[1]}") await channel.send(f"«{quote[0]}» - {quote[1]}")
@ -108,14 +114,15 @@ class Bot(commands.Bot):
self.generation_routine.cancel() self.generation_routine.cancel()
class SaveMessagesCog(commands.Cog): class SaveMessagesCog2(commands.Cog):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
self.hello_svc = HelloGeneratorSvc()
self.backoff_svc = BackoffService()
@commands.Cog.event() @commands.Cog.event()
async def event_message(self, message): async def event_message(self, message):
# An event inside a cog!
content = message.content content = message.content
if content.startswith("!"): if content.startswith("!"):
return return
@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog):
await self.store_svc.run(content) await self.store_svc.run(content)
if message.content in ["hola", "HOLA", "hiii", "ayo"]:
hello_message = self.hello_svc.run(message.author.name)
await message.channel.send(hello_message)
return
if message.content == "Yes":
await message.channel.send("Indeed")
return
if message.content.startswith("WHAT"):
await message.channel.send("WHAT Ramon")
return
laughs_messages = [
"om",
"KEK",
"LuL",
"LUL",
"OMEGALUL",
"kek",
"keking",
"KEKW",
"OMEGADANCEBUTFAST",
]
if message.content in laughs_messages:
await message.channel.send(random.choice(laughs_messages)) # noqa: S311
return
class MessageType(StrEnum):
COMMAND = "COMMAND"
HELLO = "HELLO"
YES = "YES"
WHAT = "WHAT"
LAUGH = "LAUGH"
OTHER = "OTHER"
class MessageHandler:
"""Handles different types of messages with their corresponding responses"""
def __init__(self, channel_send_func: Callable):
self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"]
self.laugh_patterns = [
"om",
"KEK",
"LuL",
"LUL",
"OMEGALUL",
"kek",
"keking",
"KEKW",
"OMEGADANCEBUTFAST",
]
self.send = channel_send_func
def get_message_type(self, content: str) -> MessageType:
"""Determines the type of message based on its content"""
if content.startswith("!"):
return MessageType.COMMAND
if content in self.hello_patterns:
return MessageType.HELLO
if content == "Yes":
return MessageType.YES
if content.startswith("WHAT"):
return MessageType.WHAT
if content in self.laugh_patterns:
return MessageType.LAUGH
return MessageType.OTHER
async def handle_hello(self, author_name: str, hello_svc) -> str:
"""Handles hello messages"""
return hello_svc.run(author_name)
async def handle_laugh(self) -> str:
"""Handles laugh messages"""
return random.choice(self.laugh_patterns) # noqa: S311
class SaveMessagesCog(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
self.hello_svc = HelloGeneratorSvc()
self.backoff_svc = BackoffService()
self.message_handler = MessageHandler(self._send_message)
# Register a separate send function for each message type
self.send_functions = {
MessageType.HELLO: self._create_typed_send("hello"),
MessageType.YES: self._create_typed_send("yes"),
MessageType.WHAT: self._create_typed_send("what"),
MessageType.LAUGH: self._create_typed_send("laugh"),
}
# Register each send function with its own backoff
for func in self.send_functions.values():
self.backoff_svc.add_callable(func, backoff_seconds=10)
def _create_typed_send(self, type_name: str):
"""Creates a send function for a specific message type"""
async def typed_send(content: str):
if hasattr(self, "current_message"):
await self.current_message.channel.send(content)
# Set a unique name for the function to ensure it's treated as distinct
typed_send.__name__ = f"send_{type_name}"
return typed_send
async def _send_message(self, content: str):
"""Generic send message function (for non-backoff uses)"""
if hasattr(self, "current_message"):
await self.current_message.channel.send(content)
@commands.Cog.event()
async def event_message(self, message):
"""Main message event handler"""
if not message.author:
return
# Store reference to current message for send functions
self.current_message = message
# Store the message content
await self.store_svc.run(message.content)
# Determine message type and handle accordingly
msg_type = self.message_handler.get_message_type(message.content)
response = None
match msg_type:
case MessageType.COMMAND:
return
case MessageType.HELLO:
response = await self.message_handler.handle_hello(
message.author.name, self.hello_svc
)
case MessageType.YES:
response = "Indeed"
case MessageType.WHAT:
response = "WHAT Ramon"
case MessageType.LAUGH:
response = await self.message_handler.handle_laugh()
case MessageType.OTHER:
return
if response and msg_type in self.send_functions:
# Use the type-specific send function
await self.backoff_svc.call_async(self.send_functions[msg_type], response)
class BotsManager: class BotsManager:
def __init__(self): def __init__(self):

View file

@ -0,0 +1,111 @@
import asyncio
import time
from collections.abc import Callable
from pydantic import BaseModel
class CallableInfo(BaseModel):
backoff_seconds: int
last_call: float | None = None
is_async: bool
class BackoffService(BaseModel):
"""Use this service to implement a backoff strategy on random callables.
The callable will be called the first time without delay but every subsequent
call may be hold off for a given time
Examples:
>>> def callable(x): print(f"foo {x}")
>>> backoff_service = BackoffService()
>>> backoff_service.add_callable(callable, backoff_time=3)
>>> backoff_service.call(callable, "bar") # prints "foo bar"
>>> backoff_service.call(callable, "baz") # prints nothing
>>> # wait 3 seconds before calling callable again
>>> backoff_service.call(callable, "qux") # prints "foo qux"
"""
callables: dict[Callable, CallableInfo] = {}
def add_callable(self, func: Callable, backoff_seconds: int):
"""Adds a callable to the local mapper with its backoff configuration.
Args:
func: The function to be registered
backoff_seconds: The number of seconds to wait between successive calls
"""
self.callables[func] = CallableInfo(
backoff_seconds=backoff_seconds, is_async=self._is_async(func)
)
@staticmethod
def _is_async(func: Callable) -> bool:
"""Checks if the callable is async"""
return asyncio.iscoroutinefunction(func)
def _can_call(self, func: Callable) -> bool:
"""Determines if enough time has passed since the last call"""
if func not in self.callables:
raise ValueError(f"Function {func} not registered with backoff service")
func_info = self.callables[func]
last_call = func_info.last_call
if last_call is None:
return True
elapsed = time.time() - last_call
return elapsed >= func_info.backoff_seconds
def call(self, func: Callable, *args, **kwargs):
"""Calls the callable with arguments and returns its result if it isn't held off
Args:
func: The function to call
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Optional[Any]: The result of the function call if executed, None if held off
"""
if func not in self.callables:
raise ValueError(f"Function {func} not registered with backoff service")
if self.callables[func].is_async:
raise ValueError(
"Cannot call async function with .call(), use .call_async() instead"
)
if not self._can_call(func):
return None
result = func(*args, **kwargs)
self.callables[func].last_call = time.time()
return result
async def call_async(self, func: Callable, *args, **kwargs):
"""Same as .call(...) but for async functions
Args:
func: The async function to call
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Optional[Any]: The result of the async function call if executed, None if held off
"""
if func not in self.callables:
raise ValueError(f"Function {func} not registered with backoff service")
if not self.callables[func].is_async:
raise ValueError(
"Cannot call sync function with .call_async(), use .call() instead"
)
if not self._can_call(func):
return None
result = await func(*args, **kwargs)
self.callables[func].last_call = time.time()
return result

View file

@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel):
self, self,
sentence: str | None = None, sentence: str | None = None,
) -> str | None: ) -> str | None:
if sentence: split_sentence = tokenize(sentence) if sentence else None
sentence = tokenize(sentence)
logger.info(f"Generating sentence from {sentence}") logger.info(f"Generating sentence from {split_sentence}")
sentence, success = self.generate(sentence) generated_sentence, success = self.generate(split_sentence)
logger.info(f"Generated sentence: {sentence}") logger.info(f"Generated sentence: {generated_sentence}")
if not success: if not success:
return None return None
return sentence return generated_sentence

View file

@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database
class RandomQuoteGetterSvc(BaseModel): class RandomQuoteGetterSvc(BaseModel):
db: Database db: Database
async def run(self, channel_name: str) -> tuple[str, str]: async def run(self, channel_name: str) -> tuple[str, str] | None:
return await self.db.get_random_quote(channel_name=channel_name) return await self.db.get_random_quote(channel_name=channel_name)

View file

@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel):
"Hi", "Hi",
"Bon día", "Bon día",
"Hola mi tremendo elemento", "Hola mi tremendo elemento",
"HOLA",
"hiii",
] ]
) )
def run(self, username: str): def run(self, username: str):
return f"{random.choice(self.hellos)} {username}" # noqa: S311 return f"{random.choice(self.hellos)} @{username}" # noqa: S311

View file

@ -7,6 +7,7 @@ from caribou.migrate import load_migrations
from src.huesoporro.infra.db import Database from src.huesoporro.infra.db import Database
from src.huesoporro.models import ChatbotSettings, TwitchAuth, User from src.huesoporro.models import ChatbotSettings, TwitchAuth, User
from src.huesoporro.settings import Settings from src.huesoporro.settings import Settings
from src.huesoporro.svc.backoff_service import BackoffService
from src.huesoporro.svc.is_mod import IsModSvc from src.huesoporro.svc.is_mod import IsModSvc
@ -16,7 +17,8 @@ def user() -> User:
user="huesoporro", user="huesoporro",
expires_at=1671234567.0, expires_at=1671234567.0,
twitch_auth=TwitchAuth( twitch_auth=TwitchAuth(
access_token="test_access_token", refresh_token="test_refresh_token" access_token="test_access_token", # noqa: S106
refresh_token="test_refresh_token", # noqa: S106
), ),
) )
@ -27,8 +29,8 @@ def s(tmp_path: Path, user: User) -> Settings:
static_files_path=tmp_path / "static_files", static_files_path=tmp_path / "static_files",
db_filepath=tmp_path / "huesoporro.db", db_filepath=tmp_path / "huesoporro.db",
twitch_client_id="test_client_id", twitch_client_id="test_client_id",
twitch_client_secret="test_client_secret", # type: ignore[arg-type] twitch_client_secret="test_client_secret", # type: ignore[arg-type] # noqa: S106
jwt_secret="test_jwt_secret", # type: ignore[arg-type] jwt_secret="test_jwt_secret", # type: ignore[arg-type] # noqa: S106
allowed_users=[user.user], allowed_users=[user.user],
) )
@ -54,3 +56,27 @@ async def chatbot_settings(db: Database, user) -> ChatbotSettings:
cbs = ChatbotSettings(mods=[user.user, "allowed_user"]) cbs = ChatbotSettings(mods=[user.user, "allowed_user"])
await db.save_chatbot_settings(user=user, chatbot_settings=cbs) await db.save_chatbot_settings(user=user, chatbot_settings=cbs)
return cbs return cbs
@pytest.fixture
def backoff_callable():
def foo():
return "foo"
return foo
@pytest.fixture
def async_backoff_callable():
async def async_foo():
return "async foo"
return async_foo
@pytest.fixture
async def backoff_svc(backoff_callable, async_backoff_callable):
backoff_svc = BackoffService()
backoff_svc.add_callable(backoff_callable, 3)
backoff_svc.add_callable(async_backoff_callable, 3)
return backoff_svc

View file

@ -1,3 +1,6 @@
import asyncio
import time
import pytest import pytest
from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.models import ChatbotSettings, User
@ -33,3 +36,61 @@ async def test_is_mod_svc_returns_false_for_user_not_in_modlist(
): ):
is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user) is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user)
assert not is_mod assert not is_mod
async def test_backoff_svc_returns_for_first_attempt(
backoff_svc, backoff_callable, async_backoff_callable
):
assert backoff_svc.call(backoff_callable) == "foo"
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
async def test_backoff_svc_returns_none_for_second_attempt(
backoff_svc, backoff_callable, async_backoff_callable
):
assert backoff_svc.call(backoff_callable) == "foo"
assert backoff_svc.call(backoff_callable) is None
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
assert await backoff_svc.call_async(async_backoff_callable) is None
async def test_backoff_svc_returns_for_second_attempt_after_delay(
backoff_svc, backoff_callable, async_backoff_callable
):
assert backoff_svc.call(backoff_callable) == "foo"
assert backoff_svc.call(backoff_callable) is None
time.sleep(3)
assert backoff_svc.call(backoff_callable) == "foo"
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
assert await backoff_svc.call_async(async_backoff_callable) is None
await asyncio.sleep(3)
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
async def test_backoff_svc_raises_value_error_for_unknown_callable(backoff_svc):
with pytest.raises(ValueError, match="not registered with backoff service"):
backoff_svc.call(lambda: "foo")
async def test_backoff_svc_raises_value_error_for_unknown_async_callable(backoff_svc):
with pytest.raises(ValueError, match="not registered with backoff service"):
await backoff_svc.call_async(lambda: "foo")
async def test_backoff_svc_raises_value_error_for_async_called_from_sync(
backoff_svc, backoff_callable
):
with pytest.raises(
ValueError, match="Cannot call sync function with .call_async()"
):
await backoff_svc.call_async(backoff_callable)
async def test_backoff_svc_raises_value_error_for_sync_called_from_async(
backoff_svc, async_backoff_callable
):
with pytest.raises(ValueError, match="Cannot call async function with .call()"):
backoff_svc.call(async_backoff_callable)

2
uv.lock generated
View file

@ -460,7 +460,7 @@ wheels = [
[[package]] [[package]]
name = "huesoporro" name = "huesoporro"
version = "0.2.2" version = "0.2.3"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aiosqlite" }, { name = "aiosqlite" },