diff --git a/config.py b/config.py index b0cca86..9e3aad1 100644 --- a/config.py +++ b/config.py @@ -30,6 +30,6 @@ class Config: OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o") - RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "6")) + RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "15")) RAG_CHUNK_SIZE = int(os.environ.get("RAG_CHUNK_SIZE", "300")) RAG_CHUNK_OVERLAP = int(os.environ.get("RAG_CHUNK_OVERLAP", "75")) diff --git a/services/llm_service.py b/services/llm_service.py index 0ae8716..e445bf0 100644 --- a/services/llm_service.py +++ b/services/llm_service.py @@ -55,13 +55,18 @@ def ask( model = _get_model() system_parts = [ - "You are a helpful AI assistant. Answer questions accurately based on the provided context.", - "If the context does not contain enough information, say so clearly.", + "You are a helpful AI assistant. You will be given excerpts from one or more documents " + "as context. Synthesize the information from ALL relevant excerpts to give a complete answer. " + "The excerpts are ordered by their position in the document.", + "If specific information is not contained in the context, say so clearly.", + "When the answer spans multiple sections, summarize each relevant part.", ] if context_chunks: - context_text = "\n\n---\n\n".join(context_chunks) - system_parts.append(f"\n\n## Context\n\n{context_text}") + context_text = "\n\n---\n\n".join( + f"[Excerpt {i+1}]\n{chunk}" for i, chunk in enumerate(context_chunks) + ) + system_parts.append(f"\n\n## Document excerpts\n\n{context_text}") if system_extra: system_parts.append(system_extra) diff --git a/services/rag_service.py b/services/rag_service.py index 6249e85..3f2c2d9 100644 --- a/services/rag_service.py +++ b/services/rag_service.py @@ -133,6 +133,63 @@ def delete_source(user_id: int, source_id: int, source_type: str): pass +def _expand_query(query: str) -> list[str]: + """Ask the LLM to generate 2 alternative search queries for better recall. + Falls back silently to [query] on any error.""" + try: + base_url = current_app.config.get("LM_STUDIO_URL", "") + model = current_app.config.get("LM_STUDIO_MODEL", "") + if not base_url or not model: + return [query] + _llm = OpenAI( + base_url=f"{base_url.rstrip('/')}/v1", + api_key="lm-studio", + timeout=15.0, + ) + resp = _llm.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": ( + "Write 2 short alternative search queries to find relevant passages " + "in a document for the following question.\n" + f"Question: {query}\n" + "Output only the 2 queries, one per line, no numbering or explanation." + ), + }], + temperature=0.3, + max_tokens=80, + ) + extras = [ + line.strip() + for line in resp.choices[0].message.content.splitlines() + if line.strip() and line.strip().lower() != query.lower() + ][:2] + return [query] + extras + except Exception: + return [query] + + +def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]: + """Fetch the chunks immediately before and after the given chunk_id. + IDs follow the pattern: {source_type}_{source_id}_chunk_{index}""" + m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id) + if not m: + return [] + prefix, idx = m.group(1), int(m.group(2)) + neighbors: list[tuple[str, str]] = [] + for offset in (-1, 1): + nid = f"{prefix}{idx + offset}" + try: + res = collection.get(ids=[nid], include=["documents"]) + docs = res.get("documents") or [] + if docs: + neighbors.append((nid, docs[0])) + except Exception: + pass + return neighbors + + def similarity_search( query: str, user_id: int, @@ -140,33 +197,59 @@ def similarity_search( source_type: Optional[str] = None, top_k: int = 5, ) -> list[str]: - """Search for relevant chunks via LM Studio embeddings.""" + """Multi-query search with neighbor expansion and reading-order sorting.""" collection = _get_collection() conditions = [{"user_id": {"$eq": str(user_id)}}] - if source_ids is not None and len(source_ids) > 0: - conditions.append( - {"source_id": {"$in": [str(sid) for sid in source_ids]}} - ) - + conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}}) if source_type: conditions.append({"source_type": {"$eq": source_type}}) - where = {"$and": conditions} if len(conditions) > 1 else conditions[0] - try: - results = collection.query( - query_texts=[query], - n_results=top_k, - where=where, - include=["documents", "distances"], - ) - docs = results["documents"][0] if results["documents"] else [] - distances = results["distances"][0] if results.get("distances") else [] - # Cosine distance: 0 = identical, 2 = opposite. Filter out poor matches (> 0.6). - if distances: - docs = [d for d, dist in zip(docs, distances) if dist < 0.6] - return docs - except Exception: - return [] + # Generate multiple queries for broader recall + queries = _expand_query(query) + + # Collect (distance, chunk_id, text) across all queries — deduplicated + seen_ids: set[str] = set() + ranked: list[tuple[float, str, str]] = [] # (distance, id, text) + + for q in queries: + try: + total = collection.count() + n = max(1, min(top_k, total)) + results = collection.query( + query_texts=[q], + n_results=n, + where=where, + include=["documents", "distances", "ids"], + ) + docs = (results.get("documents") or [[]])[0] + dists = (results.get("distances") or [[]])[0] + ids = (results.get("ids") or [[]])[0] + for doc, dist, doc_id in zip(docs, dists, ids): + if doc_id not in seen_ids and dist < 0.65: + seen_ids.add(doc_id) + ranked.append((dist, doc_id, doc)) + except Exception: + pass + + # Sort best-matching chunks first, keep only top_k + ranked.sort(key=lambda x: x[0]) + top_chunks = ranked[:top_k] + + # Expand each hit with its neighboring chunks + id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks} + for _, chunk_id, _ in top_chunks: + for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection): + if nid not in seen_ids: + seen_ids.add(nid) + id_to_doc[nid] = ndoc + + # Sort all collected chunks by (source_id, chunk_index) for reading order + def _sort_key(cid: str): + m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid) + return (int(m.group(1)), int(m.group(2))) if m else (0, 0) + + ordered_ids = sorted(id_to_doc.keys(), key=_sort_key) + return [id_to_doc[cid] for cid in ordered_ids]