Compare commits

...
Sign in to create a new pull request.

5 commits

20 changed files with 1856 additions and 1014 deletions

View file

@ -2,7 +2,7 @@ files: src|tests
exclude: ^$ exclude: ^$
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 rev: v6.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
args: [ --markdown-linebreak-ext=md ] args: [ --markdown-linebreak-ext=md ]

View file

@ -1 +1 @@
3.11 3.13

View file

@ -2,7 +2,24 @@
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
## [unreleased] ## [0.3.6] - 2025-06-06
### 🚀 Features
- Add `Quote.is_active` field
## [0.3.5] - 2025-05-27
### 🚀 Features
- Implement remaining repo methods for chatbot and quote
- Add ANO_PREFIX bot response
### ⚙️ Miscellaneous Tasks
- Add renovate lockFileMaintenance
## [0.3.3] - 2025-03-06
### 🐛 Bug Fixes ### 🐛 Bug Fixes

View file

@ -16,7 +16,7 @@ ENV PYTHONPATH="$APP_PATH"
ENV PATH="$APP_HOME/.local/bin:$PATH" ENV PATH="$APP_HOME/.local/bin:$PATH"
# hadolint ignore=DL3001,DL3008,DL3018 # hadolint ignore=DL3001,DL3008,DL3018
RUN apk add --no-cache make python3~=3.12 curl \ RUN apk add --no-cache make python3~=3.13 curl git \
&& adduser -S -u "$USERID" -h "$APP_HOME" "$USERNAME" \ && adduser -S -u "$USERID" -h "$APP_HOME" "$USERNAME" \
&& mkdir -p "$APP_PATH" \ && mkdir -p "$APP_PATH" \
&& chown -R "$USERID:$GROUPID" "$APP_PATH" && chown -R "$USERID:$GROUPID" "$APP_PATH"

View file

@ -11,6 +11,9 @@ fmt--mypy:
fmt--add-noqa: fmt--add-noqa:
uvx ruff check --add-noqa . uvx ruff check --add-noqa .
fmt--autoupdate:
uvx pre-commit autoupdate
.PHONY: tests .PHONY: tests
tests: tests:

View file

@ -1,6 +1,6 @@
apiVersion: v2 apiVersion: v2
appVersion: 0.3.4 appVersion: 0.3.7
description: A Helm chart for Kubernetes description: A Helm chart for Kubernetes
name: huesoporro name: huesoporro
type: application type: application
version: 0.3.4 version: 0.3.7

View file

@ -8,7 +8,7 @@ fullnameOverride: ''
image: image:
pullPolicy: Always pullPolicy: Always
repository: git.roboces.dev/catalin/huesoporro repository: git.roboces.dev/catalin/huesoporro
tag: 0.3.4 tag: 0.3.7
imagePullSecrets: [] imagePullSecrets: []
ingress: ingress:
annotations: {} annotations: {}

View file

@ -19,10 +19,10 @@
"flake-compat": { "flake-compat": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1733328505, "lastModified": 1747046372,
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -46,10 +46,31 @@
"type": "github" "type": "github"
} }
}, },
"git-hooks": {
"inputs": {
"flake-compat": "flake-compat",
"gitignore": "gitignore",
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1747372754,
"owner": "cachix",
"repo": "git-hooks.nix",
"rev": "80479b6ec16fefd9c1db3ea13aeb038c60530f46",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "git-hooks.nix",
"type": "github"
}
},
"gitignore": { "gitignore": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": [
"pre-commit-hooks", "git-hooks",
"nixpkgs" "nixpkgs"
] ]
}, },
@ -83,7 +104,7 @@
}, },
"nixpkgs-python": { "nixpkgs-python": {
"inputs": { "inputs": {
"flake-compat": "flake-compat", "flake-compat": "flake-compat_2",
"nixpkgs": [ "nixpkgs": [
"nixpkgs" "nixpkgs"
] ]
@ -101,33 +122,15 @@
"type": "github" "type": "github"
} }
}, },
"pre-commit-hooks": {
"inputs": {
"flake-compat": "flake-compat_2",
"gitignore": "gitignore",
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1737465171,
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"type": "github"
}
},
"root": { "root": {
"inputs": { "inputs": {
"devenv": "devenv", "devenv": "devenv",
"git-hooks": "git-hooks",
"nixpkgs": "nixpkgs", "nixpkgs": "nixpkgs",
"nixpkgs-python": "nixpkgs-python", "nixpkgs-python": "nixpkgs-python",
"pre-commit-hooks": "pre-commit-hooks" "pre-commit-hooks": [
"git-hooks"
]
} }
} }
}, },

