260 lines
9.2 KiB
Python
260 lines
9.2 KiB
Python
"""
|
|
RAG service using ChromaDB + LM Studio's /v1/embeddings endpoint.
|
|
No local ML libraries (torch, sentence-transformers, onnxruntime) needed —
|
|
embeddings are generated by the same LM Studio instance used for chat.
|
|
Each chunk is stored with metadata: user_id, source_id, source_type (doc|url).
|
|
"""
|
|
|
|
import re
|
|
from typing import Optional
|
|
|
|
from chromadb import EmbeddingFunction, Documents, Embeddings
|
|
from flask import current_app
|
|
from openai import OpenAI
|
|
|
|
_chroma_client = None
|
|
_collection = None
|
|
|
|
|
|
class LMStudioEmbeddingFunction(EmbeddingFunction):
|
|
"""ChromaDB-compatible embedding function that calls LM Studio's /v1/embeddings."""
|
|
|
|
def __init__(self, base_url: str, model: str):
|
|
self._client = OpenAI(
|
|
base_url=f"{base_url.rstrip('/')}/v1",
|
|
api_key="lm-studio",
|
|
timeout=60.0,
|
|
)
|
|
self._model = model
|
|
|
|
def __call__(self, input: Documents) -> Embeddings:
|
|
response = self._client.embeddings.create(model=self._model, input=input)
|
|
return [item.embedding for item in response.data]
|
|
|
|
|
|
def _get_collection():
|
|
global _chroma_client, _collection
|
|
if _collection is None:
|
|
import chromadb
|
|
path = current_app.config["VECTORDB_PATH"]
|
|
base_url = current_app.config["LM_STUDIO_URL"]
|
|
model = current_app.config["LM_STUDIO_EMBEDDING_MODEL"]
|
|
_chroma_client = chromadb.PersistentClient(path=path)
|
|
_collection = _chroma_client.get_or_create_collection(
|
|
name="ki_context",
|
|
embedding_function=LMStudioEmbeddingFunction(base_url, model),
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
return _collection
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]:
|
|
"""
|
|
Paragraph-aware chunking:
|
|
1. Split on blank lines to keep paragraphs intact.
|
|
2. Paragraphs shorter than chunk_size are merged together.
|
|
3. Paragraphs longer than chunk_size are split by sentence, then by word.
|
|
Overlap is applied by re-including the last <overlap> words of the previous chunk.
|
|
"""
|
|
import re as _re
|
|
|
|
# Split into paragraphs (one or more blank lines)
|
|
paragraphs = [p.strip() for p in _re.split(r'\n{2,}', text) if p.strip()]
|
|
|
|
# Further split very long paragraphs at sentence boundaries
|
|
sentences: list[str] = []
|
|
for para in paragraphs:
|
|
# Split on ". ", "! ", "? " followed by uppercase or end-of-string
|
|
parts = _re.split(r'(?<=[.!?])\s+(?=[A-ZÜÖÄ\u00C0-\u00FF"])', para)
|
|
sentences.extend([s.strip() for s in parts if s.strip()])
|
|
|
|
chunks: list[str] = []
|
|
current_words: list[str] = []
|
|
|
|
for sentence in sentences:
|
|
words = sentence.split()
|
|
# If adding this sentence would exceed chunk_size, flush current chunk
|
|
if current_words and len(current_words) + len(words) > chunk_size:
|
|
chunks.append(" ".join(current_words))
|
|
# Keep last <overlap> words as context for the next chunk
|
|
current_words = current_words[-overlap:] if overlap else []
|
|
current_words.extend(words)
|
|
# If a single sentence is already longer than chunk_size, flush immediately
|
|
if len(current_words) >= chunk_size:
|
|
chunks.append(" ".join(current_words))
|
|
current_words = current_words[-overlap:] if overlap else []
|
|
|
|
if current_words:
|
|
chunks.append(" ".join(current_words))
|
|
|
|
return [c for c in chunks if len(c.split()) >= 10] # drop near-empty chunks
|
|
|
|
|
|
def index_source(
|
|
text: str,
|
|
user_id: int,
|
|
source_id: int,
|
|
source_type: str, # "doc" | "url"
|
|
chunk_size: int = 500,
|
|
chunk_overlap: int = 50,
|
|
):
|
|
"""Chunk, embed via LM Studio and store in ChromaDB. Replaces existing chunks."""
|
|
collection = _get_collection()
|
|
|
|
delete_source(user_id, source_id, source_type)
|
|
|
|
chunks = chunk_text(text, chunk_size, chunk_overlap)
|
|
if not chunks:
|
|
return
|
|
|
|
ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))]
|
|
metadatas = [
|
|
{"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type}
|
|
for _ in chunks
|
|
]
|
|
|
|
collection.add(documents=chunks, ids=ids, metadatas=metadatas)
|
|
|
|
|
|
def delete_source(user_id: int, source_id: int, source_type: str):
|
|
"""Remove all chunks belonging to a source from ChromaDB."""
|
|
collection = _get_collection()
|
|
try:
|
|
collection.delete(
|
|
where={
|
|
"$and": [
|
|
{"user_id": {"$eq": str(user_id)}},
|
|
{"source_id": {"$eq": str(source_id)}},
|
|
{"source_type": {"$eq": source_type}},
|
|
]
|
|
}
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _expand_query(query: str) -> list[str]:
|
|
"""Ask the LLM to generate 2 alternative search queries for better recall.
|
|
Falls back silently to [query] on any error."""
|
|
try:
|
|
base_url = current_app.config.get("LM_STUDIO_URL", "")
|
|
model = current_app.config.get("LM_STUDIO_MODEL", "")
|
|
if not base_url or not model:
|
|
return [query]
|
|
_llm = OpenAI(
|
|
base_url=f"{base_url.rstrip('/')}/v1",
|
|
api_key="lm-studio",
|
|
timeout=15.0,
|
|
)
|
|
resp = _llm.chat.completions.create(
|
|
model=model,
|
|
messages=[{
|
|
"role": "user",
|
|
"content": (
|
|
"Write 2 short alternative search queries to find relevant passages "
|
|
"in a document for the following question.\n"
|
|
f"Question: {query}\n"
|
|
"Output only the 2 queries, one per line, no numbering or explanation."
|
|
),
|
|
}],
|
|
temperature=0.3,
|
|
max_tokens=80,
|
|
)
|
|
extras = [
|
|
line.strip()
|
|
for line in resp.choices[0].message.content.splitlines()
|
|
if line.strip() and line.strip().lower() != query.lower()
|
|
][:2]
|
|
return [query] + extras
|
|
except Exception:
|
|
return [query]
|
|
|
|
|
|
def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]:
|
|
"""Fetch the chunks immediately before and after the given chunk_id.
|
|
IDs follow the pattern: {source_type}_{source_id}_chunk_{index}"""
|
|
m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id)
|
|
if not m:
|
|
return []
|
|
prefix, idx = m.group(1), int(m.group(2))
|
|
neighbors: list[tuple[str, str]] = []
|
|
for offset in (-1, 1):
|
|
nid = f"{prefix}{idx + offset}"
|
|
try:
|
|
res = collection.get(ids=[nid], include=["documents"])
|
|
docs = res.get("documents") or []
|
|
if docs:
|
|
neighbors.append((nid, docs[0]))
|
|
except Exception:
|
|
pass
|
|
return neighbors
|
|
|
|
|
|
def _build_where(user_id: int, source_ids=None, source_type=None) -> dict:
|
|
"""Build a ChromaDB where-filter for user/source scoping."""
|
|
conditions = [{"user_id": {"$eq": str(user_id)}}]
|
|
if source_ids is not None and len(source_ids) > 0:
|
|
conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}})
|
|
if source_type:
|
|
conditions.append({"source_type": {"$eq": source_type}})
|
|
return {"$and": conditions} if len(conditions) > 1 else conditions[0]
|
|
|
|
|
|
def similarity_search(
|
|
query: str,
|
|
user_id: int,
|
|
source_ids: Optional[list[int]] = None,
|
|
source_type: Optional[str] = None,
|
|
top_k: int = 5,
|
|
) -> list[str]:
|
|
"""Multi-query search with neighbor expansion and reading-order sorting."""
|
|
collection = _get_collection()
|
|
where = _build_where(user_id, source_ids, source_type)
|
|
|
|
# Generate multiple queries for broader recall
|
|
queries = _expand_query(query)
|
|
|
|
# Collect (distance, chunk_id, text) across all queries — deduplicated
|
|
seen_ids: set[str] = set()
|
|
ranked: list[tuple[float, str, str]] = [] # (distance, id, text)
|
|
|
|
for q in queries:
|
|
try:
|
|
total = collection.count()
|
|
n = max(1, min(top_k, total))
|
|
results = collection.query(
|
|
query_texts=[q],
|
|
n_results=n,
|
|
where=where,
|
|
include=["documents", "distances", "ids"],
|
|
)
|
|
docs = (results.get("documents") or [[]])[0]
|
|
dists = (results.get("distances") or [[]])[0]
|
|
ids = (results.get("ids") or [[]])[0]
|
|
for doc, dist, doc_id in zip(docs, dists, ids):
|
|
if doc_id not in seen_ids and dist < 0.80:
|
|
seen_ids.add(doc_id)
|
|
ranked.append((dist, doc_id, doc))
|
|
except Exception:
|
|
pass
|
|
|
|
# Sort best-matching chunks first, keep only top_k
|
|
ranked.sort(key=lambda x: x[0])
|
|
top_chunks = ranked[:top_k]
|
|
|
|
# Expand each hit with its neighboring chunks
|
|
id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks}
|
|
for _, chunk_id, _ in top_chunks:
|
|
for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection):
|
|
if nid not in seen_ids:
|
|
seen_ids.add(nid)
|
|
id_to_doc[nid] = ndoc
|
|
|
|
# Sort all collected chunks by (source_id, chunk_index) for reading order
|
|
def _sort_key(cid: str):
|
|
m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid)
|
|
return (int(m.group(1)), int(m.group(2))) if m else (0, 0)
|
|
|
|
ordered_ids = sorted(id_to_doc.keys(), key=_sort_key)
|
|
return [id_to_doc[cid] for cid in ordered_ids]
|