feat(AISVC-T6.9): 集成Ollama嵌入模型修复RAG检索问题

## 问题修复
- 替换假嵌入(SHA256 hash)为真实Ollama nomic-embed-text嵌入
- 修复Qdrant客户端版本不兼容导致score_threshold参数失效
- 降低默认分数阈值从0.7到0.3

## 新增文件
- ai-service/app/services/embedding/ollama_embedding.py

## 修改文件
- ai-service/app/api/admin/kb.py: 索引任务使用真实嵌入
- ai-service/app/core/config.py: 新增Ollama配置,向量维度改为768
- ai-service/app/core/qdrant_client.py: 移除score_threshold参数
- ai-service/app/services/retrieval/vector_retriever.py: 使用Ollama嵌入
This commit is contained in:
MerCry 2026-02-24 22:15:53 +08:00
parent 5148c6ef42
commit 4b64a4dbf4
5 changed files with 73 additions and 44 deletions

View File

@ -212,14 +212,13 @@ async def upload_document(
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes): async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes):
""" """
Background indexing task. Background indexing task.
For MVP, we simulate indexing with a simple text extraction. Uses Ollama nomic-embed-text for real embeddings.
In production, this would use a task queue like Celery.
""" """
from app.core.database import async_session_maker from app.core.database import async_session_maker
from app.services.kb import KBService from app.services.kb import KBService
from app.core.qdrant_client import get_qdrant_client from app.core.qdrant_client import get_qdrant_client
from app.services.embedding.ollama_embedding import get_embedding
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
import hashlib
import asyncio import asyncio
await asyncio.sleep(1) await asyncio.sleep(1)
@ -241,24 +240,12 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
points = [] points = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
hash_obj = hashlib.sha256(chunk.encode()) embedding = await get_embedding(chunk)
hash_bytes = hash_obj.digest()
embedding = []
for j in range(0, min(len(hash_bytes) * 8, 1536)):
byte_idx = j // 8
bit_idx = j % 8
if byte_idx < len(hash_bytes):
val = (hash_bytes[byte_idx] >> bit_idx) & 1
embedding.append(float(val))
else:
embedding.append(0.0)
while len(embedding) < 1536:
embedding.append(0.0)
points.append( points.append(
PointStruct( PointStruct(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
vector=embedding[:1536], vector=embedding,
payload={ payload={
"text": chunk, "text": chunk,
"source": doc_id, "source": doc_id,

View File

@ -38,10 +38,13 @@ class Settings(BaseSettings):
qdrant_url: str = "http://localhost:6333" qdrant_url: str = "http://localhost:6333"
qdrant_collection_prefix: str = "kb_" qdrant_collection_prefix: str = "kb_"
qdrant_vector_size: int = 1536 qdrant_vector_size: int = 768
ollama_base_url: str = "http://localhost:11434"
ollama_embedding_model: str = "nomic-embed-text"
rag_top_k: int = 5 rag_top_k: int = 5
rag_score_threshold: float = 0.7 rag_score_threshold: float = 0.3
rag_min_hits: int = 1 rag_min_hits: int = 1
rag_max_evidence_tokens: int = 2000 rag_max_evidence_tokens: int = 2000

View File

@ -119,7 +119,6 @@ class QdrantClient:
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query_vector=query_vector,
limit=limit, limit=limit,
score_threshold=score_threshold,
) )
hits = [ hits = [
@ -129,6 +128,7 @@ class QdrantClient:
"payload": result.payload or {}, "payload": result.payload or {},
} }
for result in results for result in results
if score_threshold is None or result.score >= score_threshold
] ]
logger.info( logger.info(

View File

@ -0,0 +1,58 @@
"""
Ollama embedding service for generating text embeddings.
Uses nomic-embed-text model via Ollama API.
"""
import logging
import httpx
from app.core.config import get_settings
logger = logging.getLogger(__name__)
async def get_embedding(text: str) -> list[float]:
"""
Generate embedding vector for text using Ollama nomic-embed-text model.
"""
settings = get_settings()
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={
"model": settings.ollama_embedding_model,
"prompt": text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
logger.warning(f"Empty embedding returned for text length={len(text)}")
return [0.0] * settings.qdrant_vector_size
logger.debug(f"Generated embedding: dim={len(embedding)}")
return embedding
except httpx.HTTPStatusError as e:
logger.error(f"Ollama API error: {e.response.status_code} - {e.response.text}")
raise
except httpx.RequestError as e:
logger.error(f"Ollama connection error: {e}")
raise
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
async def get_embeddings_batch(texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
"""
embeddings = []
for text in texts:
embedding = await get_embedding(text)
embeddings.append(embedding)
return embeddings

View File

@ -119,30 +119,11 @@ class VectorRetriever(BaseRetriever):
async def _get_embedding(self, text: str) -> list[float]: async def _get_embedding(self, text: str) -> list[float]:
""" """
Generate embedding for text. Generate embedding for text using Ollama nomic-embed-text model.
[AC-AISVC-16] Placeholder for embedding generation.
TODO: Integrate with actual embedding provider (OpenAI, local model, etc.)
""" """
import hashlib from app.services.embedding.ollama_embedding import get_embedding as get_ollama_embedding
hash_obj = hashlib.sha256(text.encode()) return await get_ollama_embedding(text)
hash_bytes = hash_obj.digest()
embedding = []
for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)):
byte_idx = i // 8
bit_idx = i % 8
if byte_idx < len(hash_bytes):
val = (hash_bytes[byte_idx] >> bit_idx) & 1
embedding.append(float(val))
else:
embedding.append(0.0)
while len(embedding) < settings.qdrant_vector_size:
embedding.append(0.0)
return embedding[: settings.qdrant_vector_size]
async def health_check(self) -> bool: async def health_check(self) -> bool:
""" """