Compare commits

...

11 commits

24 changed files with 2149 additions and 1054 deletions

View file

@ -2,7 +2,7 @@ files: src|tests
exclude: ^$ exclude: ^$
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 rev: v6.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
args: [ --markdown-linebreak-ext=md ] args: [ --markdown-linebreak-ext=md ]

View file

@ -1 +1 @@
3.11 3.13

View file

@ -2,6 +2,41 @@
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
## [0.3.6] - 2025-06-06
### 🚀 Features
- Add `Quote.is_active` field
## [0.3.5] - 2025-05-27
### 🚀 Features
- Implement remaining repo methods for chatbot and quote
- Add ANO_PREFIX bot response
### ⚙️ Miscellaneous Tasks
- Add renovate lockFileMaintenance
## [0.3.3] - 2025-03-06
### 🐛 Bug Fixes
- Have uvicorn use Settings.port and Settings.host
## [0.3.2] - 2025-03-06
### ⚙️ Miscellaneous Tasks
- Update charts to v0.3.2
## [0.3.1] - 2025-03-06
### 🐛 Bug Fixes
- Correctly execute `make serve`
## [0.3.0] - 2025-03-06 ## [0.3.0] - 2025-03-06
### 🚀 Features ### 🚀 Features

View file

@ -16,7 +16,7 @@ ENV PYTHONPATH="$APP_PATH"
ENV PATH="$APP_HOME/.local/bin:$PATH" ENV PATH="$APP_HOME/.local/bin:$PATH"
# hadolint ignore=DL3001,DL3008,DL3018 # hadolint ignore=DL3001,DL3008,DL3018
RUN apk add --no-cache make python3~=3.12 \ RUN apk add --no-cache make python3~=3.13 curl git \
&& adduser -S -u "$USERID" -h "$APP_HOME" "$USERNAME" \ && adduser -S -u "$USERID" -h "$APP_HOME" "$USERNAME" \
&& mkdir -p "$APP_PATH" \ && mkdir -p "$APP_PATH" \
&& chown -R "$USERID:$GROUPID" "$APP_PATH" && chown -R "$USERID:$GROUPID" "$APP_PATH"
@ -40,6 +40,7 @@ FROM base AS serve
CMD ["make", "serve"] CMD ["make", "serve"]
FROM base AS migrate FROM base AS migrate
CMD ["make", "migrate"] CMD ["make", "migrate"]

View file

@ -11,6 +11,9 @@ fmt--mypy:
fmt--add-noqa: fmt--add-noqa:
uvx ruff check --add-noqa . uvx ruff check --add-noqa .
fmt--autoupdate:
uvx pre-commit autoupdate
.PHONY: tests .PHONY: tests
tests: tests:
@ -19,7 +22,7 @@ tests:
uv run coverage xml uv run coverage xml
serve: serve:
uv run uvicorn src.apps.httpapi.litestar.main:app uv run python src/apps/httpapi/litestar/main.py
build: build:
docker build . -t git.roboces.dev/catalin/$(PROJECT_NAME):$(PROJECT_TAG) --target $(PROJECT_TARGET) docker build . -t git.roboces.dev/catalin/$(PROJECT_NAME):$(PROJECT_TAG) --target $(PROJECT_TARGET)

View file

@ -1,6 +1,6 @@
apiVersion: v2 apiVersion: v2
appVersion: 0.3.0 appVersion: 0.3.7
description: A Helm chart for Kubernetes description: A Helm chart for Kubernetes
name: huesoporro name: huesoporro
type: application type: application
version: 0.3.0 version: 0.3.7

View file

@ -8,7 +8,7 @@ fullnameOverride: ''
image: image:
pullPolicy: Always pullPolicy: Always
repository: git.roboces.dev/catalin/huesoporro repository: git.roboces.dev/catalin/huesoporro
tag: 0.3.0 tag: 0.3.7
imagePullSecrets: [] imagePullSecrets: []
ingress: ingress:
annotations: {} annotations: {}

View file

@ -19,10 +19,10 @@
"flake-compat": { "flake-compat": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1733328505, "lastModified": 1747046372,
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -46,10 +46,31 @@
"type": "github" "type": "github"
} }
}, },
"git-hooks": {
"inputs": {
"flake-compat": "flake-compat",
"gitignore": "gitignore",
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1747372754,
"owner": "cachix",
"repo": "git-hooks.nix",
"rev": "80479b6ec16fefd9c1db3ea13aeb038c60530f46",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "git-hooks.nix",
"type": "github"
}
},
"gitignore": { "gitignore": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": [
"pre-commit-hooks", "git-hooks",
"nixpkgs" "nixpkgs"
] ]
}, },
@ -83,7 +104,7 @@
}, },
"nixpkgs-python": { "nixpkgs-python": {
"inputs": { "inputs": {
"flake-compat": "flake-compat", "flake-compat": "flake-compat_2",
"nixpkgs": [ "nixpkgs": [
"nixpkgs" "nixpkgs"
] ]
@ -101,33 +122,15 @@
"type": "github" "type": "github"
} }
}, },
"pre-commit-hooks": {
"inputs": {
"flake-compat": "flake-compat_2",
"gitignore": "gitignore",
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1737465171,
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"type": "github"
}
},
"root": { "root": {
"inputs": { "inputs": {
"devenv": "devenv", "devenv": "devenv",
"git-hooks": "git-hooks",
"nixpkgs": "nixpkgs", "nixpkgs": "nixpkgs",
"nixpkgs-python": "nixpkgs-python", "nixpkgs-python": "nixpkgs-python",
"pre-commit-hooks": "pre-commit-hooks" "pre-commit-hooks": [
"git-hooks"
]
} }
} }
}, },

