tests: add base tests

This commit is contained in:
cătălin 2024-12-18 18:27:46 +01:00
commit 9893d36be3
No known key found for this signature in database
23 changed files with 353 additions and 206 deletions

View file

@ -12,7 +12,7 @@ class StoreQuoteAction(BaseModel):
async def run(
self, user: User, channel: str, quote: str, author: str, username: str
) -> str:
if not await self.is_mod_svc.run(user=user, username=username):
if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!"
await self.quote_storer_svc.run(channel, quote, author)
return f"«{quote}» added by {author}."

View file

@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
)
async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821
"""Hook function that will be invoked after each exception."""
state = scope["app"].state
if not hasattr(state, "error_count"):

View file

@ -83,3 +83,6 @@ def create_app():
"sbs": Provide(store_chatbot_settings_svc),
},
)
app = create_app()

View file

@ -44,12 +44,15 @@ async def get_tts_permalink(access_token: str) -> Template:
)
async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
chatbot_settings = await gbs.run(user=user)
return Template(template_name="index.html", context=chatbot_settings.model_dump() if chatbot_settings else {})
return Template(
template_name="index.html",
context=chatbot_settings.model_dump() if chatbot_settings else {},
)
@put("/api/v1/bot")
async def manage_bot(
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
) -> Response:
chatbot_settings = await gbs.run(user=user)
if data.command == "start":
@ -62,6 +65,7 @@ async def manage_bot(
if data.command == "stop" and user.user in bm.bots:
await bm.stop_user_bot(user)
return Response({"message": "Bot stopped"})
return Response({"message": "Invalid command"}, status_code=400)
@get("/api/v1/bot")
@ -73,14 +77,14 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
@get("/api/v1/bot/settings")
async def get_bot_settings(
user: User, gbs: ChatbotSettingsGetterSvc
user: User, gbs: ChatbotSettingsGetterSvc
) -> ChatbotSettings:
return await gbs.run(user=user)
@put("/api/v1/bot/settings")
async def save_bot_settings(
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
) -> dict:
await sbs.run(user=user, bot_settings=data)
return {"status": "ok"}

View file

@ -31,7 +31,7 @@ class Bot(commands.Bot):
)
self.get_random_quote_svc = RandomQuoteGetterSvc(db=db)
self.cbs = chatbot_settings
self.quote_routine = routines.routine(
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
)(self.send_quote)
@ -43,16 +43,20 @@ class Bot(commands.Bot):
logger.info(f"Logged in as {self.nick}")
logger.info(f"User id is {self.user_id}")
@commands.command()
async def hello(self, ctx: commands.Context, user: User | None = None):
username = user.name if user else ctx.author.name
@commands.command(aliases=["h"])
async def hello(self, ctx: commands.Context, username: str | None = None):
username = username or ctx.author.name
await ctx.send(self.hello_svc.run(username))
@commands.command(aliases=["g"])
async def generate(self, ctx: commands.Context, *, words: str | None = None):
sentence = await self.generate_svc.run(words)
if sentence:
await ctx.send(sentence)
if not sentence:
logger.warning(
f"Could not generate sentence for {words or 'no words provided'}"
)
return
await ctx.send(sentence)
@commands.command(aliases=["qadd"])
async def add_quote(self, ctx: commands.Context, *, quote: str):
@ -91,9 +95,12 @@ class Bot(commands.Bot):
await channel.send(sentence)
def start_routines(self):
logger.info("Starting routines")
self.quote_routine.start(stop_on_error=False)
self.generation_routine.start(stop_on_error=False)
if self.cbs.automatic_quote_timer > 0:
logger.info("Starting quote routine")
self.quote_routine.start(stop_on_error=False)
if self.cbs.automatic_generation_timer > 0:
logger.info("Starting generation routine")
self.generation_routine.start(stop_on_error=False)
def stop_routines(self):
logger.info("Stopping routines")

View file

@ -26,7 +26,7 @@ class TwitchAuthenticator(BaseModel):
headers={"Accept": "application/json"},
)
if auto_refresh and response.status_code == 401:
if auto_refresh and response.status_code == 401: # noqa: PLR2004
return await self.refresh_token(response.json()["refresh_token"])
response.raise_for_status()

View file

@ -2,11 +2,11 @@ import datetime
from contextlib import asynccontextmanager
import aiosqlite
from loguru import logger
from pydantic import BaseModel, Field
from src.huesoporro.models import ChatbotSettings, User
from src.huesoporro.settings import Settings
from loguru import logger
class Database(BaseModel):
@ -27,7 +27,7 @@ class Database(BaseModel):
async def save_user(self, user: User, auto_commit=True):
async with self.get_client(auto_commit=auto_commit) as db:
async with db.execute(
"SELECT * FROM users WHERE user = ?", (user.user,)
"SELECT * FROM users WHERE user = ?", (user.user,)
) as cursor:
result = await cursor.fetchone()
if result:
@ -62,7 +62,7 @@ class Database(BaseModel):
)
async def save_chatbot_settings(
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
):
async with self.get_client(auto_commit=auto_commit) as db:
current_settings = await self.get_chatbot_settings(user)
@ -109,7 +109,7 @@ class Database(BaseModel):
async with self.get_client() as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
) as cursor:
result = await cursor.fetchone()
if not result:
@ -124,11 +124,12 @@ class Database(BaseModel):
)
await db.commit()
async def get_random_quote(self, channel_name: str):
async with self.get_client() as db:
async with db.execute(
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
(channel_name,),
) as cursor:
result = await cursor.fetchone()
return result
async def get_random_quote(self, channel_name: str) -> tuple[str, str] | None:
async with (
self.get_client() as db,
db.execute(
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
(channel_name,),
) as cursor,
):
return await cursor.fetchone()

