From 50900986fa5cb8b2d0065259ea95364d9bbe9beb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Fri, 17 Jan 2025 18:15:58 +0100 Subject: [PATCH] feat: revamp authentication -- remove twitch's tokens from our own wrapper token --- devenv.lock | 50 +++++++- devenv.nix | 7 + devenv.yaml | 17 +-- migrations/20241219191711_sentences.py | 35 +++++ .../20250112153541_user_external_auth.py | 53 ++++++++ .../20250113142241_external_auth_json.py | 35 +++++ pyproject.toml | 2 + src/huesoporro/actions/authenticate.py | 27 ++++ src/huesoporro/actions/get_user_by_jwt.py | 38 ++++++ src/huesoporro/actions/refresh.py | 27 ++++ src/huesoporro/api/dependencies.py | 39 ++++-- src/huesoporro/api/main.py | 15 ++- src/huesoporro/api/routes/api.py | 17 ++- src/huesoporro/api/routes/auth.py | 8 +- src/huesoporro/bot.py | 2 +- src/huesoporro/infra/authenticator.py | 16 ++- src/huesoporro/infra/db.py | 43 ++----- src/huesoporro/infra/repos.py | 114 +++++++++++++++++ src/huesoporro/models.py | 35 +++-- src/huesoporro/svc/authenticate.py | 26 ---- src/huesoporro/svc/get_sentences_svc.py | 11 ++ src/huesoporro/svc/refresh.py | 27 ---- src/huesoporro/templates/header.html | 11 +- src/huesoporro/templates/index.html | 10 +- .../templates/le_funny_dropdown.html | 10 ++ src/huesoporro/templates/login.html | 4 +- src/huesoporro/templates/logout.html | 2 +- src/huesoporro/templates/sentences.html | 121 ++++++++++++++++-- tests/conftest.py | 11 +- tests/test_repos.py | 53 ++++++++ uv.lock | 25 ++++ 31 files changed, 736 insertions(+), 155 deletions(-) create mode 100644 migrations/20241219191711_sentences.py create mode 100644 migrations/20250112153541_user_external_auth.py create mode 100644 migrations/20250113142241_external_auth_json.py create mode 100644 src/huesoporro/actions/authenticate.py create mode 100644 src/huesoporro/actions/get_user_by_jwt.py create mode 100644 src/huesoporro/actions/refresh.py create mode 100644 src/huesoporro/infra/repos.py delete mode 100644 src/huesoporro/svc/authenticate.py create mode 100644 src/huesoporro/svc/get_sentences_svc.py delete mode 100644 src/huesoporro/svc/refresh.py create mode 100644 src/huesoporro/templates/le_funny_dropdown.html create mode 100644 tests/test_repos.py diff --git a/devenv.lock b/devenv.lock index 8ab1706..97f804a 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1733788855, + "lastModified": 1735530587, "owner": "cachix", "repo": "devenv", - "rev": "d59fee8696cd48f69cf79f65992269df9891ba86", + "rev": "69645885c1052cc1ca398ac30ba7dfc63386c0e3", "type": "github" }, "original": { @@ -31,6 +31,21 @@ "type": "github" } }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, "gitignore": { "inputs": { "nixpkgs": [ @@ -66,12 +81,32 @@ "type": "github" } }, + "nixpkgs-python": { + "inputs": { + "flake-compat": "flake-compat", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733319315, + "owner": "cachix", + "repo": "nixpkgs-python", + "rev": "01263eeb28c09f143d59cd6b0b7c4cc8478efd48", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "nixpkgs-python", + "type": "github" + } + }, "nixpkgs-stable": { "locked": { - "lastModified": 1733730953, + "lastModified": 1735286948, "owner": "NixOS", "repo": "nixpkgs", - "rev": "7109b680d161993918b0a126f38bc39763e5a709", + "rev": "31ac92f9628682b294026f0860e14587a09ffb4b", "type": "github" }, "original": { @@ -83,7 +118,7 @@ }, "pre-commit-hooks": { "inputs": { - "flake-compat": "flake-compat", + "flake-compat": "flake-compat_2", "gitignore": "gitignore", "nixpkgs": [ "nixpkgs" @@ -91,10 +126,10 @@ "nixpkgs-stable": "nixpkgs-stable" }, "locked": { - "lastModified": 1733665616, + "lastModified": 1734797603, "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "d8c02f0ffef0ef39f6063731fc539d8c71eb463a", + "rev": "f0f0dc4920a903c3e08f5bdb9246bb572fcae498", "type": "github" }, "original": { @@ -107,6 +142,7 @@ "inputs": { "devenv": "devenv", "nixpkgs": "nixpkgs", + "nixpkgs-python": "nixpkgs-python", "pre-commit-hooks": "pre-commit-hooks" } } diff --git a/devenv.nix b/devenv.nix index 771f4d7..eea6c6c 100644 --- a/devenv.nix +++ b/devenv.nix @@ -5,8 +5,15 @@ packages = [ pkgs.git ]; + certificates = [ + "id.twitch.tv" + "twitch.tv" + "discord.com" + ]; + languages.python.enable = true; languages.python.uv.enable = true; + languages.python.version = "3.12.8"; scripts.hello.exec = '' echo hello from $GREET diff --git a/devenv.yaml b/devenv.yaml index 116a2ad..184b866 100644 --- a/devenv.yaml +++ b/devenv.yaml @@ -1,15 +1,8 @@ -# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json inputs: nixpkgs: url: github:cachix/devenv-nixpkgs/rolling - -# If you're using non-OSS software, you can set allowUnfree to true. -# allowUnfree: true - -# If you're willing to use a package that's vulnerable -# permittedInsecurePackages: -# - "openssl-1.1.1w" - -# If you have more than one devenv you can merge them -#imports: -# - ./backend + nixpkgs-python: + url: github:cachix/nixpkgs-python + inputs: + nixpkgs: + follows: nixpkgs diff --git a/migrations/20241219191711_sentences.py b/migrations/20241219191711_sentences.py new file mode 100644 index 0000000..06830e0 --- /dev/null +++ b/migrations/20241219191711_sentences.py @@ -0,0 +1,35 @@ +""" +This module contains a Caribou migration. + +Migration Name: sentences +Migration Version: 20241219191711 +""" + + +def upgrade(connection): + # update table `sentences` to have a user_id row + # which references users.id + # and a channel VARCHAR(255) row + + sql = """ + DROP TABLE IF EXISTS sentences; + """ + connection.execute(sql) + connection.commit() + sql = """ + CREATE TABLE sentences( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + sentence VARCHAR(255) NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + user_id VARCHAR(255) NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20250112153541_user_external_auth.py b/migrations/20250112153541_user_external_auth.py new file mode 100644 index 0000000..d873a70 --- /dev/null +++ b/migrations/20250112153541_user_external_auth.py @@ -0,0 +1,53 @@ +""" +This module contains a Caribou migration. + +Migration Name: user_external_auth +Migration Version: 20250112153541 +""" + + +def upgrade(connection): + """ + - delete access_token, refresh_token, and expires_at from users + - add external_auth table which will store the external auths: + - type: twitch or discord + - credentials: JSON + """ + + sql = """ + ALTER TABLE users DROP COLUMN access_token; + """ + connection.execute(sql) + sql = """ + ALTER TABLE users DROP COLUMN refresh_token; + """ + connection.execute(sql) + sql = """ + ALTER TABLE users DROP COLUMN expires_at; + """ + connection.execute(sql) + + sql = """ + CREATE TABLE external_auth( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + type VARCHAR(255) NOT NULL, + credentials JSON NOT NULL + ); + """ + connection.execute(sql) + + sql = """ + CREATE TABLE user_external_auth( + user_id VARCHAR(255) NOT NULL, + external_auth_id INTEGER NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (external_auth_id) REFERENCES external_auth(id) + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20250113142241_external_auth_json.py b/migrations/20250113142241_external_auth_json.py new file mode 100644 index 0000000..e92e847 --- /dev/null +++ b/migrations/20250113142241_external_auth_json.py @@ -0,0 +1,35 @@ +""" +This module contains a Caribou migration. + +Migration Name: external_auth_json +Migration Version: 20250113142241 +""" + + +def upgrade(connection): + """remove tables: + - external_auth + - user_external_auth + add column to users table: + - external_auth JSON + """ + sql = """ + DROP TABLE IF EXISTS external_auth; + """ + connection.execute(sql) + + sql = """ + DROP TABLE IF EXISTS user_external_auth; + """ + connection.execute(sql) + + sql = """ + ALTER TABLE users ADD COLUMN external_auth JSON; + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/pyproject.toml b/pyproject.toml index 1da3c58..d2f948a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pyjwt>=2.10.1", "twitchio>=2.10.0", "redis>=5.2.1", + "pytz>=2024.2", + "discord-py>=2.4.0", ] [tool.uv] diff --git a/src/huesoporro/actions/authenticate.py b/src/huesoporro/actions/authenticate.py new file mode 100644 index 0000000..9bf7b6a --- /dev/null +++ b/src/huesoporro/actions/authenticate.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class AuthenticateAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run( + self, + auth_code: str, + ): + tokens = await self.authenticator.get_token(auth_code) + username = tokens.userinfo["preferred_username"] + if username not in self.s.allowed_users: + raise ValueError(f"User {username} is not allowed to use this bot") + user = User(user=username, external_auth={"twitch": tokens.model_dump()}) + if await self.user_repo.get_by_user(user.user): + await self.user_repo.update(user) + else: + await self.user_repo.create(user) + return user.encode() diff --git a/src/huesoporro/actions/get_user_by_jwt.py b/src/huesoporro/actions/get_user_by_jwt.py new file mode 100644 index 0000000..b5311ab --- /dev/null +++ b/src/huesoporro/actions/get_user_by_jwt.py @@ -0,0 +1,38 @@ +from loguru import logger +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class GetUserByJWTAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run( + self, + jwt_token: str, + ) -> User: + user_data = User.decode(jwt_token) + username = user_data["user"] + user = await self.user_repo.get_by_user(username) + if not user: + raise ValueError(f"User {username} not found") + is_valid = await self.authenticator.token_is_valid( + user.external_auth["twitch"]["access_token"] + ) + logger.info(f"Token {user} is valid: {is_valid}") + if not is_valid: + logger.info(f"Refreshing token for user {user}") + new_tokens = await self.authenticator.refresh_token( + user.external_auth["twitch"]["refresh_token"] + ) + user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index] + user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index] + await self.user_repo.update(user) + return user + + return user diff --git a/src/huesoporro/actions/refresh.py b/src/huesoporro/actions/refresh.py new file mode 100644 index 0000000..4ba8543 --- /dev/null +++ b/src/huesoporro/actions/refresh.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + + +class RefreshAction(BaseModel): + user_repo: UserRepo + authenticator: TwitchAuthenticator + s: Settings + + async def run(self, user: User): + is_valid = await self.authenticator.token_is_valid( + user.external_auth["twitch"]["access_token"] + ) + + if not is_valid: + new_tokens = await self.authenticator.refresh_token( + user.external_auth["twitch"]["refresh_token"] + ) + user.external_auth["twitch"]["access_token"] = new_tokens["access_token"] # type: ignore[index] + user.external_auth["twitch"]["refresh_token"] = new_tokens["refresh_token"] # type: ignore[index] + await self.user_repo.update(user) + return user.encode() + return None diff --git a/src/huesoporro/api/dependencies.py b/src/huesoporro/api/dependencies.py index 5db8371..d1d9812 100644 --- a/src/huesoporro/api/dependencies.py +++ b/src/huesoporro/api/dependencies.py @@ -1,12 +1,15 @@ from litestar import Request from litestar.exceptions import HTTPException +from src.huesoporro.actions.authenticate import AuthenticateAction +from src.huesoporro.actions.get_user_by_jwt import GetUserByJWTAction from src.huesoporro.infra.authenticator import TwitchAuthenticator from src.huesoporro.infra.db import Database +from src.huesoporro.infra.repos import UserRepo from src.huesoporro.models import User from src.huesoporro.settings import Settings -from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc @@ -22,27 +25,43 @@ def get_db(s: Settings): return Database(s=s) -async def authenticate(request: Request) -> User: +async def get_get_user_by_jwt_action( + user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings +): + return GetUserByJWTAction(user_repo=user_repo, authenticator=authenticator, s=s) + + +async def authenticate( + request: Request, get_user_by_jwt_action: GetUserByJWTAction +) -> User: token = request.query_params.get("huesoporro_token") if token: - return User.decode(token) + return await get_user_by_jwt_action.run(token) cookies = request.cookies.get("huesoporroAuth") if cookies: - return User.decode(cookies) + return await get_user_by_jwt_action.run(cookies) raise HTTPException(status_code=401, detail="Unauthorized") -async def get_code_authenticator_svc( - a: TwitchAuthenticator, db: Database -) -> CodeAuthenticatorSvc: - return CodeAuthenticatorSvc(authenticator=a, db=db) - - async def get_chatbot_settings_svc(db: Database): return ChatbotSettingsGetterSvc(db=db) async def store_chatbot_settings_svc(db: Database): return ChatbotSettingsStorerSvc(db=db) + + +async def get_sentences_svc(db: Database): + return SentencesGetterSvc(db=db) + + +async def get_user_repo(s: Settings): + return UserRepo(s=s) + + +async def get_authenticate_action( + user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings +): + return AuthenticateAction(user_repo=user_repo, authenticator=authenticator, s=s) diff --git a/src/huesoporro/api/main.py b/src/huesoporro/api/main.py index dd0489a..8df704e 100644 --- a/src/huesoporro/api/main.py +++ b/src/huesoporro/api/main.py @@ -8,11 +8,14 @@ from litestar.template import TemplateConfig from src.huesoporro.api.dependencies import ( authenticate, + get_authenticate_action, get_authenticator, get_chatbot_settings_svc, - get_code_authenticator_svc, get_db, + get_get_user_by_jwt_action, + get_sentences_svc, get_settings, + get_user_repo, store_chatbot_settings_svc, ) from src.huesoporro.api.errors import ( @@ -24,10 +27,12 @@ from src.huesoporro.api.routes.api import ( get_bot_settings, get_bot_status, get_index, + get_sentences, get_tts_overlay, get_tts_permalink, manage_bot, save_bot_settings, + save_new_sentence, ) from src.huesoporro.api.routes.auth import get_code, login from src.huesoporro.bot import BotsManager @@ -52,6 +57,8 @@ def create_app(): get_bot_status, save_bot_settings, get_bot_settings, + get_sentences, + save_new_sentence, ], static_files_config=( StaticFilesConfig( @@ -77,10 +84,14 @@ def create_app(): "a": Provide(get_authenticator, use_cache=True), "user": Provide(authenticate), "db": Provide(get_db, use_cache=True), - "code_authenticator_svc": Provide(get_code_authenticator_svc), "bm": Provide(BotsManager, use_cache=True), "gbs": Provide(get_chatbot_settings_svc), "sbs": Provide(store_chatbot_settings_svc), + "sgs": Provide(get_sentences_svc), + "authenticator": Provide(get_authenticator), + "authenticate_action": Provide(get_authenticate_action), + "user_repo": Provide(get_user_repo), + "get_user_by_jwt_action": Provide(get_get_user_by_jwt_action), }, ) diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py index 9412598..55e6557 100644 --- a/src/huesoporro/api/routes/api.py +++ b/src/huesoporro/api/routes/api.py @@ -1,12 +1,13 @@ from typing import Literal -from litestar import MediaType, Response, get, put +from litestar import MediaType, Response, get, post, put from litestar.response import Template from pydantic import BaseModel from src.huesoporro.bot import BotsManager from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.get_sentences_svc import SentencesGetterSvc from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc @@ -98,3 +99,17 @@ async def save_bot_settings( ) -> dict: await sbs.run(user=user, bot_settings=data) return {"status": "ok"} + + +@get("/sentences") +async def get_sentences(user: User, sgs: SentencesGetterSvc) -> Template: + sentences = await sgs.run(user=user) + return Template( + template_name="sentences.html", + context={"sentences": [sentence.model_dump() for sentence in sentences]}, + ) + + +@post("/api/v1/sentences") +async def save_new_sentence(user: User, data: dict) -> dict: + return {"id": 54, "sentence": data["sentence"]} diff --git a/src/huesoporro/api/routes/auth.py b/src/huesoporro/api/routes/auth.py index 2f1c287..df95441 100644 --- a/src/huesoporro/api/routes/auth.py +++ b/src/huesoporro/api/routes/auth.py @@ -3,14 +3,14 @@ import secrets from litestar import MediaType, get from litestar.response import Redirect, Template +from src.huesoporro.actions.authenticate import AuthenticateAction from src.huesoporro.settings import Settings -from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc @get(path="/o/code") -async def get_code(code: str, code_authenticator_svc: CodeAuthenticatorSvc) -> Redirect: - user = await code_authenticator_svc.run(code) - return Redirect("/", cookies={"huesoporroAuth": user.encode()}) +async def get_code(code: str, authenticate_action: AuthenticateAction) -> Redirect: + token = await authenticate_action.run(code) + return Redirect("/", cookies={"huesoporroAuth": token}) @get( diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py index e6b3c36..025a2dd 100644 --- a/src/huesoporro/bot.py +++ b/src/huesoporro/bot.py @@ -23,7 +23,7 @@ from src.huesoporro.svc.store_quote import QuoteStorerSvc class Bot(commands.Bot): def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str): super().__init__( - token=user.twitch_auth.access_token, prefix="!", initial_channels=[channel] + token=user.twitch_access_token, prefix="!", initial_channels=[channel] ) self.channel = channel self.user = user diff --git a/src/huesoporro/infra/authenticator.py b/src/huesoporro/infra/authenticator.py index d46922c..9c372d8 100644 --- a/src/huesoporro/infra/authenticator.py +++ b/src/huesoporro/infra/authenticator.py @@ -30,7 +30,15 @@ class TwitchAuthenticator(BaseModel): return await self.refresh_token(response.json()["refresh_token"]) response.raise_for_status() - return TwitchAuth(**response.json()) + profile = await self.get_userinfo(response.json()["access_token"]) + return TwitchAuth(**response.json(), userinfo=profile) + + async def get_userinfo(self, access_token): + response = await self.client.get( + "/oauth2/userinfo", headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + return response.json() async def refresh_token(self, refresh_token: str) -> TwitchAuth: response = await self.client.post( @@ -60,3 +68,9 @@ class TwitchAuthenticator(BaseModel): raise HTTPException(status_code=403, detail="Forbidden") return user + + async def token_is_valid(self, access_token: str) -> bool: + response = await self.client.get( + "/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"} + ) + return response.status_code == 200 # noqa: PLR2004 diff --git a/src/huesoporro/infra/db.py b/src/huesoporro/infra/db.py index b2f6f8f..86d8260 100644 --- a/src/huesoporro/infra/db.py +++ b/src/huesoporro/infra/db.py @@ -5,7 +5,7 @@ import aiosqlite from loguru import logger from pydantic import BaseModel, Field -from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.models import ChatbotSettings, Sentence, User from src.huesoporro.settings import Settings @@ -24,36 +24,6 @@ class Database(BaseModel): def get_now() -> float: return datetime.datetime.now(datetime.UTC).timestamp() - async def save_user(self, user: User, auto_commit=True): - async with self.get_client(auto_commit=auto_commit) as db: - async with db.execute( - "SELECT * FROM users WHERE user = ?", (user.user,) - ) as cursor: - result = await cursor.fetchone() - if result: - await db.execute( - "UPDATE users SET access_token = ?, refresh_token = ?, expires_at = ?, last_updated_at = ? WHERE user = ?", - ( - user.twitch_auth.access_token, - user.twitch_auth.refresh_token, - user.expires_at, - self.get_now(), - user.user, - ), - ) - return - - await db.execute( - "INSERT INTO users (user, access_token, refresh_token, expires_at, last_updated_at) VALUES (?,?,?,?,?)", - ( - user.user, - user.twitch_auth.access_token, - user.twitch_auth.refresh_token, - user.expires_at, - self.get_now(), - ), - ) - async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True): async with self.get_client(auto_commit=auto_commit) as db: await db.execute( @@ -133,3 +103,14 @@ class Database(BaseModel): ) as cursor, ): return await cursor.fetchone() + + async def get_sentences(self, user: User) -> list[Sentence]: + async with self.get_client() as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM sentences WHERE user_id = ?", (user.user,) + ) as cursor: + result = await cursor.fetchall() + if not result: + return [] + return [Sentence(user=user, **dict(value)) for value in result] diff --git a/src/huesoporro/infra/repos.py b/src/huesoporro/infra/repos.py new file mode 100644 index 0000000..7e4e939 --- /dev/null +++ b/src/huesoporro/infra/repos.py @@ -0,0 +1,114 @@ +import json +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from typing import Generic, TypeVar + +import aiosqlite +from pydantic import BaseModel, Field + +from src.huesoporro.models import User +from src.huesoporro.settings import Settings + +T = TypeVar("T", bound=BaseModel) + + +class IRepo(BaseModel, ABC, Generic[T]): + s: Settings = Field(default_factory=Settings.get) + + @asynccontextmanager + async def get_client(self, auto_commit=True): + async with aiosqlite.connect(self.s.db_filepath) as db: + db.row_factory = aiosqlite.Row + yield db + if auto_commit: + await db.commit() + + @abstractmethod + async def create(self, obj: T, auto_commit=True) -> T: + pass + + @abstractmethod + async def update(self, obj: T, auto_commit=True) -> T: + pass + + @abstractmethod + async def delete(self, obj: T, auto_commit=True): + pass + + @abstractmethod + async def get_by_id(self, obj_id: int | str, auto_commit=True) -> T | None: + pass + + @abstractmethod + async def list( + self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True + ) -> list[T]: + pass + + +class UserRepo(IRepo[User]): + async def get_by_id(self, obj_id: int | str, auto_commit=True) -> User | None: + raise NotImplementedError("Not implemented since it's not needed") + + async def create(self, obj: User, auto_commit=True) -> User: + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + "INSERT INTO users (user, external_auth) VALUES (?, ?)", + (obj.user, json.dumps(obj.external_auth)), + ) + return obj + + async def update(self, obj: User, auto_commit=True) -> User: + if not await self.get_by_user(obj.user): + raise ValueError(f"User {obj.user} does not exist") + + async with ( + self.get_client(auto_commit=auto_commit) as db, + db.execute( + """ + UPDATE users + SET external_auth = ? + WHERE user = ? + RETURNING * + """, + (json.dumps(obj.external_auth), obj.user), + ) as cursor, + ): + data = await cursor.fetchone() + return User( + user=data["user"], external_auth=json.loads(data["external_auth"]) + ) + + async def delete(self, obj: User, auto_commit=True): + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + """ + DELETE FROM users WHERE user = ? + """, + (obj.user,), + ) + + async def get_by_user(self, user: str, auto_commit=True) -> User | None: + async with ( + self.get_client(auto_commit=auto_commit) as db, + db.execute( + """ + SELECT * FROM users WHERE user = ? + """, + (user,), + ) as cursor, + ): + data = await cursor.fetchone() + if not data: + return None + return User( + user=data["user"], external_auth=json.loads(data["external_auth"]) + ) + + async def list( + self, obj: User, offset: int = 0, limit: int = 10, auto_commit=True + ) -> list[User]: + raise NotImplementedError("Not implemented since it's not needed") + + async def count(self, obj: User, auto_commit=True): + raise NotImplementedError("Not implemented since it's not needed") diff --git a/src/huesoporro/models.py b/src/huesoporro/models.py index b4df9e9..f791cf6 100644 --- a/src/huesoporro/models.py +++ b/src/huesoporro/models.py @@ -1,4 +1,4 @@ -from typing import Self +from typing import Literal import jwt from pydantic import BaseModel, Field, field_validator @@ -9,28 +9,39 @@ from src.huesoporro.settings import Settings class TwitchAuth(BaseModel): access_token: str refresh_token: str + userinfo: dict + + +class ExternalAuth(BaseModel): + credentials: dict + type: Literal["twitch"] = "twitch" class User(BaseModel): user: str - expires_at: float - twitch_auth: TwitchAuth + external_auth: dict[Literal["twitch", "discord"], dict] - def encode(self, settings: Settings | None = None) -> str: + def encode( + self, settings: Settings | None = None, exclude_fields: set[str] | None = None + ) -> str: s = settings or Settings.get() + exclude_fields = exclude_fields or {"external_auth"} return jwt.encode( - self.model_dump(), + self.model_dump(exclude=exclude_fields), key=s.jwt_secret.get_secret_value(), algorithm="HS256", ) @classmethod - def decode(cls, token: str, settings: Settings | None = None) -> Self: + def decode(cls, token: str, settings: Settings | None = None) -> dict: s = settings or Settings.get() - decoded = jwt.decode( + return jwt.decode( token, key=s.jwt_secret.get_secret_value(), algorithms=["HS256"] ) - return cls(**decoded) + + @property + def twitch_access_token(self): + return self.external_auth["twitch"]["access_token"] class ChatbotSettings(BaseModel): @@ -50,3 +61,11 @@ class ChatbotSettings(BaseModel): if isinstance(v, str): return v.split(",") return v + + +class Sentence(BaseModel): + id: int + sentence: str + created_at: float + last_updated_at: float + user: User diff --git a/src/huesoporro/svc/authenticate.py b/src/huesoporro/svc/authenticate.py deleted file mode 100644 index 346a407..0000000 --- a/src/huesoporro/svc/authenticate.py +++ /dev/null @@ -1,26 +0,0 @@ -import datetime - -from pydantic import BaseModel - -from src.huesoporro.infra.authenticator import TwitchAuthenticator -from src.huesoporro.infra.db import Database -from src.huesoporro.models import User - - -class CodeAuthenticatorSvc(BaseModel): - db: Database - authenticator: TwitchAuthenticator - - @staticmethod - def get_four_hours_from_now() -> float: - now = datetime.datetime.now(datetime.UTC) - four_hours_later = now + datetime.timedelta(hours=4) - return four_hours_later.timestamp() - - async def run(self, code: str) -> User: - auth = await self.authenticator.get_token(code) - username = await self.authenticator.validate_token(auth.access_token) - expires_at = self.get_four_hours_from_now() - user = User(user=username, expires_at=expires_at, twitch_auth=auth) - await self.db.save_user(user) - return user diff --git a/src/huesoporro/svc/get_sentences_svc.py b/src/huesoporro/svc/get_sentences_svc.py new file mode 100644 index 0000000..1cb0a1c --- /dev/null +++ b/src/huesoporro/svc/get_sentences_svc.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database +from src.huesoporro.models import Sentence, User + + +class SentencesGetterSvc(BaseModel): + db: Database + + async def run(self, user: User) -> list[Sentence]: + return await self.db.get_sentences(user=user) diff --git a/src/huesoporro/svc/refresh.py b/src/huesoporro/svc/refresh.py deleted file mode 100644 index 9209743..0000000 --- a/src/huesoporro/svc/refresh.py +++ /dev/null @@ -1,27 +0,0 @@ -import datetime - -from pydantic import BaseModel - -from src.huesoporro.infra.authenticator import TwitchAuthenticator -from src.huesoporro.infra.db import Database -from src.huesoporro.models import User - - -class RefreshTokenAuthenticator(BaseModel): - db: Database - authenticator: TwitchAuthenticator - - @staticmethod - def get_four_hours_from_now() -> float: - now = datetime.datetime.now(datetime.UTC) - four_hours_later = now + datetime.timedelta(hours=4) - return four_hours_later.timestamp() - - async def run(self, refresh_token: str) -> User: - auth = await self.authenticator.refresh_token(refresh_token) - username = await self.authenticator.validate_token(auth.access_token) - expires_at = self.get_four_hours_from_now() - - user = User(user=username, expires_at=expires_at, twitch_auth=auth) - await self.db.save_user(user) - return user diff --git a/src/huesoporro/templates/header.html b/src/huesoporro/templates/header.html index 7c1cfd9..e6fe60a 100644 --- a/src/huesoporro/templates/header.html +++ b/src/huesoporro/templates/header.html @@ -2,15 +2,24 @@ - + + + + + Huesoporro diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html index b17419f..e7ffa8c 100644 --- a/src/huesoporro/templates/index.html +++ b/src/huesoporro/templates/index.html @@ -2,16 +2,16 @@
-
-
+
@@ -102,7 +102,7 @@ .catch((error) => { console.error('Failed to retrieve chatbot status', error); }); - }, 2000); + }, 5000); } async startBot() { @@ -184,6 +184,8 @@ const chatbotManager = new ChatbotManager(); chatbotManager.setEvents(); + + }); diff --git a/src/huesoporro/templates/le_funny_dropdown.html b/src/huesoporro/templates/le_funny_dropdown.html new file mode 100644 index 0000000..a7f63a6 --- /dev/null +++ b/src/huesoporro/templates/le_funny_dropdown.html @@ -0,0 +1,10 @@ +
  • + +
  • diff --git a/src/huesoporro/templates/login.html b/src/huesoporro/templates/login.html index 3a3181a..b40ac0c 100644 --- a/src/huesoporro/templates/login.html +++ b/src/huesoporro/templates/login.html @@ -1,9 +1,9 @@ {% include 'header.html' %} -
    +

    Huesoporro🦴🚬

    -
    +
    -
  • Logout
  • +
  • Logout
  • -
    +
    - + + + + + +
    - - + + + + {% for sentence in sentences %} - + + {% endfor %} +
    SentenceActionSentenceLast modifiedAction
    {{ sentence.sentence }}{{ sentence.last_updated_at }} - +
    + + +
    diff --git a/tests/conftest.py b/tests/conftest.py index 19218d9..c8c6c23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from caribou.migrate import Database as CaribouDatabase from caribou.migrate import load_migrations from src.huesoporro.infra.db import Database -from src.huesoporro.models import ChatbotSettings, TwitchAuth, User +from src.huesoporro.models import ChatbotSettings, User from src.huesoporro.settings import Settings from src.huesoporro.svc.backoff_service import BackoffService from src.huesoporro.svc.is_mod import IsModSvc @@ -15,11 +15,10 @@ from src.huesoporro.svc.is_mod import IsModSvc def user() -> User: return User( user="huesoporro", - expires_at=1671234567.0, - twitch_auth=TwitchAuth( - access_token="test_access_token", # noqa: S106 - refresh_token="test_refresh_token", # noqa: S106 - ), + external_auth={ + "twitch": {"token": "twitch_token"}, + "discord": {"token": "discord_token"}, + }, ) diff --git a/tests/test_repos.py b/tests/test_repos.py new file mode 100644 index 0000000..423e195 --- /dev/null +++ b/tests/test_repos.py @@ -0,0 +1,53 @@ +import json + +import pytest + +from src.huesoporro.infra.repos import UserRepo +from src.huesoporro.models import User + + +@pytest.fixture +async def user_repo(s, db, user: User): + async with db.get_client() as client: + await client.execute( + "INSERT INTO users (user, external_auth) VALUES (?, ?)", + (user.user, json.dumps(user.external_auth)), + ) + + return UserRepo(s=s) + + +async def test_get_user(user_repo: UserRepo, user: User): + db_user = await user_repo.get_by_user(user.user) + assert db_user == user + + +async def test_get_user_returns_none(user_repo: UserRepo): + assert await user_repo.get_by_user("unknown_user") is None + + +async def test_create_user(user_repo: UserRepo): + new_user = User( + user="new_user", external_auth={"twitch": {"token": "twitch_token"}} + ) + assert await user_repo.create(new_user) == new_user + + +async def test_update_users_tokens(user_repo: UserRepo, user: User): + new_tokens = {"twitch": {"token": "new_tokens"}} + user.external_auth = new_tokens # type: ignore[assignment] + assert await user_repo.update(user) == user + + +async def test_update_non_existing_user_raises_value_error(user_repo: UserRepo): + with pytest.raises(ValueError, match="User unknown_user does not exist"): + await user_repo.update( + User( + user="unknown_user", external_auth={"twitch": {"token": "twitch_token"}} + ) + ) + + +async def test_delete_user(user_repo: UserRepo, user: User): + assert await user_repo.delete(user) is None + assert await user_repo.get_by_user(user.user) is None diff --git a/uv.lock b/uv.lock index d861103..ccca98c 100644 --- a/uv.lock +++ b/uv.lock @@ -283,6 +283,18 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] +[[package]] +name = "discord-py" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/af/80cab4015722d3bee175509b7249a11d5adf77b5ff4c27f268558079d149/discord_py-2.4.0.tar.gz", hash = "sha256:d07cb2a223a185873a1d0ee78b9faa9597e45b3f6186df21a95cec1e9bcdc9a5", size = 1027707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/10/3c44e9331a5ec3bae8b2919d51f611a5b94e179563b1b89eb6423a8f43eb/discord.py-2.4.0-py3-none-any.whl", hash = "sha256:b8af6711c70f7e62160bfbecb55be699b5cb69d007426759ab8ab06b1bd77d1d", size = 1125988 }, +] + [[package]] name = "editorconfig" version = "0.12.4" @@ -465,6 +477,7 @@ source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, { name = "caribou" }, + { name = "discord-py" }, { name = "gtts" }, { name = "httpx" }, { name = "litestar", extra = ["standard"] }, @@ -474,6 +487,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pytz" }, { name = "redis" }, { name = "twitchio" }, ] @@ -491,6 +505,7 @@ dev = [ requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "caribou", specifier = ">=0.4.1" }, + { name = "discord-py", specifier = ">=2.4.0" }, { name = "gtts", specifier = ">=2.5.4" }, { name = "httpx", specifier = ">=0.28.0" }, { name = "litestar", extras = ["standard"], specifier = ">=2.13.0" }, @@ -500,6 +515,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.9.2" }, { name = "pydantic-settings", specifier = ">=2.6.0" }, { name = "pyjwt", specifier = ">=2.10.1" }, + { name = "pytz", specifier = ">=2024.2" }, { name = "redis", specifier = ">=5.2.1" }, { name = "twitchio", specifier = ">=2.10.0" }, ] @@ -1101,6 +1117,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, ] +[[package]] +name = "pytz" +version = "2024.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/31/3c70bf7603cc2dca0f19bdc53b4537a797747a58875b552c8c413d963a3f/pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a", size = 319692 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/c3/005fcca25ce078d2cc29fd559379817424e94885510568bc1bc53d7d5846/pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725", size = 508002 }, +] + [[package]] name = "pyyaml" version = "6.0.2"