refactor: many changes
- Add missing actions and make a clear boundary between actions, services and nfra (i.e: actions shouldn't use stuff from infra/) - Delete stuff not in use: tts, gtts, etc - Add a ton of tests
This commit is contained in:
parent
b2185f4174
commit
152546982c
46 changed files with 2328 additions and 700 deletions
|
|
@ -1,17 +1,31 @@
|
|||
from litestar import Request
|
||||
from litestar.exceptions import HTTPException
|
||||
|
||||
from huesoporro.actions.authenticate import AuthenticateAction
|
||||
from huesoporro.actions.get_user_by_jwt import GetUserByJWTAction
|
||||
from huesoporro.actions.chatbot.create_or_update_chatbot import (
|
||||
CreateOrUpdateChatbotAction,
|
||||
)
|
||||
from huesoporro.actions.chatbot.get_chatbot_by_user_id import GetChatbotByUserIdAction
|
||||
from huesoporro.actions.users.authenticate_user import AuthenticateUserAction
|
||||
from huesoporro.actions.users.get_user_by_jwt import GetUserByJWTAction
|
||||
from huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from huesoporro.infra.db import Database
|
||||
from huesoporro.infra.repos import UserRepo
|
||||
from huesoporro.infra.repos import ChatbotRepo, UserRepo
|
||||
from huesoporro.libs.db import MarkovDatabase
|
||||
from huesoporro.models import User
|
||||
from huesoporro.models import Chatbot, User
|
||||
from huesoporro.settings import Settings
|
||||
from huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
|
||||
from huesoporro.svc.chatbot_svcs import (
|
||||
CreateChatbotSvc,
|
||||
GetChatbotByUserIdSvc,
|
||||
UpdateChatbotSvc,
|
||||
)
|
||||
from huesoporro.svc.store import SentenceStorerSvc
|
||||
from huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
|
||||
from huesoporro.svc.users_svcs import (
|
||||
CreateUserSvc,
|
||||
GetTwitchAuthByAuthCodeSvc,
|
||||
GetUserByUsernameSvc,
|
||||
IsValidTokenSvc,
|
||||
RefreshTokenSvc,
|
||||
UpdateUserSvc,
|
||||
)
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
|
|
@ -22,14 +36,62 @@ def get_authenticator(s: Settings) -> TwitchAuthenticator:
|
|||
return TwitchAuthenticator(s=s)
|
||||
|
||||
|
||||
def get_db(s: Settings):
|
||||
return Database(s=s)
|
||||
def get_chatbot_repo(s: Settings):
|
||||
return ChatbotRepo(s=s)
|
||||
|
||||
|
||||
def get_get_chatbot_by_user_id_svc(chatbot_repo: ChatbotRepo):
|
||||
return GetChatbotByUserIdSvc(repo=chatbot_repo)
|
||||
|
||||
|
||||
def get_get_tokens_by_auth_code_svc(
|
||||
twitch_authenticator: TwitchAuthenticator, s: Settings
|
||||
):
|
||||
return GetTwitchAuthByAuthCodeSvc(s=s, authenticator=twitch_authenticator)
|
||||
|
||||
|
||||
def get_create_chatbot_svc(chatbot_repo: ChatbotRepo):
|
||||
return CreateChatbotSvc(repo=chatbot_repo)
|
||||
|
||||
|
||||
async def get_user_repo(s: Settings):
|
||||
return UserRepo(s=s)
|
||||
|
||||
|
||||
def get_create_user_svc(user_repo: UserRepo):
|
||||
return CreateUserSvc(user_repo=user_repo)
|
||||
|
||||
|
||||
def get_update_user_svc(user_repo: UserRepo):
|
||||
return UpdateUserSvc(user_repo=user_repo)
|
||||
|
||||
|
||||
def get_refresh_token_svc(twitch_authenticator: TwitchAuthenticator):
|
||||
return RefreshTokenSvc(twitch_authenticator=twitch_authenticator)
|
||||
|
||||
|
||||
def get_is_valid_token_svc(twitch_authenticator: TwitchAuthenticator):
|
||||
return IsValidTokenSvc(authenticator=twitch_authenticator)
|
||||
|
||||
|
||||
async def get_get_user_by_username_svc(user_repo: UserRepo):
|
||||
return GetUserByUsernameSvc(user_repo=user_repo)
|
||||
|
||||
|
||||
async def get_get_user_by_jwt_action(
|
||||
user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings
|
||||
get_user_by_username_svc: GetUserByUsernameSvc,
|
||||
update_user_svc: UpdateUserSvc,
|
||||
is_valid_token_svc: IsValidTokenSvc,
|
||||
refresh_token_svc: RefreshTokenSvc,
|
||||
s: Settings,
|
||||
):
|
||||
return GetUserByJWTAction(user_repo=user_repo, authenticator=authenticator, s=s)
|
||||
return GetUserByJWTAction(
|
||||
get_user_by_username_svc=get_user_by_username_svc,
|
||||
update_user_svc=update_user_svc,
|
||||
refresh_token_svc=refresh_token_svc,
|
||||
is_valid_token_svc=is_valid_token_svc,
|
||||
s=s,
|
||||
)
|
||||
|
||||
|
||||
async def authenticate(
|
||||
|
|
@ -37,32 +99,73 @@ async def authenticate(
|
|||
) -> User:
|
||||
token = request.query_params.get("huesoporro_token")
|
||||
if token:
|
||||
return await get_user_by_jwt_action.run(token)
|
||||
user = await get_user_by_jwt_action.run(token)
|
||||
if not user:
|
||||
raise HTTPException(detail="User does not exist", status_code=404)
|
||||
return user
|
||||
|
||||
cookies = request.cookies.get("huesoporroAuth")
|
||||
if cookies:
|
||||
return await get_user_by_jwt_action.run(cookies)
|
||||
user = await get_user_by_jwt_action.run(cookies)
|
||||
if not user:
|
||||
raise HTTPException(detail="User does not exist", status_code=404)
|
||||
return user
|
||||
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
async def get_chatbot_settings_svc(db: Database):
|
||||
return ChatbotSettingsGetterSvc(db=db)
|
||||
|
||||
|
||||
async def store_chatbot_settings_svc(db: Database):
|
||||
return ChatbotSettingsStorerSvc(db=db)
|
||||
|
||||
|
||||
async def get_sentences_storer_svc(db: MarkovDatabase):
|
||||
return SentenceStorerSvc(db=db)
|
||||
|
||||
|
||||
async def get_user_repo(s: Settings):
|
||||
return UserRepo(s=s)
|
||||
def get_update_chatbot_svc(chatbot_repo: ChatbotRepo):
|
||||
return UpdateChatbotSvc(repo=chatbot_repo)
|
||||
|
||||
|
||||
def get_create_or_update_chatbot_action(
|
||||
create_chatbot_svc: CreateChatbotSvc,
|
||||
update_chatbot_svc: UpdateChatbotSvc,
|
||||
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc,
|
||||
):
|
||||
return CreateOrUpdateChatbotAction(
|
||||
create_chatbot_svc=create_chatbot_svc,
|
||||
update_chatbot_svc=update_chatbot_svc,
|
||||
get_chatbot_by_user_id_svc=get_chatbot_by_user_id_svc,
|
||||
)
|
||||
|
||||
|
||||
def get_get_chatbot_by_user_id_action(
|
||||
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc,
|
||||
):
|
||||
return GetChatbotByUserIdAction(
|
||||
get_chatbot_by_user_id_svc=get_chatbot_by_user_id_svc
|
||||
)
|
||||
|
||||
|
||||
async def get_authenticate_action(
|
||||
user_repo: UserRepo, authenticator: TwitchAuthenticator, s: Settings
|
||||
s: Settings,
|
||||
get_tokens_by_auth_code_svc: GetTwitchAuthByAuthCodeSvc,
|
||||
get_user_by_username_svc: GetUserByUsernameSvc,
|
||||
create_user_svc: CreateUserSvc,
|
||||
update_user_svc: UpdateUserSvc,
|
||||
):
|
||||
return AuthenticateAction(user_repo=user_repo, authenticator=authenticator, s=s)
|
||||
return AuthenticateUserAction(
|
||||
s=s,
|
||||
get_tokens_by_auth_code_svc=get_tokens_by_auth_code_svc,
|
||||
get_user_by_username_svc=get_user_by_username_svc,
|
||||
create_user_svc=create_user_svc,
|
||||
update_user_svc=update_user_svc,
|
||||
)
|
||||
|
||||
|
||||
async def chatbot(
|
||||
get_chatbot_by_user_id_action: GetChatbotByUserIdAction,
|
||||
create_or_update_chatbot_action: CreateOrUpdateChatbotAction,
|
||||
user: User,
|
||||
) -> Chatbot:
|
||||
cb = await get_chatbot_by_user_id_action.run(user_id=user.id)
|
||||
if cb:
|
||||
return cb
|
||||
return await create_or_update_chatbot_action.run(
|
||||
user_id=user.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,22 @@ from apps.httpapi.litestar.dependencies import (
|
|||
authenticate,
|
||||
get_authenticate_action,
|
||||
get_authenticator,
|
||||
get_chatbot_settings_svc,
|
||||
get_db,
|
||||
get_chatbot_repo,
|
||||
get_create_chatbot_svc,
|
||||
get_create_or_update_chatbot_action,
|
||||
get_create_user_svc,
|
||||
get_get_chatbot_by_user_id_action,
|
||||
get_get_chatbot_by_user_id_svc,
|
||||
get_get_tokens_by_auth_code_svc,
|
||||
get_get_user_by_jwt_action,
|
||||
get_get_user_by_username_svc,
|
||||
get_is_valid_token_svc,
|
||||
get_refresh_token_svc,
|
||||
get_sentences_storer_svc,
|
||||
get_settings,
|
||||
get_update_chatbot_svc,
|
||||
get_update_user_svc,
|
||||
get_user_repo,
|
||||
store_chatbot_settings_svc,
|
||||
)
|
||||
from apps.httpapi.litestar.errors import (
|
||||
after_exception_handler,
|
||||
|
|
@ -79,15 +88,26 @@ def create_app():
|
|||
"s": Provide(get_settings, use_cache=True),
|
||||
"a": Provide(get_authenticator, use_cache=True),
|
||||
"user": Provide(authenticate),
|
||||
"db": Provide(get_db, use_cache=True),
|
||||
"bm": Provide(BotsManager, use_cache=True),
|
||||
"gbs": Provide(get_chatbot_settings_svc),
|
||||
"sbs": Provide(store_chatbot_settings_svc),
|
||||
"sss": Provide(get_sentences_storer_svc),
|
||||
"authenticator": Provide(get_authenticator),
|
||||
"twitch_authenticator": Provide(get_authenticator),
|
||||
"authenticate_action": Provide(get_authenticate_action),
|
||||
"user_repo": Provide(get_user_repo),
|
||||
"chatbot_repo": Provide(get_chatbot_repo),
|
||||
"create_user_svc": Provide(get_create_user_svc),
|
||||
"update_chatbot_svc": Provide(get_update_chatbot_svc),
|
||||
"update_user_svc": Provide(get_update_user_svc),
|
||||
"create_chatbot_svc": Provide(get_create_chatbot_svc),
|
||||
"refresh_token_svc": Provide(get_refresh_token_svc),
|
||||
"is_valid_token_svc": Provide(get_is_valid_token_svc),
|
||||
"get_user_by_username_svc": Provide(get_get_user_by_username_svc),
|
||||
"get_chatbot_by_user_id_svc": Provide(get_get_chatbot_by_user_id_svc),
|
||||
"get_tokens_by_auth_code_svc": Provide(get_get_tokens_by_auth_code_svc),
|
||||
"get_user_by_jwt_action": Provide(get_get_user_by_jwt_action),
|
||||
"get_chatbot_by_user_id_action": Provide(get_get_chatbot_by_user_id_action),
|
||||
"create_or_update_chatbot_action": Provide(
|
||||
get_create_or_update_chatbot_action
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@ from litestar.datastructures import UploadFile
|
|||
from litestar.response import Template
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from huesoporro.actions.chatbot.create_or_update_chatbot import (
|
||||
CreateOrUpdateChatbotAction,
|
||||
)
|
||||
from huesoporro.actions.chatbot.get_chatbot_by_user_id import GetChatbotByUserIdAction
|
||||
from huesoporro.bot import BotsManager
|
||||
from huesoporro.models import ChatbotSettings, User
|
||||
from huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
|
||||
from huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
|
||||
from huesoporro.models import Chatbot, User
|
||||
|
||||
|
||||
class ManageBotDTO(BaseModel):
|
||||
|
|
@ -23,6 +25,12 @@ class ImportTextFileDTO(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class UpdateChatbotDTO(BaseModel):
|
||||
automatic_generation_timer: int = 300
|
||||
automatic_quote_timer: int = 500
|
||||
mods: list[str]
|
||||
|
||||
|
||||
@get(
|
||||
"/tts",
|
||||
media_type=MediaType.HTML,
|
||||
|
|
@ -48,8 +56,10 @@ async def get_tts_permalink(access_token: str) -> Template:
|
|||
"/",
|
||||
media_type=MediaType.HTML,
|
||||
)
|
||||
async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
async def get_index(
|
||||
user: User, get_chatbot_by_user_id_action: GetChatbotByUserIdAction
|
||||
) -> Template:
|
||||
chatbot_settings = await get_chatbot_by_user_id_action.run(user_id=user.id)
|
||||
return Template(
|
||||
template_name="index.html",
|
||||
context=chatbot_settings.model_dump() if chatbot_settings else {},
|
||||
|
|
@ -60,22 +70,24 @@ async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
|||
async def manage_bot(
|
||||
user: User,
|
||||
data: ManageBotDTO,
|
||||
gbs: ChatbotSettingsGetterSvc,
|
||||
sbs: ChatbotSettingsStorerSvc,
|
||||
create_or_update_chatbot_action: CreateOrUpdateChatbotAction,
|
||||
get_chatbot_by_user_id_action: GetChatbotByUserIdAction,
|
||||
bm: BotsManager,
|
||||
) -> Response:
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
if not chatbot_settings:
|
||||
await sbs.run(user=user, bot_settings=ChatbotSettings())
|
||||
chatbot_settings = await gbs.run(user=user)
|
||||
chatbot = await get_chatbot_by_user_id_action.run(
|
||||
user_id=user.id
|
||||
) or await create_or_update_chatbot_action.run(
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
if data.command == "start":
|
||||
if not data.channel_name:
|
||||
return Response({"message": "Channel name is required"}, status_code=400)
|
||||
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
|
||||
if user.user in bm.bots:
|
||||
bm.add_bot(user, data.channel_name, chatbot=chatbot) # type: ignore[arg-type]
|
||||
if user.username in bm.bots:
|
||||
await bm.run_user_bot(user)
|
||||
return Response({"message": "Bot started"})
|
||||
if data.command == "stop" and user.user in bm.bots:
|
||||
if data.command == "stop" and user.username in bm.bots:
|
||||
await bm.stop_user_bot(user)
|
||||
return Response({"message": "Bot stopped"})
|
||||
return Response({"message": "Invalid command"}, status_code=400)
|
||||
|
|
@ -83,24 +95,26 @@ async def manage_bot(
|
|||
|
||||
@get("/api/v1/bot")
|
||||
async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
||||
if user.user not in bm.bots:
|
||||
if user.username not in bm.bots:
|
||||
return {"status": "ko"}
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@get("/api/v1/bot/settings")
|
||||
async def get_bot_settings(
|
||||
user: User, gbs: ChatbotSettingsGetterSvc
|
||||
) -> ChatbotSettings | dict:
|
||||
cbs = await gbs.run(user=user)
|
||||
if not cbs:
|
||||
return {"status": "Not found"}
|
||||
return cbs
|
||||
async def get_bot_settings(chatbot: Chatbot) -> Chatbot:
|
||||
return chatbot
|
||||
|
||||
|
||||
@put("/api/v1/bot/settings")
|
||||
async def save_bot_settings(
|
||||
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
|
||||
user: User,
|
||||
data: UpdateChatbotDTO,
|
||||
create_or_update_chatbot_action: CreateOrUpdateChatbotAction,
|
||||
) -> dict:
|
||||
await sbs.run(user=user, bot_settings=data)
|
||||
await create_or_update_chatbot_action.run(
|
||||
user_id=user.id,
|
||||
automatic_generation_timer=data.automatic_generation_timer,
|
||||
automatic_quote_timer=data.automatic_quote_timer,
|
||||
mods=data.mods,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ from litestar import MediaType, get
|
|||
from litestar.datastructures.cookie import Cookie
|
||||
from litestar.response import Redirect, Template
|
||||
|
||||
from huesoporro.actions.authenticate import AuthenticateAction
|
||||
from huesoporro.actions.users.authenticate_user import AuthenticateUserAction
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
@get(path="/o/code")
|
||||
async def get_code(code: str, authenticate_action: AuthenticateAction) -> Redirect:
|
||||
token = await authenticate_action.run(code)
|
||||
async def get_code(code: str, authenticate_action: AuthenticateUserAction) -> Redirect:
|
||||
user = await authenticate_action.run(code)
|
||||
token = user.encode()
|
||||
return Redirect(
|
||||
"/",
|
||||
cookies=[
|
||||
|
|
|
|||
|
|
@ -1,27 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from huesoporro.infra.repos import UserRepo
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class AuthenticateAction(BaseModel):
|
||||
user_repo: UserRepo
|
||||
authenticator: TwitchAuthenticator
|
||||
s: Settings
|
||||
|
||||
async def run(
|
||||
self,
|
||||
auth_code: str,
|
||||
):
|
||||
tokens = await self.authenticator.get_token(auth_code)
|
||||
username = tokens.userinfo["preferred_username"]
|
||||
if username not in self.s.allowed_users:
|
||||
raise ValueError(f"User {username} is not allowed to use this bot")
|
||||
user = User(user=username, external_auth={"twitch": tokens.model_dump()})
|
||||
if await self.user_repo.get_by_user(user.user):
|
||||
await self.user_repo.update(user)
|
||||
else:
|
||||
await self.user_repo.create(user)
|
||||
return user.encode()
|
||||
0
src/huesoporro/actions/chatbot/__init__.py
Normal file
0
src/huesoporro/actions/chatbot/__init__.py
Normal file
40
src/huesoporro/actions/chatbot/create_or_update_chatbot.py
Normal file
40
src/huesoporro/actions/chatbot/create_or_update_chatbot.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import Chatbot
|
||||
from huesoporro.svc.chatbot_svcs import (
|
||||
CreateChatbotSvc,
|
||||
GetChatbotByUserIdSvc,
|
||||
UpdateChatbotSvc,
|
||||
)
|
||||
|
||||
|
||||
class CreateOrUpdateChatbotAction(BaseModel):
|
||||
create_chatbot_svc: CreateChatbotSvc
|
||||
update_chatbot_svc: UpdateChatbotSvc
|
||||
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc
|
||||
|
||||
async def run(
|
||||
self,
|
||||
user_id: UUID,
|
||||
automatic_generation_timer: int = 300,
|
||||
automatic_quote_timer: int = 500,
|
||||
mods: list[str] | None = None,
|
||||
) -> Chatbot:
|
||||
mods = mods or []
|
||||
chatbot = await self.get_chatbot_by_user_id_svc.run(user_id=user_id)
|
||||
if chatbot:
|
||||
chatbot.automatic_generation_timer = automatic_generation_timer
|
||||
chatbot.automatic_quote_timer = automatic_quote_timer
|
||||
chatbot.mods = mods
|
||||
return await self.update_chatbot_svc.run(chatbot=chatbot)
|
||||
chatbot = Chatbot(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
automatic_generation_timer=automatic_generation_timer,
|
||||
automatic_quote_timer=automatic_quote_timer,
|
||||
mods=mods,
|
||||
)
|
||||
return await self.create_chatbot_svc.run(chatbot=chatbot)
|
||||
16
src/huesoporro/actions/chatbot/get_chatbot_by_user_id.py
Normal file
16
src/huesoporro/actions/chatbot/get_chatbot_by_user_id.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import Chatbot
|
||||
from huesoporro.svc.chatbot_svcs import GetChatbotByUserIdSvc
|
||||
|
||||
|
||||
class GetChatbotByUserIdAction(BaseModel):
|
||||
get_chatbot_by_user_id_svc: GetChatbotByUserIdSvc
|
||||
|
||||
async def run(
|
||||
self,
|
||||
user_id: UUID,
|
||||
) -> Chatbot | None:
|
||||
return await self.get_chatbot_by_user_id_svc.run(user_id=user_id)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import Quote
|
||||
from huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||
|
||||
|
||||
class GetRandomQuoteAction(BaseModel):
|
||||
quote_getter_svc: RandomQuoteGetterSvc
|
||||
|
||||
async def run(self, channel_name: str) -> Quote | None:
|
||||
return await self.quote_getter_svc.run(channel_name=channel_name)
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from huesoporro.infra.repos import UserRepo
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class GetUserByJWTAction(BaseModel):
|
||||
user_repo: UserRepo
|
||||
authenticator: TwitchAuthenticator
|
||||
s: Settings
|
||||
|
||||
async def run(
|
||||
self,
|
||||
jwt_token: str,
|
||||
) -> User:
|
||||
user_data = User.decode(jwt_token)
|
||||
username = user_data["user"]
|
||||
user = await self.user_repo.get_by_user(username)
|
||||
if not user:
|
||||
raise ValueError(f"User {username} not found")
|
||||
is_valid = await self.authenticator.token_is_valid(
|
||||
user.external_auth["twitch"]["access_token"]
|
||||
)
|
||||
|
||||
logger.info(f"Token {user} is valid: {is_valid}")
|
||||
if is_valid:
|
||||
return user
|
||||
|
||||
logger.info(f"Refreshing token for user {user}")
|
||||
twitch_auth = await self.authenticator.refresh_token(
|
||||
user.external_auth["twitch"]["refresh_token"]
|
||||
)
|
||||
user.external_auth["twitch"]["access_token"] = twitch_auth.access_token
|
||||
user.external_auth["twitch"]["refresh_token"] = twitch_auth.refresh_token
|
||||
await self.user_repo.update(user)
|
||||
return user
|
||||
0
src/huesoporro/actions/quotes/__init__.py
Normal file
0
src/huesoporro/actions/quotes/__init__.py
Normal file
|
|
@ -1,14 +1,15 @@
|
|||
import datetime
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import Quote, User
|
||||
from huesoporro.svc.is_mod import IsModSvc
|
||||
from huesoporro.svc.quote_storer_svc import QuoteStorerSvc
|
||||
from huesoporro.svc.quotes_svcs import CreateQuoteSvc
|
||||
|
||||
|
||||
class StoreQuoteAction(BaseModel):
|
||||
quote_storer_svc: QuoteStorerSvc
|
||||
class CreateQuoteAction(BaseModel):
|
||||
create_quote_svc: CreateQuoteSvc
|
||||
is_mod_svc: IsModSvc
|
||||
|
||||
async def run(
|
||||
|
|
@ -17,10 +18,11 @@ class StoreQuoteAction(BaseModel):
|
|||
if not await self.is_mod_svc.run(user=user, username=username, channel=channel):
|
||||
return None
|
||||
new_quote = Quote(
|
||||
id=uuid.uuid4(),
|
||||
quote=quote,
|
||||
author=User(user=author, external_auth={}),
|
||||
channel=User(user=channel, external_auth={}),
|
||||
author=author,
|
||||
channel_name=channel,
|
||||
created_at=datetime.datetime.now(datetime.UTC),
|
||||
last_updated_at=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
return await self.quote_storer_svc.run(new_quote)
|
||||
return await self.create_quote_svc.run(new_quote)
|
||||
11
src/huesoporro/actions/quotes/get_random_quote.py
Normal file
11
src/huesoporro/actions/quotes/get_random_quote.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import Quote
|
||||
from huesoporro.svc.quotes_svcs import GetRandomQuoteSvc
|
||||
|
||||
|
||||
class GetRandomQuoteAction(BaseModel):
|
||||
get_random_quote_svc: GetRandomQuoteSvc
|
||||
|
||||
async def run(self, channel_name: str) -> Quote | None:
|
||||
return await self.get_random_quote_svc.run(channel_name=channel_name)
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from huesoporro.infra.repos import UserRepo
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class RefreshAction(BaseModel):
|
||||
user_repo: UserRepo
|
||||
authenticator: TwitchAuthenticator
|
||||
s: Settings
|
||||
|
||||
async def run(self, user: User) -> str | None:
|
||||
is_valid = await self.authenticator.token_is_valid(
|
||||
user.external_auth["twitch"]["access_token"]
|
||||
)
|
||||
if is_valid:
|
||||
return None
|
||||
|
||||
twitch_auth = await self.authenticator.refresh_token(
|
||||
user.external_auth["twitch"]["refresh_token"]
|
||||
)
|
||||
user.external_auth["twitch"]["access_token"] = twitch_auth.access_token
|
||||
user.external_auth["twitch"]["refresh_token"] = twitch_auth.refresh_token
|
||||
await self.user_repo.update(user)
|
||||
return user.encode()
|
||||
0
src/huesoporro/actions/users/__init__.py
Normal file
0
src/huesoporro/actions/users/__init__.py
Normal file
38
src/huesoporro/actions/users/authenticate_user.py
Normal file
38
src/huesoporro/actions/users/authenticate_user.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
from huesoporro.svc.users_svcs import (
|
||||
CreateUserSvc,
|
||||
GetTwitchAuthByAuthCodeSvc,
|
||||
GetUserByUsernameSvc,
|
||||
UpdateUserSvc,
|
||||
)
|
||||
|
||||
|
||||
class AuthenticateUserAction(BaseModel):
|
||||
get_tokens_by_auth_code_svc: GetTwitchAuthByAuthCodeSvc
|
||||
get_user_by_username_svc: GetUserByUsernameSvc
|
||||
create_user_svc: CreateUserSvc
|
||||
update_user_svc: UpdateUserSvc
|
||||
s: Settings
|
||||
|
||||
async def run(
|
||||
self,
|
||||
auth_code: str,
|
||||
) -> User:
|
||||
auth = await self.get_tokens_by_auth_code_svc.run(auth_code=auth_code)
|
||||
username = auth.userinfo["preferred_username"]
|
||||
user = await self.get_user_by_username_svc.run(username=username)
|
||||
if user:
|
||||
user.external_auth = {"twitch": auth}
|
||||
await self.update_user_svc.run(user)
|
||||
return user
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username=username,
|
||||
external_auth={"twitch": auth},
|
||||
)
|
||||
return await self.create_user_svc.run(user=user)
|
||||
40
src/huesoporro/actions/users/get_user_by_jwt.py
Normal file
40
src/huesoporro/actions/users/get_user_by_jwt.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
from huesoporro.svc.users_svcs import (
|
||||
GetUserByUsernameSvc,
|
||||
IsValidTokenSvc,
|
||||
RefreshTokenSvc,
|
||||
UpdateUserSvc,
|
||||
)
|
||||
|
||||
|
||||
class GetUserByJWTAction(BaseModel):
|
||||
get_user_by_username_svc: GetUserByUsernameSvc
|
||||
update_user_svc: UpdateUserSvc
|
||||
refresh_token_svc: RefreshTokenSvc
|
||||
is_valid_token_svc: IsValidTokenSvc
|
||||
s: Settings
|
||||
|
||||
async def run(
|
||||
self,
|
||||
jwt_token: str,
|
||||
) -> User | None:
|
||||
user_data = User.decode(jwt_token, settings=self.s)
|
||||
username = user_data["username"]
|
||||
user = await self.get_user_by_username_svc.run(username=username)
|
||||
if not user:
|
||||
raise ValueError(f"User {username} not found")
|
||||
if await self.is_valid_token_svc.run(user=user):
|
||||
logger.info(
|
||||
f"User {username} has a correct twitch authentication token, returning user"
|
||||
)
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
f"User {username} has an invalid twitch authentication token, refreshing it"
|
||||
)
|
||||
user = await self.refresh_token_svc.run(user=user)
|
||||
return await self.update_user_svc.run(user=user)
|
||||
26
src/huesoporro/actions/users/refresh_user_jwt.py
Normal file
26
src/huesoporro/actions/users/refresh_user_jwt.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.models import User
|
||||
from huesoporro.settings import Settings
|
||||
from huesoporro.svc.users_svcs import (
|
||||
GetUserByUsernameSvc,
|
||||
IsValidTokenSvc,
|
||||
RefreshTokenSvc,
|
||||
UpdateUserSvc,
|
||||
)
|
||||
|
||||
|
||||
class RefreshUserJwtAction(BaseModel):
|
||||
get_user_by_username_svc: GetUserByUsernameSvc
|
||||
update_user_svc: UpdateUserSvc
|
||||
refresh_token_svc: RefreshTokenSvc
|
||||
is_valid_token_svc: IsValidTokenSvc
|
||||
s: Settings
|
||||
|
||||
async def run(self, user: User) -> User | None:
|
||||
"""Return None if the user has a valid token, otherwise refresh it and return the new token"""
|
||||
if await self.is_valid_token_svc.run(user=user):
|
||||
return None
|
||||
|
||||
user = await self.refresh_token_svc.run(user=user)
|
||||
return await self.update_user_svc.run(user=user)
|
||||
|
|
@ -8,47 +8,47 @@ from loguru import logger
|
|||
from twitchio import Channel
|
||||
from twitchio.ext import commands, routines
|
||||
|
||||
from huesoporro.actions.get_random_quote import GetRandomQuoteAction
|
||||
from huesoporro.actions.store_quote import StoreQuoteAction
|
||||
from huesoporro.infra.db import Database
|
||||
from huesoporro.infra.repos import QuoteRepo
|
||||
from huesoporro.actions.quotes.create_quote_action import CreateQuoteAction
|
||||
from huesoporro.actions.quotes.get_random_quote import GetRandomQuoteAction
|
||||
from huesoporro.infra.repos import ChatbotRepo, QuoteRepo
|
||||
from huesoporro.libs.db import MarkovDatabase
|
||||
from huesoporro.models import ChatbotSettings, User
|
||||
from huesoporro.models import Chatbot, User
|
||||
from huesoporro.settings import Settings
|
||||
from huesoporro.svc.backoff_service import BackoffService
|
||||
from huesoporro.svc.generate import SentenceGeneratorSvc
|
||||
from huesoporro.svc.get_random_quote import RandomQuoteGetterSvc
|
||||
from huesoporro.svc.hello import get_hello_generator_svc
|
||||
from huesoporro.svc.is_mod import IsModSvc
|
||||
from huesoporro.svc.quote_storer_svc import QuoteStorerSvc
|
||||
from huesoporro.svc.quotes_svcs import CreateQuoteSvc, GetRandomQuoteSvc
|
||||
from huesoporro.svc.store import SentenceStorerSvc
|
||||
|
||||
|
||||
class Bot(commands.Bot):
|
||||
def __init__(self, user: User, chatbot_settings: ChatbotSettings, channel: str):
|
||||
def __init__dependencies(self, channel: str, settings: Settings):
|
||||
self.quote_repo = QuoteRepo(s=settings)
|
||||
self.chatbot_repo = ChatbotRepo(s=settings)
|
||||
self.get_random_quote_action = GetRandomQuoteAction(
|
||||
get_random_quote_svc=GetRandomQuoteSvc(repo=self.quote_repo)
|
||||
)
|
||||
self.create_quote_action = CreateQuoteAction(
|
||||
create_quote_svc=CreateQuoteSvc(repo=self.quote_repo),
|
||||
is_mod_svc=IsModSvc(repo=self.chatbot_repo),
|
||||
)
|
||||
self.generate_svc = SentenceGeneratorSvc(db=MarkovDatabase(channel=channel))
|
||||
self.hello_svc = get_hello_generator_svc()
|
||||
|
||||
def __init__(self, user: User, chatbot: Chatbot, channel: str, settings: Settings):
|
||||
super().__init__(
|
||||
token=user.twitch_access_token, prefix="!", initial_channels=[channel]
|
||||
)
|
||||
self.__init__dependencies(channel=channel, settings=settings)
|
||||
self.channel = channel
|
||||
self.user = user
|
||||
self.generate_svc = SentenceGeneratorSvc(db=MarkovDatabase(channel=channel))
|
||||
self.hello_svc = get_hello_generator_svc()
|
||||
db = Database()
|
||||
self.quote_repo = QuoteRepo(s=Settings.get())
|
||||
self.get_random_quote_svc = RandomQuoteGetterSvc(quote_repo=self.quote_repo)
|
||||
self.get_random_quote_action = GetRandomQuoteAction(
|
||||
quote_getter_svc=self.get_random_quote_svc
|
||||
)
|
||||
self.store_quote_action = StoreQuoteAction(
|
||||
quote_storer_svc=QuoteStorerSvc(quote_repo=self.quote_repo),
|
||||
is_mod_svc=IsModSvc(db=db),
|
||||
)
|
||||
self.cbs = chatbot_settings
|
||||
self.chatbot = chatbot
|
||||
self.quote_routine = routines.routine(
|
||||
seconds=chatbot_settings.automatic_quote_timer, wait_first=True
|
||||
seconds=chatbot.automatic_quote_timer, wait_first=True
|
||||
)(self.send_quote)
|
||||
self.generation_routine = routines.routine(
|
||||
seconds=chatbot_settings.automatic_generation_timer, wait_first=True
|
||||
seconds=chatbot.automatic_generation_timer, wait_first=True
|
||||
)(self.send_generation)
|
||||
|
||||
async def event_ready(self):
|
||||
|
|
@ -69,7 +69,7 @@ class Bot(commands.Bot):
|
|||
async def add_quote(self, ctx: commands.Context, *, quote: str):
|
||||
# extract author from quote; the author is the last word
|
||||
quote, author = quote.rsplit(" ", 1)
|
||||
new_quote = await self.store_quote_action.run(
|
||||
new_quote = await self.create_quote_action.run(
|
||||
user=self.user,
|
||||
channel=self.channel,
|
||||
quote=quote,
|
||||
|
|
@ -107,10 +107,10 @@ class Bot(commands.Bot):
|
|||
await channel.send(sentence)
|
||||
|
||||
def start_routines(self):
|
||||
if self.cbs.automatic_quote_timer > 0:
|
||||
if self.chatbot.automatic_quote_timer > 0:
|
||||
logger.info("Starting quote routine")
|
||||
self.quote_routine.start(stop_on_error=False)
|
||||
if self.cbs.automatic_generation_timer > 0:
|
||||
if self.chatbot.automatic_generation_timer > 0:
|
||||
logger.info("Starting generation routine")
|
||||
self.generation_routine.start(stop_on_error=False)
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ class HelloMessagesCog(commands.Cog):
|
|||
if message.content in self.hello_patterns:
|
||||
hello = self.hello_svc.run(message.author.name)
|
||||
if hello:
|
||||
await message.channel.send(hello)
|
||||
await message.channel_name.send(hello)
|
||||
|
||||
|
||||
class MessageType(StrEnum):
|
||||
|
|
@ -183,8 +183,10 @@ class MessageHandler:
|
|||
class SaveMessagesCog(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.store_svc = SentenceStorerSvc(db=MarkovDatabase(channel=bot.channel))
|
||||
self.generate_svc = SentenceGeneratorSvc(db=MarkovDatabase(channel=bot.channel))
|
||||
self.store_svc = SentenceStorerSvc(db=MarkovDatabase(channel=bot.channel_name))
|
||||
self.generate_svc = SentenceGeneratorSvc(
|
||||
db=MarkovDatabase(channel=bot.channel_name)
|
||||
)
|
||||
self.backoff_svc = BackoffService()
|
||||
self.message_handler = MessageHandler(self._send_message)
|
||||
|
||||
|
|
@ -202,7 +204,7 @@ class SaveMessagesCog(commands.Cog):
|
|||
|
||||
async def typed_send(content: str):
|
||||
if hasattr(self, "current_message"):
|
||||
await self.current_message.channel.send(content)
|
||||
await self.current_message.channel_name.send(content)
|
||||
|
||||
# Set a unique name for the function to ensure it's treated as distinct
|
||||
typed_send.__name__ = f"send_{type_name}"
|
||||
|
|
@ -211,7 +213,7 @@ class SaveMessagesCog(commands.Cog):
|
|||
async def _send_message(self, content: str):
|
||||
"""Generic send message function (for non-backoff uses)"""
|
||||
if hasattr(self, "current_message"):
|
||||
await self.current_message.channel.send(content)
|
||||
await self.current_message.channel_name.send(content)
|
||||
|
||||
@commands.Cog.event()
|
||||
async def event_message(self, message):
|
||||
|
|
@ -248,32 +250,36 @@ class SaveMessagesCog(commands.Cog):
|
|||
|
||||
|
||||
class BotsManager:
|
||||
def __init__(self):
|
||||
def __init__(self, s: Settings):
|
||||
self.bots: dict[str, Bot] = {}
|
||||
self.s = s
|
||||
|
||||
def add_bot(self, user: User, channel: str, chatbot_settings: ChatbotSettings):
|
||||
if user.user in self.bots:
|
||||
logger.info(f"Bot for {user.user} already exists")
|
||||
def add_bot(self, user: User, channel: str, chatbot: Chatbot):
|
||||
if user.username in self.bots:
|
||||
logger.info(f"Bot for {user.username} already exists")
|
||||
return
|
||||
logger.info(f"Adding bot for {user.user}")
|
||||
bot = Bot(user=user, channel=channel, chatbot_settings=chatbot_settings)
|
||||
|
||||
logger.info(f"Adding bot for {user.username}")
|
||||
bot = Bot(user=user, channel=channel, chatbot=chatbot, settings=self.s)
|
||||
bot.add_cog(SaveMessagesCog(bot))
|
||||
bot.add_cog(HelloMessagesCog(bot))
|
||||
self.bots[user.user] = bot
|
||||
self.bots[user.username] = bot
|
||||
|
||||
async def run_user_bot(self, user: User):
|
||||
if user.user not in self.bots:
|
||||
if user.username not in self.bots:
|
||||
return
|
||||
|
||||
logger.info(f"Starting bot for {user.user}")
|
||||
bot = self.bots[user.user]
|
||||
logger.info(f"Starting bot for {user.username}")
|
||||
bot = self.bots[user.username]
|
||||
task = asyncio.create_task(bot.start())
|
||||
task.add_done_callback(lambda x: logger.info(f"Bot for {user.user} stopped"))
|
||||
task.add_done_callback(
|
||||
lambda x: logger.info(f"Bot for {user.username} stopped")
|
||||
)
|
||||
bot.start_routines()
|
||||
|
||||
async def stop_user_bot(self, user: User):
|
||||
if user.user not in self.bots:
|
||||
if user.username not in self.bots:
|
||||
return
|
||||
bot = self.bots.pop(user.user)
|
||||
bot = self.bots.pop(user.username)
|
||||
await bot.close()
|
||||
bot.stop_routines()
|
||||
|
|
|
|||
|
|
@ -17,11 +17,11 @@ class TwitchAuthenticator(BaseModel):
|
|||
response = await self.client.post(
|
||||
"/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.get().twitch_client_id,
|
||||
"client_secret": Settings.get().twitch_client_secret.get_secret_value(),
|
||||
"client_id": self.s.twitch_client_id,
|
||||
"client_secret": self.s.twitch_client_secret.get_secret_value(),
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": f"{Settings.get().server_hostname}o/code",
|
||||
"redirect_uri": f"{self.s.server_hostname}o/code",
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiosqlite
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from huesoporro.models import ChatbotSettings, User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class Database(BaseModel):
|
||||
s: Settings = Field(default_factory=Settings.get)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_client(self, auto_commit=True):
|
||||
logger.info(f"Opening database connection: {self.s.db_filepath}")
|
||||
async with aiosqlite.connect(self.s.db_filepath) as db:
|
||||
yield db
|
||||
if auto_commit:
|
||||
await db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_now() -> float:
|
||||
return datetime.datetime.now(datetime.UTC).timestamp()
|
||||
|
||||
async def save_chatbot_settings(
|
||||
self, user: User, chatbot_settings: ChatbotSettings, auto_commit: bool = True
|
||||
):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
current_settings = await self.get_chatbot_settings(user)
|
||||
if current_settings:
|
||||
await db.execute(
|
||||
"""UPDATE settings SET
|
||||
automatic_generation_timer = ?,
|
||||
automatic_quote_timer = ?,
|
||||
mods = ?,
|
||||
last_updated_at = ?
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(
|
||||
chatbot_settings.automatic_generation_timer,
|
||||
chatbot_settings.automatic_quote_timer,
|
||||
chatbot_settings.mods_as_string,
|
||||
self.get_now(),
|
||||
user.user,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
await db.execute(
|
||||
"""INSERT INTO settings (
|
||||
user_id,
|
||||
automatic_generation_timer,
|
||||
automatic_quote_timer,
|
||||
mods,
|
||||
created_at,
|
||||
last_updated_at
|
||||
) VALUES(?,?,?,?,?,?)
|
||||
""",
|
||||
(
|
||||
user.user,
|
||||
chatbot_settings.automatic_generation_timer,
|
||||
chatbot_settings.automatic_quote_timer,
|
||||
chatbot_settings.mods_as_string,
|
||||
self.get_now(),
|
||||
self.get_now(),
|
||||
),
|
||||
)
|
||||
|
||||
async def get_chatbot_settings(self, user: User) -> ChatbotSettings | None:
|
||||
async with self.get_client() as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM settings WHERE user_id = ?", (user.user,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
return None
|
||||
return ChatbotSettings(**dict(result))
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
from collections import deque
|
||||
from hashlib import sha512
|
||||
from pathlib import Path
|
||||
|
||||
from gtts import gTTS
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class GTTS(BaseModel):
|
||||
s: Settings
|
||||
chunk_size: int = 128
|
||||
text_max_length: int = 100
|
||||
queue: deque = deque()
|
||||
|
||||
async def generate(self, text: str, lang: str = "pt", tld="com.br") -> Path:
|
||||
text = text[: self.text_max_length]
|
||||
raw_filename = f"{text.lower()}_{lang}_{tld}"
|
||||
logger.info(f"Generating TTS for {raw_filename}")
|
||||
filepath = (
|
||||
self.s.tts_cache_path / f"{sha512(raw_filename.encode()).hexdigest()}.mp3"
|
||||
)
|
||||
tts = gTTS(text=text, lang=lang, tld=tld)
|
||||
logger.info(f"Saving TTS to {filepath}")
|
||||
tts.save(str(filepath))
|
||||
self.queue.append(filepath)
|
||||
return filepath
|
||||
|
||||
async def consume(self):
|
||||
"""If there are items in the queue, return a generator
|
||||
that reads the file's bytes by chunks of self.chunk_size"""
|
||||
while self.queue:
|
||||
filepath = self.queue.popleft()
|
||||
if not filepath.exists():
|
||||
logger.warning(f"File {filepath} does not exist, skipping")
|
||||
continue
|
||||
|
||||
logger.info(f"Reading file {filepath}")
|
||||
try:
|
||||
with filepath.open("rb") as f:
|
||||
while chunk := f.read(self.chunk_size):
|
||||
yield chunk
|
||||
logger.info(f"Finished reading {filepath}")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
|
@ -2,11 +2,13 @@ import json
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Generic, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
import aiosqlite
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from huesoporro.models import Quote, User
|
||||
from huesoporro import utils
|
||||
from huesoporro.models import Chatbot, Quote, User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
|
@ -25,75 +27,116 @@ class IRepo(BaseModel, ABC, Generic[T]):
|
|||
|
||||
@abstractmethod
|
||||
async def create(self, obj: T, auto_commit=True) -> T:
|
||||
pass
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, obj: T, auto_commit=True) -> T:
|
||||
pass
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, obj: T, auto_commit=True):
|
||||
pass
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, obj_id: int | str, auto_commit=True) -> T | None:
|
||||
pass
|
||||
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> T | None:
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def list(
|
||||
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True
|
||||
) -> list[T]:
|
||||
pass
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class UserRepo(IRepo[User]):
|
||||
async def get_by_id(self, obj_id: int | str, auto_commit=True) -> User | None:
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
@staticmethod
|
||||
def _deserialize(data: dict) -> User:
|
||||
return User(
|
||||
id=UUID(data["id"]),
|
||||
username=data["username"],
|
||||
created_at=data["created_at"],
|
||||
last_updated_at=data["last_updated_at"],
|
||||
external_auth=json.loads(data["external_auth"]),
|
||||
)
|
||||
|
||||
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> User | None:
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
await db.execute(
|
||||
"""
|
||||
SELECT * FROM users WHERE id = ?
|
||||
""",
|
||||
(obj_id.hex,),
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
if not data:
|
||||
return None
|
||||
return self._deserialize(data)
|
||||
|
||||
async def create(self, obj: User, auto_commit=True) -> User:
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
if await self.get_by_username(obj.username):
|
||||
raise ValueError(f"User {obj.username} already exists")
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
await db.execute(
|
||||
"INSERT INTO users (user, external_auth) VALUES (?, ?)",
|
||||
(obj.user, json.dumps(obj.external_auth)),
|
||||
)
|
||||
return obj
|
||||
"""INSERT INTO users (id, username, external_auth, created_at, last_updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
RETURNING *
|
||||
""",
|
||||
(
|
||||
obj.id.hex,
|
||||
obj.username,
|
||||
obj.serialize_external_auth(),
|
||||
obj.created_at,
|
||||
obj.last_updated_at,
|
||||
),
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
return self._deserialize(data)
|
||||
|
||||
async def update(self, obj: User, auto_commit=True) -> User:
|
||||
if not await self.get_by_user(obj.user):
|
||||
raise ValueError(f"User {obj.user} does not exist")
|
||||
if not await self.get_by_id(obj.id):
|
||||
raise ValueError(f"User {obj.username} does not exist")
|
||||
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET external_auth = ?
|
||||
WHERE user = ?
|
||||
SET username = ?,
|
||||
external_auth = ?,
|
||||
last_updated_at = ?
|
||||
WHERE id = ?
|
||||
RETURNING *
|
||||
""",
|
||||
(json.dumps(obj.external_auth), obj.user),
|
||||
(
|
||||
obj.username,
|
||||
obj.serialize_external_auth(),
|
||||
obj.last_updated_at,
|
||||
obj.id.hex,
|
||||
),
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
return User(
|
||||
user=data["user"], external_auth=json.loads(data["external_auth"])
|
||||
)
|
||||
return self._deserialize(data)
|
||||
|
||||
async def delete(self, obj: User, auto_commit=True):
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM users WHERE user = ?
|
||||
DELETE FROM users WHERE id = ?
|
||||
""",
|
||||
(obj.user,),
|
||||
(obj.id.hex,),
|
||||
)
|
||||
|
||||
async def get_by_user(self, user: str, auto_commit=True) -> User | None:
|
||||
async def get_by_username(self, user: str, auto_commit=True) -> User | None:
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
db.execute(
|
||||
"""
|
||||
SELECT * FROM users WHERE user = ?
|
||||
SELECT * FROM users WHERE username = ?
|
||||
""",
|
||||
(user,),
|
||||
) as cursor,
|
||||
|
|
@ -102,49 +145,69 @@ class UserRepo(IRepo[User]):
|
|||
if not data:
|
||||
return None
|
||||
return User(
|
||||
user=data["user"], external_auth=json.loads(data["external_auth"])
|
||||
id=UUID(data["id"]),
|
||||
username=data["username"],
|
||||
created_at=data["created_at"],
|
||||
last_updated_at=data["last_updated_at"],
|
||||
external_auth=json.loads(data["external_auth"]),
|
||||
)
|
||||
|
||||
async def list(
|
||||
async def list( # type: ignore[empty-body]
|
||||
self, obj: User, offset: int = 0, limit: int = 10, auto_commit=True
|
||||
) -> list[User]:
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
pass # pragma: no cover
|
||||
|
||||
async def count(self, obj: User, auto_commit=True):
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class QuoteRepo(IRepo[Quote]):
|
||||
@staticmethod
|
||||
def _deserialize(data: dict) -> Quote:
|
||||
return Quote(
|
||||
id=UUID(data["id"]),
|
||||
quote=data["quote"],
|
||||
author=data["author"],
|
||||
channel_name=data["channel"],
|
||||
created_at=data["created_at"],
|
||||
last_updated_at=data["last_updated_at"],
|
||||
)
|
||||
|
||||
async def create(self, obj: Quote, auto_commit=True) -> Quote:
|
||||
async with self.get_client(auto_commit=auto_commit) as db:
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO quotes (quote, author, channel, created_at, last_updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
INSERT INTO quotes (id, quote, author, channel, created_at, last_updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
RETURNING *
|
||||
""",
|
||||
(
|
||||
obj.id.hex,
|
||||
obj.quote,
|
||||
obj.author.user,
|
||||
obj.channel.user,
|
||||
obj.author,
|
||||
obj.channel_name,
|
||||
obj.created_at,
|
||||
obj.last_updated_at,
|
||||
),
|
||||
)
|
||||
return obj
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
return self._deserialize(data)
|
||||
|
||||
async def update(self, obj: Quote, auto_commit=True) -> Quote:
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
async def update(self, obj: Quote, auto_commit=True) -> Quote: # type: ignore[empty-body]
|
||||
pass # pragma: no cover
|
||||
|
||||
async def delete(self, obj: Quote, auto_commit=True):
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
pass # pragma: no cover
|
||||
|
||||
async def get_by_id(self, obj_id: int | str, auto_commit=True) -> Quote | None:
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Quote | None: # type: ignore[empty-body]
|
||||
pass # pragma: no cover
|
||||
|
||||
async def list(
|
||||
async def list( # type: ignore[empty-body]
|
||||
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True
|
||||
) -> list[T]:
|
||||
raise NotImplementedError("Not implemented since it's not needed")
|
||||
pass # pragma: no cover
|
||||
|
||||
async def get_random(self, channel_name: str, auto_commit=True) -> Quote | None:
|
||||
async with (
|
||||
|
|
@ -162,10 +225,97 @@ class QuoteRepo(IRepo[Quote]):
|
|||
data = await cursor.fetchone()
|
||||
if not data:
|
||||
return None
|
||||
return Quote(
|
||||
quote=data["quote"],
|
||||
author=User(user=data["author"], external_auth={}),
|
||||
channel=User(user=data["channel"], external_auth={}),
|
||||
created_at=data["created_at"],
|
||||
last_updated_at=data["last_updated_at"],
|
||||
)
|
||||
return self._deserialize(data)
|
||||
|
||||
|
||||
class ChatbotRepo(IRepo[Chatbot]):
|
||||
@staticmethod
|
||||
def _deserialize(data: dict) -> Chatbot:
|
||||
return Chatbot(
|
||||
id=UUID(data["id"]),
|
||||
user_id=data["user_id"],
|
||||
automatic_generation_timer=data["automatic_generation_timer"],
|
||||
automatic_quote_timer=data["automatic_quote_timer"],
|
||||
mods=data["mods"].split(","),
|
||||
last_updated_at=data["last_updated_at"],
|
||||
created_at=data["created_at"],
|
||||
)
|
||||
|
||||
async def create(self, obj: Chatbot, auto_commit=True) -> Chatbot:
|
||||
if await self.get_by_user_id(obj.user_id):
|
||||
raise ValueError(f"Chatbot {obj.user_id} already exists")
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
await db.execute(
|
||||
"""INSERT INTO chatbot (
|
||||
id,
|
||||
user_id,
|
||||
automatic_generation_timer,
|
||||
automatic_quote_timer,
|
||||
mods,
|
||||
created_at,
|
||||
last_updated_at
|
||||
) VALUES(?,?,?,?,?,?,?)
|
||||
RETURNING *
|
||||
""",
|
||||
(
|
||||
obj.id.hex,
|
||||
obj.user_id.hex,
|
||||
obj.automatic_generation_timer,
|
||||
obj.automatic_quote_timer,
|
||||
obj.mods_as_string,
|
||||
obj.created_at,
|
||||
obj.last_updated_at,
|
||||
),
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
return self._deserialize(data)
|
||||
|
||||
async def update(self, obj: Chatbot, auto_commit=True) -> Chatbot:
|
||||
if not await self.get_by_user_id(obj.user_id):
|
||||
raise ValueError(f"Chatbot {obj.user_id} does not exist")
|
||||
async with (
|
||||
self.get_client(auto_commit=auto_commit) as db,
|
||||
await db.execute(
|
||||
"""UPDATE chatbot SET
|
||||
automatic_generation_timer = ?,
|
||||
automatic_quote_timer = ?,
|
||||
mods = ?,
|
||||
last_updated_at = ?
|
||||
WHERE user_id = ?
|
||||
RETURNING *
|
||||
""",
|
||||
(
|
||||
obj.automatic_generation_timer,
|
||||
obj.automatic_quote_timer,
|
||||
obj.mods_as_string,
|
||||
utils.get_utc_now(),
|
||||
obj.user_id.hex,
|
||||
),
|
||||
) as cursor,
|
||||
):
|
||||
data = await cursor.fetchone()
|
||||
return self._deserialize(data)
|
||||
|
||||
async def delete(self, obj: T, auto_commit=True):
|
||||
pass # pragma: no cover
|
||||
|
||||
async def get_by_id(self, obj_id: UUID, auto_commit=True) -> Chatbot | None: # type: ignore[empty-body]
|
||||
pass # pragma: no cover
|
||||
|
||||
async def get_by_user_id(self, user_id: UUID) -> Chatbot | None:
|
||||
async with self.get_client() as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM chatbot WHERE user_id = ?", (user_id.hex,)
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
return None
|
||||
return self._deserialize(result)
|
||||
|
||||
async def list( # type: ignore[empty-body]
|
||||
self, obj: T, offset: int = 0, limit: int = 10, auto_commit=True
|
||||
) -> list[T]:
|
||||
pass # pragma: no cover
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import datetime
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import UUID4, AwareDatetime, BaseModel, Field, field_validator
|
||||
|
||||
from huesoporro import utils
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
|
|
@ -19,8 +21,11 @@ class ExternalAuth(BaseModel):
|
|||
|
||||
|
||||
class User(BaseModel):
|
||||
user: str
|
||||
external_auth: dict[Literal["twitch", "discord"], dict]
|
||||
id: UUID4
|
||||
username: str
|
||||
external_auth: dict[Literal["twitch", "discord"], TwitchAuth]
|
||||
created_at: AwareDatetime = Field(default_factory=utils.get_utc_now)
|
||||
last_updated_at: AwareDatetime = Field(default_factory=utils.get_utc_now)
|
||||
|
||||
def encode(
|
||||
self, settings: Settings | None = None, exclude_fields: set[str] | None = None
|
||||
|
|
@ -28,7 +33,7 @@ class User(BaseModel):
|
|||
s = settings or Settings.get()
|
||||
exclude_fields = exclude_fields or {"external_auth"}
|
||||
return jwt.encode(
|
||||
self.model_dump(exclude=exclude_fields),
|
||||
self.model_dump(exclude=exclude_fields, mode="json"),
|
||||
key=s.jwt_secret.get_secret_value(),
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
|
@ -42,12 +47,43 @@ class User(BaseModel):
|
|||
|
||||
@property
|
||||
def twitch_access_token(self):
|
||||
return self.external_auth["twitch"]["access_token"]
|
||||
return self.external_auth["twitch"].access_token
|
||||
|
||||
@property
|
||||
def twitch_refresh_token(self):
|
||||
return self.external_auth["twitch"].refresh_token
|
||||
|
||||
@twitch_access_token.setter # type: ignore[attr-defined,no-redef]
|
||||
def twitch_access_token(self, value):
|
||||
self.external_auth["twitch"].access_token = value
|
||||
|
||||
@twitch_refresh_token.setter # type: ignore[attr-defined,no-redef]
|
||||
def twitch_refresh_token(self, value):
|
||||
self.external_auth["twitch"].refresh_token = value
|
||||
|
||||
def serialize_external_auth(self) -> str:
|
||||
"""Return a JSON string with the inner pydantic model of external_auth serialized using model_dump"""
|
||||
|
||||
return json.dumps({k: v.model_dump() for k, v in self.external_auth.items()})
|
||||
|
||||
|
||||
class ChatbotSettings(BaseModel):
|
||||
class Chatbot(BaseModel):
|
||||
"""A chatbot is an entity that holds settings for a given user, it is NOT tied to a channel.
|
||||
|
||||
Attributes:
|
||||
id (UUID4): The unique identifier for the chatbot.
|
||||
user_id (UUID): The user_id of the user that owns the chatbot.
|
||||
automatic_generation_timer (int): The timer for automatic generation of quotes.
|
||||
automatic_quote_timer (int): The timer for automatic quotes.
|
||||
mods (list[str]): The list of mods for the chatbot.
|
||||
"""
|
||||
|
||||
id: UUID4
|
||||
user_id: UUID4
|
||||
automatic_generation_timer: int = 300
|
||||
automatic_quote_timer: int = 500
|
||||
created_at: AwareDatetime = Field(default_factory=utils.get_utc_now)
|
||||
last_updated_at: AwareDatetime = Field(default_factory=utils.get_utc_now)
|
||||
mods: list[str] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
|
|
@ -64,23 +100,16 @@ class ChatbotSettings(BaseModel):
|
|||
return v
|
||||
|
||||
|
||||
class Sentence(BaseModel):
|
||||
id: int
|
||||
sentence: str
|
||||
created_at: float
|
||||
last_updated_at: float
|
||||
user: User
|
||||
|
||||
|
||||
class Quote(BaseModel):
|
||||
id: UUID4
|
||||
quote: str
|
||||
author: User
|
||||
channel: User
|
||||
created_at: datetime.datetime
|
||||
last_updated_at: datetime.datetime
|
||||
author: str
|
||||
channel_name: str
|
||||
created_at: datetime.datetime = Field(default_factory=utils.get_utc_now)
|
||||
last_updated_at: datetime.datetime = Field(default_factory=utils.get_utc_now)
|
||||
|
||||
def as_pretty(self) -> str:
|
||||
return f"«{self.quote}» - {self.author.user}"
|
||||
return f"«{self.quote}» - {self.author}"
|
||||
|
||||
def as_pretty_saved(self):
|
||||
return f"He añadido la cita «{self.quote}» de {self.author.user}"
|
||||
return f"He añadido la cita «{self.quote}» de {self.author}"
|
||||
|
|
|
|||
27
src/huesoporro/svc/chatbot_svcs.py
Normal file
27
src/huesoporro/svc/chatbot_svcs.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.repos import ChatbotRepo
|
||||
from huesoporro.models import Chatbot
|
||||
|
||||
|
||||
class CreateChatbotSvc(BaseModel):
|
||||
repo: ChatbotRepo
|
||||
|
||||
async def run(self, chatbot: Chatbot):
|
||||
return await self.repo.create(chatbot)
|
||||
|
||||
|
||||
class GetChatbotByUserIdSvc(BaseModel):
|
||||
repo: ChatbotRepo
|
||||
|
||||
async def run(self, user_id: UUID) -> Chatbot | None:
|
||||
return await self.repo.get_by_user_id(user_id=user_id)
|
||||
|
||||
|
||||
class UpdateChatbotSvc(BaseModel):
|
||||
repo: ChatbotRepo
|
||||
|
||||
async def run(self, chatbot: Chatbot):
|
||||
return await self.repo.update(obj=chatbot)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.db import Database
|
||||
from huesoporro.models import ChatbotSettings, User
|
||||
|
||||
|
||||
class ChatbotSettingsGetterSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, user: User) -> ChatbotSettings | None:
|
||||
return await self.db.get_chatbot_settings(user=user)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.repos import QuoteRepo
|
||||
from huesoporro.models import Quote
|
||||
|
||||
|
||||
class RandomQuoteGetterSvc(BaseModel):
|
||||
quote_repo: QuoteRepo
|
||||
|
||||
async def run(self, channel_name: str) -> Quote | None:
|
||||
return await self.quote_repo.get_random(channel_name=channel_name)
|
||||
|
|
@ -32,4 +32,4 @@ class HelloGeneratorSvc(BaseModel):
|
|||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_hello_generator_svc() -> HelloGeneratorSvc:
|
||||
return HelloGeneratorSvc()
|
||||
return HelloGeneratorSvc() # pragma: no cover
|
||||
|
|
|
|||
|
|
@ -1,20 +1,26 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.db import Database
|
||||
from huesoporro.infra.repos import ChatbotRepo
|
||||
from huesoporro.models import User
|
||||
|
||||
|
||||
class IsModSvc(BaseModel):
|
||||
db: Database
|
||||
repo: ChatbotRepo
|
||||
|
||||
async def run(self, user: User, username: str, channel: str) -> bool:
|
||||
async def run(self, username: str, channel: str, user: User) -> bool:
|
||||
"""A user given username is a mod if they're the same as the current channel or if they're in the modlist
|
||||
available in a user's settings"""
|
||||
available in a user's settings
|
||||
|
||||
Args:
|
||||
username (str): The username to check if it's a mod
|
||||
channel (str): The current channel
|
||||
user (User): User object the chatbot belongs to
|
||||
"""
|
||||
|
||||
if channel == username:
|
||||
return True
|
||||
|
||||
chatbot_settings = await self.db.get_chatbot_settings(user=user)
|
||||
if not chatbot_settings:
|
||||
chatbot = await self.repo.get_by_user_id(user_id=user.id)
|
||||
if not chatbot:
|
||||
return False
|
||||
return username in chatbot_settings.mods
|
||||
return username in chatbot.mods
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.repos import QuoteRepo
|
||||
from huesoporro.models import Quote
|
||||
|
||||
|
||||
class QuoteStorerSvc(BaseModel):
|
||||
quote_repo: QuoteRepo
|
||||
|
||||
async def run(self, quote: Quote) -> Quote:
|
||||
return await self.quote_repo.create(quote)
|
||||
18
src/huesoporro/svc/quotes_svcs.py
Normal file
18
src/huesoporro/svc/quotes_svcs.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.repos import QuoteRepo
|
||||
from huesoporro.models import Quote
|
||||
|
||||
|
||||
class GetRandomQuoteSvc(BaseModel):
|
||||
repo: QuoteRepo
|
||||
|
||||
async def run(self, channel_name: str) -> Quote | None:
|
||||
return await self.repo.get_random(channel_name=channel_name)
|
||||
|
||||
|
||||
class CreateQuoteSvc(BaseModel):
|
||||
repo: QuoteRepo
|
||||
|
||||
async def run(self, quote: Quote) -> Quote:
|
||||
return await self.repo.create(obj=quote)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
# pragma: no cover
|
||||
from loguru import logger
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro.infra.db import Database
|
||||
from huesoporro.models import ChatbotSettings, User
|
||||
|
||||
|
||||
class ChatbotSettingsStorerSvc(BaseModel):
|
||||
db: Database
|
||||
|
||||
async def run(self, user: User, bot_settings: ChatbotSettings):
|
||||
return await self.db.save_chatbot_settings(
|
||||
user=user, chatbot_settings=bot_settings
|
||||
)
|
||||
127
src/huesoporro/svc/users_svcs.py
Normal file
127
src/huesoporro/svc/users_svcs.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from huesoporro import utils
|
||||
from huesoporro.infra.authenticator import TwitchAuthenticator
|
||||
from huesoporro.infra.repos import UserRepo
|
||||
from huesoporro.models import TwitchAuth, User
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class CreateUserSvc(BaseModel):
|
||||
user_repo: UserRepo
|
||||
|
||||
async def run(self, user: User) -> User:
|
||||
"""Create a new user in the system
|
||||
|
||||
Args:
|
||||
user: User object to be created
|
||||
|
||||
Returns:
|
||||
The created User with any system-generated fields populated
|
||||
"""
|
||||
return await self.user_repo.create(user)
|
||||
|
||||
|
||||
class UpdateUserSvc(BaseModel):
|
||||
user_repo: UserRepo
|
||||
|
||||
async def run(self, user: User) -> User:
|
||||
"""Update an existing user in the system
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
The updated User
|
||||
|
||||
Raises:
|
||||
ValueError: If the user doesn't exist
|
||||
"""
|
||||
user.last_updated_at = utils.get_utc_now()
|
||||
return await self.user_repo.update(user)
|
||||
|
||||
|
||||
class DeleteUserSvc(BaseModel):
|
||||
user_repo: UserRepo
|
||||
|
||||
async def run(self, user: User) -> None:
|
||||
"""Delete a user from the system
|
||||
|
||||
Args:
|
||||
user: User object to be deleted
|
||||
"""
|
||||
await self.user_repo.delete(user)
|
||||
|
||||
|
||||
class GetUserByIdSvc(BaseModel):
|
||||
user_repo: UserRepo
|
||||
|
||||
async def run(self, user_id: UUID) -> User | None:
|
||||
"""Retrieve a user by their ID
|
||||
|
||||
Args:
|
||||
user_id: UUID of the user to retrieve
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise
|
||||
"""
|
||||
return await self.user_repo.get_by_id(user_id)
|
||||
|
||||
|
||||
class GetUserByUsernameSvc(BaseModel):
|
||||
user_repo: UserRepo
|
||||
|
||||
async def run(self, username: str) -> User | None:
|
||||
"""Retrieve a user by their username
|
||||
|
||||
Args:
|
||||
username: Username of the user to retrieve
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise
|
||||
"""
|
||||
return await self.user_repo.get_by_username(username)
|
||||
|
||||
|
||||
class IsValidTokenSvc(BaseModel):
|
||||
authenticator: TwitchAuthenticator
|
||||
|
||||
async def run(self, user: User) -> bool:
|
||||
return await self.authenticator.token_is_valid(user.twitch_access_token)
|
||||
|
||||
|
||||
class RefreshTokenSvc(BaseModel):
|
||||
twitch_authenticator: TwitchAuthenticator
|
||||
|
||||
async def run(self, user: User) -> User:
|
||||
"""Refresh a user's Twitch token
|
||||
|
||||
Args:
|
||||
user: User object with the Twitch token to be refreshed
|
||||
|
||||
Returns:
|
||||
The updated User with the refreshed Twitch token
|
||||
"""
|
||||
|
||||
logger.info(f"Refreshing token for user {user}")
|
||||
twitch_auth = await self.twitch_authenticator.refresh_token(
|
||||
user.twitch_refresh_token
|
||||
)
|
||||
user.twitch_access_token = twitch_auth.access_token # type: ignore[misc]
|
||||
user.twitch_refresh_token = twitch_auth.refresh_token # type: ignore[misc]
|
||||
return user
|
||||
|
||||
|
||||
class GetTwitchAuthByAuthCodeSvc(BaseModel):
|
||||
authenticator: TwitchAuthenticator
|
||||
s: Settings
|
||||
|
||||
async def run(self, auth_code: str) -> TwitchAuth:
|
||||
auth = await self.authenticator.get_token(auth_code)
|
||||
username = auth.userinfo["preferred_username"]
|
||||
if username not in self.s.allowed_users:
|
||||
raise ValueError(f"User {username} is not allowed to use this bot")
|
||||
return auth
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
import asyncio
|
||||
from collections import deque
|
||||
from hashlib import sha512
|
||||
from pathlib import Path
|
||||
|
||||
from gtts import gTTS
|
||||
from litestar import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
from huesoporro.settings import Settings
|
||||
|
||||
|
||||
class TTSManager:
|
||||
TEXT_MAX_LENGTH: int = 400
|
||||
|
||||
def __init__(self, max_queue_size=10):
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
|
||||
# Connected WebSocket clients
|
||||
self.clients: list[WebSocket] = []
|
||||
|
||||
# Currently playing audio
|
||||
self.current_audio = None
|
||||
|
||||
# Lock to prevent race conditions
|
||||
self._lock = asyncio.Lock()
|
||||
self._tasks = []
|
||||
self.s = Settings.get()
|
||||
|
||||
def generate_tts(self, text, language="pt", tld="com.br"):
|
||||
# Generate unique filename
|
||||
text = text[0 : self.TEXT_MAX_LENGTH]
|
||||
filename = (
|
||||
self.s.tts_cache_path / f"{sha512(text.lower().encode()).hexdigest()}.mp3"
|
||||
)
|
||||
|
||||
if filename.exists():
|
||||
logger.info(
|
||||
f"TTS already exists for '{text[:50]}' at {filename}. Returning it"
|
||||
)
|
||||
return {
|
||||
"filename": filename.name,
|
||||
"text": text,
|
||||
"filepath": str(filename),
|
||||
"language": language,
|
||||
"tld": tld,
|
||||
}
|
||||
logger.info(f"Generating TTS for '{text[:50]}'")
|
||||
|
||||
# Generate TTS
|
||||
tts = gTTS(text=text, lang=language, tld=tld)
|
||||
tts.save(str(filename))
|
||||
|
||||
return {
|
||||
"filename": filename.name,
|
||||
"text": text,
|
||||
"filepath": filename,
|
||||
"language": language,
|
||||
"tld": tld,
|
||||
}
|
||||
|
||||
async def add_to_queue(self, text, language="pt", tld="com.br"):
|
||||
"""Add TTS request to queue and start processing if not already running"""
|
||||
async with self._lock:
|
||||
# Generate TTS file
|
||||
audio_info = self.generate_tts(text, language, tld)
|
||||
|
||||
# Add to queue
|
||||
self.queue.append(audio_info)
|
||||
|
||||
# If this is the only item, start processing
|
||||
if len(self.queue) == 1:
|
||||
self._tasks.append(asyncio.create_task(self.process_queue()))
|
||||
|
||||
return audio_info
|
||||
|
||||
async def process_queue(self):
|
||||
"""Process queue and stream audio to connected clients"""
|
||||
while True:
|
||||
async with self._lock:
|
||||
# Check if queue is empty
|
||||
if not self.queue:
|
||||
return
|
||||
|
||||
# Get next audio file
|
||||
audio_info = self.queue[0]
|
||||
|
||||
try:
|
||||
# Read the entire audio file
|
||||
audio_path = Path(audio_info["filepath"])
|
||||
with audio_path.open("rb") as audio_file:
|
||||
file_size = audio_path.stat().st_size
|
||||
logger.info(
|
||||
f"Streaming file: {audio_info['filename']}, Size: {file_size} bytes"
|
||||
)
|
||||
|
||||
# Stream audio to all connected clients
|
||||
for client in self.clients:
|
||||
try:
|
||||
# Reset file pointer to beginning
|
||||
audio_file.seek(0)
|
||||
|
||||
# Send file size first (as a header)
|
||||
await client.send_text(f"FILE_HEADER:{file_size}")
|
||||
|
||||
# Stream file in chunks
|
||||
chunk = audio_file.read(128) # Larger chunk size
|
||||
chunk_count = 0
|
||||
while chunk:
|
||||
logger.info(f"Streamed {chunk_count} chunks")
|
||||
chunk_count += 1
|
||||
await client.send_bytes(chunk)
|
||||
chunk = audio_file.read(128)
|
||||
|
||||
# Send file footer
|
||||
await client.send_text("FILE_FOOTER")
|
||||
|
||||
except Exception: # noqa: BLE001
|
||||
logger.error(
|
||||
f"Error streaming to client {client.client}. Removing it."
|
||||
)
|
||||
if client in self.clients:
|
||||
self.clients.remove(client)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error processing audio file: {e}")
|
||||
|
||||
# Remove the processed item from the queue
|
||||
async with self._lock:
|
||||
if self.queue and self.queue[0] == audio_info:
|
||||
self.queue.popleft()
|
||||
5
src/huesoporro/utils.py
Normal file
5
src/huesoporro/utils.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import datetime
|
||||
|
||||
|
||||
def get_utc_now():
|
||||
return datetime.datetime.now(datetime.UTC)
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WebsocketCommands(StrEnum):
|
||||
TTS_SEND = "tts_send"
|
||||
CHATBOT_START = "chatbot_start"
|
||||
CHATBOT_STOP = "chatbot_stop"
|
||||
CHATBOT_STATUS = "chatbot_status"
|
||||
CHATBOT_UPDATE = "chatbot_update"
|
||||
|
||||
|
||||
class WebsocketMessage(BaseModel):
|
||||
command: WebsocketCommands
|
||||
data: dict
|
||||
Loading…
Add table
Add a link
Reference in a new issue