520 lines
20 KiB
Python
520 lines
20 KiB
Python
import string
|
|
import time
|
|
from enum import StrEnum
|
|
|
|
from loguru import logger
|
|
from nltk.tokenize import sent_tokenize
|
|
from TwitchWebsocket import Message, TwitchWebsocket
|
|
|
|
from src.markovbot_gui.libs.db import Database
|
|
from src.markovbot_gui.libs.settings import Settings
|
|
from src.markovbot_gui.libs.timer import LoopingTimer
|
|
from src.markovbot_gui.libs.tokenizer import detokenize, tokenize
|
|
|
|
|
|
class Commands(StrEnum):
|
|
SET_COOLDOWN = "!setcd"
|
|
GENERATE = "!g"
|
|
BLACKLIST = "!blacklist"
|
|
GENERATE_HELP = "!ghelp"
|
|
QUOTE = "!q"
|
|
QUOTE_ADD = "!qadd"
|
|
|
|
|
|
class MarkovChain:
|
|
end_tag = "<END>"
|
|
|
|
def __init__(self, settings: Settings | None = None):
|
|
self.s = settings or Settings.read()
|
|
self.prev_message_t = 0.0
|
|
self._enabled = True
|
|
|
|
self.db = Database(self.s.channel_name)
|
|
|
|
if self.s.help_message_timer > 0:
|
|
if self.s.help_message_timer < 300: # noqa: PLR2004
|
|
raise ValueError(
|
|
'Value for "HelpMessageTimer" in must be at least 300 seconds, ' # noqa: EM101
|
|
"or a negative number for no help messages.",
|
|
)
|
|
t = LoopingTimer(self.s.help_message_timer, self._command_help)
|
|
t.start()
|
|
|
|
# Set up daemon Timer to send automatic generation messages
|
|
if self.s.automatic_generation_timer > 0:
|
|
if self.s.automatic_generation_timer < 30: # noqa: PLR2004
|
|
raise ValueError(
|
|
'Value for "Automatic_generation_message" must be at least 30 seconds, or a negative number for no ' # noqa: EM101
|
|
"automatic generations.",
|
|
)
|
|
logger.info(
|
|
f"Automatic generation enabled, will send messages every {self.s.automatic_generation_timer} seconds"
|
|
)
|
|
t = LoopingTimer(
|
|
self.s.automatic_generation_timer,
|
|
self._command_automatic_generation,
|
|
)
|
|
t.start()
|
|
|
|
self.ws = TwitchWebsocket(
|
|
host=self.s.host,
|
|
port=self.s.port,
|
|
chan=self.s.channel_name,
|
|
nick=self.s.nickname,
|
|
auth=self.s.authentication,
|
|
callback=self.message_handler,
|
|
capability=["commands", "tags"],
|
|
live=True,
|
|
)
|
|
|
|
def run_bot(self):
|
|
self.ws.start_bot()
|
|
|
|
def stop_bot(self):
|
|
self.ws.leave_channel(self.s.channel_name)
|
|
self.ws.stop()
|
|
|
|
def _command_help(self) -> None:
|
|
"""Send a Help message to the connected chat, as long as the bot wasn't disabled."""
|
|
if self._enabled:
|
|
logger.info("Help message sent.")
|
|
try:
|
|
self.ws.send_message(
|
|
"Learn how this bot generates sentences here: https://github.com/CubieDev/TwitchMarkovChain#how-it-works",
|
|
)
|
|
except OSError as error:
|
|
logger.warning(
|
|
f"[OSError: {error}] upon sending help message. Ignoring.",
|
|
)
|
|
|
|
def _command_set_cooldown(self, username: str, split_message: list[str]):
|
|
if len(split_message) == 2: # noqa: PLR2004
|
|
try:
|
|
cooldown = int(split_message[1])
|
|
except ValueError:
|
|
self.ws.send_whisper(
|
|
username,
|
|
"The parameter must be an integer amount, eg: !setcd 30",
|
|
)
|
|
return
|
|
self.s.cooldown = cooldown
|
|
self.s.write()
|
|
self.ws.send_whisper(
|
|
username,
|
|
f"The !generate cooldown has been set to {cooldown} seconds.",
|
|
)
|
|
|
|
def _command_blacklist(self, username: str, split_message: list[str]):
|
|
if len(split_message) == 2: # noqa: PLR2004
|
|
try:
|
|
blacklisted_username = split_message[1]
|
|
except ValueError:
|
|
self.ws.send_whisper(
|
|
username,
|
|
"The parameter must be a username, eg: !blacklist ibai",
|
|
)
|
|
return
|
|
self.s.denied_users.append(blacklisted_username)
|
|
self.s.write()
|
|
|
|
def _command_generate(self, username: str, message: str):
|
|
cur_time = time.time()
|
|
if self.prev_message_t + self.s.cooldown >= cur_time:
|
|
if not self.db.check_whisper_ignore(username):
|
|
self.send_whisper(
|
|
username,
|
|
f"Cooldown hit: {self.prev_message_t + self.s.cooldown - cur_time:0.2f} out of {self.s.cooldown:.0f}s remaining. !nopm to stop these cooldown pm's.",
|
|
)
|
|
logger.info(
|
|
f"Cooldown hit with {self.prev_message_t + self.s.cooldown - cur_time:0.2f}s remaining.",
|
|
)
|
|
params = tokenize(message)[2:] if self.s.allow_generate_params else None
|
|
# Generate an actual sentence
|
|
sentence, success = self.generate(params)
|
|
if success:
|
|
# Reset cooldown if a message was actually generated
|
|
self.prev_message_t = time.time()
|
|
logger.info(sentence)
|
|
self.ws.send_message(sentence)
|
|
|
|
self.store_sentence(message)
|
|
|
|
def _command_automatic_generation(self) -> None:
|
|
"""Send an automatic generation message to the connected chat.
|
|
|
|
As long as the bot wasn't disabled, just like if someone typed "!g" in chat.
|
|
"""
|
|
if self._enabled:
|
|
logger.debug("Automatically generating message")
|
|
sentence, success = self.generate()
|
|
if success:
|
|
logger.info(
|
|
f"Created '{sentence}'. Cooling down for {self.s.automatic_generation_timer} seconds before regenerating",
|
|
)
|
|
try:
|
|
self.ws.send_message(sentence)
|
|
except OSError as error:
|
|
logger.warning(
|
|
f"[OSError: {error}] upon sending automatic generation message. Ignoring.",
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Attempted to output automatic generation message, but there is not enough learned information yet.",
|
|
)
|
|
|
|
def _command_quote(self):
|
|
"""Retrieve a random quote from the `quotes` table and format it as
|
|
|
|
> «<quote>» - <author>
|
|
"""
|
|
data = self.db.execute(
|
|
"SELECT quote, author FROM quotes ORDER BY RANDOM() LIMIT 1;", fetch=True
|
|
)
|
|
if data:
|
|
data = data[0]
|
|
quote, author = data[0], data[1]
|
|
self.ws.send_message(f"«{quote}» - {author}")
|
|
|
|
def _command_add_quote(self, message: str):
|
|
"""Add a quote to the quotes table. The message should follow the format:
|
|
|
|
!qadd quote author
|
|
|
|
The last word will be parsed as the author and anything in between !qadd and the author will be considered
|
|
as the quote itself
|
|
"""
|
|
# Split the message into quote and author
|
|
parts = message.split()
|
|
author = parts[-1]
|
|
quote = " ".join(parts[1:-1])
|
|
|
|
data = self.db.execute(
|
|
"SELECT 1 FROM quotes WHERE quote = ?", (quote,), fetch=True
|
|
)
|
|
if data:
|
|
self.ws.send_message(f"Quote «{quote}» was already added.")
|
|
return
|
|
|
|
self.db.execute(
|
|
"INSERT INTO quotes (quote, author) VALUES (?, ?)",
|
|
(quote, author), # type: ignore[arg-type]
|
|
)
|
|
self.ws.send_message(f"Quote «{quote}» by {author} added.")
|
|
|
|
def store_sentence(self, message: str):
|
|
logger.info(f"Processing {message} in order to store it")
|
|
stripped_message = message.strip()
|
|
try:
|
|
sentences = sent_tokenize(stripped_message)
|
|
except LookupError:
|
|
logger.debug("Downloading required punkt resource...")
|
|
import nltk
|
|
|
|
nltk.download("punkt")
|
|
logger.debug("Downloaded required punkt resource.")
|
|
sentences = sent_tokenize(stripped_message)
|
|
|
|
for sentence in sentences:
|
|
words = tokenize(sentence)
|
|
# Double spaces will lead to invalid rules. We remove empty words here
|
|
if "" in words:
|
|
words = [word for word in words if word]
|
|
|
|
# If the sentence is too short, ignore it and move on to the next.
|
|
if len(words) <= self.s.key_length:
|
|
continue
|
|
|
|
# Add a new starting point for a sentence to the <START>
|
|
words = [words[x] for x in range(self.s.key_length)]
|
|
logger.debug(f"Adding {words} to start queue")
|
|
self.db.add_start_queue(words)
|
|
|
|
# Create Key variable which will be used as a key in the Dictionary for the grammar
|
|
key: list[str] = []
|
|
for word in words:
|
|
# Set up key for first use
|
|
if len(key) < self.s.key_length:
|
|
key.append(word)
|
|
continue
|
|
logger.debug(f"Adding {key}[{word}] to rule queue")
|
|
self.db.add_rule_queue([*key, word])
|
|
|
|
# Remove the first word, and add the current word,
|
|
# so that the key is correct for the next word.
|
|
key.pop(0)
|
|
key.append(word)
|
|
logger.debug(f"Adding {key} to rule queue")
|
|
# Add <END> at the end of the sentence
|
|
self.db.add_rule_queue([*key, self.end_tag])
|
|
|
|
def message_handler(self, message: Message): # noqa: C901, PLR0911, PLR0912
|
|
try:
|
|
if not message.user or message.user in self.s.denied_users:
|
|
logger.debug(f"User {message.user} can't send messages")
|
|
return
|
|
|
|
msgs = message.message.split()
|
|
if not msgs:
|
|
logger.debug("Message is empty")
|
|
return
|
|
|
|
if "bits" in message.tags:
|
|
return
|
|
|
|
if "emotes" in message.tags:
|
|
# Replace modified emotes with normal versions,
|
|
# as the bot will never have the modified emotes unlocked at the time.
|
|
for modifier in self.extract_modifiers(message.tags["emotes"]):
|
|
message.message = message.message.replace(modifier, "")
|
|
|
|
logger.debug(f"Received {msgs[0]} command from {message.user}")
|
|
match msgs[0]:
|
|
case Commands.GENERATE_HELP:
|
|
logger.debug("Executing _command_help()")
|
|
self._command_help()
|
|
|
|
case Commands.SET_COOLDOWN:
|
|
if self.is_mod(message.user, message.channel):
|
|
logger.debug(
|
|
f"User {message.user} is mod, executing _command_set_cooldown()",
|
|
)
|
|
self._command_set_cooldown(
|
|
split_message=msgs,
|
|
username=message.user,
|
|
)
|
|
|
|
case Commands.BLACKLIST:
|
|
if self.is_mod(message.user, message.channel):
|
|
logger.debug(
|
|
f"User {message.user} is a mod, executing _command_blacklist()",
|
|
)
|
|
self._command_blacklist(
|
|
split_message=msgs,
|
|
username=message.user,
|
|
)
|
|
|
|
case Commands.GENERATE:
|
|
if not self._enabled:
|
|
logger.info("Bot not enabled, skipping")
|
|
return
|
|
if message.user not in self.s.denied_users:
|
|
logger.info(
|
|
f"User {message.user} allowed to generate, executing _command_generate()",
|
|
)
|
|
self._command_generate(
|
|
message=message.message,
|
|
username=message.user,
|
|
)
|
|
|
|
case Commands.QUOTE:
|
|
if not self._enabled:
|
|
logger.info("Bot not enabled, skipping")
|
|
return
|
|
if message.user not in self.s.denied_users:
|
|
logger.info(
|
|
f"User {message.user} allowed to generate, executing _command_quote()",
|
|
)
|
|
self._command_quote()
|
|
|
|
case Commands.QUOTE_ADD:
|
|
if self.is_mod(message.user, message.channel):
|
|
logger.info(
|
|
f"User {message.user} allowed to create quote, executing _command_quote()",
|
|
)
|
|
self._command_add_quote(message.message)
|
|
return
|
|
self.ws.send_message(
|
|
f"@{message.user} you're not in the modlist, you can't add quotes"
|
|
)
|
|
|
|
case _:
|
|
logger.debug(
|
|
f"Not a command: {msgs[0]}. Storing into db as a plain message",
|
|
)
|
|
if message.type == "366":
|
|
logger.info(f"Successfully joined channel: #{message.channel}")
|
|
return
|
|
self.store_sentence(message.message)
|
|
|
|
except Exception: # noqa: BLE001
|
|
logger.exception(f"Could not process message {message}")
|
|
|
|
def generate(self, params: list[str] | None = None) -> tuple[str, bool]: # noqa: C901, PLR0912
|
|
"""Given an input sentence, generate the remainder of the sentence using the learned data.
|
|
|
|
Args:
|
|
params (list[str]): A list of words to use as an input to use as the start of generating.
|
|
|
|
Returns:
|
|
tuple[str, bool]: A tuple of a sentence as the first value, and a boolean indicating
|
|
whether the generation succeeded as the second value.
|
|
"""
|
|
params = params or []
|
|
|
|
# List of sentences that will be generated. In some cases, multiple sentences will be generated,
|
|
# e.g. when the first sentence has less words than self.min_sentence_length.
|
|
sentences: list[list | list[str]] = [[]]
|
|
|
|
# Check for commands or recursion, eg: !generate !generate
|
|
if len(params) > 0 and self.is_command(params[0]):
|
|
return "You can't make me do commands, you madman!", False
|
|
|
|
# Get the starting key and starting sentence.
|
|
# If there is more than 1 param, get the last 2 as the key.
|
|
# Note that self.s.key_length is fixed to 2 in this implementation
|
|
if len(params) > 1:
|
|
key = params[-self.s.key_length :]
|
|
# Copy the entire params for the sentence
|
|
sentences[0] = params.copy()
|
|
|
|
elif len(params) == 1:
|
|
# First we try to find if this word was once used as the first word in a sentence:
|
|
key = self.db.get_next_single_start(params[0]) # type: ignore[assignment]
|
|
if key is None:
|
|
# If this failed, we try to find the next word in the grammar as a whole
|
|
key = self.db.get_next_single_initial(0, params[0])
|
|
if key is None:
|
|
# Return a message that this word hasn't been learned yet
|
|
return f'I haven\'t extracted "{params[0]}" from chat yet.', False
|
|
# Copy this for the sentence
|
|
sentences[0] = key.copy()
|
|
|
|
else: # if there are no params
|
|
# Get starting key
|
|
key = self.db.get_start()
|
|
if key:
|
|
# Copy this for the sentence
|
|
sentences[0] = key.copy()
|
|
else:
|
|
# If nothing's ever been said
|
|
return "There is not enough learned information yet.", False
|
|
|
|
# Counter to prevent infinite loops (i.e. constantly generating <END> while below the
|
|
# minimum number of words to generate)
|
|
i = 0
|
|
while (
|
|
self.get_sentence_length(sentences) < self.s.max_sentence_length
|
|
and i < self.s.max_sentence_length * 2
|
|
):
|
|
# Use key to get next word
|
|
if i == 0:
|
|
# Prevent fetching <END> on the first word
|
|
word = self.db.get_next_initial(i, key)
|
|
else:
|
|
word = self.db.get_next(i, key)
|
|
|
|
i += 1
|
|
|
|
if word == "<END>" or word is None:
|
|
# Break, unless we are before the min_sentence_length
|
|
if i < self.s.min_sentence_length:
|
|
key = self.db.get_start()
|
|
# Ensure that the key can be generated. Otherwise, we still stop.
|
|
if key:
|
|
# Start a new sentence
|
|
sentences.append([])
|
|
for entry in key:
|
|
sentences[-1].append(entry)
|
|
continue
|
|
break
|
|
|
|
# Otherwise add the word
|
|
sentences[-1].append(word)
|
|
|
|
# Shift the key so on the next iteration it gets the next item
|
|
key.pop(0)
|
|
key.append(word)
|
|
|
|
# If there were params, but the sentence resulting is identical to the params
|
|
# Then the params did not result in an actual sentence
|
|
# If so, restart without params
|
|
if len(params) > 0 and params == sentences[0]:
|
|
return "I haven't learned what to do with \"" + detokenize(
|
|
params[-self.s.key_length :],
|
|
) + '" yet.', False
|
|
|
|
return self.s.sentence_separator.join(
|
|
detokenize(sentence) for sentence in sentences
|
|
), True
|
|
|
|
@staticmethod
|
|
def get_sentence_length(sentences: list[list[str]]) -> int:
|
|
"""Given a list of tokens representing a sentence, return the number of words in there.
|
|
|
|
Args:
|
|
sentences (List[List[str]]): List of lists of tokens that make up a sentence,
|
|
where a token is a word or punctuation. For example:
|
|
[['Hello', ',', 'you', "'re", 'Tom', '!'], ['Yes', ',', 'I', 'am', '.']]
|
|
This would return 6.
|
|
|
|
Returns:
|
|
int: The number of words in the sentence.
|
|
"""
|
|
count = 0
|
|
for sentence in sentences:
|
|
for token in sentence:
|
|
if token not in string.punctuation and token[0] != "'":
|
|
count += 1
|
|
return count
|
|
|
|
@staticmethod
|
|
def extract_modifiers(emotes: str) -> list[str]:
|
|
"""Extract emote modifiers from emotes such as the horizontal flip.
|
|
|
|
Args:
|
|
emotes (str): String containing all emotes used in the message.
|
|
|
|
Returns:
|
|
list[str]: List of strings that show modifiers, such as "_HZ" for horizontal flip.
|
|
"""
|
|
output = []
|
|
try:
|
|
while emotes:
|
|
u_index = emotes.index("_")
|
|
c_index = emotes.index(":", u_index)
|
|
output.append(emotes[u_index:c_index])
|
|
emotes = emotes[c_index:]
|
|
except ValueError:
|
|
pass
|
|
return output
|
|
|
|
def send_whisper(self, user: str, message: str) -> None:
|
|
"""Optionally send a whisper, only if "WhisperCooldown" is True.
|
|
|
|
Args:
|
|
user (str): The user to potentially whisper.
|
|
message (str): The message to potentially whisper
|
|
"""
|
|
if self.s.whisper_cooldown:
|
|
self.ws.send_whisper(user, message)
|
|
|
|
@staticmethod
|
|
def is_command(message: str) -> bool:
|
|
"""True if the message is any command, except /me.
|
|
|
|
Is used to avoid learning and generating commands.
|
|
|
|
Args:
|
|
message (str): The message to check.
|
|
|
|
Returns:
|
|
bool: True if the message is any potential command (starts with a '!', '/' or '.')
|
|
except /me.
|
|
"""
|
|
return message in list(Commands)
|
|
|
|
def is_mod(self, username: str, channel: str) -> bool:
|
|
"""True if the user is a moderator.
|
|
|
|
Args:
|
|
username (str): The name of the user to check
|
|
channel (str): The name of the channel
|
|
|
|
Returns:
|
|
bool: True if the user is a moderator.
|
|
"""
|
|
return username in self.s.mods or username == channel
|
|
|
|
|
|
if __name__ == "__main__":
|
|
MarkovChain()
|