refactor: clean the search command a bit
This commit is contained in:
parent
eeb1573f99
commit
701d79583d
4 changed files with 93 additions and 141 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
from halig.settings import Settings
|
||||
|
||||
|
|
@ -12,3 +13,6 @@ class ICommand(ABC):
|
|||
class BaseCommand(ICommand):
|
||||
def __init__(self, settings: Settings, *args, **kwargs):
|
||||
self.settings = settings
|
||||
|
||||
def traverse_notebooks(self, callback_on_item: Callable):
|
||||
"""Traverse root_path"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import platformdirs
|
||||
|
|
@ -14,156 +12,105 @@ from halig.settings import Settings
|
|||
|
||||
|
||||
class SearchCommand(BaseCommand):
|
||||
"""Full text search against a SQLite located at $HOME/.cache/halig.db
|
||||
|
||||
The database schema is pretty simple and it uses SQLite's FT5 for
|
||||
the full text search capabilities:
|
||||
|
||||
CREATE VIRTUAL TABLE note USING fts5(last_timestamp, hash, filepath, body);
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_term: str,
|
||||
settings: Settings,
|
||||
should_index: bool = False,
|
||||
):
|
||||
self.search_term = search_term
|
||||
self.settings = settings
|
||||
self.should_index = should_index
|
||||
self.encryptor = Encryptor(self.settings)
|
||||
def __init__(self, term: str, index: bool, settings: Settings, *args, **kwargs):
|
||||
super().__init__(settings, *args, **kwargs)
|
||||
self.encryptor = Encryptor(settings)
|
||||
self.term = term
|
||||
self.index = index
|
||||
self.cache_path = platformdirs.user_cache_path(
|
||||
"halig",
|
||||
ensure_exists=True,
|
||||
)
|
||||
self.db_path = self.cache_path / "halig.db"
|
||||
self.db_conn = sqlite3.connect(self.db_path)
|
||||
|
||||
def _create_schema(self):
|
||||
"""Create or repair the database schema"""
|
||||
db_path = self._get_database_path()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create or repair the schema
|
||||
cursor.execute(
|
||||
with self.db_conn:
|
||||
self.db_conn.execute(
|
||||
"""CREATE VIRTUAL TABLE IF NOT EXISTS notes
|
||||
USING fts5(last_timestamp, hash, filepath, body);
|
||||
""",
|
||||
USING fts5(name, last_timestamp, hash, filepath, body);""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _check_index_status(self):
|
||||
"""Check the db's notes indexing status using the hash and the timestamp"""
|
||||
db_path = self._get_database_path()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query the database to check if it's already indexed
|
||||
cursor.execute("SELECT COUNT(*) FROM note;")
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return count > 0
|
||||
|
||||
def _do_index(self):
|
||||
"""Index the notes, either partially or fully"""
|
||||
db_path = self._get_database_path()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete existing records before re-indexing
|
||||
cursor.execute("DELETE FROM note;")
|
||||
|
||||
# Traverse the notebook directory and index the notes
|
||||
for path in self._get_notebook_files():
|
||||
encrypted_data = self._read_encrypted_file(path)
|
||||
decrypted_data = self.encryptor.decrypt(encrypted_data)
|
||||
|
||||
# Calculate the hash of the decrypted data
|
||||
hash_value = self._calculate_hash(decrypted_data)
|
||||
|
||||
# Insert the indexed data into the database
|
||||
cursor.execute(
|
||||
"""INSERT INTO notes (last_timestamp, hash, filepath, body)
|
||||
VALUES (?, ?, ?, ?);""",
|
||||
(os.path.getmtime(path), hash_value, str(path), decrypted_data),
|
||||
def _search_note_in_db_by_path(self, path: Path) -> tuple[str | None, str | None]:
|
||||
with self.db_conn:
|
||||
cursor = self.db_conn.execute(
|
||||
"SELECT hash, last_timestamp FROM notes where filepath = ?",
|
||||
(str(path),),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def run(self):
|
||||
"""`halig search` entrypoint, which does a few checks before running
|
||||
the query.
|
||||
|
||||
1. Check if the notes are indexed
|
||||
2. If there are notes to be indexed or the database does not exist
|
||||
or it has an incorrect schema, the user is prompted to allow
|
||||
the program to reindex
|
||||
3. After we're sure the database is in a correct state, we perform the
|
||||
query
|
||||
4. We print the results as if it were `grep -rin` output
|
||||
"""
|
||||
self._create_schema()
|
||||
# Check if indexing is required or if the database is in an incorrect state
|
||||
index_status = self._check_index_status()
|
||||
if self.should_index or not index_status:
|
||||
self._do_index()
|
||||
|
||||
# Perform the search query
|
||||
self._perform_search()
|
||||
|
||||
def _perform_search(self):
|
||||
"""Perform the search query and print the results
|
||||
with highlighted search term
|
||||
"""
|
||||
db_path = self._get_database_path()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Execute the search query
|
||||
cursor.execute(
|
||||
"SELECT filepath, body FROM note WHERE body MATCH ?;",
|
||||
(self.search_term,),
|
||||
)
|
||||
|
||||
# Fetch and print the results with highlighted search term
|
||||
console = Console()
|
||||
search_regex = re.compile(re.escape(self.search_term), re.IGNORECASE)
|
||||
|
||||
results = cursor.fetchall()
|
||||
if not results:
|
||||
return None, None
|
||||
return results[0] # type: ignore[no-any-return]
|
||||
|
||||
def _index_note(
|
||||
self,
|
||||
updated_at: float,
|
||||
body_hash: str,
|
||||
note_path: Path,
|
||||
body: str,
|
||||
):
|
||||
with self.db_conn:
|
||||
self.db_conn.execute(
|
||||
""""INSERT INTO notes (name, last_timestamp, hash, filepath, body)
|
||||
VALUES (?, ?, ?, ?, ?);""",
|
||||
(note_path.name, updated_at, body_hash, str(note_path), body),
|
||||
)
|
||||
|
||||
def _update_index_note(
|
||||
self,
|
||||
updated_at: float,
|
||||
body_hash: str,
|
||||
note_path: Path,
|
||||
body: str,
|
||||
):
|
||||
with self.db_conn:
|
||||
self.db_conn.execute(
|
||||
"""UPDATE notes SET
|
||||
last_timestamp = (?),
|
||||
hash = (?),
|
||||
body = (?)
|
||||
WHERE
|
||||
filepath = (?);
|
||||
""",
|
||||
(updated_at, body_hash, body, str(note_path)),
|
||||
)
|
||||
|
||||
def _index_notebooks(self):
|
||||
for note_path in self.settings.notebooks_root_path.glob("./**/*.age"):
|
||||
updated_at = note_path.stat().st_mtime
|
||||
with note_path.open("rb") as f:
|
||||
body = self.encryptor.decrypt(f.read())
|
||||
body_hash = hashlib.sha512(body).hexdigest()
|
||||
original_hash, last_timestamp = self._search_note_in_db_by_path(note_path)
|
||||
if not original_hash:
|
||||
self._index_note(updated_at, body_hash, note_path, body.decode())
|
||||
continue
|
||||
|
||||
if hash != original_hash:
|
||||
self._update_index_note(updated_at, body_hash, note_path, body.decode())
|
||||
|
||||
def _search(self):
|
||||
with self.db_conn:
|
||||
cursor = self.db_conn.execute(
|
||||
"SELECT filepath, body FROM notes WHERE body MATCH ? ORDER BY rank;",
|
||||
(f"{self.term}*",),
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
console = Console()
|
||||
search_regex = re.compile(re.escape(self.term), re.IGNORECASE)
|
||||
for result in results:
|
||||
filepath, body = result
|
||||
lines = body.split("\n")
|
||||
|
||||
# Split the body into lines
|
||||
lines = body.decode().split("\n")
|
||||
|
||||
# Iterate over lines and find the line number where the search term is found
|
||||
for lineno, line in enumerate(lines, start=1):
|
||||
match = search_regex.search(line)
|
||||
if match:
|
||||
content_line = search_regex.sub("[bold red]\\g<0>[/bold red]", line)
|
||||
console.print(f"{filepath}:{lineno}: {content_line}")
|
||||
|
||||
conn.close()
|
||||
|
||||
def _get_database_path(self) -> Path:
|
||||
"""Get the path to the SQLite database"""
|
||||
cache_dir = platformdirs.user_cache_path("halig", ensure_exists=True)
|
||||
db_path = cache_dir / "halig.db"
|
||||
db_path.touch()
|
||||
return db_path
|
||||
|
||||
def _get_notebook_files(self) -> Generator[Path, None, None]:
|
||||
"""Get the list of notebook files to index"""
|
||||
return self.settings.notebooks_root_path.glob("**/*.age")
|
||||
|
||||
def _read_encrypted_file(self, file_path: Path) -> bytes:
|
||||
"""Read the encrypted contents of a file"""
|
||||
with file_path.open("rb") as file:
|
||||
return file.read()
|
||||
|
||||
def _calculate_hash(self, data: bytes) -> str:
|
||||
"""Calculate the hash of the data"""
|
||||
# Use an appropriate hash algorithm, e.g., hashlib.sha256()
|
||||
# Adjust the hashing algorithm based on your requirements
|
||||
hash_object = hashlib.sha256(data)
|
||||
return hash_object.hexdigest()
|
||||
def run(self):
|
||||
self._create_schema()
|
||||
if self.index:
|
||||
self._index_notebooks()
|
||||
self._search()
|
||||
self.db_conn.close()
|
||||
|
|
|
|||
|
|
@ -117,12 +117,13 @@ def import_unencrypted(
|
|||
|
||||
@app.command()
|
||||
def search(
|
||||
search_term: str,
|
||||
term: str,
|
||||
index: bool = False,
|
||||
):
|
||||
settings = load_from_file()
|
||||
command = SearchCommand(
|
||||
search_term=search_term,
|
||||
should_index=True,
|
||||
term=term,
|
||||
index=index,
|
||||
settings=settings,
|
||||
)
|
||||
command.run()
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def settings_file_path(halig_path: Path, notebooks_path: Path) -> Path:
|
|||
yaml_file.touch()
|
||||
s = Settings(notebooks_root_path=notebooks_path)
|
||||
# `.dict()` doesn't serialize some fields that yaml doesn't understand
|
||||
serialized = json.loads(s.json())
|
||||
serialized = json.loads(s.model_dump_json())
|
||||
with yaml_file.open("w") as f:
|
||||
yaml.safe_dump(serialized, f)
|
||||
return yaml_file
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue