162 lines
5.0 KiB
Python
162 lines
5.0 KiB
Python
"""
|
|
RAG Lab endpoints for debugging and experimentation.
|
|
[AC-ASA-05] RAG experiment debugging with retrieval results and prompt visualization.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Annotated, Any, List
|
|
|
|
from fastapi import APIRouter, Depends, Body
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.database import get_session
|
|
from app.core.exceptions import MissingTenantIdException
|
|
from app.core.tenant import get_tenant_id
|
|
from app.core.qdrant_client import get_qdrant_client
|
|
from app.models import ErrorResponse
|
|
from app.services.retrieval.vector_retriever import get_vector_retriever
|
|
from app.services.retrieval.base import RetrievalContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
|
|
|
|
|
|
def get_current_tenant_id() -> str:
|
|
"""Dependency to get current tenant ID or raise exception."""
|
|
tenant_id = get_tenant_id()
|
|
if not tenant_id:
|
|
raise MissingTenantIdException()
|
|
return tenant_id
|
|
|
|
|
|
class RAGExperimentRequest(BaseModel):
|
|
query: str = Field(..., description="Query text for retrieval")
|
|
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
|
|
params: dict[str, Any] | None = Field(default=None, description="Retrieval parameters")
|
|
|
|
|
|
@router.post(
|
|
"/experiments/run",
|
|
operation_id="runRagExperiment",
|
|
summary="Run RAG debugging experiment",
|
|
description="[AC-ASA-05] Trigger RAG experiment with retrieval and prompt generation.",
|
|
responses={
|
|
200: {"description": "Experiment results with retrieval and prompt"},
|
|
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
403: {"description": "Forbidden", "model": ErrorResponse},
|
|
},
|
|
)
|
|
async def run_rag_experiment(
|
|
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
|
request: RAGExperimentRequest = Body(...),
|
|
) -> JSONResponse:
|
|
"""
|
|
[AC-ASA-05] Run RAG experiment and return retrieval results with final prompt.
|
|
"""
|
|
logger.info(
|
|
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
|
|
f"query={request.query[:50]}..., kb_ids={request.kb_ids}"
|
|
)
|
|
|
|
settings = get_settings()
|
|
|
|
params = request.params or {}
|
|
top_k = params.get("topK", settings.rag_top_k)
|
|
threshold = params.get("threshold", settings.rag_score_threshold)
|
|
|
|
try:
|
|
retriever = await get_vector_retriever()
|
|
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id=tenant_id,
|
|
query=request.query,
|
|
session_id="rag_experiment",
|
|
channel_type="admin",
|
|
metadata={"kb_ids": request.kb_ids},
|
|
)
|
|
|
|
result = await retriever.retrieve(retrieval_ctx)
|
|
|
|
retrieval_results = [
|
|
{
|
|
"content": hit.text,
|
|
"score": hit.score,
|
|
"source": hit.source,
|
|
"metadata": hit.metadata,
|
|
}
|
|
for hit in result.hits
|
|
]
|
|
|
|
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
|
|
|
logger.info(
|
|
f"[AC-ASA-05] RAG experiment complete: hits={len(retrieval_results)}, "
|
|
f"max_score={result.max_score:.3f}"
|
|
)
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"retrievalResults": retrieval_results,
|
|
"finalPrompt": final_prompt,
|
|
"diagnostics": result.diagnostics,
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
|
|
|
|
fallback_results = _get_fallback_results(request.query)
|
|
fallback_prompt = _build_final_prompt(request.query, fallback_results)
|
|
|
|
return JSONResponse(
|
|
content={
|
|
"retrievalResults": fallback_results,
|
|
"finalPrompt": fallback_prompt,
|
|
"diagnostics": {
|
|
"error": str(e),
|
|
"fallback": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
|
|
"""
|
|
Build the final prompt from query and retrieval results.
|
|
"""
|
|
if not retrieval_results:
|
|
return f"""用户问题:{query}
|
|
|
|
未找到相关检索结果,请基于通用知识回答用户问题。"""
|
|
|
|
evidence_text = "\n".join([
|
|
f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}"
|
|
for i, hit in enumerate(retrieval_results[:5])
|
|
])
|
|
|
|
return f"""基于以下检索到的信息,回答用户问题:
|
|
|
|
用户问题:{query}
|
|
|
|
检索结果:
|
|
{evidence_text}
|
|
|
|
请基于以上信息生成专业、准确的回答。"""
|
|
|
|
|
|
def _get_fallback_results(query: str) -> list[dict]:
|
|
"""
|
|
Provide fallback results when retrieval fails.
|
|
"""
|
|
return [
|
|
{
|
|
"content": "检索服务暂时不可用,这是模拟结果。",
|
|
"score": 0.5,
|
|
"source": "fallback",
|
|
}
|
|
]
|