View file

@ -0,0 +1,21 @@
"""
This module contains a Caribou migration.
Migration Name: active_quotes
Migration Version: 20250606143836
"""
def upgrade(connection):
# add `is_active` column to the `quotes` table
sql = """
ALTER TABLE quotes
ADD COLUMN is_active BOOLEAN DEFAULT TRUE;
"""
connection.execute(sql)
connection.commit()
def downgrade(connection):
# add your downgrade step here
pass

View file

@ -1,6 +1,6 @@
[project] [project]
name = "huesoporro" name = "huesoporro"
version = "0.3.4" version = "0.3.7"
description = "Misc Twitch bot" description = "Misc Twitch bot"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@ -19,11 +19,13 @@ dependencies = [
"caribou>=0.4.1", "caribou>=0.4.1",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
"pyjwt>=2.10.1", "pyjwt>=2.10.1",
"twitchio>=2.10.0", "twitchio==2.10.0",
"redis>=5.2.1", "redis>=5.2.1",
"pytz>=2024.2", "pytz>=2024.2",
"discord-py>=2.4.0", "discord-py>=2.4.0",
"tenacity>=9.0.0", "tenacity>=9.0.0",
"uvicorn>=0.34.0",
"sniffio>=1.3.1",
] ]
[project.scripts] [project.scripts]

View file

@ -117,6 +117,6 @@ app = create_app()
if __name__ == "__main__": if __name__ == "__main__":
s = Settings.get() s = Settings.get()
config = uvicorn.Config("main:app", host=s.port, port=s.port, log_level="info") config = uvicorn.Config("main:app", host=s.host, port=s.port, log_level="info")
server = uvicorn.Server(config) server = uvicorn.Server(config)
server.run() server.run()

View file

@ -12,8 +12,14 @@ class CreateQuoteAction(BaseModel):
create_quote_svc: CreateQuoteSvc create_quote_svc: CreateQuoteSvc
is_mod_svc: IsModSvc is_mod_svc: IsModSvc
async def run( async def run( # noqa: PLR0913
self, user: User, channel: str, quote: str, author: str, username: str self,
user: User,
channel: str,
quote: str,
author: str,
username: str,
is_active: bool = True,
) -> Quote | None: ) -> Quote | None:
if not await self.is_mod_svc.run(user=user, username=username, channel=channel): if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
return None return None
@ -23,6 +29,7 @@ class CreateQuoteAction(BaseModel):
author=author, author=author,
channel_name=channel, channel_name=channel,
created_at=datetime.datetime.now(datetime.UTC), created_at=datetime.datetime.now(datetime.UTC),
is_active=is_active,
last_updated_at=datetime.datetime.now(datetime.UTC), last_updated_at=datetime.datetime.now(datetime.UTC),
) )
return await self.create_quote_svc.run(new_quote) return await self.create_quote_svc.run(new_quote)

View file

