modified: config.py
modified: services/llm_service.py modified: services/rag_service.py
This commit is contained in:
@@ -30,6 +30,6 @@ class Config:
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
|
||||
|
||||
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "6"))
|
||||
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "15"))
|
||||
RAG_CHUNK_SIZE = int(os.environ.get("RAG_CHUNK_SIZE", "300"))
|
||||
RAG_CHUNK_OVERLAP = int(os.environ.get("RAG_CHUNK_OVERLAP", "75"))
|
||||
|
||||
@@ -55,13 +55,18 @@ def ask(
|
||||
model = _get_model()
|
||||
|
||||
system_parts = [
|
||||
"You are a helpful AI assistant. Answer questions accurately based on the provided context.",
|
||||
"If the context does not contain enough information, say so clearly.",
|
||||
"You are a helpful AI assistant. You will be given excerpts from one or more documents "
|
||||
"as context. Synthesize the information from ALL relevant excerpts to give a complete answer. "
|
||||
"The excerpts are ordered by their position in the document.",
|
||||
"If specific information is not contained in the context, say so clearly.",
|
||||
"When the answer spans multiple sections, summarize each relevant part.",
|
||||
]
|
||||
|
||||
if context_chunks:
|
||||
context_text = "\n\n---\n\n".join(context_chunks)
|
||||
system_parts.append(f"\n\n## Context\n\n{context_text}")
|
||||
context_text = "\n\n---\n\n".join(
|
||||
f"[Excerpt {i+1}]\n{chunk}" for i, chunk in enumerate(context_chunks)
|
||||
)
|
||||
system_parts.append(f"\n\n## Document excerpts\n\n{context_text}")
|
||||
|
||||
if system_extra:
|
||||
system_parts.append(system_extra)
|
||||
|
||||
@@ -133,6 +133,63 @@ def delete_source(user_id: int, source_id: int, source_type: str):
|
||||
pass
|
||||
|
||||
|
||||
def _expand_query(query: str) -> list[str]:
|
||||
"""Ask the LLM to generate 2 alternative search queries for better recall.
|
||||
Falls back silently to [query] on any error."""
|
||||
try:
|
||||
base_url = current_app.config.get("LM_STUDIO_URL", "")
|
||||
model = current_app.config.get("LM_STUDIO_MODEL", "")
|
||||
if not base_url or not model:
|
||||
return [query]
|
||||
_llm = OpenAI(
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key="lm-studio",
|
||||
timeout=15.0,
|
||||
)
|
||||
resp = _llm.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Write 2 short alternative search queries to find relevant passages "
|
||||
"in a document for the following question.\n"
|
||||
f"Question: {query}\n"
|
||||
"Output only the 2 queries, one per line, no numbering or explanation."
|
||||
),
|
||||
}],
|
||||
temperature=0.3,
|
||||
max_tokens=80,
|
||||
)
|
||||
extras = [
|
||||
line.strip()
|
||||
for line in resp.choices[0].message.content.splitlines()
|
||||
if line.strip() and line.strip().lower() != query.lower()
|
||||
][:2]
|
||||
return [query] + extras
|
||||
except Exception:
|
||||
return [query]
|
||||
|
||||
|
||||
def _fetch_neighbor_docs(chunk_id: str, collection) -> list[tuple[str, str]]:
|
||||
"""Fetch the chunks immediately before and after the given chunk_id.
|
||||
IDs follow the pattern: {source_type}_{source_id}_chunk_{index}"""
|
||||
m = re.match(r'^(.+_chunk_)(\d+)$', chunk_id)
|
||||
if not m:
|
||||
return []
|
||||
prefix, idx = m.group(1), int(m.group(2))
|
||||
neighbors: list[tuple[str, str]] = []
|
||||
for offset in (-1, 1):
|
||||
nid = f"{prefix}{idx + offset}"
|
||||
try:
|
||||
res = collection.get(ids=[nid], include=["documents"])
|
||||
docs = res.get("documents") or []
|
||||
if docs:
|
||||
neighbors.append((nid, docs[0]))
|
||||
except Exception:
|
||||
pass
|
||||
return neighbors
|
||||
|
||||
|
||||
def similarity_search(
|
||||
query: str,
|
||||
user_id: int,
|
||||
@@ -140,33 +197,59 @@ def similarity_search(
|
||||
source_type: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> list[str]:
|
||||
"""Search for relevant chunks via LM Studio embeddings."""
|
||||
"""Multi-query search with neighbor expansion and reading-order sorting."""
|
||||
collection = _get_collection()
|
||||
|
||||
conditions = [{"user_id": {"$eq": str(user_id)}}]
|
||||
|
||||
if source_ids is not None and len(source_ids) > 0:
|
||||
conditions.append(
|
||||
{"source_id": {"$in": [str(sid) for sid in source_ids]}}
|
||||
)
|
||||
|
||||
conditions.append({"source_id": {"$in": [str(sid) for sid in source_ids]}})
|
||||
if source_type:
|
||||
conditions.append({"source_type": {"$eq": source_type}})
|
||||
|
||||
where = {"$and": conditions} if len(conditions) > 1 else conditions[0]
|
||||
|
||||
# Generate multiple queries for broader recall
|
||||
queries = _expand_query(query)
|
||||
|
||||
# Collect (distance, chunk_id, text) across all queries — deduplicated
|
||||
seen_ids: set[str] = set()
|
||||
ranked: list[tuple[float, str, str]] = [] # (distance, id, text)
|
||||
|
||||
for q in queries:
|
||||
try:
|
||||
total = collection.count()
|
||||
n = max(1, min(top_k, total))
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=top_k,
|
||||
query_texts=[q],
|
||||
n_results=n,
|
||||
where=where,
|
||||
include=["documents", "distances"],
|
||||
include=["documents", "distances", "ids"],
|
||||
)
|
||||
docs = results["documents"][0] if results["documents"] else []
|
||||
distances = results["distances"][0] if results.get("distances") else []
|
||||
# Cosine distance: 0 = identical, 2 = opposite. Filter out poor matches (> 0.6).
|
||||
if distances:
|
||||
docs = [d for d, dist in zip(docs, distances) if dist < 0.6]
|
||||
return docs
|
||||
docs = (results.get("documents") or [[]])[0]
|
||||
dists = (results.get("distances") or [[]])[0]
|
||||
ids = (results.get("ids") or [[]])[0]
|
||||
for doc, dist, doc_id in zip(docs, dists, ids):
|
||||
if doc_id not in seen_ids and dist < 0.65:
|
||||
seen_ids.add(doc_id)
|
||||
ranked.append((dist, doc_id, doc))
|
||||
except Exception:
|
||||
return []
|
||||
pass
|
||||
|
||||
# Sort best-matching chunks first, keep only top_k
|
||||
ranked.sort(key=lambda x: x[0])
|
||||
top_chunks = ranked[:top_k]
|
||||
|
||||
# Expand each hit with its neighboring chunks
|
||||
id_to_doc: dict[str, str] = {cid: doc for _, cid, doc in top_chunks}
|
||||
for _, chunk_id, _ in top_chunks:
|
||||
for nid, ndoc in _fetch_neighbor_docs(chunk_id, collection):
|
||||
if nid not in seen_ids:
|
||||
seen_ids.add(nid)
|
||||
id_to_doc[nid] = ndoc
|
||||
|
||||
# Sort all collected chunks by (source_id, chunk_index) for reading order
|
||||
def _sort_key(cid: str):
|
||||
m = re.match(r'^[a-z]+_(\d+)_chunk_(\d+)$', cid)
|
||||
return (int(m.group(1)), int(m.group(2))) if m else (0, 0)
|
||||
|
||||
ordered_ids = sorted(id_to_doc.keys(), key=_sort_key)
|
||||
return [id_to_doc[cid] for cid in ordered_ids]
|
||||
|
||||
Reference in New Issue
Block a user