tests: add base tests
This commit is contained in:
parent
4c534de47b
commit
9893d36be3
23 changed files with 353 additions and 206 deletions
|
|
@ -12,7 +12,7 @@ class StoreQuoteAction(BaseModel):
|
|||
async def run(
|
||||
self, user: User, channel: str, quote: str, author: str, username: str
|
||||
) -> str:
|
||||
if not await self.is_mod_svc.run(user=user, username=username):
|
||||
if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
|
||||
return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!"
|
||||
await self.quote_storer_svc.run(channel, quote, author)
|
||||
return f"«{quote}» added by {author}."
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
|
|||
)
|
||||
|
||||
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821
|
||||
"""Hook function that will be invoked after each exception."""
|
||||
state = scope["app"].state
|
||||
if not hasattr(state, "error_count"):
|
||||
|
|
|
|||
|
|
@ -83,3 +83,6 @@ def create_app():
|
|||
"sbs": Provide(store_chatbot_settings_svc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
|
|
|||
|
|
@ -44,12 +44,15 @@ async def get_tts_permalink(access_token: str) -> Template:
|
|||
)
|
||||
async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
return Template(template_name="index.html", context=chatbot_settings.model_dump() if chatbot_settings else {})
|
||||
return Template(
|
||||
template_name="index.html",
|
||||
context=chatbot_settings.model_dump() if chatbot_settings else {},
|
||||
)
|
||||
|
||||
|
||||
@put("/api/v1/bot")
|
||||
async def manage_bot(
|
||||
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
|
||||
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
|
||||
) -> Response:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
if data.command == "start":
|
||||
|
|
@ -62,6 +65,7 @@ async def manage_bot(
|
|||
if data.command == "stop" and user.user in bm.bots:
|
||||
await bm.stop_user_bot(user)
|
||||
return Response({"message": "Bot stopped"})
|
||||
return Response({"message": "Invalid command"}, status_code=400)
|
||||
|
||||
|
||||
@get("/api/v1/bot")
|
||||
|
|
@ -73,14 +77,14 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
|||
|
||||
@get("/api/v1/bot/settings")
|
||||
async def get_bot_settings(
|
||||
user: User, gbs: ChatbotSettingsGetterSvc
|
||||
user: User, gbs: ChatbotSettingsGetterSvc
|
||||
) -> ChatbotSettings:
|
||||
return await gbs.run(user=user)
|
||||
|
||||
|
||||
@put("/api/v1/bot/settings")
|
||||
async def save_bot_settings(
|
||||
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
|
||||
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
|
||||
) -> dict:
|
||||
await sbs.run(user=user, bot_settings=data)
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Bot(commands.Bot):
|
|||
)
|
||||
|
||||
self.get_random_quote_svc = RandomQuoteGetterSvc(db=db)
|
||||
|
||||
self.cbs = chatbot_settings
|
||||
self.quote_routine = routines.routine(
|
||||
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
|
||||
)(self.send_quote)
|
||||
|
|
@ -43,16 +43,20 @@ class Bot(commands.Bot):
|
|||
logger.info(f"Logged in as {self.nick}")
|
||||
logger.info(f"User id is {self.user_id}")
|
||||
|
||||
@commands.command()
|
||||
async def hello(self, ctx: commands.Context, user: User | None = None):
|
||||
username = user.name if user else ctx.author.name
|
||||
@commands.command(aliases=["h"])
|
||||
async def hello(self, ctx: commands.Context, username: str | None = None):
|
||||
username = username or ctx.author.name
|
||||
await ctx.send(self.hello_svc.run(username))
|
||||
|
||||
@commands.command(aliases=["g"])
|
||||
async def generate(self, ctx: commands.Context, *, words: str | None = None):
|
||||
sentence = await self.generate_svc.run(words)
|
||||
if sentence:
|
||||
await ctx.send(sentence)
|
||||
if not sentence:
|
||||
logger.warning(
|
||||
f"Could not generate sentence for {words or 'no words provided'}"
|
||||
)
|
||||
return
|
||||
await ctx.send(sentence)
|
||||
|
||||
@commands.command(aliases=["qadd"])
|
||||
async def add_quote(self, ctx: commands.Context, *, quote: str):
|
||||
|
|
@ -91,9 +95,12 @@ class Bot(commands.Bot):
|
|||
await channel.send(sentence)
|
||||
|
||||
def start_routines(self):
|
||||
logger.info("Starting routines")
|
||||
self.quote_routine.start(stop_on_error=False)
|
||||
self.generation_routine.start(stop_on_error=False)
|
||||
if self.cbs.automatic_quote_timer > 0:
|
||||
logger.info("Starting quote routine")
|
||||
self.quote_routine.start(stop_on_error=False)
|
||||
if self.cbs.automatic_generation_timer > 0:
|
||||
logger.info("Starting generation routine")
|
||||
self.generation_routine.start(stop_on_error=False)
|
||||
|
||||
def stop_routines(self):
|
||||
logger.info("Stopping routines")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class TwitchAuthenticator(BaseModel):
|
|||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if auto_refresh and response.status_code == 401:
|
||||
if auto_refresh and response.status_code == 401: # noqa: PLR2004
|
||||
return await self.refresh_token(response.json()["refresh_token"])
|
||||
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import datetime
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiosqlite
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
from src.huesoporro.settings import Settings
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Database(BaseModel):
|
||||
|
|
@ -27,7 +27,7 @@ class Database(BaseModel):
|
|||
async def save_user(self, user: User, auto_commit=True):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
async with db.execute(
|
||||
"SELECT * FROM users WHERE user = ?", (user.user,)
|
||||
"SELECT * FROM users WHERE user = ?", (user.user,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if result:
|
||||
|
|
@ -62,7 +62,7 @@ class Database(BaseModel):
|
|||
)
|
||||
|
||||
async def save_chatbot_settings(
|
||||
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
||||
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
||||
):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
current_settings = await self.get_chatbot_settings(user)
|
||||
|
|
@ -109,7 +109,7 @@ class Database(BaseModel):
|
|||
async with self.get_client() as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
|
||||
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
|
|
@ -124,11 +124,12 @@ class Database(BaseModel):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def get_random_quote(self, channel_name: str):
|
||||
async with self.get_client() as db:
|
||||
async with db.execute(
|
||||
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
|
||||
(channel_name,),
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result
|
||||
async def get_random_quote(self, channel_name: str) -> tuple[str, str] | None:
|
||||
async with (
|
||||
self.get_client() as db,
|
||||
db.execute(
|
||||
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
|
||||
(channel_name,),
|
||||
) as cursor,
|
||||
):
|
||||
return await cursor.fetchone()
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class Database:
|
|||
);
|
||||
""")
|
||||
self.add_execute_queue(
|
||||
f'INSERT INTO MarkovGrammar{first_char}{second_char} SELECT * FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";',
|
||||
f'INSERT INTO MarkovGrammar{first_char}{second_char} SELECT * FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";', # noqa: S608
|
||||
)
|
||||
self.add_execute_queue(
|
||||
f'DELETE FROM MarkovGrammar{first_char} WHERE word2 LIKE "{second_char}%";', # noqa: S608
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import uvicorn
|
||||
|
||||
from src.huesoporro.api.main import create_app
|
||||
from src.huesoporro.settings import Settings
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = Settings.get()
|
||||
app = create_app()
|
||||
uvicorn.run(app, host=settings.host, port=settings.port)
|
||||
uvicorn.run("src.huesoporro.api.main:app", host=settings.host, port=settings.port)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Self
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from src.huesoporro.settings import Settings
|
||||
|
||||
|
|
@ -36,10 +36,12 @@ class User(BaseModel):
|
|||
class ChatbotSettings(BaseModel):
|
||||
automatic_generation_timer: int = 300
|
||||
automatic_quote_timer: int = 500
|
||||
mods: list[str] | None = None
|
||||
mods: list[str] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def mods_as_string(self):
|
||||
if not self.mods:
|
||||
return ""
|
||||
return ",".join(self.mods)
|
||||
|
||||
@field_validator("mods", mode="before")
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class Settings(BaseSettings):
|
|||
default_factory=lambda: ["channel:bot", "chat:edit", "chat:read"]
|
||||
)
|
||||
allowed_users: list[str] | str = Field(default_factory=lambda: ["huesoporro"])
|
||||
server_hostname: HttpUrl = "http://localhost:8000"
|
||||
server_hostname: HttpUrl = "http://localhost:8000" # type: ignore[assignment]
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
|
|
|
|||
|
|
@ -15,18 +15,6 @@ class SentenceGeneratorSvc(BaseModel):
|
|||
sentence_separator: str = " "
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def is_mod(self, username: str, channel: str) -> bool:
|
||||
"""True if the user is a moderator.
|
||||
|
||||
Args:
|
||||
username (str): The name of the user to check
|
||||
channel (str): The name of the channel
|
||||
|
||||
Returns:
|
||||
bool: True if the user is a moderator.
|
||||
"""
|
||||
return username in self.s.mods or username == channel
|
||||
|
||||
@staticmethod
|
||||
def get_sentence_length(sentences: list[list[str]]) -> int:
|
||||
"""Given a list of tokens representing a sentence, return the number of words in there.
|
||||
|
|
@ -63,10 +51,6 @@ class SentenceGeneratorSvc(BaseModel):
|
|||
# e.g. when the first sentence has fewer words than self.min_sentence_length.
|
||||
sentences: list[list | list[str]] = [[]]
|
||||
|
||||
# Check for commands or recursion, eg: !generate !generate
|
||||
if len(params) > 0:
|
||||
return "You can't make me do commands, you madman!", False
|
||||
|
||||
# Get the starting key and starting sentence.
|
||||
# If there is more than 1 param, get the last 2 as the key.
|
||||
# Note that self.key_length is fixed to 2 in this implementation
|
||||
|
|
@ -148,11 +132,12 @@ class SentenceGeneratorSvc(BaseModel):
|
|||
async def run(
|
||||
self,
|
||||
sentence: str | None = None,
|
||||
) -> str|None:
|
||||
) -> str | None:
|
||||
if sentence:
|
||||
sentence = tokenize(sentence)
|
||||
logger.info(f"Generating sentence from {sentence}")
|
||||
sentence, success = self.generate(sentence)
|
||||
logger.info(f"Generated sentence: {sentence}")
|
||||
if success:
|
||||
return sentence
|
||||
if not success:
|
||||
return None
|
||||
return sentence
|
||||
|
|
|
|||
|
|
@ -4,7 +4,15 @@ from pydantic import BaseModel, Field
|
|||
|
||||
|
||||
class HelloGeneratorSvc(BaseModel):
|
||||
hellos: list[str] = Field(default_factory=lambda: ["Hola", "Ayo", "Hi", "Bon día"])
|
||||
hellos: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"Hola",
|
||||
"Ayo",
|
||||
"Hi",
|
||||
"Bon día",
|
||||
"Hola mi tremendo elemento",
|
||||
]
|
||||
)
|
||||
|
||||
def run(self, username: str):
|
||||
return f"{random.choice(self.hellos)} {username}"
|
||||
return f"{random.choice(self.hellos)} {username}" # noqa: S311
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ from src.huesoporro.models import User
|
|||
class IsModSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, user: User, username: str) -> bool:
|
||||
async def run(self, user: User, username: str, channel: str) -> bool:
|
||||
"""A user given username is a mod if they're the same as the current channel or if they're in the modlist
|
||||
available in a user's settings"""
|
||||
|
||||
if channel == username:
|
||||
return True
|
||||
|
||||
chatbot_settings = await self.db.get_chatbot_settings(user=user)
|
||||
if not chatbot_settings:
|
||||
return False
|
||||
return username in chatbot_settings.mods
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class SentenceStorerSvc(BaseModel):
|
|||
import nltk
|
||||
|
||||
nltk.download("punkt")
|
||||
nltk.download("punkt_tab")
|
||||
logger.debug("Downloaded required punkt resource.")
|
||||
sentences = sent_tokenize(stripped_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ from src.huesoporro.models import ChatbotSettings, User
|
|||
class ChatbotSettingsStorerSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(
|
||||
self, user: User, bot_settings: ChatbotSettings
|
||||
) -> dict[str, str | int | None] | None:
|
||||
async def run(self, user: User, bot_settings: ChatbotSettings):
|
||||
return await self.db.save_chatbot_settings(
|
||||
user=user, chatbot_settings=bot_settings
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue