From 3bc4e19de1cb5b30377f0da2c799645093dfc317 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 19 Dec 2024 18:13:38 +0100 Subject: [PATCH 01/27] feat: add backoff service and some message reactions --- pyproject.toml | 2 +- src/huesoporro/api/errors.py | 2 +- src/huesoporro/api/routes/api.py | 18 ++- src/huesoporro/bot.py | 172 ++++++++++++++++++++++++- src/huesoporro/svc/backoff_service.py | 111 ++++++++++++++++ src/huesoporro/svc/generate.py | 12 +- src/huesoporro/svc/get_random_quote.py | 2 +- src/huesoporro/svc/hello.py | 4 +- tests/conftest.py | 32 ++++- tests/test_svc.py | 61 +++++++++ uv.lock | 2 +- 11 files changed, 394 insertions(+), 24 deletions(-) create mode 100644 src/huesoporro/svc/backoff_service.py diff --git a/pyproject.toml b/pyproject.toml index 9d19c3b..9333e7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ extend-select = [ "W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET", "SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF" ] -extend-ignore = ["S101", "ISC002", "COM812", "ISC001"] +extend-ignore = ["S101", "ISC002", "COM812", "ISC001", "EM101", "EM102"] [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/src/huesoporro/api/errors.py b/src/huesoporro/api/errors.py index e88a6f0..19426df 100644 --- a/src/huesoporro/api/errors.py +++ b/src/huesoporro/api/errors.py @@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError): ) -async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821 +async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821 """Hook function that will be invoked after each exception.""" state = scope["app"].state if not hasattr(state, "error_count"): diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index dff54c2..d427cde 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template: @put("/api/v1/bot") async def manage_bot( - user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager + user: User, + data: ManageBotDTO, + gbs: ChatbotSettingsGetterSvc, + sbs: ChatbotSettingsStorerSvc, + bm: BotsManager, ) -> Response: chatbot_settings = await gbs.run(user=user) + if not chatbot_settings: + await sbs.run(user=user, bot_settings=ChatbotSettings()) + chatbot_settings = await gbs.run(user=user) if data.command == "start": if not data.channel_name: return Response({"message": "Channel name is required"}, status_code=400) - bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) + bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type] if user.user in bm.bots: await bm.run_user_bot(user) return Response({"message": "Bot started"}) @@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict: @get("/api/v1/bot/settings") async def get_bot_settings( user: User, gbs: ChatbotSettingsGetterSvc -) -> ChatbotSettings: - return await gbs.run(user=user) +) -> ChatbotSettings | dict: + cbs = await gbs.run(user=user) + if not cbs: + return {"status": "Not found"} + return cbs @put("/api/v1/bot/settings") diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index 82151c0..fae650e 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -1,4 +1,7 @@ import asyncio +import random +from collections.abc import Callable +from enum import StrEnum from loguru import logger from twitchio import Channel @@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction from src.huesoporro.infra.db import Database from src.huesoporro.libs.db import Database as MarkovDB from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.generate import SentenceGeneratorSvc from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc from src.huesoporro.svc.hello import HelloGeneratorSvc @@ -75,16 +79,18 @@ class Bot(commands.Bot): @commands.command(aliases=["q", "quote"]) async def get_random_quote(self, ctx: commands.Context): quote = await self.get_random_quote_svc.run(channel_name=self.channel) - await ctx.send(f"«{quote[0]}» - {quote[1]}") + if quote: + await ctx.send(f"«{quote[0]}» - {quote[1]}") def get_channel_conn(self) -> Channel: return Channel(name=self.channel, websocket=self._connection) async def send_quote(self): quote = await self.get_random_quote_svc.run(channel_name=self.channel) - channel = self.get_channel_conn() - logger.info(f"Sending random quote {quote[0]}") - await channel.send(f"«{quote[0]}» - {quote[1]}") + if quote: + channel = self.get_channel_conn() + logger.info(f"Sending random quote {quote[0]}") + await channel.send(f"«{quote[0]}» - {quote[1]}") async def send_generation(self): sentence = await self.generate_svc.run() @@ -108,14 +114,15 @@ class Bot(commands.Bot): self.generation_routine.cancel() -class SaveMessagesCog(commands.Cog): +class SaveMessagesCog2(commands.Cog): def __init__(self, bot): self.bot = bot self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) + self.hello_svc = HelloGeneratorSvc() + self.backoff_svc = BackoffService() @commands.Cog.event() async def event_message(self, message): - # An event inside a cog! content = message.content if content.startswith("!"): return @@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog): await self.store_svc.run(content) + if message.content in ["hola", "HOLA", "hiii", "ayo"]: + hello_message = self.hello_svc.run(message.author.name) + await message.channel.send(hello_message) + return + + if message.content == "Yes": + await message.channel.send("Indeed") + return + + if message.content.startswith("WHAT"): + await message.channel.send("WHAT Ramon") + return + + laughs_messages = [ + "om", + "KEK", + "LuL", + "LUL", + "OMEGALUL", + "kek", + "keking", + "KEKW", + "OMEGADANCEBUTFAST", + ] + + if message.content in laughs_messages: + await message.channel.send(random.choice(laughs_messages)) # noqa: S311 + return + + +class MessageType(StrEnum): + COMMAND = "COMMAND" + HELLO = "HELLO" + YES = "YES" + WHAT = "WHAT" + LAUGH = "LAUGH" + OTHER = "OTHER" + + +class MessageHandler: + """Handles different types of messages with their corresponding responses""" + + def __init__(self, channel_send_func: Callable): + self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"] + self.laugh_patterns = [ + "om", + "KEK", + "LuL", + "LUL", + "OMEGALUL", + "kek", + "keking", + "KEKW", + "OMEGADANCEBUTFAST", + ] + self.send = channel_send_func + + def get_message_type(self, content: str) -> MessageType: + """Determines the type of message based on its content""" + if content.startswith("!"): + return MessageType.COMMAND + if content in self.hello_patterns: + return MessageType.HELLO + if content == "Yes": + return MessageType.YES + if content.startswith("WHAT"): + return MessageType.WHAT + if content in self.laugh_patterns: + return MessageType.LAUGH + return MessageType.OTHER + + async def handle_hello(self, author_name: str, hello_svc) -> str: + """Handles hello messages""" + return hello_svc.run(author_name) + + async def handle_laugh(self) -> str: + """Handles laugh messages""" + return random.choice(self.laugh_patterns) # noqa: S311 + + +class SaveMessagesCog(commands.Cog): + def __init__(self, bot): + self.bot = bot + self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) + self.hello_svc = HelloGeneratorSvc() + self.backoff_svc = BackoffService() + self.message_handler = MessageHandler(self._send_message) + + # Register a separate send function for each message type + self.send_functions = { + MessageType.HELLO: self._create_typed_send("hello"), + MessageType.YES: self._create_typed_send("yes"), + MessageType.WHAT: self._create_typed_send("what"), + MessageType.LAUGH: self._create_typed_send("laugh"), + } + + # Register each send function with its own backoff + for func in self.send_functions.values(): + self.backoff_svc.add_callable(func, backoff_seconds=10) + + def _create_typed_send(self, type_name: str): + """Creates a send function for a specific message type""" + + async def typed_send(content: str): + if hasattr(self, "current_message"): + await self.current_message.channel.send(content) + + # Set a unique name for the function to ensure it's treated as distinct + typed_send.__name__ = f"send_{type_name}" + return typed_send + + async def _send_message(self, content: str): + """Generic send message function (for non-backoff uses)""" + if hasattr(self, "current_message"): + await self.current_message.channel.send(content) + + @commands.Cog.event() + async def event_message(self, message): + """Main message event handler""" + if not message.author: + return + + # Store reference to current message for send functions + self.current_message = message + + # Store the message content + await self.store_svc.run(message.content) + + # Determine message type and handle accordingly + msg_type = self.message_handler.get_message_type(message.content) + + response = None + + match msg_type: + case MessageType.COMMAND: + return + case MessageType.HELLO: + response = await self.message_handler.handle_hello( + message.author.name, self.hello_svc + ) + case MessageType.YES: + response = "Indeed" + case MessageType.WHAT: + response = "WHAT Ramon" + case MessageType.LAUGH: + response = await self.message_handler.handle_laugh() + case MessageType.OTHER: + return + + if response and msg_type in self.send_functions: + # Use the type-specific send function + await self.backoff_svc.call_async(self.send_functions[msg_type], response) + class BotsManager: def __init__(self): diff --git a/src/huesoporro/svc/backoff_service.py b/src/huesoporro/svc/backoff_service.py new file mode 100644 index 0000000..08123d1 --- /dev/null +++ b/src/huesoporro/svc/backoff_service.py @@ -0,0 +1,111 @@ +import asyncio +import time +from collections.abc import Callable + +from pydantic import BaseModel + + +class CallableInfo(BaseModel): + backoff_seconds: int + last_call: float | None = None + is_async: bool + + +class BackoffService(BaseModel): + """Use this service to implement a backoff strategy on random callables. + The callable will be called the first time without delay but every subsequent + call may be hold off for a given time + + Examples: + >>> def callable(x): print(f"foo {x}") + >>> backoff_service = BackoffService() + >>> backoff_service.add_callable(callable, backoff_time=3) + >>> backoff_service.call(callable, "bar") # prints "foo bar" + >>> backoff_service.call(callable, "baz") # prints nothing + >>> # wait 3 seconds before calling callable again + >>> backoff_service.call(callable, "qux") # prints "foo qux" + """ + + callables: dict[Callable, CallableInfo] = {} + + def add_callable(self, func: Callable, backoff_seconds: int): + """Adds a callable to the local mapper with its backoff configuration. + + Args: + func: The function to be registered + backoff_seconds: The number of seconds to wait between successive calls + """ + self.callables[func] = CallableInfo( + backoff_seconds=backoff_seconds, is_async=self._is_async(func) + ) + + @staticmethod + def _is_async(func: Callable) -> bool: + """Checks if the callable is async""" + return asyncio.iscoroutinefunction(func) + + def _can_call(self, func: Callable) -> bool: + """Determines if enough time has passed since the last call""" + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + func_info = self.callables[func] + last_call = func_info.last_call + + if last_call is None: + return True + + elapsed = time.time() - last_call + return elapsed >= func_info.backoff_seconds + + def call(self, func: Callable, *args, **kwargs): + """Calls the callable with arguments and returns its result if it isn't held off + + Args: + func: The function to call + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Optional[Any]: The result of the function call if executed, None if held off + """ + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + if self.callables[func].is_async: + raise ValueError( + "Cannot call async function with .call(), use .call_async() instead" + ) + + if not self._can_call(func): + return None + + result = func(*args, **kwargs) + self.callables[func].last_call = time.time() + return result + + async def call_async(self, func: Callable, *args, **kwargs): + """Same as .call(...) but for async functions + + Args: + func: The async function to call + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Optional[Any]: The result of the async function call if executed, None if held off + """ + if func not in self.callables: + raise ValueError(f"Function {func} not registered with backoff service") + + if not self.callables[func].is_async: + raise ValueError( + "Cannot call sync function with .call_async(), use .call() instead" + ) + + if not self._can_call(func): + return None + + result = await func(*args, **kwargs) + self.callables[func].last_call = time.time() + return result diff --git a/src/huesoporro/svc/generate.py b/src/huesoporro/svc/generate.py index 0f5058d..5bef7fe 100644 --- a/src/huesoporro/svc/generate.py +++ b/src/huesoporro/svc/generate.py @@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel): self, sentence: str | None = None, ) -> str | None: - if sentence: - sentence = tokenize(sentence) - logger.info(f"Generating sentence from {sentence}") - sentence, success = self.generate(sentence) - logger.info(f"Generated sentence: {sentence}") + split_sentence = tokenize(sentence) if sentence else None + + logger.info(f"Generating sentence from {split_sentence}") + generated_sentence, success = self.generate(split_sentence) + logger.info(f"Generated sentence: {generated_sentence}") if not success: return None - return sentence + return generated_sentence diff --git a/src/huesoporro/svc/get_random_quote.py b/src/huesoporro/svc/get_random_quote.py index 4371d4e..f6b4f6c 100644 --- a/src/huesoporro/svc/get_random_quote.py +++ b/src/huesoporro/svc/get_random_quote.py @@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database class RandomQuoteGetterSvc(BaseModel): db: Database - async def run(self, channel_name: str) -> tuple[str, str]: + async def run(self, channel_name: str) -> tuple[str, str] | None: return await self.db.get_random_quote(channel_name=channel_name) diff --git a/src/huesoporro/svc/hello.py b/src/huesoporro/svc/hello.py index 19dd38f..5be2180 100644 --- a/src/huesoporro/svc/hello.py +++ b/src/huesoporro/svc/hello.py @@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel): "Hi", "Bon día", "Hola mi tremendo elemento", + "HOLA", + "hiii", ] ) def run(self, username: str): - return f"{random.choice(self.hellos)} {username}" # noqa: S311 + return f"{random.choice(self.hellos)} @{username}" # noqa: S311 diff --git a/tests/conftest.py b/tests/conftest.py index b01710c..19218d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from caribou.migrate import load_migrations from src.huesoporro.infra.db import Database from src.huesoporro.models import ChatbotSettings, TwitchAuth, User from src.huesoporro.settings import Settings +from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.is_mod import IsModSvc @@ -16,7 +17,8 @@ def user() -> User: user="huesoporro", expires_at=1671234567.0, twitch_auth=TwitchAuth( - access_token="test_access_token", refresh_token="test_refresh_token" + access_token="test_access_token", # noqa: S106 + refresh_token="test_refresh_token", # noqa: S106 ), ) @@ -27,8 +29,8 @@ def s(tmp_path: Path, user: User) -> Settings: static_files_path=tmp_path / "static_files", db_filepath=tmp_path / "huesoporro.db", twitch_client_id="test_client_id", - twitch_client_secret="test_client_secret", # type: ignore[arg-type] - jwt_secret="test_jwt_secret", # type: ignore[arg-type] + twitch_client_secret="test_client_secret", # type: ignore[arg-type] # noqa: S106 + jwt_secret="test_jwt_secret", # type: ignore[arg-type] # noqa: S106 allowed_users=[user.user], ) @@ -54,3 +56,27 @@ async def chatbot_settings(db: Database, user) -> ChatbotSettings: cbs = ChatbotSettings(mods=[user.user, "allowed_user"]) await db.save_chatbot_settings(user=user, chatbot_settings=cbs) return cbs + + +@pytest.fixture +def backoff_callable(): + def foo(): + return "foo" + + return foo + + +@pytest.fixture +def async_backoff_callable(): + async def async_foo(): + return "async foo" + + return async_foo + + +@pytest.fixture +async def backoff_svc(backoff_callable, async_backoff_callable): + backoff_svc = BackoffService() + backoff_svc.add_callable(backoff_callable, 3) + backoff_svc.add_callable(async_backoff_callable, 3) + return backoff_svc diff --git a/tests/test_svc.py b/tests/test_svc.py index 5a5aacd..c88e1c9 100644 --- a/tests/test_svc.py +++ b/tests/test_svc.py @@ -1,3 +1,6 @@ +import asyncio +import time + import pytest from src.huesoporro.models import ChatbotSettings, User @@ -33,3 +36,61 @@ async def test_is_mod_svc_returns_false_for_user_not_in_modlist( ): is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user) assert not is_mod + + +async def test_backoff_svc_returns_for_first_attempt( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + + +async def test_backoff_svc_returns_none_for_second_attempt( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + assert backoff_svc.call(backoff_callable) is None + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + assert await backoff_svc.call_async(async_backoff_callable) is None + + +async def test_backoff_svc_returns_for_second_attempt_after_delay( + backoff_svc, backoff_callable, async_backoff_callable +): + assert backoff_svc.call(backoff_callable) == "foo" + assert backoff_svc.call(backoff_callable) is None + time.sleep(3) + assert backoff_svc.call(backoff_callable) == "foo" + + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + assert await backoff_svc.call_async(async_backoff_callable) is None + await asyncio.sleep(3) + assert await backoff_svc.call_async(async_backoff_callable) == "async foo" + + +async def test_backoff_svc_raises_value_error_for_unknown_callable(backoff_svc): + with pytest.raises(ValueError, match="not registered with backoff service"): + backoff_svc.call(lambda: "foo") + + +async def test_backoff_svc_raises_value_error_for_unknown_async_callable(backoff_svc): + with pytest.raises(ValueError, match="not registered with backoff service"): + await backoff_svc.call_async(lambda: "foo") + + +async def test_backoff_svc_raises_value_error_for_async_called_from_sync( + backoff_svc, backoff_callable +): + with pytest.raises( + ValueError, match="Cannot call sync function with .call_async()" + ): + await backoff_svc.call_async(backoff_callable) + + +async def test_backoff_svc_raises_value_error_for_sync_called_from_async( + backoff_svc, async_backoff_callable +): + with pytest.raises(ValueError, match="Cannot call async function with .call()"): + backoff_svc.call(async_backoff_callable) diff --git a/uv.lock b/uv.lock index 618cfee..458a1a4 100644 --- a/uv.lock +++ b/uv.lock @@ -460,7 +460,7 @@ wheels = [ [[package]] name = "huesoporro" -version = "0.2.2" +version = "0.2.3" source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, From efac1cc33ccf807f3c2b1a176286db24af891c23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 19 Dec 2024 18:18:37 +0100 Subject: [PATCH 02/27] chore: update to v0.2.4 and remove useless code --- charts/huesoporro/Chart.yaml | 4 +-- charts/huesoporro/values.yaml | 2 +- pyproject.toml | 2 +- src/huesoporro/bot.py | 48 ----------------------------------- uv.lock | 2 +- 5 files changed, 5 insertions(+), 53 deletions(-) diff --git a/charts/huesoporro/Chart.yaml b/charts/huesoporro/Chart.yaml index 1c6576a..22ca75f 100644 --- a/charts/huesoporro/Chart.yaml +++ b/charts/huesoporro/Chart.yaml @@ -15,10 +15,10 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.2.3 +version: 0.2.4 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.2.3" +appVersion: "0.2.4" diff --git a/charts/huesoporro/values.yaml b/charts/huesoporro/values.yaml index ae50d59..48b749a 100644 --- a/charts/huesoporro/values.yaml +++ b/charts/huesoporro/values.yaml @@ -11,7 +11,7 @@ image: # This sets the pull policy for images. pullPolicy: Always # Overrides the image tag whose default is the chart appVersion. - tag: "0.2.3" + tag: "0.2.4" # This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ imagePullSecrets: [] diff --git a/pyproject.toml b/pyproject.toml index 9333e7f..5b913bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "huesoporro" -version = "0.2.3" +version = "0.2.4" description = "Misc Twitch bots" readme = "README.md" authors = [ diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index fae650e..e6b3c36 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -114,54 +114,6 @@ class Bot(commands.Bot): self.generation_routine.cancel() -class SaveMessagesCog2(commands.Cog): - def __init__(self, bot): - self.bot = bot - self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) - self.hello_svc = HelloGeneratorSvc() - self.backoff_svc = BackoffService() - - @commands.Cog.event() - async def event_message(self, message): - content = message.content - if content.startswith("!"): - return - - if not message.author: - return - - await self.store_svc.run(content) - - if message.content in ["hola", "HOLA", "hiii", "ayo"]: - hello_message = self.hello_svc.run(message.author.name) - await message.channel.send(hello_message) - return - - if message.content == "Yes": - await message.channel.send("Indeed") - return - - if message.content.startswith("WHAT"): - await message.channel.send("WHAT Ramon") - return - - laughs_messages = [ - "om", - "KEK", - "LuL", - "LUL", - "OMEGALUL", - "kek", - "keking", - "KEKW", - "OMEGADANCEBUTFAST", - ] - - if message.content in laughs_messages: - await message.channel.send(random.choice(laughs_messages)) # noqa: S311 - return - - class MessageType(StrEnum): COMMAND = "COMMAND" HELLO = "HELLO" diff --git a/uv.lock b/uv.lock index 458a1a4..c6caf63 100644 --- a/uv.lock +++ b/uv.lock @@ -460,7 +460,7 @@ wheels = [ [[package]] name = "huesoporro" -version = "0.2.3" +version = "0.2.4" source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, From 3186afe96defbd040a0d640bdab1455e7f170e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 19 Dec 2024 18:50:05 +0100 Subject: [PATCH 03/27] fix: fix logout flow which wasn't being triggered, remove useless html code --- charts/huesoporro/Chart.yaml | 2 +- charts/huesoporro/values.yaml | 2 +- pyproject.toml | 2 +- src/huesoporro/api/routes/api.py | 2 +- src/huesoporro/static/js/utils.js | 33 ------------- src/huesoporro/templates/index.html | 12 ++--- src/huesoporro/templates/login.html | 62 +------------------------ src/huesoporro/templates/logout.html | 13 ++++++ src/huesoporro/templates/sentences.html | 15 +++--- src/huesoporro/templates/tts.html | 7 +-- uv.lock | 2 +- 11 files changed, 34 insertions(+), 118 deletions(-) create mode 100644 src/huesoporro/templates/logout.html diff --git a/charts/huesoporro/Chart.yaml b/charts/huesoporro/Chart.yaml index 22ca75f..3d392ad 100644 --- a/charts/huesoporro/Chart.yaml +++ b/charts/huesoporro/Chart.yaml @@ -21,4 +21,4 @@ version: 0.2.4 # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.2.4" +appVersion: "0.2.5" diff --git a/charts/huesoporro/values.yaml b/charts/huesoporro/values.yaml index 48b749a..ea27e5a 100644 --- a/charts/huesoporro/values.yaml +++ b/charts/huesoporro/values.yaml @@ -11,7 +11,7 @@ image: # This sets the pull policy for images. pullPolicy: Always # Overrides the image tag whose default is the chart appVersion. - tag: "0.2.4" + tag: "0.2.5" # This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ imagePullSecrets: [] diff --git a/pyproject.toml b/pyproject.toml index 5b913bf..1da3c58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "huesoporro" -version = "0.2.4" +version = "0.2.5" description = "Misc Twitch bots" readme = "README.md" authors = [ diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index d427cde..9412598 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -19,7 +19,7 @@ class ManageBotDTO(BaseModel): "/tts", media_type=MediaType.HTML, ) -async def get_tts_overlay() -> Template: +async def get_tts_overlay(user: User) -> Template: return Template(template_name="tts.html") diff --git a/src/huesoporro/static/js/utils.js b/src/huesoporro/static/js/utils.js index e014ab2..9bb1116 100644 --- a/src/huesoporro/static/js/utils.js +++ b/src/huesoporro/static/js/utils.js @@ -7,36 +7,3 @@ function getWebsocketProtocol() { return "wss://"; } } - -function addLogoutEvent() { - const logoutButton = document.getElementById("logoutButton"); - logoutButton.addEventListener("click", () => { - document.cookie = "twitchLoginData=; expires=Thu, 01 Jan 1970 00:00:00 UTC"; - window.location.href = "/"; - }); - -} - -function setCookie(name, value, days) { - const date = new Date(); - date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000)); - const expires = `expires=${date.toUTCString()}`; - document.cookie = `${name}=${value};${expires};path=/;SameSite=Strict`; -} - -function getCookie(name) { - const cookieName = `${name}=`; - const decodedCookie = decodeURIComponent(document.cookie); - const cookieArray = decodedCookie.split(';'); - - for (let i = 0; i < cookieArray.length; i++) { - let cookie = cookieArray[i]; - while (cookie.charAt(0) === ' ') { - cookie = cookie.substring(1); - } - if (cookie.indexOf(cookieName) === 0) { - return cookie.substring(cookieName.length, cookie.length); - } - } - return null; -} diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html index bdf08ee..b17419f 100644 --- a/src/huesoporro/templates/index.html +++ b/src/huesoporro/templates/index.html @@ -5,12 +5,10 @@
@@ -178,7 +176,7 @@ }) const data = await response.json() console.log(data); - if (response.ok){ + if (response.ok) { alert("Settings saved successfully") } } @@ -186,8 +184,6 @@ const chatbotManager = new ChatbotManager(); chatbotManager.setEvents(); - - addLogoutEvent() }); diff --git a/src/huesoporro/templates/login.html b/src/huesoporro/templates/login.html index 00e0a38..3a3181a 100644 --- a/src/huesoporro/templates/login.html +++ b/src/huesoporro/templates/login.html @@ -6,7 +6,8 @@
- Login + Login with Twitch @@ -14,66 +15,7 @@
diff --git a/src/huesoporro/templates/logout.html b/src/huesoporro/templates/logout.html new file mode 100644 index 0000000..61f189c --- /dev/null +++ b/src/huesoporro/templates/logout.html @@ -0,0 +1,13 @@ + + diff --git a/src/huesoporro/templates/sentences.html b/src/huesoporro/templates/sentences.html index 7a0b611..e390f79 100644 --- a/src/huesoporro/templates/sentences.html +++ b/src/huesoporro/templates/sentences.html @@ -1,17 +1,14 @@ - {% include 'header.html' %}

