ai-robot-core/ai-service/app/models/mid/schemas.py

422 lines
21 KiB
Python
Raw Normal View History

"""
Mid Platform schemas.
[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20]
Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ExecutionMode(str, Enum):
"""[AC-IDMP-02] Execution mode for dialogue response."""
AGENT = "agent"
MICRO_FLOW = "micro_flow"
FIXED = "fixed"
TRANSFER = "transfer"
class HighRiskScenario(str, Enum):
"""[AC-IDMP-20] High risk scenario types for mandatory takeover."""
REFUND = "refund"
COMPLAINT_ESCALATION = "complaint_escalation"
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
TRANSFER = "transfer"
class ToolCallStatus(str, Enum):
"""[AC-IDMP-15] Tool call status."""
OK = "ok"
TIMEOUT = "timeout"
ERROR = "error"
REJECTED = "rejected"
class ToolType(str, Enum):
"""[AC-IDMP-19] Tool type for registry governance."""
INTERNAL = "internal"
MCP = "mcp"
class SessionMode(str, Enum):
"""[AC-IDMP-09] Session mode for bot/human switching."""
BOT_ACTIVE = "BOT_ACTIVE"
HUMAN_ACTIVE = "HUMAN_ACTIVE"
class HistoryMessage(BaseModel):
"""[AC-IDMP-03] History message with only delivered content."""
role: str = Field(..., description="Message role: user, assistant, or human")
content: str = Field(..., description="Message content")
class InterruptedSegment(BaseModel):
"""[AC-IDMP-04] Interrupted segment for handling user interruption."""
segment_id: str = Field(..., description="Segment ID")
content: str = Field(..., description="Segment content")
class FeatureFlags(BaseModel):
"""[AC-IDMP-17] Feature flags for session-level grayscale and rollback."""
agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch")
rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline")
class HumanizeConfigRequest(BaseModel):
"""[AC-MARH-11] 拟人化配置请求。"""
enabled: bool | None = Field(default=True, description="Enable humanize strategy")
min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds")
max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds")
length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic")
class DialogueRequest(BaseModel):
"""[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema."""
session_id: str = Field(..., description="Session ID for conversation tracking")
user_id: str | None = Field(default=None, description="User ID for memory recall and update")
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history")
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments")
feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control")
humanize_config: HumanizeConfigRequest | None = Field(
default=None, description="Humanize config for segment delay"
)
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
class Segment(BaseModel):
"""[AC-IDMP-01] Response segment with delay control."""
segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID")
text: str = Field(..., description="Segment text content")
delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds")
class TimeoutProfile(BaseModel):
"""[AC-MARH-08, AC-MARH-09] Timeout configuration profile."""
per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds")
llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds")
end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds")
class MetricsSnapshot(BaseModel):
"""[AC-IDMP-18] Runtime metrics snapshot."""
task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate")
slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate")
wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate")
no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate")
avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds")
class ToolCallTrace(BaseModel):
"""[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability."""
tool_name: str = Field(..., description="Tool name")
tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp")
registry_version: str | None = Field(default=None, description="Tool registry version")
auth_applied: bool | None = Field(default=False, description="Whether auth was applied")
duration_ms: int = Field(..., ge=0, description="Duration in milliseconds")
status: ToolCallStatus = Field(..., description="Tool call status")
error_code: str | None = Field(default=None, description="Error code if failed")
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
result_digest: str | None = Field(default=None, description="Result digest for logging")
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
result: Any = Field(default=None, description="Full tool call result")
class SegmentStats(BaseModel):
"""[AC-MARH-12] Segment statistics for humanize strategy."""
segment_count: int = Field(default=0, ge=0, description="Number of segments")
avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length")
humanize_strategy: str | None = Field(default=None, description="Humanize strategy used")
class TraceInfo(BaseModel):
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
mode: ExecutionMode = Field(..., description="Execution mode")
intent: str | None = Field(default=None, description="Matched intent")
request_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()), description="Request ID"
)
generation_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Generation ID for interrupt handling",
)
guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered")
guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered")
interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed")
kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called")
kb_hit: bool | None = Field(default=False, description="Whether KB search had results")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations")
timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile")
segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics")
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot")
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
created_at: str | None = Field(default=None, description="Creation timestamp")
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
scene: str | None = Field(default=None, description="当前场景标识")
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
# [Step-KB-Binding] 步骤知识库绑定追踪
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
class DialogueResponse(BaseModel):
"""[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace."""
segments: list[Segment] = Field(..., description="Response segments")
trace: TraceInfo = Field(..., description="Trace info for observability")
class ReportedMessage(BaseModel):
"""[AC-IDMP-08] Reported message for message report API."""
role: str = Field(..., description="Message role: user, assistant, human, or system")
content: str = Field(..., description="Message content")
source: str = Field(..., description="Message source: bot, human, or channel")
timestamp: str = Field(..., description="Message timestamp in ISO format")
segment_id: str | None = Field(default=None, description="Segment ID if applicable")
class MessageReportRequest(BaseModel):
"""[AC-IDMP-08] Message report request schema."""
session_id: str = Field(..., description="Session ID")
messages: list[ReportedMessage] = Field(..., description="Messages to report")
class SwitchModeRequest(BaseModel):
"""[AC-IDMP-09] Switch session mode request."""
mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE")
reason: str | None = Field(default=None, description="Reason for mode switch")
class SwitchModeResponse(BaseModel):
"""[AC-IDMP-09] Switch session mode response."""
session_id: str = Field(..., description="Session ID")
mode: SessionMode = Field(..., description="Current mode after switch")
class MidSessionState(BaseModel):
"""Internal session state for mid platform."""
session_id: str
tenant_id: str
mode: SessionMode = SessionMode.BOT_ACTIVE
generation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
active_flow_id: str | None = None
context: dict[str, Any] | None = None
created_at: str | None = None
updated_at: str | None = None
class PolicyRouterResult(BaseModel):
"""[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result."""
mode: ExecutionMode = Field(..., description="Decided execution mode")
intent: str | None = Field(default=None, description="Matched intent")
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable")
high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered")
target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode")
fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode")
transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode")
class ReActContext(BaseModel):
"""[AC-IDMP-11] ReAct loop context for iteration control."""
iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count")
max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed")
tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history")
should_continue: bool = Field(default=True, description="Whether to continue ReAct loop")
final_answer: str | None = Field(default=None, description="Final answer if completed")
class CreateShareRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to create a shared session."""
title: str | None = Field(default=None, max_length=255, description="Share title")
description: str | None = Field(default=None, max_length=1000, description="Share description")
expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days")
max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users")
class ShareResponse(BaseModel):
"""[AC-IDMP-SHARE] Response after creating a share."""
share_token: str = Field(..., description="Unique share token")
share_url: str = Field(..., description="Full share URL")
expires_at: str = Field(..., description="Expiration time in ISO format")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
class SharedSessionInfo(BaseModel):
"""[AC-IDMP-SHARE] Information about a shared session."""
session_id: str = Field(..., description="Session ID")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages")
class SharedMessageRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to send a message via shared session."""
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
class ShareListItem(BaseModel):
"""[AC-IDMP-SHARE] Share list item for listing all shares of a session."""
share_token: str = Field(..., description="Share token")
share_url: str = Field(..., description="Full share URL")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
is_active: bool = Field(..., description="Whether share is active")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
created_at: str = Field(..., description="Creation time in ISO format")
class ShareListResponse(BaseModel):
"""[AC-IDMP-SHARE] Response for listing shares."""
shares: list[ShareListItem] = Field(..., description="List of shares")
class KbSearchDynamicHit(BaseModel):
"""[AC-MARH-05] Single KB search hit."""
id: str = Field(..., description="Hit ID")
content: str = Field(..., description="Hit content")
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score")
metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata")
class MissingRequiredSlot(BaseModel):
"""[AC-MARH-05] Missing required slot info."""
field_key: str = Field(..., description="Field key")
label: str = Field(..., description="Field label")
reason: str = Field(..., description="Missing reason")
class KbSearchDynamicResultSchema(BaseModel):
"""[AC-MARH-05, AC-MARH-06] KB dynamic search result schema."""
success: bool = Field(..., description="Whether search succeeded")
hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits")
applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter")
missing_required_slots: list[MissingRequiredSlot] = Field(
default_factory=list, description="Missing required slots"
)
filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds")
class IntentHintOutput(BaseModel):
"""[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。"""
intent: str | None = Field(default=None, description="识别到的意图名称")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
response_type: str | None = Field(
default=None,
description="响应类型: fixed|rag|flow|transfer|null"
)
suggested_mode: ExecutionMode | None = Field(
default=None,
description="建议执行模式: agent|micro_flow|fixed|transfer"
)
target_flow_id: str | None = Field(default=None, description="目标流程IDflow模式")
target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
class HighRiskCheckResult(BaseModel):
"""[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。"""
matched: bool = Field(default=False, description="是否命中高风险场景")
risk_scenario: HighRiskScenario | None = Field(
default=None,
description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none"
)
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
recommended_mode: ExecutionMode | None = Field(
default=None,
description="推荐执行模式: micro_flow|transfer|agent"
)
rule_id: str | None = Field(default=None, description="匹配的规则ID")
reason: str | None = Field(default=None, description="匹配原因说明")
fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
matched_text: str | None = Field(default=None, description="匹配到的文本片段")
matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)")
class SlotSource(str, Enum):
"""[AC-IDMP-13] 槽位来源类型。"""
USER_CONFIRMED = "user_confirmed"
RULE_EXTRACTED = "rule_extracted"
LLM_INFERRED = "llm_inferred"
DEFAULT = "default"
class MemorySlot(BaseModel):
"""[AC-IDMP-13] 单个槽位信息。"""
key: str = Field(..., description="槽位键名")
value: Any = Field(..., description="槽位值")
source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源")
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
updated_at: str | None = Field(default=None, description="最后更新时间")
class MemoryRecallResult(BaseModel):
"""[AC-IDMP-13] 记忆召回工具输出。"""
profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性")
facts: list[str] = Field(default_factory=list, description="事实型记忆列表")
preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好")
last_summary: str | None = Field(default=None, description="最近会话摘要")
slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位")
missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
def get_context_for_prompt(self) -> str:
"""生成用于注入 Prompt 的上下文字符串。"""
parts = []
if self.profile:
profile_parts = []
for key, value in self.profile.items():
if value:
profile_parts.append(f"{key}: {value}")
if profile_parts:
parts.append("【用户属性】" + "".join(profile_parts))
if self.facts:
parts.append("【已知事实】" + "".join(self.facts[:5]))
if self.preferences:
pref_parts = []
for key, value in self.preferences.items():
if value:
pref_parts.append(f"{key}: {value}")
if pref_parts:
parts.append("【用户偏好】" + "".join(pref_parts))
if self.last_summary:
parts.append(f"【上次会话摘要】{self.last_summary}")
if self.slots:
slot_parts = []
for key, slot in self.slots.items():
slot_parts.append(f"{key}={slot.value}")
if slot_parts:
parts.append("【已知槽位】" + ", ".join(slot_parts))
return "\n".join(parts) if parts else ""