ai-robot-core/ai-service/app/services/mid/runtime_observer.py

301 lines
8.4 KiB
Python

"""
Runtime Observer for Mid Platform.
[AC-MARH-12] 运行时观测闭环。
汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats 等观测字段。
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from app.models.mid.schemas import (
ExecutionMode,
MetricsSnapshot,
SegmentStats,
TimeoutProfile,
ToolCallTrace,
TraceInfo,
)
logger = logging.getLogger(__name__)
@dataclass
class RuntimeContext:
"""运行时上下文。"""
tenant_id: str = ""
session_id: str = ""
request_id: str = ""
generation_id: str = ""
mode: ExecutionMode = ExecutionMode.AGENT
intent: str | None = None
guardrail_triggered: bool = False
guardrail_rule_id: str | None = None
interrupt_consumed: bool = False
kb_tool_called: bool = False
kb_hit: bool = False
fallback_reason_code: str | None = None
react_iterations: int = 0
tool_calls: list[ToolCallTrace] = field(default_factory=list)
timeout_profile: TimeoutProfile | None = None
segment_stats: SegmentStats | None = None
metrics_snapshot: MetricsSnapshot | None = None
start_time: float = field(default_factory=time.time)
def to_trace_info(self) -> TraceInfo:
"""转换为 TraceInfo。"""
try:
return TraceInfo(
mode=self.mode,
intent=self.intent,
request_id=self.request_id,
generation_id=self.generation_id,
guardrail_triggered=self.guardrail_triggered,
guardrail_rule_id=self.guardrail_rule_id,
interrupt_consumed=self.interrupt_consumed,
kb_tool_called=self.kb_tool_called,
kb_hit=self.kb_hit,
fallback_reason_code=self.fallback_reason_code,
react_iterations=self.react_iterations,
timeout_profile=self.timeout_profile,
segment_stats=self.segment_stats,
metrics_snapshot=self.metrics_snapshot,
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls if self.tool_calls else None,
)
except Exception as e:
import traceback
logger.error(
f"[RuntimeObserver] Failed to create TraceInfo: {e}\n"
f"Exception type: {type(e).__name__}\n"
f"Context: mode={self.mode}, request_id={self.request_id}, "
f"generation_id={self.generation_id}\n"
f"Traceback:\n{traceback.format_exc()}"
)
raise
class RuntimeObserver:
"""
[AC-MARH-12] 运行时观测器。
Features:
- 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats
- 生成完整 TraceInfo
- 记录观测日志
"""
def __init__(self):
self._contexts: dict[str, RuntimeContext] = {}
def start_observation(
self,
tenant_id: str,
session_id: str,
request_id: str,
generation_id: str,
) -> RuntimeContext:
"""
[AC-MARH-12] 开始观测。
Args:
tenant_id: 租户 ID
session_id: 会话 ID
request_id: 请求 ID
generation_id: 生成 ID
Returns:
RuntimeContext 实例
"""
ctx = RuntimeContext(
tenant_id=tenant_id,
session_id=session_id,
request_id=request_id,
generation_id=generation_id,
)
self._contexts[request_id] = ctx
logger.info(
f"[AC-MARH-12] Observation started: request_id={request_id}, "
f"session_id={session_id}"
)
return ctx
def get_context(self, request_id: str) -> RuntimeContext | None:
"""获取观测上下文。"""
return self._contexts.get(request_id)
def update_mode(
self,
request_id: str,
mode: ExecutionMode,
intent: str | None = None,
) -> None:
"""更新执行模式。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.mode = mode
ctx.intent = intent
def record_guardrail(
self,
request_id: str,
triggered: bool,
rule_id: str | None = None,
) -> None:
"""[AC-MARH-12] 记录护栏触发。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.guardrail_triggered = triggered
ctx.guardrail_rule_id = rule_id
logger.info(
f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, "
f"triggered={triggered}, rule_id={rule_id}"
)
def record_interrupt(
self,
request_id: str,
consumed: bool,
) -> None:
"""[AC-MARH-12] 记录中断处理。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.interrupt_consumed = consumed
logger.info(
f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, "
f"consumed={consumed}"
)
def record_kb(
self,
request_id: str,
tool_called: bool,
hit: bool,
fallback_reason: str | None = None,
) -> None:
"""[AC-MARH-12] 记录 KB 检索。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.kb_tool_called = tool_called
ctx.kb_hit = hit
if fallback_reason:
ctx.fallback_reason_code = fallback_reason
logger.info(
f"[AC-MARH-12] KB recorded: request_id={request_id}, "
f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}"
)
def record_react(
self,
request_id: str,
iterations: int,
tool_calls: list[ToolCallTrace] | None = None,
) -> None:
"""[AC-MARH-12] 记录 ReAct 循环。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.react_iterations = iterations
if tool_calls:
ctx.tool_calls = tool_calls
def record_timeout_profile(
self,
request_id: str,
profile: TimeoutProfile,
) -> None:
"""[AC-MARH-12] 记录超时配置。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.timeout_profile = profile
def record_segment_stats(
self,
request_id: str,
stats: SegmentStats,
) -> None:
"""[AC-MARH-12] 记录分段统计。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.segment_stats = stats
def record_metrics(
self,
request_id: str,
metrics: MetricsSnapshot,
) -> None:
"""[AC-MARH-12] 记录指标快照。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.metrics_snapshot = metrics
def set_fallback_reason(
self,
request_id: str,
reason: str,
) -> None:
"""设置降级原因。"""
ctx = self._contexts.get(request_id)
if ctx:
ctx.fallback_reason_code = reason
def end_observation(
self,
request_id: str,
) -> TraceInfo:
"""
[AC-MARH-12] 结束观测并生成 TraceInfo。
Args:
request_id: 请求 ID
Returns:
完整的 TraceInfo
"""
ctx = self._contexts.get(request_id)
if not ctx:
logger.warning(f"[AC-MARH-12] Context not found: {request_id}")
return TraceInfo(mode=ExecutionMode.FIXED)
duration_ms = int((time.time() - ctx.start_time) * 1000)
trace_info = ctx.to_trace_info()
logger.info(
f"[AC-MARH-12] Observation ended: request_id={request_id}, "
f"mode={ctx.mode.value}, duration_ms={duration_ms}, "
f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, "
f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}"
)
if request_id in self._contexts:
del self._contexts[request_id]
return trace_info
_runtime_observer: RuntimeObserver | None = None
def get_runtime_observer() -> RuntimeObserver:
"""获取或创建 RuntimeObserver 实例。"""
global _runtime_observer
if _runtime_observer is None:
_runtime_observer = RuntimeObserver()
return _runtime_observer