feat: change QuoteStorerSvc to use the new quote repo instead of the legacy db object
This commit is contained in:
parent
75df191253
commit
3058ca112d
10 changed files with 75 additions and 47 deletions
|
|
@ -32,6 +32,7 @@ dev-dependencies = [
|
||||||
"pytest-asyncio>=0.25.0",
|
"pytest-asyncio>=0.25.0",
|
||||||
"ruff>=0.8.3",
|
"ruff>=0.8.3",
|
||||||
"pytest-coverage>=0.0",
|
"pytest-coverage>=0.0",
|
||||||
|
"polyfactory>=2.18.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from src.huesoporro.models import User
|
from src.huesoporro.models import Quote, User
|
||||||
from src.huesoporro.svc.is_mod import IsModSvc
|
from src.huesoporro.svc.is_mod import IsModSvc
|
||||||
from src.huesoporro.svc.store_quote import QuoteStorerSvc
|
from src.huesoporro.svc.quote_storer_svc import QuoteStorerSvc
|
||||||
|
|
||||||
|
|
||||||
class StoreQuoteAction(BaseModel):
|
class StoreQuoteAction(BaseModel):
|
||||||
|
|
@ -11,8 +13,14 @@ class StoreQuoteAction(BaseModel):
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, user: User, channel: str, quote: str, author: str, username: str
|
self, user: User, channel: str, quote: str, author: str, username: str
|
||||||
) -> str:
|
) -> Quote | None:
|
||||||
if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
|
if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
|
||||||
return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!"
|
return None
|
||||||
await self.quote_storer_svc.run(channel, quote, author)
|
new_quote = Quote(
|
||||||
return f"«{quote}» added by {author}."
|
quote=quote,
|
||||||
|
author=User(user=author, external_auth={}),
|
||||||
|
channel=User(user=channel, external_auth={}),
|
||||||
|
created_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
last_updated_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
)
|
||||||
|
return await self.quote_storer_svc.run(new_quote)
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ from src.huesoporro.svc.generate import SentenceGeneratorSvc
|
||||||
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||||
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
||||||
from src.huesoporro.svc.is_mod import IsModSvc
|
from src.huesoporro.svc.is_mod import IsModSvc
|
||||||
|
from src.huesoporro.svc.quote_storer_svc import QuoteStorerSvc
|
||||||
from src.huesoporro.svc.store import SentenceStorerSvc
|
from src.huesoporro.svc.store import SentenceStorerSvc
|
||||||
from src.huesoporro.svc.store_quote import QuoteStorerSvc
|
|
||||||
|
|
||||||
|
|
||||||
class Bot(commands.Bot):
|
class Bot(commands.Bot):
|
||||||
|
|
@ -33,14 +33,15 @@ class Bot(commands.Bot):
|
||||||
self.generate_svc = SentenceGeneratorSvc(db=MarkovDB(channel=channel))
|
self.generate_svc = SentenceGeneratorSvc(db=MarkovDB(channel=channel))
|
||||||
self.hello_svc = HelloGeneratorSvc()
|
self.hello_svc = HelloGeneratorSvc()
|
||||||
db = Database()
|
db = Database()
|
||||||
self.store_quote_action = StoreQuoteAction(
|
|
||||||
quote_storer_svc=QuoteStorerSvc(db=db), is_mod_svc=IsModSvc(db=db)
|
|
||||||
)
|
|
||||||
self.quote_repo = QuoteRepo(s=get_settings())
|
self.quote_repo = QuoteRepo(s=get_settings())
|
||||||
self.get_random_quote_svc = RandomQuoteGetterSvc(quote_repo=self.quote_repo)
|
self.get_random_quote_svc = RandomQuoteGetterSvc(quote_repo=self.quote_repo)
|
||||||
self.get_random_quote_action = GetRandomQuoteAction(
|
self.get_random_quote_action = GetRandomQuoteAction(
|
||||||
quote_getter_svc=self.get_random_quote_svc
|
quote_getter_svc=self.get_random_quote_svc
|
||||||
)
|
)
|
||||||
|
self.store_quote_action = StoreQuoteAction(
|
||||||
|
quote_storer_svc=QuoteStorerSvc(quote_repo=self.quote_repo),
|
||||||
|
is_mod_svc=IsModSvc(db=db),
|
||||||
|
)
|
||||||
self.cbs = chatbot_settings
|
self.cbs = chatbot_settings
|
||||||
self.quote_routine = routines.routine(
|
self.quote_routine = routines.routine(
|
||||||
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
|
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
|
||||||
|
|
@ -72,15 +73,17 @@ class Bot(commands.Bot):
|
||||||
async def add_quote(self, ctx: commands.Context, *, quote: str):
|
async def add_quote(self, ctx: commands.Context, *, quote: str):
|
||||||
# extract author from quote; the author is the last word
|
# extract author from quote; the author is the last word
|
||||||
quote, author = quote.rsplit(" ", 1)
|
quote, author = quote.rsplit(" ", 1)
|
||||||
await ctx.send(
|
new_quote = await self.store_quote_action.run(
|
||||||
await self.store_quote_action.run(
|
|
||||||
user=self.user,
|
user=self.user,
|
||||||
channel=self.channel,
|
channel=self.channel,
|
||||||
quote=quote,
|
quote=quote,
|
||||||
author=author,
|
author=author,
|
||||||
username=ctx.author.name,
|
username=ctx.author.name,
|
||||||
)
|
)
|
||||||
)
|
if new_quote:
|
||||||
|
await ctx.send(new_quote.as_pretty_saved())
|
||||||
|
else:
|
||||||
|
await ctx.send(f"@{ctx.author.name} no tienes permisos para añadir citas")
|
||||||
|
|
||||||
@commands.command(aliases=["q", "quote"])
|
@commands.command(aliases=["q", "quote"])
|
||||||
async def get_random_quote(self, ctx: commands.Context):
|
async def get_random_quote(self, ctx: commands.Context):
|
||||||
|
|
|
||||||
|
|
@ -24,13 +24,6 @@ class Database(BaseModel):
|
||||||
def get_now() -> float:
|
def get_now() -> float:
|
||||||
return datetime.datetime.now(datetime.UTC).timestamp()
|
return datetime.datetime.now(datetime.UTC).timestamp()
|
||||||
|
|
||||||
async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True):
|
|
||||||
async with self.get_client(auto_commit=auto_commit) as db:
|
|
||||||
await db.execute(
|
|
||||||
"INSERT INTO quotes (channel, quote, author) VALUES (?,?,?)",
|
|
||||||
(channel, quote, author),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def save_chatbot_settings(
|
async def save_chatbot_settings(
|
||||||
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
||||||
):
|
):
|
||||||
|
|
@ -94,16 +87,6 @@ class Database(BaseModel):
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def get_random_quote(self, channel_name: str) -> tuple[str, str] | None:
|
|
||||||
async with (
|
|
||||||
self.get_client() as db,
|
|
||||||
db.execute(
|
|
||||||
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
|
|
||||||
(channel_name,),
|
|
||||||
) as cursor,
|
|
||||||
):
|
|
||||||
return await cursor.fetchone()
|
|
||||||
|
|
||||||
async def get_sentences(self, user: User) -> list[Sentence]:
|
async def get_sentences(self, user: User) -> list[Sentence]:
|
||||||
async with self.get_client() as db:
|
async with self.get_client() as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,21 @@ class UserRepo(IRepo[User]):
|
||||||
|
|
||||||
class QuoteRepo(IRepo[Quote]):
|
class QuoteRepo(IRepo[Quote]):
|
||||||
async def create(self, obj: Quote, auto_commit=True) -> Quote:
|
async def create(self, obj: Quote, auto_commit=True) -> Quote:
|
||||||
raise NotImplementedError("Not implemented since it's not needed")
|
async with self.get_client(auto_commit=auto_commit) as db:
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO quotes (quote, author, channel, created_at, last_updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
obj.quote,
|
||||||
|
obj.author.user,
|
||||||
|
obj.channel.user,
|
||||||
|
obj.created_at,
|
||||||
|
obj.last_updated_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return obj
|
||||||
|
|
||||||
async def update(self, obj: Quote, auto_commit=True) -> Quote:
|
async def update(self, obj: Quote, auto_commit=True) -> Quote:
|
||||||
raise NotImplementedError("Not implemented since it's not needed")
|
raise NotImplementedError("Not implemented since it's not needed")
|
||||||
|
|
|
||||||
11
src/huesoporro/svc/quote_storer_svc.py
Normal file
11
src/huesoporro/svc/quote_storer_svc.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.huesoporro.infra.repos import QuoteRepo
|
||||||
|
from src.huesoporro.models import Quote
|
||||||
|
|
||||||
|
|
||||||
|
class QuoteStorerSvc(BaseModel):
|
||||||
|
quote_repo: QuoteRepo
|
||||||
|
|
||||||
|
async def run(self, quote: Quote) -> Quote:
|
||||||
|
return await self.quote_repo.create(quote)
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from src.huesoporro.infra.db import Database
|
|
||||||
|
|
||||||
|
|
||||||
class QuoteStorerSvc(BaseModel):
|
|
||||||
db: Database
|
|
||||||
|
|
||||||
async def run(self, channel: str, quote: str, author: str):
|
|
||||||
await self.db.save_quote(channel, quote, author)
|
|
||||||
|
|
@ -3,9 +3,11 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
from caribou.migrate import Database as CaribouDatabase
|
from caribou.migrate import Database as CaribouDatabase
|
||||||
from caribou.migrate import load_migrations
|
from caribou.migrate import load_migrations
|
||||||
|
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||||
|
from polyfactory.pytest_plugin import register_fixture
|
||||||
|
|
||||||
from src.huesoporro.infra.db import Database
|
from src.huesoporro.infra.db import Database
|
||||||
from src.huesoporro.models import ChatbotSettings, User
|
from src.huesoporro.models import ChatbotSettings, Quote, User
|
||||||
from src.huesoporro.settings import Settings
|
from src.huesoporro.settings import Settings
|
||||||
from src.huesoporro.svc.backoff_service import BackoffService
|
from src.huesoporro.svc.backoff_service import BackoffService
|
||||||
from src.huesoporro.svc.is_mod import IsModSvc
|
from src.huesoporro.svc.is_mod import IsModSvc
|
||||||
|
|
@ -79,3 +81,12 @@ async def backoff_svc(backoff_callable, async_backoff_callable):
|
||||||
backoff_svc.add_callable(backoff_callable, 3)
|
backoff_svc.add_callable(backoff_callable, 3)
|
||||||
backoff_svc.add_callable(async_backoff_callable, 3)
|
backoff_svc.add_callable(async_backoff_callable, 3)
|
||||||
return backoff_svc
|
return backoff_svc
|
||||||
|
|
||||||
|
|
||||||
|
@register_fixture()
|
||||||
|
class QuoteFactory(ModelFactory[Quote]): ...
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def quote(quote_factory):
|
||||||
|
return quote_factory.build()
|
||||||
|
|
|
||||||
|
|
@ -68,3 +68,8 @@ async def test_get_random_quote(quote_repo: QuoteRepo):
|
||||||
assert quote
|
assert quote
|
||||||
assert quote.author.user == "author"
|
assert quote.author.user == "author"
|
||||||
assert quote.channel.user == "channel"
|
assert quote.channel.user == "channel"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_quote(quote, quote_repo):
|
||||||
|
new_quote = await quote_repo.create(quote)
|
||||||
|
assert new_quote == quote
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -495,6 +495,7 @@ dependencies = [
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
|
{ name = "polyfactory" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-coverage" },
|
{ name = "pytest-coverage" },
|
||||||
|
|
@ -523,6 +524,7 @@ requires-dist = [
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
{ name = "mypy", specifier = ">=1.13.0" },
|
{ name = "mypy", specifier = ">=1.13.0" },
|
||||||
|
{ name = "polyfactory", specifier = ">=2.18.1" },
|
||||||
{ name = "pytest", specifier = ">=8.3.4" },
|
{ name = "pytest", specifier = ">=8.3.4" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=0.25.0" },
|
{ name = "pytest-asyncio", specifier = ">=0.25.0" },
|
||||||
{ name = "pytest-coverage", specifier = ">=0.0" },
|
{ name = "pytest-coverage", specifier = ">=0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue