""" RAG Lab endpoints for debugging and experimentation. [AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output. """ import json import logging import time from typing import Annotated, Any, List from fastapi import APIRouter, Depends, Body from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field from app.core.config import get_settings from app.core.exceptions import MissingTenantIdException from app.core.prompts import format_evidence_for_prompt, build_user_prompt_with_evidence from app.core.tenant import get_tenant_id from app.models import ErrorResponse from app.services.retrieval.vector_retriever import get_vector_retriever from app.services.retrieval.optimized_retriever import get_optimized_retriever from app.services.retrieval.base import RetrievalContext from app.services.llm.factory import get_llm_config_manager logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"]) def get_current_tenant_id() -> str: """Dependency to get current tenant ID or raise exception.""" tenant_id = get_tenant_id() if not tenant_id: raise MissingTenantIdException() return tenant_id class RAGExperimentRequest(BaseModel): query: str = Field(..., description="Query text for retrieval") kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search") top_k: int = Field(default=5, description="Number of results to retrieve") score_threshold: float = Field(default=0.5, description="Minimum similarity score") generate_response: bool = Field(default=True, description="Whether to generate AI response") llm_provider: str | None = Field(default=None, description="Specific LLM provider to use") class AIResponse(BaseModel): content: str prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 latency_ms: float = 0 model: str = "" class RAGExperimentResult(BaseModel): query: str retrieval_results: List[dict] = [] final_prompt: str = "" ai_response: AIResponse | None = None total_latency_ms: float = 0 diagnostics: dict[str, Any] = {} @router.post( "/experiments/run", operation_id="runRagExperiment", summary="Run RAG debugging experiment with AI output", description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.", responses={ 200: {"description": "Experiment results with retrieval, prompt, and AI response"}, 401: {"description": "Unauthorized", "model": ErrorResponse}, 403: {"description": "Forbidden", "model": ErrorResponse}, }, ) async def run_rag_experiment( tenant_id: Annotated[str, Depends(get_current_tenant_id)], request: RAGExperimentRequest = Body(...), ) -> JSONResponse: """ [AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response. """ start_time = time.time() logger.info( f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, " f"query={request.query[:50]}..., kb_ids={request.kb_ids}, " f"generate_response={request.generate_response}" ) settings = get_settings() top_k = request.top_k or settings.rag_top_k threshold = request.score_threshold or settings.rag_score_threshold try: # Use optimized retriever with RAG enhancements retriever = await get_optimized_retriever() retrieval_ctx = RetrievalContext( tenant_id=tenant_id, query=request.query, session_id="rag_experiment", channel_type="admin", metadata={"kb_ids": request.kb_ids}, ) result = await retriever.retrieve(retrieval_ctx) retrieval_results = [ { "content": hit.text, "score": hit.score, "source": hit.source, "metadata": hit.metadata, } for hit in result.hits ] final_prompt = _build_final_prompt(request.query, retrieval_results) logger.info( f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, " f"max_score={result.max_score:.3f}" ) ai_response = None if request.generate_response: ai_response = await _generate_ai_response( final_prompt, provider=request.llm_provider, ) total_latency_ms = (time.time() - start_time) * 1000 return JSONResponse( content={ "query": request.query, "retrieval_results": retrieval_results, "final_prompt": final_prompt, "ai_response": ai_response.model_dump() if ai_response else None, "total_latency_ms": round(total_latency_ms, 2), "diagnostics": result.diagnostics, } ) except Exception as e: logger.error(f"[AC-ASA-05] RAG experiment failed: {e}") fallback_results = _get_fallback_results(request.query) fallback_prompt = _build_final_prompt(request.query, fallback_results) ai_response = None if request.generate_response: ai_response = await _generate_ai_response( fallback_prompt, provider=request.llm_provider, ) total_latency_ms = (time.time() - start_time) * 1000 return JSONResponse( content={ "query": request.query, "retrieval_results": fallback_results, "final_prompt": fallback_prompt, "ai_response": ai_response.model_dump() if ai_response else None, "total_latency_ms": round(total_latency_ms, 2), "diagnostics": { "error": str(e), "fallback": True, }, } ) @router.post( "/experiments/stream", operation_id="runRagExperimentStream", summary="Run RAG experiment with streaming AI output", description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.", responses={ 200: {"description": "SSE stream with retrieval results and AI response"}, 401: {"description": "Unauthorized", "model": ErrorResponse}, 403: {"description": "Forbidden", "model": ErrorResponse}, }, ) async def run_rag_experiment_stream( tenant_id: Annotated[str, Depends(get_current_tenant_id)], request: RAGExperimentRequest = Body(...), ) -> StreamingResponse: """ [AC-ASA-20] Run RAG experiment with SSE streaming for AI response. """ logger.info( f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, " f"query={request.query[:50]}..." ) settings = get_settings() top_k = request.top_k or settings.rag_top_k threshold = request.score_threshold or settings.rag_score_threshold async def event_generator(): try: # Use optimized retriever with RAG enhancements retriever = await get_optimized_retriever() retrieval_ctx = RetrievalContext( tenant_id=tenant_id, query=request.query, session_id="rag_experiment_stream", channel_type="admin", metadata={"kb_ids": request.kb_ids}, ) result = await retriever.retrieve(retrieval_ctx) retrieval_results = [ { "content": hit.text, "score": hit.score, "source": hit.source, "metadata": hit.metadata, } for hit in result.hits ] final_prompt = _build_final_prompt(request.query, retrieval_results) logger.info(f"[AC-ASA-20] ========== RAG LAB STREAM FULL PROMPT ==========") logger.info(f"[AC-ASA-20] Prompt length: {len(final_prompt)}") logger.info(f"[AC-ASA-20] Prompt content:\n{final_prompt}") logger.info(f"[AC-ASA-20] ==============================================") yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n" yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n" if request.generate_response: manager = get_llm_config_manager() client = manager.get_client() full_content = "" async for chunk in client.stream_generate( messages=[{"role": "user", "content": final_prompt}], ): if chunk.delta: full_content += chunk.delta yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n" yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n" else: yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n" except Exception as e: logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}") yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) async def _generate_ai_response( prompt: str, provider: str | None = None, ) -> AIResponse | None: """ [AC-ASA-19, AC-ASA-21] Generate AI response from prompt. """ import time logger.info(f"[AC-ASA-19] ========== RAG LAB FULL PROMPT ==========") logger.info(f"[AC-ASA-19] Prompt length: {len(prompt)}") logger.info(f"[AC-ASA-19] Prompt content:\n{prompt}") logger.info(f"[AC-ASA-19] ==========================================") try: manager = get_llm_config_manager() client = manager.get_client() start_time = time.time() response = await client.generate( messages=[{"role": "user", "content": prompt}], ) latency_ms = (time.time() - start_time) * 1000 return AIResponse( content=response.content, prompt_tokens=response.usage.get("prompt_tokens", 0), completion_tokens=response.usage.get("completion_tokens", 0), total_tokens=response.usage.get("total_tokens", 0), latency_ms=round(latency_ms, 2), model=response.model, ) except Exception as e: logger.error(f"[AC-ASA-19] AI response generation failed: {e}") return AIResponse( content=f"AI 响应生成失败: {str(e)}", latency_ms=0, ) def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str: """ Build the final prompt from query and retrieval results. Uses shared prompt configuration for consistency with orchestrator. """ evidence_text = format_evidence_for_prompt(retrieval_results, max_results=5, max_content_length=500) return build_user_prompt_with_evidence(query, evidence_text) def _get_fallback_results(query: str) -> list[dict]: """ Provide fallback results when retrieval fails. """ return [ { "content": "检索服务暂时不可用,这是模拟结果。", "score": 0.5, "source": "fallback", } ]