301 lines
8.4 KiB
Python
301 lines
8.4 KiB
Python
"""
|
|
Runtime Observer for Mid Platform.
|
|
[AC-MARH-12] 运行时观测闭环。
|
|
|
|
汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats 等观测字段。
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from app.models.mid.schemas import (
|
|
ExecutionMode,
|
|
MetricsSnapshot,
|
|
SegmentStats,
|
|
TimeoutProfile,
|
|
ToolCallTrace,
|
|
TraceInfo,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RuntimeContext:
|
|
"""运行时上下文。"""
|
|
tenant_id: str = ""
|
|
session_id: str = ""
|
|
request_id: str = ""
|
|
generation_id: str = ""
|
|
mode: ExecutionMode = ExecutionMode.AGENT
|
|
intent: str | None = None
|
|
|
|
guardrail_triggered: bool = False
|
|
guardrail_rule_id: str | None = None
|
|
|
|
interrupt_consumed: bool = False
|
|
|
|
kb_tool_called: bool = False
|
|
kb_hit: bool = False
|
|
|
|
fallback_reason_code: str | None = None
|
|
|
|
react_iterations: int = 0
|
|
tool_calls: list[ToolCallTrace] = field(default_factory=list)
|
|
|
|
timeout_profile: TimeoutProfile | None = None
|
|
segment_stats: SegmentStats | None = None
|
|
metrics_snapshot: MetricsSnapshot | None = None
|
|
|
|
start_time: float = field(default_factory=time.time)
|
|
|
|
def to_trace_info(self) -> TraceInfo:
|
|
"""转换为 TraceInfo。"""
|
|
try:
|
|
return TraceInfo(
|
|
mode=self.mode,
|
|
intent=self.intent,
|
|
request_id=self.request_id,
|
|
generation_id=self.generation_id,
|
|
guardrail_triggered=self.guardrail_triggered,
|
|
guardrail_rule_id=self.guardrail_rule_id,
|
|
interrupt_consumed=self.interrupt_consumed,
|
|
kb_tool_called=self.kb_tool_called,
|
|
kb_hit=self.kb_hit,
|
|
fallback_reason_code=self.fallback_reason_code,
|
|
react_iterations=self.react_iterations,
|
|
timeout_profile=self.timeout_profile,
|
|
segment_stats=self.segment_stats,
|
|
metrics_snapshot=self.metrics_snapshot,
|
|
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
|
tool_calls=self.tool_calls if self.tool_calls else None,
|
|
)
|
|
except Exception as e:
|
|
import traceback
|
|
logger.error(
|
|
f"[RuntimeObserver] Failed to create TraceInfo: {e}\n"
|
|
f"Exception type: {type(e).__name__}\n"
|
|
f"Context: mode={self.mode}, request_id={self.request_id}, "
|
|
f"generation_id={self.generation_id}\n"
|
|
f"Traceback:\n{traceback.format_exc()}"
|
|
)
|
|
raise
|
|
|
|
|
|
class RuntimeObserver:
|
|
"""
|
|
[AC-MARH-12] 运行时观测器。
|
|
|
|
Features:
|
|
- 汇总 guardrail、interrupt、kb_hit、timeouts、segment_stats
|
|
- 生成完整 TraceInfo
|
|
- 记录观测日志
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._contexts: dict[str, RuntimeContext] = {}
|
|
|
|
def start_observation(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
request_id: str,
|
|
generation_id: str,
|
|
) -> RuntimeContext:
|
|
"""
|
|
[AC-MARH-12] 开始观测。
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
session_id: 会话 ID
|
|
request_id: 请求 ID
|
|
generation_id: 生成 ID
|
|
|
|
Returns:
|
|
RuntimeContext 实例
|
|
"""
|
|
ctx = RuntimeContext(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
request_id=request_id,
|
|
generation_id=generation_id,
|
|
)
|
|
|
|
self._contexts[request_id] = ctx
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Observation started: request_id={request_id}, "
|
|
f"session_id={session_id}"
|
|
)
|
|
|
|
return ctx
|
|
|
|
def get_context(self, request_id: str) -> RuntimeContext | None:
|
|
"""获取观测上下文。"""
|
|
return self._contexts.get(request_id)
|
|
|
|
def update_mode(
|
|
self,
|
|
request_id: str,
|
|
mode: ExecutionMode,
|
|
intent: str | None = None,
|
|
) -> None:
|
|
"""更新执行模式。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.mode = mode
|
|
ctx.intent = intent
|
|
|
|
def record_guardrail(
|
|
self,
|
|
request_id: str,
|
|
triggered: bool,
|
|
rule_id: str | None = None,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录护栏触发。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.guardrail_triggered = triggered
|
|
ctx.guardrail_rule_id = rule_id
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Guardrail recorded: request_id={request_id}, "
|
|
f"triggered={triggered}, rule_id={rule_id}"
|
|
)
|
|
|
|
def record_interrupt(
|
|
self,
|
|
request_id: str,
|
|
consumed: bool,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录中断处理。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.interrupt_consumed = consumed
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Interrupt recorded: request_id={request_id}, "
|
|
f"consumed={consumed}"
|
|
)
|
|
|
|
def record_kb(
|
|
self,
|
|
request_id: str,
|
|
tool_called: bool,
|
|
hit: bool,
|
|
fallback_reason: str | None = None,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录 KB 检索。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.kb_tool_called = tool_called
|
|
ctx.kb_hit = hit
|
|
|
|
if fallback_reason:
|
|
ctx.fallback_reason_code = fallback_reason
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] KB recorded: request_id={request_id}, "
|
|
f"tool_called={tool_called}, hit={hit}, fallback={fallback_reason}"
|
|
)
|
|
|
|
def record_react(
|
|
self,
|
|
request_id: str,
|
|
iterations: int,
|
|
tool_calls: list[ToolCallTrace] | None = None,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录 ReAct 循环。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.react_iterations = iterations
|
|
if tool_calls:
|
|
ctx.tool_calls = tool_calls
|
|
|
|
def record_timeout_profile(
|
|
self,
|
|
request_id: str,
|
|
profile: TimeoutProfile,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录超时配置。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.timeout_profile = profile
|
|
|
|
def record_segment_stats(
|
|
self,
|
|
request_id: str,
|
|
stats: SegmentStats,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录分段统计。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.segment_stats = stats
|
|
|
|
def record_metrics(
|
|
self,
|
|
request_id: str,
|
|
metrics: MetricsSnapshot,
|
|
) -> None:
|
|
"""[AC-MARH-12] 记录指标快照。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.metrics_snapshot = metrics
|
|
|
|
def set_fallback_reason(
|
|
self,
|
|
request_id: str,
|
|
reason: str,
|
|
) -> None:
|
|
"""设置降级原因。"""
|
|
ctx = self._contexts.get(request_id)
|
|
if ctx:
|
|
ctx.fallback_reason_code = reason
|
|
|
|
def end_observation(
|
|
self,
|
|
request_id: str,
|
|
) -> TraceInfo:
|
|
"""
|
|
[AC-MARH-12] 结束观测并生成 TraceInfo。
|
|
|
|
Args:
|
|
request_id: 请求 ID
|
|
|
|
Returns:
|
|
完整的 TraceInfo
|
|
"""
|
|
ctx = self._contexts.get(request_id)
|
|
if not ctx:
|
|
logger.warning(f"[AC-MARH-12] Context not found: {request_id}")
|
|
return TraceInfo(mode=ExecutionMode.FIXED)
|
|
|
|
duration_ms = int((time.time() - ctx.start_time) * 1000)
|
|
|
|
trace_info = ctx.to_trace_info()
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Observation ended: request_id={request_id}, "
|
|
f"mode={ctx.mode.value}, duration_ms={duration_ms}, "
|
|
f"guardrail={ctx.guardrail_triggered}, kb_hit={ctx.kb_hit}, "
|
|
f"segments={ctx.segment_stats.segment_count if ctx.segment_stats else 0}"
|
|
)
|
|
|
|
if request_id in self._contexts:
|
|
del self._contexts[request_id]
|
|
|
|
return trace_info
|
|
|
|
|
|
_runtime_observer: RuntimeObserver | None = None
|
|
|
|
|
|
def get_runtime_observer() -> RuntimeObserver:
|
|
"""获取或创建 RuntimeObserver 实例。"""
|
|
global _runtime_observer
|
|
if _runtime_observer is None:
|
|
_runtime_observer = RuntimeObserver()
|
|
return _runtime_observer
|