Files
notes/services/rag_service.py
2026-05-23 14:01:35 +02:00

173 lines
5.8 KiB
Python

"""
RAG service using ChromaDB + LM Studio's /v1/embeddings endpoint.
No local ML libraries (torch, sentence-transformers, onnxruntime) needed —
embeddings are generated by the same LM Studio instance used for chat.
Each chunk is stored with metadata: user_id, source_id, source_type (doc|url).
"""
import re
from typing import Optional
from chromadb import EmbeddingFunction, Documents, Embeddings
from flask import current_app
from openai import OpenAI
_chroma_client = None
_collection = None
class LMStudioEmbeddingFunction(EmbeddingFunction):
"""ChromaDB-compatible embedding function that calls LM Studio's /v1/embeddings."""
def __init__(self, base_url: str, model: str):
self._client = OpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key="lm-studio",
timeout=60.0,
)
self._model = model
def __call__(self, input: Documents) -> Embeddings:
response = self._client.embeddings.create(model=self._model, input=input)
return [item.embedding for item in response.data]
def _get_collection():
global _chroma_client, _collection
if _collection is None:
import chromadb
path = current_app.config["VECTORDB_PATH"]
base_url = current_app.config["LM_STUDIO_URL"]
model = current_app.config["LM_STUDIO_EMBEDDING_MODEL"]
_chroma_client = chromadb.PersistentClient(path=path)
_collection = _chroma_client.get_or_create_collection(
name="ki_context",
embedding_function=LMStudioEmbeddingFunction(base_url, model),
metadata={"hnsw:space": "cosine"},
)
return _collection
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]:
"""
Paragraph-aware chunking:
1. Split on blank lines to keep paragraphs intact.
2. Paragraphs shorter than chunk_size are merged together.
3. Paragraphs longer than chunk_size are split by sentence, then by word.
Overlap is applied by re-including the last <overlap> words of the previous chunk.
"""
import re as _re
# Split into paragraphs (one or more blank lines)
paragraphs = [p.strip() for p in _re.split(r'\n{2,}', text) if p.strip()]
# Further split very long paragraphs at sentence boundaries
sentences: list[str] = []
for para in paragraphs:
# Split on ". ", "! ", "? " followed by uppercase or end-of-string
parts = _re.split(r'(?<=[.!?])\s+(?=[A-ZÜÖÄ\u00C0-\u00FF"])', para)
sentences.extend([s.strip() for s in parts if s.strip()])
chunks: list[str] = []
current_words: list[str] = []
for sentence in sentences:
words = sentence.split()
# If adding this sentence would exceed chunk_size, flush current chunk
if current_words and len(current_words) + len(words) > chunk_size:
chunks.append(" ".join(current_words))
# Keep last <overlap> words as context for the next chunk
current_words = current_words[-overlap:] if overlap else []
current_words.extend(words)
# If a single sentence is already longer than chunk_size, flush immediately
if len(current_words) >= chunk_size:
chunks.append(" ".join(current_words))
current_words = current_words[-overlap:] if overlap else []
if current_words:
chunks.append(" ".join(current_words))
return [c for c in chunks if len(c.split()) >= 10] # drop near-empty chunks
def index_source(
text: str,
user_id: int,
source_id: int,
source_type: str, # "doc" | "url"
chunk_size: int = 500,
chunk_overlap: int = 50,
):
"""Chunk, embed via LM Studio and store in ChromaDB. Replaces existing chunks."""
collection = _get_collection()
delete_source(user_id, source_id, source_type)
chunks = chunk_text(text, chunk_size, chunk_overlap)
if not chunks:
return
ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))]
metadatas = [
{"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type}
for _ in chunks
]
collection.add(documents=chunks, ids=ids, metadatas=metadatas)
def delete_source(user_id: int, source_id: int, source_type: str):
"""Remove all chunks belonging to a source from ChromaDB."""
collection = _get_collection()
try:
collection.delete(
where={
"$and": [
{"user_id": {"$eq": str(user_id)}},
{"source_id": {"$eq": str(source_id)}},
{"source_type": {"$eq": source_type}},
]
}
)
except Exception:
pass
def similarity_search(
query: str,
user_id: int,
source_ids: Optional[list[int]] = None,
source_type: Optional[str] = None,
top_k: int = 5,
) -> list[str]:
"""Search for relevant chunks via LM Studio embeddings."""
collection = _get_collection()
conditions = [{"user_id": {"$eq": str(user_id)}}]
if source_ids is not None and len(source_ids) > 0:
conditions.append(
{"source_id": {"$in": [str(sid) for sid in source_ids]}}
)
if source_type:
conditions.append({"source_type": {"$eq": source_type}})
where = {"$and": conditions} if len(conditions) > 1 else conditions[0]
try:
results = collection.query(
query_texts=[query],
n_results=top_k,
where=where,
include=["documents", "distances"],
)
docs = results["documents"][0] if results["documents"] else []
distances = results["distances"][0] if results.get("distances") else []
# Cosine distance: 0 = identical, 2 = opposite. Filter out poor matches (> 0.6).
if distances:
docs = [d for d, dist in zip(docs, distances) if dist < 0.6]
return docs
except Exception:
return []