403 lines
12 KiB
Python
403 lines
12 KiB
Python
"""
|
|
Flow test API for AI Service Admin.
|
|
[AC-AISVC-93~AC-AISVC-95] Complete 12-step flow execution testing.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import desc, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.database import get_session
|
|
from app.models.entities import FlowTestRecord, FlowTestRecordStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/admin/test", tags=["Flow Test"])
|
|
|
|
|
|
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
|
"""Extract tenant ID from header."""
|
|
if not x_tenant_id:
|
|
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
|
|
return x_tenant_id
|
|
|
|
|
|
class FlowExecutionRequest(BaseModel):
|
|
"""Request for flow execution test."""
|
|
|
|
message: str
|
|
session_id: str | None = None
|
|
user_id: str | None = None
|
|
enable_flow: bool = True
|
|
enable_intent: bool = True
|
|
enable_rag: bool = True
|
|
enable_guardrail: bool = True
|
|
enable_memory: bool = True
|
|
compare_mode: bool = False
|
|
|
|
|
|
class FlowExecutionResponse(BaseModel):
|
|
"""Response for flow execution test."""
|
|
|
|
test_id: str
|
|
session_id: str
|
|
status: str
|
|
steps: list[dict[str, Any]]
|
|
final_response: dict[str, Any] | None
|
|
total_duration_ms: int
|
|
created_at: str
|
|
|
|
|
|
@router.post(
|
|
"/flow-execution",
|
|
operation_id="executeFlowTest",
|
|
summary="Execute complete 12-step flow",
|
|
description="[AC-AISVC-93] Execute complete 12-step generation flow with detailed step logging.",
|
|
)
|
|
async def execute_flow_test(
|
|
request: FlowExecutionRequest,
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> FlowExecutionResponse:
|
|
"""
|
|
[AC-AISVC-93] Execute complete 12-step flow for testing.
|
|
|
|
Steps:
|
|
1. InputScanner - Scan input for forbidden words
|
|
2. FlowEngine - Check if flow is active
|
|
3. IntentRouter - Match intent rules
|
|
4. QueryRewriter - Rewrite query for better retrieval
|
|
5. MultiKBRetrieval - Retrieve from multiple knowledge bases
|
|
6. ResultRanker - Rank and filter results
|
|
7. PromptBuilder - Build prompt from template
|
|
8. LLMGenerate - Generate response via LLM
|
|
9. OutputFilter - Filter output for forbidden words
|
|
10. Confidence - Calculate confidence score
|
|
11. Memory - Store conversation in memory
|
|
12. Response - Return final response
|
|
"""
|
|
import time
|
|
|
|
from app.models import ChatRequest, ChannelType
|
|
from app.services.llm.factory import get_llm_config_manager
|
|
from app.services.memory import MemoryService
|
|
from app.services.orchestrator import OrchestratorService
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-93] Executing flow test for tenant={tenant_id}, "
|
|
f"message={request.message[:50]}..."
|
|
)
|
|
|
|
test_session_id = request.session_id or f"test_{uuid.uuid4().hex[:8]}"
|
|
start_time = time.time()
|
|
|
|
memory_service = MemoryService(session)
|
|
llm_config_manager = get_llm_config_manager()
|
|
llm_client = llm_config_manager.get_client()
|
|
retriever = await get_optimized_retriever()
|
|
|
|
orchestrator = OrchestratorService(
|
|
llm_client=llm_client,
|
|
memory_service=memory_service,
|
|
retriever=retriever,
|
|
)
|
|
|
|
try:
|
|
chat_request = ChatRequest(
|
|
session_id=test_session_id,
|
|
current_message=request.message,
|
|
channel_type=ChannelType.WECHAT,
|
|
history=[],
|
|
)
|
|
|
|
result = await orchestrator.generate(
|
|
tenant_id=tenant_id,
|
|
request=chat_request,
|
|
)
|
|
|
|
steps = result.metadata.get("execution_steps", []) if result.metadata else []
|
|
total_duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
has_failure = any(s.get("status") == "failed" for s in steps)
|
|
has_partial = any(s.get("status") == "skipped" for s in steps)
|
|
|
|
if has_failure:
|
|
status = FlowTestRecordStatus.FAILED.value
|
|
elif has_partial:
|
|
status = FlowTestRecordStatus.PARTIAL.value
|
|
else:
|
|
status = FlowTestRecordStatus.SUCCESS.value
|
|
|
|
test_record = FlowTestRecord(
|
|
tenant_id=tenant_id,
|
|
session_id=test_session_id,
|
|
status=status,
|
|
steps=steps,
|
|
final_response={
|
|
"reply": result.reply,
|
|
"confidence": result.confidence,
|
|
"should_transfer": result.should_transfer,
|
|
},
|
|
total_duration_ms=total_duration_ms,
|
|
)
|
|
|
|
try:
|
|
session.add(test_record)
|
|
await session.commit()
|
|
await session.refresh(test_record)
|
|
except Exception as db_error:
|
|
logger.warning(f"Failed to save test record: {db_error}")
|
|
await session.rollback()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-93] Flow test completed: id={test_record.id}, "
|
|
f"status={status}, duration={total_duration_ms}ms"
|
|
)
|
|
|
|
return FlowExecutionResponse(
|
|
test_id=str(test_record.id),
|
|
session_id=test_session_id,
|
|
status=status,
|
|
steps=steps,
|
|
final_response=test_record.final_response,
|
|
total_duration_ms=total_duration_ms,
|
|
created_at=test_record.created_at.isoformat(),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AC-AISVC-93] Flow test failed: {e}")
|
|
|
|
total_duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
await session.rollback()
|
|
|
|
test_record = FlowTestRecord(
|
|
tenant_id=tenant_id,
|
|
session_id=test_session_id,
|
|
status=FlowTestRecordStatus.FAILED.value,
|
|
steps=[{
|
|
"step": 0,
|
|
"name": "Error",
|
|
"status": "failed",
|
|
"error": str(e),
|
|
}],
|
|
final_response=None,
|
|
total_duration_ms=total_duration_ms,
|
|
)
|
|
session.add(test_record)
|
|
await session.commit()
|
|
await session.refresh(test_record)
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get(
|
|
"/flow-execution/{test_id}",
|
|
operation_id="getFlowTestResult",
|
|
summary="Get flow test result",
|
|
description="[AC-AISVC-94] Get detailed result of a flow execution test.",
|
|
)
|
|
async def get_flow_test_result(
|
|
test_id: uuid.UUID,
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
[AC-AISVC-94] Get detailed result of a flow execution test.
|
|
Returns step-by-step execution details for debugging.
|
|
"""
|
|
logger.info(
|
|
f"[AC-AISVC-94] Getting flow test result for tenant={tenant_id}, "
|
|
f"test_id={test_id}"
|
|
)
|
|
|
|
stmt = select(FlowTestRecord).where(
|
|
FlowTestRecord.id == test_id,
|
|
FlowTestRecord.tenant_id == tenant_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
|
|
if not record:
|
|
raise HTTPException(status_code=404, detail="Test record not found")
|
|
|
|
return {
|
|
"testId": str(record.id),
|
|
"sessionId": record.session_id,
|
|
"status": record.status,
|
|
"steps": record.steps,
|
|
"finalResponse": record.final_response,
|
|
"totalDurationMs": record.total_duration_ms,
|
|
"createdAt": record.created_at.isoformat(),
|
|
}
|
|
|
|
|
|
@router.get(
|
|
"/flow-executions",
|
|
operation_id="listFlowTests",
|
|
summary="List flow test records",
|
|
description="[AC-AISVC-95] List flow test execution records.",
|
|
)
|
|
async def list_flow_tests(
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
session_id: str | None = Query(None, description="Filter by session ID"),
|
|
status: str | None = Query(None, description="Filter by status"),
|
|
page: int = Query(1, ge=1, description="Page number"),
|
|
page_size: int = Query(20, ge=1, le=100, description="Page size"),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
[AC-AISVC-95] List flow test execution records.
|
|
Records are retained for 7 days.
|
|
"""
|
|
logger.info(
|
|
f"[AC-AISVC-95] Listing flow tests for tenant={tenant_id}, "
|
|
f"session={session_id}, page={page}"
|
|
)
|
|
|
|
stmt = select(FlowTestRecord).where(
|
|
FlowTestRecord.tenant_id == tenant_id,
|
|
)
|
|
|
|
if session_id:
|
|
stmt = stmt.where(FlowTestRecord.session_id == session_id)
|
|
if status:
|
|
stmt = stmt.where(FlowTestRecord.status == status)
|
|
|
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
|
total_result = await session.execute(count_stmt)
|
|
total = total_result.scalar() or 0
|
|
|
|
stmt = stmt.order_by(desc(FlowTestRecord.created_at))
|
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
|
|
|
result = await session.execute(stmt)
|
|
records = result.scalars().all()
|
|
|
|
return {
|
|
"data": [
|
|
{
|
|
"testId": str(r.id),
|
|
"sessionId": r.session_id,
|
|
"status": r.status,
|
|
"stepCount": len(r.steps),
|
|
"totalDurationMs": r.total_duration_ms,
|
|
"createdAt": r.created_at.isoformat(),
|
|
}
|
|
for r in records
|
|
],
|
|
"page": page,
|
|
"pageSize": page_size,
|
|
"total": total,
|
|
}
|
|
|
|
|
|
class CompareRequest(BaseModel):
|
|
"""Request for comparison test."""
|
|
|
|
message: str
|
|
baseline_config: dict[str, Any] | None = None
|
|
test_config: dict[str, Any] | None = None
|
|
|
|
|
|
@router.post(
|
|
"/compare",
|
|
operation_id="compareFlowTest",
|
|
summary="Compare two flow executions",
|
|
description="[AC-AISVC-95] Compare baseline and test configurations.",
|
|
)
|
|
async def compare_flow_test(
|
|
request: CompareRequest,
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
[AC-AISVC-95] Compare two flow executions with different configurations.
|
|
|
|
Useful for:
|
|
- A/B testing prompt templates
|
|
- Comparing RAG retrieval strategies
|
|
- Testing guardrail effectiveness
|
|
"""
|
|
import time
|
|
|
|
from app.models import ChatRequest, ChannelType
|
|
from app.services.llm.factory import get_llm_config_manager
|
|
from app.services.memory import MemoryService
|
|
from app.services.orchestrator import OrchestratorService
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-95] Running comparison test for tenant={tenant_id}"
|
|
)
|
|
|
|
baseline_session_id = f"compare_baseline_{uuid.uuid4().hex[:8]}"
|
|
test_session_id = f"compare_test_{uuid.uuid4().hex[:8]}"
|
|
|
|
memory_service = MemoryService(session)
|
|
llm_config_manager = get_llm_config_manager()
|
|
llm_client = llm_config_manager.get_client()
|
|
retriever = await get_optimized_retriever()
|
|
|
|
orchestrator = OrchestratorService(
|
|
llm_client=llm_client,
|
|
memory_service=memory_service,
|
|
retriever=retriever,
|
|
)
|
|
|
|
baseline_chat_request = ChatRequest(
|
|
session_id=baseline_session_id,
|
|
current_message=request.message,
|
|
channel_type=ChannelType.WECHAT,
|
|
history=[],
|
|
)
|
|
|
|
baseline_start = time.time()
|
|
baseline_result = await orchestrator.generate(
|
|
tenant_id=tenant_id,
|
|
request=baseline_chat_request,
|
|
)
|
|
baseline_duration = int((time.time() - baseline_start) * 1000)
|
|
|
|
test_chat_request = ChatRequest(
|
|
session_id=test_session_id,
|
|
current_message=request.message,
|
|
channel_type=ChannelType.WECHAT,
|
|
history=[],
|
|
)
|
|
|
|
test_start = time.time()
|
|
test_result = await orchestrator.generate(
|
|
tenant_id=tenant_id,
|
|
request=test_chat_request,
|
|
)
|
|
test_duration = int((time.time() - test_start) * 1000)
|
|
|
|
return {
|
|
"baseline": {
|
|
"sessionId": baseline_session_id,
|
|
"reply": baseline_result.reply,
|
|
"confidence": baseline_result.confidence,
|
|
"durationMs": baseline_duration,
|
|
"steps": baseline_result.metadata.get("execution_steps", []) if baseline_result.metadata else [],
|
|
},
|
|
"test": {
|
|
"sessionId": test_session_id,
|
|
"reply": test_result.reply,
|
|
"confidence": test_result.confidence,
|
|
"durationMs": test_duration,
|
|
"steps": test_result.metadata.get("execution_steps", []) if test_result.metadata else [],
|
|
},
|
|
"comparison": {
|
|
"durationDiffMs": test_duration - baseline_duration,
|
|
"confidenceDiff": (test_result.confidence or 0) - (baseline_result.confidence or 0),
|
|
},
|
|
}
|