View file

@ -0,0 +1,21 @@
"""
This module contains a Caribou migration.
Migration Name: active_quotes
Migration Version: 20250606143836
"""
def upgrade(connection):
# add `is_active` column to the `quotes` table
sql = """
ALTER TABLE quotes
ADD COLUMN is_active BOOLEAN DEFAULT TRUE;
"""
connection.execute(sql)
connection.commit()
def downgrade(connection):
# add your downgrade step here
pass

View file

@ -1,6 +1,6 @@
[project] [project]
name = "huesoporro" name = "huesoporro"
version = "0.3.0" version = "0.3.7"
description = "Misc Twitch bot" description = "Misc Twitch bot"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@ -19,11 +19,13 @@ dependencies = [
"caribou>=0.4.1", "caribou>=0.4.1",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
"pyjwt>=2.10.1", "pyjwt>=2.10.1",
"twitchio>=2.10.0", "twitchio==2.10.0",
"redis>=5.2.1", "redis>=5.2.1",
"pytz>=2024.2", "pytz>=2024.2",
"discord-py>=2.4.0", "discord-py>=2.4.0",
"tenacity>=9.0.0", "tenacity>=9.0.0",
"uvicorn>=0.34.0",
"sniffio>=1.3.1",
] ]
[project.scripts] [project.scripts]

6
renovate.json5 Normal file
View file

@ -0,0 +1,6 @@
{
$schema: "https://docs.renovatebot.com/renovate-schema.json",
lockFileMaintenance: {
enabled: true,
},
}

View file

