modified: .env.example
modified: config.py modified: docker-compose.yml modified: requirements.txt modified: services/rag_service.py
This commit is contained in:
@@ -16,6 +16,8 @@ AI_PROVIDER=lmstudio
|
|||||||
# On Linux/Docker: use host.docker.internal to reach the host
|
# On Linux/Docker: use host.docker.internal to reach the host
|
||||||
LM_STUDIO_URL=http://host.docker.internal:1234
|
LM_STUDIO_URL=http://host.docker.internal:1234
|
||||||
LM_STUDIO_MODEL=local-model
|
LM_STUDIO_MODEL=local-model
|
||||||
|
# Model used for RAG embeddings — can be the same model or a dedicated embedding model
|
||||||
|
LM_STUDIO_EMBEDDING_MODEL=local-model
|
||||||
|
|
||||||
# OpenAI (only needed when AI_PROVIDER=openai)
|
# OpenAI (only needed when AI_PROVIDER=openai)
|
||||||
# OPENAI_API_KEY=sk-...
|
# OPENAI_API_KEY=sk-...
|
||||||
|
|||||||
11
config.py
11
config.py
@@ -8,16 +8,16 @@ BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
SECRET_KEY = os.environ.get("SECRET_KEY", "change-me-in-production")
|
SECRET_KEY = os.environ.get("SECRET_KEY", "change-me-in-production")
|
||||||
SQLALCHEMY_DATABASE_URI = os.environ.get(
|
_db_uri = os.environ.get("DATABASE_URI", "")
|
||||||
"DATABASE_URI", f"sqlite:///{os.path.join(BASE_DIR, 'app.db')}"
|
_default_uri = f"sqlite:///{os.path.join(BASE_DIR, 'app.db')}"
|
||||||
|
# Fall back to SQLite if DATABASE_URI is empty or not a valid SQLAlchemy URL
|
||||||
|
SQLALCHEMY_DATABASE_URI = (
|
||||||
|
_db_uri if _db_uri and "://" in _db_uri else _default_uri
|
||||||
)
|
)
|
||||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||||
|
|
||||||
UPLOAD_FOLDER = os.environ.get("UPLOAD_FOLDER", os.path.join(BASE_DIR, "uploads"))
|
UPLOAD_FOLDER = os.environ.get("UPLOAD_FOLDER", os.path.join(BASE_DIR, "uploads"))
|
||||||
VECTORDB_PATH = os.environ.get("VECTORDB_PATH", os.path.join(BASE_DIR, "vectordb"))
|
VECTORDB_PATH = os.environ.get("VECTORDB_PATH", os.path.join(BASE_DIR, "vectordb"))
|
||||||
TRANSFORMERS_CACHE = os.environ.get(
|
|
||||||
"TRANSFORMERS_CACHE", os.path.join(BASE_DIR, ".cache")
|
|
||||||
)
|
|
||||||
|
|
||||||
ALLOWED_EXTENSIONS = {"pdf", "txt", "docx", "md"}
|
ALLOWED_EXTENSIONS = {"pdf", "txt", "docx", "md"}
|
||||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50 MB
|
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50 MB
|
||||||
@@ -26,6 +26,7 @@ class Config:
|
|||||||
AI_PROVIDER = os.environ.get("AI_PROVIDER", "lmstudio")
|
AI_PROVIDER = os.environ.get("AI_PROVIDER", "lmstudio")
|
||||||
LM_STUDIO_URL = os.environ.get("LM_STUDIO_URL", "http://localhost:1234")
|
LM_STUDIO_URL = os.environ.get("LM_STUDIO_URL", "http://localhost:1234")
|
||||||
LM_STUDIO_MODEL = os.environ.get("LM_STUDIO_MODEL", "local-model")
|
LM_STUDIO_MODEL = os.environ.get("LM_STUDIO_MODEL", "local-model")
|
||||||
|
LM_STUDIO_EMBEDDING_MODEL = os.environ.get("LM_STUDIO_EMBEDDING_MODEL", "local-model")
|
||||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||||
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
|
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- uploads:/app/uploads
|
- uploads:/app/uploads
|
||||||
- vectordb:/app/vectordb
|
- vectordb:/app/vectordb
|
||||||
- hf_cache:/app/.cache
|
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c",
|
test: ["CMD", "python", "-c",
|
||||||
"import urllib.request; urllib.request.urlopen('http://localhost:5000/auth/login')"]
|
"import urllib.request; urllib.request.urlopen('http://localhost:5000/auth/login')"]
|
||||||
@@ -21,4 +20,3 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
uploads:
|
uploads:
|
||||||
vectordb:
|
vectordb:
|
||||||
hf_cache:
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ python-docx==1.1.2
|
|||||||
markdown==3.6
|
markdown==3.6
|
||||||
beautifulsoup4==4.12.3
|
beautifulsoup4==4.12.3
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
sentence-transformers==3.0.1
|
|
||||||
chromadb==0.5.3
|
chromadb==0.5.3
|
||||||
openai==1.35.3
|
openai==1.35.3
|
||||||
gunicorn==22.0.0
|
gunicorn==22.0.0
|
||||||
|
|||||||
@@ -1,25 +1,31 @@
|
|||||||
"""
|
"""
|
||||||
RAG service using ChromaDB + sentence-transformers.
|
RAG service using ChromaDB + LM Studio's /v1/embeddings endpoint.
|
||||||
|
No local ML libraries (torch, sentence-transformers, onnxruntime) needed —
|
||||||
|
embeddings are generated by the same LM Studio instance used for chat.
|
||||||
Each chunk is stored with metadata: user_id, source_id, source_type (doc|url).
|
Each chunk is stored with metadata: user_id, source_id, source_type (doc|url).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from chromadb import EmbeddingFunction, Documents, Embeddings
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
_chroma_client = None
|
_chroma_client = None
|
||||||
_collection = None
|
_collection = None
|
||||||
_embedder = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_embedder():
|
class LMStudioEmbeddingFunction(EmbeddingFunction):
|
||||||
global _embedder
|
"""ChromaDB-compatible embedding function that calls LM Studio's /v1/embeddings."""
|
||||||
if _embedder is None:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
def __init__(self, base_url: str, model: str):
|
||||||
cache = current_app.config.get("TRANSFORMERS_CACHE", ".cache")
|
self._client = OpenAI(base_url=f"{base_url}/v1", api_key="lm-studio")
|
||||||
_embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache)
|
self._model = model
|
||||||
return _embedder
|
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
response = self._client.embeddings.create(model=self._model, input=input)
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
|
|
||||||
|
|
||||||
def _get_collection():
|
def _get_collection():
|
||||||
@@ -27,16 +33,18 @@ def _get_collection():
|
|||||||
if _collection is None:
|
if _collection is None:
|
||||||
import chromadb
|
import chromadb
|
||||||
path = current_app.config["VECTORDB_PATH"]
|
path = current_app.config["VECTORDB_PATH"]
|
||||||
|
base_url = current_app.config["LM_STUDIO_URL"]
|
||||||
|
model = current_app.config["LM_STUDIO_EMBEDDING_MODEL"]
|
||||||
_chroma_client = chromadb.PersistentClient(path=path)
|
_chroma_client = chromadb.PersistentClient(path=path)
|
||||||
_collection = _chroma_client.get_or_create_collection(
|
_collection = _chroma_client.get_or_create_collection(
|
||||||
name="ki_context",
|
name="ki_context",
|
||||||
|
embedding_function=LMStudioEmbeddingFunction(base_url, model),
|
||||||
metadata={"hnsw:space": "cosine"},
|
metadata={"hnsw:space": "cosine"},
|
||||||
)
|
)
|
||||||
return _collection
|
return _collection
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]:
|
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[str]:
|
||||||
"""Split text into overlapping word-based chunks."""
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
chunks = []
|
chunks = []
|
||||||
start = 0
|
start = 0
|
||||||
@@ -55,26 +63,22 @@ def index_source(
|
|||||||
chunk_size: int = 500,
|
chunk_size: int = 500,
|
||||||
chunk_overlap: int = 50,
|
chunk_overlap: int = 50,
|
||||||
):
|
):
|
||||||
"""Chunk, embed and store text in ChromaDB. Replaces existing chunks for this source."""
|
"""Chunk, embed via LM Studio and store in ChromaDB. Replaces existing chunks."""
|
||||||
collection = _get_collection()
|
collection = _get_collection()
|
||||||
embedder = _get_embedder()
|
|
||||||
|
|
||||||
# Remove old chunks for this source first
|
|
||||||
delete_source(user_id, source_id, source_type)
|
delete_source(user_id, source_id, source_type)
|
||||||
|
|
||||||
chunks = chunk_text(text, chunk_size, chunk_overlap)
|
chunks = chunk_text(text, chunk_size, chunk_overlap)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
embeddings = embedder.encode(chunks, show_progress_bar=False).tolist()
|
|
||||||
|
|
||||||
ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))]
|
ids = [f"{source_type}_{source_id}_chunk_{i}" for i in range(len(chunks))]
|
||||||
metadatas = [
|
metadatas = [
|
||||||
{"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type}
|
{"user_id": str(user_id), "source_id": str(source_id), "source_type": source_type}
|
||||||
for _ in chunks
|
for _ in chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
collection.add(documents=chunks, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
collection.add(documents=chunks, ids=ids, metadatas=metadatas)
|
||||||
|
|
||||||
|
|
||||||
def delete_source(user_id: int, source_id: int, source_type: str):
|
def delete_source(user_id: int, source_id: int, source_type: str):
|
||||||
@@ -101,17 +105,9 @@ def similarity_search(
|
|||||||
source_type: Optional[str] = None,
|
source_type: Optional[str] = None,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Search for relevant chunks via LM Studio embeddings."""
|
||||||
Search for relevant chunks.
|
|
||||||
Optionally filter by specific source_ids and/or source_type.
|
|
||||||
Returns list of chunk texts.
|
|
||||||
"""
|
|
||||||
collection = _get_collection()
|
collection = _get_collection()
|
||||||
embedder = _get_embedder()
|
|
||||||
|
|
||||||
query_embedding = embedder.encode([query], show_progress_bar=False).tolist()[0]
|
|
||||||
|
|
||||||
# Build where filter
|
|
||||||
conditions = [{"user_id": {"$eq": str(user_id)}}]
|
conditions = [{"user_id": {"$eq": str(user_id)}}]
|
||||||
|
|
||||||
if source_ids is not None and len(source_ids) > 0:
|
if source_ids is not None and len(source_ids) > 0:
|
||||||
@@ -126,7 +122,7 @@ def similarity_search(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
results = collection.query(
|
results = collection.query(
|
||||||
query_embeddings=[query_embedding],
|
query_texts=[query],
|
||||||
n_results=top_k,
|
n_results=top_k,
|
||||||
where=where,
|
where=where,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user