diff --git a/pyproject.toml b/pyproject.toml index d2f948a..8b774ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dev-dependencies = [ "pytest-asyncio>=0.25.0", "ruff>=0.8.3", "pytest-coverage>=0.0", + "polyfactory>=2.18.1", ] [[tool.mypy.overrides]] diff --git a/src/huesoporro/actions/store_quote.py b/src/huesoporro/actions/store_quote.py index 2d569a6..87b5a84 100644 --- a/src/huesoporro/actions/store_quote.py +++ b/src/huesoporro/actions/store_quote.py @@ -1,8 +1,10 @@ +import datetime + from pydantic import BaseModel -from src.huesoporro.models import User +from src.huesoporro.models import Quote, User from src.huesoporro.svc.is_mod import IsModSvc -from src.huesoporro.svc.store_quote import QuoteStorerSvc +from src.huesoporro.svc.quote_storer_svc import QuoteStorerSvc class StoreQuoteAction(BaseModel): @@ -11,8 +13,14 @@ class StoreQuoteAction(BaseModel): async def run( self, user: User, channel: str, quote: str, author: str, username: str - ) -> str: + ) -> Quote | None: if not await self.is_mod_svc.run(user=user, username=username, channel=channel): - return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!" - await self.quote_storer_svc.run(channel, quote, author) - return f"«{quote}» added by {author}." + return None + new_quote = Quote( + quote=quote, + author=User(user=author, external_auth={}), + channel=User(user=channel, external_auth={}), + created_at=datetime.datetime.now(datetime.UTC), + last_updated_at=datetime.datetime.now(datetime.UTC), + ) + return await self.quote_storer_svc.run(new_quote) diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index f682260..14c39db 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -19,8 +19,8 @@ from src.huesoporro.svc.generate import SentenceGeneratorSvc from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc from src.huesoporro.svc.hello import HelloGeneratorSvc from src.huesoporro.svc.is_mod import IsModSvc +from src.huesoporro.svc.quote_storer_svc import QuoteStorerSvc from src.huesoporro.svc.store import SentenceStorerSvc -from src.huesoporro.svc.store_quote import QuoteStorerSvc class Bot(commands.Bot): @@ -33,14 +33,15 @@ class Bot(commands.Bot): self.generate_svc = SentenceGeneratorSvc(db=MarkovDB(channel=channel)) self.hello_svc = HelloGeneratorSvc() db = Database() - self.store_quote_action = StoreQuoteAction( - quote_storer_svc=QuoteStorerSvc(db=db), is_mod_svc=IsModSvc(db=db) - ) self.quote_repo = QuoteRepo(s=get_settings()) self.get_random_quote_svc = RandomQuoteGetterSvc(quote_repo=self.quote_repo) self.get_random_quote_action = GetRandomQuoteAction( quote_getter_svc=self.get_random_quote_svc ) + self.store_quote_action = StoreQuoteAction( + quote_storer_svc=QuoteStorerSvc(quote_repo=self.quote_repo), + is_mod_svc=IsModSvc(db=db), + ) self.cbs = chatbot_settings self.quote_routine = routines.routine( seconds=chatbot_settings.automatic_quote_timer, wait_first=True @@ -72,15 +73,17 @@ class Bot(commands.Bot): async def add_quote(self, ctx: commands.Context, *, quote: str): # extract author from quote; the author is the last word quote, author = quote.rsplit(" ", 1) - await ctx.send( - await self.store_quote_action.run( - user=self.user, - channel=self.channel, - quote=quote, - author=author, - username=ctx.author.name, - ) + new_quote = await self.store_quote_action.run( + user=self.user, + channel=self.channel, + quote=quote, + author=author, + username=ctx.author.name, ) + if new_quote: + await ctx.send(new_quote.as_pretty_saved()) + else: + await ctx.send(f"@{ctx.author.name} no tienes permisos para añadir citas") @commands.command(aliases=["q", "quote"]) async def get_random_quote(self, ctx: commands.Context): diff --git a/src/huesoporro/infra/db.py b/src/huesoporro/infra/db.py index 86d8260..898b768 100644 --- a/src/huesoporro/infra/db.py +++ b/src/huesoporro/infra/db.py @@ -24,13 +24,6 @@ class Database(BaseModel): def get_now() -> float: return datetime.datetime.now(datetime.UTC).timestamp() - async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True): - async with self.get_client(auto_commit=auto_commit) as db: - await db.execute( - "INSERT INTO quotes (channel, quote, author) VALUES (?,?,?)", - (channel, quote, author), - ) - async def save_chatbot_settings( self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True ): @@ -94,16 +87,6 @@ class Database(BaseModel): ) await db.commit() - async def get_random_quote(self, channel_name: str) -> tuple[str, str] | None: - async with ( - self.get_client() as db, - db.execute( - "SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1", - (channel_name,), - ) as cursor, - ): - return await cursor.fetchone() - async def get_sentences(self, user: User) -> list[Sentence]: async with self.get_client() as db: db.row_factory = aiosqlite.Row diff --git a/src/huesoporro/infra/repos.py b/src/huesoporro/infra/repos.py index a4aab5b..3256ffa 100644 --- a/src/huesoporro/infra/repos.py +++ b/src/huesoporro/infra/repos.py @@ -116,7 +116,21 @@ class UserRepo(IRepo[User]): class QuoteRepo(IRepo[Quote]): async def create(self, obj: Quote, auto_commit=True) -> Quote: - raise NotImplementedError("Not implemented since it's not needed") + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + """ + INSERT INTO quotes (quote, author, channel, created_at, last_updated_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + obj.quote, + obj.author.user, + obj.channel.user, + obj.created_at, + obj.last_updated_at, + ), + ) + return obj async def update(self, obj: Quote, auto_commit=True) -> Quote: raise NotImplementedError("Not implemented since it's not needed") diff --git a/src/huesoporro/svc/quote_storer_svc.py b/src/huesoporro/svc/quote_storer_svc.py new file mode 100644 index 0000000..bc71320 --- /dev/null +++ b/src/huesoporro/svc/quote_storer_svc.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.repos import QuoteRepo +from src.huesoporro.models import Quote + + +class QuoteStorerSvc(BaseModel): + quote_repo: QuoteRepo + + async def run(self, quote: Quote) -> Quote: + return await self.quote_repo.create(quote) diff --git a/src/huesoporro/svc/store_quote.py b/src/huesoporro/svc/store_quote.py deleted file mode 100644 index fe56899..0000000 --- a/src/huesoporro/svc/store_quote.py +++ /dev/null @@ -1,10 +0,0 @@ -from pydantic import BaseModel - -from src.huesoporro.infra.db import Database - - -class QuoteStorerSvc(BaseModel): - db: Database - - async def run(self, channel: str, quote: str, author: str): - await self.db.save_quote(channel, quote, author) diff --git a/tests/conftest.py b/tests/conftest.py index c8c6c23..2163026 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,11 @@ from pathlib import Path import pytest from caribou.migrate import Database as CaribouDatabase from caribou.migrate import load_migrations +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture from src.huesoporro.infra.db import Database -from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.models import ChatbotSettings, Quote, User from src.huesoporro.settings import Settings from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.is_mod import IsModSvc @@ -79,3 +81,12 @@ async def backoff_svc(backoff_callable, async_backoff_callable): backoff_svc.add_callable(backoff_callable, 3) backoff_svc.add_callable(async_backoff_callable, 3) return backoff_svc + + +@register_fixture() +class QuoteFactory(ModelFactory[Quote]): ... + + +@pytest.fixture +def quote(quote_factory): + return quote_factory.build() diff --git a/tests/test_repos.py b/tests/test_repos.py index 884d088..5327108 100644 --- a/tests/test_repos.py +++ b/tests/test_repos.py @@ -68,3 +68,8 @@ async def test_get_random_quote(quote_repo: QuoteRepo): assert quote assert quote.author.user == "author" assert quote.channel.user == "channel" + + +async def test_create_quote(quote, quote_repo): + new_quote = await quote_repo.create(quote) + assert new_quote == quote diff --git a/uv.lock b/uv.lock index ccca98c..3da5f24 100644 --- a/uv.lock +++ b/uv.lock @@ -495,6 +495,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "mypy" }, + { name = "polyfactory" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-coverage" }, @@ -523,6 +524,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "mypy", specifier = ">=1.13.0" }, + { name = "polyfactory", specifier = ">=2.18.1" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-coverage", specifier = ">=0.0" },