""" Runtime Observer for Mid Platform. [AC-MARH-12] 运行时观测闭环。 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats 等观测字段。 """ import logging import time from dataclasses import dataclass, field from typing import Any from app.models.mid.schemas import ( ExecutionMode, MetricsSnapshot, SegmentStats, TimeoutProfile, ToolCallTrace, TraceInfo, ) logger = logging.getLogger(__name__) @dataclass class RuntimeContext: """运行时上下文。""" tenant_id: str = "" session_id: str = "" request_id: str = "" generation_id: str = "" mode: ExecutionMode = ExecutionMode.AGENT intent: str | None = None guardrail_triggered: bool = False guardrail_rule_id: str | None = None interrupt_consumed: bool = False kb_tool_called: bool = False kb_hit: bool = False fallback_reason_code: str | None = None react_iterations: int = 0 tool_calls: list[ToolCallTrace] = field(default_factory=list) timeout_profile: TimeoutProfile | None = None segment_stats: SegmentStats | None = None metrics_snapshot: MetricsSnapshot | None = None start_time: float = field(default_factory=time.time) def to_trace_info(self) -> TraceInfo: """转换为 TraceInfo。""" try: return TraceInfo( mode=self.mode, intent=self.intent, request_id=self.request_id, generation_id=self.generation_id, guardrail_triggered=self.guardrail_triggered, guardrail_rule_id=self.guardrail_rule_id, interrupt_consumed=self.interrupt_consumed, kb_tool_called=self.kb_tool_called, kb_hit=self.kb_hit, fallback_reason_code=self.fallback_reason_code, react_iterations=self.react_iterations, timeout_profile=self.timeout_profile, segment_stats=self.segment_stats, metrics_snapshot=self.metrics_snapshot, tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None, tool_calls=self.tool_calls if self.tool_calls else None, ) except Exception as e: import traceback logger.error( f"[RuntimeObserver] Failed to create TraceInfo: {e}\n" f"Exception type: {type(e).__name__}\n" f"Context: mode={self.mode}, request_id={self.request_id}, " f"generation_id={self.generation_id}\n" f"Traceback:\n{traceback.format_exc()}" ) raise class RuntimeObserver: """ [AC-MARH-12] 运行时观测器。 Features: - 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats - 生成完整 TraceInfo - 记录观测日志 """ def __init__(self): self._contexts: dict[str, RuntimeContext] = {} def start_observation( self, tenant_id: str, session_id: str, request_id: str, generation_id: str, ) -> RuntimeContext: """ [AC-MARH-12] 开始观测。 Args: tenant_id: 租户 ID session_id: 会话 ID request_id: 请求 ID generation_id: 生成 ID Returns: RuntimeContext 实例 """ ctx = RuntimeContext( tenant_id=tenant_id, session_id=session_id, request_id=request_id, generation_id=generation_id, ) self._contexts[request_id] = ctx logger.info( f"[AC-MARH-12] Observation started: request_id={request_id}, " f"session_id={session_id}" ) return ctx def get_context(self, request_id: str) -> RuntimeContext | None: """获取观测上下文。""" return self._contexts.get(request_id) def update_mode( self, request_id: str, mode: ExecutionMode, intent: str | None = None, ) -> None: """更新执行模式。""" ctx = self._contexts.get(request_id) if ctx: ctx.mode = mode ctx.intent = intent def record_guardrail( self, request_id: str, triggered: bool, rule_id: str | None = None, ) -> None: """[AC-MARH-12] 记录护栏触发。""" ctx = self._contexts.get(request_id) if ctx: ctx.guardrail_triggered = triggered ctx.guardrail_rule_id = rule_id logger.info( f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, " f"triggered={triggered}, rule_id={rule_id}" ) def record_interrupt( self, request_id: str, consumed: bool, ) -> None: """[AC-MARH-12] 记录中断处理。""" ctx = self._contexts.get(request_id) if ctx: ctx.interrupt_consumed = consumed logger.info( f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, " f"consumed={consumed}" ) def record_kb( self, request_id: str, tool_called: bool, hit: bool, fallback_reason: str | None = None, ) -> None: """[AC-MARH-12] 记录 KB 检索。""" ctx = self._contexts.get(request_id) if ctx: ctx.kb_tool_called = tool_called ctx.kb_hit = hit if fallback_reason: ctx.fallback_reason_code = fallback_reason logger.info( f"[AC-MARH-12] KB recorded: request_id={request_id}, " f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}" ) def record_react( self, request_id: str, iterations: int, tool_calls: list[ToolCallTrace] | None = None, ) -> None: """[AC-MARH-12] 记录 ReAct 循环。""" ctx = self._contexts.get(request_id) if ctx: ctx.react_iterations = iterations if tool_calls: ctx.tool_calls = tool_calls def record_timeout_profile( self, request_id: str, profile: TimeoutProfile, ) -> None: """[AC-MARH-12] 记录超时配置。""" ctx = self._contexts.get(request_id) if ctx: ctx.timeout_profile = profile def record_segment_stats( self, request_id: str, stats: SegmentStats, ) -> None: """[AC-MARH-12] 记录分段统计。""" ctx = self._contexts.get(request_id) if ctx: ctx.segment_stats = stats def record_metrics( self, request_id: str, metrics: MetricsSnapshot, ) -> None: """[AC-MARH-12] 记录指标快照。""" ctx = self._contexts.get(request_id) if ctx: ctx.metrics_snapshot = metrics def set_fallback_reason( self, request_id: str, reason: str, ) -> None: """设置降级原因。""" ctx = self._contexts.get(request_id) if ctx: ctx.fallback_reason_code = reason def end_observation( self, request_id: str, ) -> TraceInfo: """ [AC-MARH-12] 结束观测并生成 TraceInfo。 Args: request_id: 请求 ID Returns: 完整的 TraceInfo """ ctx = self._contexts.get(request_id) if not ctx: logger.warning(f"[AC-MARH-12] Context not found: {request_id}") return TraceInfo(mode=ExecutionMode.FIXED) duration_ms = int((time.time() - ctx.start_time) * 1000) trace_info = ctx.to_trace_info() logger.info( f"[AC-MARH-12] Observation ended: request_id={request_id}, " f"mode={ctx.mode.value}, duration_ms={duration_ms}, " f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, " f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}" ) if request_id in self._contexts: del self._contexts[request_id] return trace_info _runtime_observer: RuntimeObserver | None = None def get_runtime_observer() -> RuntimeObserver: """获取或创建 RuntimeObserver 实例。""" global _runtime_observer if _runtime_observer is None: _runtime_observer = RuntimeObserver() return _runtime_observer