feat: add migrations, api bot endpoints and revamp the whole twitch backend by making use of twitchio
This commit is contained in:
parent
8799bab900
commit
4c534de47b
45 changed files with 1718 additions and 1109 deletions
0
src/huesoporro/actions/__init__.py
Normal file
0
src/huesoporro/actions/__init__.py
Normal file
18
src/huesoporro/actions/store_quote.py
Normal file
18
src/huesoporro/actions/store_quote.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.models import User
|
||||
from src.huesoporro.svc.is_mod import IsModSvc
|
||||
from src.huesoporro.svc.store_quote import QuoteStorerSvc
|
||||
|
||||
|
||||
class StoreQuoteAction(BaseModel):
|
||||
quote_storer_svc: QuoteStorerSvc
|
||||
is_mod_svc: IsModSvc
|
||||
|
||||
async def run(
|
||||
self, user: User, channel: str, quote: str, author: str, username: str
|
||||
) -> str:
|
||||
if not await self.is_mod_svc.run(user=user, username=username):
|
||||
return f"{username} is not a mod and cannot add quotes. Only moderators can add quotes. Sorry!"
|
||||
await self.quote_storer_svc.run(channel, quote, author)
|
||||
return f"«{quote}» added by {author}."
|
||||
0
src/huesoporro/api/__init__.py
Normal file
0
src/huesoporro/api/__init__.py
Normal file
48
src/huesoporro/api/dependencies.py
Normal file
48
src/huesoporro/api/dependencies.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from litestar import Request
|
||||
from litestar.exceptions import HTTPException
|
||||
|
||||
from src.huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import User
|
||||
from src.huesoporro.settings import Settings
|
||||
from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc
|
||||
from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
|
||||
from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
return Settings.get()
|
||||
|
||||
|
||||
def get_authenticator(s: Settings) -> TwitchAuthenticator:
|
||||
return TwitchAuthenticator(s=s)
|
||||
|
||||
|
||||
def get_db(s: Settings):
|
||||
return Database(s=s)
|
||||
|
||||
|
||||
async def authenticate(request: Request) -> User:
|
||||
token = request.query_params.get("huesoporro_token")
|
||||
if token:
|
||||
return User.decode(token)
|
||||
|
||||
cookies = request.cookies.get("huesoporroAuth")
|
||||
if cookies:
|
||||
return User.decode(cookies)
|
||||
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
async def get_code_authenticator_svc(
|
||||
a: TwitchAuthenticator, db: Database
|
||||
) -> CodeAuthenticatorSvc:
|
||||
return CodeAuthenticatorSvc(authenticator=a, db=db)
|
||||
|
||||
|
||||
async def get_chatbot_settings_svc(db: Database):
|
||||
return ChatbotSettingsGetterSvc(db=db)
|
||||
|
||||
|
||||
async def store_chatbot_settings_svc(db: Database):
|
||||
return ChatbotSettingsStorerSvc(db=db)
|
||||
45
src/huesoporro/api/errors.py
Normal file
45
src/huesoporro/api/errors.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import httpx
|
||||
from litestar import MediaType, Request, Response
|
||||
from litestar.exceptions import HTTPException
|
||||
from litestar.response import Redirect
|
||||
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def http_exception_handler(_: Request, exc: HTTPException) -> Response:
|
||||
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
detail = getattr(exc, "detail", "")
|
||||
|
||||
if isinstance(exc, HTTPException) and (exc.status_code in [401, 403]):
|
||||
logger.warning("User could not authenticate. Redirecting to /login page")
|
||||
return Redirect("/login")
|
||||
|
||||
return Response(
|
||||
media_type=MediaType.TEXT,
|
||||
content=detail,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
def httpx_status_error_handler(_: Request, exc: httpx.HTTPStatusError):
|
||||
logger.error(f"HTTPX error occurred: {exc}")
|
||||
return Response(
|
||||
media_type=MediaType.TEXT,
|
||||
content=f"HTTPX error occurred: {exc}",
|
||||
status_code=exc.response.status_code,
|
||||
)
|
||||
|
||||
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
|
||||
"""Hook function that will be invoked after each exception."""
|
||||
state = scope["app"].state
|
||||
if not hasattr(state, "error_count"):
|
||||
state.error_count = 1
|
||||
else:
|
||||
state.error_count += 1
|
||||
logger.error(
|
||||
f"an exception of type {type(exc).__name__} has occurred for requested path {scope['path']} and the application error count is {state.error_count}.",
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
85
src/huesoporro/api/main.py
Normal file
85
src/huesoporro/api/main.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import httpx
|
||||
from litestar import Litestar, get
|
||||
from litestar.contrib.jinja import JinjaTemplateEngine
|
||||
from litestar.di import Provide
|
||||
from litestar.exceptions import HTTPException
|
||||
from litestar.static_files import StaticFilesConfig
|
||||
from litestar.template import TemplateConfig
|
||||
|
||||
from src.huesoporro.api.dependencies import (
|
||||
authenticate,
|
||||
get_authenticator,
|
||||
get_chatbot_settings_svc,
|
||||
get_code_authenticator_svc,
|
||||
get_db,
|
||||
get_settings,
|
||||
store_chatbot_settings_svc,
|
||||
)
|
||||
from src.huesoporro.api.errors import (
|
||||
after_exception_handler,
|
||||
http_exception_handler,
|
||||
httpx_status_error_handler,
|
||||
)
|
||||
from src.huesoporro.api.routes.api import (
|
||||
get_bot_settings,
|
||||
get_bot_status,
|
||||
get_index,
|
||||
get_tts_overlay,
|
||||
get_tts_permalink,
|
||||
manage_bot,
|
||||
save_bot_settings,
|
||||
)
|
||||
from src.huesoporro.api.routes.auth import get_code, login
|
||||
from src.huesoporro.bot import BotsManager
|
||||
from src.huesoporro.settings import Settings
|
||||
|
||||
|
||||
@get("/healthz")
|
||||
def get_health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def create_app():
|
||||
return Litestar(
|
||||
route_handlers=[
|
||||
get_health,
|
||||
login,
|
||||
get_index,
|
||||
get_tts_overlay,
|
||||
get_tts_permalink,
|
||||
get_code,
|
||||
manage_bot,
|
||||
get_bot_status,
|
||||
save_bot_settings,
|
||||
get_bot_settings,
|
||||
],
|
||||
static_files_config=(
|
||||
StaticFilesConfig(
|
||||
path="/tts_files",
|
||||
directories=[Settings.get().tts_cache_path],
|
||||
),
|
||||
StaticFilesConfig(
|
||||
path="static",
|
||||
directories=[Settings.get().static_files_path],
|
||||
),
|
||||
),
|
||||
template_config=TemplateConfig(
|
||||
directory=Settings.get().templates_files_path,
|
||||
engine=JinjaTemplateEngine,
|
||||
),
|
||||
exception_handlers={
|
||||
HTTPException: http_exception_handler,
|
||||
httpx.HTTPStatusError: httpx_status_error_handler,
|
||||
},
|
||||
after_exception=[after_exception_handler],
|
||||
dependencies={
|
||||
"s": Provide(get_settings, use_cache=True),
|
||||
"a": Provide(get_authenticator, use_cache=True),
|
||||
"user": Provide(authenticate),
|
||||
"db": Provide(get_db, use_cache=True),
|
||||
"code_authenticator_svc": Provide(get_code_authenticator_svc),
|
||||
"bm": Provide(BotsManager, use_cache=True),
|
||||
"gbs": Provide(get_chatbot_settings_svc),
|
||||
"sbs": Provide(store_chatbot_settings_svc),
|
||||
},
|
||||
)
|
||||
0
src/huesoporro/api/routes/__init__.py
Normal file
0
src/huesoporro/api/routes/__init__.py
Normal file
86
src/huesoporro/api/routes/api.py
Normal file
86
src/huesoporro/api/routes/api.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from typing import Literal
|
||||
|
||||
from litestar import MediaType, Response, get, put
|
||||
from litestar.response import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.bot import BotsManager
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
from src.huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
|
||||
from src.huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
|
||||
|
||||
|
||||
class ManageBotDTO(BaseModel):
|
||||
command: Literal["start", "stop"]
|
||||
channel_name: str | None = None
|
||||
|
||||
|
||||
@get(
|
||||
"/tts",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def get_tts_overlay() -> Template:
|
||||
return Template(template_name="tts.html")
|
||||
|
||||
|
||||
@get(
|
||||
"/tts/permalink",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def get_tts_permalink(access_token: str) -> Template:
|
||||
"""Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query
|
||||
param and not as a cookie, i.e. OBS"""
|
||||
|
||||
# authenticate the user using the provided access token
|
||||
|
||||
return Template(
|
||||
template_name="tts.html",
|
||||
)
|
||||
|
||||
|
||||
@get(
|
||||
"/",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
return Template(template_name="index.html", context=chatbot_settings.model_dump() if chatbot_settings else {})
|
||||
|
||||
|
||||
@put("/api/v1/bot")
|
||||
async def manage_bot(
|
||||
user: User, data: ManageBotDTO, gbs: ChatbotSettingsGetterSvc, bm: BotsManager
|
||||
) -> Response:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
if data.command == "start":
|
||||
if not data.channel_name:
|
||||
return Response({"message": "Channel name is required"}, status_code=400)
|
||||
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings)
|
||||
if user.user in bm.bots:
|
||||
await bm.run_user_bot(user)
|
||||
return Response({"message": "Bot started"})
|
||||
if data.command == "stop" and user.user in bm.bots:
|
||||
await bm.stop_user_bot(user)
|
||||
return Response({"message": "Bot stopped"})
|
||||
|
||||
|
||||
@get("/api/v1/bot")
|
||||
async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
||||
if user.user not in bm.bots:
|
||||
return {"status": "ko"}
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@get("/api/v1/bot/settings")
|
||||
async def get_bot_settings(
|
||||
user: User, gbs: ChatbotSettingsGetterSvc
|
||||
) -> ChatbotSettings:
|
||||
return await gbs.run(user=user)
|
||||
|
||||
|
||||
@put("/api/v1/bot/settings")
|
||||
async def save_bot_settings(
|
||||
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
|
||||
) -> dict:
|
||||
await sbs.run(user=user, bot_settings=data)
|
||||
return {"status": "ok"}
|
||||
31
src/huesoporro/api/routes/auth.py
Normal file
31
src/huesoporro/api/routes/auth.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import secrets
|
||||
|
||||
from litestar import MediaType, get
|
||||
from litestar.response import Redirect, Template
|
||||
|
||||
from src.huesoporro.settings import Settings
|
||||
from src.huesoporro.svc.authenticate import CodeAuthenticatorSvc
|
||||
|
||||
|
||||
@get(path="/o/code")
|
||||
async def get_code(code: str, code_authenticator_svc: CodeAuthenticatorSvc) -> Redirect:
|
||||
user = await code_authenticator_svc.run(code)
|
||||
return Redirect("/", cookies={"huesoporroAuth": user.encode()})
|
||||
|
||||
|
||||
@get(
|
||||
"/login",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def login(s: Settings) -> Template:
|
||||
scopes = "+".join(s.twitch_scopes)
|
||||
return Template(
|
||||
"login.html",
|
||||
context={
|
||||
"twitch_login_url": "https://id.twitch.tv/oauth2/authorize?response_type=code"
|
||||
f"&client_id={s.twitch_client_id}"
|
||||
f"&redirect_uri={s.server_hostname}o/code"
|
||||
f"&scope={scopes}"
|
||||
f"&state={secrets.token_urlsafe(32)}"
|
||||
},
|
||||
)
|
||||
150
src/huesoporro/bot.py
Normal file
150
src/huesoporro/bot.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from twitchio import Channel
|
||||
from twitchio.ext import commands, routines
|
||||
|
||||
from src.huesoporro.actions.store_quote import StoreQuoteAction
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.libs.db import Database as MarkovDB
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
from src.huesoporro.svc.generate import SentenceGeneratorSvc
|
||||
from src.huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||
from src.huesoporro.svc.hello import HelloGeneratorSvc
|
||||
from src.huesoporro.svc.is_mod import IsModSvc
|
||||
from src.huesoporro.svc.store import SentenceStorerSvc
|
||||
from src.huesoporro.svc.store_quote import QuoteStorerSvc
|
||||
|
||||
|
||||
class Bot(commands.Bot):
|
||||
def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str):
|
||||
super().__init__(
|
||||
token=user.twitch_auth.access_token, prefix="!", initial_channels=[channel]
|
||||
)
|
||||
self.channel = channel
|
||||
self.user = user
|
||||
self.generate_svc = SentenceGeneratorSvc(db=MarkovDB(channel=channel))
|
||||
self.hello_svc = HelloGeneratorSvc()
|
||||
db = Database()
|
||||
self.store_quote_action = StoreQuoteAction(
|
||||
quote_storer_svc=QuoteStorerSvc(db=db), is_mod_svc=IsModSvc(db=db)
|
||||
)
|
||||
|
||||
self.get_random_quote_svc = RandomQuoteGetterSvc(db=db)
|
||||
|
||||
self.quote_routine = routines.routine(
|
||||
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
|
||||
)(self.send_quote)
|
||||
self.generation_routine = routines.routine(
|
||||
seconds=chatbot_settings.automatic_generation_timer, wait_first=True
|
||||
)(self.send_generation)
|
||||
|
||||
async def event_ready(self):
|
||||
logger.info(f"Logged in as {self.nick}")
|
||||
logger.info(f"User id is {self.user_id}")
|
||||
|
||||
@commands.command()
|
||||
async def hello(self, ctx: commands.Context, user: User | None = None):
|
||||
username = user.name if user else ctx.author.name
|
||||
await ctx.send(self.hello_svc.run(username))
|
||||
|
||||
@commands.command(aliases=["g"])
|
||||
async def generate(self, ctx: commands.Context, *, words: str | None = None):
|
||||
sentence = await self.generate_svc.run(words)
|
||||
if sentence:
|
||||
await ctx.send(sentence)
|
||||
|
||||
@commands.command(aliases=["qadd"])
|
||||
async def add_quote(self, ctx: commands.Context, *, quote: str):
|
||||
# extract author from quote; the author is the last word
|
||||
quote, author = quote.rsplit(" ", 1)
|
||||
await ctx.send(
|
||||
await self.store_quote_action.run(
|
||||
user=self.user,
|
||||
channel=self.channel,
|
||||
quote=quote,
|
||||
author=author,
|
||||
username=ctx.author.name,
|
||||
)
|
||||
)
|
||||
|
||||
@commands.command(aliases=["q", "quote"])
|
||||
async def get_random_quote(self, ctx: commands.Context):
|
||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||
await ctx.send(f"«{quote[0]}» - {quote[1]}")
|
||||
|
||||
def get_channel_conn(self) -> Channel:
|
||||
return Channel(name=self.channel, websocket=self._connection)
|
||||
|
||||
async def send_quote(self):
|
||||
quote = await self.get_random_quote_svc.run(channel_name=self.channel)
|
||||
channel = self.get_channel_conn()
|
||||
logger.info(f"Sending random quote {quote[0]}")
|
||||
await channel.send(f"«{quote[0]}» - {quote[1]}")
|
||||
|
||||
async def send_generation(self):
|
||||
sentence = await self.generate_svc.run()
|
||||
if not sentence:
|
||||
return
|
||||
channel = self.get_channel_conn()
|
||||
logger.info(f"Sending generated sentence {sentence}")
|
||||
await channel.send(sentence)
|
||||
|
||||
def start_routines(self):
|
||||
logger.info("Starting routines")
|
||||
self.quote_routine.start(stop_on_error=False)
|
||||
self.generation_routine.start(stop_on_error=False)
|
||||
|
||||
def stop_routines(self):
|
||||
logger.info("Stopping routines")
|
||||
self.quote_routine.cancel()
|
||||
self.generation_routine.cancel()
|
||||
|
||||
|
||||
class SaveMessagesCog(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.store_svc = SentenceStorerSvc(db=MarkovDB(channel=bot.channel))
|
||||
|
||||
@commands.Cog.event()
|
||||
async def event_message(self, message):
|
||||
# An event inside a cog!
|
||||
content = message.content
|
||||
if content.startswith("!"):
|
||||
return
|
||||
|
||||
if not message.author:
|
||||
return
|
||||
|
||||
await self.store_svc.run(content)
|
||||
|
||||
|
||||
class BotsManager:
|
||||
def __init__(self):
|
||||
self.bots: dict[str, Bot] = {}
|
||||
|
||||
def add_bot(self, user: User, channel: str, chatbot_settings: ChatbotSettings):
|
||||
if user.user in self.bots:
|
||||
logger.info(f"Bot for {user.user} already exists")
|
||||
return
|
||||
logger.info(f"Adding bot for {user.user}")
|
||||
bot = Bot(user=user, channel=channel, chatbot_settings=chatbot_settings)
|
||||
bot.add_cog(SaveMessagesCog(bot))
|
||||
self.bots[user.user] = bot
|
||||
|
||||
async def run_user_bot(self, user: User):
|
||||
if user.user not in self.bots:
|
||||
return
|
||||
|
||||
logger.info(f"Starting bot for {user.user}")
|
||||
bot = self.bots[user.user]
|
||||
task = asyncio.create_task(bot.start())
|
||||
task.add_done_callback(lambda x: logger.info(f"Bot for {user.user} stopped"))
|
||||
bot.start_routines()
|
||||
|
||||
async def stop_user_bot(self, user: User):
|
||||
if user.user not in self.bots:
|
||||
return
|
||||
bot = self.bots.pop(user.user)
|
||||
await bot.close()
|
||||
bot.stop_routines()
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
import asyncio
|
||||
from asyncio import sleep as asleep
|
||||
from queue import Queue
|
||||
from time import sleep
|
||||
|
||||
import nltk
|
||||
from litestar import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
from src.huesoporro.libs.markov_chain_bot import MarkovChain
|
||||
from src.huesoporro.libs.settings import Settings as MarkovChainSettings
|
||||
from src.huesoporro.value_objects import WebsocketCommands, WebsocketMessage
|
||||
|
||||
nltk.download("punkt_tab")
|
||||
|
||||
|
||||
class ChatbotManager:
|
||||
def __init__(self):
|
||||
self.bot: MarkovChain | None = None
|
||||
self.clients: set[WebSocket] = set()
|
||||
self.log_queue: Queue = Queue()
|
||||
self.tasks: set = set()
|
||||
|
||||
def start_bot(
|
||||
self,
|
||||
channel_name: str,
|
||||
nickname: str,
|
||||
authentication: str,
|
||||
):
|
||||
task = asyncio.create_task(self.send_bot_status())
|
||||
self.tasks.add(task)
|
||||
if self.bot:
|
||||
return
|
||||
self.bot = MarkovChain(
|
||||
settings=MarkovChainSettings(
|
||||
Channel=channel_name,
|
||||
Nickname=nickname,
|
||||
Authentication=authentication,
|
||||
AutomaticGenerationTimer=300,
|
||||
),
|
||||
)
|
||||
|
||||
self.bot.run_bot()
|
||||
sleep(2)
|
||||
|
||||
def stop_bot(self):
|
||||
self.bot.stop_bot()
|
||||
self.bot = None
|
||||
|
||||
async def send_bot_status(self):
|
||||
while True:
|
||||
for client in self.clients:
|
||||
message = WebsocketMessage(
|
||||
command=WebsocketCommands.CHATBOT_STATUS,
|
||||
data={"status": "ok" if self.bot else "ko"},
|
||||
)
|
||||
await client.send_text(message.model_dump_json())
|
||||
logger.info(
|
||||
f"Sending bot status {message} to {client.client.host}:{client.client.port}"
|
||||
)
|
||||
|
||||
await asleep(2)
|
||||
0
src/huesoporro/infra/__init__.py
Normal file
0
src/huesoporro/infra/__init__.py
Normal file
62
src/huesoporro/infra/authenticator.py
Normal file
62
src/huesoporro/infra/authenticator.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import httpx
|
||||
from litestar.exceptions import HTTPException
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from src.huesoporro.models import TwitchAuth
|
||||
from src.huesoporro.settings import Settings
|
||||
|
||||
|
||||
class TwitchAuthenticator(BaseModel):
|
||||
s: Settings = Field(default_factory=Settings.get)
|
||||
client: httpx.AsyncClient = Field(
|
||||
default_factory=lambda x: httpx.AsyncClient(base_url="https://id.twitch.tv/")
|
||||
)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
async def get_token(self, code: str, auto_refresh: bool = True) -> TwitchAuth:
|
||||
response = await self.client.post(
|
||||
"/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.get().twitch_client_id,
|
||||
"client_secret": Settings.get().twitch_client_secret.get_secret_value(),
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": f"{Settings.get().server_hostname}o/code",
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if auto_refresh and response.status_code == 401:
|
||||
return await self.refresh_token(response.json()["refresh_token"])
|
||||
|
||||
response.raise_for_status()
|
||||
return TwitchAuth(**response.json())
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> TwitchAuth:
|
||||
response = await self.client.post(
|
||||
"/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.get().twitch_client_id,
|
||||
"client_secret": Settings.get().twitch_client_secret.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return TwitchAuth(**response.json())
|
||||
|
||||
async def validate_token(self, access_token: str) -> str:
|
||||
response = await self.client.get(
|
||||
"/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = response.json()
|
||||
|
||||
if user_data.get("status"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
if (user := user_data["login"]) not in self.s.allowed_users:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
return user
|
||||
134
src/huesoporro/infra/db.py
Normal file
134
src/huesoporro/infra/db.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiosqlite
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
from src.huesoporro.settings import Settings
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Database(BaseModel):
|
||||
s: Settings = Field(default_factory=Settings.get)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_client(self, auto_commit=True):
|
||||
logger.info(f"Opening database connection: {self.s.db_filepath}")
|
||||
async with aiosqlite.connect(self.s.db_filepath) as db:
|
||||
yield db
|
||||
if auto_commit:
|
||||
await db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_now() -> float:
|
||||
return datetime.datetime.now(datetime.UTC).timestamp()
|
||||
|
||||
async def save_user(self, user: User, auto_commit=True):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
async with db.execute(
|
||||
"SELECT * FROM users WHERE user = ?", (user.user,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if result:
|
||||
await db.execute(
|
||||
"UPDATE users SET access_token = ?, refresh_token = ?, expires_at = ?, last_updated_at = ? WHERE user = ?",
|
||||
(
|
||||
user.twitch_auth.access_token,
|
||||
user.twitch_auth.refresh_token,
|
||||
user.expires_at,
|
||||
self.get_now(),
|
||||
user.user,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO users (user, access_token, refresh_token, expires_at, last_updated_at) VALUES (?,?,?,?,?)",
|
||||
(
|
||||
user.user,
|
||||
user.twitch_auth.access_token,
|
||||
user.twitch_auth.refresh_token,
|
||||
user.expires_at,
|
||||
self.get_now(),
|
||||
),
|
||||
)
|
||||
|
||||
async def save_quote(self, channel: str, quote: str, author: str, auto_commit=True):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
await db.execute(
|
||||
"INSERT INTO quotes (channel, quote, author) VALUES (?,?,?)",
|
||||
(channel, quote, author),
|
||||
)
|
||||
|
||||
async def save_chatbot_settings(
|
||||
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
||||
):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
current_settings = await self.get_chatbot_settings(user)
|
||||
if current_settings:
|
||||
await db.execute(
|
||||
"""UPDATE settings SET
|
||||
automatic_generation_timer = ?,
|
||||
automatic_quote_timer = ?,
|
||||
mods = ?,
|
||||
last_updated_at = ?
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(
|
||||
chatbot_settings.automatic_generation_timer,
|
||||
chatbot_settings.automatic_quote_timer,
|
||||
chatbot_settings.mods_as_string,
|
||||
self.get_now(),
|
||||
user.user,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
await db.execute(
|
||||
"""INSERT INTO settings (
|
||||
user_id,
|
||||
automatic_generation_timer,
|
||||
automatic_quote_timer,
|
||||
mods,
|
||||
created_at,
|
||||
last_updated_at
|
||||
) VALUES(?,?,?,?,?,?)
|
||||
""",
|
||||
(
|
||||
user.user,
|
||||
chatbot_settings.automatic_generation_timer,
|
||||
chatbot_settings.automatic_quote_timer,
|
||||
chatbot_settings.mods_as_string,
|
||||
self.get_now(),
|
||||
self.get_now(),
|
||||
),
|
||||
)
|
||||
|
||||
async def get_chatbot_settings(self, user: User) -> ChatbotSettings | None:
|
||||
async with self.get_client() as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
return None
|
||||
return ChatbotSettings(**dict(result))
|
||||
|
||||
async def save_sentence(self, sentence: str, auto_commit=True):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
await db.execute(
|
||||
"INSERT INTO sentences (sentence) VALUES (?)",
|
||||
(sentence,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get_random_quote(self, channel_name: str):
|
||||
async with self.get_client() as db:
|
||||
async with db.execute(
|
||||
"SELECT quote, author FROM quotes WHERE channel = ? ORDER BY RANDOM() LIMIT 1",
|
||||
(channel_name,),
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result
|
||||
|
|
@ -1,543 +0,0 @@
|
|||
import string
|
||||
import time
|
||||
from enum import StrEnum
|
||||
|
||||
from loguru import logger
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from TwitchWebsocket import Message, TwitchWebsocket
|
||||
|
||||
from src.huesoporro.libs.db import Database
|
||||
from src.huesoporro.libs.settings import Settings
|
||||
from src.huesoporro.libs.timer import LoopingTimer
|
||||
from src.huesoporro.libs.tokenizer import detokenize, tokenize
|
||||
|
||||
|
||||
class Commands(StrEnum):
|
||||
SET_COOLDOWN = "!setcd"
|
||||
GENERATE = "!g"
|
||||
BLACKLIST = "!blacklist"
|
||||
GENERATE_HELP = "!ghelp"
|
||||
QUOTE = "!q"
|
||||
QUOTE_ADD = "!qadd"
|
||||
|
||||
|
||||
class MarkovChain:
|
||||
end_tag = "<END>"
|
||||
|
||||
def __init__(self, settings: Settings | None = None):
|
||||
self.s = settings or Settings.read()
|
||||
self.prev_message_t = 0.0
|
||||
self._enabled = True
|
||||
|
||||
self.db = Database(self.s.channel_name)
|
||||
|
||||
if self.s.help_message_timer > 0:
|
||||
if self.s.help_message_timer < 300: # noqa: PLR2004
|
||||
raise ValueError(
|
||||
'Value for "HelpMessageTimer" in must be at least 300 seconds, ' # noqa: EM101
|
||||
"or a negative number for no help messages.",
|
||||
)
|
||||
t = LoopingTimer(self.s.help_message_timer, self._command_help)
|
||||
t.start()
|
||||
|
||||
# Set up daemon Timer to send automatic generation messages
|
||||
if self.s.automatic_generation_timer > 0:
|
||||
if self.s.automatic_generation_timer < 30: # noqa: PLR2004
|
||||
raise ValueError(
|
||||
'Value for "Automatic_generation_message" must be at least 30 seconds, or a negative number for no ' # noqa: EM101
|
||||
"automatic generations.",
|
||||
)
|
||||
logger.info(
|
||||
f"Automatic generation enabled, will send messages every {self.s.automatic_generation_timer} seconds"
|
||||
)
|
||||
t = LoopingTimer(
|
||||
self.s.automatic_generation_timer,
|
||||
self._command_automatic_generation,
|
||||
)
|
||||
t.start()
|
||||
|
||||
self.ws = TwitchWebsocket(
|
||||
host=self.s.host,
|
||||
port=self.s.port,
|
||||
chan=self.s.channel_name,
|
||||
nick=self.s.nickname,
|
||||
auth=self.s.authentication,
|
||||
callback=self.message_handler,
|
||||
capability=["commands", "tags"],
|
||||
live=True,
|
||||
)
|
||||
|
||||
def run_bot(self):
|
||||
self.ws.start_nonblocking()
|
||||
|
||||
def stop_bot(self):
|
||||
self.ws.leave_channel(self.s.channel_name)
|
||||
self.ws.stop()
|
||||
logger.info("Stopped bot")
|
||||
|
||||
def _command_help(self) -> None:
|
||||
"""Send a Help message to the connected chat, as long as the bot wasn't disabled."""
|
||||
if self._enabled:
|
||||
logger.info("Help message sent.")
|
||||
try:
|
||||
self.ws.send_message(
|
||||
"Learn how this bot generates sentences here: https://github.com/CubieDev/TwitchMarkovChain#how-it-works",
|
||||
)
|
||||
except OSError as error:
|
||||
logger.warning(
|
||||
f"[OSError: {error}] upon sending help message. Ignoring.",
|
||||
)
|
||||
|
||||
def _command_set_cooldown(self, username: str, split_message: list[str]):
|
||||
if len(split_message) == 2: # noqa: PLR2004
|
||||
try:
|
||||
cooldown = int(split_message[1])
|
||||
except ValueError:
|
||||
self.ws.send_whisper(
|
||||
username,
|
||||
"The parameter must be an integer amount, eg: !setcd 30",
|
||||
)
|
||||
return
|
||||
self.s.cooldown = cooldown
|
||||
self.s.write()
|
||||
self.ws.send_whisper(
|
||||
username,
|
||||
f"The !generate cooldown has been set to {cooldown} seconds.",
|
||||
)
|
||||
|
||||
def _command_blacklist(self, username: str, split_message: list[str]):
|
||||
if len(split_message) == 2: # noqa: PLR2004
|
||||
try:
|
||||
blacklisted_username = split_message[1]
|
||||
except ValueError:
|
||||
self.ws.send_whisper(
|
||||
username,
|
||||
"The parameter must be a username, eg: !blacklist ibai",
|
||||
)
|
||||
return
|
||||
self.s.denied_users.append(blacklisted_username)
|
||||
self.s.write()
|
||||
|
||||
def _command_generate(self, username: str, message: str):
|
||||
cur_time = time.time()
|
||||
if self.prev_message_t + self.s.cooldown >= cur_time:
|
||||
if not self.db.check_whisper_ignore(username):
|
||||
self.send_whisper(
|
||||
username,
|
||||
f"Cooldown hit: {self.prev_message_t + self.s.cooldown - cur_time:0.2f} out of {self.s.cooldown:.0f}s remaining. !nopm to stop these cooldown pm's.",
|
||||
)
|
||||
logger.info(
|
||||
f"Cooldown hit with {self.prev_message_t + self.s.cooldown - cur_time:0.2f}s remaining.",
|
||||
)
|
||||
params = tokenize(message)[2:] if self.s.allow_generate_params else None
|
||||
# Generate an actual sentence
|
||||
sentence, success = self.generate(params)
|
||||
if success:
|
||||
# Reset cooldown if a message was actually generated
|
||||
self.prev_message_t = time.time()
|
||||
logger.info(sentence)
|
||||
self.ws.send_message(sentence)
|
||||
|
||||
self.store_sentence(message)
|
||||
|
||||
def _command_automatic_generation(self) -> None:
|
||||
"""Send an automatic generation message to the connected chat.
|
||||
|
||||
As long as the bot wasn't disabled, just like if someone typed "!g" in chat.
|
||||
"""
|
||||
if self._enabled:
|
||||
logger.debug("Automatically generating message")
|
||||
sentence, success = self.generate()
|
||||
if success:
|
||||
logger.info(
|
||||
f"Created '{sentence}'. Cooling down for {self.s.automatic_generation_timer} seconds before regenerating",
|
||||
)
|
||||
try:
|
||||
self.ws.send_message(sentence)
|
||||
except OSError as error:
|
||||
logger.warning(
|
||||
f"[OSError: {error}] upon sending automatic generation message. Ignoring.",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Attempted to output automatic generation message, but there is not enough learned information yet.",
|
||||
)
|
||||
|
||||
def _command_quote(self):
|
||||
"""Retrieve a random quote from the `quotes` table and format it as
|
||||
|
||||
> «<quote>» - <author>
|
||||
"""
|
||||
data = self.db.execute(
|
||||
"SELECT quote, author FROM quotes ORDER BY RANDOM() LIMIT 1;", fetch=True
|
||||
)
|
||||
if data:
|
||||
data = data[0]
|
||||
quote, author = data[0], data[1]
|
||||
self.ws.send_message(f"«{quote}» - {author}")
|
||||
|
||||
def _command_add_quote(self, message: str):
|
||||
"""Add a quote to the quotes table. The message should follow the format:
|
||||
|
||||
!qadd quote author
|
||||
|
||||
The last word will be parsed as the author and anything in between !qadd and the author will be considered
|
||||
as the quote itself
|
||||
"""
|
||||
# Split the message into quote and author
|
||||
parts = message.split()
|
||||
author = parts[-1]
|
||||
quote = " ".join(parts[1:-1])
|
||||
|
||||
data = self.db.execute(
|
||||
"SELECT 1 FROM quotes WHERE quote = ?", (quote,), fetch=True
|
||||
)
|
||||
if data:
|
||||
self.ws.send_message(f"Quote «{quote}» was already added.")
|
||||
return
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO quotes (quote, author) VALUES (?, ?)",
|
||||
(quote, author), # type: ignore[arg-type]
|
||||
)
|
||||
self.ws.send_message(f"Quote «{quote}» by {author} added.")
|
||||
|
||||
def store_sentence(self, message: str):
|
||||
logger.info(f"Processing {message} in order to store it")
|
||||
stripped_message = message.strip()
|
||||
try:
|
||||
sentences = sent_tokenize(stripped_message)
|
||||
except LookupError:
|
||||
logger.debug("Downloading required punkt resource...")
|
||||
import nltk
|
||||
|
||||
nltk.download("punkt")
|
||||
logger.debug("Downloaded required punkt resource.")
|
||||
sentences = sent_tokenize(stripped_message)
|
||||
|
||||
for sentence in sentences:
|
||||
words = tokenize(sentence)
|
||||
# Double spaces will lead to invalid rules. We remove empty words here
|
||||
if "" in words:
|
||||
words = [word for word in words if word]
|
||||
|
||||
# If the sentence is too short, ignore it and move on to the next.
|
||||
if len(words) <= self.s.key_length:
|
||||
continue
|
||||
|
||||
# Add a new starting point for a sentence to the <START>
|
||||
words = [words[x] for x in range(self.s.key_length)]
|
||||
logger.debug(f"Adding {words} to start queue")
|
||||
self.db.add_start_queue(words)
|
||||
|
||||
# Create Key variable which will be used as a key in the Dictionary for the grammar
|
||||
key: list[str] = []
|
||||
for word in words:
|
||||
# Set up key for first use
|
||||
if len(key) < self.s.key_length:
|
||||
key.append(word)
|
||||
continue
|
||||
logger.debug(f"Adding {key}[{word}] to rule queue")
|
||||
self.db.add_rule_queue([*key, word])
|
||||
|
||||
# Remove the first word, and add the current word,
|
||||
# so that the key is correct for the next word.
|
||||
key.pop(0)
|
||||
key.append(word)
|
||||
logger.debug(f"Adding {key} to rule queue")
|
||||
# Add <END> at the end of the sentence
|
||||
self.db.add_rule_queue([*key, self.end_tag])
|
||||
|
||||
def message_handler(self, message: Message): # noqa: C901, PLR0911, PLR0912
|
||||
try:
|
||||
"""
|
||||
tts_message = {
|
||||
"badge-info": "subscriber/4",
|
||||
"badges": "vip/1,subscriber/3,sub-gifter/5",
|
||||
"color": "#F79AC6",
|
||||
"custom-reward-id": "8c454446-73b0-480f-946e-d6b5f5c5e331",
|
||||
"display-name": "robosap1ens__",
|
||||
"emotes": "",
|
||||
"first-msg": "0",
|
||||
"flags": "",
|
||||
"id": "6cbd37eb-49ae-41f5-b073-345275c91a07",
|
||||
"mod": "0",
|
||||
"returning-chatter": "0",
|
||||
"room-id": "600944302",
|
||||
"subscriber": "1",
|
||||
"tmi-sent-ts": "1733252657689",
|
||||
"turbo": "0",
|
||||
"user-id": "713968248",
|
||||
"user-type": "",
|
||||
"vip": "1",
|
||||
}
|
||||
"""
|
||||
if not message.user or message.user in self.s.denied_users:
|
||||
logger.debug(f"User {message.user} can't send messages")
|
||||
return
|
||||
|
||||
msgs = message.message.split()
|
||||
if not msgs:
|
||||
logger.debug("Message is empty")
|
||||
return
|
||||
|
||||
if "bits" in message.tags:
|
||||
return
|
||||
|
||||
if "emotes" in message.tags:
|
||||
# Replace modified emotes with normal versions,
|
||||
# as the bot will never have the modified emotes unlocked at the time.
|
||||
for modifier in self.extract_modifiers(message.tags["emotes"]):
|
||||
message.message = message.message.replace(modifier, "")
|
||||
|
||||
logger.debug(f"Received {msgs[0]} command from {message.user}")
|
||||
match msgs[0]:
|
||||
case Commands.GENERATE_HELP:
|
||||
logger.debug("Executing _command_help()")
|
||||
self._command_help()
|
||||
|
||||
case Commands.SET_COOLDOWN:
|
||||
if self.is_mod(message.user, message.channel):
|
||||
logger.debug(
|
||||
f"User {message.user} is mod, executing _command_set_cooldown()",
|
||||
)
|
||||
self._command_set_cooldown(
|
||||
split_message=msgs,
|
||||
username=message.user,
|
||||
)
|
||||
|
||||
case Commands.BLACKLIST:
|
||||
if self.is_mod(message.user, message.channel):
|
||||
logger.debug(
|
||||
f"User {message.user} is a mod, executing _command_blacklist()",
|
||||
)
|
||||
self._command_blacklist(
|
||||
split_message=msgs,
|
||||
username=message.user,
|
||||
)
|
||||
|
||||
case Commands.GENERATE:
|
||||
if not self._enabled:
|
||||
logger.info("Bot not enabled, skipping")
|
||||
return
|
||||
if message.user not in self.s.denied_users:
|
||||
logger.info(
|
||||
f"User {message.user} allowed to generate, executing _command_generate()",
|
||||
)
|
||||
self._command_generate(
|
||||
message=message.message,
|
||||
username=message.user,
|
||||
)
|
||||
|
||||
case Commands.QUOTE:
|
||||
if not self._enabled:
|
||||
logger.info("Bot not enabled, skipping")
|
||||
return
|
||||
if message.user not in self.s.denied_users:
|
||||
logger.info(
|
||||
f"User {message.user} allowed to generate, executing _command_quote()",
|
||||
)
|
||||
self._command_quote()
|
||||
|
||||
case Commands.QUOTE_ADD:
|
||||
if self.is_mod(message.user, message.channel):
|
||||
logger.info(
|
||||
f"User {message.user} allowed to create quote, executing _command_quote()",
|
||||
)
|
||||
self._command_add_quote(message.message)
|
||||
return
|
||||
self.ws.send_message(
|
||||
f"@{message.user} you're not in the modlist, you can't add quotes"
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.debug(
|
||||
f"Not a command: {msgs[0]}. Storing into db as a plain message",
|
||||
)
|
||||
if message.type == "366":
|
||||
logger.info(f"Successfully joined channel: #{message.channel}")
|
||||
return
|
||||
self.store_sentence(message.message)
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Could not process message {message}")
|
||||
|
||||
def generate(self, params: list[str] | None = None) -> tuple[str, bool]: # noqa: C901, PLR0912
|
||||
"""Given an input sentence, generate the remainder of the sentence using the learned data.
|
||||
|
||||
Args:
|
||||
params (list[str]): A list of words to use as an input to use as the start of generating.
|
||||
|
||||
Returns:
|
||||
tuple[str, bool]: A tuple of a sentence as the first value, and a boolean indicating
|
||||
whether the generation succeeded as the second value.
|
||||
"""
|
||||
params = params or []
|
||||
|
||||
# List of sentences that will be generated. In some cases, multiple sentences will be generated,
|
||||
# e.g. when the first sentence has less words than self.min_sentence_length.
|
||||
sentences: list[list | list[str]] = [[]]
|
||||
|
||||
# Check for commands or recursion, eg: !generate !generate
|
||||
if len(params) > 0 and self.is_command(params[0]):
|
||||
return "You can't make me do commands, you madman!", False
|
||||
|
||||
# Get the starting key and starting sentence.
|
||||
# If there is more than 1 param, get the last 2 as the key.
|
||||
# Note that self.s.key_length is fixed to 2 in this implementation
|
||||
if len(params) > 1:
|
||||
key = params[-self.s.key_length :]
|
||||
# Copy the entire params for the sentence
|
||||
sentences[0] = params.copy()
|
||||
|
||||
elif len(params) == 1:
|
||||
# First we try to find if this word was once used as the first word in a sentence:
|
||||
key = self.db.get_next_single_start(params[0]) # type: ignore[assignment]
|
||||
if key is None:
|
||||
# If this failed, we try to find the next word in the grammar as a whole
|
||||
key = self.db.get_next_single_initial(0, params[0])
|
||||
if key is None:
|
||||
# Return a message that this word hasn't been learned yet
|
||||
return f'I haven\'t extracted "{params[0]}" from chat yet.', False
|
||||
# Copy this for the sentence
|
||||
sentences[0] = key.copy()
|
||||
|
||||
else: # if there are no params
|
||||
# Get starting key
|
||||
key = self.db.get_start()
|
||||
if key:
|
||||
# Copy this for the sentence
|
||||
sentences[0] = key.copy()
|
||||
else:
|
||||
# If nothing's ever been said
|
||||
return "There is not enough learned information yet.", False
|
||||
|
||||
# Counter to prevent infinite loops (i.e. constantly generating <END> while below the
|
||||
# minimum number of words to generate)
|
||||
i = 0
|
||||
while (
|
||||
self.get_sentence_length(sentences) < self.s.max_sentence_length
|
||||
and i < self.s.max_sentence_length * 2
|
||||
):
|
||||
# Use key to get next word
|
||||
if i == 0:
|
||||
# Prevent fetching <END> on the first word
|
||||
word = self.db.get_next_initial(i, key)
|
||||
else:
|
||||
word = self.db.get_next(i, key)
|
||||
|
||||
i += 1
|
||||
|
||||
if word == "<END>" or word is None:
|
||||
# Break, unless we are before the min_sentence_length
|
||||
if i < self.s.min_sentence_length:
|
||||
key = self.db.get_start()
|
||||
# Ensure that the key can be generated. Otherwise, we still stop.
|
||||
if key:
|
||||
# Start a new sentence
|
||||
sentences.append([])
|
||||
for entry in key:
|
||||
sentences[-1].append(entry)
|
||||
continue
|
||||
break
|
||||
|
||||
# Otherwise add the word
|
||||
sentences[-1].append(word)
|
||||
|
||||
# Shift the key so on the next iteration it gets the next item
|
||||
key.pop(0)
|
||||
key.append(word)
|
||||
|
||||
# If there were params, but the sentence resulting is identical to the params
|
||||
# Then the params did not result in an actual sentence
|
||||
# If so, restart without params
|
||||
if len(params) > 0 and params == sentences[0]:
|
||||
return "I haven't learned what to do with \"" + detokenize(
|
||||
params[-self.s.key_length :],
|
||||
) + '" yet.', False
|
||||
|
||||
return self.s.sentence_separator.join(
|
||||
detokenize(sentence) for sentence in sentences
|
||||
), True
|
||||
|
||||
@staticmethod
|
||||
def get_sentence_length(sentences: list[list[str]]) -> int:
|
||||
"""Given a list of tokens representing a sentence, return the number of words in there.
|
||||
|
||||
Args:
|
||||
sentences (List[List[str]]): List of lists of tokens that make up a sentence,
|
||||
where a token is a word or punctuation. For example:
|
||||
[['Hello', ',', 'you', "'re", 'Tom', '!'], ['Yes', ',', 'I', 'am', '.']]
|
||||
This would return 6.
|
||||
|
||||
Returns:
|
||||
int: The number of words in the sentence.
|
||||
"""
|
||||
count = 0
|
||||
for sentence in sentences:
|
||||
for token in sentence:
|
||||
if token not in string.punctuation and token[0] != "'":
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def extract_modifiers(emotes: str) -> list[str]:
|
||||
"""Extract emote modifiers from emotes such as the horizontal flip.
|
||||
|
||||
Args:
|
||||
emotes (str): String containing all emotes used in the message.
|
||||
|
||||
Returns:
|
||||
list[str]: List of strings that show modifiers, such as "_HZ" for horizontal flip.
|
||||
"""
|
||||
output = []
|
||||
try:
|
||||
while emotes:
|
||||
u_index = emotes.index("_")
|
||||
c_index = emotes.index(":", u_index)
|
||||
output.append(emotes[u_index:c_index])
|
||||
emotes = emotes[c_index:]
|
||||
except ValueError:
|
||||
pass
|
||||
return output
|
||||
|
||||
def send_whisper(self, user: str, message: str) -> None:
|
||||
"""Optionally send a whisper, only if "WhisperCooldown" is True.
|
||||
|
||||
Args:
|
||||
user (str): The user to potentially whisper.
|
||||
message (str): The message to potentially whisper
|
||||
"""
|
||||
if self.s.whisper_cooldown:
|
||||
self.ws.send_whisper(user, message)
|
||||
|
||||
@staticmethod
|
||||
def is_command(message: str) -> bool:
|
||||
"""True if the message is any command, except /me.
|
||||
|
||||
Is used to avoid learning and generating commands.
|
||||
|
||||
Args:
|
||||
message (str): The message to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the message is any potential command (starts with a '!', '/' or '.')
|
||||
except /me.
|
||||
"""
|
||||
return message in list(Commands)
|
||||
|
||||
def is_mod(self, username: str, channel: str) -> bool:
|
||||
"""True if the user is a moderator.
|
||||
|
||||
Args:
|
||||
username (str): The name of the user to check
|
||||
channel (str): The name of the channel
|
||||
|
||||
Returns:
|
||||
bool: True if the user is a moderator.
|
||||
"""
|
||||
return username in self.s.mods or username == channel
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MarkovChain()
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import platformdirs
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
host: str = Field("irc.chat.twitch.tv", alias="Host", serialization_alias="Host")
|
||||
port: int = Field(6667, alias="Port", serialization_alias="Port")
|
||||
channel: str = Field(..., alias="Channel", serialization_alias="Channel")
|
||||
nickname: str = Field(..., alias="Nickname", serialization_alias="Nickname")
|
||||
authentication: str = Field(
|
||||
...,
|
||||
alias="Authentication",
|
||||
serialization_alias="Authentication",
|
||||
)
|
||||
denied_users: list[str] = Field(
|
||||
[
|
||||
"StreamElements",
|
||||
"Nightbot",
|
||||
"Moobot",
|
||||
"Marbiebot",
|
||||
],
|
||||
alias="DeniedUsers",
|
||||
serialization_alias="DeniedUsers",
|
||||
)
|
||||
banned_words: list[str] = Field(
|
||||
default_factory=list,
|
||||
alias="BannedWords",
|
||||
serialization_alias="BannedWords",
|
||||
)
|
||||
mods: list[str] = Field(
|
||||
default_factory=list,
|
||||
alias="Mods",
|
||||
serialization_alias="Mods",
|
||||
)
|
||||
cooldown: int = Field(210, alias="Cooldown", serialization_alias="Cooldown")
|
||||
key_length: int = Field(2, alias="KeyLength", serialization_alias="KeyLength")
|
||||
max_sentence_length: int = Field(
|
||||
25,
|
||||
alias="MaxSentenceWordAmount",
|
||||
serialization_alias="MaxSentenceWordAmount",
|
||||
)
|
||||
min_sentence_length: int = Field(
|
||||
-1,
|
||||
alias="MinSentenceWordAmount",
|
||||
serialization_alias="MinSentenceWordAmount",
|
||||
)
|
||||
help_message_timer: int = Field(
|
||||
60 * 60 * 5,
|
||||
alias="HelpMessageTimer",
|
||||
serialization_alias="HelpMessageTimer",
|
||||
)
|
||||
automatic_generation_timer: int = Field(
|
||||
-1,
|
||||
alias="AutomaticGenerationTimer",
|
||||
serialization_alias="AutomaticGenerationTimer",
|
||||
)
|
||||
whisper_cooldown: bool = Field(
|
||||
True,
|
||||
alias="WhisperCooldown",
|
||||
serialization_alias="WhisperCooldown",
|
||||
)
|
||||
enable_generate_command: bool = Field(
|
||||
True,
|
||||
alias="EnableGenerateCommand",
|
||||
serialization_alias="EnableGenerateCommand",
|
||||
)
|
||||
sentence_separator: str = Field(
|
||||
" - ",
|
||||
alias="SentenceSeparator",
|
||||
serialization_alias="SentenceSeparator",
|
||||
)
|
||||
allow_generate_params: bool = Field(
|
||||
True,
|
||||
alias="AllowGenerateParams",
|
||||
serialization_alias="AllowGenerateParams",
|
||||
)
|
||||
log_level: Literal[
|
||||
"CRITICAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"TRACE",
|
||||
] = Field("DEBUG", alias="LogLevel")
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
@property
|
||||
def channel_name(self):
|
||||
return self.channel.replace("#", "").lower()
|
||||
|
||||
@classmethod
|
||||
def read(cls, filepath: Path | None = None) -> "Settings":
|
||||
if not filepath:
|
||||
filepath = (
|
||||
platformdirs.user_config_path("markovbot_gui", ensure_exists=True)
|
||||
/ "settings.json"
|
||||
)
|
||||
|
||||
with filepath.open("r") as f:
|
||||
data = json.load(f)
|
||||
return Settings(**data)
|
||||
|
||||
def write(self, filepath: Path | None = None):
|
||||
if not filepath:
|
||||
filepath = (
|
||||
platformdirs.user_config_path("markovbot_gui", ensure_exists=True)
|
||||
/ "settings.json"
|
||||
)
|
||||
|
||||
with filepath.open("w") as f:
|
||||
logger.info(f"Writing current settings to {filepath}")
|
||||
json.dump(self.model_dump(by_alias=True), f, indent=4)
|
||||
|
|
@ -1,248 +1,7 @@
|
|||
import json
|
||||
import secrets
|
||||
from json import JSONDecodeError
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from litestar import Litestar, MediaType, Request, Response, WebSocket, get
|
||||
from litestar.connection import ASGIConnection
|
||||
from litestar.contrib.jinja import JinjaTemplateEngine
|
||||
from litestar.datastructures.state import State
|
||||
from litestar.di import Provide
|
||||
from litestar.exceptions import HTTPException
|
||||
from litestar.handlers import BaseRouteHandler, WebsocketListener
|
||||
from litestar.response import Redirect, Template
|
||||
from litestar.static_files import StaticFilesConfig
|
||||
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
from litestar.template import TemplateConfig
|
||||
from loguru import logger
|
||||
|
||||
from src.huesoporro.chatbot import ChatbotManager
|
||||
from src.huesoporro.api.main import create_app
|
||||
from src.huesoporro.settings import Settings
|
||||
from src.huesoporro.tts import TTSManager
|
||||
from src.huesoporro.value_objects import WebsocketCommands, WebsocketMessage
|
||||
|
||||
|
||||
async def _authenticate(access_token: str):
|
||||
s = Settings.get()
|
||||
client = httpx.AsyncClient(
|
||||
base_url="https://id.twitch.tv",
|
||||
)
|
||||
|
||||
resp = await client.get(
|
||||
"/oauth2/validate", headers={"Authorization": f"OAuth {access_token}"}
|
||||
)
|
||||
user_data = resp.json()
|
||||
|
||||
if user_data.get("status"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
if (user := user_data["login"]) not in s.allowed_users:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate(
|
||||
connection: ASGIConnection, route_handler: BaseRouteHandler
|
||||
) -> None:
|
||||
"""Extract cookie from connection and try to authenticate"""
|
||||
|
||||
try:
|
||||
login_data = json.loads(connection.cookies.get("twitchLoginData"))
|
||||
except (JSONDecodeError, TypeError) as exc:
|
||||
logger.warning(f"Error parsing twitch login data: {exc}")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
access_token = login_data.get("access_token")
|
||||
if not login_data or not access_token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
user = await _authenticate(access_token)
|
||||
|
||||
connection.state["user"] = user
|
||||
connection.state["access_token"] = access_token
|
||||
|
||||
|
||||
class WebsocketHandler(WebsocketListener):
|
||||
path = "/ws"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tts_manager = TTSManager()
|
||||
self.chatbot_manager = ChatbotManager()
|
||||
self.user = None
|
||||
self.access_token = None
|
||||
|
||||
async def on_accept(self, socket: WebSocket, state: State) -> None:
|
||||
"""If the authentication is correct, add the manager's clients list"""
|
||||
|
||||
cookies = socket.cookies.get("twitchLoginData")
|
||||
try:
|
||||
access_token = json.loads(cookies).get("access_token")
|
||||
except (JSONDecodeError, TypeError) as exc:
|
||||
logger.warning(f"Error parsing twitch login data {exc}")
|
||||
return
|
||||
if not access_token:
|
||||
return
|
||||
user = await _authenticate(access_token)
|
||||
|
||||
self.user = user
|
||||
self.access_token = access_token
|
||||
self.chatbot_manager.clients.add(socket)
|
||||
self.tts_manager.clients.append(socket)
|
||||
|
||||
logger.info(
|
||||
f"Connection accepted from {socket.client.host}:{socket.client.port}" # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
async def on_disconnect(self, socket: WebSocket) -> None:
|
||||
# Remove client from the list
|
||||
if socket in self.tts_manager.clients:
|
||||
self.tts_manager.clients.remove(socket)
|
||||
self.chatbot_manager.clients.remove(socket)
|
||||
logger.info(f"Connection closed by {socket.client.host}:{socket.client.port}") # type: ignore[union-attr]
|
||||
|
||||
async def on_receive(self, data: str, state: State) -> None:
|
||||
message = WebsocketMessage(**json.loads(data))
|
||||
logger.info(f"Received {message.command.value} command")
|
||||
|
||||
match message.command:
|
||||
case WebsocketCommands.TTS_SEND:
|
||||
await self.tts_manager.add_to_queue(**message.data)
|
||||
case WebsocketCommands.CHATBOT_START:
|
||||
self.chatbot_manager.start_bot(
|
||||
**message.data
|
||||
| {
|
||||
"nickname": self.user,
|
||||
"authentication": f"oauth:{self.access_token}",
|
||||
},
|
||||
)
|
||||
case WebsocketCommands.CHATBOT_STOP:
|
||||
self.chatbot_manager.stop_bot()
|
||||
|
||||
|
||||
@get(
|
||||
"/tts",
|
||||
media_type=MediaType.HTML,
|
||||
guards=[authenticate],
|
||||
)
|
||||
async def get_tts_overlay() -> Template:
|
||||
return Template(template_name="tts.html")
|
||||
|
||||
|
||||
@get(
|
||||
"/tts/permalink",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def get_tts_permalink(access_token: str) -> Template:
|
||||
"""Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query
|
||||
param and not as a cookie, i.e. OBS"""
|
||||
|
||||
# authenticate the user using the provided access token
|
||||
await _authenticate(access_token)
|
||||
|
||||
return Template(
|
||||
template_name="tts.html",
|
||||
)
|
||||
|
||||
|
||||
@get(
|
||||
"/",
|
||||
media_type=MediaType.HTML,
|
||||
guards=[authenticate],
|
||||
)
|
||||
async def get_index() -> Template:
|
||||
return Template(
|
||||
template_name="index.html",
|
||||
)
|
||||
|
||||
|
||||
@get("/login", media_type=MediaType.HTML, dependencies={"s": Provide(Settings.get)})
|
||||
async def login(s: Settings) -> Template:
|
||||
scopes = "+".join(s.twitch_scopes)
|
||||
return Template(
|
||||
"login.html",
|
||||
context={
|
||||
"twitch_login_url": "https://id.twitch.tv/oauth2/authorize?response_type=token"
|
||||
f"&client_id={s.twitch_client_id}"
|
||||
f"&redirect_uri={s.server_hostname}login"
|
||||
f"&scope={scopes}"
|
||||
f"&state={secrets.token_urlsafe(32)}"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@get("/healthz")
|
||||
def get_health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@get("/lefunny")
|
||||
def get_lefunny() -> Template:
|
||||
return Template(
|
||||
template_name="lefunny.html",
|
||||
context={"sentences": [{"sentence": "Hola huesoperro", "id": 1}]},
|
||||
)
|
||||
|
||||
|
||||
def exception_handler(_: Request, exc: Exception) -> Response:
|
||||
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
detail = getattr(exc, "detail", "")
|
||||
|
||||
if isinstance(exc, HTTPException) and (exc.status_code in [401, 403]):
|
||||
logger.warning("User could not authenticate. Redirecting to /login page")
|
||||
return Redirect("/login")
|
||||
|
||||
return Response(
|
||||
media_type=MediaType.TEXT,
|
||||
content=detail,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
|
||||
"""Hook function that will be invoked after each exception."""
|
||||
state = scope["app"].state
|
||||
if not hasattr(state, "error_count"):
|
||||
state.error_count = 1
|
||||
else:
|
||||
state.error_count += 1
|
||||
|
||||
logger.error(
|
||||
f"an exception of type {type(exc).__name__} has occurred for requested path {scope['path']} and the application error count is {state.error_count}.",
|
||||
)
|
||||
|
||||
|
||||
def create_app():
|
||||
return Litestar(
|
||||
route_handlers=[
|
||||
get_health,
|
||||
login,
|
||||
get_index,
|
||||
get_tts_overlay,
|
||||
get_tts_permalink,
|
||||
get_lefunny,
|
||||
WebsocketHandler,
|
||||
],
|
||||
static_files_config=(
|
||||
StaticFilesConfig(
|
||||
path="/tts_files",
|
||||
directories=[Settings.get().tts_cache_path],
|
||||
),
|
||||
StaticFilesConfig(
|
||||
path="static",
|
||||
directories=[Settings.get().static_files_path],
|
||||
),
|
||||
),
|
||||
template_config=TemplateConfig(
|
||||
directory=Settings.get().templates_files_path,
|
||||
engine=JinjaTemplateEngine,
|
||||
),
|
||||
exception_handlers={HTTPException: exception_handler},
|
||||
after_exception=[after_exception_handler],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = Settings.get()
|
||||
|
|
|
|||
50
src/huesoporro/models.py
Normal file
50
src/huesoporro/models.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Self
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from src.huesoporro.settings import Settings
|
||||
|
||||
|
||||
class TwitchAuth(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
user: str
|
||||
expires_at: float
|
||||
twitch_auth: TwitchAuth
|
||||
|
||||
def encode(self, settings: Settings | None = None) -> str:
|
||||
s = settings or Settings.get()
|
||||
return jwt.encode(
|
||||
self.model_dump(),
|
||||
key=s.jwt_secret.get_secret_value(),
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, token: str, settings: Settings | None = None) -> Self:
|
||||
s = settings or Settings.get()
|
||||
decoded = jwt.decode(
|
||||
token, key=s.jwt_secret.get_secret_value(), algorithms=["HS256"]
|
||||
)
|
||||
return cls(**decoded)
|
||||
|
||||
|
||||
class ChatbotSettings(BaseModel):
|
||||
automatic_generation_timer: int = 300
|
||||
automatic_quote_timer: int = 500
|
||||
mods: list[str] | None = None
|
||||
|
||||
@property
|
||||
def mods_as_string(self):
|
||||
return ",".join(self.mods)
|
||||
|
||||
@field_validator("mods", mode="before")
|
||||
@classmethod
|
||||
def format_mods_from_string(cls, v):
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, HttpUrl, field_validator
|
||||
from pydantic import Field, HttpUrl, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
|
@ -21,6 +21,8 @@ class Settings(BaseSettings):
|
|||
default_factory=lambda: Path(__file__).parent / "huesoporro.db"
|
||||
)
|
||||
twitch_client_id: str
|
||||
twitch_client_secret: SecretStr
|
||||
jwt_secret: SecretStr
|
||||
twitch_scopes: list[str] = Field(
|
||||
default_factory=lambda: ["channel:bot", "chat:edit", "chat:read"]
|
||||
)
|
||||
|
|
|
|||
0
src/huesoporro/svc/__init__.py
Normal file
0
src/huesoporro/svc/__init__.py
Normal file
26
src/huesoporro/svc/authenticate.py
Normal file
26
src/huesoporro/svc/authenticate.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import User
|
||||
|
||||
|
||||
class CodeAuthenticatorSvc(BaseModel):
|
||||
db: Database
|
||||
authenticator: TwitchAuthenticator
|
||||
|
||||
@staticmethod
|
||||
def get_four_hours_from_now() -> float:
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
four_hours_later = now + datetime.timedelta(hours=4)
|
||||
return four_hours_later.timestamp()
|
||||
|
||||
async def run(self, code: str) -> User:
|
||||
auth = await self.authenticator.get_token(code)
|
||||
username = await self.authenticator.validate_token(auth.access_token)
|
||||
expires_at = self.get_four_hours_from_now()
|
||||
user = User(user=username, expires_at=expires_at, twitch_auth=auth)
|
||||
await self.db.save_user(user)
|
||||
return user
|
||||
158
src/huesoporro/svc/generate.py
Normal file
158
src/huesoporro/svc/generate.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import string
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from src.huesoporro.libs.db import Database as MarkovDB
|
||||
from src.huesoporro.libs.tokenizer import detokenize, tokenize
|
||||
|
||||
|
||||
class SentenceGeneratorSvc(BaseModel):
|
||||
db: MarkovDB
|
||||
min_sentence_length: int = 2
|
||||
key_length: int = 2
|
||||
max_sentence_length: int = 25
|
||||
sentence_separator: str = " "
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def is_mod(self, username: str, channel: str) -> bool:
|
||||
"""True if the user is a moderator.
|
||||
|
||||
Args:
|
||||
username (str): The name of the user to check
|
||||
channel (str): The name of the channel
|
||||
|
||||
Returns:
|
||||
bool: True if the user is a moderator.
|
||||
"""
|
||||
return username in self.s.mods or username == channel
|
||||
|
||||
@staticmethod
|
||||
def get_sentence_length(sentences: list[list[str]]) -> int:
|
||||
"""Given a list of tokens representing a sentence, return the number of words in there.
|
||||
|
||||
Args:
|
||||
sentences (List[List[str]]): List of lists of tokens that make up a sentence,
|
||||
where a token is a word or punctuation. For example:
|
||||
[['Hello', ',', 'you', "'re", 'Tom', '!'], ['Yes', ',', 'I', 'am', '.']]
|
||||
This would return 6.
|
||||
|
||||
Returns:
|
||||
int: The number of words in the sentence.
|
||||
"""
|
||||
count = 0
|
||||
for sentence in sentences:
|
||||
for token in sentence:
|
||||
if token not in string.punctuation and token[0] != "'":
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def generate(self, params: list[str] | None = None) -> tuple[str, bool]: # noqa: C901, PLR0912
|
||||
"""Given an input sentence, generate the remainder of the sentence using the learned data.
|
||||
|
||||
Args:
|
||||
params (list[str]): A list of words to use as an input to use as the start of generating.
|
||||
|
||||
Returns:
|
||||
tuple[str, bool]: A tuple of a sentence as the first value, and a boolean indicating
|
||||
whether the generation succeeded as the second value.
|
||||
"""
|
||||
params = params or []
|
||||
|
||||
# List of sentences that will be generated. In some cases, multiple sentences will be generated,
|
||||
# e.g. when the first sentence has fewer words than self.min_sentence_length.
|
||||
sentences: list[list | list[str]] = [[]]
|
||||
|
||||
# Check for commands or recursion, eg: !generate !generate
|
||||
if len(params) > 0:
|
||||
return "You can't make me do commands, you madman!", False
|
||||
|
||||
# Get the starting key and starting sentence.
|
||||
# If there is more than 1 param, get the last 2 as the key.
|
||||
# Note that self.key_length is fixed to 2 in this implementation
|
||||
if len(params) > 1:
|
||||
key = params[-self.key_length :]
|
||||
# Copy the entire params for the sentence
|
||||
sentences[0] = params.copy()
|
||||
|
||||
elif len(params) == 1:
|
||||
# First we try to find if this word was once used as the first word in a sentence:
|
||||
key = self.db.get_next_single_start(params[0]) # type: ignore[assignment]
|
||||
if key is None:
|
||||
# If this failed, we try to find the next word in the grammar as a whole
|
||||
key = self.db.get_next_single_initial(0, params[0])
|
||||
if key is None:
|
||||
# Return a message that this word hasn't been learned yet
|
||||
return f'I haven\'t extracted "{params[0]}" from chat yet.', False
|
||||
# Copy this for the sentence
|
||||
sentences[0] = key.copy()
|
||||
|
||||
else: # if there are no params
|
||||
# Get starting key
|
||||
key = self.db.get_start()
|
||||
if key:
|
||||
# Copy this for the sentence
|
||||
sentences[0] = key.copy()
|
||||
else:
|
||||
# If nothing's ever been said
|
||||
return "There is not enough learned information yet.", False
|
||||
|
||||
# Counter to prevent infinite loops (i.e. constantly generating <END> while below the
|
||||
# minimum number of words to generate)
|
||||
i = 0
|
||||
while (
|
||||
self.get_sentence_length(sentences) < self.max_sentence_length
|
||||
and i < self.max_sentence_length * 2
|
||||
):
|
||||
# Use key to get next word
|
||||
if i == 0:
|
||||
# Prevent fetching <END> on the first word
|
||||
word = self.db.get_next_initial(i, key)
|
||||
else:
|
||||
word = self.db.get_next(i, key)
|
||||
|
||||
i += 1
|
||||
|
||||
if word == "<END>" or word is None:
|
||||
# Break, unless we are before the min_sentence_length
|
||||
if i < self.min_sentence_length:
|
||||
key = self.db.get_start()
|
||||
# Ensure that the key can be generated. Otherwise, we still stop.
|
||||
if key:
|
||||
# Start a new sentence
|
||||
sentences.append([])
|
||||
for entry in key:
|
||||
sentences[-1].append(entry)
|
||||
continue
|
||||
break
|
||||
|
||||
# Otherwise add the word
|
||||
sentences[-1].append(word)
|
||||
|
||||
# Shift the key so on the next iteration it gets the next item
|
||||
key.pop(0)
|
||||
key.append(word)
|
||||
|
||||
# If there were params, but the sentence resulting is identical to the params
|
||||
# Then the params did not result in an actual sentence
|
||||
# If so, restart without params
|
||||
if len(params) > 0 and params == sentences[0]:
|
||||
return "I haven't learned what to do with \"" + detokenize(
|
||||
params[-self.key_length :],
|
||||
) + '" yet.', False
|
||||
|
||||
return self.sentence_separator.join(
|
||||
detokenize(sentence) for sentence in sentences
|
||||
), True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
sentence: str | None = None,
|
||||
) -> str|None:
|
||||
if sentence:
|
||||
sentence = tokenize(sentence)
|
||||
logger.info(f"Generating sentence from {sentence}")
|
||||
sentence, success = self.generate(sentence)
|
||||
logger.info(f"Generated sentence: {sentence}")
|
||||
if success:
|
||||
return sentence
|
||||
11
src/huesoporro/svc/get_chatbot_settings.py
Normal file
11
src/huesoporro/svc/get_chatbot_settings.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
|
||||
|
||||
class ChatbotSettingsGetterSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, user: User) -> ChatbotSettings | None:
|
||||
return await self.db.get_chatbot_settings(user=user)
|
||||
10
src/huesoporro/svc/get_random_quote.py
Normal file
10
src/huesoporro/svc/get_random_quote.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.db import Database
|
||||
|
||||
|
||||
class RandomQuoteGetterSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, channel_name: str) -> tuple[str, str]:
|
||||
return await self.db.get_random_quote(channel_name=channel_name)
|
||||
10
src/huesoporro/svc/hello.py
Normal file
10
src/huesoporro/svc/hello.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import random
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HelloGeneratorSvc(BaseModel):
|
||||
hellos: list[str] = Field(default_factory=lambda: ["Hola", "Ayo", "Hi", "Bon día"])
|
||||
|
||||
def run(self, username: str):
|
||||
return f"{random.choice(self.hellos)} {username}"
|
||||
12
src/huesoporro/svc/is_mod.py
Normal file
12
src/huesoporro/svc/is_mod.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import User
|
||||
|
||||
|
||||
class IsModSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, user: User, username: str) -> bool:
|
||||
chatbot_settings = await self.db.get_chatbot_settings(user=user)
|
||||
return username in chatbot_settings.mods
|
||||
27
src/huesoporro/svc/refresh.py
Normal file
27
src/huesoporro/svc/refresh.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import User
|
||||
|
||||
|
||||
class RefreshTokenAuthenticator(BaseModel):
|
||||
db: Database
|
||||
authenticator: TwitchAuthenticator
|
||||
|
||||
@staticmethod
|
||||
def get_four_hours_from_now() -> float:
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
four_hours_later = now + datetime.timedelta(hours=4)
|
||||
return four_hours_later.timestamp()
|
||||
|
||||
async def run(self, refresh_token: str) -> User:
|
||||
auth = await self.authenticator.refresh_token(refresh_token)
|
||||
username = await self.authenticator.validate_token(auth.access_token)
|
||||
expires_at = self.get_four_hours_from_now()
|
||||
|
||||
user = User(user=username, expires_at=expires_at, twitch_auth=auth)
|
||||
await self.db.save_user(user)
|
||||
return user
|
||||
63
src/huesoporro/svc/store.py
Normal file
63
src/huesoporro/svc/store.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from loguru import logger
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from src.huesoporro.libs.db import Database as MarkovDB
|
||||
from src.huesoporro.libs.tokenizer import tokenize
|
||||
|
||||
|
||||
class SentenceStorerSvc(BaseModel):
|
||||
db: MarkovDB
|
||||
key_length: int = 2
|
||||
end_tag: str = "<END>"
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def store_sentence(self, message: str):
|
||||
logger.info(f"Processing {message} in order to store it")
|
||||
stripped_message = message.strip()
|
||||
try:
|
||||
sentences = sent_tokenize(stripped_message)
|
||||
except LookupError:
|
||||
logger.debug("Downloading required punkt resource...")
|
||||
import nltk
|
||||
|
||||
nltk.download("punkt")
|
||||
logger.debug("Downloaded required punkt resource.")
|
||||
sentences = sent_tokenize(stripped_message)
|
||||
|
||||
for sentence in sentences:
|
||||
words = tokenize(sentence)
|
||||
# Double spaces will lead to invalid rules. We remove empty words here
|
||||
if "" in words:
|
||||
words = [word for word in words if word]
|
||||
|
||||
# If the sentence is too short, ignore it and move on to the next.
|
||||
if len(words) <= self.key_length:
|
||||
continue
|
||||
|
||||
# Add a new starting point for a sentence to the <START>
|
||||
words = [words[x] for x in range(self.key_length)]
|
||||
logger.debug(f"Adding {words} to start queue")
|
||||
self.db.add_start_queue(words)
|
||||
|
||||
# Create Key variable which will be used as a key in the Dictionary for the grammar
|
||||
key: list[str] = []
|
||||
for word in words:
|
||||
# Set up key for first use
|
||||
if len(key) < self.key_length:
|
||||
key.append(word)
|
||||
continue
|
||||
logger.debug(f"Adding {key}[{word}] to rule queue")
|
||||
self.db.add_rule_queue([*key, word])
|
||||
|
||||
# Remove the first word, and add the current word,
|
||||
# so that the key is correct for the next word.
|
||||
key.pop(0)
|
||||
key.append(word)
|
||||
logger.debug(f"Adding {key} to rule queue")
|
||||
# Add <END> at the end of the sentence
|
||||
self.db.add_rule_queue([*key, self.end_tag])
|
||||
|
||||
async def run(self, sentence: str):
|
||||
return self.store_sentence(sentence)
|
||||
10
src/huesoporro/svc/store_quote.py
Normal file
10
src/huesoporro/svc/store_quote.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.db import Database
|
||||
|
||||
|
||||
class QuoteStorerSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, channel: str, quote: str, author: str):
|
||||
await self.db.save_quote(channel, quote, author)
|
||||
15
src/huesoporro/svc/store_settings.py
Normal file
15
src/huesoporro/svc/store_settings.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from src.huesoporro.infra.db import Database
|
||||
from src.huesoporro.models import ChatbotSettings, User
|
||||
|
||||
|
||||
class ChatbotSettingsStorerSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(
|
||||
self, user: User, bot_settings: ChatbotSettings
|
||||
) -> dict[str, str | int | None] | None:
|
||||
return await self.db.save_chatbot_settings(
|
||||
user=user, chatbot_settings=bot_settings
|
||||
)
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
<head>
|
||||
<link rel="stylesheet" href="/static/css/pico/pico.classless.min.css">
|
||||
<link rel="stylesheet" href="/static/css/pico/pico.colors.min.css">
|
||||
<link rel="icon"
|
||||
href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🦴</text></svg>">
|
||||
<meta charset="utf-8">
|
||||
|
|
|
|||
|
|
@ -15,146 +15,180 @@
|
|||
</header>
|
||||
<main>
|
||||
<section>
|
||||
|
||||
<form>
|
||||
<label for="channelName">Enter channel name:</label>
|
||||
<input type="text" id="channelName" placeholder="#huesoperro" aria-describedby="channelNameValidHelper">
|
||||
<input type="text" id="channelName" placeholder="huesoperro" aria-describedby="channelNameValidHelper">
|
||||
<small id="channelNameValidHelper"></small>
|
||||
<button id="startButton" type="button">Start chatbot</button>
|
||||
|
||||
<button id="stopButton" type="button" disabled style="background-color: #aa0000; border-color: #aa0000">Stop
|
||||
chatbot
|
||||
</button>
|
||||
<br/>
|
||||
|
||||
</form>
|
||||
<details>
|
||||
<summary>Chatbot settings</summary>
|
||||
<form>
|
||||
<label for="automaticGenerationTimer">Automatic generation timer (seconds)</label>
|
||||
<input type="number" id="automaticGenerationTimer" placeholder="300"
|
||||
value="{{ automatic_generation_timer }}">
|
||||
|
||||
<label for="automaticQuotesTimer">Automatic quotes timer (seconds)</label>
|
||||
<input type="number" id="automaticQuotesTimer" placeholder="500" value="{{ automatic_quote_timer }}">
|
||||
|
||||
<label for="mods">Chatbot mods (comma-separated)</label>
|
||||
<input type="text" id="mods" placeholder="huesoporro" value="{{ ','.join(mods) or '' }}">
|
||||
|
||||
<button id="saveSettings" type="button" style="background-color: #00c482; border-color: #00c482">
|
||||
Save settings
|
||||
</button>
|
||||
</form>
|
||||
</details>
|
||||
</section>
|
||||
|
||||
<details>
|
||||
<summary>Log</summary>
|
||||
<div><samp id="log"></samp></div>
|
||||
</details>
|
||||
</main>
|
||||
<script>
|
||||
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
class ChatbotManager {
|
||||
|
||||
constructor() {
|
||||
this.url = getWebsocketProtocol() + window.location.host + "/ws";
|
||||
this.logElement = document.getElementById('log');
|
||||
this.socket = null;
|
||||
this.stopButton = document.getElementById("stopButton");
|
||||
this.startButton = document.getElementById("startButton");
|
||||
this.automaticGenerationTimerInput = document.getElementById("automaticGenerationTimer");
|
||||
this.automaticQuotesTimerInput = document.getElementById("automaticQuotesTimer");
|
||||
this.modsInput = document.getElementById("mods");
|
||||
this.channelNameInput = document.getElementById("channelName");
|
||||
|
||||
}
|
||||
|
||||
log(message) {
|
||||
console.log(message);
|
||||
this.logElement.innerHTML += message + '<br>';
|
||||
}
|
||||
setEvents() {
|
||||
document.getElementById('saveSettings').addEventListener('click', () => {
|
||||
chatbotManager.saveBotSettings()
|
||||
})
|
||||
|
||||
async open() {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.socket = new WebSocket(this.url);
|
||||
this.socket.withCredentials = true;
|
||||
this.socket.onopen = () => {
|
||||
this.log("Connected to WebSocket " + this.url);
|
||||
}
|
||||
this.socket.onmessage = async (event) => {
|
||||
try {
|
||||
const message = JSON.parse(event.data);
|
||||
if (message.command === "chatbot_message") {
|
||||
this.log(`[${message.data.username}]: ${message.data.message}`);
|
||||
} else if (message.command === "chatbot_status") {
|
||||
startButton.disabled = message.data.status === "ok";
|
||||
stopButton.disabled = message.data.status === "ko";
|
||||
this.log("Bot status is " + message.data.status)
|
||||
} else if (message.command === "chatbot_start") {
|
||||
this.log(message.data.log)
|
||||
}
|
||||
} catch (error) {
|
||||
this.log(`Error parsing message: ${error.message}`);
|
||||
}
|
||||
}
|
||||
this.socket.onerror = (error) => {
|
||||
this.log(`WebSocket Error: ${error}`);
|
||||
reject(error);
|
||||
}
|
||||
this.socket.onclose = () => {
|
||||
this.log(`WebSocket connection closed: ${event.code} ${event.reason}`);
|
||||
resolve();
|
||||
|
||||
this.startButton.addEventListener('click', () => {
|
||||
const channelName = this.channelNameInput ? this.channelNameInput.value : '';
|
||||
if (!channelName) {
|
||||
document.getElementById('channelNameValidHelper').textContent = 'Please enter a channel name';
|
||||
this.channelNameInput.setAttribute('aria-invalid', 'true');
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById('channelNameValidHelper').textContent = 'Looks good!';
|
||||
this.channelNameInput.setAttribute('aria-invalid', 'false');
|
||||
|
||||
this.startBot()
|
||||
.then(() => {
|
||||
console.log('Chatbot started successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to start chatbot', error);
|
||||
});
|
||||
});
|
||||
|
||||
this.stopButton.addEventListener('click', () => {
|
||||
chatbotManager.stopBot()
|
||||
.then(() => {
|
||||
console.log('Chatbot stopped successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to stop chatbot', error);
|
||||
});
|
||||
});
|
||||
|
||||
setInterval(() => {
|
||||
this.getBotStatus()
|
||||
.then(() => {
|
||||
console.log('Chatbot status retrieved successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to retrieve chatbot status', error);
|
||||
});
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
async startBot() {
|
||||
// call PUT /bot/start with the channel name as a query param
|
||||
const channelNameInput = document.getElementById('channelName');
|
||||
const channelName = channelNameInput ? channelNameInput.value : '';
|
||||
|
||||
const startCommand = {
|
||||
command: "chatbot_start",
|
||||
data: {
|
||||
channel_name: channelName
|
||||
}
|
||||
};
|
||||
|
||||
this.socket.send(JSON.stringify(startCommand));
|
||||
const response = await fetch(`/api/v1/bot`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"command": "start",
|
||||
"channel_name": channelName
|
||||
})
|
||||
});
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
}
|
||||
|
||||
async stopBot() {
|
||||
const stopCommand = {
|
||||
command: "chatbot_stop",
|
||||
data: {}
|
||||
};
|
||||
// call PUT /bot/stop
|
||||
const response = await fetch(`/api/v1/bot`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"command": "stop",
|
||||
}),
|
||||
});
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
// disable startButton
|
||||
startButton.disabled = true;
|
||||
stopButton.disabled = false;
|
||||
}
|
||||
|
||||
this.socket.send(JSON.stringify(stopCommand));
|
||||
setButtonsStatus(status) {
|
||||
this.startButton.disabled = status === "ok";
|
||||
this.stopButton.disabled = status === "ko";
|
||||
}
|
||||
|
||||
async getBotStatus() {
|
||||
const response = await fetch(`/api/v1/bot`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
|
||||
this.setButtonsStatus(data.status)
|
||||
}
|
||||
|
||||
async saveBotSettings() {
|
||||
const automatic_generation_timer = this.automaticGenerationTimerInput.value
|
||||
const automatic_quote_timer = this.automaticQuotesTimerInput.value
|
||||
const mods = this.modsInput.value.split(",")
|
||||
const response = await fetch(`/api/v1/bot/settings`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
'automatic_generation_timer': automatic_generation_timer,
|
||||
"automatic_quote_timer": automatic_quote_timer,
|
||||
"mods": mods,
|
||||
})
|
||||
})
|
||||
const data = await response.json()
|
||||
console.log(data);
|
||||
if (response.ok){
|
||||
alert("Settings saved successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const chatbotManager = new ChatbotManager();
|
||||
chatbotManager.open()
|
||||
|
||||
const startButton = document.getElementById('startButton');
|
||||
const stopButton = document.getElementById('stopButton');
|
||||
if (startButton) {
|
||||
startButton.addEventListener('click', () => {
|
||||
// check if the input has text
|
||||
const channelNameInput = document.getElementById('channelName');
|
||||
const channelName = channelNameInput ? channelNameInput.value : '';
|
||||
if (!channelName) {
|
||||
// if channelName is empty show error in the helper and add
|
||||
// aria-invalid="true" and aria-describedby="channelNameValidHelper" to the input
|
||||
document.getElementById('channelNameValidHelper').textContent = 'Please enter a channel name';
|
||||
channelNameInput.setAttribute('aria-invalid', 'true');
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById('channelNameValidHelper').textContent = 'Looks good!';
|
||||
channelNameInput.setAttribute('aria-invalid', 'false');
|
||||
|
||||
chatbotManager.startBot()
|
||||
.then(() => {
|
||||
console.log('Chatbot started successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to start chatbot', error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if (stopButton) {
|
||||
stopButton.addEventListener('click', () => {
|
||||
chatbotManager.stopBot()
|
||||
.then(() => {
|
||||
console.log('Chatbot stopped successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to stop chatbot', error);
|
||||
});
|
||||
});
|
||||
}
|
||||
chatbotManager.setEvents();
|
||||
|
||||
addLogoutEvent()
|
||||
});
|
||||
|
||||
</script>
|
||||
</body>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,32 +1,17 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en" xmlns="http://www.w3.org/1999/html">
|
||||
|
||||
<head>
|
||||
<link rel="stylesheet" href="/static/css/mvp.css">
|
||||
|
||||
<link rel="icon"
|
||||
href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🦴</text></svg>">
|
||||
<meta charset="utf-8">
|
||||
<meta name="description" content="Huesoporro Twitch bot">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>Huesoporro login</title>
|
||||
|
||||
</head>
|
||||
|
||||
{% include 'header.html' %}
|
||||
<body>
|
||||
<header>
|
||||
<h1>Huesoporro🦴🚬</h1>
|
||||
</header>
|
||||
<main>
|
||||
<section>
|
||||
|
||||
<a href="{{ twitch_login_url }}" id="loginButton" type="button" style="color: #9c36b5; border-color: #9c36b5">Login
|
||||
with
|
||||
Twitch
|
||||
</a>
|
||||
<form>
|
||||
<a role="button" href="{{ twitch_login_url }}" id="loginButton" type="button" style="background-color: #B645CD; border-color: #B645CD">Login
|
||||
with
|
||||
Twitch
|
||||
</a>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
</main>
|
||||
<script>
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue