ai-robot-core/ai-service/app/api/admin/rag.py

343 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG Lab endpoints for debugging and experimentation.
[AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output.
"""
import json
import logging
import time
from typing import Annotated, Any, List
from fastapi import APIRouter, Depends, Body
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from app.core.config import get_settings
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.services.retrieval.vector_retriever import get_vector_retriever
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
from app.services.llm.factory import get_llm_config_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
class RAGExperimentRequest(BaseModel):
query: str = Field(..., description="Query text for retrieval")
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
top_k: int = Field(default=5, description="Number of results to retrieve")
score_threshold: float = Field(default=0.5, description="Minimum similarity score")
generate_response: bool = Field(default=True, description="Whether to generate AI response")
llm_provider: str | None = Field(default=None, description="Specific LLM provider to use")
class AIResponse(BaseModel):
content: str
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
latency_ms: float = 0
model: str = ""
class RAGExperimentResult(BaseModel):
query: str
retrieval_results: List[dict] = []
final_prompt: str = ""
ai_response: AIResponse | None = None
total_latency_ms: float = 0
diagnostics: dict[str, Any] = {}
@router.post(
"/experiments/run",
operation_id="runRagExperiment",
summary="Run RAG debugging experiment with AI output",
description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.",
responses={
200: {"description": "Experiment results with retrieval, prompt, and AI response"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def run_rag_experiment(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
request: RAGExperimentRequest = Body(...),
) -> JSONResponse:
"""
[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response.
"""
start_time = time.time()
logger.info(
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
f"query={request.query[:50]}..., kb_ids={request.kb_ids}, "
f"generate_response={request.generate_response}"
)
settings = get_settings()
top_k = request.top_k or settings.rag_top_k
threshold = request.score_threshold or settings.rag_score_threshold
try:
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=request.query,
session_id="rag_experiment",
channel_type="admin",
metadata={"kb_ids": request.kb_ids},
)
result = await retriever.retrieve(retrieval_ctx)
retrieval_results = [
{
"content": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
]
final_prompt = _build_final_prompt(request.query, retrieval_results)
logger.info(
f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, "
f"max_score={result.max_score:.3f}"
)
ai_response = None
if request.generate_response:
ai_response = await _generate_ai_response(
final_prompt,
provider=request.llm_provider,
)
total_latency_ms = (time.time() - start_time) * 1000
return JSONResponse(
content={
"query": request.query,
"retrieval_results": retrieval_results,
"final_prompt": final_prompt,
"ai_response": ai_response.model_dump() if ai_response else None,
"total_latency_ms": round(total_latency_ms, 2),
"diagnostics": result.diagnostics,
}
)
except Exception as e:
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
fallback_results = _get_fallback_results(request.query)
fallback_prompt = _build_final_prompt(request.query, fallback_results)
ai_response = None
if request.generate_response:
ai_response = await _generate_ai_response(
fallback_prompt,
provider=request.llm_provider,
)
total_latency_ms = (time.time() - start_time) * 1000
return JSONResponse(
content={
"query": request.query,
"retrieval_results": fallback_results,
"final_prompt": fallback_prompt,
"ai_response": ai_response.model_dump() if ai_response else None,
"total_latency_ms": round(total_latency_ms, 2),
"diagnostics": {
"error": str(e),
"fallback": True,
},
}
)
@router.post(
"/experiments/stream",
operation_id="runRagExperimentStream",
summary="Run RAG experiment with streaming AI output",
description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.",
responses={
200: {"description": "SSE stream with retrieval results and AI response"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def run_rag_experiment_stream(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
request: RAGExperimentRequest = Body(...),
) -> StreamingResponse:
"""
[AC-ASA-20] Run RAG experiment with SSE streaming for AI response.
"""
logger.info(
f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, "
f"query={request.query[:50]}..."
)
settings = get_settings()
top_k = request.top_k or settings.rag_top_k
threshold = request.score_threshold or settings.rag_score_threshold
async def event_generator():
try:
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=request.query,
session_id="rag_experiment_stream",
channel_type="admin",
metadata={"kb_ids": request.kb_ids},
)
result = await retriever.retrieve(retrieval_ctx)
retrieval_results = [
{
"content": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
]
final_prompt = _build_final_prompt(request.query, retrieval_results)
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
if request.generate_response:
manager = get_llm_config_manager()
client = manager.get_client()
full_content = ""
async for chunk in client.stream_generate(
messages=[{"role": "user", "content": final_prompt}],
):
if chunk.delta:
full_content += chunk.delta
yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n"
yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n"
else:
yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n"
except Exception as e:
logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
async def _generate_ai_response(
prompt: str,
provider: str | None = None,
) -> AIResponse | None:
"""
[AC-ASA-19, AC-ASA-21] Generate AI response from prompt.
"""
import time
try:
manager = get_llm_config_manager()
client = manager.get_client()
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": prompt}],
)
latency_ms = (time.time() - start_time) * 1000
return AIResponse(
content=response.content,
prompt_tokens=response.usage.get("prompt_tokens", 0),
completion_tokens=response.usage.get("completion_tokens", 0),
total_tokens=response.usage.get("total_tokens", 0),
latency_ms=round(latency_ms, 2),
model=response.model,
)
except Exception as e:
logger.error(f"[AC-ASA-19] AI response generation failed: {e}")
return AIResponse(
content=f"AI 响应生成失败: {str(e)}",
latency_ms=0,
)
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
"""
Build the final prompt from query and retrieval results.
"""
if not retrieval_results:
return f"""用户问题:{query}
未找到相关检索结果,请基于通用知识回答用户问题。"""
evidence_text = "\n".join([
f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}"
for i, hit in enumerate(retrieval_results[:5])
])
return f"""【系统指令】
你是一名经验丰富的客服专员名字叫“小N”。请遵循以下准则
1. 使用自然、口语化的中文,语气亲切、有礼貌,像在和人聊天一样。
2. 不要使用列表、编号或过于机械的结构,而是用连贯的段落表达。
3. 如果知识库信息不足,可以委婉地说明“我帮您查一下”、“我确认一下”,而不是直接说“我不知道”。
4. 不要提及“作为AI”、“根据算法”等词避免暴露机器身份。
5. 适当使用语气词(如“呢”、“哦”、“哈”)、表情符号(根据品牌调性可选)和人性化的表达(如“我理解”、“您别着急”)。
6. 回答长度适中一般不超过3-5句话如需详细说明可分段落但保持易读。
7. 如果问题涉及隐私或敏感信息,请引导客户联系人工客服(如“为了您的账户安全,建议您拨打官方电话...”)。
【知识库内容】
{evidence_text}
【用户问题】
{query}
"""
def _get_fallback_results(query: str) -> list[dict]:
"""
Provide fallback results when retrieval fails.
"""
return [
{
"content": "检索服务暂时不可用,这是模拟结果。",
"score": 0.5,
"source": "fallback",
}
]