""" RAG service using ChromaDB + sentence-transformers. Each chunk is stored with metadata: user_id, source_id, source_type (doc|url). """ import os import re from typing import Optional from flask import current_app _chroma_client = None _collection = None _embedder = None def _get_embedder(): global _embedder if _embedder is None: from sentence_transformers import SentenceTransformer cache = current_app.config.get("TRANSFORMERS_CACHE", ".cache") _embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache) return _embedder def _get_collection(): global _chroma_client, _collection if _collection is None: import chromadb path = current_app.config["VECTORDB_PATH"] _chroma_client = chromadb.PersistentClient(path=path) _collection = _chroma_client.get_or_create_collection( name="ki_context", metadata={"hnsw:space": "cosine"}, ) return _collection def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]: """Split text into overlapping word-based chunks.""" words = text.split() chunks = [] start = 0 while start < len(words): end = start + chunk_size chunks.append(" ".join(words[start:end])) start += chunk_size - overlap return [c for c in chunks if c.strip()] def index_source( text: str, user_id: int, source_id: int, source_type: str, # "doc" | "url" chunk_size: int = 500, chunk_overlap: int = 50, ): """Chunk, embed and store text in ChromaDB. Replaces existing chunks for this source.""" collection = _get_collection() embedder = _get_embedder() # Remove old chunks for this source first delete_source(user_id, source_id, source_type) chunks = chunk_text(text, chunk_size, chunk_overlap) if not chunks: return embeddings = embedder.encode(chunks, show_progress_bar=False).tolist() ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))] metadatas = [ {"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type} for _ in chunks ] collection.add(documents=chunks, embeddings=embeddings, ids=ids, metadatas=metadatas) def delete_source(user_id: int, source_id: int, source_type: str): """Remove all chunks belonging to a source from ChromaDB.""" collection = _get_collection() try: collection.delete( where={ "$and": [ {"user_id": {"$eq": str(user_id)}}, {"source_id": {"$eq": str(source_id)}}, {"source_type": {"$eq": source_type}}, ] } ) except Exception: pass def similarity_search( query: str, user_id: int, source_ids: Optional[list[int]] = None, source_type: Optional[str] = None, top_k: int = 5, ) -> list[str]: """ Search for relevant chunks. Optionally filter by specific source_ids and/or source_type. Returns list of chunk texts. """ collection = _get_collection() embedder = _get_embedder() query_embedding = embedder.encode([query], show_progress_bar=False).tolist()[0] # Build where filter conditions = [{"user_id": {"$eq": str(user_id)}}] if source_ids is not None and len(source_ids) > 0: conditions.append( {"source_id": {"$in": [str(sid) for sid in source_ids]}} ) if source_type: conditions.append({"source_type": {"$eq": source_type}}) where = {"$and": conditions} if len(conditions) > 1 else conditions[0] try: results = collection.query( query_embeddings=[query_embedding], n_results=top_k, where=where, ) return results["documents"][0] if results["documents"] else [] except Exception: return []