modified: config.py

modified:   services/llm_service.py
	modified:   services/rag_service.py
This commit is contained in:
SimolZimol
2026-05-23 14:15:42 +02:00
parent 20f72fab98
commit 9eb0869c50
3 changed files with 115 additions and 27 deletions

View File

@@ -30,6 +30,6 @@ class Config:
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o") OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "6")) RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "15"))
RAG_CHUNK_SIZE = int(os.environ.get("RAG_CHUNK_SIZE", "300")) RAG_CHUNK_SIZE = int(os.environ.get("RAG_CHUNK_SIZE", "300"))
RAG_CHUNK_OVERLAP = int(os.environ.get("RAG_CHUNK_OVERLAP", "75")) RAG_CHUNK_OVERLAP = int(os.environ.get("RAG_CHUNK_OVERLAP", "75"))

View File

@@ -55,13 +55,18 @@ def ask(
model = _get_model() model = _get_model()
system_parts = [ system_parts = [
"You are a helpful AI assistant. Answer questions accurately based on the provided context.", "You are a helpful AI assistant. You will be given excerpts from one or more documents "
"If the context does not contain enough information, say so clearly.", "as context. Synthesize the information from ALL relevant excerpts to give a complete answer. "
"The excerpts are ordered by their position in the document.",
"If specific information is not contained in the context, say so clearly.",
"When the answer spans multiple sections, summarize each relevant part.",
] ]
if context_chunks: if context_chunks:
context_text = "\n\n---\n\n".join(context_chunks) context_text = "\n\n---\n\n".join(
system_parts.append(f"\n\n## Context\n\n{context_text}") f"[Excerpt {i+1}]\n{chunk}" for i, chunk in enumerate(context_chunks)
)
system_parts.append(f"\n\n## Document excerpts\n\n{context_text}")
if system_extra: if system_extra:
system_parts.append(system_extra) system_parts.append(system_extra)

View File

@@ -133,6 +133,63 @@ def delete_source(user_id: int, source_id: int, source_type: str):
pass pass
def _expand_query(query: str) -> list[str]:
"""Ask the LLM to generate 2 alternative search queries for better recall.
Falls back silently to [query] on any error."""
try:
base_url = current_app.config.get("LM_STUDIO_URL", "")
model = current_app.config.get("LM_STUDIO_MODEL", "")
if not base_url or not model:
return [query]
_llm = OpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key="lm-studio",
timeout=15.0,
)
resp = _llm.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": (
"Write 2 short alternative search queries to find relevant passages "
"in a document for the following question.\n"
f"Question: {query}\n"
"Output only the 2 queries, one per line, no numbering or explanation."
),
}],
temperature=0.3,
max_tokens=80,
)
extras = [
line.strip()
for line in resp.choices[0].message.content.splitlines()
if line.strip() and line.strip().lower() != query.lower()
][:2]
return [query] + extras
except Exception:
return [query]
def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]:
"""Fetch the chunks immediately before and after the given chunk_id.
IDs follow the pattern: {source_type}_{source_id}_chunk_{index}"""
m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id)
if not m:
return []
prefix, idx = m.group(1), int(m.group(2))
neighbors: list[tuple[str, str]] = []
for offset in (-1, 1):
nid = f"{prefix}{idx + offset}"
try:
res = collection.get(ids=[nid], include=["documents"])
docs = res.get("documents") or []
if docs:
neighbors.append((nid, docs[0]))
except Exception:
pass
return neighbors
def similarity_search( def similarity_search(
query: str, query: str,
user_id: int, user_id: int,
@@ -140,33 +197,59 @@ def similarity_search(
source_type: Optional[str] = None, source_type: Optional[str] = None,
top_k: int = 5, top_k: int = 5,
) -> list[str]: ) -> list[str]:
"""Search for relevant chunks via LM Studio embeddings.""" """Multi-query search with neighbor expansion and reading-order sorting."""
collection = _get_collection() collection = _get_collection()
conditions = [{"user_id": {"$eq": str(user_id)}}] conditions = [{"user_id": {"$eq": str(user_id)}}]
if source_ids is not None and len(source_ids) > 0: if source_ids is not None and len(source_ids) > 0:
conditions.append( conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}})
{"source_id": {"$in": [str(sid) for sid in source_ids]}}
)
if source_type: if source_type:
conditions.append({"source_type": {"$eq": source_type}}) conditions.append({"source_type": {"$eq": source_type}})
where = {"$and": conditions} if len(conditions) > 1 else conditions[0] where = {"$and": conditions} if len(conditions) > 1 else conditions[0]
try: # Generate multiple queries for broader recall
results = collection.query( queries = _expand_query(query)
query_texts=[query],
n_results=top_k, # Collect (distance, chunk_id, text) across all queries — deduplicated
where=where, seen_ids: set[str] = set()
include=["documents", "distances"], ranked: list[tuple[float, str, str]] = [] # (distance, id, text)
)
docs = results["documents"][0] if results["documents"] else [] for q in queries:
distances = results["distances"][0] if results.get("distances") else [] try:
# Cosine distance: 0 = identical, 2 = opposite. Filter out poor matches (> 0.6). total = collection.count()
if distances: n = max(1, min(top_k, total))
docs = [d for d, dist in zip(docs, distances) if dist < 0.6] results = collection.query(
return docs query_texts=[q],
except Exception: n_results=n,
return [] where=where,
include=["documents", "distances", "ids"],
)
docs = (results.get("documents") or [[]])[0]
dists = (results.get("distances") or [[]])[0]
ids = (results.get("ids") or [[]])[0]
for doc, dist, doc_id in zip(docs, dists, ids):
if doc_id not in seen_ids and dist < 0.65:
seen_ids.add(doc_id)
ranked.append((dist, doc_id, doc))
except Exception:
pass
# Sort best-matching chunks first, keep only top_k
ranked.sort(key=lambda x: x[0])
top_chunks = ranked[:top_k]
# Expand each hit with its neighboring chunks
id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks}
for _, chunk_id, _ in top_chunks:
for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection):
if nid not in seen_ids:
seen_ids.add(nid)
id_to_doc[nid] = ndoc
# Sort all collected chunks by (source_id, chunk_index) for reading order
def _sort_key(cid: str):
m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid)
return (int(m.group(1)), int(m.group(2))) if m else (0, 0)
ordered_ids = sorted(id_to_doc.keys(), key=_sort_key)
return [id_to_doc[cid] for cid in ordered_ids]