Huesoporro🦴🍃

@@ -20,14 +17,18 @@ - + {% for sentence in sentences %} - + {% endfor %}
SentenceSentence Action
{{ sentence.sentence }} + +
diff --git a/src/huesoporro/templates/tts.html b/src/huesoporro/templates/tts.html index a0f3266..5a7843e 100644 --- a/src/huesoporro/templates/tts.html +++ b/src/huesoporro/templates/tts.html @@ -8,10 +8,7 @@
  • TTS
  • Le Funny
  • - - + {% include 'logout.html' %}

    Huesoporro🦴🍃

    @@ -186,7 +183,7 @@ // generate /tts/permalink?access_token= // the access token is available in the twitchLoginData cookie - const cookie = JSON.parse(getCookie("twitchLoginData")) + const cookie = JSON.parse(getCookie("huesoporroAuth")) const permalinkUrl = `${window.location.origin}/tts/permalink?access_token=${cookie.access_token}`; navigator.clipboard.writeText(permalinkUrl); alert('OBS link copied to clipboard ' + permalinkUrl); diff --git a/uv.lock b/uv.lock index c6caf63..d861103 100644 --- a/uv.lock +++ b/uv.lock @@ -460,7 +460,7 @@ wheels = [ [[package]] name = "huesoporro" -version = "0.2.4" +version = "0.2.5" source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, From 50900986fa5cb8b2d0065259ea95364d9bbe9beb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Fri, 17 Jan 2025 18:15:58 +0100 Subject: [PATCH 04/27] feat: revamp authentication -- remove twitch's tokens from our own wrapper token --- devenv.lock | 50 +++++++- devenv.nix | 7 + devenv.yaml | 17 +-- migrations/20241219191711_sentences.py | 35 +++++ .../20250112153541_user_external_auth.py | 53 ++++++++ .../20250113142241_external_auth_json.py | 35 +++++ pyproject.toml | 2 + src/huesoporro/actions/authenticate.py | 27 ++++ src/huesoporro/actions/get_user_by_jwt.py | 38 ++++++ src/huesoporro/actions/refresh.py | 27 ++++ src/huesoporro/api/dependencies.py | 39 ++++-- src/huesoporro/api/main.py | 15 ++- src/huesoporro/api/routes/api.py | 17 ++- src/huesoporro/api/routes/auth.py | 8 +- src/huesoporro/bot.py | 2 +- src/huesoporro/infra/authenticator.py | 16 ++- src/huesoporro/infra/db.py | 43 ++----- src/huesoporro/infra/repos.py | 114 +++++++++++++++++ src/huesoporro/models.py | 35 +++-- src/huesoporro/svc/authenticate.py | 26 ---- src/huesoporro/svc/get_sentences_svc.py | 11 ++ src/huesoporro/svc/refresh.py | 27 ---- src/huesoporro/templates/header.html | 11 +- src/huesoporro/templates/index.html | 10 +- .../templates/le_funny_dropdown.html | 10 ++ src/huesoporro/templates/login.html | 4 +- src/huesoporro/templates/logout.html | 2 +- src/huesoporro/templates/sentences.html | 121 ++++++++++++++++-- tests/conftest.py | 11 +- tests/test_repos.py | 53 ++++++++ uv.lock | 25 ++++ 31 files changed, 736 insertions(+), 155 deletions(-) create mode 100644 migrations/20241219191711_sentences.py create mode 100644 migrations/20250112153541_user_external_auth.py create mode 100644 migrations/20250113142241_external_auth_json.py create mode 100644 src/huesoporro/actions/authenticate.py create mode 100644 src/huesoporro/actions/get_user_by_jwt.py create mode 100644 src/huesoporro/actions/refresh.py create mode 100644 src/huesoporro/infra/repos.py delete mode 100644 src/huesoporro/svc/authenticate.py create mode 100644 src/huesoporro/svc/get_sentences_svc.py delete mode 100644 src/huesoporro/svc/refresh.py create mode 100644 src/huesoporro/templates/le_funny_dropdown.html create mode 100644 tests/test_repos.py diff --git a/devenv.lock b/devenv.lock index 8ab1706..97f804a 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1733788855, + "lastModified": 1735530587, "owner": "cachix", "repo": "devenv", - "rev": "d59fee8696cd48f69cf79f65992269df9891ba86", + "rev": "69645885c1052cc1ca398ac30ba7dfc63386c0e3", "type": "github" }, "original": { @@ -31,6 +31,21 @@ "type": "github" } }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, "gitignore": { "inputs": { "nixpkgs": [ @@ -66,12 +81,32 @@ "type": "github" } }, + "nixpkgs-python": { + "inputs": { + "flake-compat": "flake-compat", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733319315, + "owner": "cachix", + "repo": "nixpkgs-python", + "rev": "01263eeb28c09f143d59cd6b0b7c4cc8478efd48", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "nixpkgs-python", + "type": "github" + } + }, "nixpkgs-stable": { "locked": { - "lastModified": 1733730953, + "lastModified": 1735286948, "owner": "NixOS", "repo": "nixpkgs", - "rev": "7109b680d161993918b0a126f38bc39763e5a709", + "rev": "31ac92f9628682b294026f0860e14587a09ffb4b", "type": "github" }, "original": { @@ -83,7 +118,7 @@ }, "pre-commit-hooks": { "inputs": { - "flake-compat": "flake-compat", + "flake-compat": "flake-compat_2", "gitignore": "gitignore", "nixpkgs": [ "nixpkgs" @@ -91,10 +126,10 @@ "nixpkgs-stable": "nixpkgs-stable" }, "locked": { - "lastModified": 1733665616, + "lastModified": 1734797603, "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "d8c02f0ffef0ef39f6063731fc539d8c71eb463a", + "rev": "f0f0dc4920a903c3e08f5bdb9246bb572fcae498", "type": "github" }, "original": { @@ -107,6 +142,7 @@ "inputs": { "devenv": "devenv", "nixpkgs": "nixpkgs", + "nixpkgs-python": "nixpkgs-python", "pre-commit-hooks": "pre-commit-hooks" } } diff --git a/devenv.nix b/devenv.nix index 771f4d7..eea6c6c 100644 --- a/devenv.nix +++ b/devenv.nix @@ -5,8 +5,15 @@ packages = [ pkgs.git ]; + certificates = [ + "id.twitch.tv" + "twitch.tv" + "discord.com" + ]; + languages.python.enable = true; languages.python.uv.enable = true; + languages.python.version = "3.12.8"; scripts.hello.exec = '' echo hello from $GREET diff --git a/devenv.yaml b/devenv.yaml index 116a2ad..184b866 100644 --- a/devenv.yaml +++ b/devenv.yaml @@ -1,15 +1,8 @@ -# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json inputs: nixpkgs: url: github:cachix/devenv-nixpkgs/rolling - -# If you're using non-OSS software, you can set allowUnfree to true. -# allowUnfree: true - -# If you're willing to use a package that's vulnerable -# permittedInsecurePackages: -# - "openssl-1.1.1w" - -# If you have more than one devenv you can merge them -#imports: -# - ./backend + nixpkgs-python: + url: github:cachix/nixpkgs-python + inputs: + nixpkgs: + follows: nixpkgs diff --git a/migrations/20241219191711_sentences.py b/migrations/20241219191711_sentences.py new file mode 100644 index 0000000..06830e0 --- /dev/null +++ b/migrations/20241219191711_sentences.py @@ -0,0 +1,35 @@ +""" +This module contains a Caribou migration. + +Migration Name: sentences +Migration Version: 20241219191711 +""" + + +def upgrade(connection): + # update table `sentences` to have a user_id row + # which references users.id + # and a channel VARCHAR(255) row + + sql = """ + DROP TABLE IF EXISTS sentences; + """ + connection.execute(sql) + connection.commit() + sql = """ + CREATE TABLE sentences( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + sentence VARCHAR(255) NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + user_id VARCHAR(255) NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20250112153541_user_external_auth.py b/migrations/20250112153541_user_external_auth.py new file mode 100644 index 0000000..d873a70 --- /dev/null +++ b/migrations/20250112153541_user_external_auth.py @@ -0,0 +1,53 @@ +""" +This module contains a Caribou migration. + +Migration Name: user_external_auth +Migration Version: 20250112153541 +""" + + +def upgrade(connection): + """ + - delete access_token, refresh_token, and expires_at from users + - add external_auth table which will store the external auths: + - type: twitch or discord + - credentials: JSON + """ + + sql = """ + ALTER TABLE users DROP COLUMN access_token; + """ + connection.execute(sql) + sql = """ + ALTER TABLE users DROP COLUMN refresh_token; + """ + connection.execute(sql) + sql = """ + ALTER TABLE users DROP COLUMN expires_at; + """ + connection.execute(sql) + + sql = """ + CREATE TABLE external_auth( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + type VARCHAR(255) NOT NULL, + credentials JSON NOT NULL + ); + """ + connection.execute(sql) + + sql = """ + CREATE TABLE user_external_auth( + user_id VARCHAR(255) NOT NULL, + external_auth_id INTEGER NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (external_auth_id) REFERENCES external_auth(id) + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20250113142241_external_auth_json.py b/migrations/20250113142241_external_auth_json.py new file mode 100644 index 0000000..e92e847 --- /dev/null +++ b/migrations/20250113142241_external_auth_json.py @@ -0,0 +1,35 @@ +""" +This module contains a Caribou migration. + +Migration Name: external_auth_json +Migration Version: 20250113142241 +""" + + +def upgrade(connection): + """remove tables: + - external_auth + - user_external_auth + add column to users table: + - external_auth JSON + """ + sql = """ + DROP TABLE IF EXISTS external_auth; + """ + connection.execute(sql) + + sql = """ + DROP TABLE IF EXISTS user_external_auth; + """ + connection.execute(sql) + + sql = """ + ALTER TABLE users ADD COLUMN external_auth JSON; + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/pyproject.toml b/pyproject.toml index 1da3c58..d2f948a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pyjwt>=2.10.1", "twitchio>=2.10.0", "redis>=5.2.1", + "pytz>=2024.2", + "discord-py>=2.4.0", ] [tool.uv] diff --git a/src/huesoporro/actions/authenticate.py b/src/huesoporro/actions/authenticate.py new file mode 100644 index 0000000..9bf7b6a --- /dev/null +++ b/src/huesoporro/actions/authenticate.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class AuthenticateAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run( + self, + auth_code: str, + ): + tokens = await self.authenticator.get_token(auth_code) + username = tokens.userinfo["preferred_username"] + if username not in self.s.allowed_users: + raise ValueError(f"User {username} is not allowed to use this bot") + user = User(user=username, external_auth={"twitch": tokens.model_dump()}) + if await self.user_repo.get_by_user(user.user): + await self.user_repo.update(user) + else: + await self.user_repo.create(user) + return user.encode() diff --git a/src/huesoporro/actions/get_user_by_jwt.py b/src/huesoporro/actions/get_user_by_jwt.py new file mode 100644 index 0000000..b5311ab --- /dev/null +++ b/src/huesoporro/actions/get_user_by_jwt.py @@ -0,0 +1,38 @@ +from loguru import logger +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class GetUserByJWTAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run( + self, + jwt_token: str, + ) -> User: + user_data = User.decode(jwt_token) + username = user_data["user"] + user = await self.user_repo.get_by_user(username) + if not user: + raise ValueError(f"User {username} not found") + is_valid = await self.authenticator.token_is_valid( + user.external_auth["twitch"]["access_token"] + ) + logger.info(f"Token {user} is valid: {is_valid}") + if not is_valid: + logger.info(f"Refreshing token for user {user}") + new_tokens = await self.authenticator.refresh_token( + user.external_auth["twitch"]["refresh_token"] + ) + user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index] + user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index] + await self.user_repo.update(user) + return user + + return user diff --git a/src/huesoporro/actions/refresh.py b/src/huesoporro/actions/refresh.py new file mode 100644 index 0000000..4ba8543 --- /dev/null +++ b/src/huesoporro/actions/refresh.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class RefreshAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run(self, user: User): + is_valid = await self.authenticator.token_is_valid( + user.external_auth["twitch"]["access_token"] + ) + + if not is_valid: + new_tokens = await self.authenticator.refresh_token( + user.external_auth["twitch"]["refresh_token"] + ) + user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index] + user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index] + await self.user_repo.update(user) + return user.encode() + return None diff --git a/src/huesoporro/api/dependencies.py b/src/huesoporro/api/dependencies.py index 5db8371..d1d9812 100644 --- a/src/huesoporro/api/dependencies.py +++ b/src/huesoporro/api/dependencies.py @@ -1,12 +1,15 @@ from litestar import Request from litestar.exceptions import HTTPException +from src.huesoporro.actions.authenticate import AuthenticateAction +from src.huesoporro.actions.get_user_by_jwt import GetUserByJWTAction from src.huesoporro.infra.authenticator import TwitchAuthenticator from src.huesoporro.infra.db import Database +from src.huesoporro.infra.repos import UserRepo from src.huesoporro.models import User from src.huesoporro.settings import Settings -from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc @@ -22,27 +25,43 @@ def get_db(s: Settings): return Database(s=s) -async def authenticate(request: Request) -> User: +async def get_get_user_by_jwt_action( + user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings +): + return GetUserByJWTAction(user_repo=user_repo, authenticator=authenticator, s=s) + + +async def authenticate( + request: Request, get_user_by_jwt_action: GetUserByJWTAction +) -> User: token = request.query_params.get("huesoporro_token") if token: - return User.decode(token) + return await get_user_by_jwt_action.run(token) cookies = request.cookies.get("huesoporroAuth") if cookies: - return User.decode(cookies) + return await get_user_by_jwt_action.run(cookies) raise HTTPException(status_code=401, detail="Unauthorized") -async def get_code_authenticator_svc( - a: TwitchAuthenticator, db: Database -) -> CodeAuthenticatorSvc: - return CodeAuthenticatorSvc(authenticator=a, db=db) - - async def get_chatbot_settings_svc(db: Database): return ChatbotSettingsGetterSvc(db=db) async def store_chatbot_settings_svc(db: Database): return ChatbotSettingsStorerSvc(db=db) + + +async def get_sentences_svc(db: Database): + return SentencesGetterSvc(db=db) + + +async def get_user_repo(s: Settings): + return UserRepo(s=s) + + +async def get_authenticate_action( + user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings +): + return AuthenticateAction(user_repo=user_repo, authenticator=authenticator, s=s) diff --git a/src/huesoporro/api/main.py b/src/huesoporro/api/main.py index dd0489a..8df704e 100644 --- a/src/huesoporro/api/main.py +++ b/src/huesoporro/api/main.py @@ -8,11 +8,14 @@ from litestar.template import TemplateConfig from src.huesoporro.api.dependencies import ( authenticate, + get_authenticate_action, get_authenticator, get_chatbot_settings_svc, - get_code_authenticator_svc, get_db, + get_get_user_by_jwt_action, + get_sentences_svc, get_settings, + get_user_repo, store_chatbot_settings_svc, ) from src.huesoporro.api.errors import ( @@ -24,10 +27,12 @@ from src.huesoporro.api.routes.api import ( get_bot_settings, get_bot_status, get_index, + get_sentences, get_tts_overlay, get_tts_permalink, manage_bot, save_bot_settings, + save_new_sentence, ) from src.huesoporro.api.routes.auth import get_code, login from src.huesoporro.bot import BotsManager @@ -52,6 +57,8 @@ def create_app(): get_bot_status, save_bot_settings, get_bot_settings, + get_sentences, + save_new_sentence, ], static_files_config=( StaticFilesConfig( @@ -77,10 +84,14 @@ def create_app(): "a": Provide(get_authenticator, use_cache=True), "user": Provide(authenticate), "db": Provide(get_db, use_cache=True), - "code_authenticator_svc": Provide(get_code_authenticator_svc), "bm": Provide(BotsManager, use_cache=True), "gbs": Provide(get_chatbot_settings_svc), "sbs": Provide(store_chatbot_settings_svc), + "sgs": Provide(get_sentences_svc), + "authenticator": Provide(get_authenticator), + "authenticate_action": Provide(get_authenticate_action), + "user_repo": Provide(get_user_repo), + "get_user_by_jwt_action": Provide(get_get_user_by_jwt_action), }, ) diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index 9412598..55e6557 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -1,12 +1,13 @@ from typing import Literal -from litestar import MediaType, Response, get, put +from litestar import MediaType, Response, get, post, put from litestar.response import Template from pydantic import BaseModel from src.huesoporro.bot import BotsManager from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc @@ -98,3 +99,17 @@ async def save_bot_settings( ) -> dict: await sbs.run(user=user, bot_settings=data) return {"status": "ok"} + + +@get("/sentences") +async def get_sentences(user: User, sgs: SentencesGetterSvc) -> Template: + sentences = await sgs.run(user=user) + return Template( + template_name="sentences.html", + context={"sentences": [sentence.model_dump() for sentence in sentences]}, + ) + + +@post("/api/v1/sentences") +async def save_new_sentence(user: User, data: dict) -> dict: + return {"id": 54, "sentence": data["sentence"]} diff --git a/src/huesoporro/api/routes/auth.py b/src/huesoporro/api/routes/auth.py index 2f1c287..df95441 100644 --- a/src/huesoporro/api/routes/auth.py +++ b/src/huesoporro/api/routes/auth.py @@ -3,14 +3,14 @@ import secrets from litestar import MediaType, get from litestar.response import Redirect, Template +from src.huesoporro.actions.authenticate import AuthenticateAction from src.huesoporro.settings import Settings -from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc @get(path="/o/code") -async def get_code(code: str, code_authenticator_svc: CodeAuthenticatorSvc) -> Redirect: - user = await code_authenticator_svc.run(code) - return Redirect("/", cookies={"huesoporroAuth": user.encode()}) +async def get_code(code: str, authenticate_action: AuthenticateAction) -> Redirect: + token = await authenticate_action.run(code) + return Redirect("/", cookies={"huesoporroAuth": token}) @get( diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index e6b3c36..025a2dd 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -23,7 +23,7 @@ from src.huesoporro.svc.store_quote import QuoteStorerSvc class Bot(commands.Bot): def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str): super().__init__( - token=user.twitch_auth.access_token, prefix="!", initial_channels=[channel] + token=user.twitch_access_token, prefix="!", initial_channels=[channel] ) self.channel = channel self.user = user diff --git a/src/huesoporro/infra/authenticator.py b/src/huesoporro/infra/authenticator.py index d46922c..9c372d8 100644 --- a/src/huesoporro/infra/authenticator.py +++ b/src/huesoporro/infra/authenticator.py @@ -30,7 +30,15 @@ class TwitchAuthenticator(BaseModel): return await self.refresh_token(response.json()["refresh_token"]) response.raise_for_status() - return TwitchAuth(**response.json()) + profile = await self.get_userinfo(response.json()["access_token"]) + return TwitchAuth(**response.json(), userinfo=profile) + + async def get_userinfo(self, access_token): + response = await self.client.get( + "/oauth2/userinfo", headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + return response.json() async def refresh_token(self, refresh_token: str) -> TwitchAuth: response = await self.client.post( @@ -60,3 +68,9 @@ class TwitchAuthenticator(BaseModel): raise HTTPException(status_code=403, detail="Forbidden") return user + + async def token_is_valid(self, access_token: str) -> bool: + response = await self.client.get( + "/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"} + ) + return response.status_code == 200 # noqa: PLR2004 diff --git a/src/huesoporro/infra/db.py b/src/huesoporro/infra/db.py index b2f6f8f..86d8260 100644 --- a/src/huesoporro/infra/db.py +++ b/src/huesoporro/infra/db.py @@ -5,7 +5,7 @@ import aiosqlite from loguru import logger from pydantic import BaseModel, Field -from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.models import ChatbotSettings, Sentence, User from src.huesoporro.settings import Settings @@ -24,36 +24,6 @@ class Database(BaseModel): def get_now() -> float: return datetime.datetime.now(datetime.UTC).timestamp() - async def save_user(self, user: User, auto_commit=True): - async with self.get_client(auto_commit=auto_commit) as db: - async with db.execute( - "SELECT * FROM users WHERE user = ?", (user.user,) - ) as cursor: - result = await cursor.fetchone() - if result: - await db.execute( - "UPDATE users SET access_token = ?, refresh_token = ?, expires_at = ?, last_updated_at = ? WHERE user = ?", - ( - user.twitch_auth.access_token, - user.twitch_auth.refresh_token, - user.expires_at, - self.get_now(), - user.user, - ), - ) - return - - await db.execute( - "INSERT INTO users (user, access_token, refresh_token, expires_at, last_updated_at) VALUES (?,?,?,?,?)", - ( - user.user, - user.twitch_auth.access_token, - user.twitch_auth.refresh_token, - user.expires_at, - self.get_now(), - ), - ) - async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True): async with self.get_client(auto_commit=auto_commit) as db: await db.execute( @@ -133,3 +103,14 @@ class Database(BaseModel): ) as cursor, ): return await cursor.fetchone() + + async def get_sentences(self, user: User) -> list[Sentence]: + async with self.get_client() as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM sentences WHERE user_id = ?", (user.user,) + ) as cursor: + result = await cursor.fetchall() + if not result: + return [] + return [Sentence(user=user, **dict(value)) for value in result] diff --git a/src/huesoporro/infra/repos.py b/src/huesoporro/infra/repos.py new file mode 100644 index 0000000..7e4e939 --- /dev/null +++ b/src/huesoporro/infra/repos.py @@ -0,0 +1,114 @@ +import json +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from typing import Generic, TypeVar + +import aiosqlite +from pydantic import BaseModel, Field + +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + +T = TypeVar("T", bound=BaseModel) + + +class IRepo(BaseModel, ABC, Generic[T]): + s: Settings = Field(default_factory=Settings.get) + + @asynccontextmanager + async def get_client(self, auto_commit=True): + async with aiosqlite.connect(self.s.db_filepath) as db: + db.row_factory = aiosqlite.Row + yield db + if auto_commit: + await db.commit() + + @abstractmethod + async def create(self, obj: T, auto_commit=True) -> T: + pass + + @abstractmethod + async def update(self, obj: T, auto_commit=True) -> T: + pass + + @abstractmethod + async def delete(self, obj: T, auto_commit=True): + pass + + @abstractmethod + async def get_by_id(self, obj_id: int | str, auto_commit=True) -> T | None: + pass + + @abstractmethod + async def list( + self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True + ) -> list[T]: + pass + + +class UserRepo(IRepo[User]): + async def get_by_id(self, obj_id: int | str, auto_commit=True) -> User | None: + raise NotImplementedError("Not implemented since it's not needed") + + async def create(self, obj: User, auto_commit=True) -> User: + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + "INSERT INTO users (user, external_auth) VALUES (?, ?)", + (obj.user, json.dumps(obj.external_auth)), + ) + return obj + + async def update(self, obj: User, auto_commit=True) -> User: + if not await self.get_by_user(obj.user): + raise ValueError(f"User {obj.user} does not exist") + + async with ( + self.get_client(auto_commit=auto_commit) as db, + db.execute( + """ + UPDATE users + SET external_auth = ? + WHERE user = ? + RETURNING * + """, + (json.dumps(obj.external_auth), obj.user), + ) as cursor, + ): + data = await cursor.fetchone() + return User( + user=data["user"], external_auth=json.loads(data["external_auth"]) + ) + + async def delete(self, obj: User, auto_commit=True): + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + """ + DELETE FROM users WHERE user = ? + """, + (obj.user,), + ) + + async def get_by_user(self, user: str, auto_commit=True) -> User | None: + async with ( + self.get_client(auto_commit=auto_commit) as db, + db.execute( + """ + SELECT * FROM users WHERE user = ? + """, + (user,), + ) as cursor, + ): + data = await cursor.fetchone() + if not data: + return None + return User( + user=data["user"], external_auth=json.loads(data["external_auth"]) + ) + + async def list( + self, obj: User, offset: int = 0, limit: int = 10, auto_commit=True + ) -> list[User]: + raise NotImplementedError("Not implemented since it's not needed") + + async def count(self, obj: User, auto_commit=True): + raise NotImplementedError("Not implemented since it's not needed") diff --git a/src/huesoporro/models.py b/src/huesoporro/models.py index b4df9e9..f791cf6 100644 --- a/src/huesoporro/models.py +++ b/src/huesoporro/models.py @@ -1,4 +1,4 @@ -from typing import Self +from typing import Literal import jwt from pydantic import BaseModel, Field, field_validator @@ -9,28 +9,39 @@ from src.huesoporro.settings import Settings class TwitchAuth(BaseModel): access_token: str refresh_token: str + userinfo: dict + + +class ExternalAuth(BaseModel): + credentials: dict + type: Literal["twitch"] = "twitch" class User(BaseModel): user: str - expires_at: float - twitch_auth: TwitchAuth + external_auth: dict[Literal["twitch", "discord"], dict] - def encode(self, settings: Settings | None = None) -> str: + def encode( + self, settings: Settings | None = None, exclude_fields: set[str] | None = None + ) -> str: s = settings or Settings.get() + exclude_fields = exclude_fields or {"external_auth"} return jwt.encode( - self.model_dump(), + self.model_dump(exclude=exclude_fields), key=s.jwt_secret.get_secret_value(), algorithm="HS256", ) @classmethod - def decode(cls, token: str, settings: Settings | None = None) -> Self: + def decode(cls, token: str, settings: Settings | None = None) -> dict: s = settings or Settings.get() - decoded = jwt.decode( + return jwt.decode( token, key=s.jwt_secret.get_secret_value(), algorithms=["HS256"] ) - return cls(**decoded) + + @property + def twitch_access_token(self): + return self.external_auth["twitch"]["access_token"] class ChatbotSettings(BaseModel): @@ -50,3 +61,11 @@ class ChatbotSettings(BaseModel): if isinstance(v, str): return v.split(",") return v + + +class Sentence(BaseModel): + id: int + sentence: str + created_at: float + last_updated_at: float + user: User diff --git a/src/huesoporro/svc/authenticate.py b/src/huesoporro/svc/authenticate.py deleted file mode 100644 index 346a407..0000000 --- a/src/huesoporro/svc/authenticate.py +++ /dev/null @@ -1,26 +0,0 @@ -import datetime - -from pydantic import BaseModel - -from src.huesoporro.infra.authenticator import TwitchAuthenticator -from src.huesoporro.infra.db import Database -from src.huesoporro.models import User - - -class CodeAuthenticatorSvc(BaseModel): - db: Database - authenticator: TwitchAuthenticator - - @staticmethod - def get_four_hours_from_now() -> float: - now = datetime.datetime.now(datetime.UTC) - four_hours_later = now + datetime.timedelta(hours=4) - return four_hours_later.timestamp() - - async def run(self, code: str) -> User: - auth = await self.authenticator.get_token(code) - username = await self.authenticator.validate_token(auth.access_token) - expires_at = self.get_four_hours_from_now() - user = User(user=username, expires_at=expires_at, twitch_auth=auth) - await self.db.save_user(user) - return user diff --git a/src/huesoporro/svc/get_sentences_svc.py b/src/huesoporro/svc/get_sentences_svc.py new file mode 100644 index 0000000..1cb0a1c --- /dev/null +++ b/src/huesoporro/svc/get_sentences_svc.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database +from src.huesoporro.models import Sentence, User + + +class SentencesGetterSvc(BaseModel): + db: Database + + async def run(self, user: User) -> list[Sentence]: + return await self.db.get_sentences(user=user) diff --git a/src/huesoporro/svc/refresh.py b/src/huesoporro/svc/refresh.py deleted file mode 100644 index 9209743..0000000 --- a/src/huesoporro/svc/refresh.py +++ /dev/null @@ -1,27 +0,0 @@ -import datetime - -from pydantic import BaseModel - -from src.huesoporro.infra.authenticator import TwitchAuthenticator -from src.huesoporro.infra.db import Database -from src.huesoporro.models import User - - -class RefreshTokenAuthenticator(BaseModel): - db: Database - authenticator: TwitchAuthenticator - - @staticmethod - def get_four_hours_from_now() -> float: - now = datetime.datetime.now(datetime.UTC) - four_hours_later = now + datetime.timedelta(hours=4) - return four_hours_later.timestamp() - - async def run(self, refresh_token: str) -> User: - auth = await self.authenticator.refresh_token(refresh_token) - username = await self.authenticator.validate_token(auth.access_token) - expires_at = self.get_four_hours_from_now() - - user = User(user=username, expires_at=expires_at, twitch_auth=auth) - await self.db.save_user(user) - return user diff --git a/src/huesoporro/templates/header.html b/src/huesoporro/templates/header.html index 7c1cfd9..e6fe60a 100644 --- a/src/huesoporro/templates/header.html +++ b/src/huesoporro/templates/header.html @@ -2,15 +2,24 @@ - + + + + + Huesoporro diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html index b17419f..e7ffa8c 100644 --- a/src/huesoporro/templates/index.html +++ b/src/huesoporro/templates/index.html @@ -2,16 +2,16 @@
    -
    -
    +
    @@ -102,7 +102,7 @@ .catch((error) => { console.error('Failed to retrieve chatbot status', error); }); - }, 2000); + }, 5000); } async startBot() { @@ -184,6 +184,8 @@ const chatbotManager = new ChatbotManager(); chatbotManager.setEvents(); + + }); diff --git a/src/huesoporro/templates/le_funny_dropdown.html b/src/huesoporro/templates/le_funny_dropdown.html new file mode 100644 index 0000000..a7f63a6 --- /dev/null +++ b/src/huesoporro/templates/le_funny_dropdown.html @@ -0,0 +1,10 @@ +
  • + +
  • diff --git a/src/huesoporro/templates/login.html b/src/huesoporro/templates/login.html index 3a3181a..b40ac0c 100644 --- a/src/huesoporro/templates/login.html +++ b/src/huesoporro/templates/login.html @@ -1,9 +1,9 @@ {% include 'header.html' %} -
    +

    Huesoporro🦴🚬

    -
    +
    -
  • Logout
  • +
  • Logout
  • -
    +
    - + + + + + +
    - - + + + + {% for sentence in sentences %} - + + {% endfor %} +
    SentenceActionSentenceLast modifiedAction
    {{ sentence.sentence }}{{ sentence.last_updated_at }} - +
    + + +
    diff --git a/tests/conftest.py b/tests/conftest.py index 19218d9..c8c6c23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from caribou.migrate import Database as CaribouDatabase from caribou.migrate import load_migrations from src.huesoporro.infra.db import Database -from src.huesoporro.models import ChatbotSettings, TwitchAuth, User +from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.settings import Settings from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.is_mod import IsModSvc @@ -15,11 +15,10 @@ from src.huesoporro.svc.is_mod import IsModSvc def user() -> User: return User( user="huesoporro", - expires_at=1671234567.0, - twitch_auth=TwitchAuth( - access_token="test_access_token", # noqa: S106 - refresh_token="test_refresh_token", # noqa: S106 - ), + external_auth={ + "twitch": {"token": "twitch_token"}, + "discord": {"token": "discord_token"}, + }, ) diff --git a/tests/test_repos.py b/tests/test_repos.py new file mode 100644 index 0000000..423e195 --- /dev/null +++ b/tests/test_repos.py @@ -0,0 +1,53 @@ +import json + +import pytest + +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User + + +@pytest.fixture +async def user_repo(s, db, user: User): + async with db.get_client() as client: + await client.execute( + "INSERT INTO users (user, external_auth) VALUES (?, ?)", + (user.user, json.dumps(user.external_auth)), + ) + + return UserRepo(s=s) + + +async def test_get_user(user_repo: UserRepo, user: User): + db_user = await user_repo.get_by_user(user.user) + assert db_user == user + + +async def test_get_user_returns_none(user_repo: UserRepo): + assert await user_repo.get_by_user("unknown_user") is None + + +async def test_create_user(user_repo: UserRepo): + new_user = User( + user="new_user", external_auth={"twitch": {"token": "twitch_token"}} + ) + assert await user_repo.create(new_user) == new_user + + +async def test_update_users_tokens(user_repo: UserRepo, user: User): + new_tokens = {"twitch": {"token": "new_tokens"}} + user.external_auth = new_tokens # type: ignore[assignment] + assert await user_repo.update(user) == user + + +async def test_update_non_existing_user_raises_value_error(user_repo: UserRepo): + with pytest.raises(ValueError, match="User unknown_user does not exist"): + await user_repo.update( + User( + user="unknown_user", external_auth={"twitch": {"token": "twitch_token"}} + ) + ) + + +async def test_delete_user(user_repo: UserRepo, user: User): + assert await user_repo.delete(user) is None + assert await user_repo.get_by_user(user.user) is None diff --git a/uv.lock b/uv.lock index d861103..ccca98c 100644 --- a/uv.lock +++ b/uv.lock @@ -283,6 +283,18 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] +[[package]] +name = "discord-py" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/af/80cab4015722d3bee175509b7249a11d5adf77b5ff4c27f268558079d149/discord_py-2.4.0.tar.gz", hash = "sha256:d07cb2a223a185873a1d0ee78b9faa9597e45b3f6186df21a95cec1e9bcdc9a5", size = 1027707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/10/3c44e9331a5ec3bae8b2919d51f611a5b94e179563b1b89eb6423a8f43eb/discord.py-2.4.0-py3-none-any.whl", hash = "sha256:b8af6711c70f7e62160bfbecb55be699b5cb69d007426759ab8ab06b1bd77d1d", size = 1125988 }, +] + [[package]] name = "editorconfig" version = "0.12.4" @@ -465,6 +477,7 @@ source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, { name = "caribou" }, + { name = "discord-py" }, { name = "gtts" }, { name = "httpx" }, { name = "litestar", extra = ["standard"] }, @@ -474,6 +487,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pytz" }, { name = "redis" }, { name = "twitchio" }, ] @@ -491,6 +505,7 @@ dev = [ requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "caribou", specifier = ">=0.4.1" }, + { name = "discord-py", specifier = ">=2.4.0" }, { name = "gtts", specifier = ">=2.5.4" }, { name = "httpx", specifier = ">=0.28.0" }, { name = "litestar", extras = ["standard"], specifier = ">=2.13.0" }, @@ -500,6 +515,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.9.2" }, { name = "pydantic-settings", specifier = ">=2.6.0" }, { name = "pyjwt", specifier = ">=2.10.1" }, + { name = "pytz", specifier = ">=2024.2" }, { name = "redis", specifier = ">=5.2.1" }, { name = "twitchio", specifier = ">=2.10.0" }, ] @@ -1101,6 +1117,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, ] +[[package]] +name = "pytz" +version = "2024.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/31/3c70bf7603cc2dca0f19bdc53b4537a797747a58875b552c8c413d963a3f/pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a", size = 319692 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/c3/005fcca25ce078d2cc29fd559379817424e94885510568bc1bc53d7d5846/pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725", size = 508002 }, +] + [[package]] name = "pyyaml" version = "6.0.2" From 75df191253987cd7d4b06c3b837b42430c7edd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Thu, 13 Feb 2025 09:52:15 +0100 Subject: [PATCH 05/27] feat: add GetRandomQuoteAction --- devenv.lock | 26 +-- devenv.nix | 13 -- src/huesoporro/actions/get_random_quote.py | 11 ++ src/huesoporro/api/routes/api.py | 2 - src/huesoporro/api/routes/auth.py | 12 +- src/huesoporro/bot.py | 20 ++- src/huesoporro/infra/gtts.py | 48 ++++++ src/huesoporro/infra/repos.py | 45 ++++- src/huesoporro/models.py | 15 ++ src/huesoporro/svc/get_random_quote.py | 9 +- src/huesoporro/templates/index.html | 2 +- src/huesoporro/templates/tts.html | 181 ++------------------- tests/test_repos.py | 19 ++- 13 files changed, 185 insertions(+), 218 deletions(-) create mode 100644 src/huesoporro/actions/get_random_quote.py create mode 100644 src/huesoporro/infra/gtts.py diff --git a/devenv.lock b/devenv.lock index 97f804a..d776fa9 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1735530587, + "lastModified": 1739362938, "owner": "cachix", "repo": "devenv", - "rev": "69645885c1052cc1ca398ac30ba7dfc63386c0e3", + "rev": "27276816caa1718f8b8e8d53d64cc18da059e101", "type": "github" }, "original": { @@ -101,35 +101,19 @@ "type": "github" } }, - "nixpkgs-stable": { - "locked": { - "lastModified": 1735286948, - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "31ac92f9628682b294026f0860e14587a09ffb4b", - "type": "github" - }, - "original": { - "owner": "NixOS", - "ref": "nixos-24.05", - "repo": "nixpkgs", - "type": "github" - } - }, "pre-commit-hooks": { "inputs": { "flake-compat": "flake-compat_2", "gitignore": "gitignore", "nixpkgs": [ "nixpkgs" - ], - "nixpkgs-stable": "nixpkgs-stable" + ] }, "locked": { - "lastModified": 1734797603, + "lastModified": 1737465171, "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "f0f0dc4920a903c3e08f5bdb9246bb572fcae498", + "rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index eea6c6c..06edec4 100644 --- a/devenv.nix +++ b/devenv.nix @@ -5,24 +5,11 @@ packages = [ pkgs.git ]; - certificates = [ - "id.twitch.tv" - "twitch.tv" - "discord.com" - ]; - languages.python.enable = true; languages.python.uv.enable = true; languages.python.version = "3.12.8"; - scripts.hello.exec = '' - echo hello from $GREET - ''; - enterShell = '' - hello - git --version - fish ''; dotenv.enable = true; diff --git a/src/huesoporro/actions/get_random_quote.py b/src/huesoporro/actions/get_random_quote.py new file mode 100644 index 0000000..47c7b5b --- /dev/null +++ b/src/huesoporro/actions/get_random_quote.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.huesoporro.models import Quote +from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc + + +class GetRandomQuoteAction(BaseModel): + quote_getter_svc: RandomQuoteGetterSvc + + async def run(self, channel_name: str) -> Quote | None: + return await self.quote_getter_svc.run(channel_name=channel_name) diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index 55e6557..607789f 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -32,8 +32,6 @@ async def get_tts_permalink(access_token: str) -> Template: """Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query param and not as a cookie, i.e. OBS""" - # authenticate the user using the provided access token - return Template( template_name="tts.html", ) diff --git a/src/huesoporro/api/routes/auth.py b/src/huesoporro/api/routes/auth.py index df95441..7d930b2 100644 --- a/src/huesoporro/api/routes/auth.py +++ b/src/huesoporro/api/routes/auth.py @@ -1,6 +1,7 @@ import secrets from litestar import MediaType, get +from litestar.datastructures.cookie import Cookie from litestar.response import Redirect, Template from src.huesoporro.actions.authenticate import AuthenticateAction @@ -10,7 +11,16 @@ from src.huesoporro.settings import Settings @get(path="/o/code") async def get_code(code: str, authenticate_action: AuthenticateAction) -> Redirect: token = await authenticate_action.run(code) - return Redirect("/", cookies={"huesoporroAuth": token}) + return Redirect( + "/", + cookies=[ + Cookie( + key="huesoporroAuth", + value=token, + expires=604800, # 1 week + ) + ], + ) @get( diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index 025a2dd..f682260 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -7,8 +7,11 @@ from loguru import logger from twitchio import Channel from twitchio.ext import commands, routines +from src.huesoporro.actions.get_random_quote import GetRandomQuoteAction from src.huesoporro.actions.store_quote import StoreQuoteAction +from src.huesoporro.api.dependencies import get_settings from src.huesoporro.infra.db import Database +from src.huesoporro.infra.repos import QuoteRepo from src.huesoporro.libs.db import Database as MarkovDB from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.svc.backoff_service import BackoffService @@ -33,8 +36,11 @@ class Bot(commands.Bot): self.store_quote_action = StoreQuoteAction( quote_storer_svc=QuoteStorerSvc(db=db), is_mod_svc=IsModSvc(db=db) ) - - self.get_random_quote_svc = RandomQuoteGetterSvc(db=db) + self.quote_repo = QuoteRepo(s=get_settings()) + self.get_random_quote_svc = RandomQuoteGetterSvc(quote_repo=self.quote_repo) + self.get_random_quote_action = GetRandomQuoteAction( + quote_getter_svc=self.get_random_quote_svc + ) self.cbs = chatbot_settings self.quote_routine = routines.routine( seconds=chatbot_settings.automatic_quote_timer, wait_first=True @@ -78,19 +84,19 @@ class Bot(commands.Bot): @commands.command(aliases=["q", "quote"]) async def get_random_quote(self, ctx: commands.Context): - quote = await self.get_random_quote_svc.run(channel_name=self.channel) + quote = await self.get_random_quote_action.run(channel_name=self.channel) if quote: - await ctx.send(f"«{quote[0]}» - {quote[1]}") + await ctx.send(quote.as_pretty()) def get_channel_conn(self) -> Channel: return Channel(name=self.channel, websocket=self._connection) async def send_quote(self): - quote = await self.get_random_quote_svc.run(channel_name=self.channel) + quote = await self.get_random_quote_action.run(channel_name=self.channel) if quote: channel = self.get_channel_conn() - logger.info(f"Sending random quote {quote[0]}") - await channel.send(f"«{quote[0]}» - {quote[1]}") + logger.info(f"Sending random quote {quote.quote}") + await channel.send(quote.quote) async def send_generation(self): sentence = await self.generate_svc.run() diff --git a/src/huesoporro/infra/gtts.py b/src/huesoporro/infra/gtts.py new file mode 100644 index 0000000..10814ec --- /dev/null +++ b/src/huesoporro/infra/gtts.py @@ -0,0 +1,48 @@ +from collections import deque +from hashlib import sha512 +from pathlib import Path + +from gtts import gTTS +from loguru import logger +from pydantic import BaseModel + +from src.huesoporro.settings import Settings + + +class GTTS(BaseModel): + s: Settings + chunk_size: int = 128 + text_max_length: int = 100 + queue: deque = deque() + + async def generate(self, text: str, lang: str = "pt", tld="com.br") -> Path: + text = text[: self.text_max_length] + raw_filename = f"{text.lower()}_{lang}_{tld}" + logger.info(f"Generating TTS for {raw_filename}") + filepath = ( + self.s.tts_cache_path / f"{sha512(raw_filename.encode()).hexdigest()}.mp3" + ) + tts = gTTS(text=text, lang=lang, tld=tld) + logger.info(f"Saving TTS to {filepath}") + tts.save(str(filepath)) + self.queue.append(filepath) + return filepath + + async def consume(self): + """If there are items in the queue, return a generator + that reads the file's bytes by chunks of self.chunk_size""" + while self.queue: + filepath = self.queue.popleft() + if not filepath.exists(): + logger.warning(f"File {filepath} does not exist, skipping") + continue + + logger.info(f"Reading file {filepath}") + try: + with filepath.open("rb") as f: + while chunk := f.read(self.chunk_size): + yield chunk + logger.info(f"Finished reading {filepath}") + except Exception as e: # noqa: BLE001 + logger.error(f"Error reading file {filepath}: {e}") + continue diff --git a/src/huesoporro/infra/repos.py b/src/huesoporro/infra/repos.py index 7e4e939..a4aab5b 100644 --- a/src/huesoporro/infra/repos.py +++ b/src/huesoporro/infra/repos.py @@ -6,7 +6,7 @@ from typing import Generic, TypeVar import aiosqlite from pydantic import BaseModel, Field -from src.huesoporro.models import User +from src.huesoporro.models import Quote, User from src.huesoporro.settings import Settings T = TypeVar("T", bound=BaseModel) @@ -112,3 +112,46 @@ class UserRepo(IRepo[User]): async def count(self, obj: User, auto_commit=True): raise NotImplementedError("Not implemented since it's not needed") + + +class QuoteRepo(IRepo[Quote]): + async def create(self, obj: Quote, auto_commit=True) -> Quote: + raise NotImplementedError("Not implemented since it's not needed") + + async def update(self, obj: Quote, auto_commit=True) -> Quote: + raise NotImplementedError("Not implemented since it's not needed") + + async def delete(self, obj: Quote, auto_commit=True): + raise NotImplementedError("Not implemented since it's not needed") + + async def get_by_id(self, obj_id: int | str, auto_commit=True) -> Quote | None: + raise NotImplementedError("Not implemented since it's not needed") + + async def list( + self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True + ) -> list[T]: + raise NotImplementedError("Not implemented since it's not needed") + + async def get_random(self, channel_name: str, auto_commit=True) -> Quote | None: + async with ( + self.get_client(auto_commit=auto_commit) as db, + db.execute( + """ + SELECT * FROM quotes + WHERE channel = ? + ORDER BY RANDOM() + LIMIT 1 + """, + (channel_name,), + ) as cursor, + ): + data = await cursor.fetchone() + if not data: + return None + return Quote( + quote=data["quote"], + author=User(user=data["author"], external_auth={}), + channel=User(user=data["channel"], external_auth={}), + created_at=data["created_at"], + last_updated_at=data["last_updated_at"], + ) diff --git a/src/huesoporro/models.py b/src/huesoporro/models.py index f791cf6..7222887 100644 --- a/src/huesoporro/models.py +++ b/src/huesoporro/models.py @@ -1,3 +1,4 @@ +import datetime from typing import Literal import jwt @@ -69,3 +70,17 @@ class Sentence(BaseModel): created_at: float last_updated_at: float user: User + + +class Quote(BaseModel): + quote: str + author: User + channel: User + created_at: datetime.datetime + last_updated_at: datetime.datetime + + def as_pretty(self) -> str: + return f"«{self.quote}» - {self.author}" + + def as_pretty_saved(self): + return f"He añadido la cita «{self.quote}» de {self.author}" diff --git a/src/huesoporro/svc/get_random_quote.py b/src/huesoporro/svc/get_random_quote.py index f6b4f6c..055b633 100644 --- a/src/huesoporro/svc/get_random_quote.py +++ b/src/huesoporro/svc/get_random_quote.py @@ -1,10 +1,11 @@ from pydantic import BaseModel -from src.huesoporro.infra.db import Database +from src.huesoporro.infra.repos import QuoteRepo +from src.huesoporro.models import Quote class RandomQuoteGetterSvc(BaseModel): - db: Database + quote_repo: QuoteRepo - async def run(self, channel_name: str) -> tuple[str, str] | None: - return await self.db.get_random_quote(channel_name=channel_name) + async def run(self, channel_name: str) -> Quote | None: + return await self.quote_repo.get_random(channel_name=channel_name) diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html index e7ffa8c..082c17e 100644 --- a/src/huesoporro/templates/index.html +++ b/src/huesoporro/templates/index.html @@ -5,7 +5,7 @@