@ -13,7 +13,7 @@ app = Typer()
@app.command() @app.command()
def import_vod_cc(channel_name: str, youtube_url: str, db_path: Path | None = None): def import_vod(channel_name: str, youtube_url: str, db_path: Path | None = None):
logger.info(f"Importing VOD closed captions for {channel_name} from {youtube_url}") logger.info(f"Importing VOD closed captions for {channel_name} from {youtube_url}")
s = Settings.get(db_filepath=db_path) s = Settings.get(db_filepath=db_path)
import_from_vod_action = ImportFromVODAction( import_from_vod_action = ImportFromVODAction(

View file

@ -7,6 +7,7 @@ from huesoporro.actions.chatbot.create_or_update_chatbot import (
from huesoporro.actions.chatbot.get_chatbot_by_user_id import GetChatbotByUserIdAction from huesoporro.actions.chatbot.get_chatbot_by_user_id import GetChatbotByUserIdAction
from huesoporro.actions.users.authenticate_user import AuthenticateUserAction from huesoporro.actions.users.authenticate_user import AuthenticateUserAction
from huesoporro.actions.users.get_user_by_jwt import GetUserByJWTAction from huesoporro.actions.users.get_user_by_jwt import GetUserByJWTAction
from huesoporro.bot import BotsManager
from huesoporro.infra.authenticator import TwitchAuthenticator from huesoporro.infra.authenticator import TwitchAuthenticator
from huesoporro.infra.repos import ChatbotRepo, UserRepo from huesoporro.infra.repos import ChatbotRepo, UserRepo
from huesoporro.libs.db import MarkovDatabase from huesoporro.libs.db import MarkovDatabase
@ -28,29 +29,29 @@ from huesoporro.svc.users_svcs import (
) )
def get_settings() -> Settings: async def get_settings() -> Settings:
return Settings.get() return Settings.get()
def get_authenticator(s: Settings) -> TwitchAuthenticator: async def get_authenticator(s: Settings) -> TwitchAuthenticator:
return TwitchAuthenticator(s=s) return TwitchAuthenticator(s=s)
def get_chatbot_repo(s: Settings): async def get_chatbot_repo(s: Settings):
return ChatbotRepo(s=s) return ChatbotRepo(s=s)
def get_get_chatbot_by_user_id_svc(chatbot_repo: ChatbotRepo): async def get_get_chatbot_by_user_id_svc(chatbot_repo: ChatbotRepo):
return GetChatbotByUserIdSvc(repo=chatbot_repo) return GetChatbotByUserIdSvc(repo=chatbot_repo)
def get_get_tokens_by_auth_code_svc( async def get_get_tokens_by_auth_code_svc(
twitch_authenticator: TwitchAuthenticator, s: Settings twitch_authenticator: TwitchAuthenticator, s: Settings
): ):
return GetTwitchAuthByAuthCodeSvc(s=s, authenticator=twitch_authenticator) return GetTwitchAuthByAuthCodeSvc(s=s, authenticator=twitch_authenticator)
def get_create_chatbot_svc(chatbot_repo: ChatbotRepo): async def get_create_chatbot_svc(chatbot_repo: ChatbotRepo):
return CreateChatbotSvc(repo=chatbot_repo) return CreateChatbotSvc(repo=chatbot_repo)
@ -58,19 +59,19 @@ async def get_user_repo(s: Settings):
return UserRepo(s=s) return UserRepo(s=s)
def get_create_user_svc(user_repo: UserRepo): async def get_create_user_svc(user_repo: UserRepo):
return CreateUserSvc(user_repo=user_repo) return CreateUserSvc(user_repo=user_repo)
def get_update_user_svc(user_repo: UserRepo): async def get_update_user_svc(user_repo: UserRepo):
return UpdateUserSvc(user_repo=user_repo) return UpdateUserSvc(user_repo=user_repo)
def get_refresh_token_svc(twitch_authenticator: TwitchAuthenticator): async def get_refresh_token_svc(twitch_authenticator: TwitchAuthenticator):
return RefreshTokenSvc(twitch_authenticator=twitch_authenticator) return RefreshTokenSvc(twitch_authenticator=twitch_authenticator)
def get_is_valid_token_svc(twitch_authenticator: TwitchAuthenticator): async def get_is_valid_token_svc(twitch_authenticator: TwitchAuthenticator):
return IsValidTokenSvc(authenticator=twitch_authenticator) return IsValidTokenSvc(authenticator=twitch_authenticator)
@ -118,11 +119,11 @@ async def get_sentences_storer_svc(db: MarkovDatabase):
return SentenceStorerSvc(db=db) return SentenceStorerSvc(db=db)
def get_update_chatbot_svc(chatbot_repo: ChatbotRepo): async def get_update_chatbot_svc(chatbot_repo: ChatbotRepo):
return UpdateChatbotSvc(repo=chatbot_repo) return UpdateChatbotSvc(repo=chatbot_repo)
def get_create_or_update_chatbot_action( async def get_create_or_update_chatbot_action(
create_chatbot_svc: CreateChatbotSvc, create_chatbot_svc: CreateChatbotSvc,
update_chatbot_svc: UpdateChatbotSvc, update_chatbot_svc: UpdateChatbotSvc,
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc, get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc,
@ -134,7 +135,7 @@ def get_create_or_update_chatbot_action(
) )
def get_get_chatbot_by_user_id_action( async def get_get_chatbot_by_user_id_action(
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc, get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc,
): ):
return GetChatbotByUserIdAction( return GetChatbotByUserIdAction(
@ -158,6 +159,10 @@ async def get_authenticate_action(
) )
async def get_bot_manager(s: Settings):
return BotsManager(s=s)
async def chatbot( async def chatbot(
get_chatbot_by_user_id_action: GetChatbotByUserIdAction, get_chatbot_by_user_id_action: GetChatbotByUserIdAction,
create_or_update_chatbot_action: CreateOrUpdateChatbotAction, create_or_update_chatbot_action: CreateOrUpdateChatbotAction,

View file

@ -1,4 +1,5 @@
import httpx import httpx
import uvicorn
from litestar import Litestar, get from litestar import Litestar, get
from litestar.contrib.jinja import JinjaTemplateEngine from litestar.contrib.jinja import JinjaTemplateEngine
from litestar.di import Provide from litestar.di import Provide
@ -10,6 +11,7 @@ from apps.httpapi.litestar.dependencies import (
authenticate, authenticate,
get_authenticate_action, get_authenticate_action,
get_authenticator, get_authenticator,
get_bot_manager,
get_chatbot_repo, get_chatbot_repo,
get_create_chatbot_svc, get_create_chatbot_svc,
get_create_or_update_chatbot_action, get_create_or_update_chatbot_action,
@ -42,12 +44,11 @@ from apps.httpapi.litestar.routes.api import (
save_bot_settings, save_bot_settings,
) )
from apps.httpapi.litestar.routes.auth import get_code, login from apps.httpapi.litestar.routes.auth import get_code, login
from huesoporro.bot import BotsManager
from huesoporro.settings import Settings from huesoporro.settings import Settings
@get("/healthz") @get("/healthz")
def get_health() -> dict: async def get_health() -> dict:
return {"status": "ok"} return {"status": "ok"}
@ -88,7 +89,7 @@ def create_app():
"s": Provide(get_settings, use_cache=True), "s": Provide(get_settings, use_cache=True),
"a": Provide(get_authenticator, use_cache=True), "a": Provide(get_authenticator, use_cache=True),
"user": Provide(authenticate), "user": Provide(authenticate),
"bm": Provide(BotsManager, use_cache=True), "bm": Provide(get_bot_manager, use_cache=True),
"sss": Provide(get_sentences_storer_svc), "sss": Provide(get_sentences_storer_svc),
"twitch_authenticator": Provide(get_authenticator), "twitch_authenticator": Provide(get_authenticator),
"authenticate_action": Provide(get_authenticate_action), "authenticate_action": Provide(get_authenticate_action),
@ -113,3 +114,9 @@ def create_app():
app = create_app() app = create_app()
if __name__ == "__main__":
s = Settings.get()
config = uvicorn.Config("main:app", host=s.host, port=s.port, log_level="info")
server = uvicorn.Server(config)
server.run()

View file

@ -12,8 +12,14 @@ class CreateQuoteAction(BaseModel):
create_quote_svc: CreateQuoteSvc create_quote_svc: CreateQuoteSvc
is_mod_svc: IsModSvc is_mod_svc: IsModSvc
async def run( async def run( # noqa: PLR0913
self, user: User, channel: str, quote: str, author: str, username: str self,
user: User,
channel: str,
quote: str,
author: str,
username: str,
is_active: bool = True,
) -> Quote | None: ) -> Quote | None:
if not await self.is_mod_svc.run(user=user, username=username, channel=channel): if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
return None return None
@ -23,6 +29,7 @@ class CreateQuoteAction(BaseModel):
author=author, author=author,
channel_name=channel, channel_name=channel,
created_at=datetime.datetime.now(datetime.UTC), created_at=datetime.datetime.now(datetime.UTC),
is_active=is_active,
last_updated_at=datetime.datetime.now(datetime.UTC), last_updated_at=datetime.datetime.now(datetime.UTC),
) )
return await self.create_quote_svc.run(new_quote) return await self.create_quote_svc.run(new_quote)

View file

@ -11,7 +11,7 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from twitchio import Channel from twitchio import Channel, Message
from twitchio.ext import commands, routines from twitchio.ext import commands, routines
from huesoporro.actions.quotes.create_quote_action import CreateQuoteAction from huesoporro.actions.quotes.create_quote_action import CreateQuoteAction
@ -151,6 +151,7 @@ class MessageType(StrEnum):
YES = "YES" YES = "YES"
WHAT = "WHAT" WHAT = "WHAT"
LAUGH = "LAUGH" LAUGH = "LAUGH"
ANO_SUFFIX = "ANO_SUFFIX"
OTHER = "OTHER" OTHER = "OTHER"
@ -168,6 +169,14 @@ class MessageHandler:
"keking", "keking",
"KEKW", "KEKW",
"OMEGADANCEBUTFAST", "OMEGADANCEBUTFAST",
"xdd",
"xdding",
]
self.ano_suffix_reply_patterns = [
"me la agarras con la mano. venga, tira",
"me la agarras con la mano, espabila",
"me la agarras con la mano y te falta calle",
"vegetasmile",
] ]
self.send = channel_send_func self.send = channel_send_func
@ -175,18 +184,22 @@ class MessageHandler:
"""Determines the type of message based on its content""" """Determines the type of message based on its content"""
if content.startswith("!"): if content.startswith("!"):
return MessageType.COMMAND return MessageType.COMMAND
if content == "Yes": if content in ["Yes", "yes"]:
return MessageType.YES return MessageType.YES
if content.startswith("WHAT"): if content.startswith("WHAT"):
return MessageType.WHAT return MessageType.WHAT
if content.endswith("ano") and len(content) > 3: # noqa: PLR2004
return MessageType.ANO_SUFFIX
if content in self.laugh_patterns: if content in self.laugh_patterns:
return MessageType.LAUGH return MessageType.LAUGH
return MessageType.OTHER return MessageType.OTHER
async def handle_laugh(self) -> str: def handle_laugh(self) -> str:
"""Handles laugh messages"""
return random.choice(self.laugh_patterns) # noqa: S311 return random.choice(self.laugh_patterns) # noqa: S311
def handle_ano_suffix(self) -> str:
return random.choice(self.ano_suffix_reply_patterns) # noqa: S311
class SaveMessagesCog(commands.Cog): class SaveMessagesCog(commands.Cog):
def __init__(self, bot: Bot): def __init__(self, bot: Bot):
@ -200,6 +213,7 @@ class SaveMessagesCog(commands.Cog):
MessageType.YES: self._create_typed_send("yes"), MessageType.YES: self._create_typed_send("yes"),
MessageType.WHAT: self._create_typed_send("what"), MessageType.WHAT: self._create_typed_send("what"),
MessageType.LAUGH: self._create_typed_send("laugh"), MessageType.LAUGH: self._create_typed_send("laugh"),
MessageType.ANO_SUFFIX: self._create_typed_send("ano_suffix"),
} }
for func in self.send_functions.values(): for func in self.send_functions.values():
@ -221,19 +235,53 @@ class SaveMessagesCog(commands.Cog):
if hasattr(self, "current_message"): if hasattr(self, "current_message"):
await self.current_message.channel.send(content) await self.current_message.channel.send(content)
def is_bot_mention(self, tok: str) -> bool:
return tok.lower() == str(self.bot.nick).lower()
async def _handle_bot_mention(self, message: Message) -> str | None:
content = (message.content or "").strip()
if not content:
return None
tokens = content.split()
contains_mention = any(self.is_bot_mention(t) for t in tokens)
if not contains_mention:
return None
# Find the first non-mention token as seed
non_mention_tokens = (
t.strip(".,!?;:") for t in tokens if not self.is_bot_mention(t)
)
seed = next((t for t in non_mention_tokens if t), None)
if not seed:
return None
sentence = await self.generate_svc.run(seed)
if not sentence:
return None
await message.channel.send(f"@{message.author.name} {sentence}")
return sentence
@commands.Cog.event() @commands.Cog.event()
async def event_message(self, message): async def event_message(self, message):
"""Main message event handler""" """Main message event handler"""
if not message.author: if not message.author:
return return
# Store reference to current message for send functions
self.current_message = message self.current_message = message
# Store the message content
await self.store_svc.run(message.content) await self.store_svc.run(message.content)
# Determine message type and handle accordingly # If the message contains a mention to this bot, reply by generating
# a sentence from the first word that is not the bot username itself.
if await self._handle_bot_mention(message):
# If the bot actually replies with something, it should not try to send
# any other type of reply
return
msg_type = self.message_handler.get_message_type(message.content) msg_type = self.message_handler.get_message_type(message.content)
response = None response = None
@ -246,12 +294,15 @@ class SaveMessagesCog(commands.Cog):
case MessageType.WHAT: case MessageType.WHAT:
response = "WHAT Ramon" response = "WHAT Ramon"
case MessageType.LAUGH: case MessageType.LAUGH:
response = await self.message_handler.handle_laugh() response = self.message_handler.handle_laugh()
case MessageType.ANO_SUFFIX:
response = (
f"@{message.author.name} {self.message_handler.handle_ano_suffix()}"
)
case MessageType.OTHER: case MessageType.OTHER:
return return
if response and msg_type in self.send_functions: if response and msg_type in self.send_functions:
# Use the type-specific send function
await self.backoff_svc.call_async(self.send_functions[msg_type], response) await self.backoff_svc.call_async(self.send_functions[msg_type], response)

