ai-robot-core/ai-service/app/models/mid/schemas.py

422 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mid Platform schemas.
[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20]
Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ExecutionMode(str, Enum):
"""[AC-IDMP-02] Execution mode for dialogue response."""
AGENT = "agent"
MICRO_FLOW = "micro_flow"
FIXED = "fixed"
TRANSFER = "transfer"
class HighRiskScenario(str, Enum):
"""[AC-IDMP-20] High risk scenario types for mandatory takeover."""
REFUND = "refund"
COMPLAINT_ESCALATION = "complaint_escalation"
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
TRANSFER = "transfer"
class ToolCallStatus(str, Enum):
"""[AC-IDMP-15] Tool call status."""
OK = "ok"
TIMEOUT = "timeout"
ERROR = "error"
REJECTED = "rejected"
class ToolType(str, Enum):
"""[AC-IDMP-19] Tool type for registry governance."""
INTERNAL = "internal"
MCP = "mcp"
class SessionMode(str, Enum):
"""[AC-IDMP-09] Session mode for bot/human switching."""
BOT_ACTIVE = "BOT_ACTIVE"
HUMAN_ACTIVE = "HUMAN_ACTIVE"
class HistoryMessage(BaseModel):
"""[AC-IDMP-03] History message with only delivered content."""
role: str = Field(..., description="Message role: user, assistant, or human")
content: str = Field(..., description="Message content")
class InterruptedSegment(BaseModel):
"""[AC-IDMP-04] Interrupted segment for handling user interruption."""
segment_id: str = Field(..., description="Segment ID")
content: str = Field(..., description="Segment content")
class FeatureFlags(BaseModel):
"""[AC-IDMP-17] Feature flags for session-level grayscale and rollback."""
agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch")
rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline")
class HumanizeConfigRequest(BaseModel):
"""[AC-MARH-11] 拟人化配置请求。"""
enabled: bool | None = Field(default=True, description="Enable humanize strategy")
min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds")
max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds")
length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic")
class DialogueRequest(BaseModel):
"""[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema."""
session_id: str = Field(..., description="Session ID for conversation tracking")
user_id: str | None = Field(default=None, description="User ID for memory recall and update")
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history")
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments")
feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control")
humanize_config: HumanizeConfigRequest | None = Field(
default=None, description="Humanize config for segment delay"
)
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
class Segment(BaseModel):
"""[AC-IDMP-01] Response segment with delay control."""
segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID")
text: str = Field(..., description="Segment text content")
delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds")
class TimeoutProfile(BaseModel):
"""[AC-MARH-08, AC-MARH-09] Timeout configuration profile."""
per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds")
llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds")
end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds")
class MetricsSnapshot(BaseModel):
"""[AC-IDMP-18] Runtime metrics snapshot."""
task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate")
slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate")
wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate")
no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate")
avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds")
class ToolCallTrace(BaseModel):
"""[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability."""
tool_name: str = Field(..., description="Tool name")
tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp")
registry_version: str | None = Field(default=None, description="Tool registry version")
auth_applied: bool | None = Field(default=False, description="Whether auth was applied")
duration_ms: int = Field(..., ge=0, description="Duration in milliseconds")
status: ToolCallStatus = Field(..., description="Tool call status")
error_code: str | None = Field(default=None, description="Error code if failed")
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
result_digest: str | None = Field(default=None, description="Result digest for logging")
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
result: Any = Field(default=None, description="Full tool call result")
class SegmentStats(BaseModel):
"""[AC-MARH-12] Segment statistics for humanize strategy."""
segment_count: int = Field(default=0, ge=0, description="Number of segments")
avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length")
humanize_strategy: str | None = Field(default=None, description="Humanize strategy used")
class TraceInfo(BaseModel):
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
mode: ExecutionMode = Field(..., description="Execution mode")
intent: str | None = Field(default=None, description="Matched intent")
request_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()), description="Request ID"
)
generation_id: str | None = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Generation ID for interrupt handling",
)
guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered")
guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered")
interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed")
kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called")
kb_hit: bool | None = Field(default=False, description="Whether KB search had results")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations")
timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile")
segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics")
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot")
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
created_at: str | None = Field(default=None, description="Creation timestamp")
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
scene: str | None = Field(default=None, description="当前场景标识")
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
# [Step-KB-Binding] 步骤知识库绑定追踪
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
class DialogueResponse(BaseModel):
"""[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace."""
segments: list[Segment] = Field(..., description="Response segments")
trace: TraceInfo = Field(..., description="Trace info for observability")
class ReportedMessage(BaseModel):
"""[AC-IDMP-08] Reported message for message report API."""
role: str = Field(..., description="Message role: user, assistant, human, or system")
content: str = Field(..., description="Message content")
source: str = Field(..., description="Message source: bot, human, or channel")
timestamp: str = Field(..., description="Message timestamp in ISO format")
segment_id: str | None = Field(default=None, description="Segment ID if applicable")
class MessageReportRequest(BaseModel):
"""[AC-IDMP-08] Message report request schema."""
session_id: str = Field(..., description="Session ID")
messages: list[ReportedMessage] = Field(..., description="Messages to report")
class SwitchModeRequest(BaseModel):
"""[AC-IDMP-09] Switch session mode request."""
mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE")
reason: str | None = Field(default=None, description="Reason for mode switch")
class SwitchModeResponse(BaseModel):
"""[AC-IDMP-09] Switch session mode response."""
session_id: str = Field(..., description="Session ID")
mode: SessionMode = Field(..., description="Current mode after switch")
class MidSessionState(BaseModel):
"""Internal session state for mid platform."""
session_id: str
tenant_id: str
mode: SessionMode = SessionMode.BOT_ACTIVE
generation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
active_flow_id: str | None = None
context: dict[str, Any] | None = None
created_at: str | None = None
updated_at: str | None = None
class PolicyRouterResult(BaseModel):
"""[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result."""
mode: ExecutionMode = Field(..., description="Decided execution mode")
intent: str | None = Field(default=None, description="Matched intent")
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable")
high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered")
target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode")
fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode")
transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode")
class ReActContext(BaseModel):
"""[AC-IDMP-11] ReAct loop context for iteration control."""
iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count")
max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed")
tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history")
should_continue: bool = Field(default=True, description="Whether to continue ReAct loop")
final_answer: str | None = Field(default=None, description="Final answer if completed")
class CreateShareRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to create a shared session."""
title: str | None = Field(default=None, max_length=255, description="Share title")
description: str | None = Field(default=None, max_length=1000, description="Share description")
expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days")
max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users")
class ShareResponse(BaseModel):
"""[AC-IDMP-SHARE] Response after creating a share."""
share_token: str = Field(..., description="Unique share token")
share_url: str = Field(..., description="Full share URL")
expires_at: str = Field(..., description="Expiration time in ISO format")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
class SharedSessionInfo(BaseModel):
"""[AC-IDMP-SHARE] Information about a shared session."""
session_id: str = Field(..., description="Session ID")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages")
class SharedMessageRequest(BaseModel):
"""[AC-IDMP-SHARE] Request to send a message via shared session."""
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
class ShareListItem(BaseModel):
"""[AC-IDMP-SHARE] Share list item for listing all shares of a session."""
share_token: str = Field(..., description="Share token")
share_url: str = Field(..., description="Full share URL")
title: str | None = Field(default=None, description="Share title")
description: str | None = Field(default=None, description="Share description")
expires_at: str = Field(..., description="Expiration time in ISO format")
is_active: bool = Field(..., description="Whether share is active")
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
current_users: int = Field(..., description="Current online users")
created_at: str = Field(..., description="Creation time in ISO format")
class ShareListResponse(BaseModel):
"""[AC-IDMP-SHARE] Response for listing shares."""
shares: list[ShareListItem] = Field(..., description="List of shares")
class KbSearchDynamicHit(BaseModel):
"""[AC-MARH-05] Single KB search hit."""
id: str = Field(..., description="Hit ID")
content: str = Field(..., description="Hit content")
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score")
metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata")
class MissingRequiredSlot(BaseModel):
"""[AC-MARH-05] Missing required slot info."""
field_key: str = Field(..., description="Field key")
label: str = Field(..., description="Field label")
reason: str = Field(..., description="Missing reason")
class KbSearchDynamicResultSchema(BaseModel):
"""[AC-MARH-05, AC-MARH-06] KB dynamic search result schema."""
success: bool = Field(..., description="Whether search succeeded")
hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits")
applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter")
missing_required_slots: list[MissingRequiredSlot] = Field(
default_factory=list, description="Missing required slots"
)
filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info")
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds")
class IntentHintOutput(BaseModel):
"""[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。"""
intent: str | None = Field(default=None, description="识别到的意图名称")
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
response_type: str | None = Field(
default=None,
description="响应类型: fixed|rag|flow|transfer|null"
)
suggested_mode: ExecutionMode | None = Field(
default=None,
description="建议执行模式: agent|micro_flow|fixed|transfer"
)
target_flow_id: str | None = Field(default=None, description="目标流程IDflow模式")
target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
class HighRiskCheckResult(BaseModel):
"""[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。"""
matched: bool = Field(default=False, description="是否命中高风险场景")
risk_scenario: HighRiskScenario | None = Field(
default=None,
description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none"
)
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
recommended_mode: ExecutionMode | None = Field(
default=None,
description="推荐执行模式: micro_flow|transfer|agent"
)
rule_id: str | None = Field(default=None, description="匹配的规则ID")
reason: str | None = Field(default=None, description="匹配原因说明")
fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
matched_text: str | None = Field(default=None, description="匹配到的文本片段")
matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)")
class SlotSource(str, Enum):
"""[AC-IDMP-13] 槽位来源类型。"""
USER_CONFIRMED = "user_confirmed"
RULE_EXTRACTED = "rule_extracted"
LLM_INFERRED = "llm_inferred"
DEFAULT = "default"
class MemorySlot(BaseModel):
"""[AC-IDMP-13] 单个槽位信息。"""
key: str = Field(..., description="槽位键名")
value: Any = Field(..., description="槽位值")
source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源")
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
updated_at: str | None = Field(default=None, description="最后更新时间")
class MemoryRecallResult(BaseModel):
"""[AC-IDMP-13] 记忆召回工具输出。"""
profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性")
facts: list[str] = Field(default_factory=list, description="事实型记忆列表")
preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好")
last_summary: str | None = Field(default=None, description="最近会话摘要")
slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位")
missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位")
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
def get_context_for_prompt(self) -> str:
"""生成用于注入 Prompt 的上下文字符串。"""
parts = []
if self.profile:
profile_parts = []
for key, value in self.profile.items():
if value:
profile_parts.append(f"{key}: {value}")
if profile_parts:
parts.append("【用户属性】" + "".join(profile_parts))
if self.facts:
parts.append("【已知事实】" + "".join(self.facts[:5]))
if self.preferences:
pref_parts = []
for key, value in self.preferences.items():
if value:
pref_parts.append(f"{key}: {value}")
if pref_parts:
parts.append("【用户偏好】" + "".join(pref_parts))
if self.last_summary:
parts.append(f"【上次会话摘要】{self.last_summary}")
if self.slots:
slot_parts = []
for key, slot in self.slots.items():
slot_parts.append(f"{key}={slot.value}")
if slot_parts:
parts.append("【已知槽位】" + ", ".join(slot_parts))
return "\n".join(parts) if parts else ""