ai-robot-core/ai-service/app/services/retrieval/optimized_retriever.py

606 lines
22 KiB
Python
Raw Normal View History

"""
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
"""
import asyncio
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
from app.services.retrieval.metadata import (
ChunkMetadata,
MetadataFilter,
RetrieveResult,
RetrievalStrategy,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class TwoStageResult:
"""Result from two-stage retrieval."""
candidates: list[dict[str, Any]]
final_results: list[RetrieveResult]
stage1_latency_ms: float = 0.0
stage2_latency_ms: float = 0.0
class RRFCombiner:
"""
Reciprocal Rank Fusion for combining multiple retrieval results.
Reference: rag-optimization/spec.md Section 2.5
Formula: score = Σ(1 / (k + rank_i))
Default k = 60
"""
def __init__(self, k: int = 60):
self._k = k
def combine(
self,
vector_results: list[dict[str, Any]],
bm25_results: list[dict[str, Any]],
vector_weight: float = 0.7,
bm25_weight: float = 0.3,
) -> list[dict[str, Any]]:
"""
Combine vector and BM25 results using RRF.
Args:
vector_results: Results from vector search
bm25_results: Results from BM25 search
vector_weight: Weight for vector results
bm25_weight: Weight for BM25 results
Returns:
Combined and sorted results
"""
combined_scores: dict[str, dict[str, Any]] = {}
for rank, result in enumerate(vector_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = vector_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": result.get("score", 0.0),
"bm25_score": 0.0,
"vector_rank": rank,
"bm25_rank": -1,
"payload": result.get("payload", {}),
"id": chunk_id,
"vector": result.get("vector"),
}
else:
combined_scores[chunk_id]["vector_score"] = result.get("score", 0.0)
combined_scores[chunk_id]["vector_rank"] = rank
if result.get("vector"):
combined_scores[chunk_id]["vector"] = result.get("vector")
combined_scores[chunk_id]["score"] += rrf_score
for rank, result in enumerate(bm25_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = bm25_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": 0.0,
"bm25_score": result.get("score", 0.0),
"vector_rank": -1,
"bm25_rank": rank,
"payload": result.get("payload", {}),
"id": chunk_id,
"vector": result.get("vector"),
}
else:
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
combined_scores[chunk_id]["bm25_rank"] = rank
combined_scores[chunk_id]["score"] += rrf_score
sorted_results = sorted(
combined_scores.values(),
key=lambda x: x["score"],
reverse=True
)
return sorted_results
class OptimizedRetriever(BaseRetriever):
"""
Optimized retriever with:
- Task prefixes (search_document/search_query)
- Two-stage retrieval (256 dim -> 768 dim)
- RRF hybrid ranking (vector + BM25)
- Metadata filtering
Reference: rag-optimization/spec.md Section 2, 3, 4
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
min_hits: int | None = None,
two_stage_enabled: bool | None = None,
two_stage_expand_factor: int | None = None,
hybrid_enabled: bool | None = None,
rrf_k: int | None = None,
):
self._qdrant_client = qdrant_client
self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
self._rrf_k = rrf_k or settings.rag_rrf_k
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager()
provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider):
return provider
else:
return NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
Retrieve documents using optimized strategy.
Strategy selection:
1. If two_stage_enabled: use two-stage retrieval
2. If hybrid_enabled: use RRF hybrid ranking
3. Otherwise: simple vector search
"""
logger.info(
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
)
logger.info(
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
)
try:
provider = await self._get_embedding_provider()
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
embedding_result = await provider.embed_query(ctx.query)
logger.info(
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
)
if self._two_stage_enabled and self._hybrid_enabled:
logger.info("[RAG-OPT] Using two-stage + hybrid retrieval strategy")
results = await self._two_stage_hybrid_retrieve(
ctx.tenant_id,
embedding_result,
ctx.query,
self._top_k,
)
elif self._two_stage_enabled:
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
results = await self._two_stage_retrieve(
ctx.tenant_id,
embedding_result,
self._top_k,
)
elif self._hybrid_enabled:
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
results = await self._hybrid_retrieve(
ctx.tenant_id,
embedding_result,
ctx.query,
self._top_k,
)
else:
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
results = await self._vector_retrieve(
ctx.tenant_id,
embedding_result.embedding_full,
self._top_k,
)
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
retrieval_hits = [
RetrievalHit(
text=result.get("payload", {}).get("text", ""),
score=result.get("score", 0.0),
source="optimized_rag",
metadata=result.get("payload", {}),
)
for result in results
if result.get("score", 0.0) >= self._score_threshold
]
filtered_count = len(results) - len(retrieval_hits)
if filtered_count > 0:
logger.info(
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
)
is_insufficient = len(retrieval_hits) < self._min_hits
diagnostics = {
"query_length": len(ctx.query),
"top_k": self._top_k,
"score_threshold": self._score_threshold,
"two_stage_enabled": self._two_stage_enabled,
"hybrid_enabled": self._hybrid_enabled,
"total_hits": len(retrieval_hits),
"is_insufficient": is_insufficient,
"max_score": max((h.score for h in retrieval_hits), default=0.0),
"raw_results_count": len(results),
"filtered_below_threshold": filtered_count,
}
logger.info(
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
)
if len(retrieval_hits) == 0:
logger.warning(
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
f"raw_results={len(results)}, threshold={self._score_threshold}"
)
return RetrievalResult(
hits=retrieval_hits,
diagnostics=diagnostics,
)
except Exception as e:
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
return RetrievalResult(
hits=[],
diagnostics={"error": str(e), "is_insufficient": True},
)
async def _two_stage_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
top_k: int,
) -> list[dict[str, Any]]:
"""
Two-stage retrieval using Matryoshka dimensions.
Stage 1: Fast retrieval with 256-dim vectors
Stage 2: Precise reranking with 768-dim vectors
Reference: rag-optimization/spec.md Section 2.4
"""
import time
client = await self._get_client()
stage1_start = time.perf_counter()
candidates = await self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor,
with_vectors=True,
)
stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.info(
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
)
stage2_start = time.perf_counter()
reranked = []
for candidate in candidates:
vector_data = candidate.get("vector", {})
stored_full_embedding = None
if isinstance(vector_data, dict):
stored_full_embedding = vector_data.get("full", [])
elif isinstance(vector_data, list):
stored_full_embedding = vector_data
if stored_full_embedding and len(stored_full_embedding) > 0:
similarity = self._cosine_similarity(
embedding_result.embedding_full,
stored_full_embedding
)
candidate["score"] = similarity
candidate["stage"] = "reranked"
reranked.append(candidate)
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.info(
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
)
return results
async def _hybrid_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
query: str,
top_k: int,
) -> list[dict[str, Any]]:
"""
Hybrid retrieval using RRF to combine vector and BM25 results.
Reference: rag-optimization/spec.md Section 2.5
"""
client = await self._get_client()
vector_task = self._search_with_dimension(
client, tenant_id, embedding_result.embedding_full, "full",
top_k * 2
)
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
vector_results, bm25_results = await asyncio.gather(
vector_task, bm25_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
vector_results = []
if isinstance(bm25_results, Exception):
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
bm25_results = []
combined = self._rrf_combiner.combine(
vector_results,
bm25_results,
vector_weight=settings.rag_vector_weight,
bm25_weight=settings.rag_bm25_weight,
)
return combined[:top_k]
async def _two_stage_hybrid_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
query: str,
top_k: int,
) -> list[dict[str, Any]]:
"""
Two-stage + Hybrid retrieval strategy.
Stage 1: Fast retrieval with 256-dim vectors + BM25 in parallel
Stage 2: RRF fusion + Precise reranking with 768-dim vectors
This combines the best of both worlds:
- Two-stage: Speed from 256-dim, precision from 768-dim reranking
- Hybrid: Semantic matching from vectors, keyword matching from BM25
"""
import time
client = await self._get_client()
stage1_start = time.perf_counter()
vector_task = self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor,
with_vectors=True,
)
bm25_task = self._bm25_search(client, tenant_id, query, top_k * self._two_stage_expand_factor)
vector_results, bm25_results = await asyncio.gather(
vector_task, bm25_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
vector_results = []
if isinstance(bm25_results, Exception):
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
bm25_results = []
stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.info(
f"[RAG-OPT] Two-stage Hybrid Stage 1: vector={len(vector_results)}, bm25={len(bm25_results)}, latency={stage1_latency:.2f}ms"
)
stage2_start = time.perf_counter()
combined = self._rrf_combiner.combine(
vector_results,
bm25_results,
vector_weight=settings.rag_vector_weight,
bm25_weight=settings.rag_bm25_weight,
)
reranked = []
for candidate in combined[:top_k * 2]:
vector_data = candidate.get("vector", {})
stored_full_embedding = None
if isinstance(vector_data, dict):
stored_full_embedding = vector_data.get("full", [])
elif isinstance(vector_data, list):
stored_full_embedding = vector_data
if stored_full_embedding and len(stored_full_embedding) > 0:
similarity = self._cosine_similarity(
embedding_result.embedding_full,
stored_full_embedding
)
candidate["score"] = similarity
candidate["stage"] = "two_stage_hybrid_reranked"
reranked.append(candidate)
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.info(
f"[RAG-OPT] Two-stage Hybrid Stage 2 (reranking): {len(results)} final results in {stage2_latency:.2f}ms"
)
return results
async def _vector_retrieve(
self,
tenant_id: str,
embedding: list[float],
top_k: int,
) -> list[dict[str, Any]]:
"""Simple vector retrieval."""
client = await self._get_client()
return await self._search_with_dimension(
client, tenant_id, embedding, "full", top_k
)
async def _search_with_dimension(
self,
client: QdrantClient,
tenant_id: str,
query_vector: list[float],
vector_name: str,
limit: int,
with_vectors: bool = False,
) -> list[dict[str, Any]]:
"""Search using specified vector dimension."""
try:
logger.info(
f"[RAG-OPT] Searching with vector_name={vector_name}, "
f"limit={limit}, vector_dim={len(query_vector)}, with_vectors={with_vectors}"
)
results = await client.search(
tenant_id=tenant_id,
query_vector=query_vector,
limit=limit,
vector_name=vector_name,
with_vectors=with_vectors,
)
logger.info(
f"[RAG-OPT] Search returned {len(results)} results"
)
if len(results) > 0:
for i, r in enumerate(results[:3]):
logger.debug(
f"[RAG-OPT] Result {i+1}: id={r['id']}, score={r['score']:.4f}"
)
return results
except Exception as e:
logger.error(
f"[RAG-OPT] Search with {vector_name} failed: {e}",
exc_info=True
)
return []
async def _bm25_search(
self,
client: QdrantClient,
tenant_id: str,
query: str,
limit: int,
) -> list[dict[str, Any]]:
"""
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
This is a simplified implementation; for production, use Elasticsearch.
"""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
query_terms = set(re.findall(r'\w+', query.lower()))
results = await qdrant.scroll(
collection_name=collection_name,
limit=limit * 3,
with_payload=True,
)
scored_results = []
for point in results[0]:
text = point.payload.get("text", "").lower()
text_terms = set(re.findall(r'\w+', text))
overlap = len(query_terms & text_terms)
if overlap > 0:
score = overlap / (len(query_terms) + len(text_terms) - overlap)
scored_results.append({
"id": str(point.id),
"score": score,
"payload": point.payload or {},
})
scored_results.sort(key=lambda x: x["score"], reverse=True)
return scored_results[:limit]
except Exception as e:
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
return []
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
import numpy as np
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def health_check(self) -> bool:
"""Check if retriever is healthy."""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[RAG-OPT] Health check failed: {e}")
return False
_optimized_retriever: OptimizedRetriever | None = None
async def get_optimized_retriever() -> OptimizedRetriever:
"""Get or create OptimizedRetriever instance."""
global _optimized_retriever
if _optimized_retriever is None:
_optimized_retriever = OptimizedRetriever()
return _optimized_retriever