""" RAG service using ChromaDB + LM Studio's /v1/embeddings endpoint. No local ML libraries (torch, sentence-transformers, onnxruntime) needed — embeddings are generated by the same LM Studio instance used for chat. Each chunk is stored with metadata: user_id, source_id, source_type (doc|url). """ import re from typing import Optional from chromadb import EmbeddingFunction, Documents, Embeddings from flask import current_app from openai import OpenAI _chroma_client = None _collection = None class LMStudioEmbeddingFunction(EmbeddingFunction): """ChromaDB-compatible embedding function that calls LM Studio's /v1/embeddings.""" def __init__(self, base_url: str, model: str): self._client = OpenAI( base_url=f"{base_url.rstrip('/')}/v1", api_key="lm-studio", timeout=60.0, ) self._model = model def __call__(self, input: Documents) -> Embeddings: response = self._client.embeddings.create(model=self._model, input=input) return [item.embedding for item in response.data] def _get_collection(): global _chroma_client, _collection if _collection is None: import chromadb path = current_app.config["VECTORDB_PATH"] base_url = current_app.config["LM_STUDIO_URL"] model = current_app.config["LM_STUDIO_EMBEDDING_MODEL"] _chroma_client = chromadb.PersistentClient(path=path) _collection = _chroma_client.get_or_create_collection( name="ki_context", embedding_function=LMStudioEmbeddingFunction(base_url, model), metadata={"hnsw:space": "cosine"}, ) return _collection def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]: """ Paragraph-aware chunking: 1. Split on blank lines to keep paragraphs intact. 2. Paragraphs shorter than chunk_size are merged together. 3. Paragraphs longer than chunk_size are split by sentence, then by word. Overlap is applied by re-including the last words of the previous chunk. """ import re as _re # Split into paragraphs (one or more blank lines) paragraphs = [p.strip() for p in _re.split(r'\n{2,}', text) if p.strip()] # Further split very long paragraphs at sentence boundaries sentences: list[str] = [] for para in paragraphs: # Split on ". ", "! ", "? " followed by uppercase or end-of-string parts = _re.split(r'(?<=[.!?])\s+(?=[A-ZÜÖÄ\u00C0-\u00FF"])', para) sentences.extend([s.strip() for s in parts if s.strip()]) chunks: list[str] = [] current_words: list[str] = [] for sentence in sentences: words = sentence.split() # If adding this sentence would exceed chunk_size, flush current chunk if current_words and len(current_words) + len(words) > chunk_size: chunks.append(" ".join(current_words)) # Keep last words as context for the next chunk current_words = current_words[-overlap:] if overlap else [] current_words.extend(words) # If a single sentence is already longer than chunk_size, flush immediately if len(current_words) >= chunk_size: chunks.append(" ".join(current_words)) current_words = current_words[-overlap:] if overlap else [] if current_words: chunks.append(" ".join(current_words)) return [c for c in chunks if len(c.split()) >= 10] # drop near-empty chunks def index_source( text: str, user_id: int, source_id: int, source_type: str, # "doc" | "url" chunk_size: int = 500, chunk_overlap: int = 50, ): """Chunk, embed via LM Studio and store in ChromaDB. Replaces existing chunks.""" collection = _get_collection() delete_source(user_id, source_id, source_type) chunks = chunk_text(text, chunk_size, chunk_overlap) if not chunks: return ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))] metadatas = [ {"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type} for _ in chunks ] collection.add(documents=chunks, ids=ids, metadatas=metadatas) def delete_source(user_id: int, source_id: int, source_type: str): """Remove all chunks belonging to a source from ChromaDB.""" collection = _get_collection() try: collection.delete( where={ "$and": [ {"user_id": {"$eq": str(user_id)}}, {"source_id": {"$eq": str(source_id)}}, {"source_type": {"$eq": source_type}}, ] } ) except Exception: pass def _expand_query(query: str) -> list[str]: """Ask the LLM to generate 2 alternative search queries for better recall. Falls back silently to [query] on any error.""" try: base_url = current_app.config.get("LM_STUDIO_URL", "") model = current_app.config.get("LM_STUDIO_MODEL", "") if not base_url or not model: return [query] _llm = OpenAI( base_url=f"{base_url.rstrip('/')}/v1", api_key="lm-studio", timeout=15.0, ) resp = _llm.chat.completions.create( model=model, messages=[{ "role": "user", "content": ( "Write 2 short alternative search queries to find relevant passages " "in a document for the following question.\n" f"Question: {query}\n" "Output only the 2 queries, one per line, no numbering or explanation." ), }], temperature=0.3, max_tokens=80, ) extras = [ line.strip() for line in resp.choices[0].message.content.splitlines() if line.strip() and line.strip().lower() != query.lower() ][:2] return [query] + extras except Exception: return [query] def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]: """Fetch the chunks immediately before and after the given chunk_id. IDs follow the pattern: {source_type}_{source_id}_chunk_{index}""" m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id) if not m: return [] prefix, idx = m.group(1), int(m.group(2)) neighbors: list[tuple[str, str]] = [] for offset in (-1, 1): nid = f"{prefix}{idx + offset}" try: res = collection.get(ids=[nid], include=["documents"]) docs = res.get("documents") or [] if docs: neighbors.append((nid, docs[0])) except Exception: pass return neighbors def _build_where(user_id: int, source_ids=None, source_type=None) -> dict: """Build a ChromaDB where-filter for user/source scoping.""" conditions = [{"user_id": {"$eq": str(user_id)}}] if source_ids is not None and len(source_ids) > 0: conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}}) if source_type: conditions.append({"source_type": {"$eq": source_type}}) return {"$and": conditions} if len(conditions) > 1 else conditions[0] def similarity_search( query: str, user_id: int, source_ids: Optional[list[int]] = None, source_type: Optional[str] = None, top_k: int = 5, ) -> list[str]: """Multi-query search with neighbor expansion and reading-order sorting.""" collection = _get_collection() where = _build_where(user_id, source_ids, source_type) # Generate multiple queries for broader recall queries = _expand_query(query) # Collect (distance, chunk_id, text) across all queries — deduplicated seen_ids: set[str] = set() ranked: list[tuple[float, str, str]] = [] # (distance, id, text) for q in queries: try: total = collection.count() n = max(1, min(top_k, total)) results = collection.query( query_texts=[q], n_results=n, where=where, include=["documents", "distances", "ids"], ) docs = (results.get("documents") or [[]])[0] dists = (results.get("distances") or [[]])[0] ids = (results.get("ids") or [[]])[0] for doc, dist, doc_id in zip(docs, dists, ids): if doc_id not in seen_ids and dist < 0.80: seen_ids.add(doc_id) ranked.append((dist, doc_id, doc)) except Exception: pass # Sort best-matching chunks first, keep only top_k ranked.sort(key=lambda x: x[0]) top_chunks = ranked[:top_k] # Expand each hit with its neighboring chunks id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks} for _, chunk_id, _ in top_chunks: for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection): if nid not in seen_ids: seen_ids.add(nid) id_to_doc[nid] = ndoc # Sort all collected chunks by (source_id, chunk_index) for reading order def _sort_key(cid: str): m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid) return (int(m.group(1)), int(m.group(2))) if m else (0, 0) ordered_ids = sorted(id_to_doc.keys(), key=_sort_key) return [id_to_doc[cid] for cid in ordered_ids]