@ -11,7 +11,7 @@ from tenacity import (
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
) )
from twitchio import Channel from twitchio import Channel, Message
from twitchio.ext import commands, routines from twitchio.ext import commands, routines
from huesoporro.actions.quotes.create_quote_action import CreateQuoteAction from huesoporro.actions.quotes.create_quote_action import CreateQuoteAction
@ -151,6 +151,7 @@ class MessageType(StrEnum):
YES = "YES" YES = "YES"
WHAT = "WHAT" WHAT = "WHAT"
LAUGH = "LAUGH" LAUGH = "LAUGH"
ANO_SUFFIX = "ANO_SUFFIX"
OTHER = "OTHER" OTHER = "OTHER"
@ -168,6 +169,14 @@ class MessageHandler:
"keking", "keking",
"KEKW", "KEKW",
"OMEGADANCEBUTFAST", "OMEGADANCEBUTFAST",
"xdd",
"xdding",
]
self.ano_suffix_reply_patterns = [
"me la agarras con la mano. venga, tira",
"me la agarras con la mano, espabila",
"me la agarras con la mano y te falta calle",
"vegetasmile",
] ]
self.send = channel_send_func self.send = channel_send_func
@ -175,18 +184,22 @@ class MessageHandler:
"""Determines the type of message based on its content""" """Determines the type of message based on its content"""
if content.startswith("!"): if content.startswith("!"):
return MessageType.COMMAND return MessageType.COMMAND
if content == "Yes": if content in ["Yes", "yes"]:
return MessageType.YES return MessageType.YES
if content.startswith("WHAT"): if content.startswith("WHAT"):
return MessageType.WHAT return MessageType.WHAT
if content.endswith("ano") and len(content) > 3: # noqa: PLR2004
return MessageType.ANO_SUFFIX
if content in self.laugh_patterns: if content in self.laugh_patterns:
return MessageType.LAUGH return MessageType.LAUGH
return MessageType.OTHER return MessageType.OTHER
async def handle_laugh(self) -> str: def handle_laugh(self) -> str:
"""Handles laugh messages"""
return random.choice(self.laugh_patterns) # noqa: S311 return random.choice(self.laugh_patterns) # noqa: S311
def handle_ano_suffix(self) -> str:
return random.choice(self.ano_suffix_reply_patterns) # noqa: S311
class SaveMessagesCog(commands.Cog): class SaveMessagesCog(commands.Cog):
def __init__(self, bot: Bot): def __init__(self, bot: Bot):
@ -200,6 +213,7 @@ class SaveMessagesCog(commands.Cog):
MessageType.YES: self._create_typed_send("yes"), MessageType.YES: self._create_typed_send("yes"),
MessageType.WHAT: self._create_typed_send("what"), MessageType.WHAT: self._create_typed_send("what"),
MessageType.LAUGH: self._create_typed_send("laugh"), MessageType.LAUGH: self._create_typed_send("laugh"),
MessageType.ANO_SUFFIX: self._create_typed_send("ano_suffix"),
} }
for func in self.send_functions.values(): for func in self.send_functions.values():
@ -221,19 +235,53 @@ class SaveMessagesCog(commands.Cog):
if hasattr(self, "current_message"): if hasattr(self, "current_message"):
await self.current_message.channel.send(content) await self.current_message.channel.send(content)
def is_bot_mention(self, tok: str) -> bool:
return tok.lower() == str(self.bot.nick).lower()
async def _handle_bot_mention(self, message: Message) -> str | None:
content = (message.content or "").strip()
if not content:
return None
tokens = content.split()
contains_mention = any(self.is_bot_mention(t) for t in tokens)
if not contains_mention:
return None
# Find the first non-mention token as seed
non_mention_tokens = (
t.strip(".,!?;:") for t in tokens if not self.is_bot_mention(t)
)
seed = next((t for t in non_mention_tokens if t), None)
if not seed:
return None
sentence = await self.generate_svc.run(seed)
if not sentence:
return None
await message.channel.send(f"@{message.author.name} {sentence}")
return sentence
@commands.Cog.event() @commands.Cog.event()
async def event_message(self, message): async def event_message(self, message):
"""Main message event handler""" """Main message event handler"""
if not message.author: if not message.author:
return return
# Store reference to current message for send functions
self.current_message = message self.current_message = message
# Store the message content
await self.store_svc.run(message.content) await self.store_svc.run(message.content)
# Determine message type and handle accordingly # If the message contains a mention to this bot, reply by generating
# a sentence from the first word that is not the bot username itself.
if await self._handle_bot_mention(message):
# If the bot actually replies with something, it should not try to send
# any other type of reply
return
msg_type = self.message_handler.get_message_type(message.content) msg_type = self.message_handler.get_message_type(message.content)
response = None response = None
@ -246,12 +294,15 @@ class SaveMessagesCog(commands.Cog):
case MessageType.WHAT: case MessageType.WHAT:
response = "WHAT Ramon" response = "WHAT Ramon"
case MessageType.LAUGH: case MessageType.LAUGH:
response = await self.message_handler.handle_laugh() response = self.message_handler.handle_laugh()
case MessageType.ANO_SUFFIX:
response = (
f"@{message.author.name} {self.message_handler.handle_ano_suffix()}"
)
case MessageType.OTHER: case MessageType.OTHER:
return return
if response and msg_type in self.send_functions: if response and msg_type in self.send_functions:
# Use the type-specific send function
await self.backoff_svc.call_async(self.send_functions[msg_type], response) await self.backoff_svc.call_async(self.send_functions[msg_type], response)

