From 3bc4e19de1cb5b30377f0da2c799645093dfc317 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?c=C4=83t=C4=83lin?=
Date: Thu, 19 Dec 2024 18:13:38 +0100
Subject: [PATCH 01/27] feat: add backoff service and some message reactions
---
pyproject.toml | 2 +-
src/huesoporro/api/errors.py | 2 +-
src/huesoporro/api/routes/api.py | 18 ++-
src/huesoporro/bot.py | 172 ++++++++++++++++++++++++-
src/huesoporro/svc/backoff_service.py | 111 ++++++++++++++++
src/huesoporro/svc/generate.py | 12 +-
src/huesoporro/svc/get_random_quote.py | 2 +-
src/huesoporro/svc/hello.py | 4 +-
tests/conftest.py | 32 ++++-
tests/test_svc.py | 61 +++++++++
uv.lock | 2 +-
11 files changed, 394 insertions(+), 24 deletions(-)
create mode 100644 src/huesoporro/svc/backoff_service.py
diff --git a/pyproject.toml b/pyproject.toml
index 9d19c3b..9333e7f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,7 +52,7 @@ extend-select = [
"W", "C90", "I", "N", "UP", "S", "BLE", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "ISC", "T20", "PT", "RSE", "RET",
"SIM", "PTH", "ERA", "PGH", "PL", "RUF", "FURB", "PERF"
]
-extend-ignore = ["S101", "ISC002", "COM812", "ISC001"]
+extend-ignore = ["S101", "ISC002", "COM812", "ISC001", "EM101", "EM102"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
diff --git a/src/huesoporro/api/errors.py b/src/huesoporro/api/errors.py
index e88a6f0..19426df 100644
--- a/src/huesoporro/api/errors.py
+++ b/src/huesoporro/api/errors.py
@@ -30,7 +30,7 @@ def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
)
-async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # noqa: F821
+async def after_exception_handler(exc: Exception, scope: "Scope") -> None: # type: ignore[name-defined] # noqa: F821
"""Hook function that will be invoked after each exception."""
state = scope["app"].state
if not hasattr(state, "error_count"):
diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py
index dff54c2..d427cde 100644
--- a/src/huesoporro/api/routes/api.py
+++ b/src/huesoporro/api/routes/api.py
@@ -52,13 +52,20 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
@put("/api/v1/bot")
async def manage_bot(
- user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
+ user: User,
+ data: ManageBotDTO,
+ gbs: ChatbotSettingsGetterSvc,
+ sbs: ChatbotSettingsStorerSvc,
+ bm: BotsManager,
) -> Response:
chatbot_settings = await gbs.run(user=user)
+ if not chatbot_settings:
+ await sbs.run(user=user, bot_settings=ChatbotSettings())
+ chatbot_settings = await gbs.run(user=user)
if data.command == "start":
if not data.channel_name:
return Response({"message": "Channel name is required"}, status_code=400)
- bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings)
+ bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
if user.user in bm.bots:
await bm.run_user_bot(user)
return Response({"message": "Bot started"})
@@ -78,8 +85,11 @@ async def get_bot_status(user: User, bm: BotsManager) -> dict:
@get("/api/v1/bot/settings")
async def get_bot_settings(
user: User, gbs: ChatbotSettingsGetterSvc
-) -> ChatbotSettings:
- return await gbs.run(user=user)
+) -> ChatbotSettings | dict:
+ cbs = await gbs.run(user=user)
+ if not cbs:
+ return {"status": "Not found"}
+ return cbs
@put("/api/v1/bot/settings")
diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py
index 82151c0..fae650e 100644
--- a/src/huesoporro/bot.py
+++ b/src/huesoporro/bot.py
@@ -1,4 +1,7 @@
import asyncio
+import random
+from collections.abc import Callable
+from enum import StrEnum
from loguru import logger
from twitchio import Channel
@@ -8,6 +11,7 @@ from src.huesoporro.actions.store_quote import StoreQuoteAction
from src.huesoporro.infra.db import Database
from src.huesoporro.libs.db import Database as MarkovDB
from src.huesoporro.models import ChatbotSettings, User
+from src.huesoporro.svc.backoff_service import BackoffService
from src.huesoporro.svc.generate import SentenceGeneratorSvc
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
from src.huesoporro.svc.hello import HelloGeneratorSvc
@@ -75,16 +79,18 @@ class Bot(commands.Bot):
@commands.command(aliases=["q", "quote"])
async def get_random_quote(self, ctx: commands.Context):
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
- await ctx.send(f"«{quote[0]}» - {quote[1]}")
+ if quote:
+ await ctx.send(f"«{quote[0]}» - {quote[1]}")
def get_channel_conn(self) -> Channel:
return Channel(name=self.channel, websocket=self._connection)
async def send_quote(self):
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
- channel = self.get_channel_conn()
- logger.info(f"Sending random quote {quote[0]}")
- await channel.send(f"«{quote[0]}» - {quote[1]}")
+ if quote:
+ channel = self.get_channel_conn()
+ logger.info(f"Sending random quote {quote[0]}")
+ await channel.send(f"«{quote[0]}» - {quote[1]}")
async def send_generation(self):
sentence = await self.generate_svc.run()
@@ -108,14 +114,15 @@ class Bot(commands.Bot):
self.generation_routine.cancel()
-class SaveMessagesCog(commands.Cog):
+class SaveMessagesCog2(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
+ self.hello_svc = HelloGeneratorSvc()
+ self.backoff_svc = BackoffService()
@commands.Cog.event()
async def event_message(self, message):
- # An event inside a cog!
content = message.content
if content.startswith("!"):
return
@@ -125,6 +132,159 @@ class SaveMessagesCog(commands.Cog):
await self.store_svc.run(content)
+ if message.content in ["hola", "HOLA", "hiii", "ayo"]:
+ hello_message = self.hello_svc.run(message.author.name)
+ await message.channel.send(hello_message)
+ return
+
+ if message.content == "Yes":
+ await message.channel.send("Indeed")
+ return
+
+ if message.content.startswith("WHAT"):
+ await message.channel.send("WHAT Ramon")
+ return
+
+ laughs_messages = [
+ "om",
+ "KEK",
+ "LuL",
+ "LUL",
+ "OMEGALUL",
+ "kek",
+ "keking",
+ "KEKW",
+ "OMEGADANCEBUTFAST",
+ ]
+
+ if message.content in laughs_messages:
+ await message.channel.send(random.choice(laughs_messages)) # noqa: S311
+ return
+
+
+class MessageType(StrEnum):
+ COMMAND = "COMMAND"
+ HELLO = "HELLO"
+ YES = "YES"
+ WHAT = "WHAT"
+ LAUGH = "LAUGH"
+ OTHER = "OTHER"
+
+
+class MessageHandler:
+ """Handles different types of messages with their corresponding responses"""
+
+ def __init__(self, channel_send_func: Callable):
+ self.hello_patterns = ["hola", "HOLA", "hiii", "ayo"]
+ self.laugh_patterns = [
+ "om",
+ "KEK",
+ "LuL",
+ "LUL",
+ "OMEGALUL",
+ "kek",
+ "keking",
+ "KEKW",
+ "OMEGADANCEBUTFAST",
+ ]
+ self.send = channel_send_func
+
+ def get_message_type(self, content: str) -> MessageType:
+ """Determines the type of message based on its content"""
+ if content.startswith("!"):
+ return MessageType.COMMAND
+ if content in self.hello_patterns:
+ return MessageType.HELLO
+ if content == "Yes":
+ return MessageType.YES
+ if content.startswith("WHAT"):
+ return MessageType.WHAT
+ if content in self.laugh_patterns:
+ return MessageType.LAUGH
+ return MessageType.OTHER
+
+ async def handle_hello(self, author_name: str, hello_svc) -> str:
+ """Handles hello messages"""
+ return hello_svc.run(author_name)
+
+ async def handle_laugh(self) -> str:
+ """Handles laugh messages"""
+ return random.choice(self.laugh_patterns) # noqa: S311
+
+
+class SaveMessagesCog(commands.Cog):
+ def __init__(self, bot):
+ self.bot = bot
+ self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
+ self.hello_svc = HelloGeneratorSvc()
+ self.backoff_svc = BackoffService()
+ self.message_handler = MessageHandler(self._send_message)
+
+ # Register a separate send function for each message type
+ self.send_functions = {
+ MessageType.HELLO: self._create_typed_send("hello"),
+ MessageType.YES: self._create_typed_send("yes"),
+ MessageType.WHAT: self._create_typed_send("what"),
+ MessageType.LAUGH: self._create_typed_send("laugh"),
+ }
+
+ # Register each send function with its own backoff
+ for func in self.send_functions.values():
+ self.backoff_svc.add_callable(func, backoff_seconds=10)
+
+ def _create_typed_send(self, type_name: str):
+ """Creates a send function for a specific message type"""
+
+ async def typed_send(content: str):
+ if hasattr(self, "current_message"):
+ await self.current_message.channel.send(content)
+
+ # Set a unique name for the function to ensure it's treated as distinct
+ typed_send.__name__ = f"send_{type_name}"
+ return typed_send
+
+ async def _send_message(self, content: str):
+ """Generic send message function (for non-backoff uses)"""
+ if hasattr(self, "current_message"):
+ await self.current_message.channel.send(content)
+
+ @commands.Cog.event()
+ async def event_message(self, message):
+ """Main message event handler"""
+ if not message.author:
+ return
+
+ # Store reference to current message for send functions
+ self.current_message = message
+
+ # Store the message content
+ await self.store_svc.run(message.content)
+
+ # Determine message type and handle accordingly
+ msg_type = self.message_handler.get_message_type(message.content)
+
+ response = None
+
+ match msg_type:
+ case MessageType.COMMAND:
+ return
+ case MessageType.HELLO:
+ response = await self.message_handler.handle_hello(
+ message.author.name, self.hello_svc
+ )
+ case MessageType.YES:
+ response = "Indeed"
+ case MessageType.WHAT:
+ response = "WHAT Ramon"
+ case MessageType.LAUGH:
+ response = await self.message_handler.handle_laugh()
+ case MessageType.OTHER:
+ return
+
+ if response and msg_type in self.send_functions:
+ # Use the type-specific send function
+ await self.backoff_svc.call_async(self.send_functions[msg_type], response)
+
class BotsManager:
def __init__(self):
diff --git a/src/huesoporro/svc/backoff_service.py b/src/huesoporro/svc/backoff_service.py
new file mode 100644
index 0000000..08123d1
--- /dev/null
+++ b/src/huesoporro/svc/backoff_service.py
@@ -0,0 +1,111 @@
+import asyncio
+import time
+from collections.abc import Callable
+
+from pydantic import BaseModel
+
+
+class CallableInfo(BaseModel):
+ backoff_seconds: int
+ last_call: float | None = None
+ is_async: bool
+
+
+class BackoffService(BaseModel):
+ """Use this service to implement a backoff strategy on random callables.
+ The callable will be called the first time without delay but every subsequent
+ call may be hold off for a given time
+
+ Examples:
+ >>> def callable(x): print(f"foo {x}")
+ >>> backoff_service = BackoffService()
+ >>> backoff_service.add_callable(callable, backoff_time=3)
+ >>> backoff_service.call(callable, "bar") # prints "foo bar"
+ >>> backoff_service.call(callable, "baz") # prints nothing
+ >>> # wait 3 seconds before calling callable again
+ >>> backoff_service.call(callable, "qux") # prints "foo qux"
+ """
+
+ callables: dict[Callable, CallableInfo] = {}
+
+ def add_callable(self, func: Callable, backoff_seconds: int):
+ """Adds a callable to the local mapper with its backoff configuration.
+
+ Args:
+ func: The function to be registered
+ backoff_seconds: The number of seconds to wait between successive calls
+ """
+ self.callables[func] = CallableInfo(
+ backoff_seconds=backoff_seconds, is_async=self._is_async(func)
+ )
+
+ @staticmethod
+ def _is_async(func: Callable) -> bool:
+ """Checks if the callable is async"""
+ return asyncio.iscoroutinefunction(func)
+
+ def _can_call(self, func: Callable) -> bool:
+ """Determines if enough time has passed since the last call"""
+ if func not in self.callables:
+ raise ValueError(f"Function {func} not registered with backoff service")
+
+ func_info = self.callables[func]
+ last_call = func_info.last_call
+
+ if last_call is None:
+ return True
+
+ elapsed = time.time() - last_call
+ return elapsed >= func_info.backoff_seconds
+
+ def call(self, func: Callable, *args, **kwargs):
+ """Calls the callable with arguments and returns its result if it isn't held off
+
+ Args:
+ func: The function to call
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ Optional[Any]: The result of the function call if executed, None if held off
+ """
+ if func not in self.callables:
+ raise ValueError(f"Function {func} not registered with backoff service")
+
+ if self.callables[func].is_async:
+ raise ValueError(
+ "Cannot call async function with .call(), use .call_async() instead"
+ )
+
+ if not self._can_call(func):
+ return None
+
+ result = func(*args, **kwargs)
+ self.callables[func].last_call = time.time()
+ return result
+
+ async def call_async(self, func: Callable, *args, **kwargs):
+ """Same as .call(...) but for async functions
+
+ Args:
+ func: The async function to call
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ Optional[Any]: The result of the async function call if executed, None if held off
+ """
+ if func not in self.callables:
+ raise ValueError(f"Function {func} not registered with backoff service")
+
+ if not self.callables[func].is_async:
+ raise ValueError(
+ "Cannot call sync function with .call_async(), use .call() instead"
+ )
+
+ if not self._can_call(func):
+ return None
+
+ result = await func(*args, **kwargs)
+ self.callables[func].last_call = time.time()
+ return result
diff --git a/src/huesoporro/svc/generate.py b/src/huesoporro/svc/generate.py
index 0f5058d..5bef7fe 100644
--- a/src/huesoporro/svc/generate.py
+++ b/src/huesoporro/svc/generate.py
@@ -133,11 +133,11 @@ class SentenceGeneratorSvc(BaseModel):
self,
sentence: str | None = None,
) -> str | None:
- if sentence:
- sentence = tokenize(sentence)
- logger.info(f"Generating sentence from {sentence}")
- sentence, success = self.generate(sentence)
- logger.info(f"Generated sentence: {sentence}")
+ split_sentence = tokenize(sentence) if sentence else None
+
+ logger.info(f"Generating sentence from {split_sentence}")
+ generated_sentence, success = self.generate(split_sentence)
+ logger.info(f"Generated sentence: {generated_sentence}")
if not success:
return None
- return sentence
+ return generated_sentence
diff --git a/src/huesoporro/svc/get_random_quote.py b/src/huesoporro/svc/get_random_quote.py
index 4371d4e..f6b4f6c 100644
--- a/src/huesoporro/svc/get_random_quote.py
+++ b/src/huesoporro/svc/get_random_quote.py
@@ -6,5 +6,5 @@ from src.huesoporro.infra.db import Database
class RandomQuoteGetterSvc(BaseModel):
db: Database
- async def run(self, channel_name: str) -> tuple[str, str]:
+ async def run(self, channel_name: str) -> tuple[str, str] | None:
return await self.db.get_random_quote(channel_name=channel_name)
diff --git a/src/huesoporro/svc/hello.py b/src/huesoporro/svc/hello.py
index 19dd38f..5be2180 100644
--- a/src/huesoporro/svc/hello.py
+++ b/src/huesoporro/svc/hello.py
@@ -11,8 +11,10 @@ class HelloGeneratorSvc(BaseModel):
"Hi",
"Bon día",
"Hola mi tremendo elemento",
+ "HOLA",
+ "hiii",
]
)
def run(self, username: str):
- return f"{random.choice(self.hellos)} {username}" # noqa: S311
+ return f"{random.choice(self.hellos)} @{username}" # noqa: S311
diff --git a/tests/conftest.py b/tests/conftest.py
index b01710c..19218d9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ from caribou.migrate import load_migrations
from src.huesoporro.infra.db import Database
from src.huesoporro.models import ChatbotSettings, TwitchAuth, User
from src.huesoporro.settings import Settings
+from src.huesoporro.svc.backoff_service import BackoffService
from src.huesoporro.svc.is_mod import IsModSvc
@@ -16,7 +17,8 @@ def user() -> User:
user="huesoporro",
expires_at=1671234567.0,
twitch_auth=TwitchAuth(
- access_token="test_access_token", refresh_token="test_refresh_token"
+ access_token="test_access_token", # noqa: S106
+ refresh_token="test_refresh_token", # noqa: S106
),
)
@@ -27,8 +29,8 @@ def s(tmp_path: Path, user: User) -> Settings:
static_files_path=tmp_path / "static_files",
db_filepath=tmp_path / "huesoporro.db",
twitch_client_id="test_client_id",
- twitch_client_secret="test_client_secret", # type: ignore[arg-type]
- jwt_secret="test_jwt_secret", # type: ignore[arg-type]
+ twitch_client_secret="test_client_secret", # type: ignore[arg-type] # noqa: S106
+ jwt_secret="test_jwt_secret", # type: ignore[arg-type] # noqa: S106
allowed_users=[user.user],
)
@@ -54,3 +56,27 @@ async def chatbot_settings(db: Database, user) -> ChatbotSettings:
cbs = ChatbotSettings(mods=[user.user, "allowed_user"])
await db.save_chatbot_settings(user=user, chatbot_settings=cbs)
return cbs
+
+
+@pytest.fixture
+def backoff_callable():
+ def foo():
+ return "foo"
+
+ return foo
+
+
+@pytest.fixture
+def async_backoff_callable():
+ async def async_foo():
+ return "async foo"
+
+ return async_foo
+
+
+@pytest.fixture
+async def backoff_svc(backoff_callable, async_backoff_callable):
+ backoff_svc = BackoffService()
+ backoff_svc.add_callable(backoff_callable, 3)
+ backoff_svc.add_callable(async_backoff_callable, 3)
+ return backoff_svc
diff --git a/tests/test_svc.py b/tests/test_svc.py
index 5a5aacd..c88e1c9 100644
--- a/tests/test_svc.py
+++ b/tests/test_svc.py
@@ -1,3 +1,6 @@
+import asyncio
+import time
+
import pytest
from src.huesoporro.models import ChatbotSettings, User
@@ -33,3 +36,61 @@ async def test_is_mod_svc_returns_false_for_user_not_in_modlist(
):
is_mod = await is_mod_svc.run(user=user, username="TestUser2", channel=user.user)
assert not is_mod
+
+
+async def test_backoff_svc_returns_for_first_attempt(
+ backoff_svc, backoff_callable, async_backoff_callable
+):
+ assert backoff_svc.call(backoff_callable) == "foo"
+
+ assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
+
+
+async def test_backoff_svc_returns_none_for_second_attempt(
+ backoff_svc, backoff_callable, async_backoff_callable
+):
+ assert backoff_svc.call(backoff_callable) == "foo"
+ assert backoff_svc.call(backoff_callable) is None
+
+ assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
+ assert await backoff_svc.call_async(async_backoff_callable) is None
+
+
+async def test_backoff_svc_returns_for_second_attempt_after_delay(
+ backoff_svc, backoff_callable, async_backoff_callable
+):
+ assert backoff_svc.call(backoff_callable) == "foo"
+ assert backoff_svc.call(backoff_callable) is None
+ time.sleep(3)
+ assert backoff_svc.call(backoff_callable) == "foo"
+
+ assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
+ assert await backoff_svc.call_async(async_backoff_callable) is None
+ await asyncio.sleep(3)
+ assert await backoff_svc.call_async(async_backoff_callable) == "async foo"
+
+
+async def test_backoff_svc_raises_value_error_for_unknown_callable(backoff_svc):
+ with pytest.raises(ValueError, match="not registered with backoff service"):
+ backoff_svc.call(lambda: "foo")
+
+
+async def test_backoff_svc_raises_value_error_for_unknown_async_callable(backoff_svc):
+ with pytest.raises(ValueError, match="not registered with backoff service"):
+ await backoff_svc.call_async(lambda: "foo")
+
+
+async def test_backoff_svc_raises_value_error_for_async_called_from_sync(
+ backoff_svc, backoff_callable
+):
+ with pytest.raises(
+ ValueError, match="Cannot call sync function with .call_async()"
+ ):
+ await backoff_svc.call_async(backoff_callable)
+
+
+async def test_backoff_svc_raises_value_error_for_sync_called_from_async(
+ backoff_svc, async_backoff_callable
+):
+ with pytest.raises(ValueError, match="Cannot call async function with .call()"):
+ backoff_svc.call(async_backoff_callable)
diff --git a/uv.lock b/uv.lock
index 618cfee..458a1a4 100644
--- a/uv.lock
+++ b/uv.lock
@@ -460,7 +460,7 @@ wheels = [
[[package]]
name = "huesoporro"
-version = "0.2.2"
+version = "0.2.3"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },
From efac1cc33ccf807f3c2b1a176286db24af891c23 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?c=C4=83t=C4=83lin?=
Date: Thu, 19 Dec 2024 18:18:37 +0100
Subject: [PATCH 02/27] chore: update to v0.2.4 and remove useless code
---
charts/huesoporro/Chart.yaml | 4 +--
charts/huesoporro/values.yaml | 2 +-
pyproject.toml | 2 +-
src/huesoporro/bot.py | 48 -----------------------------------
uv.lock | 2 +-
5 files changed, 5 insertions(+), 53 deletions(-)
diff --git a/charts/huesoporro/Chart.yaml b/charts/huesoporro/Chart.yaml
index 1c6576a..22ca75f 100644
--- a/charts/huesoporro/Chart.yaml
+++ b/charts/huesoporro/Chart.yaml
@@ -15,10 +15,10 @@ type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
-version: 0.2.3
+version: 0.2.4
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
-appVersion: "0.2.3"
+appVersion: "0.2.4"
diff --git a/charts/huesoporro/values.yaml b/charts/huesoporro/values.yaml
index ae50d59..48b749a 100644
--- a/charts/huesoporro/values.yaml
+++ b/charts/huesoporro/values.yaml
@@ -11,7 +11,7 @@ image:
# This sets the pull policy for images.
pullPolicy: Always
# Overrides the image tag whose default is the chart appVersion.
- tag: "0.2.3"
+ tag: "0.2.4"
# This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/
imagePullSecrets: []
diff --git a/pyproject.toml b/pyproject.toml
index 9333e7f..5b913bf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "huesoporro"
-version = "0.2.3"
+version = "0.2.4"
description = "Misc Twitch bots"
readme = "README.md"
authors = [
diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py
index fae650e..e6b3c36 100644
--- a/src/huesoporro/bot.py
+++ b/src/huesoporro/bot.py
@@ -114,54 +114,6 @@ class Bot(commands.Bot):
self.generation_routine.cancel()
-class SaveMessagesCog2(commands.Cog):
- def __init__(self, bot):
- self.bot = bot
- self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
- self.hello_svc = HelloGeneratorSvc()
- self.backoff_svc = BackoffService()
-
- @commands.Cog.event()
- async def event_message(self, message):
- content = message.content
- if content.startswith("!"):
- return
-
- if not message.author:
- return
-
- await self.store_svc.run(content)
-
- if message.content in ["hola", "HOLA", "hiii", "ayo"]:
- hello_message = self.hello_svc.run(message.author.name)
- await message.channel.send(hello_message)
- return
-
- if message.content == "Yes":
- await message.channel.send("Indeed")
- return
-
- if message.content.startswith("WHAT"):
- await message.channel.send("WHAT Ramon")
- return
-
- laughs_messages = [
- "om",
- "KEK",
- "LuL",
- "LUL",
- "OMEGALUL",
- "kek",
- "keking",
- "KEKW",
- "OMEGADANCEBUTFAST",
- ]
-
- if message.content in laughs_messages:
- await message.channel.send(random.choice(laughs_messages)) # noqa: S311
- return
-
-
class MessageType(StrEnum):
COMMAND = "COMMAND"
HELLO = "HELLO"
diff --git a/uv.lock b/uv.lock
index 458a1a4..c6caf63 100644
--- a/uv.lock
+++ b/uv.lock
@@ -460,7 +460,7 @@ wheels = [
[[package]]
name = "huesoporro"
-version = "0.2.3"
+version = "0.2.4"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },
From 3186afe96defbd040a0d640bdab1455e7f170e8c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?c=C4=83t=C4=83lin?=
Date: Thu, 19 Dec 2024 18:50:05 +0100
Subject: [PATCH 03/27] fix: fix logout flow which wasn't being triggered,
remove useless html code
---
charts/huesoporro/Chart.yaml | 2 +-
charts/huesoporro/values.yaml | 2 +-
pyproject.toml | 2 +-
src/huesoporro/api/routes/api.py | 2 +-
src/huesoporro/static/js/utils.js | 33 -------------
src/huesoporro/templates/index.html | 12 ++---
src/huesoporro/templates/login.html | 62 +------------------------
src/huesoporro/templates/logout.html | 13 ++++++
src/huesoporro/templates/sentences.html | 15 +++---
src/huesoporro/templates/tts.html | 7 +--
uv.lock | 2 +-
11 files changed, 34 insertions(+), 118 deletions(-)
create mode 100644 src/huesoporro/templates/logout.html
diff --git a/charts/huesoporro/Chart.yaml b/charts/huesoporro/Chart.yaml
index 22ca75f..3d392ad 100644
--- a/charts/huesoporro/Chart.yaml
+++ b/charts/huesoporro/Chart.yaml
@@ -21,4 +21,4 @@ version: 0.2.4
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
-appVersion: "0.2.4"
+appVersion: "0.2.5"
diff --git a/charts/huesoporro/values.yaml b/charts/huesoporro/values.yaml
index 48b749a..ea27e5a 100644
--- a/charts/huesoporro/values.yaml
+++ b/charts/huesoporro/values.yaml
@@ -11,7 +11,7 @@ image:
# This sets the pull policy for images.
pullPolicy: Always
# Overrides the image tag whose default is the chart appVersion.
- tag: "0.2.4"
+ tag: "0.2.5"
# This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/
imagePullSecrets: []
diff --git a/pyproject.toml b/pyproject.toml
index 5b913bf..1da3c58 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "huesoporro"
-version = "0.2.4"
+version = "0.2.5"
description = "Misc Twitch bots"
readme = "README.md"
authors = [
diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py
index d427cde..9412598 100644
--- a/src/huesoporro/api/routes/api.py
+++ b/src/huesoporro/api/routes/api.py
@@ -19,7 +19,7 @@ class ManageBotDTO(BaseModel):
"/tts",
media_type=MediaType.HTML,
)
-async def get_tts_overlay() -> Template:
+async def get_tts_overlay(user: User) -> Template:
return Template(template_name="tts.html")
diff --git a/src/huesoporro/static/js/utils.js b/src/huesoporro/static/js/utils.js
index e014ab2..9bb1116 100644
--- a/src/huesoporro/static/js/utils.js
+++ b/src/huesoporro/static/js/utils.js
@@ -7,36 +7,3 @@ function getWebsocketProtocol() {
return "wss://";
}
}
-
-function addLogoutEvent() {
- const logoutButton = document.getElementById("logoutButton");
- logoutButton.addEventListener("click", () => {
- document.cookie = "twitchLoginData=; expires=Thu, 01 Jan 1970 00:00:00 UTC";
- window.location.href = "/";
- });
-
-}
-
-function setCookie(name, value, days) {
- const date = new Date();
- date.setTime(date.getTime() + (days * 24 * 60 * 60 * 1000));
- const expires = `expires=${date.toUTCString()}`;
- document.cookie = `${name}=${value};${expires};path=/;SameSite=Strict`;
-}
-
-function getCookie(name) {
- const cookieName = `${name}=`;
- const decodedCookie = decodeURIComponent(document.cookie);
- const cookieArray = decodedCookie.split(';');
-
- for (let i = 0; i < cookieArray.length; i++) {
- let cookie = cookieArray[i];
- while (cookie.charAt(0) === ' ') {
- cookie = cookie.substring(1);
- }
- if (cookie.indexOf(cookieName) === 0) {
- return cookie.substring(cookieName.length, cookie.length);
- }
- }
- return null;
-}
diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html
index bdf08ee..b17419f 100644
--- a/src/huesoporro/templates/index.html
+++ b/src/huesoporro/templates/index.html
@@ -5,12 +5,10 @@
@@ -178,7 +176,7 @@
})
const data = await response.json()
console.log(data);
- if (response.ok){
+ if (response.ok) {
alert("Settings saved successfully")
}
}
@@ -186,8 +184,6 @@
const chatbotManager = new ChatbotManager();
chatbotManager.setEvents();
-
- addLogoutEvent()
});
Huesoporro🦴🍃
@@ -20,14 +17,18 @@
- | Sentence |
+ Sentence |
Action |
{% for sentence in sentences %}
| {{ sentence.sentence }} |
- |
+
+
+ |
{% endfor %}
diff --git a/src/huesoporro/templates/tts.html b/src/huesoporro/templates/tts.html
index a0f3266..5a7843e 100644
--- a/src/huesoporro/templates/tts.html
+++ b/src/huesoporro/templates/tts.html
@@ -8,10 +8,7 @@
TTS
Le Funny
-
-
+ {% include 'logout.html' %}
Huesoporro🦴🍃
@@ -186,7 +183,7 @@
// generate /tts/permalink?access_token=
// the access token is available in the twitchLoginData cookie
- const cookie = JSON.parse(getCookie("twitchLoginData"))
+ const cookie = JSON.parse(getCookie("huesoporroAuth"))
const permalinkUrl = `${window.location.origin}/tts/permalink?access_token=${cookie.access_token}`;
navigator.clipboard.writeText(permalinkUrl);
alert('OBS link copied to clipboard ' + permalinkUrl);
diff --git a/uv.lock b/uv.lock
index c6caf63..d861103 100644
--- a/uv.lock
+++ b/uv.lock
@@ -460,7 +460,7 @@ wheels = [
[[package]]
name = "huesoporro"
-version = "0.2.4"
+version = "0.2.5"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },
From 50900986fa5cb8b2d0065259ea95364d9bbe9beb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?c=C4=83t=C4=83lin?=
Date: Fri, 17 Jan 2025 18:15:58 +0100
Subject: [PATCH 04/27] feat: revamp authentication -- remove twitch's tokens
from our own wrapper token
---
devenv.lock | 50 +++++++-
devenv.nix | 7 +
devenv.yaml | 17 +--
migrations/20241219191711_sentences.py | 35 +++++
.../20250112153541_user_external_auth.py | 53 ++++++++
.../20250113142241_external_auth_json.py | 35 +++++
pyproject.toml | 2 +
src/huesoporro/actions/authenticate.py | 27 ++++
src/huesoporro/actions/get_user_by_jwt.py | 38 ++++++
src/huesoporro/actions/refresh.py | 27 ++++
src/huesoporro/api/dependencies.py | 39 ++++--
src/huesoporro/api/main.py | 15 ++-
src/huesoporro/api/routes/api.py | 17 ++-
src/huesoporro/api/routes/auth.py | 8 +-
src/huesoporro/bot.py | 2 +-
src/huesoporro/infra/authenticator.py | 16 ++-
src/huesoporro/infra/db.py | 43 ++-----
src/huesoporro/infra/repos.py | 114 +++++++++++++++++
src/huesoporro/models.py | 35 +++--
src/huesoporro/svc/authenticate.py | 26 ----
src/huesoporro/svc/get_sentences_svc.py | 11 ++
src/huesoporro/svc/refresh.py | 27 ----
src/huesoporro/templates/header.html | 11 +-
src/huesoporro/templates/index.html | 10 +-
.../templates/le_funny_dropdown.html | 10 ++
src/huesoporro/templates/login.html | 4 +-
src/huesoporro/templates/logout.html | 2 +-
src/huesoporro/templates/sentences.html | 121 ++++++++++++++++--
tests/conftest.py | 11 +-
tests/test_repos.py | 53 ++++++++
uv.lock | 25 ++++
31 files changed, 736 insertions(+), 155 deletions(-)
create mode 100644 migrations/20241219191711_sentences.py
create mode 100644 migrations/20250112153541_user_external_auth.py
create mode 100644 migrations/20250113142241_external_auth_json.py
create mode 100644 src/huesoporro/actions/authenticate.py
create mode 100644 src/huesoporro/actions/get_user_by_jwt.py
create mode 100644 src/huesoporro/actions/refresh.py
create mode 100644 src/huesoporro/infra/repos.py
delete mode 100644 src/huesoporro/svc/authenticate.py
create mode 100644 src/huesoporro/svc/get_sentences_svc.py
delete mode 100644 src/huesoporro/svc/refresh.py
create mode 100644 src/huesoporro/templates/le_funny_dropdown.html
create mode 100644 tests/test_repos.py
diff --git a/devenv.lock b/devenv.lock
index 8ab1706..97f804a 100644
--- a/devenv.lock
+++ b/devenv.lock
@@ -3,10 +3,10 @@
"devenv": {
"locked": {
"dir": "src/modules",
- "lastModified": 1733788855,
+ "lastModified": 1735530587,
"owner": "cachix",
"repo": "devenv",
- "rev": "d59fee8696cd48f69cf79f65992269df9891ba86",
+ "rev": "69645885c1052cc1ca398ac30ba7dfc63386c0e3",
"type": "github"
},
"original": {
@@ -31,6 +31,21 @@
"type": "github"
}
},
+ "flake-compat_2": {
+ "flake": false,
+ "locked": {
+ "lastModified": 1733328505,
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
+ "type": "github"
+ },
+ "original": {
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "type": "github"
+ }
+ },
"gitignore": {
"inputs": {
"nixpkgs": [
@@ -66,12 +81,32 @@
"type": "github"
}
},
+ "nixpkgs-python": {
+ "inputs": {
+ "flake-compat": "flake-compat",
+ "nixpkgs": [
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1733319315,
+ "owner": "cachix",
+ "repo": "nixpkgs-python",
+ "rev": "01263eeb28c09f143d59cd6b0b7c4cc8478efd48",
+ "type": "github"
+ },
+ "original": {
+ "owner": "cachix",
+ "repo": "nixpkgs-python",
+ "type": "github"
+ }
+ },
"nixpkgs-stable": {
"locked": {
- "lastModified": 1733730953,
+ "lastModified": 1735286948,
"owner": "NixOS",
"repo": "nixpkgs",
- "rev": "7109b680d161993918b0a126f38bc39763e5a709",
+ "rev": "31ac92f9628682b294026f0860e14587a09ffb4b",
"type": "github"
},
"original": {
@@ -83,7 +118,7 @@
},
"pre-commit-hooks": {
"inputs": {
- "flake-compat": "flake-compat",
+ "flake-compat": "flake-compat_2",
"gitignore": "gitignore",
"nixpkgs": [
"nixpkgs"
@@ -91,10 +126,10 @@
"nixpkgs-stable": "nixpkgs-stable"
},
"locked": {
- "lastModified": 1733665616,
+ "lastModified": 1734797603,
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
- "rev": "d8c02f0ffef0ef39f6063731fc539d8c71eb463a",
+ "rev": "f0f0dc4920a903c3e08f5bdb9246bb572fcae498",
"type": "github"
},
"original": {
@@ -107,6 +142,7 @@
"inputs": {
"devenv": "devenv",
"nixpkgs": "nixpkgs",
+ "nixpkgs-python": "nixpkgs-python",
"pre-commit-hooks": "pre-commit-hooks"
}
}
diff --git a/devenv.nix b/devenv.nix
index 771f4d7..eea6c6c 100644
--- a/devenv.nix
+++ b/devenv.nix
@@ -5,8 +5,15 @@
packages = [ pkgs.git ];
+ certificates = [
+ "id.twitch.tv"
+ "twitch.tv"
+ "discord.com"
+ ];
+
languages.python.enable = true;
languages.python.uv.enable = true;
+ languages.python.version = "3.12.8";
scripts.hello.exec = ''
echo hello from $GREET
diff --git a/devenv.yaml b/devenv.yaml
index 116a2ad..184b866 100644
--- a/devenv.yaml
+++ b/devenv.yaml
@@ -1,15 +1,8 @@
-# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json
inputs:
nixpkgs:
url: github:cachix/devenv-nixpkgs/rolling
-
-# If you're using non-OSS software, you can set allowUnfree to true.
-# allowUnfree: true
-
-# If you're willing to use a package that's vulnerable
-# permittedInsecurePackages:
-# - "openssl-1.1.1w"
-
-# If you have more than one devenv you can merge them
-#imports:
-# - ./backend
+ nixpkgs-python:
+ url: github:cachix/nixpkgs-python
+ inputs:
+ nixpkgs:
+ follows: nixpkgs
diff --git a/migrations/20241219191711_sentences.py b/migrations/20241219191711_sentences.py
new file mode 100644
index 0000000..06830e0
--- /dev/null
+++ b/migrations/20241219191711_sentences.py
@@ -0,0 +1,35 @@
+"""
+This module contains a Caribou migration.
+
+Migration Name: sentences
+Migration Version: 20241219191711
+"""
+
+
+def upgrade(connection):
+ # update table `sentences` to have a user_id row
+ # which references users.id
+ # and a channel VARCHAR(255) row
+
+ sql = """
+ DROP TABLE IF EXISTS sentences;
+ """
+ connection.execute(sql)
+ connection.commit()
+ sql = """
+ CREATE TABLE sentences(
+ id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ sentence VARCHAR(255) NOT NULL UNIQUE,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ user_id VARCHAR(255) NOT NULL,
+ FOREIGN KEY (user_id) REFERENCES users(id)
+ );
+ """
+ connection.execute(sql)
+ connection.commit()
+
+
+def downgrade(connection):
+ # add your downgrade step here
+ pass
diff --git a/migrations/20250112153541_user_external_auth.py b/migrations/20250112153541_user_external_auth.py
new file mode 100644
index 0000000..d873a70
--- /dev/null
+++ b/migrations/20250112153541_user_external_auth.py
@@ -0,0 +1,53 @@
+"""
+This module contains a Caribou migration.
+
+Migration Name: user_external_auth
+Migration Version: 20250112153541
+"""
+
+
+def upgrade(connection):
+ """
+ - delete access_token, refresh_token, and expires_at from users
+ - add external_auth table which will store the external auths:
+ - type: twitch or discord
+ - credentials: JSON
+ """
+
+ sql = """
+ ALTER TABLE users DROP COLUMN access_token;
+ """
+ connection.execute(sql)
+ sql = """
+ ALTER TABLE users DROP COLUMN refresh_token;
+ """
+ connection.execute(sql)
+ sql = """
+ ALTER TABLE users DROP COLUMN expires_at;
+ """
+ connection.execute(sql)
+
+ sql = """
+ CREATE TABLE external_auth(
+ id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ type VARCHAR(255) NOT NULL,
+ credentials JSON NOT NULL
+ );
+ """
+ connection.execute(sql)
+
+ sql = """
+ CREATE TABLE user_external_auth(
+ user_id VARCHAR(255) NOT NULL,
+ external_auth_id INTEGER NOT NULL,
+ FOREIGN KEY (user_id) REFERENCES users(id),
+ FOREIGN KEY (external_auth_id) REFERENCES external_auth(id)
+ );
+ """
+ connection.execute(sql)
+ connection.commit()
+
+
+def downgrade(connection):
+ # add your downgrade step here
+ pass
diff --git a/migrations/20250113142241_external_auth_json.py b/migrations/20250113142241_external_auth_json.py
new file mode 100644
index 0000000..e92e847
--- /dev/null
+++ b/migrations/20250113142241_external_auth_json.py
@@ -0,0 +1,35 @@
+"""
+This module contains a Caribou migration.
+
+Migration Name: external_auth_json
+Migration Version: 20250113142241
+"""
+
+
+def upgrade(connection):
+ """remove tables:
+ - external_auth
+ - user_external_auth
+ add column to users table:
+ - external_auth JSON
+ """
+ sql = """
+ DROP TABLE IF EXISTS external_auth;
+ """
+ connection.execute(sql)
+
+ sql = """
+ DROP TABLE IF EXISTS user_external_auth;
+ """
+ connection.execute(sql)
+
+ sql = """
+ ALTER TABLE users ADD COLUMN external_auth JSON;
+ """
+ connection.execute(sql)
+ connection.commit()
+
+
+def downgrade(connection):
+ # add your downgrade step here
+ pass
diff --git a/pyproject.toml b/pyproject.toml
index 1da3c58..d2f948a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,8 @@ dependencies = [
"pyjwt>=2.10.1",
"twitchio>=2.10.0",
"redis>=5.2.1",
+ "pytz>=2024.2",
+ "discord-py>=2.4.0",
]
[tool.uv]
diff --git a/src/huesoporro/actions/authenticate.py b/src/huesoporro/actions/authenticate.py
new file mode 100644
index 0000000..9bf7b6a
--- /dev/null
+++ b/src/huesoporro/actions/authenticate.py
@@ -0,0 +1,27 @@
+from pydantic import BaseModel
+
+from src.huesoporro.infra.authenticator import TwitchAuthenticator
+from src.huesoporro.infra.repos import UserRepo
+from src.huesoporro.models import User
+from src.huesoporro.settings import Settings
+
+
+class AuthenticateAction(BaseModel):
+ user_repo: UserRepo
+ authenticator: TwitchAuthenticator
+ s: Settings
+
+ async def run(
+ self,
+ auth_code: str,
+ ):
+ tokens = await self.authenticator.get_token(auth_code)
+ username = tokens.userinfo["preferred_username"]
+ if username not in self.s.allowed_users:
+ raise ValueError(f"User {username} is not allowed to use this bot")
+ user = User(user=username, external_auth={"twitch": tokens.model_dump()})
+ if await self.user_repo.get_by_user(user.user):
+ await self.user_repo.update(user)
+ else:
+ await self.user_repo.create(user)
+ return user.encode()
diff --git a/src/huesoporro/actions/get_user_by_jwt.py b/src/huesoporro/actions/get_user_by_jwt.py
new file mode 100644
index 0000000..b5311ab
--- /dev/null
+++ b/src/huesoporro/actions/get_user_by_jwt.py
@@ -0,0 +1,38 @@
+from loguru import logger
+from pydantic import BaseModel
+
+from src.huesoporro.infra.authenticator import TwitchAuthenticator
+from src.huesoporro.infra.repos import UserRepo
+from src.huesoporro.models import User
+from src.huesoporro.settings import Settings
+
+
+class GetUserByJWTAction(BaseModel):
+ user_repo: UserRepo
+ authenticator: TwitchAuthenticator
+ s: Settings
+
+ async def run(
+ self,
+ jwt_token: str,
+ ) -> User:
+ user_data = User.decode(jwt_token)
+ username = user_data["user"]
+ user = await self.user_repo.get_by_user(username)
+ if not user:
+ raise ValueError(f"User {username} not found")
+ is_valid = await self.authenticator.token_is_valid(
+ user.external_auth["twitch"]["access_token"]
+ )
+ logger.info(f"Token {user} is valid: {is_valid}")
+ if not is_valid:
+ logger.info(f"Refreshing token for user {user}")
+ new_tokens = await self.authenticator.refresh_token(
+ user.external_auth["twitch"]["refresh_token"]
+ )
+ user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index]
+ user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index]
+ await self.user_repo.update(user)
+ return user
+
+ return user
diff --git a/src/huesoporro/actions/refresh.py b/src/huesoporro/actions/refresh.py
new file mode 100644
index 0000000..4ba8543
--- /dev/null
+++ b/src/huesoporro/actions/refresh.py
@@ -0,0 +1,27 @@
+from pydantic import BaseModel
+
+from src.huesoporro.infra.authenticator import TwitchAuthenticator
+from src.huesoporro.infra.repos import UserRepo
+from src.huesoporro.models import User
+from src.huesoporro.settings import Settings
+
+
+class RefreshAction(BaseModel):
+ user_repo: UserRepo
+ authenticator: TwitchAuthenticator
+ s: Settings
+
+ async def run(self, user: User):
+ is_valid = await self.authenticator.token_is_valid(
+ user.external_auth["twitch"]["access_token"]
+ )
+
+ if not is_valid:
+ new_tokens = await self.authenticator.refresh_token(
+ user.external_auth["twitch"]["refresh_token"]
+ )
+ user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index]
+ user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index]
+ await self.user_repo.update(user)
+ return user.encode()
+ return None
diff --git a/src/huesoporro/api/dependencies.py b/src/huesoporro/api/dependencies.py
index 5db8371..d1d9812 100644
--- a/src/huesoporro/api/dependencies.py
+++ b/src/huesoporro/api/dependencies.py
@@ -1,12 +1,15 @@
from litestar import Request
from litestar.exceptions import HTTPException
+from src.huesoporro.actions.authenticate import AuthenticateAction
+from src.huesoporro.actions.get_user_by_jwt import GetUserByJWTAction
from src.huesoporro.infra.authenticator import TwitchAuthenticator
from src.huesoporro.infra.db import Database
+from src.huesoporro.infra.repos import UserRepo
from src.huesoporro.models import User
from src.huesoporro.settings import Settings
-from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc
from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
+from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc
from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
@@ -22,27 +25,43 @@ def get_db(s: Settings):
return Database(s=s)
-async def authenticate(request: Request) -> User:
+async def get_get_user_by_jwt_action(
+ user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings
+):
+ return GetUserByJWTAction(user_repo=user_repo, authenticator=authenticator, s=s)
+
+
+async def authenticate(
+ request: Request, get_user_by_jwt_action: GetUserByJWTAction
+) -> User:
token = request.query_params.get("huesoporro_token")
if token:
- return User.decode(token)
+ return await get_user_by_jwt_action.run(token)
cookies = request.cookies.get("huesoporroAuth")
if cookies:
- return User.decode(cookies)
+ return await get_user_by_jwt_action.run(cookies)
raise HTTPException(status_code=401, detail="Unauthorized")
-async def get_code_authenticator_svc(
- a: TwitchAuthenticator, db: Database
-) -> CodeAuthenticatorSvc:
- return CodeAuthenticatorSvc(authenticator=a, db=db)
-
-
async def get_chatbot_settings_svc(db: Database):
return ChatbotSettingsGetterSvc(db=db)
async def store_chatbot_settings_svc(db: Database):
return ChatbotSettingsStorerSvc(db=db)
+
+
+async def get_sentences_svc(db: Database):
+ return SentencesGetterSvc(db=db)
+
+
+async def get_user_repo(s: Settings):
+ return UserRepo(s=s)
+
+
+async def get_authenticate_action(
+ user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings
+):
+ return AuthenticateAction(user_repo=user_repo, authenticator=authenticator, s=s)
diff --git a/src/huesoporro/api/main.py b/src/huesoporro/api/main.py
index dd0489a..8df704e 100644
--- a/src/huesoporro/api/main.py
+++ b/src/huesoporro/api/main.py
@@ -8,11 +8,14 @@ from litestar.template import TemplateConfig
from src.huesoporro.api.dependencies import (
authenticate,
+ get_authenticate_action,
get_authenticator,
get_chatbot_settings_svc,
- get_code_authenticator_svc,
get_db,
+ get_get_user_by_jwt_action,
+ get_sentences_svc,
get_settings,
+ get_user_repo,
store_chatbot_settings_svc,
)
from src.huesoporro.api.errors import (
@@ -24,10 +27,12 @@ from src.huesoporro.api.routes.api import (
get_bot_settings,
get_bot_status,
get_index,
+ get_sentences,
get_tts_overlay,
get_tts_permalink,
manage_bot,
save_bot_settings,
+ save_new_sentence,
)
from src.huesoporro.api.routes.auth import get_code, login
from src.huesoporro.bot import BotsManager
@@ -52,6 +57,8 @@ def create_app():
get_bot_status,
save_bot_settings,
get_bot_settings,
+ get_sentences,
+ save_new_sentence,
],
static_files_config=(
StaticFilesConfig(
@@ -77,10 +84,14 @@ def create_app():
"a": Provide(get_authenticator, use_cache=True),
"user": Provide(authenticate),
"db": Provide(get_db, use_cache=True),
- "code_authenticator_svc": Provide(get_code_authenticator_svc),
"bm": Provide(BotsManager, use_cache=True),
"gbs": Provide(get_chatbot_settings_svc),
"sbs": Provide(store_chatbot_settings_svc),
+ "sgs": Provide(get_sentences_svc),
+ "authenticator": Provide(get_authenticator),
+ "authenticate_action": Provide(get_authenticate_action),
+ "user_repo": Provide(get_user_repo),
+ "get_user_by_jwt_action": Provide(get_get_user_by_jwt_action),
},
)
diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py
index 9412598..55e6557 100644
--- a/src/huesoporro/api/routes/api.py
+++ b/src/huesoporro/api/routes/api.py
@@ -1,12 +1,13 @@
from typing import Literal
-from litestar import MediaType, Response, get, put
+from litestar import MediaType, Response, get, post, put
from litestar.response import Template
from pydantic import BaseModel
from src.huesoporro.bot import BotsManager
from src.huesoporro.models import ChatbotSettings, User
from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
+from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc
from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
@@ -98,3 +99,17 @@ async def save_bot_settings(
) -> dict:
await sbs.run(user=user, bot_settings=data)
return {"status": "ok"}
+
+
+@get("/sentences")
+async def get_sentences(user: User, sgs: SentencesGetterSvc) -> Template:
+ sentences = await sgs.run(user=user)
+ return Template(
+ template_name="sentences.html",
+ context={"sentences": [sentence.model_dump() for sentence in sentences]},
+ )
+
+
+@post("/api/v1/sentences")
+async def save_new_sentence(user: User, data: dict) -> dict:
+ return {"id": 54, "sentence": data["sentence"]}
diff --git a/src/huesoporro/api/routes/auth.py b/src/huesoporro/api/routes/auth.py
index 2f1c287..df95441 100644
--- a/src/huesoporro/api/routes/auth.py
+++ b/src/huesoporro/api/routes/auth.py
@@ -3,14 +3,14 @@ import secrets
from litestar import MediaType, get
from litestar.response import Redirect, Template
+from src.huesoporro.actions.authenticate import AuthenticateAction
from src.huesoporro.settings import Settings
-from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc
@get(path="/o/code")
-async def get_code(code: str, code_authenticator_svc: CodeAuthenticatorSvc) -> Redirect:
- user = await code_authenticator_svc.run(code)
- return Redirect("/", cookies={"huesoporroAuth": user.encode()})
+async def get_code(code: str, authenticate_action: AuthenticateAction) -> Redirect:
+ token = await authenticate_action.run(code)
+ return Redirect("/", cookies={"huesoporroAuth": token})
@get(
diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py
index e6b3c36..025a2dd 100644
--- a/src/huesoporro/bot.py
+++ b/src/huesoporro/bot.py
@@ -23,7 +23,7 @@ from src.huesoporro.svc.store_quote import QuoteStorerSvc
class Bot(commands.Bot):
def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str):
super().__init__(
- token=user.twitch_auth.access_token, prefix="!", initial_channels=[channel]
+ token=user.twitch_access_token, prefix="!", initial_channels=[channel]
)
self.channel = channel
self.user = user
diff --git a/src/huesoporro/infra/authenticator.py b/src/huesoporro/infra/authenticator.py
index d46922c..9c372d8 100644
--- a/src/huesoporro/infra/authenticator.py
+++ b/src/huesoporro/infra/authenticator.py
@@ -30,7 +30,15 @@ class TwitchAuthenticator(BaseModel):
return await self.refresh_token(response.json()["refresh_token"])
response.raise_for_status()
- return TwitchAuth(**response.json())
+ profile = await self.get_userinfo(response.json()["access_token"])
+ return TwitchAuth(**response.json(), userinfo=profile)
+
+ async def get_userinfo(self, access_token):
+ response = await self.client.get(
+ "/oauth2/userinfo", headers={"Authorization": f"Bearer {access_token}"}
+ )
+ response.raise_for_status()
+ return response.json()
async def refresh_token(self, refresh_token: str) -> TwitchAuth:
response = await self.client.post(
@@ -60,3 +68,9 @@ class TwitchAuthenticator(BaseModel):
raise HTTPException(status_code=403, detail="Forbidden")
return user
+
+ async def token_is_valid(self, access_token: str) -> bool:
+ response = await self.client.get(
+ "/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"}
+ )
+ return response.status_code == 200 # noqa: PLR2004
diff --git a/src/huesoporro/infra/db.py b/src/huesoporro/infra/db.py
index b2f6f8f..86d8260 100644
--- a/src/huesoporro/infra/db.py
+++ b/src/huesoporro/infra/db.py
@@ -5,7 +5,7 @@ import aiosqlite
from loguru import logger
from pydantic import BaseModel, Field
-from src.huesoporro.models import ChatbotSettings, User
+from src.huesoporro.models import ChatbotSettings, Sentence, User
from src.huesoporro.settings import Settings
@@ -24,36 +24,6 @@ class Database(BaseModel):
def get_now() -> float:
return datetime.datetime.now(datetime.UTC).timestamp()
- async def save_user(self, user: User, auto_commit=True):
- async with self.get_client(auto_commit=auto_commit) as db:
- async with db.execute(
- "SELECT * FROM users WHERE user = ?", (user.user,)
- ) as cursor:
- result = await cursor.fetchone()
- if result:
- await db.execute(
- "UPDATE users SET access_token = ?, refresh_token = ?, expires_at = ?, last_updated_at = ? WHERE user = ?",
- (
- user.twitch_auth.access_token,
- user.twitch_auth.refresh_token,
- user.expires_at,
- self.get_now(),
- user.user,
- ),
- )
- return
-
- await db.execute(
- "INSERT INTO users (user, access_token, refresh_token, expires_at, last_updated_at) VALUES (?,?,?,?,?)",
- (
- user.user,
- user.twitch_auth.access_token,
- user.twitch_auth.refresh_token,
- user.expires_at,
- self.get_now(),
- ),
- )
-
async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True):
async with self.get_client(auto_commit=auto_commit) as db:
await db.execute(
@@ -133,3 +103,14 @@ class Database(BaseModel):
) as cursor,
):
return await cursor.fetchone()
+
+ async def get_sentences(self, user: User) -> list[Sentence]:
+ async with self.get_client() as db:
+ db.row_factory = aiosqlite.Row
+ async with db.execute(
+ "SELECT * FROM sentences WHERE user_id = ?", (user.user,)
+ ) as cursor:
+ result = await cursor.fetchall()
+ if not result:
+ return []
+ return [Sentence(user=user, **dict(value)) for value in result]
diff --git a/src/huesoporro/infra/repos.py b/src/huesoporro/infra/repos.py
new file mode 100644
index 0000000..7e4e939
--- /dev/null
+++ b/src/huesoporro/infra/repos.py
@@ -0,0 +1,114 @@
+import json
+from abc import ABC, abstractmethod
+from contextlib import asynccontextmanager
+from typing import Generic, TypeVar
+
+import aiosqlite
+from pydantic import BaseModel, Field
+
+from src.huesoporro.models import User
+from src.huesoporro.settings import Settings
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class IRepo(BaseModel, ABC, Generic[T]):
+ s: Settings = Field(default_factory=Settings.get)
+
+ @asynccontextmanager
+ async def get_client(self, auto_commit=True):
+ async with aiosqlite.connect(self.s.db_filepath) as db:
+ db.row_factory = aiosqlite.Row
+ yield db
+ if auto_commit:
+ await db.commit()
+
+ @abstractmethod
+ async def create(self, obj: T, auto_commit=True) -> T:
+ pass
+
+ @abstractmethod
+ async def update(self, obj: T, auto_commit=True) -> T:
+ pass
+
+ @abstractmethod
+ async def delete(self, obj: T, auto_commit=True):
+ pass
+
+ @abstractmethod
+ async def get_by_id(self, obj_id: int | str, auto_commit=True) -> T | None:
+ pass
+
+ @abstractmethod
+ async def list(
+ self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True
+ ) -> list[T]:
+ pass
+
+
+class UserRepo(IRepo[User]):
+ async def get_by_id(self, obj_id: int | str, auto_commit=True) -> User | None:
+ raise NotImplementedError("Not implemented since it's not needed")
+
+ async def create(self, obj: User, auto_commit=True) -> User:
+ async with self.get_client(auto_commit=auto_commit) as db:
+ await db.execute(
+ "INSERT INTO users (user, external_auth) VALUES (?, ?)",
+ (obj.user, json.dumps(obj.external_auth)),
+ )
+ return obj
+
+ async def update(self, obj: User, auto_commit=True) -> User:
+ if not await self.get_by_user(obj.user):
+ raise ValueError(f"User {obj.user} does not exist")
+
+ async with (
+ self.get_client(auto_commit=auto_commit) as db,
+ db.execute(
+ """
+ UPDATE users
+ SET external_auth = ?
+ WHERE user = ?
+ RETURNING *
+ """,
+ (json.dumps(obj.external_auth), obj.user),
+ ) as cursor,
+ ):
+ data = await cursor.fetchone()
+ return User(
+ user=data["user"], external_auth=json.loads(data["external_auth"])
+ )
+
+ async def delete(self, obj: User, auto_commit=True):
+ async with self.get_client(auto_commit=auto_commit) as db:
+ await db.execute(
+ """
+ DELETE FROM users WHERE user = ?
+ """,
+ (obj.user,),
+ )
+
+ async def get_by_user(self, user: str, auto_commit=True) -> User | None:
+ async with (
+ self.get_client(auto_commit=auto_commit) as db,
+ db.execute(
+ """
+ SELECT * FROM users WHERE user = ?
+ """,
+ (user,),
+ ) as cursor,
+ ):
+ data = await cursor.fetchone()
+ if not data:
+ return None
+ return User(
+ user=data["user"], external_auth=json.loads(data["external_auth"])
+ )
+
+ async def list(
+ self, obj: User, offset: int = 0, limit: int = 10, auto_commit=True
+ ) -> list[User]:
+ raise NotImplementedError("Not implemented since it's not needed")
+
+ async def count(self, obj: User, auto_commit=True):
+ raise NotImplementedError("Not implemented since it's not needed")
diff --git a/src/huesoporro/models.py b/src/huesoporro/models.py
index b4df9e9..f791cf6 100644
--- a/src/huesoporro/models.py
+++ b/src/huesoporro/models.py
@@ -1,4 +1,4 @@
-from typing import Self
+from typing import Literal
import jwt
from pydantic import BaseModel, Field, field_validator
@@ -9,28 +9,39 @@ from src.huesoporro.settings import Settings
class TwitchAuth(BaseModel):
access_token: str
refresh_token: str
+ userinfo: dict
+
+
+class ExternalAuth(BaseModel):
+ credentials: dict
+ type: Literal["twitch"] = "twitch"
class User(BaseModel):
user: str
- expires_at: float
- twitch_auth: TwitchAuth
+ external_auth: dict[Literal["twitch", "discord"], dict]
- def encode(self, settings: Settings | None = None) -> str:
+ def encode(
+ self, settings: Settings | None = None, exclude_fields: set[str] | None = None
+ ) -> str:
s = settings or Settings.get()
+ exclude_fields = exclude_fields or {"external_auth"}
return jwt.encode(
- self.model_dump(),
+ self.model_dump(exclude=exclude_fields),
key=s.jwt_secret.get_secret_value(),
algorithm="HS256",
)
@classmethod
- def decode(cls, token: str, settings: Settings | None = None) -> Self:
+ def decode(cls, token: str, settings: Settings | None = None) -> dict:
s = settings or Settings.get()
- decoded = jwt.decode(
+ return jwt.decode(
token, key=s.jwt_secret.get_secret_value(), algorithms=["HS256"]
)
- return cls(**decoded)
+
+ @property
+ def twitch_access_token(self):
+ return self.external_auth["twitch"]["access_token"]
class ChatbotSettings(BaseModel):
@@ -50,3 +61,11 @@ class ChatbotSettings(BaseModel):
if isinstance(v, str):
return v.split(",")
return v
+
+
+class Sentence(BaseModel):
+ id: int
+ sentence: str
+ created_at: float
+ last_updated_at: float
+ user: User
diff --git a/src/huesoporro/svc/authenticate.py b/src/huesoporro/svc/authenticate.py
deleted file mode 100644
index 346a407..0000000
--- a/src/huesoporro/svc/authenticate.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import datetime
-
-from pydantic import BaseModel
-
-from src.huesoporro.infra.authenticator import TwitchAuthenticator
-from src.huesoporro.infra.db import Database
-from src.huesoporro.models import User
-
-
-class CodeAuthenticatorSvc(BaseModel):
- db: Database
- authenticator: TwitchAuthenticator
-
- @staticmethod
- def get_four_hours_from_now() -> float:
- now = datetime.datetime.now(datetime.UTC)
- four_hours_later = now + datetime.timedelta(hours=4)
- return four_hours_later.timestamp()
-
- async def run(self, code: str) -> User:
- auth = await self.authenticator.get_token(code)
- username = await self.authenticator.validate_token(auth.access_token)
- expires_at = self.get_four_hours_from_now()
- user = User(user=username, expires_at=expires_at, twitch_auth=auth)
- await self.db.save_user(user)
- return user
diff --git a/src/huesoporro/svc/get_sentences_svc.py b/src/huesoporro/svc/get_sentences_svc.py
new file mode 100644
index 0000000..1cb0a1c
--- /dev/null
+++ b/src/huesoporro/svc/get_sentences_svc.py
@@ -0,0 +1,11 @@
+from pydantic import BaseModel
+
+from src.huesoporro.infra.db import Database
+from src.huesoporro.models import Sentence, User
+
+
+class SentencesGetterSvc(BaseModel):
+ db: Database
+
+ async def run(self, user: User) -> list[Sentence]:
+ return await self.db.get_sentences(user=user)
diff --git a/src/huesoporro/svc/refresh.py b/src/huesoporro/svc/refresh.py
deleted file mode 100644
index 9209743..0000000
--- a/src/huesoporro/svc/refresh.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import datetime
-
-from pydantic import BaseModel
-
-from src.huesoporro.infra.authenticator import TwitchAuthenticator
-from src.huesoporro.infra.db import Database
-from src.huesoporro.models import User
-
-
-class RefreshTokenAuthenticator(BaseModel):
- db: Database
- authenticator: TwitchAuthenticator
-
- @staticmethod
- def get_four_hours_from_now() -> float:
- now = datetime.datetime.now(datetime.UTC)
- four_hours_later = now + datetime.timedelta(hours=4)
- return four_hours_later.timestamp()
-
- async def run(self, refresh_token: str) -> User:
- auth = await self.authenticator.refresh_token(refresh_token)
- username = await self.authenticator.validate_token(auth.access_token)
- expires_at = self.get_four_hours_from_now()
-
- user = User(user=username, expires_at=expires_at, twitch_auth=auth)
- await self.db.save_user(user)
- return user
diff --git a/src/huesoporro/templates/header.html b/src/huesoporro/templates/header.html
index 7c1cfd9..e6fe60a 100644
--- a/src/huesoporro/templates/header.html
+++ b/src/huesoporro/templates/header.html
@@ -2,15 +2,24 @@
-
+
+
+
+
+
Huesoporro
diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html
index b17419f..e7ffa8c 100644
--- a/src/huesoporro/templates/index.html
+++ b/src/huesoporro/templates/index.html
@@ -2,16 +2,16 @@
-
-
+