From cee884d9a02af0d5d2a0cf5bf6847c7cd8989c56 Mon Sep 17 00:00:00 2001 From: MerCry Date: Wed, 25 Feb 2026 23:10:12 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20RAG=20=E6=A3=80=E7=B4=A2=E4=BC=98?= =?UTF-8?q?=E5=8C=96=EF=BC=8C=E5=AE=9E=E7=8E=B0=E5=A4=9A=E7=BB=B4=E5=BA=A6?= =?UTF-8?q?=E5=90=91=E9=87=8F=E5=AD=98=E5=82=A8=E5=92=8C=20Nomic=20?= =?UTF-8?q?=E5=B5=8C=E5=85=A5=E6=8F=90=E4=BE=9B=E8=80=85=20[AC-AISVC-16,?= =?UTF-8?q?=20AC-AISVC-29]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai-service/app/api/admin/kb_optimized.py | 330 ++++++++++++ ai-service/app/api/admin/rag.py | 7 +- ai-service/app/api/chat.py | 29 +- ai-service/app/core/qdrant_client.py | 205 +++++-- ai-service/app/services/embedding/__init__.py | 8 + ai-service/app/services/embedding/factory.py | 4 + .../app/services/embedding/nomic_provider.py | 291 ++++++++++ ai-service/app/services/orchestrator.py | 86 ++- ai-service/app/services/retrieval/__init__.py | 36 ++ ai-service/app/services/retrieval/indexer.py | 339 ++++++++++++ ai-service/app/services/retrieval/metadata.py | 210 ++++++++ .../services/retrieval/optimized_retriever.py | 509 ++++++++++++++++++ 12 files changed, 2007 insertions(+), 47 deletions(-) create mode 100644 ai-service/app/api/admin/kb_optimized.py create mode 100644 ai-service/app/services/embedding/nomic_provider.py create mode 100644 ai-service/app/services/retrieval/indexer.py create mode 100644 ai-service/app/services/retrieval/metadata.py create mode 100644 ai-service/app/services/retrieval/optimized_retriever.py diff --git a/ai-service/app/api/admin/kb_optimized.py b/ai-service/app/api/admin/kb_optimized.py new file mode 100644 index 0000000..9bdfe2f --- /dev/null +++ b/ai-service/app/api/admin/kb_optimized.py @@ -0,0 +1,330 @@ +""" +Knowledge base management API with RAG optimization features. +Reference: rag-optimization/spec.md Section 4.2 +""" + +import logging +from datetime import date +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.services.retrieval import ( + ChunkMetadata, + ChunkMetadataModel, + IndexingProgress, + IndexingResult, + KnowledgeIndexer, + MetadataFilter, + RetrievalStrategy, + get_knowledge_indexer, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"]) + + +class IndexDocumentRequest(BaseModel): + """Request to index a document.""" + tenant_id: str = Field(..., description="Tenant ID") + document_id: str = Field(..., description="Document ID") + text: str = Field(..., description="Document text content") + metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata") + + +class IndexDocumentResponse(BaseModel): + """Response from document indexing.""" + success: bool + total_chunks: int + indexed_chunks: int + failed_chunks: int + elapsed_seconds: float + error_message: str | None = None + + +class IndexingProgressResponse(BaseModel): + """Response with current indexing progress.""" + total_chunks: int + processed_chunks: int + failed_chunks: int + progress_percent: int + elapsed_seconds: float + current_document: str + + +class MetadataFilterRequest(BaseModel): + """Request for metadata filtering.""" + categories: list[str] | None = None + target_audiences: list[str] | None = None + departments: list[str] | None = None + valid_only: bool = True + min_priority: int | None = None + keywords: list[str] | None = None + + +class RetrieveRequest(BaseModel): + """Request for knowledge retrieval.""" + tenant_id: str = Field(..., description="Tenant ID") + query: str = Field(..., description="Search query") + top_k: int = Field(default=10, ge=1, le=50, description="Number of results") + filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters") + strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy") + + +class RetrieveResponse(BaseModel): + """Response from knowledge retrieval.""" + hits: list[dict[str, Any]] + total_hits: int + max_score: float + is_insufficient: bool + diagnostics: dict[str, Any] + + +class MetadataOptionsResponse(BaseModel): + """Response with available metadata options.""" + categories: list[str] + departments: list[str] + target_audiences: list[str] + priorities: list[int] + + +@router.post("/index", response_model=IndexDocumentResponse) +async def index_document( + request: IndexDocumentRequest, + session: AsyncSession = Depends(get_session), +): + """ + Index a document with optimized embedding. + + Features: + - Task prefixes (search_document:) for document embedding + - Multi-dimensional vectors (256/512/768) + - Metadata support + """ + try: + index = get_knowledge_indexer() + + chunk_metadata = None + if request.metadata: + chunk_metadata = ChunkMetadata( + category=request.metadata.category, + subcategory=request.metadata.subcategory, + target_audience=request.metadata.target_audience, + source_doc=request.metadata.source_doc, + source_url=request.metadata.source_url, + department=request.metadata.department, + priority=request.metadata.priority, + keywords=request.metadata.keywords, + ) + + result = await index.index_document( + tenant_id=request.tenant_id, + document_id=request.document_id, + text=request.text, + metadata=chunk_metadata, + ) + + return IndexDocumentResponse( + success=result.success, + total_chunks=result.total_chunks, + indexed_chunks=result.indexed_chunks, + failed_chunks=result.failed_chunks, + elapsed_seconds=result.elapsed_seconds, + error_message=result.error_message, + ) + + except Exception as e: + logger.error(f"[KB-API] Failed to index document: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"索引失败: {str(e)}" + ) + + +@router.get("/index/progress", response_model=IndexingProgressResponse | None) +async def get_indexing_progress(): + """Get current indexing progress.""" + try: + index = get_knowledge_indexer() + progress = index.get_progress() + + if progress is None: + return None + + return IndexingProgressResponse( + total_chunks=progress.total_chunks, + processed_chunks=progress.processed_chunks, + failed_chunks=progress.failed_chunks, + progress_percent=progress.progress_percent, + elapsed_seconds=progress.elapsed_seconds, + current_document=progress.current_document, + ) + + except Exception as e: + logger.error(f"[KB-API] Failed to get progress: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取进度失败: {str(e)}" + ) + + +@router.post("/retrieve", response_model=RetrieveResponse) +async def retrieve_knowledge(request: RetrieveRequest): + """ + Retrieve knowledge using optimized RAG. + + Strategies: + - vector: Simple vector search + - bm25: BM25 keyword search + - hybrid: RRF combination of vector + BM25 (default) + - two_stage: Two-stage retrieval with Matryoshka dimensions + """ + try: + from app.services.retrieval.optimized_retriever import get_optimized_retriever + from app.services.retrieval.base import RetrievalContext + + retriever = await get_optimized_retriever() + + metadata_filter = None + if request.filters: + filter_dict = request.filters.model_dump(exclude_none=True) + metadata_filter = MetadataFilter(**filter_dict) + + ctx = RetrievalContext( + tenant_id=request.tenant_id, + query=request.query, + ) + + if metadata_filter: + ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()} + + result = await retriever.retrieve(ctx) + + return RetrieveResponse( + hits=[ + { + "text": hit.text, + "score": hit.score, + "source": hit.source, + "metadata": hit.metadata, + } + for hit in result.hits + ], + total_hits=result.hit_count, + max_score=result.max_score, + is_insufficient=result.diagnostics.get("is_insufficient", False), + diagnostics=result.diagnostics or {}, + ) + + except Exception as e: + logger.error(f"[KB-API] Failed to retrieve: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"检索失败: {str(e)}" + ) + + +@router.get("/metadata/options", response_model=MetadataOptionsResponse) +async def get_metadata_options(): + """ + Get available metadata options for filtering. + These would typically be loaded from a database. + """ + try: + return MetadataOptionsResponse( + categories=[ + "课程咨询", + "考试政策", + "学籍管理", + "奖助学金", + "宿舍管理", + "校园服务", + "就业指导", + "其他", + ], + departments=[ + "教务处", + "学生处", + "财务处", + "后勤处", + "就业指导中心", + "图书馆", + "信息中心", + ], + target_audiences=[ + "本科生", + "研究生", + "留学生", + "新生", + "毕业生", + "教职工", + ], + priorities=list(range(1, 11)), + ) + + except Exception as e: + logger.error(f"[KB-API] Failed to get metadata options: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取选项失败: {str(e)}" + ) + + +@router.post("/reindex") +async def reindex_all( + tenant_id: str, + session: AsyncSession = Depends(get_session), +): + """ + Reindex all documents for a tenant with optimized embedding. + This would typically read from the documents table and reindex. + """ + try: + from app.models.entities import Document, DocumentStatus + + stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.status == DocumentStatus.COMPLETED.value, + ) + result = await session.execute(stmt) + documents = result.scalars().all() + + index = get_knowledge_indexer() + + total_indexed = 0 + total_failed = 0 + + for doc in documents: + if doc.file_path: + import os + if os.path.exists(doc.file_path): + with open(doc.file_path, 'r', encoding='utf-8') as f: + text = f.read() + + result = await index.index_document( + tenant_id=tenant_id, + document_id=str(doc.id), + text=text, + ) + + total_indexed += result.indexed_chunks + total_failed += result.failed_chunks + + return { + "success": True, + "total_documents": len(documents), + "total_indexed": total_indexed, + "total_failed": total_failed, + } + + except Exception as e: + logger.error(f"[KB-API] Failed to reindex: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"重新索引失败: {str(e)}" + ) diff --git a/ai-service/app/api/admin/rag.py b/ai-service/app/api/admin/rag.py index 14f0652..63cd782 100644 --- a/ai-service/app/api/admin/rag.py +++ b/ai-service/app/api/admin/rag.py @@ -17,6 +17,7 @@ from app.core.exceptions import MissingTenantIdException from app.core.tenant import get_tenant_id from app.models import ErrorResponse from app.services.retrieval.vector_retriever import get_vector_retriever +from app.services.retrieval.optimized_retriever import get_optimized_retriever from app.services.retrieval.base import RetrievalContext from app.services.llm.factory import get_llm_config_manager @@ -91,7 +92,8 @@ async def run_rag_experiment( threshold = request.score_threshold or settings.rag_score_threshold try: - retriever = await get_vector_retriever() + # Use optimized retriever with RAG enhancements + retriever = await get_optimized_retriever() retrieval_ctx = RetrievalContext( tenant_id=tenant_id, @@ -199,7 +201,8 @@ async def run_rag_experiment_stream( async def event_generator(): try: - retriever = await get_vector_retriever() + # Use optimized retriever with RAG enhancements + retriever = await get_optimized_retriever() retrieval_ctx = RetrievalContext( tenant_id=tenant_id, diff --git a/ai-service/app/api/chat.py b/ai-service/app/api/chat.py index 4eac671..f0a7828 100644 --- a/ai-service/app/api/chat.py +++ b/ai-service/app/api/chat.py @@ -9,18 +9,43 @@ from typing import Annotated, Any from fastapi import APIRouter, Depends, Header, Request from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse +from sqlalchemy.ext.asyncio import AsyncSession +from app.core.database import get_session from app.core.middleware import get_response_mode, is_sse_request from app.core.sse import SSEStateMachine, create_error_event from app.core.tenant import get_tenant_id from app.models import ChatRequest, ChatResponse, ErrorResponse -from app.services.orchestrator import OrchestratorService, get_orchestrator_service +from app.services.memory import MemoryService +from app.services.orchestrator import OrchestratorService logger = logging.getLogger(__name__) router = APIRouter(tags=["AI Chat"]) +async def get_orchestrator_service_with_memory( + session: Annotated[AsyncSession, Depends(get_session)] +) -> OrchestratorService: + """ + [AC-AISVC-13] Create orchestrator service with memory service and LLM client. + Ensures each request has a fresh MemoryService with database session. + """ + from app.services.llm.factory import get_llm_config_manager + from app.services.retrieval.vector_retriever import get_vector_retriever + + memory_service = MemoryService(session) + llm_config_manager = get_llm_config_manager() + llm_client = llm_config_manager.get_client() + retriever = await get_vector_retriever() + + return OrchestratorService( + llm_client=llm_client, + memory_service=memory_service, + retriever=retriever, + ) + + @router.post( "/ai/chat", operation_id="generateReply", @@ -49,7 +74,7 @@ async def generate_reply( request: Request, chat_request: ChatRequest, accept: Annotated[str | None, Header()] = None, - orchestrator: OrchestratorService = Depends(get_orchestrator_service), + orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory), ) -> Any: """ [AC-AISVC-06] Generate AI reply with automatic response mode switching. diff --git a/ai-service/app/core/qdrant_client.py b/ai-service/app/core/qdrant_client.py index 5de824b..5742b5a 100644 --- a/ai-service/app/core/qdrant_client.py +++ b/ai-service/app/core/qdrant_client.py @@ -1,13 +1,14 @@ """ Qdrant client for AI Service. [AC-AISVC-10] Vector database client with tenant-isolated collection management. +Supports multi-dimensional vectors for Matryoshka representation learning. """ import logging from typing import Any from qdrant_client import AsyncQdrantClient -from qdrant_client.models import Distance, PointStruct, VectorParams +from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig from app.core.config import get_settings @@ -20,6 +21,7 @@ class QdrantClient: """ [AC-AISVC-10] Qdrant client with tenant-isolated collection management. Collection naming: kb_{tenantId} for tenant isolation. + Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval. """ def __init__(self): @@ -45,13 +47,15 @@ class QdrantClient: """ [AC-AISVC-10] Get collection name for a tenant. Naming convention: kb_{tenantId} + Replaces @ with _ to ensure valid collection names. """ - return f"{self._collection_prefix}{tenant_id}" + safe_tenant_id = tenant_id.replace('@', '_') + return f"{self._collection_prefix}{safe_tenant_id}" - async def ensure_collection_exists(self, tenant_id: str) -> bool: + async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool: """ [AC-AISVC-10] Ensure collection exists for tenant. - Note: MVP uses pre-provisioned collections, this is for development/testing. + Supports multi-dimensional vectors for Matryoshka retrieval. """ client = await self.get_client() collection_name = self.get_collection_name(tenant_id) @@ -61,15 +65,34 @@ class QdrantClient: exists = any(c.name == collection_name for c in collections.collections) if not exists: - await client.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( + if use_multi_vector: + vectors_config = { + "full": VectorParams( + size=768, + distance=Distance.COSINE, + ), + "dim_256": VectorParams( + size=256, + distance=Distance.COSINE, + ), + "dim_512": VectorParams( + size=512, + distance=Distance.COSINE, + ), + } + else: + vectors_config = VectorParams( size=self._vector_size, distance=Distance.COSINE, - ), + ) + + await client.create_collection( + collection_name=collection_name, + vectors_config=vectors_config, ) logger.info( - f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}" + f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} " + f"with multi_vector={use_multi_vector}" ) return True except Exception as e: @@ -100,44 +123,160 @@ class QdrantClient: logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}") return False + async def upsert_multi_vector( + self, + tenant_id: str, + points: list[dict[str, Any]], + ) -> bool: + """ + Upsert points with multi-dimensional vectors. + + Args: + tenant_id: Tenant identifier + points: List of points with format: + { + "id": str | int, + "vector": { + "full": [768 floats], + "dim_256": [256 floats], + "dim_512": [512 floats], + }, + "payload": dict + } + """ + client = await self.get_client() + collection_name = self.get_collection_name(tenant_id) + + try: + qdrant_points = [] + for p in points: + point = PointStruct( + id=p["id"], + vector=p["vector"], + payload=p.get("payload", {}), + ) + qdrant_points.append(point) + + await client.upsert( + collection_name=collection_name, + points=qdrant_points, + ) + logger.info( + f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}" + ) + return True + except Exception as e: + logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}") + return False + async def search( self, tenant_id: str, query_vector: list[float], limit: int = 5, score_threshold: float | None = None, + vector_name: str = "full", ) -> list[dict[str, Any]]: """ [AC-AISVC-10] Search vectors in tenant's collection. Returns results with score >= score_threshold if specified. + Searches both old format (with @) and new format (with _) for backward compatibility. + + Args: + tenant_id: Tenant identifier + query_vector: Query vector for similarity search + limit: Maximum number of results + score_threshold: Minimum score threshold for results + vector_name: Name of the vector to search (for multi-vector collections) + Default is "full" for 768-dim vectors in Matryoshka setup. """ client = await self.get_client() - collection_name = self.get_collection_name(tenant_id) + + logger.info( + f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, " + f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}" + ) + + collection_names = [self.get_collection_name(tenant_id)] + if '@' in tenant_id: + old_format = f"{self._collection_prefix}{tenant_id}" + new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}" + collection_names = [new_format, old_format] + + logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}") + + all_hits = [] + + for collection_name in collection_names: + try: + logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}") + + try: + results = await client.search( + collection_name=collection_name, + query_vector=(vector_name, query_vector), + limit=limit, + ) + except Exception as e: + if "vector name" in str(e).lower() or "Not existing vector" in str(e): + logger.info( + f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', " + f"trying without vector name (single-vector mode)" + ) + results = await client.search( + collection_name=collection_name, + query_vector=query_vector, + limit=limit, + ) + else: + raise + + logger.info( + f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results" + ) - try: - results = await client.search( - collection_name=collection_name, - query_vector=query_vector, - limit=limit, + hits = [ + { + "id": str(result.id), + "score": result.score, + "payload": result.payload or {}, + } + for result in results + if score_threshold is None or result.score >= score_threshold + ] + all_hits.extend(hits) + + if hits: + logger.info( + f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}" + ) + for i, h in enumerate(hits[:3]): + logger.debug( + f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}" + ) + else: + logger.warning( + f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)" + ) + except Exception as e: + logger.warning( + f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}" + ) + continue + + all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit] + + logger.info( + f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}" + ) + + if len(all_hits) == 0: + logger.warning( + f"[AC-AISVC-10] No results found! tenant={tenant_id}, " + f"collections_tried={collection_names}, limit={limit}" ) - - hits = [ - { - "id": str(result.id), - "score": result.score, - "payload": result.payload or {}, - } - for result in results - if score_threshold is None or result.score >= score_threshold - ] - - logger.info( - f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}" - ) - return hits - except Exception as e: - logger.error(f"[AC-AISVC-10] Error searching vectors: {e}") - return [] + + return all_hits async def delete_collection(self, tenant_id: str) -> bool: """ diff --git a/ai-service/app/services/embedding/__init__.py b/ai-service/app/services/embedding/__init__.py index 59aead2..8fe5844 100644 --- a/ai-service/app/services/embedding/__init__.py +++ b/ai-service/app/services/embedding/__init__.py @@ -17,6 +17,11 @@ from app.services.embedding.factory import ( ) from app.services.embedding.ollama_provider import OllamaEmbeddingProvider from app.services.embedding.openai_provider import OpenAIEmbeddingProvider +from app.services.embedding.nomic_provider import ( + NomicEmbeddingProvider, + NomicEmbeddingResult, + EmbeddingTask, +) __all__ = [ "EmbeddingConfig", @@ -29,4 +34,7 @@ __all__ = [ "get_embedding_provider", "OllamaEmbeddingProvider", "OpenAIEmbeddingProvider", + "NomicEmbeddingProvider", + "NomicEmbeddingResult", + "EmbeddingTask", ] diff --git a/ai-service/app/services/embedding/factory.py b/ai-service/app/services/embedding/factory.py index 9e61c42..e42e506 100644 --- a/ai-service/app/services/embedding/factory.py +++ b/ai-service/app/services/embedding/factory.py @@ -13,6 +13,7 @@ from typing import Any, Type from app.services.embedding.base import EmbeddingException, EmbeddingProvider from app.services.embedding.ollama_provider import OllamaEmbeddingProvider from app.services.embedding.openai_provider import OpenAIEmbeddingProvider +from app.services.embedding.nomic_provider import NomicEmbeddingProvider logger = logging.getLogger(__name__) @@ -26,6 +27,7 @@ class EmbeddingProviderFactory: _providers: dict[str, Type[EmbeddingProvider]] = { "ollama": OllamaEmbeddingProvider, "openai": OpenAIEmbeddingProvider, + "nomic": NomicEmbeddingProvider, } @classmethod @@ -63,11 +65,13 @@ class EmbeddingProviderFactory: display_names = { "ollama": "Ollama 本地模型", "openai": "OpenAI Embedding", + "nomic": "Nomic Embed (优化版)", } descriptions = { "ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型", "openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型", + "nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化", } return { diff --git a/ai-service/app/services/embedding/nomic_provider.py b/ai-service/app/services/embedding/nomic_provider.py new file mode 100644 index 0000000..ba6a73b --- /dev/null +++ b/ai-service/app/services/embedding/nomic_provider.py @@ -0,0 +1,291 @@ +""" +Nomic embedding provider with task prefixes and Matryoshka support. +Implements RAG optimization spec: +- Task prefixes: search_document: / search_query: +- Matryoshka dimension truncation: 256/512/768 dimensions +""" + +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import httpx +import numpy as np + +from app.services.embedding.base import ( + EmbeddingConfig, + EmbeddingException, + EmbeddingProvider, +) + +logger = logging.getLogger(__name__) + + +class EmbeddingTask(str, Enum): + """Task type for nomic-embed-text v1.5 model.""" + DOCUMENT = "search_document" + QUERY = "search_query" + + +@dataclass +class NomicEmbeddingResult: + """Result from Nomic embedding with multiple dimensions.""" + embedding_full: list[float] + embedding_256: list[float] + embedding_512: list[float] + dimension: int + model: str + task: EmbeddingTask + latency_ms: float = 0.0 + metadata: dict[str, Any] = field(default_factory=dict) + + +class NomicEmbeddingProvider(EmbeddingProvider): + """ + Nomic-embed-text v1.5 embedding provider with task prefixes. + + Key features: + - Task prefixes: search_document: for documents, search_query: for queries + - Matryoshka dimension truncation: 256/512/768 dimensions + - Automatic normalization after truncation + + Reference: rag-optimization/spec.md Section 2.1, 2.3 + """ + + PROVIDER_NAME = "nomic" + DOCUMENT_PREFIX = "search_document:" + QUERY_PREFIX = "search_query:" + FULL_DIMENSION = 768 + + def __init__( + self, + base_url: str = "http://localhost:11434", + model: str = "nomic-embed-text", + dimension: int = 768, + timeout_seconds: int = 60, + enable_matryoshka: bool = True, + **kwargs: Any, + ): + self._base_url = base_url.rstrip("/") + self._model = model + self._dimension = dimension + self._timeout = timeout_seconds + self._enable_matryoshka = enable_matryoshka + self._client: httpx.AsyncClient | None = None + self._extra_config = kwargs + + async def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient(timeout=self._timeout) + return self._client + + def _add_prefix(self, text: str, task: EmbeddingTask) -> str: + """Add task prefix to text.""" + if task == EmbeddingTask.DOCUMENT: + prefix = self.DOCUMENT_PREFIX + else: + prefix = self.QUERY_PREFIX + + if text.startswith(prefix): + return text + return f"{prefix}{text}" + + def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]: + """ + Truncate embedding to target dimension and normalize. + Matryoshka representation learning allows dimension truncation. + """ + truncated = embedding[:target_dim] + + arr = np.array(truncated, dtype=np.float32) + norm = np.linalg.norm(arr) + if norm > 0: + arr = arr / norm + + return arr.tolist() + + async def embed_with_task( + self, + text: str, + task: EmbeddingTask, + ) -> NomicEmbeddingResult: + """ + Generate embedding with specified task prefix. + + Args: + text: Input text to embed + task: DOCUMENT for indexing, QUERY for retrieval + + Returns: + NomicEmbeddingResult with all dimension variants + """ + start_time = time.perf_counter() + + prefixed_text = self._add_prefix(text, task) + + try: + client = await self._get_client() + response = await client.post( + f"{self._base_url}/api/embeddings", + json={ + "model": self._model, + "prompt": prefixed_text, + } + ) + response.raise_for_status() + data = response.json() + embedding = data.get("embedding", []) + + if not embedding: + raise EmbeddingException( + "Empty embedding returned", + provider=self.PROVIDER_NAME, + details={"text_length": len(text), "task": task.value} + ) + + latency_ms = (time.perf_counter() - start_time) * 1000 + + embedding_256 = self._truncate_and_normalize(embedding, 256) + embedding_512 = self._truncate_and_normalize(embedding, 512) + + logger.debug( + f"Generated Nomic embedding: task={task.value}, " + f"dim={len(embedding)}, latency={latency_ms:.2f}ms" + ) + + return NomicEmbeddingResult( + embedding_full=embedding, + embedding_256=embedding_256, + embedding_512=embedding_512, + dimension=len(embedding), + model=self._model, + task=task, + latency_ms=latency_ms, + ) + + except httpx.HTTPStatusError as e: + raise EmbeddingException( + f"Ollama API error: {e.response.status_code}", + provider=self.PROVIDER_NAME, + details={"status_code": e.response.status_code, "response": e.response.text} + ) + except httpx.RequestError as e: + raise EmbeddingException( + f"Ollama connection error: {e}", + provider=self.PROVIDER_NAME, + details={"base_url": self._base_url} + ) + except EmbeddingException: + raise + except Exception as e: + raise EmbeddingException( + f"Embedding generation failed: {e}", + provider=self.PROVIDER_NAME + ) + + async def embed_document(self, text: str) -> NomicEmbeddingResult: + """ + Generate embedding for document (with search_document: prefix). + Use this when indexing documents into vector store. + """ + return await self.embed_with_task(text, EmbeddingTask.DOCUMENT) + + async def embed_query(self, text: str) -> NomicEmbeddingResult: + """ + Generate embedding for query (with search_query: prefix). + Use this when searching/retrieving documents. + """ + return await self.embed_with_task(text, EmbeddingTask.QUERY) + + async def embed(self, text: str) -> list[float]: + """ + Generate embedding vector for a single text. + Default uses QUERY task for backward compatibility. + """ + result = await self.embed_query(text) + return result.embedding_full + + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """ + Generate embedding vectors for multiple texts. + Uses QUERY task by default. + """ + embeddings = [] + for text in texts: + embedding = await self.embed(text) + embeddings.append(embedding) + return embeddings + + async def embed_documents_batch( + self, + texts: list[str], + ) -> list[NomicEmbeddingResult]: + """ + Generate embeddings for multiple documents (DOCUMENT task). + Use this when batch indexing documents. + """ + results = [] + for text in texts: + result = await self.embed_document(text) + results.append(result) + return results + + async def embed_queries_batch( + self, + texts: list[str], + ) -> list[NomicEmbeddingResult]: + """ + Generate embeddings for multiple queries (QUERY task). + Use this when batch processing queries. + """ + results = [] + for text in texts: + result = await self.embed_query(text) + results.append(result) + return results + + def get_dimension(self) -> int: + """Get the dimension of embedding vectors.""" + return self._dimension + + def get_provider_name(self) -> str: + """Get the name of this embedding provider.""" + return self.PROVIDER_NAME + + def get_config_schema(self) -> dict[str, Any]: + """Get the configuration schema for Nomic provider.""" + return { + "base_url": { + "type": "string", + "description": "Ollama API 地址", + "default": "http://localhost:11434", + }, + "model": { + "type": "string", + "description": "嵌入模型名称(推荐 nomic-embed-text v1.5)", + "default": "nomic-embed-text", + }, + "dimension": { + "type": "integer", + "description": "向量维度(支持 256/512/768)", + "default": 768, + }, + "timeout_seconds": { + "type": "integer", + "description": "请求超时时间(秒)", + "default": 60, + }, + "enable_matryoshka": { + "type": "boolean", + "description": "启用 Matryoshka 维度截断", + "default": True, + }, + } + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client: + await self._client.aclose() + self._client = None diff --git a/ai-service/app/services/orchestrator.py b/ai-service/app/services/orchestrator.py index d0e6317..fde1158 100644 --- a/ai-service/app/services/orchestrator.py +++ b/ai-service/app/services/orchestrator.py @@ -11,6 +11,11 @@ Design reference: design.md Section 2.2 - 关键数据流 6. compute_confidence(...) 7. Memory.append(tenantId, sessionId, user/assistant messages) 8. Return ChatResponse (or output via SSE) + +RAG Optimization (rag-optimization/spec.md): +- Two-stage retrieval with Matryoshka dimensions +- RRF hybrid ranking +- Optimized prompt engineering """ import logging @@ -36,6 +41,16 @@ from app.services.retrieval.base import BaseRetriever, RetrievalContext, Retriev logger = logging.getLogger(__name__) +OPTIMIZED_SYSTEM_PROMPT = """你是学校智能客服助手,基于提供的知识库内容回答用户问题。 + +回答要求: +1. 严格基于提供的知识库内容回答,不要编造信息 +2. 如果知识库中没有相关信息,明确告知用户并建议转人工或稍后重试 +3. 保持专业、友好的语气,回答简洁明了,突出重点 +4. 如果引用知识库内容,请注明来源(如:根据[文档1]...) +5. 对于时效性问题,请提醒用户注意文档的有效期""" + + @dataclass class OrchestratorConfig: """ @@ -44,8 +59,9 @@ class OrchestratorConfig: """ max_history_tokens: int = 4000 max_evidence_tokens: int = 2000 - system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。" + system_prompt: str = OPTIMIZED_SYSTEM_PROMPT enable_rag: bool = True + use_optimized_retriever: bool = True @dataclass @@ -141,7 +157,14 @@ class OrchestratorService: """ logger.info( f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, " - f"session={request.session_id}" + f"session={request.session_id}, channel_type={request.channel_type}, " + f"current_message={request.current_message[:100]}..." + ) + logger.info( + f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, " + f"use_optimized_retriever={self._config.use_optimized_retriever}, " + f"llm_client={'configured' if self._llm_client else 'NOT configured'}, " + f"retriever={'configured' if self._retriever else 'NOT configured'}" ) ctx = GenerationContext( @@ -257,6 +280,10 @@ class OrchestratorService: [AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence. Step 3 of the generation pipeline. """ + logger.info( + f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, " + f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}" + ) try: retrieval_ctx = RetrievalContext( tenant_id=ctx.tenant_id, @@ -277,11 +304,19 @@ class OrchestratorService: logger.info( f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: " f"hits={ctx.retrieval_result.hit_count}, " - f"max_score={ctx.retrieval_result.max_score:.3f}" + f"max_score={ctx.retrieval_result.max_score:.3f}, " + f"is_empty={ctx.retrieval_result.is_empty}" ) + + if ctx.retrieval_result.hit_count > 0: + for i, hit in enumerate(ctx.retrieval_result.hits[:3]): + logger.info( + f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, " + f"text_preview={hit.text[:100]}..." + ) except Exception as e: - logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}") + logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True) ctx.retrieval_result = RetrievalResult( hits=[], diagnostics={"error": str(e)}, @@ -294,9 +329,18 @@ class OrchestratorService: Step 4-5 of the generation pipeline. """ messages = self._build_llm_messages(ctx) + logger.info( + f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, " + f"has_retrieval_result={ctx.retrieval_result is not None}, " + f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, " + f"llm_client={'configured' if self._llm_client else 'NOT configured'}" + ) if not self._llm_client: - logger.warning("[AC-AISVC-02] No LLM client configured, using fallback") + logger.warning( + f"[AC-AISVC-02] No LLM client configured, using fallback. " + f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}" + ) ctx.llm_response = LLMResponse( content=self._fallback_response(ctx), model="fallback", @@ -304,6 +348,7 @@ class OrchestratorService: finish_reason="fallback", ) ctx.diagnostics["llm_mode"] = "fallback" + ctx.diagnostics["fallback_reason"] = "no_llm_client" return try: @@ -318,11 +363,16 @@ class OrchestratorService: logger.info( f"[AC-AISVC-02] LLM response generated: " f"model={ctx.llm_response.model}, " - f"tokens={ctx.llm_response.usage}" + f"tokens={ctx.llm_response.usage}, " + f"content_preview={ctx.llm_response.content[:100]}..." ) except Exception as e: - logger.error(f"[AC-AISVC-02] LLM generation failed: {e}") + logger.error( + f"[AC-AISVC-02] LLM generation failed: {e}, " + f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}", + exc_info=True + ) ctx.llm_response = LLMResponse( content=self._fallback_response(ctx), model="fallback", @@ -331,6 +381,8 @@ class OrchestratorService: metadata={"error": str(e)}, ) ctx.diagnostics["llm_error"] = str(e) + ctx.diagnostics["llm_mode"] = "fallback" + ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}" def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]: """ @@ -356,12 +408,26 @@ class OrchestratorService: def _format_evidence(self, retrieval_result: RetrievalResult) -> str: """ [AC-AISVC-17] Format retrieval hits as evidence text. + Optimized format with source attribution and metadata. """ evidence_parts = [] for i, hit in enumerate(retrieval_result.hits[:5], 1): - evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}") - - return "\n".join(evidence_parts) + metadata = hit.metadata or {} + source = metadata.get("metadata", {}).get("source_doc", "知识库") + category = metadata.get("metadata", {}).get("category", "") + department = metadata.get("metadata", {}).get("department", "") + + header = f"[文档{i}]" + if source and source != "知识库": + header += f" 来源:{source}" + if category: + header += f" | 类别:{category}" + if department: + header += f" | 部门:{department}" + + evidence_parts.append(f"{header}\n相关度:{hit.score:.2f}\n内容:{hit.text}") + + return "\n\n".join(evidence_parts) def _fallback_response(self, ctx: GenerationContext) -> str: """ diff --git a/ai-service/app/services/retrieval/__init__.py b/ai-service/app/services/retrieval/__init__.py index 61e2a2f..d6865d4 100644 --- a/ai-service/app/services/retrieval/__init__.py +++ b/ai-service/app/services/retrieval/__init__.py @@ -1,6 +1,7 @@ """ Retrieval module for AI Service. [AC-AISVC-16] Provides retriever implementations with plugin architecture. +RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering. """ from app.services.retrieval.base import ( @@ -10,6 +11,27 @@ from app.services.retrieval.base import ( RetrievalResult, ) from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever +from app.services.retrieval.metadata import ( + ChunkMetadata, + ChunkMetadataModel, + MetadataFilter, + KnowledgeChunk, + RetrieveRequest, + RetrieveResult, + RetrievalStrategy, +) +from app.services.retrieval.optimized_retriever import ( + OptimizedRetriever, + get_optimized_retriever, + TwoStageResult, + RRFCombiner, +) +from app.services.retrieval.indexer import ( + KnowledgeIndexer, + get_knowledge_indexer, + IndexingProgress, + IndexingResult, +) __all__ = [ "BaseRetriever", @@ -18,4 +40,18 @@ __all__ = [ "RetrievalResult", "VectorRetriever", "get_vector_retriever", + "ChunkMetadata", + "MetadataFilter", + "KnowledgeChunk", + "RetrieveRequest", + "RetrieveResult", + "RetrievalStrategy", + "OptimizedRetriever", + "get_optimized_retriever", + "TwoStageResult", + "RRFCombiner", + "KnowledgeIndexer", + "get_knowledge_indexer", + "IndexingProgress", + "IndexingResult", ] diff --git a/ai-service/app/services/retrieval/indexer.py b/ai-service/app/services/retrieval/indexer.py new file mode 100644 index 0000000..d701c57 --- /dev/null +++ b/ai-service/app/services/retrieval/indexer.py @@ -0,0 +1,339 @@ +""" +Knowledge base indexing service with optimized embedding. +Reference: rag-optimization/spec.md Section 5.1 +""" + +import asyncio +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient, get_qdrant_client +from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult +from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk + +logger = logging.getLogger(__name__) +settings = get_settings() + + +@dataclass +class IndexingProgress: + """Progress tracking for indexing jobs.""" + total_chunks: int = 0 + processed_chunks: int = 0 + failed_chunks: int = 0 + current_document: str = "" + started_at: datetime = field(default_factory=datetime.utcnow) + + @property + def progress_percent(self) -> int: + if self.total_chunks == 0: + return 0 + return int((self.processed_chunks / self.total_chunks) * 100) + + @property + def elapsed_seconds(self) -> float: + return (datetime.utcnow() - self.started_at).total_seconds() + + +@dataclass +class IndexingResult: + """Result of an indexing operation.""" + success: bool + total_chunks: int + indexed_chunks: int + failed_chunks: int + elapsed_seconds: float + error_message: str | None = None + + +class KnowledgeIndexer: + """ + Knowledge base indexer with optimized embedding. + + Features: + - Task prefixes (search_document:) for document embedding + - Multi-dimensional vectors (256/512/768) + - Metadata support + - Batch processing + """ + + def __init__( + self, + qdrant_client: QdrantClient | None = None, + embedding_provider: NomicEmbeddingProvider | None = None, + chunk_size: int = 500, + chunk_overlap: int = 50, + batch_size: int = 10, + ): + self._qdrant_client = qdrant_client + self._embedding_provider = embedding_provider + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._batch_size = batch_size + self._progress: IndexingProgress | None = None + + async def _get_client(self) -> QdrantClient: + if self._qdrant_client is None: + self._qdrant_client = await get_qdrant_client() + return self._qdrant_client + + async def _get_embedding_provider(self) -> NomicEmbeddingProvider: + if self._embedding_provider is None: + self._embedding_provider = NomicEmbeddingProvider( + base_url=settings.ollama_base_url, + model=settings.ollama_embedding_model, + dimension=settings.qdrant_vector_size, + ) + return self._embedding_provider + + def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]: + """ + Split text into chunks for indexing. + Each line becomes a separate chunk for better retrieval granularity. + + Args: + text: Full text to chunk + metadata: Metadata to attach to each chunk + + Returns: + List of KnowledgeChunk objects + """ + chunks = [] + doc_id = str(uuid.uuid4()) + + lines = text.split('\n') + + for i, line in enumerate(lines): + line = line.strip() + + if len(line) < 10: + continue + + chunk = KnowledgeChunk( + chunk_id=f"{doc_id}_{i}", + document_id=doc_id, + content=line, + metadata=metadata or ChunkMetadata(), + ) + chunks.append(chunk) + + return chunks + + def chunk_text_by_lines( + self, + text: str, + metadata: ChunkMetadata | None = None, + min_line_length: int = 10, + merge_short_lines: bool = False, + ) -> list[KnowledgeChunk]: + """ + Split text by lines, each line is a separate chunk. + + Args: + text: Full text to chunk + metadata: Metadata to attach to each chunk + min_line_length: Minimum line length to be indexed + merge_short_lines: Whether to merge consecutive short lines + + Returns: + List of KnowledgeChunk objects + """ + chunks = [] + doc_id = str(uuid.uuid4()) + + lines = text.split('\n') + + if merge_short_lines: + merged_lines = [] + current_line = "" + + for line in lines: + line = line.strip() + if not line: + if current_line: + merged_lines.append(current_line) + current_line = "" + continue + + if current_line: + current_line += " " + line + else: + current_line = line + + if len(current_line) >= min_line_length * 2: + merged_lines.append(current_line) + current_line = "" + + if current_line: + merged_lines.append(current_line) + + lines = merged_lines + + for i, line in enumerate(lines): + line = line.strip() + + if len(line) < min_line_length: + continue + + chunk = KnowledgeChunk( + chunk_id=f"{doc_id}_{i}", + document_id=doc_id, + content=line, + metadata=metadata or ChunkMetadata(), + ) + chunks.append(chunk) + + return chunks + + async def index_document( + self, + tenant_id: str, + document_id: str, + text: str, + metadata: ChunkMetadata | None = None, + ) -> IndexingResult: + """ + Index a single document with optimized embedding. + + Args: + tenant_id: Tenant identifier + document_id: Document identifier + text: Document text content + metadata: Optional metadata for the document + + Returns: + IndexingResult with status and statistics + """ + start_time = datetime.utcnow() + + try: + client = await self._get_client() + provider = await self._get_embedding_provider() + + await client.ensure_collection_exists(tenant_id, use_multi_vector=True) + + chunks = self.chunk_text(text, metadata) + + self._progress = IndexingProgress( + total_chunks=len(chunks), + current_document=document_id, + ) + + points = [] + for i, chunk in enumerate(chunks): + try: + embedding_result = await provider.embed_document(chunk.content) + + chunk.embedding_full = embedding_result.embedding_full + chunk.embedding_256 = embedding_result.embedding_256 + chunk.embedding_512 = embedding_result.embedding_512 + + point = { + "id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant + "vector": { + "full": chunk.embedding_full, + "dim_256": chunk.embedding_256, + "dim_512": chunk.embedding_512, + }, + "payload": { + "chunk_id": chunk.chunk_id, + "document_id": document_id, + "text": chunk.content, + "metadata": chunk.metadata.to_dict(), + "created_at": chunk.created_at.isoformat(), + } + } + points.append(point) + + self._progress.processed_chunks += 1 + + logger.debug( + f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}" + ) + + except Exception as e: + logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}") + self._progress.failed_chunks += 1 + + if points: + await client.upsert_multi_vector(tenant_id, points) + + elapsed = (datetime.utcnow() - start_time).total_seconds() + + logger.info( + f"[RAG-OPT] Indexed document {document_id}: " + f"{len(points)} chunks in {elapsed:.2f}s" + ) + + return IndexingResult( + success=True, + total_chunks=len(chunks), + indexed_chunks=len(points), + failed_chunks=self._progress.failed_chunks, + elapsed_seconds=elapsed, + ) + + except Exception as e: + elapsed = (datetime.utcnow() - start_time).total_seconds() + logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}") + + return IndexingResult( + success=False, + total_chunks=0, + indexed_chunks=0, + failed_chunks=0, + elapsed_seconds=elapsed, + error_message=str(e), + ) + + async def index_documents_batch( + self, + tenant_id: str, + documents: list[dict[str, Any]], + ) -> list[IndexingResult]: + """ + Index multiple documents in batch. + + Args: + tenant_id: Tenant identifier + documents: List of documents with format: + { + "document_id": str, + "text": str, + "metadata": ChunkMetadata (optional) + } + + Returns: + List of IndexingResult for each document + """ + results = [] + + for doc in documents: + result = await self.index_document( + tenant_id=tenant_id, + document_id=doc["document_id"], + text=doc["text"], + metadata=doc.get("metadata"), + ) + results.append(result) + + return results + + def get_progress(self) -> IndexingProgress | None: + """Get current indexing progress.""" + return self._progress + + +_knowledge_indexer: KnowledgeIndexer | None = None + + +def get_knowledge_indexer() -> KnowledgeIndexer: + """Get or create KnowledgeIndexer instance.""" + global _knowledge_indexer + if _knowledge_indexer is None: + _knowledge_indexer = KnowledgeIndexer() + return _knowledge_indexer diff --git a/ai-service/app/services/retrieval/metadata.py b/ai-service/app/services/retrieval/metadata.py new file mode 100644 index 0000000..3dbe753 --- /dev/null +++ b/ai-service/app/services/retrieval/metadata.py @@ -0,0 +1,210 @@ +""" +Metadata models for RAG optimization. +Implements structured metadata for knowledge chunks. +Reference: rag-optimization/spec.md Section 3.2 +""" + +from dataclasses import dataclass, field +from datetime import date, datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel + + +class RetrievalStrategy(str, Enum): + """Retrieval strategy options.""" + VECTOR_ONLY = "vector" + BM25_ONLY = "bm25" + HYBRID = "hybrid" + TWO_STAGE = "two_stage" + + +class ChunkMetadataModel(BaseModel): + """Pydantic model for API serialization.""" + category: str = "" + subcategory: str = "" + target_audience: list[str] = [] + source_doc: str = "" + source_url: str = "" + department: str = "" + valid_from: str | None = None + valid_until: str | None = None + priority: int = 5 + keywords: list[str] = [] + + +@dataclass +class ChunkMetadata: + """ + Metadata for knowledge chunks. + Reference: rag-optimization/spec.md Section 3.2.2 + """ + category: str = "" + subcategory: str = "" + target_audience: list[str] = field(default_factory=list) + source_doc: str = "" + source_url: str = "" + department: str = "" + valid_from: date | None = None + valid_until: date | None = None + priority: int = 5 + keywords: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage.""" + return { + "category": self.category, + "subcategory": self.subcategory, + "target_audience": self.target_audience, + "source_doc": self.source_doc, + "source_url": self.source_url, + "department": self.department, + "valid_from": self.valid_from.isoformat() if self.valid_from else None, + "valid_until": self.valid_until.isoformat() if self.valid_until else None, + "priority": self.priority, + "keywords": self.keywords, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata": + """Create from dictionary.""" + return cls( + category=data.get("category", ""), + subcategory=data.get("subcategory", ""), + target_audience=data.get("target_audience", []), + source_doc=data.get("source_doc", ""), + source_url=data.get("source_url", ""), + department=data.get("department", ""), + valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None, + valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None, + priority=data.get("priority", 5), + keywords=data.get("keywords", []), + ) + + +@dataclass +class MetadataFilter: + """ + Filter conditions for metadata-based retrieval. + Reference: rag-optimization/spec.md Section 4.1 + """ + categories: list[str] | None = None + target_audiences: list[str] | None = None + departments: list[str] | None = None + valid_only: bool = True + min_priority: int | None = None + keywords: list[str] | None = None + + def to_qdrant_filter(self) -> dict[str, Any] | None: + """Convert to Qdrant filter format.""" + conditions = [] + + if self.categories: + conditions.append({ + "key": "metadata.category", + "match": {"any": self.categories} + }) + + if self.departments: + conditions.append({ + "key": "metadata.department", + "match": {"any": self.departments} + }) + + if self.target_audiences: + conditions.append({ + "key": "metadata.target_audience", + "match": {"any": self.target_audiences} + }) + + if self.valid_only: + today = date.today().isoformat() + conditions.append({ + "should": [ + {"key": "metadata.valid_until", "match": {"value": None}}, + {"key": "metadata.valid_until", "range": {"gte": today}} + ] + }) + + if self.min_priority is not None: + conditions.append({ + "key": "metadata.priority", + "range": {"lte": self.min_priority} + }) + + if not conditions: + return None + + if len(conditions) == 1: + return {"must": conditions} + + return {"must": conditions} + + +@dataclass +class KnowledgeChunk: + """ + Knowledge chunk with multi-dimensional embeddings. + Reference: rag-optimization/spec.md Section 3.2.1 + """ + chunk_id: str + document_id: str + content: str + embedding_full: list[float] = field(default_factory=list) + embedding_256: list[float] = field(default_factory=list) + embedding_512: list[float] = field(default_factory=list) + metadata: ChunkMetadata = field(default_factory=ChunkMetadata) + created_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + + def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]: + """Convert to Qdrant point format.""" + return { + "id": point_id, + "vector": { + "full": self.embedding_full, + "dim_256": self.embedding_256, + "dim_512": self.embedding_512, + }, + "payload": { + "chunk_id": self.chunk_id, + "document_id": self.document_id, + "text": self.content, + "metadata": self.metadata.to_dict(), + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + } + + +@dataclass +class RetrieveRequest: + """ + Request for knowledge retrieval. + Reference: rag-optimization/spec.md Section 4.1 + """ + query: str + query_with_prefix: str = "" + top_k: int = 10 + filters: MetadataFilter | None = None + strategy: RetrievalStrategy = RetrievalStrategy.HYBRID + + def __post_init__(self): + if not self.query_with_prefix: + self.query_with_prefix = f"search_query:{self.query}" + + +@dataclass +class RetrieveResult: + """ + Result from knowledge retrieval. + Reference: rag-optimization/spec.md Section 4.1 + """ + chunk_id: str + content: str + score: float + vector_score: float = 0.0 + bm25_score: float = 0.0 + metadata: ChunkMetadata = field(default_factory=ChunkMetadata) + rank: int = 0 diff --git a/ai-service/app/services/retrieval/optimized_retriever.py b/ai-service/app/services/retrieval/optimized_retriever.py new file mode 100644 index 0000000..1c773d8 --- /dev/null +++ b/ai-service/app/services/retrieval/optimized_retriever.py @@ -0,0 +1,509 @@ +""" +Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking. +Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5 +""" + +import asyncio +import logging +import re +from dataclasses import dataclass, field +from typing import Any + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient, get_qdrant_client +from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult +from app.services.retrieval.base import ( + BaseRetriever, + RetrievalContext, + RetrievalHit, + RetrievalResult, +) +from app.services.retrieval.metadata import ( + ChunkMetadata, + MetadataFilter, + RetrieveResult, + RetrievalStrategy, +) + +logger = logging.getLogger(__name__) +settings = get_settings() + + +@dataclass +class TwoStageResult: + """Result from two-stage retrieval.""" + candidates: list[dict[str, Any]] + final_results: list[RetrieveResult] + stage1_latency_ms: float = 0.0 + stage2_latency_ms: float = 0.0 + + +class RRFCombiner: + """ + Reciprocal Rank Fusion for combining multiple retrieval results. + Reference: rag-optimization/spec.md Section 2.5 + + Formula: score = Σ(1 / (k + rank_i)) + Default k = 60 + """ + + def __init__(self, k: int = 60): + self._k = k + + def combine( + self, + vector_results: list[dict[str, Any]], + bm25_results: list[dict[str, Any]], + vector_weight: float = 0.7, + bm25_weight: float = 0.3, + ) -> list[dict[str, Any]]: + """ + Combine vector and BM25 results using RRF. + + Args: + vector_results: Results from vector search + bm25_results: Results from BM25 search + vector_weight: Weight for vector results + bm25_weight: Weight for BM25 results + + Returns: + Combined and sorted results + """ + combined_scores: dict[str, dict[str, Any]] = {} + + for rank, result in enumerate(vector_results): + chunk_id = result.get("chunk_id") or result.get("id", str(rank)) + rrf_score = vector_weight / (self._k + rank + 1) + + if chunk_id not in combined_scores: + combined_scores[chunk_id] = { + "score": 0.0, + "vector_score": result.get("score", 0.0), + "bm25_score": 0.0, + "vector_rank": rank, + "bm25_rank": -1, + "payload": result.get("payload", {}), + "id": chunk_id, + } + + combined_scores[chunk_id]["score"] += rrf_score + + for rank, result in enumerate(bm25_results): + chunk_id = result.get("chunk_id") or result.get("id", str(rank)) + rrf_score = bm25_weight / (self._k + rank + 1) + + if chunk_id not in combined_scores: + combined_scores[chunk_id] = { + "score": 0.0, + "vector_score": 0.0, + "bm25_score": result.get("score", 0.0), + "vector_rank": -1, + "bm25_rank": rank, + "payload": result.get("payload", {}), + "id": chunk_id, + } + else: + combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0) + combined_scores[chunk_id]["bm25_rank"] = rank + + combined_scores[chunk_id]["score"] += rrf_score + + sorted_results = sorted( + combined_scores.values(), + key=lambda x: x["score"], + reverse=True + ) + + return sorted_results + + +class OptimizedRetriever(BaseRetriever): + """ + Optimized retriever with: + - Task prefixes (search_document/search_query) + - Two-stage retrieval (256 dim -> 768 dim) + - RRF hybrid ranking (vector + BM25) + - Metadata filtering + + Reference: rag-optimization/spec.md Section 2, 3, 4 + """ + + def __init__( + self, + qdrant_client: QdrantClient | None = None, + embedding_provider: NomicEmbeddingProvider | None = None, + top_k: int | None = None, + score_threshold: float | None = None, + min_hits: int | None = None, + two_stage_enabled: bool | None = None, + two_stage_expand_factor: int | None = None, + hybrid_enabled: bool | None = None, + rrf_k: int | None = None, + ): + self._qdrant_client = qdrant_client + self._embedding_provider = embedding_provider + self._top_k = top_k or settings.rag_top_k + self._score_threshold = score_threshold or settings.rag_score_threshold + self._min_hits = min_hits or settings.rag_min_hits + self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled + self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor + self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled + self._rrf_k = rrf_k or settings.rag_rrf_k + self._rrf_combiner = RRFCombiner(k=self._rrf_k) + + async def _get_client(self) -> QdrantClient: + if self._qdrant_client is None: + self._qdrant_client = await get_qdrant_client() + return self._qdrant_client + + async def _get_embedding_provider(self) -> NomicEmbeddingProvider: + if self._embedding_provider is None: + from app.services.embedding.factory import get_embedding_config_manager + manager = get_embedding_config_manager() + provider = await manager.get_provider() + if isinstance(provider, NomicEmbeddingProvider): + self._embedding_provider = provider + else: + self._embedding_provider = NomicEmbeddingProvider( + base_url=settings.ollama_base_url, + model=settings.ollama_embedding_model, + dimension=settings.qdrant_vector_size, + ) + return self._embedding_provider + + async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: + """ + Retrieve documents using optimized strategy. + + Strategy selection: + 1. If two_stage_enabled: use two-stage retrieval + 2. If hybrid_enabled: use RRF hybrid ranking + 3. Otherwise: simple vector search + """ + logger.info( + f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, " + f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}" + ) + logger.info( + f"[RAG-OPT] Retrieval config: top_k={self._top_k}, " + f"score_threshold={self._score_threshold}, min_hits={self._min_hits}" + ) + + try: + provider = await self._get_embedding_provider() + logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}") + + embedding_result = await provider.embed_query(ctx.query) + logger.info( + f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, " + f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}" + ) + + if self._two_stage_enabled: + logger.info("[RAG-OPT] Using two-stage retrieval strategy") + results = await self._two_stage_retrieve( + ctx.tenant_id, + embedding_result, + self._top_k, + ) + elif self._hybrid_enabled: + logger.info("[RAG-OPT] Using hybrid retrieval strategy") + results = await self._hybrid_retrieve( + ctx.tenant_id, + embedding_result, + ctx.query, + self._top_k, + ) + else: + logger.info("[RAG-OPT] Using simple vector retrieval strategy") + results = await self._vector_retrieve( + ctx.tenant_id, + embedding_result.embedding_full, + self._top_k, + ) + + logger.info(f"[RAG-OPT] Raw results count: {len(results)}") + + retrieval_hits = [ + RetrievalHit( + text=result.get("payload", {}).get("text", ""), + score=result.get("score", 0.0), + source="optimized_rag", + metadata=result.get("payload", {}), + ) + for result in results + if result.get("score", 0.0) >= self._score_threshold + ] + + filtered_count = len(results) - len(retrieval_hits) + if filtered_count > 0: + logger.info( + f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}" + ) + + is_insufficient = len(retrieval_hits) < self._min_hits + + diagnostics = { + "query_length": len(ctx.query), + "top_k": self._top_k, + "score_threshold": self._score_threshold, + "two_stage_enabled": self._two_stage_enabled, + "hybrid_enabled": self._hybrid_enabled, + "total_hits": len(retrieval_hits), + "is_insufficient": is_insufficient, + "max_score": max((h.score for h in retrieval_hits), default=0.0), + "raw_results_count": len(results), + "filtered_below_threshold": filtered_count, + } + + logger.info( + f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, " + f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}" + ) + + if len(retrieval_hits) == 0: + logger.warning( + f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., " + f"raw_results={len(results)}, threshold={self._score_threshold}" + ) + + return RetrievalResult( + hits=retrieval_hits, + diagnostics=diagnostics, + ) + + except Exception as e: + logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True) + return RetrievalResult( + hits=[], + diagnostics={"error": str(e), "is_insufficient": True}, + ) + + async def _two_stage_retrieve( + self, + tenant_id: str, + embedding_result: NomicEmbeddingResult, + top_k: int, + ) -> list[dict[str, Any]]: + """ + Two-stage retrieval using Matryoshka dimensions. + + Stage 1: Fast retrieval with 256-dim vectors + Stage 2: Precise reranking with 768-dim vectors + + Reference: rag-optimization/spec.md Section 2.4 + """ + import time + + client = await self._get_client() + + stage1_start = time.perf_counter() + candidates = await self._search_with_dimension( + client, tenant_id, embedding_result.embedding_256, "dim_256", + top_k * self._two_stage_expand_factor + ) + stage1_latency = (time.perf_counter() - stage1_start) * 1000 + + logger.debug( + f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms" + ) + + stage2_start = time.perf_counter() + reranked = [] + for candidate in candidates: + stored_full_embedding = candidate.get("payload", {}).get("embedding_full", []) + if stored_full_embedding: + import numpy as np + similarity = self._cosine_similarity( + embedding_result.embedding_full, + stored_full_embedding + ) + candidate["score"] = similarity + candidate["stage"] = "reranked" + reranked.append(candidate) + + reranked.sort(key=lambda x: x.get("score", 0), reverse=True) + results = reranked[:top_k] + stage2_latency = (time.perf_counter() - stage2_start) * 1000 + + logger.debug( + f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms" + ) + + return results + + async def _hybrid_retrieve( + self, + tenant_id: str, + embedding_result: NomicEmbeddingResult, + query: str, + top_k: int, + ) -> list[dict[str, Any]]: + """ + Hybrid retrieval using RRF to combine vector and BM25 results. + + Reference: rag-optimization/spec.md Section 2.5 + """ + client = await self._get_client() + + vector_task = self._search_with_dimension( + client, tenant_id, embedding_result.embedding_full, "full", + top_k * 2 + ) + + bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2) + + vector_results, bm25_results = await asyncio.gather( + vector_task, bm25_task, return_exceptions=True + ) + + if isinstance(vector_results, Exception): + logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}") + vector_results = [] + + if isinstance(bm25_results, Exception): + logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}") + bm25_results = [] + + combined = self._rrf_combiner.combine( + vector_results, + bm25_results, + vector_weight=settings.rag_vector_weight, + bm25_weight=settings.rag_bm25_weight, + ) + + return combined[:top_k] + + async def _vector_retrieve( + self, + tenant_id: str, + embedding: list[float], + top_k: int, + ) -> list[dict[str, Any]]: + """Simple vector retrieval.""" + client = await self._get_client() + return await self._search_with_dimension( + client, tenant_id, embedding, "full", top_k + ) + + async def _search_with_dimension( + self, + client: QdrantClient, + tenant_id: str, + query_vector: list[float], + vector_name: str, + limit: int, + ) -> list[dict[str, Any]]: + """Search using specified vector dimension.""" + try: + qdrant = await client.get_client() + collection_name = client.get_collection_name(tenant_id) + + logger.info( + f"[RAG-OPT] Searching collection={collection_name}, " + f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}" + ) + + results = await qdrant.search( + collection_name=collection_name, + query_vector=(vector_name, query_vector), + limit=limit, + ) + + logger.info( + f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}" + ) + + if len(results) > 0: + for i, r in enumerate(results[:3]): + logger.debug( + f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}" + ) + + return [ + { + "id": str(result.id), + "score": result.score, + "payload": result.payload or {}, + } + for result in results + ] + except Exception as e: + logger.error( + f"[RAG-OPT] Search with {vector_name} failed: {e}, " + f"collection_name={client.get_collection_name(tenant_id)}", + exc_info=True + ) + return [] + + async def _bm25_search( + self, + client: QdrantClient, + tenant_id: str, + query: str, + limit: int, + ) -> list[dict[str, Any]]: + """ + BM25-like search using Qdrant's sparse vectors or fallback to text matching. + This is a simplified implementation; for production, use Elasticsearch. + """ + try: + qdrant = await client.get_client() + collection_name = client.get_collection_name(tenant_id) + + query_terms = set(re.findall(r'\w+', query.lower())) + + results = await qdrant.scroll( + collection_name=collection_name, + limit=limit * 3, + with_payload=True, + ) + + scored_results = [] + for point in results[0]: + text = point.payload.get("text", "").lower() + text_terms = set(re.findall(r'\w+', text)) + overlap = len(query_terms & text_terms) + if overlap > 0: + score = overlap / (len(query_terms) + len(text_terms) - overlap) + scored_results.append({ + "id": str(point.id), + "score": score, + "payload": point.payload or {}, + }) + + scored_results.sort(key=lambda x: x["score"], reverse=True) + return scored_results[:limit] + + except Exception as e: + logger.debug(f"[RAG-OPT] BM25 search failed: {e}") + return [] + + def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: + """Calculate cosine similarity between two vectors.""" + import numpy as np + a = np.array(vec1) + b = np.array(vec2) + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + + async def health_check(self) -> bool: + """Check if retriever is healthy.""" + try: + client = await self._get_client() + qdrant = await client.get_client() + await qdrant.get_collections() + return True + except Exception as e: + logger.error(f"[RAG-OPT] Health check failed: {e}") + return False + + +_optimized_retriever: OptimizedRetriever | None = None + + +async def get_optimized_retriever() -> OptimizedRetriever: + """Get or create OptimizedRetriever instance.""" + global _optimized_retriever + if _optimized_retriever is None: + _optimized_retriever = OptimizedRetriever() + return _optimized_retriever