343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""
|
||
RAG Lab endpoints for debugging and experimentation.
|
||
[AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from typing import Annotated, Any, List
|
||
|
||
from fastapi import APIRouter, Depends, Body
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
from app.core.config import get_settings
|
||
from app.core.exceptions import MissingTenantIdException
|
||
from app.core.tenant import get_tenant_id
|
||
from app.models import ErrorResponse
|
||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||
from app.services.retrieval.base import RetrievalContext
|
||
from app.services.llm.factory import get_llm_config_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
|
||
|
||
|
||
def get_current_tenant_id() -> str:
|
||
"""Dependency to get current tenant ID or raise exception."""
|
||
tenant_id = get_tenant_id()
|
||
if not tenant_id:
|
||
raise MissingTenantIdException()
|
||
return tenant_id
|
||
|
||
|
||
class RAGExperimentRequest(BaseModel):
|
||
query: str = Field(..., description="Query text for retrieval")
|
||
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
|
||
top_k: int = Field(default=5, description="Number of results to retrieve")
|
||
score_threshold: float = Field(default=0.5, description="Minimum similarity score")
|
||
generate_response: bool = Field(default=True, description="Whether to generate AI response")
|
||
llm_provider: str | None = Field(default=None, description="Specific LLM provider to use")
|
||
|
||
|
||
class AIResponse(BaseModel):
|
||
content: str
|
||
prompt_tokens: int = 0
|
||
completion_tokens: int = 0
|
||
total_tokens: int = 0
|
||
latency_ms: float = 0
|
||
model: str = ""
|
||
|
||
|
||
class RAGExperimentResult(BaseModel):
|
||
query: str
|
||
retrieval_results: List[dict] = []
|
||
final_prompt: str = ""
|
||
ai_response: AIResponse | None = None
|
||
total_latency_ms: float = 0
|
||
diagnostics: dict[str, Any] = {}
|
||
|
||
|
||
@router.post(
|
||
"/experiments/run",
|
||
operation_id="runRagExperiment",
|
||
summary="Run RAG debugging experiment with AI output",
|
||
description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.",
|
||
responses={
|
||
200: {"description": "Experiment results with retrieval, prompt, and AI response"},
|
||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||
},
|
||
)
|
||
async def run_rag_experiment(
|
||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||
request: RAGExperimentRequest = Body(...),
|
||
) -> JSONResponse:
|
||
"""
|
||
[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response.
|
||
"""
|
||
start_time = time.time()
|
||
|
||
logger.info(
|
||
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
|
||
f"query={request.query[:50]}..., kb_ids={request.kb_ids}, "
|
||
f"generate_response={request.generate_response}"
|
||
)
|
||
|
||
settings = get_settings()
|
||
top_k = request.top_k or settings.rag_top_k
|
||
threshold = request.score_threshold or settings.rag_score_threshold
|
||
|
||
try:
|
||
# Use optimized retriever with RAG enhancements
|
||
retriever = await get_optimized_retriever()
|
||
|
||
retrieval_ctx = RetrievalContext(
|
||
tenant_id=tenant_id,
|
||
query=request.query,
|
||
session_id="rag_experiment",
|
||
channel_type="admin",
|
||
metadata={"kb_ids": request.kb_ids},
|
||
)
|
||
|
||
result = await retriever.retrieve(retrieval_ctx)
|
||
|
||
retrieval_results = [
|
||
{
|
||
"content": hit.text,
|
||
"score": hit.score,
|
||
"source": hit.source,
|
||
"metadata": hit.metadata,
|
||
}
|
||
for hit in result.hits
|
||
]
|
||
|
||
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
||
|
||
logger.info(
|
||
f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, "
|
||
f"max_score={result.max_score:.3f}"
|
||
)
|
||
|
||
ai_response = None
|
||
if request.generate_response:
|
||
ai_response = await _generate_ai_response(
|
||
final_prompt,
|
||
provider=request.llm_provider,
|
||
)
|
||
|
||
total_latency_ms = (time.time() - start_time) * 1000
|
||
|
||
return JSONResponse(
|
||
content={
|
||
"query": request.query,
|
||
"retrieval_results": retrieval_results,
|
||
"final_prompt": final_prompt,
|
||
"ai_response": ai_response.model_dump() if ai_response else None,
|
||
"total_latency_ms": round(total_latency_ms, 2),
|
||
"diagnostics": result.diagnostics,
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
|
||
|
||
fallback_results = _get_fallback_results(request.query)
|
||
fallback_prompt = _build_final_prompt(request.query, fallback_results)
|
||
|
||
ai_response = None
|
||
if request.generate_response:
|
||
ai_response = await _generate_ai_response(
|
||
fallback_prompt,
|
||
provider=request.llm_provider,
|
||
)
|
||
|
||
total_latency_ms = (time.time() - start_time) * 1000
|
||
|
||
return JSONResponse(
|
||
content={
|
||
"query": request.query,
|
||
"retrieval_results": fallback_results,
|
||
"final_prompt": fallback_prompt,
|
||
"ai_response": ai_response.model_dump() if ai_response else None,
|
||
"total_latency_ms": round(total_latency_ms, 2),
|
||
"diagnostics": {
|
||
"error": str(e),
|
||
"fallback": True,
|
||
},
|
||
}
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/experiments/stream",
|
||
operation_id="runRagExperimentStream",
|
||
summary="Run RAG experiment with streaming AI output",
|
||
description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.",
|
||
responses={
|
||
200: {"description": "SSE stream with retrieval results and AI response"},
|
||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||
},
|
||
)
|
||
async def run_rag_experiment_stream(
|
||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||
request: RAGExperimentRequest = Body(...),
|
||
) -> StreamingResponse:
|
||
"""
|
||
[AC-ASA-20] Run RAG experiment with SSE streaming for AI response.
|
||
"""
|
||
logger.info(
|
||
f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, "
|
||
f"query={request.query[:50]}..."
|
||
)
|
||
|
||
settings = get_settings()
|
||
top_k = request.top_k or settings.rag_top_k
|
||
threshold = request.score_threshold or settings.rag_score_threshold
|
||
|
||
async def event_generator():
|
||
try:
|
||
# Use optimized retriever with RAG enhancements
|
||
retriever = await get_optimized_retriever()
|
||
|
||
retrieval_ctx = RetrievalContext(
|
||
tenant_id=tenant_id,
|
||
query=request.query,
|
||
session_id="rag_experiment_stream",
|
||
channel_type="admin",
|
||
metadata={"kb_ids": request.kb_ids},
|
||
)
|
||
|
||
result = await retriever.retrieve(retrieval_ctx)
|
||
|
||
retrieval_results = [
|
||
{
|
||
"content": hit.text,
|
||
"score": hit.score,
|
||
"source": hit.source,
|
||
"metadata": hit.metadata,
|
||
}
|
||
for hit in result.hits
|
||
]
|
||
|
||
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
||
|
||
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
|
||
|
||
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
|
||
|
||
if request.generate_response:
|
||
manager = get_llm_config_manager()
|
||
client = manager.get_client()
|
||
|
||
full_content = ""
|
||
async for chunk in client.stream_generate(
|
||
messages=[{"role": "user", "content": final_prompt}],
|
||
):
|
||
if chunk.delta:
|
||
full_content += chunk.delta
|
||
yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n"
|
||
|
||
yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n"
|
||
else:
|
||
yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}")
|
||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
|
||
async def _generate_ai_response(
|
||
prompt: str,
|
||
provider: str | None = None,
|
||
) -> AIResponse | None:
|
||
"""
|
||
[AC-ASA-19, AC-ASA-21] Generate AI response from prompt.
|
||
"""
|
||
import time
|
||
|
||
try:
|
||
manager = get_llm_config_manager()
|
||
client = manager.get_client()
|
||
|
||
start_time = time.time()
|
||
response = await client.generate(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
)
|
||
latency_ms = (time.time() - start_time) * 1000
|
||
|
||
return AIResponse(
|
||
content=response.content,
|
||
prompt_tokens=response.usage.get("prompt_tokens", 0),
|
||
completion_tokens=response.usage.get("completion_tokens", 0),
|
||
total_tokens=response.usage.get("total_tokens", 0),
|
||
latency_ms=round(latency_ms, 2),
|
||
model=response.model,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[AC-ASA-19] AI response generation failed: {e}")
|
||
return AIResponse(
|
||
content=f"AI 响应生成失败: {str(e)}",
|
||
latency_ms=0,
|
||
)
|
||
|
||
|
||
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
|
||
"""
|
||
Build the final prompt from query and retrieval results.
|
||
"""
|
||
if not retrieval_results:
|
||
return f"""用户问题:{query}
|
||
|
||
未找到相关检索结果,请基于通用知识回答用户问题。"""
|
||
|
||
evidence_text = "\n".join([
|
||
f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}"
|
||
for i, hit in enumerate(retrieval_results[:5])
|
||
])
|
||
|
||
return f"""【系统指令】
|
||
你是一名经验丰富的客服专员,名字叫“小N”。请遵循以下准则:
|
||
1. 使用自然、口语化的中文,语气亲切、有礼貌,像在和人聊天一样。
|
||
2. 不要使用列表、编号或过于机械的结构,而是用连贯的段落表达。
|
||
3. 如果知识库信息不足,可以委婉地说明“我帮您查一下”、“我确认一下”,而不是直接说“我不知道”。
|
||
4. 不要提及“作为AI”、“根据算法”等词,避免暴露机器身份。
|
||
5. 适当使用语气词(如“呢”、“哦”、“哈”)、表情符号(根据品牌调性可选)和人性化的表达(如“我理解”、“您别着急”)。
|
||
6. 回答长度适中,一般不超过3-5句话,如需详细说明可分段落,但保持易读。
|
||
7. 如果问题涉及隐私或敏感信息,请引导客户联系人工客服(如“为了您的账户安全,建议您拨打官方电话...”)。
|
||
|
||
【知识库内容】
|
||
{evidence_text}
|
||
|
||
【用户问题】
|
||
{query}
|
||
"""
|
||
|
||
|
||
def _get_fallback_results(query: str) -> list[dict]:
|
||
"""
|
||
Provide fallback results when retrieval fails.
|
||
"""
|
||
return [
|
||
{
|
||
"content": "检索服务暂时不可用,这是模拟结果。",
|
||
"score": 0.5,
|
||
"source": "fallback",
|
||
}
|
||
]
|