feat: add backoff service and some message reactions
This commit is contained in:
parent
2cad170eb3
commit
3bc4e19de1
11 changed files with 394 additions and 24 deletions
|
|
@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
|
|||
)
|
||||
|
||||
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
"""Hook function that will be invoked after each exception."""
|
||||
state = scope["app"].state
|
||||
if not hasattr(state, "error_count"):
|
||||
|
|
|
|||
|
|
@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
|||
|
||||
@put("/api/v1/bot")
|
||||
async def manage_bot(
|
||||
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
|
||||
user: User,
|
||||
data: ManageBotDTO,
|
||||
gbs: ChatbotSettingsGetterSvc,
|
||||
sbs: ChatbotSettingsStorerSvc,
|
||||
bm: BotsManager,
|
||||
) -> Response:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
if not chatbot_settings:
|
||||
await sbs.run(user=user, bot_settings=ChatbotSettings())
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
if data.command == "start":
|
||||
if not data.channel_name:
|
||||
return Response({"message": "Channel name is required"}, status_code=400)
|
||||
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings)
|
||||
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
|
||||
if user.user in bm.bots:
|
||||
await bm.run_user_bot(user)
|
||||
return Response({"message": "Bot started"})
|
||||
|
|
@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
|||
@get("/api/v1/bot/settings")
|
||||
async def get_bot_settings(
|
||||
user: User, gbs: ChatbotSettingsGetterSvc
|
||||
) -> ChatbotSettings:
|
||||
return await gbs.run(user=user)
|
||||
) -> ChatbotSettings | dict:
|
||||
cbs = await gbs.run(user=user)
|
||||
if not cbs:
|
||||
return {"status": "Not found"}
|
||||
return cbs
|
||||
|
||||
|
||||
@put("/api/v1/bot/settings")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import asyncio
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum
|
||||
|
||||
from loguru import logger
|
||||
from twitchio import Channel
|
||||
|
|
@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction
|
|||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.libs.db import Database as MarkovDB
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
from src.huesoporro.svc.backoff_service import BackoffService
|
||||
from src.huesoporro.svc.generate import SentenceGeneratorSvc
|
||||
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
||||
|
|
@ -75,16 +79,18 @@ class Bot(commands.Bot):
|
|||
@commands.command(aliases=["q", "quote"])
|
||||
async def get_random_quote(self, ctx: commands.Context):
|
||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||
await ctx.send(f"«{quote[0]}» - {quote[1]}")
|
||||
if quote:
|
||||
await ctx.send(f"«{quote[0]}» - {quote[1]}")
|
||||
|
||||
def get_channel_conn(self) -> Channel:
|
||||
return Channel(name=self.channel, websocket=self._connection)
|
||||
|
||||
async def send_quote(self):
|
||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||
channel = self.get_channel_conn()
|
||||
logger.info(f"Sending random quote {quote[0]}")
|
||||
await channel.send(f"«{quote[0]}» - {quote[1]}")
|
||||
if quote:
|
||||
channel = self.get_channel_conn()
|
||||
logger.info(f"Sending random quote {quote[0]}")
|
||||
await channel.send(f"«{quote[0]}» - {quote[1]}")
|
||||
|
||||
async def send_generation(self):
|
||||
sentence = await self.generate_svc.run()
|
||||
|
|
@ -108,14 +114,15 @@ class Bot(commands.Bot):
|
|||
self.generation_routine.cancel()
|
||||
|
||||
|
||||
class SaveMessagesCog(commands.Cog):
|
||||
class SaveMessagesCog2(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
||||
self.hello_svc = HelloGeneratorSvc()
|
||||
self.backoff_svc = BackoffService()
|
||||
|
||||
@commands.Cog.event()
|
||||
async def event_message(self, message):
|
||||
# An event inside a cog!
|
||||
content = message.content
|
||||
if content.startswith("!"):
|
||||
return
|
||||
|
|
@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog):
|
|||
|
||||
await self.store_svc.run(content)
|
||||
|
||||
if message.content in ["hola", "HOLA", "hiii", "ayo"]:
|
||||
hello_message = self.hello_svc.run(message.author.name)
|
||||
await message.channel.send(hello_message)
|
||||
return
|
||||
|
||||
if message.content == "Yes":
|
||||
await message.channel.send("Indeed")
|
||||
return
|
||||
|
||||
if message.content.startswith("WHAT"):
|
||||
await message.channel.send("WHAT Ramon")
|
||||
return
|
||||
|
||||
laughs_messages = [
|
||||
"om",
|
||||
"KEK",
|
||||
"LuL",
|
||||
"LUL",
|
||||
"OMEGALUL",
|
||||
"kek",
|
||||
"keking",
|
||||
"KEKW",
|
||||
"OMEGADANCEBUTFAST",
|
||||
]
|
||||
|
||||
if message.content in laughs_messages:
|
||||
await message.channel.send(random.choice(laughs_messages)) # noqa: S311
|
||||
return
|
||||
|
||||
|
||||
class MessageType(StrEnum):
|
||||
COMMAND = "COMMAND"
|
||||
HELLO = "HELLO"
|
||||
YES = "YES"
|
||||
WHAT = "WHAT"
|
||||
LAUGH = "LAUGH"
|
||||
OTHER = "OTHER"
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
"""Handles different types of messages with their corresponding responses"""
|
||||
|
||||
def __init__(self, channel_send_func: Callable):
|
||||
self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"]
|
||||
self.laugh_patterns = [
|
||||
"om",
|
||||
"KEK",
|
||||
"LuL",
|
||||
"LUL",
|
||||
"OMEGALUL",
|
||||
"kek",
|
||||
"keking",
|
||||
"KEKW",
|
||||
"OMEGADANCEBUTFAST",
|
||||
]
|
||||
self.send = channel_send_func
|
||||
|
||||
def get_message_type(self, content: str) -> MessageType:
|
||||
"""Determines the type of message based on its content"""
|
||||
if content.startswith("!"):
|
||||
return MessageType.COMMAND
|
||||
if content in self.hello_patterns:
|
||||
return MessageType.HELLO
|
||||
if content == "Yes":
|
||||
return MessageType.YES
|
||||
if content.startswith("WHAT"):
|
||||
return MessageType.WHAT
|
||||
if content in self.laugh_patterns:
|
||||
return MessageType.LAUGH
|
||||
return MessageType.OTHER
|
||||
|
||||
async def handle_hello(self, author_name: str, hello_svc) -> str:
|
||||
"""Handles hello messages"""
|
||||
return hello_svc.run(author_name)
|
||||
|
||||
async def handle_laugh(self) -> str:
|
||||
"""Handles laugh messages"""
|
||||
return random.choice(self.laugh_patterns) # noqa: S311
|
||||
|
||||
|
||||
class SaveMessagesCog(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
||||
self.hello_svc = HelloGeneratorSvc()
|
||||
self.backoff_svc = BackoffService()
|
||||
self.message_handler = MessageHandler(self._send_message)
|
||||
|
||||
# Register a separate send function for each message type
|
||||
self.send_functions = {
|
||||
MessageType.HELLO: self._create_typed_send("hello"),
|
||||
MessageType.YES: self._create_typed_send("yes"),
|
||||
MessageType.WHAT: self._create_typed_send("what"),
|
||||
MessageType.LAUGH: self._create_typed_send("laugh"),
|
||||
}
|
||||
|
||||
# Register each send function with its own backoff
|
||||
for func in self.send_functions.values():
|
||||
self.backoff_svc.add_callable(func, backoff_seconds=10)
|
||||
|
||||
def _create_typed_send(self, type_name: str):
|
||||
"""Creates a send function for a specific message type"""
|
||||
|
||||
async def typed_send(content: str):
|
||||
if hasattr(self, "current_message"):
|
||||
await self.current_message.channel.send(content)
|
||||
|
||||
# Set a unique name for the function to ensure it's treated as distinct
|
||||
typed_send.__name__ = f"send_{type_name}"
|
||||
return typed_send
|
||||
|
||||
async def _send_message(self, content: str):
|
||||
"""Generic send message function (for non-backoff uses)"""
|
||||
if hasattr(self, "current_message"):
|
||||
await self.current_message.channel.send(content)
|
||||
|
||||
@commands.Cog.event()
|
||||
async def event_message(self, message):
|
||||
"""Main message event handler"""
|
||||
if not message.author:
|
||||
return
|
||||
|
||||
# Store reference to current message for send functions
|
||||
self.current_message = message
|
||||
|
||||
# Store the message content
|
||||
await self.store_svc.run(message.content)
|
||||
|
||||
# Determine message type and handle accordingly
|
||||
msg_type = self.message_handler.get_message_type(message.content)
|
||||
|
||||
response = None
|
||||
|
||||
match msg_type:
|
||||
case MessageType.COMMAND:
|
||||
return
|
||||
case MessageType.HELLO:
|
||||
response = await self.message_handler.handle_hello(
|
||||
message.author.name, self.hello_svc
|
||||
)
|
||||
case MessageType.YES:
|
||||
response = "Indeed"
|
||||
case MessageType.WHAT:
|
||||
response = "WHAT Ramon"
|
||||
case MessageType.LAUGH:
|
||||
response = await self.message_handler.handle_laugh()
|
||||
case MessageType.OTHER:
|
||||
return
|
||||
|
||||
if response and msg_type in self.send_functions:
|
||||
# Use the type-specific send function
|
||||
await self.backoff_svc.call_async(self.send_functions[msg_type], response)
|
||||
|
||||
|
||||
class BotsManager:
|
||||
def __init__(self):
|
||||
|
|
|
|||
111
src/huesoporro/svc/backoff_service.py
Normal file
111
src/huesoporro/svc/backoff_service.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CallableInfo(BaseModel):
|
||||
backoff_seconds: int
|
||||
last_call: float | None = None
|
||||
is_async: bool
|
||||
|
||||
|
||||
class BackoffService(BaseModel):
|
||||
"""Use this service to implement a backoff strategy on random callables.
|
||||
The callable will be called the first time without delay but every subsequent
|
||||
call may be hold off for a given time
|
||||
|
||||
Examples:
|
||||
>>> def callable(x): print(f"foo {x}")
|
||||
>>> backoff_service = BackoffService()
|
||||
>>> backoff_service.add_callable(callable, backoff_time=3)
|
||||
>>> backoff_service.call(callable, "bar") # prints "foo bar"
|
||||
>>> backoff_service.call(callable, "baz") # prints nothing
|
||||
>>> # wait 3 seconds before calling callable again
|
||||
>>> backoff_service.call(callable, "qux") # prints "foo qux"
|
||||
"""
|
||||
|
||||
callables: dict[Callable, CallableInfo] = {}
|
||||
|
||||
def add_callable(self, func: Callable, backoff_seconds: int):
|
||||
"""Adds a callable to the local mapper with its backoff configuration.
|
||||
|
||||
Args:
|
||||
func: The function to be registered
|
||||
backoff_seconds: The number of seconds to wait between successive calls
|
||||
"""
|
||||
self.callables[func] = CallableInfo(
|
||||
backoff_seconds=backoff_seconds, is_async=self._is_async(func)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_async(func: Callable) -> bool:
|
||||
"""Checks if the callable is async"""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
def _can_call(self, func: Callable) -> bool:
|
||||
"""Determines if enough time has passed since the last call"""
|
||||
if func not in self.callables:
|
||||
raise ValueError(f"Function {func} not registered with backoff service")
|
||||
|
||||
func_info = self.callables[func]
|
||||
last_call = func_info.last_call
|
||||
|
||||
if last_call is None:
|
||||
return True
|
||||
|
||||
elapsed = time.time() - last_call
|
||||
return elapsed >= func_info.backoff_seconds
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs):
|
||||
"""Calls the callable with arguments and returns its result if it isn't held off
|
||||
|
||||
Args:
|
||||
func: The function to call
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The result of the function call if executed, None if held off
|
||||
"""
|
||||
if func not in self.callables:
|
||||
raise ValueError(f"Function {func} not registered with backoff service")
|
||||
|
||||
if self.callables[func].is_async:
|
||||
raise ValueError(
|
||||
"Cannot call async function with .call(), use .call_async() instead"
|
||||
)
|
||||
|
||||
if not self._can_call(func):
|
||||
return None
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
self.callables[func].last_call = time.time()
|
||||
return result
|
||||
|
||||
async def call_async(self, func: Callable, *args, **kwargs):
|
||||
"""Same as .call(...) but for async functions
|
||||
|
||||
Args:
|
||||
func: The async function to call
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The result of the async function call if executed, None if held off
|
||||
"""
|
||||
if func not in self.callables:
|
||||
raise ValueError(f"Function {func} not registered with backoff service")
|
||||
|
||||
if not self.callables[func].is_async:
|
||||
raise ValueError(
|
||||
"Cannot call sync function with .call_async(), use .call() instead"
|
||||
)
|
||||
|
||||
if not self._can_call(func):
|
||||
return None
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
self.callables[func].last_call = time.time()
|
||||
return result
|
||||
|
|
@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel):
|
|||
self,
|
||||
sentence: str | None = None,
|
||||
) -> str | None:
|
||||
if sentence:
|
||||
sentence = tokenize(sentence)
|
||||
logger.info(f"Generating sentence from {sentence}")
|
||||
sentence, success = self.generate(sentence)
|
||||
logger.info(f"Generated sentence: {sentence}")
|
||||
split_sentence = tokenize(sentence) if sentence else None
|
||||
|
||||
logger.info(f"Generating sentence from {split_sentence}")
|
||||
generated_sentence, success = self.generate(split_sentence)
|
||||
logger.info(f"Generated sentence: {generated_sentence}")
|
||||
if not success:
|
||||
return None
|
||||
return sentence
|
||||
return generated_sentence
|
||||
|
|
|
|||
|
|
@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database
|
|||
class RandomQuoteGetterSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, channel_name: str) -> tuple[str, str]:
|
||||
async def run(self, channel_name: str) -> tuple[str, str] | None:
|
||||
return await self.db.get_random_quote(channel_name=channel_name)
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel):
|
|||
"Hi",
|
||||
"Bon día",
|
||||
"Hola mi tremendo elemento",
|
||||
"HOLA",
|
||||
"hiii",
|
||||
]
|
||||
)
|
||||
|
||||
def run(self, username: str):
|
||||
return f"{random.choice(self.hellos)} {username}" # noqa: S311
|
||||
return f"{random.choice(self.hellos)} @{username}" # noqa: S311
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue