feat: add backoff service and some message reactions
This commit is contained in:
parent
2cad170eb3
commit
3bc4e19de1
11 changed files with 394 additions and 24 deletions
|
|
@ -52,7 +52,7 @@ extend-select = [
|
||||||
"W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET",
|
"W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET",
|
||||||
"SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF"
|
"SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF"
|
||||||
]
|
]
|
||||||
extend-ignore = ["S101", "ISC002", "COM812", "ISC001"]
|
extend-ignore = ["S101", "ISC002", "COM812", "ISC001", "EM101", "EM102"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821
|
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821
|
||||||
"""Hook function that will be invoked after each exception."""
|
"""Hook function that will be invoked after each exception."""
|
||||||
state = scope["app"].state
|
state = scope["app"].state
|
||||||
if not hasattr(state, "error_count"):
|
if not hasattr(state, "error_count"):
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
||||||
|
|
||||||
@put("/api/v1/bot")
|
@put("/api/v1/bot")
|
||||||
async def manage_bot(
|
async def manage_bot(
|
||||||
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
|
user: User,
|
||||||
|
data: ManageBotDTO,
|
||||||
|
gbs: ChatbotSettingsGetterSvc,
|
||||||
|
sbs: ChatbotSettingsStorerSvc,
|
||||||
|
bm: BotsManager,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
chatbot_settings = await gbs.run(user=user)
|
chatbot_settings = await gbs.run(user=user)
|
||||||
|
if not chatbot_settings:
|
||||||
|
await sbs.run(user=user, bot_settings=ChatbotSettings())
|
||||||
|
chatbot_settings = await gbs.run(user=user)
|
||||||
if data.command == "start":
|
if data.command == "start":
|
||||||
if not data.channel_name:
|
if not data.channel_name:
|
||||||
return Response({"message": "Channel name is required"}, status_code=400)
|
return Response({"message": "Channel name is required"}, status_code=400)
|
||||||
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings)
|
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
|
||||||
if user.user in bm.bots:
|
if user.user in bm.bots:
|
||||||
await bm.run_user_bot(user)
|
await bm.run_user_bot(user)
|
||||||
return Response({"message": "Bot started"})
|
return Response({"message": "Bot started"})
|
||||||
|
|
@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
||||||
@get("/api/v1/bot/settings")
|
@get("/api/v1/bot/settings")
|
||||||
async def get_bot_settings(
|
async def get_bot_settings(
|
||||||
user: User, gbs: ChatbotSettingsGetterSvc
|
user: User, gbs: ChatbotSettingsGetterSvc
|
||||||
) -> ChatbotSettings:
|
) -> ChatbotSettings | dict:
|
||||||
return await gbs.run(user=user)
|
cbs = await gbs.run(user=user)
|
||||||
|
if not cbs:
|
||||||
|
return {"status": "Not found"}
|
||||||
|
return cbs
|
||||||
|
|
||||||
|
|
||||||
@put("/api/v1/bot/settings")
|
@put("/api/v1/bot/settings")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import random
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from twitchio import Channel
|
from twitchio import Channel
|
||||||
|
|
@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction
|
||||||
from src.huesoporro.infra.db import Database
|
from src.huesoporro.infra.db import Database
|
||||||
from src.huesoporro.libs.db import Database as MarkovDB
|
from src.huesoporro.libs.db import Database as MarkovDB
|
||||||
from src.huesoporro.models import ChatbotSettings, User
|
from src.huesoporro.models import ChatbotSettings, User
|
||||||
|
from src.huesoporro.svc.backoff_service import BackoffService
|
||||||
from src.huesoporro.svc.generate import SentenceGeneratorSvc
|
from src.huesoporro.svc.generate import SentenceGeneratorSvc
|
||||||
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||||
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
||||||
|
|
@ -75,6 +79,7 @@ class Bot(commands.Bot):
|
||||||
@commands.command(aliases=["q", "quote"])
|
@commands.command(aliases=["q", "quote"])
|
||||||
async def get_random_quote(self, ctx: commands.Context):
|
async def get_random_quote(self, ctx: commands.Context):
|
||||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||||
|
if quote:
|
||||||
await ctx.send(f"«{quote[0]}» - {quote[1]}")
|
await ctx.send(f"«{quote[0]}» - {quote[1]}")
|
||||||
|
|
||||||
def get_channel_conn(self) -> Channel:
|
def get_channel_conn(self) -> Channel:
|
||||||
|
|
@ -82,6 +87,7 @@ class Bot(commands.Bot):
|
||||||
|
|
||||||
async def send_quote(self):
|
async def send_quote(self):
|
||||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||||
|
if quote:
|
||||||
channel = self.get_channel_conn()
|
channel = self.get_channel_conn()
|
||||||
logger.info(f"Sending random quote {quote[0]}")
|
logger.info(f"Sending random quote {quote[0]}")
|
||||||
await channel.send(f"«{quote[0]}» - {quote[1]}")
|
await channel.send(f"«{quote[0]}» - {quote[1]}")
|
||||||
|
|
@ -108,14 +114,15 @@ class Bot(commands.Bot):
|
||||||
self.generation_routine.cancel()
|
self.generation_routine.cancel()
|
||||||
|
|
||||||
|
|
||||||
class SaveMessagesCog(commands.Cog):
|
class SaveMessagesCog2(commands.Cog):
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
||||||
|
self.hello_svc = HelloGeneratorSvc()
|
||||||
|
self.backoff_svc = BackoffService()
|
||||||
|
|
||||||
@commands.Cog.event()
|
@commands.Cog.event()
|
||||||
async def event_message(self, message):
|
async def event_message(self, message):
|
||||||
# An event inside a cog!
|
|
||||||
content = message.content
|
content = message.content
|
||||||
if content.startswith("!"):
|
if content.startswith("!"):
|
||||||
return
|
return
|
||||||
|
|
@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog):
|
||||||
|
|
||||||
await self.store_svc.run(content)
|
await self.store_svc.run(content)
|
||||||
|
|
||||||
|
if message.content in ["hola", "HOLA", "hiii", "ayo"]:
|
||||||
|
hello_message = self.hello_svc.run(message.author.name)
|
||||||
|
await message.channel.send(hello_message)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content == "Yes":
|
||||||
|
await message.channel.send("Indeed")
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.startswith("WHAT"):
|
||||||
|
await message.channel.send("WHAT Ramon")
|
||||||
|
return
|
||||||
|
|
||||||
|
laughs_messages = [
|
||||||
|
"om",
|
||||||
|
"KEK",
|
||||||
|
"LuL",
|
||||||
|
"LUL",
|
||||||
|
"OMEGALUL",
|
||||||
|
"kek",
|
||||||
|
"keking",
|
||||||
|
"KEKW",
|
||||||
|
"OMEGADANCEBUTFAST",
|
||||||
|
]
|
||||||
|
|
||||||
|
if message.content in laughs_messages:
|
||||||
|
await message.channel.send(random.choice(laughs_messages)) # noqa: S311
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(StrEnum):
|
||||||
|
COMMAND = "COMMAND"
|
||||||
|
HELLO = "HELLO"
|
||||||
|
YES = "YES"
|
||||||
|
WHAT = "WHAT"
|
||||||
|
LAUGH = "LAUGH"
|
||||||
|
OTHER = "OTHER"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHandler:
|
||||||
|
"""Handles different types of messages with their corresponding responses"""
|
||||||
|
|
||||||
|
def __init__(self, channel_send_func: Callable):
|
||||||
|
self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"]
|
||||||
|
self.laugh_patterns = [
|
||||||
|
"om",
|
||||||
|
"KEK",
|
||||||
|
"LuL",
|
||||||
|
"LUL",
|
||||||
|
"OMEGALUL",
|
||||||
|
"kek",
|
||||||
|
"keking",
|
||||||
|
"KEKW",
|
||||||
|
"OMEGADANCEBUTFAST",
|
||||||
|
]
|
||||||
|
self.send = channel_send_func
|
||||||
|
|
||||||
|
def get_message_type(self, content: str) -> MessageType:
|
||||||
|
"""Determines the type of message based on its content"""
|
||||||
|
if content.startswith("!"):
|
||||||
|
return MessageType.COMMAND
|
||||||
|
if content in self.hello_patterns:
|
||||||
|
return MessageType.HELLO
|
||||||
|
if content == "Yes":
|
||||||
|
return MessageType.YES
|
||||||
|
if content.startswith("WHAT"):
|
||||||
|
return MessageType.WHAT
|
||||||
|
if content in self.laugh_patterns:
|
||||||
|
return MessageType.LAUGH
|
||||||
|
return MessageType.OTHER
|
||||||
|
|
||||||
|
async def handle_hello(self, author_name: str, hello_svc) -> str:
|
||||||
|
"""Handles hello messages"""
|
||||||
|
return hello_svc.run(author_name)
|
||||||
|
|
||||||
|
async def handle_laugh(self) -> str:
|
||||||
|
"""Handles laugh messages"""
|
||||||
|
return random.choice(self.laugh_patterns) # noqa: S311
|
||||||
|
|
||||||
|
|
||||||
|
class SaveMessagesCog(commands.Cog):
|
||||||
|
def __init__(self, bot):
|
||||||
|
self.bot = bot
|
||||||
|
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
||||||
|
self.hello_svc = HelloGeneratorSvc()
|
||||||
|
self.backoff_svc = BackoffService()
|
||||||
|
self.message_handler = MessageHandler(self._send_message)
|
||||||
|
|
||||||
|
# Register a separate send function for each message type
|
||||||
|
self.send_functions = {
|
||||||
|
MessageType.HELLO: self._create_typed_send("hello"),
|
||||||
|
MessageType.YES: self._create_typed_send("yes"),
|
||||||
|
MessageType.WHAT: self._create_typed_send("what"),
|
||||||
|
MessageType.LAUGH: self._create_typed_send("laugh"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register each send function with its own backoff
|
||||||
|
for func in self.send_functions.values():
|
||||||
|
self.backoff_svc.add_callable(func, backoff_seconds=10)
|
||||||
|
|
||||||
|
def _create_typed_send(self, type_name: str):
|
||||||
|
"""Creates a send function for a specific message type"""
|
||||||
|
|
||||||
|
async def typed_send(content: str):
|
||||||
|
if hasattr(self, "current_message"):
|
||||||
|
await self.current_message.channel.send(content)
|
||||||
|
|
||||||
|
# Set a unique name for the function to ensure it's treated as distinct
|
||||||
|
typed_send.__name__ = f"send_{type_name}"
|
||||||
|
return typed_send
|
||||||
|
|
||||||
|
async def _send_message(self, content: str):
|
||||||
|
"""Generic send message function (for non-backoff uses)"""
|
||||||
|
if hasattr(self, "current_message"):
|
||||||
|
await self.current_message.channel.send(content)
|
||||||
|
|
||||||
|
@commands.Cog.event()
|
||||||
|
async def event_message(self, message):
|
||||||
|
"""Main message event handler"""
|
||||||
|
if not message.author:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store reference to current message for send functions
|
||||||
|
self.current_message = message
|
||||||
|
|
||||||
|
# Store the message content
|
||||||
|
await self.store_svc.run(message.content)
|
||||||
|
|
||||||
|
# Determine message type and handle accordingly
|
||||||
|
msg_type = self.message_handler.get_message_type(message.content)
|
||||||
|
|
||||||
|
response = None
|
||||||
|
|
||||||
|
match msg_type:
|
||||||
|
case MessageType.COMMAND:
|
||||||
|
return
|
||||||
|
case MessageType.HELLO:
|
||||||
|
response = await self.message_handler.handle_hello(
|
||||||
|
message.author.name, self.hello_svc
|
||||||
|
)
|
||||||
|
case MessageType.YES:
|
||||||
|
response = "Indeed"
|
||||||
|
case MessageType.WHAT:
|
||||||
|
response = "WHAT Ramon"
|
||||||
|
case MessageType.LAUGH:
|
||||||
|
response = await self.message_handler.handle_laugh()
|
||||||
|
case MessageType.OTHER:
|
||||||
|
return
|
||||||
|
|
||||||
|
if response and msg_type in self.send_functions:
|
||||||
|
# Use the type-specific send function
|
||||||
|
await self.backoff_svc.call_async(self.send_functions[msg_type], response)
|
||||||
|
|
||||||
|
|
||||||
class BotsManager:
|
class BotsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
111
src/huesoporro/svc/backoff_service.py
Normal file
111
src/huesoporro/svc/backoff_service.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CallableInfo(BaseModel):
|
||||||
|
backoff_seconds: int
|
||||||
|
last_call: float | None = None
|
||||||
|
is_async: bool
|
||||||
|
|
||||||
|
|
||||||
|
class BackoffService(BaseModel):
|
||||||
|
"""Use this service to implement a backoff strategy on random callables.
|
||||||
|
The callable will be called the first time without delay but every subsequent
|
||||||
|
call may be hold off for a given time
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> def callable(x): print(f"foo {x}")
|
||||||
|
>>> backoff_service = BackoffService()
|
||||||
|
>>> backoff_service.add_callable(callable, backoff_time=3)
|
||||||
|
>>> backoff_service.call(callable, "bar") # prints "foo bar"
|
||||||
|
>>> backoff_service.call(callable, "baz") # prints nothing
|
||||||
|
>>> # wait 3 seconds before calling callable again
|
||||||
|
>>> backoff_service.call(callable, "qux") # prints "foo qux"
|
||||||
|
"""
|
||||||
|
|
||||||
|
callables: dict[Callable, CallableInfo] = {}
|
||||||
|
|
||||||
|
def add_callable(self, func: Callable, backoff_seconds: int):
|
||||||
|
"""Adds a callable to the local mapper with its backoff configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be registered
|
||||||
|
backoff_seconds: The number of seconds to wait between successive calls
|
||||||
|
"""
|
||||||
|
self.callables[func] = CallableInfo(
|
||||||
|
backoff_seconds=backoff_seconds, is_async=self._is_async(func)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_async(func: Callable) -> bool:
|
||||||
|
"""Checks if the callable is async"""
|
||||||
|
return asyncio.iscoroutinefunction(func)
|
||||||
|
|
||||||
|
def _can_call(self, func: Callable) -> bool:
|
||||||
|
"""Determines if enough time has passed since the last call"""
|
||||||
|
if func not in self.callables:
|
||||||
|
raise ValueError(f"Function {func} not registered with backoff service")
|
||||||
|
|
||||||
|
func_info = self.callables[func]
|
||||||
|
last_call = func_info.last_call
|
||||||
|
|
||||||
|
if last_call is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
elapsed = time.time() - last_call
|
||||||
|
return elapsed >= func_info.backoff_seconds
|
||||||
|
|
||||||
|
def call(self, func: Callable, *args, **kwargs):
|
||||||
|
"""Calls the callable with arguments and returns its result if it isn't held off
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to call
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Any]: The result of the function call if executed, None if held off
|
||||||
|
"""
|
||||||
|
if func not in self.callables:
|
||||||
|
raise ValueError(f"Function {func} not registered with backoff service")
|
||||||
|
|
||||||
|
if self.callables[func].is_async:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot call async function with .call(), use .call_async() instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._can_call(func):
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
self.callables[func].last_call = time.time()
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def call_async(self, func: Callable, *args, **kwargs):
|
||||||
|
"""Same as .call(...) but for async functions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The async function to call
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Any]: The result of the async function call if executed, None if held off
|
||||||
|
"""
|
||||||
|
if func not in self.callables:
|
||||||
|
raise ValueError(f"Function {func} not registered with backoff service")
|
||||||
|
|
||||||
|
if not self.callables[func].is_async:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot call sync function with .call_async(), use .call() instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._can_call(func):
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
self.callables[func].last_call = time.time()
|
||||||
|
return result
|
||||||
|
|
@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel):
|
||||||
self,
|
self,
|
||||||
sentence: str | None = None,
|
sentence: str | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if sentence:
|
split_sentence = tokenize(sentence) if sentence else None
|
||||||
sentence = tokenize(sentence)
|
|
||||||
logger.info(f"Generating sentence from {sentence}")
|
logger.info(f"Generating sentence from {split_sentence}")
|
||||||
sentence, success = self.generate(sentence)
|
generated_sentence, success = self.generate(split_sentence)
|
||||||
logger.info(f"Generated sentence: {sentence}")
|
logger.info(f"Generated sentence: {generated_sentence}")
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
return sentence
|
return generated_sentence
|
||||||
|
|
|
||||||
|
|
@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database
|
||||||
class RandomQuoteGetterSvc(BaseModel):
|
class RandomQuoteGetterSvc(BaseModel):
|
||||||
db: Database
|
db: Database
|
||||||
|
|
||||||
async def run(self, channel_name: str) -> tuple[str, str]:
|
async def run(self, channel_name: str) -> tuple[str, str] | None:
|
||||||
return await self.db.get_random_quote(channel_name=channel_name)
|
return await self.db.get_random_quote(channel_name=channel_name)
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel):
|
||||||
"Hi",
|
"Hi",
|
||||||
"Bon día",
|
"Bon día",
|
||||||
"Hola mi tremendo elemento",
|
"Hola mi tremendo elemento",
|
||||||
|
"HOLA",
|
||||||
|
"hiii",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, username: str):
|
def run(self, username: str):
|
||||||
return f"{random.choice(self.hellos)} {username}" # noqa: S311
|
return f"{random.choice(self.hellos)} @{username}" # noqa: S311
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from caribou.migrate import load_migrations
|
||||||
from src.huesoporro.infra.db import Database
|
from src.huesoporro.infra.db import Database
|
||||||
from src.huesoporro.models import ChatbotSettings, TwitchAuth, User
|
from src.huesoporro.models import ChatbotSettings, TwitchAuth, User
|
||||||
from src.huesoporro.settings import Settings
|
from src.huesoporro.settings import Settings
|
||||||
|
from src.huesoporro.svc.backoff_service import BackoffService
|
||||||
from src.huesoporro.svc.is_mod import IsModSvc
|
from src.huesoporro.svc.is_mod import IsModSvc
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -16,7 +17,8 @@ def user() -> User:
|
||||||
user="huesoporro",
|
user="huesoporro",
|
||||||
expires_at=1671234567.0,
|
expires_at=1671234567.0,
|
||||||
twitch_auth=TwitchAuth(
|
twitch_auth=TwitchAuth(
|
||||||
access_token="test_access_token", refresh_token="test_refresh_token"
|
access_token="test_access_token", # noqa: S106
|
||||||
|
refresh_token="test_refresh_token", # noqa: S106
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -27,8 +29,8 @@ def s(tmp_path: Path, user: User) -> Settings:
|
||||||
static_files_path=tmp_path / "static_files",
|
static_files_path=tmp_path / "static_files",
|
||||||
db_filepath=tmp_path / "huesoporro.db",
|
db_filepath=tmp_path / "huesoporro.db",
|
||||||
twitch_client_id="test_client_id",
|
twitch_client_id="test_client_id",
|
||||||
twitch_client_secret="test_client_secret", # type: ignore[arg-type]
|
twitch_client_secret="test_client_secret", # type: ignore[arg-type] # noqa: S106
|
||||||
jwt_secret="test_jwt_secret", # type: ignore[arg-type]
|
jwt_secret="test_jwt_secret", # type: ignore[arg-type] # noqa: S106
|
||||||
allowed_users=[user.user],
|
allowed_users=[user.user],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -54,3 +56,27 @@ async def chatbot_settings(db: Database, user) -> ChatbotSettings:
|
||||||
cbs = ChatbotSettings(mods=[user.user, "allowed_user"])
|
cbs = ChatbotSettings(mods=[user.user, "allowed_user"])
|
||||||
await db.save_chatbot_settings(user=user, chatbot_settings=cbs)
|
await db.save_chatbot_settings(user=user, chatbot_settings=cbs)
|
||||||
return cbs
|
return cbs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backoff_callable():
|
||||||
|
def foo():
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
return foo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_backoff_callable():
|
||||||
|
async def async_foo():
|
||||||
|
return "async foo"
|
||||||
|
|
||||||
|
return async_foo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def backoff_svc(backoff_callable, async_backoff_callable):
|
||||||
|
backoff_svc = BackoffService()
|
||||||
|
backoff_svc.add_callable(backoff_callable, 3)
|
||||||
|
backoff_svc.add_callable(async_backoff_callable, 3)
|
||||||
|
return backoff_svc
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.huesoporro.models import ChatbotSettings, User
|
from src.huesoporro.models import ChatbotSettings, User
|
||||||
|
|
@ -33,3 +36,61 @@ async def test_is_mod_svc_returns_false_for_user_not_in_modlist(
|
||||||
):
|
):
|
||||||
is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user)
|
is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user)
|
||||||
assert not is_mod
|
assert not is_mod
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_returns_for_first_attempt(
|
||||||
|
backoff_svc, backoff_callable, async_backoff_callable
|
||||||
|
):
|
||||||
|
assert backoff_svc.call(backoff_callable) == "foo"
|
||||||
|
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_returns_none_for_second_attempt(
|
||||||
|
backoff_svc, backoff_callable, async_backoff_callable
|
||||||
|
):
|
||||||
|
assert backoff_svc.call(backoff_callable) == "foo"
|
||||||
|
assert backoff_svc.call(backoff_callable) is None
|
||||||
|
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_returns_for_second_attempt_after_delay(
|
||||||
|
backoff_svc, backoff_callable, async_backoff_callable
|
||||||
|
):
|
||||||
|
assert backoff_svc.call(backoff_callable) == "foo"
|
||||||
|
assert backoff_svc.call(backoff_callable) is None
|
||||||
|
time.sleep(3)
|
||||||
|
assert backoff_svc.call(backoff_callable) == "foo"
|
||||||
|
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) is None
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_raises_value_error_for_unknown_callable(backoff_svc):
|
||||||
|
with pytest.raises(ValueError, match="not registered with backoff service"):
|
||||||
|
backoff_svc.call(lambda: "foo")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_raises_value_error_for_unknown_async_callable(backoff_svc):
|
||||||
|
with pytest.raises(ValueError, match="not registered with backoff service"):
|
||||||
|
await backoff_svc.call_async(lambda: "foo")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_raises_value_error_for_async_called_from_sync(
|
||||||
|
backoff_svc, backoff_callable
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Cannot call sync function with .call_async()"
|
||||||
|
):
|
||||||
|
await backoff_svc.call_async(backoff_callable)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_backoff_svc_raises_value_error_for_sync_called_from_async(
|
||||||
|
backoff_svc, async_backoff_callable
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Cannot call async function with .call()"):
|
||||||
|
backoff_svc.call(async_backoff_callable)
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -460,7 +460,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huesoporro"
|
name = "huesoporro"
|
||||||
version = "0.2.2"
|
version = "0.2.3"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue