From 4c534de47b06bf46a18b098d99e94b5296dd5fe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C4=83t=C4=83lin?= Date: Tue, 17 Dec 2024 17:55:02 +0100 Subject: [PATCH] feat: add migrations, api bot endpoints and revamp the whole twitch backend by making use of twitchio --- .gitignore | 15 +- Dockerfile | 7 +- Makefile | 3 + charts/huesoporro/Chart.yaml | 4 +- charts/huesoporro/templates/deployment.yaml | 11 + charts/huesoporro/values.yaml | 2 +- migrations/20241213175820_auth.py | 29 + migrations/20241216204252_quotes.py | 38 ++ migrations/20241217000747_settings.py | 28 + pyproject.toml | 9 +- src/huesoporro/actions/__init__.py | 0 src/huesoporro/actions/store_quote.py | 18 + src/huesoporro/api/__init__.py | 0 src/huesoporro/api/dependencies.py | 48 ++ src/huesoporro/api/errors.py | 45 ++ src/huesoporro/api/main.py | 85 +++ src/huesoporro/api/routes/__init__.py | 0 src/huesoporro/api/routes/api.py | 86 +++ src/huesoporro/api/routes/auth.py | 31 + src/huesoporro/bot.py | 150 +++++ src/huesoporro/chatbot.py | 62 -- src/huesoporro/infra/__init__.py | 0 src/huesoporro/infra/authenticator.py | 62 ++ src/huesoporro/infra/db.py | 134 +++++ src/huesoporro/libs/markov_chain_bot.py | 543 ------------------ src/huesoporro/libs/settings.py | 118 ---- src/huesoporro/main.py | 243 +------- src/huesoporro/models.py | 50 ++ src/huesoporro/settings.py | 4 +- src/huesoporro/svc/__init__.py | 0 src/huesoporro/svc/authenticate.py | 26 + src/huesoporro/svc/generate.py | 158 +++++ src/huesoporro/svc/get_chatbot_settings.py | 11 + src/huesoporro/svc/get_random_quote.py | 10 + src/huesoporro/svc/hello.py | 10 + src/huesoporro/svc/is_mod.py | 12 + src/huesoporro/svc/refresh.py | 27 + src/huesoporro/svc/store.py | 63 ++ src/huesoporro/svc/store_quote.py | 10 + src/huesoporro/svc/store_settings.py | 15 + src/huesoporro/templates/header.html | 1 + src/huesoporro/templates/index.html | 238 ++++---- src/huesoporro/templates/login.html | 29 +- .../{lefunny.html => sentences.html} | 0 uv.lock | 394 ++++++++++++- 45 files changed, 1719 insertions(+), 1110 deletions(-) create mode 100644 migrations/20241213175820_auth.py create mode 100644 migrations/20241216204252_quotes.py create mode 100644 migrations/20241217000747_settings.py create mode 100644 src/huesoporro/actions/__init__.py create mode 100644 src/huesoporro/actions/store_quote.py create mode 100644 src/huesoporro/api/__init__.py create mode 100644 src/huesoporro/api/dependencies.py create mode 100644 src/huesoporro/api/errors.py create mode 100644 src/huesoporro/api/main.py create mode 100644 src/huesoporro/api/routes/__init__.py create mode 100644 src/huesoporro/api/routes/api.py create mode 100644 src/huesoporro/api/routes/auth.py create mode 100644 src/huesoporro/bot.py delete mode 100644 src/huesoporro/chatbot.py create mode 100644 src/huesoporro/infra/__init__.py create mode 100644 src/huesoporro/infra/authenticator.py create mode 100644 src/huesoporro/infra/db.py delete mode 100644 src/huesoporro/libs/markov_chain_bot.py delete mode 100644 src/huesoporro/libs/settings.py create mode 100644 src/huesoporro/models.py create mode 100644 src/huesoporro/svc/__init__.py create mode 100644 src/huesoporro/svc/authenticate.py create mode 100644 src/huesoporro/svc/generate.py create mode 100644 src/huesoporro/svc/get_chatbot_settings.py create mode 100644 src/huesoporro/svc/get_random_quote.py create mode 100644 src/huesoporro/svc/hello.py create mode 100644 src/huesoporro/svc/is_mod.py create mode 100644 src/huesoporro/svc/refresh.py create mode 100644 src/huesoporro/svc/store.py create mode 100644 src/huesoporro/svc/store_quote.py create mode 100644 src/huesoporro/svc/store_settings.py rename src/huesoporro/templates/{lefunny.html => sentences.html} (100%) diff --git a/.gitignore b/.gitignore index bc9dc20..8ad2561 100644 --- a/.gitignore +++ b/.gitignore @@ -114,18 +114,5 @@ src/huesoporro/tts_files/ # Devenv .devenv* devenv.local.nix - # direnv -.direnv - -# pre-commit -.pre-commit-config.yaml -# Devenv -.devenv* -devenv.local.nix - -# direnv -.direnv - -# pre-commit -.pre-commit-config.yaml +.direnv \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 88743df..226e72a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,8 +32,13 @@ COPY --chown=$USERNAME pyproject.toml uv.lock Makefile README.md ./ RUN uv sync COPY --chown=$USERNAME src/ src/ +COPY --chown=$USERNAME migrations/ migrations/ FROM base AS serve -CMD ["make", "serve"] \ No newline at end of file +CMD ["make", "serve"] + +FROM base AS migrate + +CMD ["make", "migrate"] \ No newline at end of file diff --git a/Makefile b/Makefile index 8f3b092..9b87ccf 100644 --- a/Makefile +++ b/Makefile @@ -16,3 +16,6 @@ serve: build: docker build . -t git.roboces.dev/catalin/$(PROJECT_NAME):$(PROJECT_TAG) --target $(PROJECT_TARGET) + +migrate: + uv run caribou upgrade ~/.local/share/huesoporro/huesoporro.db migrations/ \ No newline at end of file diff --git a/charts/huesoporro/Chart.yaml b/charts/huesoporro/Chart.yaml index be2cfe9..ac75742 100644 --- a/charts/huesoporro/Chart.yaml +++ b/charts/huesoporro/Chart.yaml @@ -15,10 +15,10 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.2.1 +version: 0.2.2 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.2.1" +appVersion: "0.2.2" diff --git a/charts/huesoporro/templates/deployment.yaml b/charts/huesoporro/templates/deployment.yaml index 0a9fb22..180e12d 100644 --- a/charts/huesoporro/templates/deployment.yaml +++ b/charts/huesoporro/templates/deployment.yaml @@ -40,6 +40,17 @@ spec: mountPath: /data securityContext: runAsUser: 0 + - name: migrate + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - make + - migrate + {{- if .Values.persistence.enabled }} + volumeMounts: + - name: data + mountPath: /home/huesoporro/.local/share/huesoporro + {{- end }} {{- end }} containers: - name: {{ .Chart.Name }} diff --git a/charts/huesoporro/values.yaml b/charts/huesoporro/values.yaml index 60a71f7..ffa36f2 100644 --- a/charts/huesoporro/values.yaml +++ b/charts/huesoporro/values.yaml @@ -11,7 +11,7 @@ image: # This sets the pull policy for images. pullPolicy: Always # Overrides the image tag whose default is the chart appVersion. - tag: "0.2.1" + tag: "0.2.2" # This is for the secretes for pulling an image from a private repository more information can be found here: https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ imagePullSecrets: [] diff --git a/migrations/20241213175820_auth.py b/migrations/20241213175820_auth.py new file mode 100644 index 0000000..d7a6f87 --- /dev/null +++ b/migrations/20241213175820_auth.py @@ -0,0 +1,29 @@ +""" +This module contains a Caribou migration. + +Migration Name: auth +Migration Version: 20241213175820 +""" + + +def upgrade(connection): + # add your upgrade step here + sql = """ + create table users + ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + user varchar(255) NOT NULL UNIQUE, + access_token varchar(255) NOT NULL, + refresh_token varchar(255) NOT NULL, + expires_at TIMESTAMP NOT NULL, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20241216204252_quotes.py b/migrations/20241216204252_quotes.py new file mode 100644 index 0000000..411cd4b --- /dev/null +++ b/migrations/20241216204252_quotes.py @@ -0,0 +1,38 @@ +""" +This module contains a Caribou migration. + +Migration Name: quotes +Migration Version: 20241216204252 +""" + + +def upgrade(connection): + # add your upgrade step here + sql = """ + create table quotes + ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + quote varchar(255) NOT NULL UNIQUE, + author varchar(255), + channel varchar(255), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ + connection.execute(sql) + sql = """ + create table sentences + ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + sentence varchar(255) NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/migrations/20241217000747_settings.py b/migrations/20241217000747_settings.py new file mode 100644 index 0000000..23568ae --- /dev/null +++ b/migrations/20241217000747_settings.py @@ -0,0 +1,28 @@ +""" +This module contains a Caribou migration. + +Migration Name: settings +Migration Version: 20241217000747 +""" + + +def upgrade(connection): + sql = """ + create table settings( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + user_id VARCHAR(255) NOT NULL UNIQUE, + automatic_generation_timer INTENGER NOT NULL DEFAULT 300, + automatic_quote_timer INTEGER NOT NULL DEFAULT 500, + mods VARCHAR(255), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(user) + ); + """ + connection.execute(sql) + connection.commit() + + +def downgrade(connection): + # add your downgrade step here + pass diff --git a/pyproject.toml b/pyproject.toml index 032a6dc..543c278 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "huesoporro" -version = "0.2.1" +version = "0.2.2" description = "Misc Twitch bots" readme = "README.md" authors = [ @@ -20,6 +20,13 @@ dependencies = [ "gtts>=2.5.4", "litestar[standard]>=2.13.0", "httpx>=0.28.0", + "caribou>=0.4.1", + "aiosqlite>=0.20.0", + "pyjwt>=2.10.1", + "huey>=2.5.2", + "twitchio>=2.10.0", + "redis>=5.2.1", + "pytest>=8.3.4", ] [tool.uv] diff --git a/src/huesoporro/actions/__init__.py b/src/huesoporro/actions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/huesoporro/actions/store_quote.py b/src/huesoporro/actions/store_quote.py new file mode 100644 index 0000000..719c776 --- /dev/null +++ b/src/huesoporro/actions/store_quote.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from src.huesoporro.models import User +from src.huesoporro.svc.is_mod import IsModSvc +from src.huesoporro.svc.store_quote import QuoteStorerSvc + + +class StoreQuoteAction(BaseModel): + quote_storer_svc: QuoteStorerSvc + is_mod_svc: IsModSvc + + async def run( + self, user: User, channel: str, quote: str, author: str, username: str + ) -> str: + if not await self.is_mod_svc.run(user=user, username=username): + return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!" + await self.quote_storer_svc.run(channel, quote, author) + return f"«{quote}» added by {author}." diff --git a/src/huesoporro/api/__init__.py b/src/huesoporro/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/huesoporro/api/dependencies.py b/src/huesoporro/api/dependencies.py new file mode 100644 index 0000000..5db8371 --- /dev/null +++ b/src/huesoporro/api/dependencies.py @@ -0,0 +1,48 @@ +from litestar import Request +from litestar.exceptions import HTTPException + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.db import Database +from src.huesoporro.models import User +from src.huesoporro.settings import Settings +from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc +from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc + + +def get_settings() -> Settings: + return Settings.get() + + +def get_authenticator(s: Settings) -> TwitchAuthenticator: + return TwitchAuthenticator(s=s) + + +def get_db(s: Settings): + return Database(s=s) + + +async def authenticate(request: Request) -> User: + token = request.query_params.get("huesoporro_token") + if token: + return User.decode(token) + + cookies = request.cookies.get("huesoporroAuth") + if cookies: + return User.decode(cookies) + + raise HTTPException(status_code=401, detail="Unauthorized") + + +async def get_code_authenticator_svc( + a: TwitchAuthenticator, db: Database +) -> CodeAuthenticatorSvc: + return CodeAuthenticatorSvc(authenticator=a, db=db) + + +async def get_chatbot_settings_svc(db: Database): + return ChatbotSettingsGetterSvc(db=db) + + +async def store_chatbot_settings_svc(db: Database): + return ChatbotSettingsStorerSvc(db=db) diff --git a/src/huesoporro/api/errors.py b/src/huesoporro/api/errors.py new file mode 100644 index 0000000..f530ae3 --- /dev/null +++ b/src/huesoporro/api/errors.py @@ -0,0 +1,45 @@ +import httpx +from litestar import MediaType, Request, Response +from litestar.exceptions import HTTPException +from litestar.response import Redirect +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from loguru import logger + + +def http_exception_handler(_: Request, exc: HTTPException) -> Response: + status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) + detail = getattr(exc, "detail", "") + + if isinstance(exc, HTTPException) and (exc.status_code in [401, 403]): + logger.warning("User could not authenticate. Redirecting to /login page") + return Redirect("/login") + + return Response( + media_type=MediaType.TEXT, + content=detail, + status_code=status_code, + ) + + +def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError): + logger.error(f"HTTPX error occurred: {exc}") + return Response( + media_type=MediaType.TEXT, + content=f"HTTPX error occurred: {exc}", + status_code=exc.response.status_code, + ) + + +async def after_exception_handler(exc: Exception, scope: "Scope") -> None: + """Hook function that will be invoked after each exception.""" + state = scope["app"].state + if not hasattr(state, "error_count"): + state.error_count = 1 + else: + state.error_count += 1 + logger.error( + f"an exception of type {type(exc).__name__} has occurred for requested path {scope['path']} and the application error count is {state.error_count}.", + ) + import traceback + + traceback.print_exc() diff --git a/src/huesoporro/api/main.py b/src/huesoporro/api/main.py new file mode 100644 index 0000000..07dba33 --- /dev/null +++ b/src/huesoporro/api/main.py @@ -0,0 +1,85 @@ +import httpx +from litestar import Litestar, get +from litestar.contrib.jinja import JinjaTemplateEngine +from litestar.di import Provide +from litestar.exceptions import HTTPException +from litestar.static_files import StaticFilesConfig +from litestar.template import TemplateConfig + +from src.huesoporro.api.dependencies import ( + authenticate, + get_authenticator, + get_chatbot_settings_svc, + get_code_authenticator_svc, + get_db, + get_settings, + store_chatbot_settings_svc, +) +from src.huesoporro.api.errors import ( + after_exception_handler, + http_exception_handler, + httpx_status_error_handler, +) +from src.huesoporro.api.routes.api import ( + get_bot_settings, + get_bot_status, + get_index, + get_tts_overlay, + get_tts_permalink, + manage_bot, + save_bot_settings, +) +from src.huesoporro.api.routes.auth import get_code, login +from src.huesoporro.bot import BotsManager +from src.huesoporro.settings import Settings + + +@get("/healthz") +def get_health() -> dict: + return {"status": "ok"} + + +def create_app(): + return Litestar( + route_handlers=[ + get_health, + login, + get_index, + get_tts_overlay, + get_tts_permalink, + get_code, + manage_bot, + get_bot_status, + save_bot_settings, + get_bot_settings, + ], + static_files_config=( + StaticFilesConfig( + path="/tts_files", + directories=[Settings.get().tts_cache_path], + ), + StaticFilesConfig( + path="static", + directories=[Settings.get().static_files_path], + ), + ), + template_config=TemplateConfig( + directory=Settings.get().templates_files_path, + engine=JinjaTemplateEngine, + ), + exception_handlers={ + HTTPException: http_exception_handler, + httpx.HTTPStatusError: httpx_status_error_handler, + }, + after_exception=[after_exception_handler], + dependencies={ + "s": Provide(get_settings, use_cache=True), + "a": Provide(get_authenticator, use_cache=True), + "user": Provide(authenticate), + "db": Provide(get_db, use_cache=True), + "code_authenticator_svc": Provide(get_code_authenticator_svc), + "bm": Provide(BotsManager, use_cache=True), + "gbs": Provide(get_chatbot_settings_svc), + "sbs": Provide(store_chatbot_settings_svc), + }, + ) diff --git a/src/huesoporro/api/routes/__init__.py b/src/huesoporro/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/huesoporro/api/routes/api.py b/src/huesoporro/api/routes/api.py new file mode 100644 index 0000000..9161f28 --- /dev/null +++ b/src/huesoporro/api/routes/api.py @@ -0,0 +1,86 @@ +from typing import Literal + +from litestar import MediaType, Response, get, put +from litestar.response import Template +from pydantic import BaseModel + +from src.huesoporro.bot import BotsManager +from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc +from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc + + +class ManageBotDTO(BaseModel): + command: Literal["start", "stop"] + channel_name: str | None = None + + +@get( + "/tts", + media_type=MediaType.HTML, +) +async def get_tts_overlay() -> Template: + return Template(template_name="tts.html") + + +@get( + "/tts/permalink", + media_type=MediaType.HTML, +) +async def get_tts_permalink(access_token: str) -> Template: + """Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query + param and not as a cookie, i.e. OBS""" + + # authenticate the user using the provided access token + + return Template( + template_name="tts.html", + ) + + +@get( + "/", + media_type=MediaType.HTML, +) +async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template: + chatbot_settings = await gbs.run(user=user) + return Template(template_name="index.html", context=chatbot_settings.model_dump() if chatbot_settings else {}) + + +@put("/api/v1/bot") +async def manage_bot( + user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager +) -> Response: + chatbot_settings = await gbs.run(user=user) + if data.command == "start": + if not data.channel_name: + return Response({"message": "Channel name is required"}, status_code=400) + bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) + if user.user in bm.bots: + await bm.run_user_bot(user) + return Response({"message": "Bot started"}) + if data.command == "stop" and user.user in bm.bots: + await bm.stop_user_bot(user) + return Response({"message": "Bot stopped"}) + + +@get("/api/v1/bot") +async def get_bot_status(user: User, bm: BotsManager) -> dict: + if user.user not in bm.bots: + return {"status": "ko"} + return {"status": "ok"} + + +@get("/api/v1/bot/settings") +async def get_bot_settings( + user: User, gbs: ChatbotSettingsGetterSvc +) -> ChatbotSettings: + return await gbs.run(user=user) + + +@put("/api/v1/bot/settings") +async def save_bot_settings( + user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc +) -> dict: + await sbs.run(user=user, bot_settings=data) + return {"status": "ok"} diff --git a/src/huesoporro/api/routes/auth.py b/src/huesoporro/api/routes/auth.py new file mode 100644 index 0000000..2f1c287 --- /dev/null +++ b/src/huesoporro/api/routes/auth.py @@ -0,0 +1,31 @@ +import secrets + +from litestar import MediaType, get +from litestar.response import Redirect, Template + +from src.huesoporro.settings import Settings +from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc + + +@get(path="/o/code") +async def get_code(code: str, code_authenticator_svc: CodeAuthenticatorSvc) -> Redirect: + user = await code_authenticator_svc.run(code) + return Redirect("/", cookies={"huesoporroAuth": user.encode()}) + + +@get( + "/login", + media_type=MediaType.HTML, +) +async def login(s: Settings) -> Template: + scopes = "+".join(s.twitch_scopes) + return Template( + "login.html", + context={ + "twitch_login_url": "https://id.twitch.tv/oauth2/authorize?response_type=code" + f"&client_id={s.twitch_client_id}" + f"&redirect_uri={s.server_hostname}o/code" + f"&scope={scopes}" + f"&state={secrets.token_urlsafe(32)}" + }, + ) diff --git a/src/huesoporro/bot.py b/src/huesoporro/bot.py new file mode 100644 index 0000000..a842586 --- /dev/null +++ b/src/huesoporro/bot.py @@ -0,0 +1,150 @@ +import asyncio + +from loguru import logger +from twitchio import Channel +from twitchio.ext import commands, routines + +from src.huesoporro.actions.store_quote import StoreQuoteAction +from src.huesoporro.infra.db import Database +from src.huesoporro.libs.db import Database as MarkovDB +from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.svc.generate import SentenceGeneratorSvc +from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc +from src.huesoporro.svc.hello import HelloGeneratorSvc +from src.huesoporro.svc.is_mod import IsModSvc +from src.huesoporro.svc.store import SentenceStorerSvc +from src.huesoporro.svc.store_quote import QuoteStorerSvc + + +class Bot(commands.Bot): + def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str): + super().__init__( + token=user.twitch_auth.access_token, prefix="!", initial_channels=[channel] + ) + self.channel = channel + self.user = user + self.generate_svc = SentenceGeneratorSvc(db=MarkovDB(channel=channel)) + self.hello_svc = HelloGeneratorSvc() + db = Database() + self.store_quote_action = StoreQuoteAction( + quote_storer_svc=QuoteStorerSvc(db=db), is_mod_svc=IsModSvc(db=db) + ) + + self.get_random_quote_svc = RandomQuoteGetterSvc(db=db) + + self.quote_routine = routines.routine( + seconds=chatbot_settings.automatic_quote_timer, wait_first=True + )(self.send_quote) + self.generation_routine = routines.routine( + seconds=chatbot_settings.automatic_generation_timer, wait_first=True + )(self.send_generation) + + async def event_ready(self): + logger.info(f"Logged in as {self.nick}") + logger.info(f"User id is {self.user_id}") + + @commands.command() + async def hello(self, ctx: commands.Context, user: User | None = None): + username = user.name if user else ctx.author.name + await ctx.send(self.hello_svc.run(username)) + + @commands.command(aliases=["g"]) + async def generate(self, ctx: commands.Context, *, words: str | None = None): + sentence = await self.generate_svc.run(words) + if sentence: + await ctx.send(sentence) + + @commands.command(aliases=["qadd"]) + async def add_quote(self, ctx: commands.Context, *, quote: str): + # extract author from quote; the author is the last word + quote, author = quote.rsplit(" ", 1) + await ctx.send( + await self.store_quote_action.run( + user=self.user, + channel=self.channel, + quote=quote, + author=author, + username=ctx.author.name, + ) + ) + + @commands.command(aliases=["q", "quote"]) + async def get_random_quote(self, ctx: commands.Context): + quote = await self.get_random_quote_svc.run(channel_name=self.channel) + await ctx.send(f"«{quote[0]}» - {quote[1]}") + + def get_channel_conn(self) -> Channel: + return Channel(name=self.channel, websocket=self._connection) + + async def send_quote(self): + quote = await self.get_random_quote_svc.run(channel_name=self.channel) + channel = self.get_channel_conn() + logger.info(f"Sending random quote {quote[0]}") + await channel.send(f"«{quote[0]}» - {quote[1]}") + + async def send_generation(self): + sentence = await self.generate_svc.run() + if not sentence: + return + channel = self.get_channel_conn() + logger.info(f"Sending generated sentence {sentence}") + await channel.send(sentence) + + def start_routines(self): + logger.info("Starting routines") + self.quote_routine.start(stop_on_error=False) + self.generation_routine.start(stop_on_error=False) + + def stop_routines(self): + logger.info("Stopping routines") + self.quote_routine.cancel() + self.generation_routine.cancel() + + +class SaveMessagesCog(commands.Cog): + def __init__(self, bot): + self.bot = bot + self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel)) + + @commands.Cog.event() + async def event_message(self, message): + # An event inside a cog! + content = message.content + if content.startswith("!"): + return + + if not message.author: + return + + await self.store_svc.run(content) + + +class BotsManager: + def __init__(self): + self.bots: dict[str, Bot] = {} + + def add_bot(self, user: User, channel: str, chatbot_settings: ChatbotSettings): + if user.user in self.bots: + logger.info(f"Bot for {user.user} already exists") + return + logger.info(f"Adding bot for {user.user}") + bot = Bot(user=user, channel=channel, chatbot_settings=chatbot_settings) + bot.add_cog(SaveMessagesCog(bot)) + self.bots[user.user] = bot + + async def run_user_bot(self, user: User): + if user.user not in self.bots: + return + + logger.info(f"Starting bot for {user.user}") + bot = self.bots[user.user] + task = asyncio.create_task(bot.start()) + task.add_done_callback(lambda x: logger.info(f"Bot for {user.user} stopped")) + bot.start_routines() + + async def stop_user_bot(self, user: User): + if user.user not in self.bots: + return + bot = self.bots.pop(user.user) + await bot.close() + bot.stop_routines() diff --git a/src/huesoporro/chatbot.py b/src/huesoporro/chatbot.py deleted file mode 100644 index 30f04f8..0000000 --- a/src/huesoporro/chatbot.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from asyncio import sleep as asleep -from queue import Queue -from time import sleep - -import nltk -from litestar import WebSocket -from loguru import logger - -from src.huesoporro.libs.markov_chain_bot import MarkovChain -from src.huesoporro.libs.settings import Settings as MarkovChainSettings -from src.huesoporro.value_objects import WebsocketCommands, WebsocketMessage - -nltk.download("punkt_tab") - - -class ChatbotManager: - def __init__(self): - self.bot: MarkovChain | None = None - self.clients: set[WebSocket] = set() - self.log_queue: Queue = Queue() - self.tasks: set = set() - - def start_bot( - self, - channel_name: str, - nickname: str, - authentication: str, - ): - task = asyncio.create_task(self.send_bot_status()) - self.tasks.add(task) - if self.bot: - return - self.bot = MarkovChain( - settings=MarkovChainSettings( - Channel=channel_name, - Nickname=nickname, - Authentication=authentication, - AutomaticGenerationTimer=300, - ), - ) - - self.bot.run_bot() - sleep(2) - - def stop_bot(self): - self.bot.stop_bot() - self.bot = None - - async def send_bot_status(self): - while True: - for client in self.clients: - message = WebsocketMessage( - command=WebsocketCommands.CHATBOT_STATUS, - data={"status": "ok" if self.bot else "ko"}, - ) - await client.send_text(message.model_dump_json()) - logger.info( - f"Sending bot status {message} to {client.client.host}:{client.client.port}" - ) - - await asleep(2) diff --git a/src/huesoporro/infra/__init__.py b/src/huesoporro/infra/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/huesoporro/infra/authenticator.py b/src/huesoporro/infra/authenticator.py new file mode 100644 index 0000000..9bad27c --- /dev/null +++ b/src/huesoporro/infra/authenticator.py @@ -0,0 +1,62 @@ +import httpx +from litestar.exceptions import HTTPException +from pydantic import BaseModel, ConfigDict, Field + +from src.huesoporro.models import TwitchAuth +from src.huesoporro.settings import Settings + + +class TwitchAuthenticator(BaseModel): + s: Settings = Field(default_factory=Settings.get) + client: httpx.AsyncClient = Field( + default_factory=lambda x: httpx.AsyncClient(base_url="https://id.twitch.tv/") + ) + model_config = ConfigDict(arbitrary_types_allowed=True) + + async def get_token(self, code: str, auto_refresh: bool = True) -> TwitchAuth: + response = await self.client.post( + "/oauth2/token", + data={ + "client_id": Settings.get().twitch_client_id, + "client_secret": Settings.get().twitch_client_secret.get_secret_value(), + "grant_type": "authorization_code", + "code": code, + "redirect_uri": f"{Settings.get().server_hostname}o/code", + }, + headers={"Accept": "application/json"}, + ) + + if auto_refresh and response.status_code == 401: + return await self.refresh_token(response.json()["refresh_token"]) + + response.raise_for_status() + return TwitchAuth(**response.json()) + + async def refresh_token(self, refresh_token: str) -> TwitchAuth: + response = await self.client.post( + "/oauth2/token", + data={ + "client_id": Settings.get().twitch_client_id, + "client_secret": Settings.get().twitch_client_secret.get_secret_value(), + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + headers={"Accept": "application/json"}, + ) + response.raise_for_status() + return TwitchAuth(**response.json()) + + async def validate_token(self, access_token: str) -> str: + response = await self.client.get( + "/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"} + ) + response.raise_for_status() + user_data = response.json() + + if user_data.get("status"): + raise HTTPException(status_code=401, detail="Unauthorized") + + if (user := user_data["login"]) not in self.s.allowed_users: + raise HTTPException(status_code=403, detail="Forbidden") + + return user diff --git a/src/huesoporro/infra/db.py b/src/huesoporro/infra/db.py new file mode 100644 index 0000000..3bb481c --- /dev/null +++ b/src/huesoporro/infra/db.py @@ -0,0 +1,134 @@ +import datetime +from contextlib import asynccontextmanager + +import aiosqlite +from pydantic import BaseModel, Field + +from src.huesoporro.models import ChatbotSettings, User +from src.huesoporro.settings import Settings +from loguru import logger + + +class Database(BaseModel): + s: Settings = Field(default_factory=Settings.get) + + @asynccontextmanager + async def get_client(self, auto_commit=True): + logger.info(f"Opening database connection: {self.s.db_filepath}") + async with aiosqlite.connect(self.s.db_filepath) as db: + yield db + if auto_commit: + await db.commit() + + @staticmethod + def get_now() -> float: + return datetime.datetime.now(datetime.UTC).timestamp() + + async def save_user(self, user: User, auto_commit=True): + async with self.get_client(auto_commit=auto_commit) as db: + async with db.execute( + "SELECT * FROM users WHERE user = ?", (user.user,) + ) as cursor: + result = await cursor.fetchone() + if result: + await db.execute( + "UPDATE users SET access_token = ?, refresh_token = ?, expires_at = ?, last_updated_at = ? WHERE user = ?", + ( + user.twitch_auth.access_token, + user.twitch_auth.refresh_token, + user.expires_at, + self.get_now(), + user.user, + ), + ) + return + + await db.execute( + "INSERT INTO users (user, access_token, refresh_token, expires_at, last_updated_at) VALUES (?,?,?,?,?)", + ( + user.user, + user.twitch_auth.access_token, + user.twitch_auth.refresh_token, + user.expires_at, + self.get_now(), + ), + ) + + async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True): + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + "INSERT INTO quotes (channel, quote, author) VALUES (?,?,?)", + (channel, quote, author), + ) + + async def save_chatbot_settings( + self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True + ): + async with self.get_client(auto_commit=auto_commit) as db: + current_settings = await self.get_chatbot_settings(user) + if current_settings: + await db.execute( + """UPDATE settings SET + automatic_generation_timer = ?, + automatic_quote_timer = ?, + mods = ?, + last_updated_at = ? + WHERE user_id = ? + """, + ( + chatbot_settings.automatic_generation_timer, + chatbot_settings.automatic_quote_timer, + chatbot_settings.mods_as_string, + self.get_now(), + user.user, + ), + ) + return + + await db.execute( + """INSERT INTO settings ( + user_id, + automatic_generation_timer, + automatic_quote_timer, + mods, + created_at, + last_updated_at + ) VALUES(?,?,?,?,?,?) + """, + ( + user.user, + chatbot_settings.automatic_generation_timer, + chatbot_settings.automatic_quote_timer, + chatbot_settings.mods_as_string, + self.get_now(), + self.get_now(), + ), + ) + + async def get_chatbot_settings(self, user: User) -> ChatbotSettings | None: + async with self.get_client() as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM settings WHERE user_id = ?", (user.user,) + ) as cursor: + result = await cursor.fetchone() + if not result: + return None + return ChatbotSettings(**dict(result)) + + async def save_sentence(self, sentence: str, auto_commit=True): + async with self.get_client(auto_commit=auto_commit) as db: + await db.execute( + "INSERT INTO sentences (sentence) VALUES (?)", + (sentence,), + ) + await db.commit() + + async def get_random_quote(self, channel_name: str): + async with self.get_client() as db: + async with db.execute( + "SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1", + (channel_name,), + ) as cursor: + result = await cursor.fetchone() + return result diff --git a/src/huesoporro/libs/markov_chain_bot.py b/src/huesoporro/libs/markov_chain_bot.py deleted file mode 100644 index 68b67a8..0000000 --- a/src/huesoporro/libs/markov_chain_bot.py +++ /dev/null @@ -1,543 +0,0 @@ -import string -import time -from enum import StrEnum - -from loguru import logger -from nltk.tokenize import sent_tokenize -from TwitchWebsocket import Message, TwitchWebsocket - -from src.huesoporro.libs.db import Database -from src.huesoporro.libs.settings import Settings -from src.huesoporro.libs.timer import LoopingTimer -from src.huesoporro.libs.tokenizer import detokenize, tokenize - - -class Commands(StrEnum): - SET_COOLDOWN = "!setcd" - GENERATE = "!g" - BLACKLIST = "!blacklist" - GENERATE_HELP = "!ghelp" - QUOTE = "!q" - QUOTE_ADD = "!qadd" - - -class MarkovChain: - end_tag = "" - - def __init__(self, settings: Settings | None = None): - self.s = settings or Settings.read() - self.prev_message_t = 0.0 - self._enabled = True - - self.db = Database(self.s.channel_name) - - if self.s.help_message_timer > 0: - if self.s.help_message_timer < 300: # noqa: PLR2004 - raise ValueError( - 'Value for "HelpMessageTimer" in must be at least 300 seconds, ' # noqa: EM101 - "or a negative number for no help messages.", - ) - t = LoopingTimer(self.s.help_message_timer, self._command_help) - t.start() - - # Set up daemon Timer to send automatic generation messages - if self.s.automatic_generation_timer > 0: - if self.s.automatic_generation_timer < 30: # noqa: PLR2004 - raise ValueError( - 'Value for "Automatic_generation_message" must be at least 30 seconds, or a negative number for no ' # noqa: EM101 - "automatic generations.", - ) - logger.info( - f"Automatic generation enabled, will send messages every {self.s.automatic_generation_timer} seconds" - ) - t = LoopingTimer( - self.s.automatic_generation_timer, - self._command_automatic_generation, - ) - t.start() - - self.ws = TwitchWebsocket( - host=self.s.host, - port=self.s.port, - chan=self.s.channel_name, - nick=self.s.nickname, - auth=self.s.authentication, - callback=self.message_handler, - capability=["commands", "tags"], - live=True, - ) - - def run_bot(self): - self.ws.start_nonblocking() - - def stop_bot(self): - self.ws.leave_channel(self.s.channel_name) - self.ws.stop() - logger.info("Stopped bot") - - def _command_help(self) -> None: - """Send a Help message to the connected chat, as long as the bot wasn't disabled.""" - if self._enabled: - logger.info("Help message sent.") - try: - self.ws.send_message( - "Learn how this bot generates sentences here: https://github.com/CubieDev/TwitchMarkovChain#how-it-works", - ) - except OSError as error: - logger.warning( - f"[OSError: {error}] upon sending help message. Ignoring.", - ) - - def _command_set_cooldown(self, username: str, split_message: list[str]): - if len(split_message) == 2: # noqa: PLR2004 - try: - cooldown = int(split_message[1]) - except ValueError: - self.ws.send_whisper( - username, - "The parameter must be an integer amount, eg: !setcd 30", - ) - return - self.s.cooldown = cooldown - self.s.write() - self.ws.send_whisper( - username, - f"The !generate cooldown has been set to {cooldown} seconds.", - ) - - def _command_blacklist(self, username: str, split_message: list[str]): - if len(split_message) == 2: # noqa: PLR2004 - try: - blacklisted_username = split_message[1] - except ValueError: - self.ws.send_whisper( - username, - "The parameter must be a username, eg: !blacklist ibai", - ) - return - self.s.denied_users.append(blacklisted_username) - self.s.write() - - def _command_generate(self, username: str, message: str): - cur_time = time.time() - if self.prev_message_t + self.s.cooldown >= cur_time: - if not self.db.check_whisper_ignore(username): - self.send_whisper( - username, - f"Cooldown hit: {self.prev_message_t + self.s.cooldown - cur_time:0.2f} out of {self.s.cooldown:.0f}s remaining. !nopm to stop these cooldown pm's.", - ) - logger.info( - f"Cooldown hit with {self.prev_message_t + self.s.cooldown - cur_time:0.2f}s remaining.", - ) - params = tokenize(message)[2:] if self.s.allow_generate_params else None - # Generate an actual sentence - sentence, success = self.generate(params) - if success: - # Reset cooldown if a message was actually generated - self.prev_message_t = time.time() - logger.info(sentence) - self.ws.send_message(sentence) - - self.store_sentence(message) - - def _command_automatic_generation(self) -> None: - """Send an automatic generation message to the connected chat. - - As long as the bot wasn't disabled, just like if someone typed "!g" in chat. - """ - if self._enabled: - logger.debug("Automatically generating message") - sentence, success = self.generate() - if success: - logger.info( - f"Created '{sentence}'. Cooling down for {self.s.automatic_generation_timer} seconds before regenerating", - ) - try: - self.ws.send_message(sentence) - except OSError as error: - logger.warning( - f"[OSError: {error}] upon sending automatic generation message. Ignoring.", - ) - else: - logger.info( - "Attempted to output automatic generation message, but there is not enough learned information yet.", - ) - - def _command_quote(self): - """Retrieve a random quote from the `quotes` table and format it as - - > «» - - """ - data = self.db.execute( - "SELECT quote, author FROM quotes ORDER BY RANDOM() LIMIT 1;", fetch=True - ) - if data: - data = data[0] - quote, author = data[0], data[1] - self.ws.send_message(f"«{quote}» - {author}") - - def _command_add_quote(self, message: str): - """Add a quote to the quotes table. The message should follow the format: - - !qadd quote author - - The last word will be parsed as the author and anything in between !qadd and the author will be considered - as the quote itself - """ - # Split the message into quote and author - parts = message.split() - author = parts[-1] - quote = " ".join(parts[1:-1]) - - data = self.db.execute( - "SELECT 1 FROM quotes WHERE quote = ?", (quote,), fetch=True - ) - if data: - self.ws.send_message(f"Quote «{quote}» was already added.") - return - - self.db.execute( - "INSERT INTO quotes (quote, author) VALUES (?, ?)", - (quote, author), # type: ignore[arg-type] - ) - self.ws.send_message(f"Quote «{quote}» by {author} added.") - - def store_sentence(self, message: str): - logger.info(f"Processing {message} in order to store it") - stripped_message = message.strip() - try: - sentences = sent_tokenize(stripped_message) - except LookupError: - logger.debug("Downloading required punkt resource...") - import nltk - - nltk.download("punkt") - logger.debug("Downloaded required punkt resource.") - sentences = sent_tokenize(stripped_message) - - for sentence in sentences: - words = tokenize(sentence) - # Double spaces will lead to invalid rules. We remove empty words here - if "" in words: - words = [word for word in words if word] - - # If the sentence is too short, ignore it and move on to the next. - if len(words) <= self.s.key_length: - continue - - # Add a new starting point for a sentence to the - words = [words[x] for x in range(self.s.key_length)] - logger.debug(f"Adding {words} to start queue") - self.db.add_start_queue(words) - - # Create Key variable which will be used as a key in the Dictionary for the grammar - key: list[str] = [] - for word in words: - # Set up key for first use - if len(key) < self.s.key_length: - key.append(word) - continue - logger.debug(f"Adding {key}[{word}] to rule queue") - self.db.add_rule_queue([*key, word]) - - # Remove the first word, and add the current word, - # so that the key is correct for the next word. - key.pop(0) - key.append(word) - logger.debug(f"Adding {key} to rule queue") - # Add at the end of the sentence - self.db.add_rule_queue([*key, self.end_tag]) - - def message_handler(self, message: Message): # noqa: C901, PLR0911, PLR0912 - try: - """ - tts_message = { - "badge-info": "subscriber/4", - "badges": "vip/1,subscriber/3,sub-gifter/5", - "color": "#F79AC6", - "custom-reward-id": "8c454446-73b0-480f-946e-d6b5f5c5e331", - "display-name": "robosap1ens__", - "emotes": "", - "first-msg": "0", - "flags": "", - "id": "6cbd37eb-49ae-41f5-b073-345275c91a07", - "mod": "0", - "returning-chatter": "0", - "room-id": "600944302", - "subscriber": "1", - "tmi-sent-ts": "1733252657689", - "turbo": "0", - "user-id": "713968248", - "user-type": "", - "vip": "1", - } - """ - if not message.user or message.user in self.s.denied_users: - logger.debug(f"User {message.user} can't send messages") - return - - msgs = message.message.split() - if not msgs: - logger.debug("Message is empty") - return - - if "bits" in message.tags: - return - - if "emotes" in message.tags: - # Replace modified emotes with normal versions, - # as the bot will never have the modified emotes unlocked at the time. - for modifier in self.extract_modifiers(message.tags["emotes"]): - message.message = message.message.replace(modifier, "") - - logger.debug(f"Received {msgs[0]} command from {message.user}") - match msgs[0]: - case Commands.GENERATE_HELP: - logger.debug("Executing _command_help()") - self._command_help() - - case Commands.SET_COOLDOWN: - if self.is_mod(message.user, message.channel): - logger.debug( - f"User {message.user} is mod, executing _command_set_cooldown()", - ) - self._command_set_cooldown( - split_message=msgs, - username=message.user, - ) - - case Commands.BLACKLIST: - if self.is_mod(message.user, message.channel): - logger.debug( - f"User {message.user} is a mod, executing _command_blacklist()", - ) - self._command_blacklist( - split_message=msgs, - username=message.user, - ) - - case Commands.GENERATE: - if not self._enabled: - logger.info("Bot not enabled, skipping") - return - if message.user not in self.s.denied_users: - logger.info( - f"User {message.user} allowed to generate, executing _command_generate()", - ) - self._command_generate( - message=message.message, - username=message.user, - ) - - case Commands.QUOTE: - if not self._enabled: - logger.info("Bot not enabled, skipping") - return - if message.user not in self.s.denied_users: - logger.info( - f"User {message.user} allowed to generate, executing _command_quote()", - ) - self._command_quote() - - case Commands.QUOTE_ADD: - if self.is_mod(message.user, message.channel): - logger.info( - f"User {message.user} allowed to create quote, executing _command_quote()", - ) - self._command_add_quote(message.message) - return - self.ws.send_message( - f"@{message.user} you're not in the modlist, you can't add quotes" - ) - - case _: - logger.debug( - f"Not a command: {msgs[0]}. Storing into db as a plain message", - ) - if message.type == "366": - logger.info(f"Successfully joined channel: #{message.channel}") - return - self.store_sentence(message.message) - - except Exception: # noqa: BLE001 - logger.exception(f"Could not process message {message}") - - def generate(self, params: list[str] | None = None) -> tuple[str, bool]: # noqa: C901, PLR0912 - """Given an input sentence, generate the remainder of the sentence using the learned data. - - Args: - params (list[str]): A list of words to use as an input to use as the start of generating. - - Returns: - tuple[str, bool]: A tuple of a sentence as the first value, and a boolean indicating - whether the generation succeeded as the second value. - """ - params = params or [] - - # List of sentences that will be generated. In some cases, multiple sentences will be generated, - # e.g. when the first sentence has less words than self.min_sentence_length. - sentences: list[list | list[str]] = [[]] - - # Check for commands or recursion, eg: !generate !generate - if len(params) > 0 and self.is_command(params[0]): - return "You can't make me do commands, you madman!", False - - # Get the starting key and starting sentence. - # If there is more than 1 param, get the last 2 as the key. - # Note that self.s.key_length is fixed to 2 in this implementation - if len(params) > 1: - key = params[-self.s.key_length :] - # Copy the entire params for the sentence - sentences[0] = params.copy() - - elif len(params) == 1: - # First we try to find if this word was once used as the first word in a sentence: - key = self.db.get_next_single_start(params[0]) # type: ignore[assignment] - if key is None: - # If this failed, we try to find the next word in the grammar as a whole - key = self.db.get_next_single_initial(0, params[0]) - if key is None: - # Return a message that this word hasn't been learned yet - return f'I haven\'t extracted "{params[0]}" from chat yet.', False - # Copy this for the sentence - sentences[0] = key.copy() - - else: # if there are no params - # Get starting key - key = self.db.get_start() - if key: - # Copy this for the sentence - sentences[0] = key.copy() - else: - # If nothing's ever been said - return "There is not enough learned information yet.", False - - # Counter to prevent infinite loops (i.e. constantly generating while below the - # minimum number of words to generate) - i = 0 - while ( - self.get_sentence_length(sentences) < self.s.max_sentence_length - and i < self.s.max_sentence_length * 2 - ): - # Use key to get next word - if i == 0: - # Prevent fetching on the first word - word = self.db.get_next_initial(i, key) - else: - word = self.db.get_next(i, key) - - i += 1 - - if word == "" or word is None: - # Break, unless we are before the min_sentence_length - if i < self.s.min_sentence_length: - key = self.db.get_start() - # Ensure that the key can be generated. Otherwise, we still stop. - if key: - # Start a new sentence - sentences.append([]) - for entry in key: - sentences[-1].append(entry) - continue - break - - # Otherwise add the word - sentences[-1].append(word) - - # Shift the key so on the next iteration it gets the next item - key.pop(0) - key.append(word) - - # If there were params, but the sentence resulting is identical to the params - # Then the params did not result in an actual sentence - # If so, restart without params - if len(params) > 0 and params == sentences[0]: - return "I haven't learned what to do with \"" + detokenize( - params[-self.s.key_length :], - ) + '" yet.', False - - return self.s.sentence_separator.join( - detokenize(sentence) for sentence in sentences - ), True - - @staticmethod - def get_sentence_length(sentences: list[list[str]]) -> int: - """Given a list of tokens representing a sentence, return the number of words in there. - - Args: - sentences (List[List[str]]): List of lists of tokens that make up a sentence, - where a token is a word or punctuation. For example: - [['Hello', ',', 'you', "'re", 'Tom', '!'], ['Yes', ',', 'I', 'am', '.']] - This would return 6. - - Returns: - int: The number of words in the sentence. - """ - count = 0 - for sentence in sentences: - for token in sentence: - if token not in string.punctuation and token[0] != "'": - count += 1 - return count - - @staticmethod - def extract_modifiers(emotes: str) -> list[str]: - """Extract emote modifiers from emotes such as the horizontal flip. - - Args: - emotes (str): String containing all emotes used in the message. - - Returns: - list[str]: List of strings that show modifiers, such as "_HZ" for horizontal flip. - """ - output = [] - try: - while emotes: - u_index = emotes.index("_") - c_index = emotes.index(":", u_index) - output.append(emotes[u_index:c_index]) - emotes = emotes[c_index:] - except ValueError: - pass - return output - - def send_whisper(self, user: str, message: str) -> None: - """Optionally send a whisper, only if "WhisperCooldown" is True. - - Args: - user (str): The user to potentially whisper. - message (str): The message to potentially whisper - """ - if self.s.whisper_cooldown: - self.ws.send_whisper(user, message) - - @staticmethod - def is_command(message: str) -> bool: - """True if the message is any command, except /me. - - Is used to avoid learning and generating commands. - - Args: - message (str): The message to check. - - Returns: - bool: True if the message is any potential command (starts with a '!', '/' or '.') - except /me. - """ - return message in list(Commands) - - def is_mod(self, username: str, channel: str) -> bool: - """True if the user is a moderator. - - Args: - username (str): The name of the user to check - channel (str): The name of the channel - - Returns: - bool: True if the user is a moderator. - """ - return username in self.s.mods or username == channel - - -if __name__ == "__main__": - MarkovChain() diff --git a/src/huesoporro/libs/settings.py b/src/huesoporro/libs/settings.py deleted file mode 100644 index 18a8152..0000000 --- a/src/huesoporro/libs/settings.py +++ /dev/null @@ -1,118 +0,0 @@ -import json -from pathlib import Path -from typing import Literal - -import platformdirs -from loguru import logger -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class Settings(BaseSettings): - host: str = Field("irc.chat.twitch.tv", alias="Host", serialization_alias="Host") - port: int = Field(6667, alias="Port", serialization_alias="Port") - channel: str = Field(..., alias="Channel", serialization_alias="Channel") - nickname: str = Field(..., alias="Nickname", serialization_alias="Nickname") - authentication: str = Field( - ..., - alias="Authentication", - serialization_alias="Authentication", - ) - denied_users: list[str] = Field( - [ - "StreamElements", - "Nightbot", - "Moobot", - "Marbiebot", - ], - alias="DeniedUsers", - serialization_alias="DeniedUsers", - ) - banned_words: list[str] = Field( - default_factory=list, - alias="BannedWords", - serialization_alias="BannedWords", - ) - mods: list[str] = Field( - default_factory=list, - alias="Mods", - serialization_alias="Mods", - ) - cooldown: int = Field(210, alias="Cooldown", serialization_alias="Cooldown") - key_length: int = Field(2, alias="KeyLength", serialization_alias="KeyLength") - max_sentence_length: int = Field( - 25, - alias="MaxSentenceWordAmount", - serialization_alias="MaxSentenceWordAmount", - ) - min_sentence_length: int = Field( - -1, - alias="MinSentenceWordAmount", - serialization_alias="MinSentenceWordAmount", - ) - help_message_timer: int = Field( - 60 * 60 * 5, - alias="HelpMessageTimer", - serialization_alias="HelpMessageTimer", - ) - automatic_generation_timer: int = Field( - -1, - alias="AutomaticGenerationTimer", - serialization_alias="AutomaticGenerationTimer", - ) - whisper_cooldown: bool = Field( - True, - alias="WhisperCooldown", - serialization_alias="WhisperCooldown", - ) - enable_generate_command: bool = Field( - True, - alias="EnableGenerateCommand", - serialization_alias="EnableGenerateCommand", - ) - sentence_separator: str = Field( - " - ", - alias="SentenceSeparator", - serialization_alias="SentenceSeparator", - ) - allow_generate_params: bool = Field( - True, - alias="AllowGenerateParams", - serialization_alias="AllowGenerateParams", - ) - log_level: Literal[ - "CRITICAL", - "ERROR", - "WARNING", - "INFO", - "DEBUG", - "TRACE", - ] = Field("DEBUG", alias="LogLevel") - model_config = SettingsConfigDict(extra="ignore") - - @property - def channel_name(self): - return self.channel.replace("#", "").lower() - - @classmethod - def read(cls, filepath: Path | None = None) -> "Settings": - if not filepath: - filepath = ( - platformdirs.user_config_path("markovbot_gui", ensure_exists=True) - / "settings.json" - ) - - with filepath.open("r") as f: - data = json.load(f) - return Settings(**data) - - def write(self, filepath: Path | None = None): - if not filepath: - filepath = ( - platformdirs.user_config_path("markovbot_gui", ensure_exists=True) - / "settings.json" - ) - - with filepath.open("w") as f: - logger.info(f"Writing current settings to {filepath}") - json.dump(self.model_dump(by_alias=True), f, indent=4) diff --git a/src/huesoporro/main.py b/src/huesoporro/main.py index 2187dc4..e2616b8 100644 --- a/src/huesoporro/main.py +++ b/src/huesoporro/main.py @@ -1,248 +1,7 @@ -import json -import secrets -from json import JSONDecodeError - -import httpx import uvicorn -from litestar import Litestar, MediaType, Request, Response, WebSocket, get -from litestar.connection import ASGIConnection -from litestar.contrib.jinja import JinjaTemplateEngine -from litestar.datastructures.state import State -from litestar.di import Provide -from litestar.exceptions import HTTPException -from litestar.handlers import BaseRouteHandler, WebsocketListener -from litestar.response import Redirect, Template -from litestar.static_files import StaticFilesConfig -from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR -from litestar.template import TemplateConfig -from loguru import logger -from src.huesoporro.chatbot import ChatbotManager +from src.huesoporro.api.main import create_app from src.huesoporro.settings import Settings -from src.huesoporro.tts import TTSManager -from src.huesoporro.value_objects import WebsocketCommands, WebsocketMessage - - -async def _authenticate(access_token: str): - s = Settings.get() - client = httpx.AsyncClient( - base_url="https://id.twitch.tv", - ) - - resp = await client.get( - "/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"} - ) - user_data = resp.json() - - if user_data.get("status"): - raise HTTPException(status_code=401, detail="Unauthorized") - - if (user := user_data["login"]) not in s.allowed_users: - raise HTTPException(status_code=403, detail="Forbidden") - - return user - - -async def authenticate( - connection: ASGIConnection, route_handler: BaseRouteHandler -) -> None: - """Extract cookie from connection and try to authenticate""" - - try: - login_data = json.loads(connection.cookies.get("twitchLoginData")) - except (JSONDecodeError, TypeError) as exc: - logger.warning(f"Error parsing twitch login data: {exc}") - raise HTTPException(status_code=401, detail="Unauthorized") from exc - - access_token = login_data.get("access_token") - if not login_data or not access_token: - raise HTTPException(status_code=401, detail="Unauthorized") - - user = await _authenticate(access_token) - - connection.state["user"] = user - connection.state["access_token"] = access_token - - -class WebsocketHandler(WebsocketListener): - path = "/ws" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.tts_manager = TTSManager() - self.chatbot_manager = ChatbotManager() - self.user = None - self.access_token = None - - async def on_accept(self, socket: WebSocket, state: State) -> None: - """If the authentication is correct, add the manager's clients list""" - - cookies = socket.cookies.get("twitchLoginData") - try: - access_token = json.loads(cookies).get("access_token") - except (JSONDecodeError, TypeError) as exc: - logger.warning(f"Error parsing twitch login data {exc}") - return - if not access_token: - return - user = await _authenticate(access_token) - - self.user = user - self.access_token = access_token - self.chatbot_manager.clients.add(socket) - self.tts_manager.clients.append(socket) - - logger.info( - f"Connection accepted from {socket.client.host}:{socket.client.port}" # type: ignore[union-attr] - ) - - async def on_disconnect(self, socket: WebSocket) -> None: - # Remove client from the list - if socket in self.tts_manager.clients: - self.tts_manager.clients.remove(socket) - self.chatbot_manager.clients.remove(socket) - logger.info(f"Connection closed by {socket.client.host}:{socket.client.port}") # type: ignore[union-attr] - - async def on_receive(self, data: str, state: State) -> None: - message = WebsocketMessage(**json.loads(data)) - logger.info(f"Received {message.command.value} command") - - match message.command: - case WebsocketCommands.TTS_SEND: - await self.tts_manager.add_to_queue(**message.data) - case WebsocketCommands.CHATBOT_START: - self.chatbot_manager.start_bot( - **message.data - | { - "nickname": self.user, - "authentication": f"oauth:{self.access_token}", - }, - ) - case WebsocketCommands.CHATBOT_STOP: - self.chatbot_manager.stop_bot() - - -@get( - "/tts", - media_type=MediaType.HTML, - guards=[authenticate], -) -async def get_tts_overlay() -> Template: - return Template(template_name="tts.html") - - -@get( - "/tts/permalink", - media_type=MediaType.HTML, -) -async def get_tts_permalink(access_token: str) -> Template: - """Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query - param and not as a cookie, i.e. OBS""" - - # authenticate the user using the provided access token - await _authenticate(access_token) - - return Template( - template_name="tts.html", - ) - - -@get( - "/", - media_type=MediaType.HTML, - guards=[authenticate], -) -async def get_index() -> Template: - return Template( - template_name="index.html", - ) - - -@get("/login", media_type=MediaType.HTML, dependencies={"s": Provide(Settings.get)}) -async def login(s: Settings) -> Template: - scopes = "+".join(s.twitch_scopes) - return Template( - "login.html", - context={ - "twitch_login_url": "https://id.twitch.tv/oauth2/authorize?response_type=token" - f"&client_id={s.twitch_client_id}" - f"&redirect_uri={s.server_hostname}login" - f"&scope={scopes}" - f"&state={secrets.token_urlsafe(32)}" - }, - ) - - -@get("/healthz") -def get_health() -> dict: - return {"status": "ok"} - - -@get("/lefunny") -def get_lefunny() -> Template: - return Template( - template_name="lefunny.html", - context={"sentences": [{"sentence": "Hola huesoperro", "id": 1}]}, - ) - - -def exception_handler(_: Request, exc: Exception) -> Response: - status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) - detail = getattr(exc, "detail", "") - - if isinstance(exc, HTTPException) and (exc.status_code in [401, 403]): - logger.warning("User could not authenticate. Redirecting to /login page") - return Redirect("/login") - - return Response( - media_type=MediaType.TEXT, - content=detail, - status_code=status_code, - ) - - -async def after_exception_handler(exc: Exception, scope: "Scope") -> None: - """Hook function that will be invoked after each exception.""" - state = scope["app"].state - if not hasattr(state, "error_count"): - state.error_count = 1 - else: - state.error_count += 1 - - logger.error( - f"an exception of type {type(exc).__name__} has occurred for requested path {scope['path']} and the application error count is {state.error_count}.", - ) - - -def create_app(): - return Litestar( - route_handlers=[ - get_health, - login, - get_index, - get_tts_overlay, - get_tts_permalink, - get_lefunny, - WebsocketHandler, - ], - static_files_config=( - StaticFilesConfig( - path="/tts_files", - directories=[Settings.get().tts_cache_path], - ), - StaticFilesConfig( - path="static", - directories=[Settings.get().static_files_path], - ), - ), - template_config=TemplateConfig( - directory=Settings.get().templates_files_path, - engine=JinjaTemplateEngine, - ), - exception_handlers={HTTPException: exception_handler}, - after_exception=[after_exception_handler], - ) - if __name__ == "__main__": settings = Settings.get() diff --git a/src/huesoporro/models.py b/src/huesoporro/models.py new file mode 100644 index 0000000..4680f92 --- /dev/null +++ b/src/huesoporro/models.py @@ -0,0 +1,50 @@ +from typing import Self + +import jwt +from pydantic import BaseModel, field_validator + +from src.huesoporro.settings import Settings + + +class TwitchAuth(BaseModel): + access_token: str + refresh_token: str + + +class User(BaseModel): + user: str + expires_at: float + twitch_auth: TwitchAuth + + def encode(self, settings: Settings | None = None) -> str: + s = settings or Settings.get() + return jwt.encode( + self.model_dump(), + key=s.jwt_secret.get_secret_value(), + algorithm="HS256", + ) + + @classmethod + def decode(cls, token: str, settings: Settings | None = None) -> Self: + s = settings or Settings.get() + decoded = jwt.decode( + token, key=s.jwt_secret.get_secret_value(), algorithms=["HS256"] + ) + return cls(**decoded) + + +class ChatbotSettings(BaseModel): + automatic_generation_timer: int = 300 + automatic_quote_timer: int = 500 + mods: list[str] | None = None + + @property + def mods_as_string(self): + return ",".join(self.mods) + + @field_validator("mods", mode="before") + @classmethod + def format_mods_from_string(cls, v): + if isinstance(v, str): + return v.split(",") + return v diff --git a/src/huesoporro/settings.py b/src/huesoporro/settings.py index f6c1f05..1bc19f9 100644 --- a/src/huesoporro/settings.py +++ b/src/huesoporro/settings.py @@ -1,7 +1,7 @@ from functools import lru_cache from pathlib import Path -from pydantic import Field, HttpUrl, field_validator +from pydantic import Field, HttpUrl, SecretStr, field_validator from pydantic_settings import BaseSettings @@ -21,6 +21,8 @@ class Settings(BaseSettings): default_factory=lambda: Path(__file__).parent / "huesoporro.db" ) twitch_client_id: str + twitch_client_secret: SecretStr + jwt_secret: SecretStr twitch_scopes: list[str] = Field( default_factory=lambda: ["channel:bot", "chat:edit", "chat:read"] ) diff --git a/src/huesoporro/svc/__init__.py b/src/huesoporro/svc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/huesoporro/svc/authenticate.py b/src/huesoporro/svc/authenticate.py new file mode 100644 index 0000000..346a407 --- /dev/null +++ b/src/huesoporro/svc/authenticate.py @@ -0,0 +1,26 @@ +import datetime + +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.db import Database +from src.huesoporro.models import User + + +class CodeAuthenticatorSvc(BaseModel): + db: Database + authenticator: TwitchAuthenticator + + @staticmethod + def get_four_hours_from_now() -> float: + now = datetime.datetime.now(datetime.UTC) + four_hours_later = now + datetime.timedelta(hours=4) + return four_hours_later.timestamp() + + async def run(self, code: str) -> User: + auth = await self.authenticator.get_token(code) + username = await self.authenticator.validate_token(auth.access_token) + expires_at = self.get_four_hours_from_now() + user = User(user=username, expires_at=expires_at, twitch_auth=auth) + await self.db.save_user(user) + return user diff --git a/src/huesoporro/svc/generate.py b/src/huesoporro/svc/generate.py new file mode 100644 index 0000000..f6f91f6 --- /dev/null +++ b/src/huesoporro/svc/generate.py @@ -0,0 +1,158 @@ +import string + +from loguru import logger +from pydantic import BaseModel, ConfigDict + +from src.huesoporro.libs.db import Database as MarkovDB +from src.huesoporro.libs.tokenizer import detokenize, tokenize + + +class SentenceGeneratorSvc(BaseModel): + db: MarkovDB + min_sentence_length: int = 2 + key_length: int = 2 + max_sentence_length: int = 25 + sentence_separator: str = " " + model_config = ConfigDict(arbitrary_types_allowed=True) + + def is_mod(self, username: str, channel: str) -> bool: + """True if the user is a moderator. + + Args: + username (str): The name of the user to check + channel (str): The name of the channel + + Returns: + bool: True if the user is a moderator. + """ + return username in self.s.mods or username == channel + + @staticmethod + def get_sentence_length(sentences: list[list[str]]) -> int: + """Given a list of tokens representing a sentence, return the number of words in there. + + Args: + sentences (List[List[str]]): List of lists of tokens that make up a sentence, + where a token is a word or punctuation. For example: + [['Hello', ',', 'you', "'re", 'Tom', '!'], ['Yes', ',', 'I', 'am', '.']] + This would return 6. + + Returns: + int: The number of words in the sentence. + """ + count = 0 + for sentence in sentences: + for token in sentence: + if token not in string.punctuation and token[0] != "'": + count += 1 + return count + + def generate(self, params: list[str] | None = None) -> tuple[str, bool]: # noqa: C901, PLR0912 + """Given an input sentence, generate the remainder of the sentence using the learned data. + + Args: + params (list[str]): A list of words to use as an input to use as the start of generating. + + Returns: + tuple[str, bool]: A tuple of a sentence as the first value, and a boolean indicating + whether the generation succeeded as the second value. + """ + params = params or [] + + # List of sentences that will be generated. In some cases, multiple sentences will be generated, + # e.g. when the first sentence has fewer words than self.min_sentence_length. + sentences: list[list | list[str]] = [[]] + + # Check for commands or recursion, eg: !generate !generate + if len(params) > 0: + return "You can't make me do commands, you madman!", False + + # Get the starting key and starting sentence. + # If there is more than 1 param, get the last 2 as the key. + # Note that self.key_length is fixed to 2 in this implementation + if len(params) > 1: + key = params[-self.key_length :] + # Copy the entire params for the sentence + sentences[0] = params.copy() + + elif len(params) == 1: + # First we try to find if this word was once used as the first word in a sentence: + key = self.db.get_next_single_start(params[0]) # type: ignore[assignment] + if key is None: + # If this failed, we try to find the next word in the grammar as a whole + key = self.db.get_next_single_initial(0, params[0]) + if key is None: + # Return a message that this word hasn't been learned yet + return f'I haven\'t extracted "{params[0]}" from chat yet.', False + # Copy this for the sentence + sentences[0] = key.copy() + + else: # if there are no params + # Get starting key + key = self.db.get_start() + if key: + # Copy this for the sentence + sentences[0] = key.copy() + else: + # If nothing's ever been said + return "There is not enough learned information yet.", False + + # Counter to prevent infinite loops (i.e. constantly generating while below the + # minimum number of words to generate) + i = 0 + while ( + self.get_sentence_length(sentences) < self.max_sentence_length + and i < self.max_sentence_length * 2 + ): + # Use key to get next word + if i == 0: + # Prevent fetching on the first word + word = self.db.get_next_initial(i, key) + else: + word = self.db.get_next(i, key) + + i += 1 + + if word == "" or word is None: + # Break, unless we are before the min_sentence_length + if i < self.min_sentence_length: + key = self.db.get_start() + # Ensure that the key can be generated. Otherwise, we still stop. + if key: + # Start a new sentence + sentences.append([]) + for entry in key: + sentences[-1].append(entry) + continue + break + + # Otherwise add the word + sentences[-1].append(word) + + # Shift the key so on the next iteration it gets the next item + key.pop(0) + key.append(word) + + # If there were params, but the sentence resulting is identical to the params + # Then the params did not result in an actual sentence + # If so, restart without params + if len(params) > 0 and params == sentences[0]: + return "I haven't learned what to do with \"" + detokenize( + params[-self.key_length :], + ) + '" yet.', False + + return self.sentence_separator.join( + detokenize(sentence) for sentence in sentences + ), True + + async def run( + self, + sentence: str | None = None, + ) -> str|None: + if sentence: + sentence = tokenize(sentence) + logger.info(f"Generating sentence from {sentence}") + sentence, success = self.generate(sentence) + logger.info(f"Generated sentence: {sentence}") + if success: + return sentence diff --git a/src/huesoporro/svc/get_chatbot_settings.py b/src/huesoporro/svc/get_chatbot_settings.py new file mode 100644 index 0000000..a1bfc2f --- /dev/null +++ b/src/huesoporro/svc/get_chatbot_settings.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database +from src.huesoporro.models import ChatbotSettings, User + + +class ChatbotSettingsGetterSvc(BaseModel): + db: Database + + async def run(self, user: User) -> ChatbotSettings | None: + return await self.db.get_chatbot_settings(user=user) diff --git a/src/huesoporro/svc/get_random_quote.py b/src/huesoporro/svc/get_random_quote.py new file mode 100644 index 0000000..4371d4e --- /dev/null +++ b/src/huesoporro/svc/get_random_quote.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database + + +class RandomQuoteGetterSvc(BaseModel): + db: Database + + async def run(self, channel_name: str) -> tuple[str, str]: + return await self.db.get_random_quote(channel_name=channel_name) diff --git a/src/huesoporro/svc/hello.py b/src/huesoporro/svc/hello.py new file mode 100644 index 0000000..119095b --- /dev/null +++ b/src/huesoporro/svc/hello.py @@ -0,0 +1,10 @@ +import random + +from pydantic import BaseModel, Field + + +class HelloGeneratorSvc(BaseModel): + hellos: list[str] = Field(default_factory=lambda: ["Hola", "Ayo", "Hi", "Bon día"]) + + def run(self, username: str): + return f"{random.choice(self.hellos)} {username}" diff --git a/src/huesoporro/svc/is_mod.py b/src/huesoporro/svc/is_mod.py new file mode 100644 index 0000000..8450448 --- /dev/null +++ b/src/huesoporro/svc/is_mod.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database +from src.huesoporro.models import User + + +class IsModSvc(BaseModel): + db: Database + + async def run(self, user: User, username: str) -> bool: + chatbot_settings = await self.db.get_chatbot_settings(user=user) + return username in chatbot_settings.mods diff --git a/src/huesoporro/svc/refresh.py b/src/huesoporro/svc/refresh.py new file mode 100644 index 0000000..9209743 --- /dev/null +++ b/src/huesoporro/svc/refresh.py @@ -0,0 +1,27 @@ +import datetime + +from pydantic import BaseModel + +from src.huesoporro.infra.authenticator import TwitchAuthenticator +from src.huesoporro.infra.db import Database +from src.huesoporro.models import User + + +class RefreshTokenAuthenticator(BaseModel): + db: Database + authenticator: TwitchAuthenticator + + @staticmethod + def get_four_hours_from_now() -> float: + now = datetime.datetime.now(datetime.UTC) + four_hours_later = now + datetime.timedelta(hours=4) + return four_hours_later.timestamp() + + async def run(self, refresh_token: str) -> User: + auth = await self.authenticator.refresh_token(refresh_token) + username = await self.authenticator.validate_token(auth.access_token) + expires_at = self.get_four_hours_from_now() + + user = User(user=username, expires_at=expires_at, twitch_auth=auth) + await self.db.save_user(user) + return user diff --git a/src/huesoporro/svc/store.py b/src/huesoporro/svc/store.py new file mode 100644 index 0000000..fc45d46 --- /dev/null +++ b/src/huesoporro/svc/store.py @@ -0,0 +1,63 @@ +from loguru import logger +from nltk.tokenize import sent_tokenize +from pydantic import BaseModel, ConfigDict + +from src.huesoporro.libs.db import Database as MarkovDB +from src.huesoporro.libs.tokenizer import tokenize + + +class SentenceStorerSvc(BaseModel): + db: MarkovDB + key_length: int = 2 + end_tag: str = "" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def store_sentence(self, message: str): + logger.info(f"Processing {message} in order to store it") + stripped_message = message.strip() + try: + sentences = sent_tokenize(stripped_message) + except LookupError: + logger.debug("Downloading required punkt resource...") + import nltk + + nltk.download("punkt") + logger.debug("Downloaded required punkt resource.") + sentences = sent_tokenize(stripped_message) + + for sentence in sentences: + words = tokenize(sentence) + # Double spaces will lead to invalid rules. We remove empty words here + if "" in words: + words = [word for word in words if word] + + # If the sentence is too short, ignore it and move on to the next. + if len(words) <= self.key_length: + continue + + # Add a new starting point for a sentence to the + words = [words[x] for x in range(self.key_length)] + logger.debug(f"Adding {words} to start queue") + self.db.add_start_queue(words) + + # Create Key variable which will be used as a key in the Dictionary for the grammar + key: list[str] = [] + for word in words: + # Set up key for first use + if len(key) < self.key_length: + key.append(word) + continue + logger.debug(f"Adding {key}[{word}] to rule queue") + self.db.add_rule_queue([*key, word]) + + # Remove the first word, and add the current word, + # so that the key is correct for the next word. + key.pop(0) + key.append(word) + logger.debug(f"Adding {key} to rule queue") + # Add at the end of the sentence + self.db.add_rule_queue([*key, self.end_tag]) + + async def run(self, sentence: str): + return self.store_sentence(sentence) diff --git a/src/huesoporro/svc/store_quote.py b/src/huesoporro/svc/store_quote.py new file mode 100644 index 0000000..fe56899 --- /dev/null +++ b/src/huesoporro/svc/store_quote.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database + + +class QuoteStorerSvc(BaseModel): + db: Database + + async def run(self, channel: str, quote: str, author: str): + await self.db.save_quote(channel, quote, author) diff --git a/src/huesoporro/svc/store_settings.py b/src/huesoporro/svc/store_settings.py new file mode 100644 index 0000000..e912d2d --- /dev/null +++ b/src/huesoporro/svc/store_settings.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +from src.huesoporro.infra.db import Database +from src.huesoporro.models import ChatbotSettings, User + + +class ChatbotSettingsStorerSvc(BaseModel): + db: Database + + async def run( + self, user: User, bot_settings: ChatbotSettings + ) -> dict[str, str | int | None] | None: + return await self.db.save_chatbot_settings( + user=user, chatbot_settings=bot_settings + ) diff --git a/src/huesoporro/templates/header.html b/src/huesoporro/templates/header.html index 526601c..7c1cfd9 100644 --- a/src/huesoporro/templates/header.html +++ b/src/huesoporro/templates/header.html @@ -3,6 +3,7 @@ + diff --git a/src/huesoporro/templates/index.html b/src/huesoporro/templates/index.html index 6b9623d..bdf08ee 100644 --- a/src/huesoporro/templates/index.html +++ b/src/huesoporro/templates/index.html @@ -15,146 +15,180 @@
-
- + -
+
+
+ Chatbot settings +
+ + + + + + + + +
+
- -
- Log -
-
diff --git a/src/huesoporro/templates/login.html b/src/huesoporro/templates/login.html index 7d3c472..00e0a38 100644 --- a/src/huesoporro/templates/login.html +++ b/src/huesoporro/templates/login.html @@ -1,32 +1,17 @@ - - - - - - - - - - - - Huesoporro login - - - +{% include 'header.html' %}

Huesoporro🦴🚬

- - Login - with - Twitch - +
+ Login + with + Twitch + +
-