422 lines
21 KiB
Python
422 lines
21 KiB
Python
"""
|
||
Mid Platform schemas.
|
||
[AC-IDMP-01, AC-IDMP-02, AC-IDMP-07, AC-IDMP-11, AC-IDMP-12, AC-IDMP-15, AC-IDMP-17, AC-IDMP-18, AC-IDMP-19, AC-IDMP-20]
|
||
Aligned with spec/intent-driven-mid-platform/openapi.provider.yaml
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import uuid
|
||
from enum import Enum
|
||
from typing import Any
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
class ExecutionMode(str, Enum):
|
||
"""[AC-IDMP-02] Execution mode for dialogue response."""
|
||
AGENT = "agent"
|
||
MICRO_FLOW = "micro_flow"
|
||
FIXED = "fixed"
|
||
TRANSFER = "transfer"
|
||
|
||
|
||
class HighRiskScenario(str, Enum):
|
||
"""[AC-IDMP-20] High risk scenario types for mandatory takeover."""
|
||
REFUND = "refund"
|
||
COMPLAINT_ESCALATION = "complaint_escalation"
|
||
PRIVACY_SENSITIVE_PROMISE = "privacy_sensitive_promise"
|
||
TRANSFER = "transfer"
|
||
|
||
|
||
class ToolCallStatus(str, Enum):
|
||
"""[AC-IDMP-15] Tool call status."""
|
||
OK = "ok"
|
||
TIMEOUT = "timeout"
|
||
ERROR = "error"
|
||
REJECTED = "rejected"
|
||
|
||
|
||
class ToolType(str, Enum):
|
||
"""[AC-IDMP-19] Tool type for registry governance."""
|
||
INTERNAL = "internal"
|
||
MCP = "mcp"
|
||
|
||
|
||
class SessionMode(str, Enum):
|
||
"""[AC-IDMP-09] Session mode for bot/human switching."""
|
||
BOT_ACTIVE = "BOT_ACTIVE"
|
||
HUMAN_ACTIVE = "HUMAN_ACTIVE"
|
||
|
||
|
||
class HistoryMessage(BaseModel):
|
||
"""[AC-IDMP-03] History message with only delivered content."""
|
||
role: str = Field(..., description="Message role: user, assistant, or human")
|
||
content: str = Field(..., description="Message content")
|
||
|
||
|
||
class InterruptedSegment(BaseModel):
|
||
"""[AC-IDMP-04] Interrupted segment for handling user interruption."""
|
||
segment_id: str = Field(..., description="Segment ID")
|
||
content: str = Field(..., description="Segment content")
|
||
|
||
|
||
class FeatureFlags(BaseModel):
|
||
"""[AC-IDMP-17] Feature flags for session-level grayscale and rollback."""
|
||
agent_enabled: bool | None = Field(default=True, description="Session-level Agent grayscale switch")
|
||
rollback_to_legacy: bool | None = Field(default=False, description="Force rollback to legacy pipeline")
|
||
|
||
|
||
class HumanizeConfigRequest(BaseModel):
|
||
"""[AC-MARH-11] 拟人化配置请求。"""
|
||
enabled: bool | None = Field(default=True, description="Enable humanize strategy")
|
||
min_delay_ms: int | None = Field(default=50, ge=0, description="Minimum delay in milliseconds")
|
||
max_delay_ms: int | None = Field(default=500, ge=0, description="Maximum delay in milliseconds")
|
||
length_bucket_strategy: str | None = Field(default="simple", description="Strategy: simple or semantic")
|
||
|
||
|
||
class DialogueRequest(BaseModel):
|
||
"""[AC-IDMP-01, AC-IDMP-03, AC-IDMP-04, AC-IDMP-17, AC-MARH-11] Dialogue request schema."""
|
||
session_id: str = Field(..., description="Session ID for conversation tracking")
|
||
user_id: str | None = Field(default=None, description="User ID for memory recall and update")
|
||
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
|
||
history: list[HistoryMessage] = Field(default_factory=list, description="Only delivered history")
|
||
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="Interrupted segments")
|
||
feature_flags: FeatureFlags | None = Field(default=None, description="Feature flags for grayscale control")
|
||
humanize_config: HumanizeConfigRequest | None = Field(
|
||
default=None, description="Humanize config for segment delay"
|
||
)
|
||
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
|
||
|
||
|
||
class Segment(BaseModel):
|
||
"""[AC-IDMP-01] Response segment with delay control."""
|
||
segment_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Segment ID")
|
||
text: str = Field(..., description="Segment text content")
|
||
delay_after: int = Field(default=0, ge=0, description="Delay after this segment in milliseconds")
|
||
|
||
|
||
class TimeoutProfile(BaseModel):
|
||
"""[AC-MARH-08, AC-MARH-09] Timeout configuration profile."""
|
||
per_tool_timeout_ms: int = Field(default=30000, le=60000, description="Per-tool timeout in milliseconds")
|
||
llm_timeout_ms: int = Field(default=60000, le=120000, description="LLM call timeout in milliseconds")
|
||
end_to_end_timeout_ms: int = Field(default=120000, le=180000, description="End-to-end timeout in milliseconds")
|
||
|
||
|
||
class MetricsSnapshot(BaseModel):
|
||
"""[AC-IDMP-18] Runtime metrics snapshot."""
|
||
task_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Task completion rate")
|
||
slot_completion_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Slot completion rate")
|
||
wrong_transfer_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="Wrong transfer rate")
|
||
no_recall_rate: float | None = Field(default=None, ge=0.0, le=1.0, description="No recall rate")
|
||
avg_latency_ms: float | None = Field(default=None, ge=0.0, description="Average latency in milliseconds")
|
||
|
||
|
||
class ToolCallTrace(BaseModel):
|
||
"""[AC-IDMP-15, AC-IDMP-19] Tool call trace for observability."""
|
||
tool_name: str = Field(..., description="Tool name")
|
||
tool_type: ToolType | None = Field(default=ToolType.INTERNAL, description="Tool type: internal or mcp")
|
||
registry_version: str | None = Field(default=None, description="Tool registry version")
|
||
auth_applied: bool | None = Field(default=False, description="Whether auth was applied")
|
||
duration_ms: int = Field(..., ge=0, description="Duration in milliseconds")
|
||
status: ToolCallStatus = Field(..., description="Tool call status")
|
||
error_code: str | None = Field(default=None, description="Error code if failed")
|
||
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
|
||
result_digest: str | None = Field(default=None, description="Result digest for logging")
|
||
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
|
||
result: Any = Field(default=None, description="Full tool call result")
|
||
|
||
|
||
class SegmentStats(BaseModel):
|
||
"""[AC-MARH-12] Segment statistics for humanize strategy."""
|
||
segment_count: int = Field(default=0, ge=0, description="Number of segments")
|
||
avg_segment_length: float = Field(default=0.0, ge=0.0, description="Average segment length")
|
||
humanize_strategy: str | None = Field(default=None, description="Humanize strategy used")
|
||
|
||
|
||
class TraceInfo(BaseModel):
|
||
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
|
||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
|
||
mode: ExecutionMode = Field(..., description="Execution mode")
|
||
intent: str | None = Field(default=None, description="Matched intent")
|
||
request_id: str | None = Field(
|
||
default_factory=lambda: str(uuid.uuid4()), description="Request ID"
|
||
)
|
||
generation_id: str | None = Field(
|
||
default_factory=lambda: str(uuid.uuid4()),
|
||
description="Generation ID for interrupt handling",
|
||
)
|
||
guardrail_triggered: bool | None = Field(default=False, description="Whether guardrail was triggered")
|
||
guardrail_rule_id: str | None = Field(default=None, description="Guardrail rule ID that triggered")
|
||
interrupt_consumed: bool | None = Field(default=False, description="Whether interrupted segments were consumed")
|
||
kb_tool_called: bool | None = Field(default=False, description="Whether KB tool was called")
|
||
kb_hit: bool | None = Field(default=False, description="Whether KB search had results")
|
||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
|
||
react_iterations: int | None = Field(default=0, ge=0, le=5, description="ReAct loop iterations")
|
||
timeout_profile: TimeoutProfile | None = Field(default=None, description="Timeout profile")
|
||
segment_stats: SegmentStats | None = Field(default=None, description="Segment statistics")
|
||
metrics_snapshot: MetricsSnapshot | None = Field(default=None, description="Metrics snapshot")
|
||
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
|
||
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
|
||
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
|
||
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
|
||
created_at: str | None = Field(default=None, description="Creation timestamp")
|
||
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
|
||
scene: str | None = Field(default=None, description="当前场景标识")
|
||
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
|
||
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
|
||
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
|
||
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
|
||
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
|
||
# [Step-KB-Binding] 步骤知识库绑定追踪
|
||
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
|
||
|
||
|
||
class DialogueResponse(BaseModel):
|
||
"""[AC-IDMP-01, AC-IDMP-02] Dialogue response with segments and trace."""
|
||
segments: list[Segment] = Field(..., description="Response segments")
|
||
trace: TraceInfo = Field(..., description="Trace info for observability")
|
||
|
||
|
||
class ReportedMessage(BaseModel):
|
||
"""[AC-IDMP-08] Reported message for message report API."""
|
||
role: str = Field(..., description="Message role: user, assistant, human, or system")
|
||
content: str = Field(..., description="Message content")
|
||
source: str = Field(..., description="Message source: bot, human, or channel")
|
||
timestamp: str = Field(..., description="Message timestamp in ISO format")
|
||
segment_id: str | None = Field(default=None, description="Segment ID if applicable")
|
||
|
||
|
||
class MessageReportRequest(BaseModel):
|
||
"""[AC-IDMP-08] Message report request schema."""
|
||
session_id: str = Field(..., description="Session ID")
|
||
messages: list[ReportedMessage] = Field(..., description="Messages to report")
|
||
|
||
|
||
class SwitchModeRequest(BaseModel):
|
||
"""[AC-IDMP-09] Switch session mode request."""
|
||
mode: SessionMode = Field(..., description="Target mode: BOT_ACTIVE or HUMAN_ACTIVE")
|
||
reason: str | None = Field(default=None, description="Reason for mode switch")
|
||
|
||
|
||
class SwitchModeResponse(BaseModel):
|
||
"""[AC-IDMP-09] Switch session mode response."""
|
||
session_id: str = Field(..., description="Session ID")
|
||
mode: SessionMode = Field(..., description="Current mode after switch")
|
||
|
||
|
||
class MidSessionState(BaseModel):
|
||
"""Internal session state for mid platform."""
|
||
session_id: str
|
||
tenant_id: str
|
||
mode: SessionMode = SessionMode.BOT_ACTIVE
|
||
generation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||
active_flow_id: str | None = None
|
||
context: dict[str, Any] | None = None
|
||
created_at: str | None = None
|
||
updated_at: str | None = None
|
||
|
||
|
||
class PolicyRouterResult(BaseModel):
|
||
"""[AC-IDMP-02, AC-IDMP-05, AC-IDMP-16] Policy router decision result."""
|
||
mode: ExecutionMode = Field(..., description="Decided execution mode")
|
||
intent: str | None = Field(default=None, description="Matched intent")
|
||
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Intent confidence")
|
||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason if applicable")
|
||
high_risk_triggered: bool = Field(default=False, description="Whether high-risk scenario triggered")
|
||
target_flow_id: str | None = Field(default=None, description="Target flow ID for micro_flow mode")
|
||
fixed_reply: str | None = Field(default=None, description="Fixed reply for fixed mode")
|
||
transfer_message: str | None = Field(default=None, description="Transfer message for transfer mode")
|
||
|
||
|
||
class ReActContext(BaseModel):
|
||
"""[AC-IDMP-11] ReAct loop context for iteration control."""
|
||
iteration: int = Field(default=0, ge=0, le=5, description="Current iteration count")
|
||
max_iterations: int = Field(default=5, ge=3, le=5, description="Maximum iterations allowed")
|
||
tool_calls: list[ToolCallTrace] = Field(default_factory=list, description="Tool call history")
|
||
should_continue: bool = Field(default=True, description="Whether to continue ReAct loop")
|
||
final_answer: str | None = Field(default=None, description="Final answer if completed")
|
||
|
||
|
||
class CreateShareRequest(BaseModel):
|
||
"""[AC-IDMP-SHARE] Request to create a shared session."""
|
||
title: str | None = Field(default=None, max_length=255, description="Share title")
|
||
description: str | None = Field(default=None, max_length=1000, description="Share description")
|
||
expires_in_days: int = Field(default=7, ge=1, le=365, description="Expiration time in days")
|
||
max_concurrent_users: int = Field(default=10, ge=1, le=100, description="Maximum concurrent users")
|
||
|
||
|
||
class ShareResponse(BaseModel):
|
||
"""[AC-IDMP-SHARE] Response after creating a share."""
|
||
share_token: str = Field(..., description="Unique share token")
|
||
share_url: str = Field(..., description="Full share URL")
|
||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||
title: str | None = Field(default=None, description="Share title")
|
||
description: str | None = Field(default=None, description="Share description")
|
||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||
|
||
|
||
class SharedSessionInfo(BaseModel):
|
||
"""[AC-IDMP-SHARE] Information about a shared session."""
|
||
session_id: str = Field(..., description="Session ID")
|
||
title: str | None = Field(default=None, description="Share title")
|
||
description: str | None = Field(default=None, description="Share description")
|
||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||
current_users: int = Field(..., description="Current online users")
|
||
history: list[HistoryMessage] = Field(default_factory=list, description="Historical messages")
|
||
|
||
|
||
class SharedMessageRequest(BaseModel):
|
||
"""[AC-IDMP-SHARE] Request to send a message via shared session."""
|
||
user_message: str = Field(..., min_length=1, max_length=2000, description="User message content")
|
||
|
||
|
||
class ShareListItem(BaseModel):
|
||
"""[AC-IDMP-SHARE] Share list item for listing all shares of a session."""
|
||
share_token: str = Field(..., description="Share token")
|
||
share_url: str = Field(..., description="Full share URL")
|
||
title: str | None = Field(default=None, description="Share title")
|
||
description: str | None = Field(default=None, description="Share description")
|
||
expires_at: str = Field(..., description="Expiration time in ISO format")
|
||
is_active: bool = Field(..., description="Whether share is active")
|
||
max_concurrent_users: int = Field(..., description="Maximum concurrent users")
|
||
current_users: int = Field(..., description="Current online users")
|
||
created_at: str = Field(..., description="Creation time in ISO format")
|
||
|
||
|
||
class ShareListResponse(BaseModel):
|
||
"""[AC-IDMP-SHARE] Response for listing shares."""
|
||
shares: list[ShareListItem] = Field(..., description="List of shares")
|
||
|
||
|
||
class KbSearchDynamicHit(BaseModel):
|
||
"""[AC-MARH-05] Single KB search hit."""
|
||
id: str = Field(..., description="Hit ID")
|
||
content: str = Field(..., description="Hit content")
|
||
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score")
|
||
metadata: dict[str, Any] = Field(default_factory=dict, description="Hit metadata")
|
||
|
||
|
||
class MissingRequiredSlot(BaseModel):
|
||
"""[AC-MARH-05] Missing required slot info."""
|
||
field_key: str = Field(..., description="Field key")
|
||
label: str = Field(..., description="Field label")
|
||
reason: str = Field(..., description="Missing reason")
|
||
|
||
|
||
class KbSearchDynamicResultSchema(BaseModel):
|
||
"""[AC-MARH-05, AC-MARH-06] KB dynamic search result schema."""
|
||
success: bool = Field(..., description="Whether search succeeded")
|
||
hits: list[KbSearchDynamicHit] = Field(default_factory=list, description="Search hits")
|
||
applied_filter: dict[str, Any] = Field(default_factory=dict, description="Applied filter")
|
||
missing_required_slots: list[MissingRequiredSlot] = Field(
|
||
default_factory=list, description="Missing required slots"
|
||
)
|
||
filter_debug: dict[str, Any] = Field(default_factory=dict, description="Filter debug info")
|
||
fallback_reason_code: str | None = Field(default=None, description="Fallback reason code")
|
||
duration_ms: int = Field(default=0, ge=0, description="Duration in milliseconds")
|
||
|
||
|
||
class IntentHintOutput(BaseModel):
|
||
"""[AC-IDMP-02, AC-IDMP-16] 轻量意图提示工具输出。"""
|
||
intent: str | None = Field(default=None, description="识别到的意图名称")
|
||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
|
||
response_type: str | None = Field(
|
||
default=None,
|
||
description="响应类型: fixed|rag|flow|transfer|null"
|
||
)
|
||
suggested_mode: ExecutionMode | None = Field(
|
||
default=None,
|
||
description="建议执行模式: agent|micro_flow|fixed|transfer"
|
||
)
|
||
target_flow_id: str | None = Field(default=None, description="目标流程ID(flow模式)")
|
||
target_kb_ids: list[str] | None = Field(default=None, description="目标知识库ID列表")
|
||
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
|
||
high_risk_detected: bool = Field(default=False, description="是否检测到高风险场景")
|
||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||
|
||
|
||
class HighRiskCheckResult(BaseModel):
|
||
"""[AC-IDMP-05, AC-IDMP-20] 高风险检测工具输出。"""
|
||
matched: bool = Field(default=False, description="是否命中高风险场景")
|
||
risk_scenario: HighRiskScenario | None = Field(
|
||
default=None,
|
||
description="风险场景: refund|complaint_escalation|privacy_sensitive_promise|transfer|none"
|
||
)
|
||
confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="置信度 0~1")
|
||
recommended_mode: ExecutionMode | None = Field(
|
||
default=None,
|
||
description="推荐执行模式: micro_flow|transfer|agent"
|
||
)
|
||
rule_id: str | None = Field(default=None, description="匹配的规则ID")
|
||
reason: str | None = Field(default=None, description="匹配原因说明")
|
||
fallback_reason_code: str | None = Field(default=None, description="降级原因码(工具失败时)")
|
||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||
matched_text: str | None = Field(default=None, description="匹配到的文本片段")
|
||
matched_pattern: str | None = Field(default=None, description="匹配到的模式(关键词或正则)")
|
||
|
||
|
||
class SlotSource(str, Enum):
|
||
"""[AC-IDMP-13] 槽位来源类型。"""
|
||
USER_CONFIRMED = "user_confirmed"
|
||
RULE_EXTRACTED = "rule_extracted"
|
||
LLM_INFERRED = "llm_inferred"
|
||
DEFAULT = "default"
|
||
|
||
|
||
class MemorySlot(BaseModel):
|
||
"""[AC-IDMP-13] 单个槽位信息。"""
|
||
key: str = Field(..., description="槽位键名")
|
||
value: Any = Field(..., description="槽位值")
|
||
source: SlotSource = Field(default=SlotSource.DEFAULT, description="槽位来源")
|
||
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="置信度")
|
||
updated_at: str | None = Field(default=None, description="最后更新时间")
|
||
|
||
|
||
class MemoryRecallResult(BaseModel):
|
||
"""[AC-IDMP-13] 记忆召回工具输出。"""
|
||
profile: dict[str, Any] = Field(default_factory=dict, description="用户基础属性")
|
||
facts: list[str] = Field(default_factory=list, description="事实型记忆列表")
|
||
preferences: dict[str, Any] = Field(default_factory=dict, description="用户偏好")
|
||
last_summary: str | None = Field(default=None, description="最近会话摘要")
|
||
slots: dict[str, MemorySlot] = Field(default_factory=dict, description="结构化槽位")
|
||
missing_slots: list[str] = Field(default_factory=list, description="缺失的必填槽位")
|
||
fallback_reason_code: str | None = Field(default=None, description="降级原因码")
|
||
duration_ms: int = Field(default=0, ge=0, description="执行耗时(毫秒)")
|
||
|
||
def get_context_for_prompt(self) -> str:
|
||
"""生成用于注入 Prompt 的上下文字符串。"""
|
||
parts = []
|
||
|
||
if self.profile:
|
||
profile_parts = []
|
||
for key, value in self.profile.items():
|
||
if value:
|
||
profile_parts.append(f"{key}: {value}")
|
||
if profile_parts:
|
||
parts.append("【用户属性】" + "、".join(profile_parts))
|
||
|
||
if self.facts:
|
||
parts.append("【已知事实】" + ";".join(self.facts[:5]))
|
||
|
||
if self.preferences:
|
||
pref_parts = []
|
||
for key, value in self.preferences.items():
|
||
if value:
|
||
pref_parts.append(f"{key}: {value}")
|
||
if pref_parts:
|
||
parts.append("【用户偏好】" + "、".join(pref_parts))
|
||
|
||
if self.last_summary:
|
||
parts.append(f"【上次会话摘要】{self.last_summary}")
|
||
|
||
if self.slots:
|
||
slot_parts = []
|
||
for key, slot in self.slots.items():
|
||
slot_parts.append(f"{key}={slot.value}")
|
||
if slot_parts:
|
||
parts.append("【已知槽位】" + ", ".join(slot_parts))
|
||
|
||
return "\n".join(parts) if parts else ""
|