View file

@ -266,7 +266,7 @@ class Database:
);
""")
self.add_execute_queue(
f'INSERT INTO MarkovGrammar{first_char}{second_char} SELECT * FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";',
f'INSERT INTO MarkovGrammar{first_char}{second_char} SELECT * FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";', # noqa: S608
)
self.add_execute_queue(
f'DELETE FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";', # noqa: S608

View file

@ -1,9 +1,7 @@
import uvicorn
from src.huesoporro.api.main import create_app
from src.huesoporro.settings import Settings
if __name__ == "__main__":
settings = Settings.get()
app = create_app()
uvicorn.run(app, host=settings.host, port=settings.port)
uvicorn.run("src.huesoporro.api.main:app", host=settings.host, port=settings.port)

View file

@ -1,7 +1,7 @@
from typing import Self
import jwt
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from src.huesoporro.settings import Settings
@ -36,10 +36,12 @@ class User(BaseModel):
class ChatbotSettings(BaseModel):
automatic_generation_timer: int = 300
automatic_quote_timer: int = 500
mods: list[str] | None = None
mods: list[str] = Field(default_factory=list)
@property
def mods_as_string(self):
if not self.mods:
return ""
return ",".join(self.mods)
@field_validator("mods", mode="before")

View file

@ -27,7 +27,7 @@ class Settings(BaseSettings):
default_factory=lambda: ["channel:bot", "chat:edit", "chat:read"]
)
allowed_users: list[str] | str = Field(default_factory=lambda: ["huesoporro"])
server_hostname: HttpUrl = "http://localhost:8000"
server_hostname: HttpUrl = "http://localhost:8000" # type: ignore[assignment]
@staticmethod
@lru_cache(maxsize=1)

View file

@ -15,18 +15,6 @@ class SentenceGeneratorSvc(BaseModel):
sentence_separator: str = " "
model_config = ConfigDict(arbitrary_types_allowed=True)
def is_mod(self, username: str, channel: str) -> bool:
"""True if the user is a moderator.
Args:
username (str): The name of the user to check
channel (str): The name of the channel
Returns:
bool: True if the user is a moderator.
"""
return username in self.s.mods or username == channel
@staticmethod
def get_sentence_length(sentences: list[list[str]]) -> int:
"""Given a list of tokens representing a sentence, return the number of words in there.
@ -63,10 +51,6 @@ class SentenceGeneratorSvc(BaseModel):
# e.g. when the first sentence has fewer words than self.min_sentence_length.
sentences: list[list | list[str]] = [[]]
# Check for commands or recursion, eg: !generate !generate
if len(params) > 0:
return "You can't make me do commands, you madman!", False
# Get the starting key and starting sentence.
# If there is more than 1 param, get the last 2 as the key.
# Note that self.key_length is fixed to 2 in this implementation
@ -148,11 +132,12 @@ class SentenceGeneratorSvc(BaseModel):
async def run(
self,
sentence: str | None = None,
) -> str|None:
) -> str | None:
if sentence:
sentence = tokenize(sentence)
logger.info(f"Generating sentence from {sentence}")
sentence, success = self.generate(sentence)
logger.info(f"Generated sentence: {sentence}")
if success:
return sentence
if not success:
return None
return sentence

View file

@ -4,7 +4,15 @@ from pydantic import BaseModel, Field
class HelloGeneratorSvc(BaseModel):
hellos: list[str] = Field(default_factory=lambda: ["Hola", "Ayo", "Hi", "Bon día"])
hellos: list[str] = Field(
default_factory=lambda: [
"Hola",
"Ayo",
"Hi",
"Bon día",
"Hola mi tremendo elemento",
]
)
def run(self, username: str):
return f"{random.choice(self.hellos)} {username}"
return f"{random.choice(self.hellos)} {username}" # noqa: S311

View file

@ -7,6 +7,14 @@ from src.huesoporro.models import User
class IsModSvc(BaseModel):
db: Database
async def run(self, user: User, username: str) -> bool:
async def run(self, user: User, username: str, channel: str) -> bool:
"""A user given username is a mod if they're the same as the current channel or if they're in the modlist
available in a user's settings"""
if channel == username:
return True
chatbot_settings = await self.db.get_chatbot_settings(user=user)
if not chatbot_settings:
return False
return username in chatbot_settings.mods

View file

@ -23,6 +23,7 @@ class SentenceStorerSvc(BaseModel):
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")
logger.debug("Downloaded required punkt resource.")
sentences = sent_tokenize(stripped_message)

View file

@ -7,9 +7,7 @@ from src.huesoporro.models import ChatbotSettings, User
class ChatbotSettingsStorerSvc(BaseModel):
db: Database
async def run(
self, user: User, bot_settings: ChatbotSettings
) -> dict[str, str | int | None] | None:
async def run(self, user: User, bot_settings: ChatbotSettings):
return await self.db.save_chatbot_settings(
user=user, chatbot_settings=bot_settings
)