106 lines
3 KiB
Python
106 lines
3 KiB
Python
from typing import Literal
|
|
|
|
from litestar import MediaType, Response, get, put
|
|
from litestar.datastructures import UploadFile
|
|
from litestar.response import Template
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from huesoporro.bot import BotsManager
|
|
from huesoporro.models import ChatbotSettings, User
|
|
from huesoporro.svc.get_chatbot_settings import ChatbotSettingsGetterSvc
|
|
from huesoporro.svc.store_settings import ChatbotSettingsStorerSvc
|
|
|
|
|
|
class ManageBotDTO(BaseModel):
|
|
command: Literal["start", "stop"]
|
|
channel_name: str | None = None
|
|
|
|
|
|
class ImportTextFileDTO(BaseModel):
|
|
file: UploadFile
|
|
channel_name: str
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
|
|
@get(
|
|
"/tts",
|
|
media_type=MediaType.HTML,
|
|
)
|
|
async def get_tts_overlay(user: User) -> Template:
|
|
return Template(template_name="tts.html")
|
|
|
|
|
|
@get(
|
|
"/tts/permalink",
|
|
media_type=MediaType.HTML,
|
|
)
|
|
async def get_tts_permalink(access_token: str) -> Template:
|
|
"""Handler for the /tts permalink endpoint to be used by apps that can only give the authentication as a query
|
|
param and not as a cookie, i.e. OBS"""
|
|
|
|
return Template(
|
|
template_name="tts.html",
|
|
)
|
|
|
|
|
|
@get(
|
|
"/",
|
|
media_type=MediaType.HTML,
|
|
)
|
|
async def get_index(user: User, gbs: ChatbotSettingsGetterSvc) -> Template:
|
|
chatbot_settings = await gbs.run(user=user)
|
|
return Template(
|
|
template_name="index.html",
|
|
context=chatbot_settings.model_dump() if chatbot_settings else {},
|
|
)
|
|
|
|
|
|
@put("/api/v1/bot")
|
|
async def manage_bot(
|
|
user: User,
|
|
data: ManageBotDTO,
|
|
gbs: ChatbotSettingsGetterSvc,
|
|
sbs: ChatbotSettingsStorerSvc,
|
|
bm: BotsManager,
|
|
) -> Response:
|
|
chatbot_settings = await gbs.run(user=user)
|
|
if not chatbot_settings:
|
|
await sbs.run(user=user, bot_settings=ChatbotSettings())
|
|
chatbot_settings = await gbs.run(user=user)
|
|
if data.command == "start":
|
|
if not data.channel_name:
|
|
return Response({"message": "Channel name is required"}, status_code=400)
|
|
bm.add_bot(user, data.channel_name, chatbot_settings=chatbot_settings) # type: ignore[arg-type]
|
|
if user.user in bm.bots:
|
|
await bm.run_user_bot(user)
|
|
return Response({"message": "Bot started"})
|
|
if data.command == "stop" and user.user in bm.bots:
|
|
await bm.stop_user_bot(user)
|
|
return Response({"message": "Bot stopped"})
|
|
return Response({"message": "Invalid command"}, status_code=400)
|
|
|
|
|
|
@get("/api/v1/bot")
|
|
async def get_bot_status(user: User, bm: BotsManager) -> dict:
|
|
if user.user not in bm.bots:
|
|
return {"status": "ko"}
|
|
return {"status": "ok"}
|
|
|
|
|
|
@get("/api/v1/bot/settings")
|
|
async def get_bot_settings(
|
|
user: User, gbs: ChatbotSettingsGetterSvc
|
|
) -> ChatbotSettings | dict:
|
|
cbs = await gbs.run(user=user)
|
|
if not cbs:
|
|
return {"status": "Not found"}
|
|
return cbs
|
|
|
|
|
|
@put("/api/v1/bot/settings")
|
|
async def save_bot_settings(
|
|
user: User, data: ChatbotSettings, sbs: ChatbotSettingsStorerSvc
|
|
) -> dict:
|
|
await sbs.run(user=user, bot_settings=data)
|
|
return {"status": "ok"}
|