View file

@ -42,9 +42,7 @@ class IRepo(BaseModel, ABC, Generic[T]):
pass # pragma: no cover pass # pragma: no cover
@abstractmethod @abstractmethod
async def list( async def list(self, offset: int = 0, limit: int = 10, auto_commit=True) -> list[T]:
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True
) -> list[T]:
pass # pragma: no cover pass # pragma: no cover
@ -64,7 +62,9 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
""" """
SELECT * FROM users WHERE id = ? SELECT *
FROM users
WHERE id = ?
""", """,
(obj_id.hex,), (obj_id.hex,),
) as cursor, ) as cursor,
@ -81,8 +81,7 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""INSERT INTO users (id, username, external_auth, created_at, last_updated_at) """INSERT INTO users (id, username, external_auth, created_at, last_updated_at)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?) RETURNING *
RETURNING *
""", """,
( (
obj.id.hex, obj.id.hex,
@ -108,8 +107,7 @@ class UserRepo(IRepo[User]):
SET username = ?, SET username = ?,
external_auth = ?, external_auth = ?,
last_updated_at = ? last_updated_at = ?
WHERE id = ? WHERE id = ? RETURNING *
RETURNING *
""", """,
( (
obj.username, obj.username,
@ -126,7 +124,9 @@ class UserRepo(IRepo[User]):
async with self.get_client(auto_commit=auto_commit) as db: async with self.get_client(auto_commit=auto_commit) as db:
await db.execute( await db.execute(
""" """
DELETE FROM users WHERE id = ? DELETE
FROM users
WHERE id = ?
""", """,
(obj.id.hex,), (obj.id.hex,),
) )
@ -136,7 +136,9 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM users WHERE username = ? SELECT *
FROM users
WHERE username = ?
""", """,
(user,), (user,),
) as cursor, ) as cursor,
@ -153,7 +155,7 @@ class UserRepo(IRepo[User]):
) )
async def list( # type: ignore[empty-body] async def list( # type: ignore[empty-body]
self, obj: User, offset: int = 0, limit: int = 10, auto_commit=True self, offset: int = 0, limit: int = 10, auto_commit=True
) -> list[User]: ) -> list[User]:
pass # pragma: no cover pass # pragma: no cover
@ -171,16 +173,18 @@ class QuoteRepo(IRepo[Quote]):
channel_name=data["channel"], channel_name=data["channel"],
created_at=data["created_at"], created_at=data["created_at"],
last_updated_at=data["last_updated_at"], last_updated_at=data["last_updated_at"],
is_active=data["is_active"],
) )
async def create(self, obj: Quote, auto_commit=True) -> Quote: async def create(self, obj: Quote, auto_commit=True) -> Quote:
if await self.get_by_quote(obj.quote):
raise ValueError(f"Quote {obj.quote} already exists")
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
""" """
INSERT INTO quotes (id, quote, author, channel, created_at, last_updated_at) INSERT INTO quotes (id, quote, author, channel, created_at, is_active, last_updated_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *
RETURNING *
""", """,
( (
obj.id.hex, obj.id.hex,
@ -188,6 +192,7 @@ class QuoteRepo(IRepo[Quote]):
obj.author, obj.author,
obj.channel_name, obj.channel_name,
obj.created_at, obj.created_at,
obj.is_active,
obj.last_updated_at, obj.last_updated_at,
), ),
) as cursor, ) as cursor,
@ -196,28 +201,99 @@ class QuoteRepo(IRepo[Quote]):
return self._deserialize(data) return self._deserialize(data)
async def update(self, obj: Quote, auto_commit=True) -> Quote: # type: ignore[empty-body] async def update(self, obj: Quote, auto_commit=True) -> Quote: # type: ignore[empty-body]
pass # pragma: no cover if not await self.get_by_id(obj.id):
raise ValueError(f"Quote {obj.id} does not exist")
async with (
self.get_client(auto_commit=auto_commit) as db,
await db.execute(
"""
UPDATE quotes
SET quote = ?,
author = ?,
channel = ?,
is_active = ?,
last_updated_at = ?
WHERE id = ? RETURNING *
""",
(
obj.quote,
obj.author,
obj.channel_name,
obj.is_active,
utils.get_utc_now(),
obj.id.hex,
),
) as cursor,
):
data = await cursor.fetchone()
return self._deserialize(data)
async def delete(self, obj: Quote, auto_commit=True): async def delete(self, obj: Quote, auto_commit=True):
pass # pragma: no cover async with self.get_client(auto_commit=auto_commit) as db:
await db.execute(
"""
DELETE
FROM quotes
WHERE id = ?
""",
(obj.id.hex,),
)
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Quote | None: # type: ignore[empty-body] async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Quote | None: # type: ignore[empty-body]
pass # pragma: no cover async with (
self.get_client(auto_commit=auto_commit) as db,
db.execute(
"""
SELECT *
FROM quotes
WHERE id = ?
""",
(obj_id.hex,),
) as cursor,
):
data = await cursor.fetchone()
if not data:
return None
return self._deserialize(data)
async def get_by_quote(self, quote: str, auto_commit=True) -> Quote | None:
async with (
self.get_client(auto_commit=auto_commit) as db,
db.execute(
"""
SELECT *
FROM quotes
WHERE quote = ?
""",
(quote,),
) as cursor,
):
data = await cursor.fetchone()
if not data:
return None
return self._deserialize(data)
async def list( # type: ignore[empty-body] async def list( # type: ignore[empty-body]
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True self, offset: int = 0, limit: int = 10, auto_commit=True
) -> list[T]: ) -> list[Quote]:
pass # pragma: no cover async with self.get_client() as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM quotes LIMIT ? OFFSET ?", (limit, offset)
) as cursor:
results = await cursor.fetchall()
return [self._deserialize(result) for result in results]
async def get_random(self, channel_name: str, auto_commit=True) -> Quote | None: async def get_random(self, channel_name: str, auto_commit=True) -> Quote | None:
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM quotes SELECT *
FROM quotes
WHERE channel = ? WHERE channel = ?
ORDER BY RANDOM() AND is_active = 1
LIMIT 1 ORDER BY RANDOM() LIMIT 1
""", """,
(channel_name,), (channel_name,),
) as cursor, ) as cursor,
@ -247,16 +323,14 @@ class ChatbotRepo(IRepo[Chatbot]):
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""INSERT INTO chatbot ( """INSERT INTO chatbot (id,
id,
user_id, user_id,
automatic_generation_timer, automatic_generation_timer,
automatic_quote_timer, automatic_quote_timer,
mods, mods,
created_at, created_at,
last_updated_at last_updated_at)
) VALUES(?,?,?,?,?,?,?) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *
RETURNING *
""", """,
( (
obj.id.hex, obj.id.hex,
@ -278,13 +352,12 @@ class ChatbotRepo(IRepo[Chatbot]):
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""UPDATE chatbot SET """UPDATE chatbot
automatic_generation_timer = ?, SET automatic_generation_timer = ?,
automatic_quote_timer = ?, automatic_quote_timer = ?,
mods = ?, mods = ?,
last_updated_at = ? last_updated_at = ?
WHERE user_id = ? WHERE user_id = ? RETURNING *
RETURNING *
""", """,
( (
obj.automatic_generation_timer, obj.automatic_generation_timer,
@ -298,11 +371,22 @@ class ChatbotRepo(IRepo[Chatbot]):
data = await cursor.fetchone() data = await cursor.fetchone()
return self._deserialize(data) return self._deserialize(data)
async def delete(self, obj: T, auto_commit=True): async def delete(self, obj: Chatbot, auto_commit=True):
pass # pragma: no cover if not await self.get_by_id(obj.id):
raise ValueError(f"Chatbot {obj.id} does not exist")
async with self.get_client() as db:
await db.execute("DELETE FROM chatbot WHERE id = ?", (obj.id.hex,))
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Chatbot | None: # type: ignore[empty-body] async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Chatbot | None: # type: ignore[empty-body]
pass # pragma: no cover async with self.get_client() as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM chatbot WHERE id = ?", (obj_id.hex,)
) as cursor:
result = await cursor.fetchone()
if not result:
return None
return self._deserialize(result)
async def get_by_user_id(self, user_id: UUID) -> Chatbot | None: async def get_by_user_id(self, user_id: UUID) -> Chatbot | None:
async with self.get_client() as db: async with self.get_client() as db:
@ -316,6 +400,12 @@ class ChatbotRepo(IRepo[Chatbot]):
return self._deserialize(result) return self._deserialize(result)
async def list( # type: ignore[empty-body] async def list( # type: ignore[empty-body]
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True self, offset: int = 0, limit: int = 10, auto_commit=True
) -> list[T]: ) -> list[Chatbot]:
pass # pragma: no cover async with self.get_client() as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM chatbot LIMIT ? OFFSET ?", (limit, offset)
) as cursor:
results = await cursor.fetchall()
return [self._deserialize(result) for result in results]

View file

@ -105,6 +105,7 @@ class Quote(BaseModel):
quote: str quote: str
author: str author: str
channel_name: str channel_name: str
is_active: bool = True
created_at: datetime.datetime = Field(default_factory=utils.get_utc_now) created_at: datetime.datetime = Field(default_factory=utils.get_utc_now)
last_updated_at: datetime.datetime = Field(default_factory=utils.get_utc_now) last_updated_at: datetime.datetime = Field(default_factory=utils.get_utc_now)

View file

@ -1,5 +1,5 @@
import tempfile import tempfile
from collections.abc import Generator from collections.abc import Iterator
from pathlib import Path from pathlib import Path
import yt_dlp import yt_dlp
@ -9,7 +9,7 @@ from pydantic import BaseModel
class DownloadClosedCaptionsSvc(BaseModel): class DownloadClosedCaptionsSvc(BaseModel):
@staticmethod @staticmethod
def run(youtube_url: str, sub_lang: str = "es") -> Generator[Path, None, None]: def run(youtube_url: str, sub_lang: str = "es") -> Iterator[Path]:
"""Download closed captions from a yt video and save it to a temp file """Download closed captions from a yt video and save it to a temp file
Args: Args:

View file

@ -65,7 +65,7 @@ def db(s, cdb):
pass pass
async def list( # type: ignore[empty-body] async def list( # type: ignore[empty-body]
self, obj: BaseModel, offset: int = 0, limit: int = 10, auto_commit=True self, offset: int = 0, limit: int = 10, auto_commit=True
) -> list[BaseModel]: ) -> list[BaseModel]:
pass pass
@ -169,6 +169,16 @@ async def chatbot(chatbot_factory, user):
return chatbot_factory.build(user_id=user.id) return chatbot_factory.build(user_id=user.id)
@pytest.fixture
async def five_chatbots(chatbot_factory, user):
return [chatbot_factory.build() for _ in range(5)]
@pytest.fixture
async def persisted_five_chatbots(five_chatbots, chatbot_repo):
return [await chatbot_repo.create(chatbot) for chatbot in five_chatbots]
@pytest.fixture @pytest.fixture
async def persisted_chatbot(chatbot_repo, chatbot, persisted_user): async def persisted_chatbot(chatbot_repo, chatbot, persisted_user):
return await chatbot_repo.create(chatbot) return await chatbot_repo.create(chatbot)
@ -204,11 +214,22 @@ async def quote(quote_factory):
return quote_factory.build() return quote_factory.build()
@pytest.fixture
async def five_quotes(quote_factory):
return [quote_factory.build() for _ in range(5)]
@pytest.fixture @pytest.fixture
async def persisted_quote(quote_repo, quote): async def persisted_quote(quote_repo, quote):
quote.is_active = True
return await quote_repo.create(quote) return await quote_repo.create(quote)
@pytest.fixture
async def persisted_five_quotes(five_quotes, quote_repo):
return [await quote_repo.create(quote) for quote in five_quotes]
@pytest.fixture @pytest.fixture
async def create_quote_svc(quote_repo): async def create_quote_svc(quote_repo):
return CreateQuoteSvc(repo=quote_repo) return CreateQuoteSvc(repo=quote_repo)

View file

@ -234,6 +234,7 @@ async def test_create_quote_action(
username=user.username, username=user.username,
channel=user.username, channel=user.username,
quote=quote.quote, quote=quote.quote,
is_active=quote.is_active,
author=quote.author, author=quote.author,
) )
assert new_quote assert new_quote

149
tests/test_bot_mentions.py Normal file
View file

@ -0,0 +1,149 @@
import types
import pytest
from huesoporro.bot import SaveMessagesCog
class DummyDB:
def __init__(self, *args, **kwargs):
pass
class FakeChannel:
def __init__(self):
self.sent: list[str] = []
async def send(self, content: str):
self.sent.append(content)
class FakeAuthor:
def __init__(self, name: str):
self.name = name
class FakeMessage:
def __init__(self, content: str | None, author_name: str = "alice"):
self.content = content
self.author = FakeAuthor(author_name)
self.channel = FakeChannel()
class FakeBot:
def __init__(self, nick: str = "Junie", channel: str = "testchan"):
self.nick = nick
self.channel = channel
@pytest.fixture(autouse=True)
def patch_markov_and_svcs(monkeypatch):
monkeypatch.setattr("huesoporro.bot.MarkovDatabase", DummyDB)
class _DummyStoreSvc:
def __init__(self, db):
self.db = db
async def run(self, content: str | None):
return None
class _DummyGenSvc:
def __init__(self, db):
self.db = db
async def run(self, seed: str):
return None
monkeypatch.setattr("huesoporro.bot.SentenceStorerSvc", _DummyStoreSvc)
monkeypatch.setattr("huesoporro.bot.SentenceGeneratorSvc", _DummyGenSvc)
@pytest.fixture
def cog(monkeypatch) -> SaveMessagesCog:
bot = FakeBot()
return SaveMessagesCog(bot) # type: ignore[arg-type]
def make_async_fn(result=None, exc: Exception | None = None):
async def _fn(*args, **kwargs):
if exc:
raise exc
return result
return _fn
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_on_empty_content(cog: SaveMessagesCog):
msg = FakeMessage(content=None)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_when_no_mention(cog: SaveMessagesCog):
msg = FakeMessage(content="hello world")
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_when_only_mention(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=cog.bot.nick) # type: ignore[attr-defined]
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_generates_and_sends_reply(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content="juNie hello there")
monkeypatch.setattr(
cog, "generate_svc", types.SimpleNamespace(run=make_async_fn("foo bar"))
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res == "foo bar"
assert len(msg.channel.sent) == 1
assert msg.channel.sent[0].startswith(f"@{msg.author.name} ")
assert msg.channel.sent[0].endswith("foo bar")
@pytest.mark.asyncio
async def test_handle_bot_mention_no_send_when_generator_returns_none(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=f"{cog.bot.nick} hello") # type: ignore[attr-defined]
monkeypatch.setattr(
cog, "generate_svc", types.SimpleNamespace(run=make_async_fn(None))
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_swallows_exceptions_and_returns_none(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=f"{cog.bot.nick} hello") # type: ignore[attr-defined]
monkeypatch.setattr(
cog,
"generate_svc",
types.SimpleNamespace(run=make_async_fn(exc=RuntimeError("boom"))),
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []

View file

@ -88,8 +88,79 @@ async def test_update_chatbot_raises_value_error_on_non_existing_chatbot(
await chatbot_repo.update(chatbot) await chatbot_repo.update(chatbot)
async def test_delete_chatbot_raises_value_error_on_non_existing_chatbot(
chatbot_repo, chatbot
):
with pytest.raises(ValueError, match=f"Chatbot {chatbot.id} does not exist"):
await chatbot_repo.delete(chatbot)
async def test_delete_chatbot(chatbot_repo, persisted_chatbot):
assert await chatbot_repo.delete(persisted_chatbot) is None
assert await chatbot_repo.get_by_id(persisted_chatbot.id) is None
async def test_get_by_id(chatbot_repo, persisted_chatbot):
chatbot = await chatbot_repo.get_by_id(persisted_chatbot.id)
assert chatbot == persisted_chatbot
async def test_list(chatbot_repo, persisted_five_chatbots):
chatbots = await chatbot_repo.list()
assert len(chatbots) == 5 # noqa: PLR2004
async def test_list_offset_limit(chatbot_repo, persisted_five_chatbots):
chatbots = await chatbot_repo.list(offset=1, limit=2)
assert len(chatbots) == 2 # noqa: PLR2004
async def test_get_random_quote(quote_repo: QuoteRepo, persisted_quote): async def test_get_random_quote(quote_repo: QuoteRepo, persisted_quote):
quote = await quote_repo.get_random(persisted_quote.channel_name) quote = await quote_repo.get_random(persisted_quote.channel_name)
assert quote assert quote
assert quote.author == persisted_quote.author assert quote.author == persisted_quote.author
assert quote.channel_name == persisted_quote.channel_name assert quote.channel_name == persisted_quote.channel_name
async def test_create_quote_raises_value_error_for_existing_quote(
quote_repo: QuoteRepo, persisted_quote
):
with pytest.raises(
ValueError, match=f"Quote {persisted_quote.quote} already exists"
):
await quote_repo.create(persisted_quote)
async def test_create_quote(quote_repo: QuoteRepo, quote_factory):
quote = quote_factory.build()
created_quote = await quote_repo.create(quote)
assert created_quote == quote
async def test_update_quote_raises_value_error_on_non_existing_quote(
quote_repo: QuoteRepo, quote
):
with pytest.raises(ValueError, match=f"Quote {quote.id} does not exist"):
await quote_repo.update(quote)
async def test_update_quote(quote_repo: QuoteRepo, persisted_quote):
persisted_quote.quote = "new quote"
updated_quote = await quote_repo.update(persisted_quote)
persisted_quote.last_updated_at = updated_quote.last_updated_at
assert updated_quote == persisted_quote
async def test_delete_quote(quote_repo: QuoteRepo, persisted_quote):
assert await quote_repo.delete(persisted_quote) is None
assert await quote_repo.get_by_id(persisted_quote.id) is None
async def test_list_quotes(quote_repo, persisted_five_quotes):
quotes = await quote_repo.list()
assert len(quotes) == 5 # noqa: PLR2004
async def test_list_quotes_offset_limit(quote_repo, persisted_five_quotes):
quotes = await quote_repo.list(offset=1, limit=2)
assert len(quotes) == 2 # noqa: PLR2004

2419
uv.lock generated

File diff suppressed because it is too large Load diff