modified: config.py

modified:   services/llm_service.py
	modified:   services/rag_service.py
This commit is contained in:
SimolZimol
2026-05-23 14:15:42 +02:00
parent 20f72fab98
commit 9eb0869c50
3 changed files with 115 additions and 27 deletions

View File

@@ -133,6 +133,63 @@ def delete_source(user_id: int, source_id: int, source_type: str):
pass
def _expand_query(query: str) -> list[str]:
"""Ask the LLM to generate 2 alternative search queries for better recall.
Falls back silently to [query] on any error."""
try:
base_url = current_app.config.get("LM_STUDIO_URL", "")
model = current_app.config.get("LM_STUDIO_MODEL", "")
if not base_url or not model:
return [query]
_llm = OpenAI(
base_url=f"{base_url.rstrip('/')}/v1",
api_key="lm-studio",
timeout=15.0,
)
resp = _llm.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": (
"Write 2 short alternative search queries to find relevant passages "
"in a document for the following question.\n"
f"Question: {query}\n"
"Output only the 2 queries, one per line, no numbering or explanation."
),
}],
temperature=0.3,
max_tokens=80,
)
extras = [
line.strip()
for line in resp.choices[0].message.content.splitlines()
if line.strip() and line.strip().lower() != query.lower()
][:2]
return [query] + extras
except Exception:
return [query]
def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]:
"""Fetch the chunks immediately before and after the given chunk_id.
IDs follow the pattern: {source_type}_{source_id}_chunk_{index}"""
m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id)
if not m:
return []
prefix, idx = m.group(1), int(m.group(2))
neighbors: list[tuple[str, str]] = []
for offset in (-1, 1):
nid = f"{prefix}{idx + offset}"
try:
res = collection.get(ids=[nid], include=["documents"])
docs = res.get("documents") or []
if docs:
neighbors.append((nid, docs[0]))
except Exception:
pass
return neighbors
def similarity_search(
query: str,
user_id: int,
@@ -140,33 +197,59 @@ def similarity_search(
source_type: Optional[str] = None,
top_k: int = 5,
) -> list[str]:
"""Search for relevant chunks via LM Studio embeddings."""
"""Multi-query search with neighbor expansion and reading-order sorting."""
collection = _get_collection()
conditions = [{"user_id": {"$eq": str(user_id)}}]
if source_ids is not None and len(source_ids) > 0:
conditions.append(
{"source_id": {"$in": [str(sid) for sid in source_ids]}}
)
conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}})
if source_type:
conditions.append({"source_type": {"$eq": source_type}})
where = {"$and": conditions} if len(conditions) > 1 else conditions[0]
try:
results = collection.query(
query_texts=[query],
n_results=top_k,
where=where,
include=["documents", "distances"],
)
docs = results["documents"][0] if results["documents"] else []
distances = results["distances"][0] if results.get("distances") else []
# Cosine distance: 0 = identical, 2 = opposite. Filter out poor matches (> 0.6).
if distances:
docs = [d for d, dist in zip(docs, distances) if dist < 0.6]
return docs
except Exception:
return []
# Generate multiple queries for broader recall
queries = _expand_query(query)
# Collect (distance, chunk_id, text) across all queries — deduplicated
seen_ids: set[str] = set()
ranked: list[tuple[float, str, str]] = [] # (distance, id, text)
for q in queries:
try:
total = collection.count()
n = max(1, min(top_k, total))
results = collection.query(
query_texts=[q],
n_results=n,
where=where,
include=["documents", "distances", "ids"],
)
docs = (results.get("documents") or [[]])[0]
dists = (results.get("distances") or [[]])[0]
ids = (results.get("ids") or [[]])[0]
for doc, dist, doc_id in zip(docs, dists, ids):
if doc_id not in seen_ids and dist < 0.65:
seen_ids.add(doc_id)
ranked.append((dist, doc_id, doc))
except Exception:
pass
# Sort best-matching chunks first, keep only top_k
ranked.sort(key=lambda x: x[0])
top_chunks = ranked[:top_k]
# Expand each hit with its neighboring chunks
id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks}
for _, chunk_id, _ in top_chunks:
for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection):
if nid not in seen_ids:
seen_ids.add(nid)
id_to_doc[nid] = ndoc
# Sort all collected chunks by (source_id, chunk_index) for reading order
def _sort_key(cid: str):
m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid)
return (int(m.group(1)), int(m.group(2))) if m else (0, 0)
ordered_ids = sorted(id_to_doc.keys(), key=_sort_key)
return [id_to_doc[cid] for cid in ordered_ids]