View file

@ -62,7 +62,9 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
""" """
SELECT * FROM users WHERE id = ? SELECT *
FROM users
WHERE id = ?
""", """,
(obj_id.hex,), (obj_id.hex,),
) as cursor, ) as cursor,
@ -79,8 +81,7 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""INSERT INTO users (id, username, external_auth, created_at, last_updated_at) """INSERT INTO users (id, username, external_auth, created_at, last_updated_at)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?) RETURNING *
RETURNING *
""", """,
( (
obj.id.hex, obj.id.hex,
@ -102,13 +103,12 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
UPDATE users UPDATE users
SET username = ?, SET username = ?,
external_auth = ?, external_auth = ?,
last_updated_at = ? last_updated_at = ?
WHERE id = ? WHERE id = ? RETURNING *
RETURNING * """,
""",
( (
obj.username, obj.username,
obj.serialize_external_auth(), obj.serialize_external_auth(),
@ -124,7 +124,9 @@ class UserRepo(IRepo[User]):
async with self.get_client(auto_commit=auto_commit) as db: async with self.get_client(auto_commit=auto_commit) as db:
await db.execute( await db.execute(
""" """
DELETE FROM users WHERE id = ? DELETE
FROM users
WHERE id = ?
""", """,
(obj.id.hex,), (obj.id.hex,),
) )
@ -134,8 +136,10 @@ class UserRepo(IRepo[User]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM users WHERE username = ? SELECT *
""", FROM users
WHERE username = ?
""",
(user,), (user,),
) as cursor, ) as cursor,
): ):
@ -169,6 +173,7 @@ class QuoteRepo(IRepo[Quote]):
channel_name=data["channel"], channel_name=data["channel"],
created_at=data["created_at"], created_at=data["created_at"],
last_updated_at=data["last_updated_at"], last_updated_at=data["last_updated_at"],
is_active=data["is_active"],
) )
async def create(self, obj: Quote, auto_commit=True) -> Quote: async def create(self, obj: Quote, auto_commit=True) -> Quote:
@ -178,9 +183,8 @@ class QuoteRepo(IRepo[Quote]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
""" """
INSERT INTO quotes (id, quote, author, channel, created_at, last_updated_at) INSERT INTO quotes (id, quote, author, channel, created_at, is_active, last_updated_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *
RETURNING *
""", """,
( (
obj.id.hex, obj.id.hex,
@ -188,6 +192,7 @@ class QuoteRepo(IRepo[Quote]):
obj.author, obj.author,
obj.channel_name, obj.channel_name,
obj.created_at, obj.created_at,
obj.is_active,
obj.last_updated_at, obj.last_updated_at,
), ),
) as cursor, ) as cursor,
@ -202,18 +207,19 @@ class QuoteRepo(IRepo[Quote]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
""" """
UPDATE quotes UPDATE quotes
SET quote = ?, SET quote = ?,
author = ?, author = ?,
channel = ?, channel = ?,
last_updated_at = ? is_active = ?,
WHERE id = ? last_updated_at = ?
RETURNING * WHERE id = ? RETURNING *
""", """,
( (
obj.quote, obj.quote,
obj.author, obj.author,
obj.channel_name, obj.channel_name,
obj.is_active,
utils.get_utc_now(), utils.get_utc_now(),
obj.id.hex, obj.id.hex,
), ),
@ -226,7 +232,9 @@ class QuoteRepo(IRepo[Quote]):
async with self.get_client(auto_commit=auto_commit) as db: async with self.get_client(auto_commit=auto_commit) as db:
await db.execute( await db.execute(
""" """
DELETE FROM quotes WHERE id = ? DELETE
FROM quotes
WHERE id = ?
""", """,
(obj.id.hex,), (obj.id.hex,),
) )
@ -236,7 +244,9 @@ class QuoteRepo(IRepo[Quote]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM quotes WHERE id = ? SELECT *
FROM quotes
WHERE id = ?
""", """,
(obj_id.hex,), (obj_id.hex,),
) as cursor, ) as cursor,
@ -251,7 +261,9 @@ class QuoteRepo(IRepo[Quote]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM quotes WHERE quote = ? SELECT *
FROM quotes
WHERE quote = ?
""", """,
(quote,), (quote,),
) as cursor, ) as cursor,
@ -277,10 +289,11 @@ class QuoteRepo(IRepo[Quote]):
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
db.execute( db.execute(
""" """
SELECT * FROM quotes SELECT *
WHERE channel = ? FROM quotes
ORDER BY RANDOM() WHERE channel = ?
LIMIT 1 AND is_active = 1
ORDER BY RANDOM() LIMIT 1
""", """,
(channel_name,), (channel_name,),
) as cursor, ) as cursor,
@ -310,17 +323,15 @@ class ChatbotRepo(IRepo[Chatbot]):
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""INSERT INTO chatbot ( """INSERT INTO chatbot (id,
id, user_id,
user_id, automatic_generation_timer,
automatic_generation_timer, automatic_quote_timer,
automatic_quote_timer, mods,
mods, created_at,
created_at, last_updated_at)
last_updated_at VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *
) VALUES(?,?,?,?,?,?,?) """,
RETURNING *
""",
( (
obj.id.hex, obj.id.hex,
obj.user_id.hex, obj.user_id.hex,
@ -341,13 +352,12 @@ class ChatbotRepo(IRepo[Chatbot]):
async with ( async with (
self.get_client(auto_commit=auto_commit) as db, self.get_client(auto_commit=auto_commit) as db,
await db.execute( await db.execute(
"""UPDATE chatbot SET """UPDATE chatbot
automatic_generation_timer = ?, SET automatic_generation_timer = ?,
automatic_quote_timer = ?, automatic_quote_timer = ?,
mods = ?, mods = ?,
last_updated_at = ? last_updated_at = ?
WHERE user_id = ? WHERE user_id = ? RETURNING *
RETURNING *
""", """,
( (
obj.automatic_generation_timer, obj.automatic_generation_timer,

View file

@ -105,6 +105,7 @@ class Quote(BaseModel):
quote: str quote: str
author: str author: str
channel_name: str channel_name: str
is_active: bool = True
created_at: datetime.datetime = Field(default_factory=utils.get_utc_now) created_at: datetime.datetime = Field(default_factory=utils.get_utc_now)
last_updated_at: datetime.datetime = Field(default_factory=utils.get_utc_now) last_updated_at: datetime.datetime = Field(default_factory=utils.get_utc_now)

View file

@ -1,5 +1,5 @@
import tempfile import tempfile
from collections.abc import Generator from collections.abc import Iterator
from pathlib import Path from pathlib import Path
import yt_dlp import yt_dlp
@ -9,7 +9,7 @@ from pydantic import BaseModel
class DownloadClosedCaptionsSvc(BaseModel): class DownloadClosedCaptionsSvc(BaseModel):
@staticmethod @staticmethod
def run(youtube_url: str, sub_lang: str = "es") -> Generator[Path, None, None]: def run(youtube_url: str, sub_lang: str = "es") -> Iterator[Path]:
"""Download closed captions from a yt video and save it to a temp file """Download closed captions from a yt video and save it to a temp file
Args: Args:

View file

@ -176,7 +176,7 @@ async def five_chatbots(chatbot_factory, user):
@pytest.fixture @pytest.fixture
async def persisted_five_chatbots(five_chatbots, chatbot_repo): async def persisted_five_chatbots(five_chatbots, chatbot_repo):
return [await chatbot_repo.create(chatbot) for chat in five_chatbots] return [await chatbot_repo.create(chatbot) for chatbot in five_chatbots]
@pytest.fixture @pytest.fixture
@ -221,6 +221,7 @@ async def five_quotes(quote_factory):
@pytest.fixture @pytest.fixture
async def persisted_quote(quote_repo, quote): async def persisted_quote(quote_repo, quote):
quote.is_active = True
return await quote_repo.create(quote) return await quote_repo.create(quote)

View file

@ -234,6 +234,7 @@ async def test_create_quote_action(
username=user.username, username=user.username,
channel=user.username, channel=user.username,
quote=quote.quote, quote=quote.quote,
is_active=quote.is_active,
author=quote.author, author=quote.author,
) )
assert new_quote assert new_quote

149
tests/test_bot_mentions.py Normal file
View file

@ -0,0 +1,149 @@
import types
import pytest
from huesoporro.bot import SaveMessagesCog
class DummyDB:
def __init__(self, *args, **kwargs):
pass
class FakeChannel:
def __init__(self):
self.sent: list[str] = []
async def send(self, content: str):
self.sent.append(content)
class FakeAuthor:
def __init__(self, name: str):
self.name = name
class FakeMessage:
def __init__(self, content: str | None, author_name: str = "alice"):
self.content = content
self.author = FakeAuthor(author_name)
self.channel = FakeChannel()
class FakeBot:
def __init__(self, nick: str = "Junie", channel: str = "testchan"):
self.nick = nick
self.channel = channel
@pytest.fixture(autouse=True)
def patch_markov_and_svcs(monkeypatch):
monkeypatch.setattr("huesoporro.bot.MarkovDatabase", DummyDB)
class _DummyStoreSvc:
def __init__(self, db):
self.db = db
async def run(self, content: str | None):
return None
class _DummyGenSvc:
def __init__(self, db):
self.db = db
async def run(self, seed: str):
return None
monkeypatch.setattr("huesoporro.bot.SentenceStorerSvc", _DummyStoreSvc)
monkeypatch.setattr("huesoporro.bot.SentenceGeneratorSvc", _DummyGenSvc)
@pytest.fixture
def cog(monkeypatch) -> SaveMessagesCog:
bot = FakeBot()
return SaveMessagesCog(bot) # type: ignore[arg-type]
def make_async_fn(result=None, exc: Exception | None = None):
async def _fn(*args, **kwargs):
if exc:
raise exc
return result
return _fn
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_on_empty_content(cog: SaveMessagesCog):
msg = FakeMessage(content=None)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_when_no_mention(cog: SaveMessagesCog):
msg = FakeMessage(content="hello world")
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_returns_none_when_only_mention(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=cog.bot.nick) # type: ignore[attr-defined]
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_generates_and_sends_reply(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content="juNie hello there")
monkeypatch.setattr(
cog, "generate_svc", types.SimpleNamespace(run=make_async_fn("foo bar"))
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res == "foo bar"
assert len(msg.channel.sent) == 1
assert msg.channel.sent[0].startswith(f"@{msg.author.name} ")
assert msg.channel.sent[0].endswith("foo bar")
@pytest.mark.asyncio
async def test_handle_bot_mention_no_send_when_generator_returns_none(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=f"{cog.bot.nick} hello") # type: ignore[attr-defined]
monkeypatch.setattr(
cog, "generate_svc", types.SimpleNamespace(run=make_async_fn(None))
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []
@pytest.mark.asyncio
async def test_handle_bot_mention_swallows_exceptions_and_returns_none(
cog: SaveMessagesCog, monkeypatch
):
msg = FakeMessage(content=f"{cog.bot.nick} hello") # type: ignore[attr-defined]
monkeypatch.setattr(
cog,
"generate_svc",
types.SimpleNamespace(run=make_async_fn(exc=RuntimeError("boom"))),
)
res = await cog._handle_bot_mention(msg) # type: ignore[attr-defined]
assert res is None
assert msg.channel.sent == []

2342
uv.lock generated

File diff suppressed because it is too large Load diff