270 lines
8.2 KiB
Python
270 lines
8.2 KiB
Python
"""
|
|
Trace Logger for Mid Platform.
|
|
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging.
|
|
|
|
Audit Fields:
|
|
- session_id, request_id, generation_id
|
|
- mode, intent, tool_calls
|
|
- guardrail_triggered, guardrail_rule_id
|
|
- interrupt_consumed
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from app.models.mid.schemas import (
|
|
ExecutionMode,
|
|
ToolCallTrace,
|
|
TraceInfo,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AuditRecord:
|
|
"""[AC-MARH-12] Audit record for database persistence."""
|
|
tenant_id: str
|
|
session_id: str
|
|
request_id: str
|
|
generation_id: str
|
|
mode: ExecutionMode
|
|
intent: str | None = None
|
|
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
guardrail_triggered: bool = False
|
|
guardrail_rule_id: str | None = None
|
|
interrupt_consumed: bool = False
|
|
fallback_reason_code: str | None = None
|
|
react_iterations: int = 0
|
|
latency_ms: int = 0
|
|
created_at: str | None = None
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"tenant_id": self.tenant_id,
|
|
"session_id": self.session_id,
|
|
"request_id": self.request_id,
|
|
"generation_id": self.generation_id,
|
|
"mode": self.mode.value,
|
|
"intent": self.intent,
|
|
"tool_calls": self.tool_calls,
|
|
"guardrail_triggered": self.guardrail_triggered,
|
|
"guardrail_rule_id": self.guardrail_rule_id,
|
|
"interrupt_consumed": self.interrupt_consumed,
|
|
"fallback_reason_code": self.fallback_reason_code,
|
|
"react_iterations": self.react_iterations,
|
|
"latency_ms": self.latency_ms,
|
|
"created_at": self.created_at,
|
|
}
|
|
|
|
|
|
class TraceLogger:
|
|
"""
|
|
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit.
|
|
|
|
Features:
|
|
- Request-scoped trace context
|
|
- Tool call tracing
|
|
- Guardrail event logging with rule_id
|
|
- Interrupt consumption tracking
|
|
- Audit record generation
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._traces: dict[str, TraceInfo] = {}
|
|
self._audit_records: list[AuditRecord] = []
|
|
|
|
def start_trace(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
request_id: str | None = None,
|
|
generation_id: str | None = None,
|
|
) -> TraceInfo:
|
|
"""
|
|
[AC-MARH-12] Start a new trace context.
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
session_id: Session ID
|
|
request_id: Request ID (auto-generated if not provided)
|
|
generation_id: Generation ID for interrupt handling
|
|
|
|
Returns:
|
|
TraceInfo for the new trace
|
|
"""
|
|
request_id = request_id or str(uuid.uuid4())
|
|
generation_id = generation_id or str(uuid.uuid4())
|
|
|
|
trace = TraceInfo(
|
|
mode=ExecutionMode.AGENT,
|
|
request_id=request_id,
|
|
generation_id=generation_id,
|
|
)
|
|
|
|
self._traces[request_id] = trace
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Trace started: request_id={request_id}, "
|
|
f"session_id={session_id}, generation_id={generation_id}"
|
|
)
|
|
|
|
return trace
|
|
|
|
def get_trace(self, request_id: str) -> TraceInfo | None:
|
|
"""Get trace by request ID."""
|
|
return self._traces.get(request_id)
|
|
|
|
def update_trace(
|
|
self,
|
|
request_id: str,
|
|
mode: ExecutionMode | None = None,
|
|
intent: str | None = None,
|
|
guardrail_triggered: bool | None = None,
|
|
guardrail_rule_id: str | None = None,
|
|
interrupt_consumed: bool | None = None,
|
|
fallback_reason_code: str | None = None,
|
|
react_iterations: int | None = None,
|
|
tool_calls: list[ToolCallTrace] | None = None,
|
|
) -> TraceInfo | None:
|
|
"""
|
|
[AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details.
|
|
"""
|
|
trace = self._traces.get(request_id)
|
|
if not trace:
|
|
logger.warning(f"[AC-MARH-12] Trace not found: {request_id}")
|
|
return None
|
|
|
|
if mode is not None:
|
|
trace.mode = mode
|
|
if intent is not None:
|
|
trace.intent = intent
|
|
if guardrail_triggered is not None:
|
|
trace.guardrail_triggered = guardrail_triggered
|
|
if guardrail_rule_id is not None:
|
|
trace.guardrail_rule_id = guardrail_rule_id
|
|
if interrupt_consumed is not None:
|
|
trace.interrupt_consumed = interrupt_consumed
|
|
if fallback_reason_code is not None:
|
|
trace.fallback_reason_code = fallback_reason_code
|
|
if react_iterations is not None:
|
|
trace.react_iterations = react_iterations
|
|
if tool_calls is not None:
|
|
trace.tool_calls = tool_calls
|
|
|
|
return trace
|
|
|
|
def add_tool_call(
|
|
self,
|
|
request_id: str,
|
|
tool_call: ToolCallTrace,
|
|
) -> None:
|
|
"""
|
|
[AC-MARH-12] Add tool call trace to request.
|
|
"""
|
|
trace = self._traces.get(request_id)
|
|
if not trace:
|
|
logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}")
|
|
return
|
|
|
|
if trace.tool_calls is None:
|
|
trace.tool_calls = []
|
|
|
|
trace.tool_calls.append(tool_call)
|
|
|
|
if trace.tools_used is None:
|
|
trace.tools_used = []
|
|
|
|
if tool_call.tool_name not in trace.tools_used:
|
|
trace.tools_used.append(tool_call.tool_name)
|
|
|
|
logger.debug(
|
|
f"[AC-MARH-12] Tool call recorded: request_id={request_id}, "
|
|
f"tool={tool_call.tool_name}, status={tool_call.status.value}"
|
|
)
|
|
|
|
def end_trace(
|
|
self,
|
|
request_id: str,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
latency_ms: int,
|
|
) -> AuditRecord:
|
|
"""
|
|
[AC-MARH-12] End trace and create audit record.
|
|
"""
|
|
trace = self._traces.get(request_id)
|
|
if not trace:
|
|
logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}")
|
|
return AuditRecord(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
request_id=request_id,
|
|
generation_id=str(uuid.uuid4()),
|
|
mode=ExecutionMode.FIXED,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
audit = AuditRecord(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
request_id=request_id,
|
|
generation_id=trace.generation_id or str(uuid.uuid4()),
|
|
mode=trace.mode,
|
|
intent=trace.intent,
|
|
tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [],
|
|
guardrail_triggered=trace.guardrail_triggered or False,
|
|
guardrail_rule_id=trace.guardrail_rule_id,
|
|
interrupt_consumed=trace.interrupt_consumed or False,
|
|
fallback_reason_code=trace.fallback_reason_code,
|
|
react_iterations=trace.react_iterations or 0,
|
|
latency_ms=latency_ms,
|
|
created_at=time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
)
|
|
|
|
self._audit_records.append(audit)
|
|
|
|
if request_id in self._traces:
|
|
del self._traces[request_id]
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Trace ended: request_id={request_id}, "
|
|
f"mode={trace.mode.value}, latency_ms={latency_ms}"
|
|
)
|
|
|
|
return audit
|
|
|
|
def get_audit_records(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str | None = None,
|
|
limit: int = 100,
|
|
) -> list[AuditRecord]:
|
|
"""Get audit records for a tenant/session."""
|
|
records = [
|
|
r for r in self._audit_records
|
|
if r.tenant_id == tenant_id
|
|
]
|
|
|
|
if session_id:
|
|
records = [r for r in records if r.session_id == session_id]
|
|
|
|
return records[-limit:]
|
|
|
|
def clear_audit_records(self, tenant_id: str | None = None) -> int:
|
|
"""Clear audit records, optionally filtered by tenant."""
|
|
if tenant_id:
|
|
original_count = len(self._audit_records)
|
|
self._audit_records = [
|
|
r for r in self._audit_records
|
|
if r.tenant_id != tenant_id
|
|
]
|
|
return original_count - len(self._audit_records)
|
|
else:
|
|
count = len(self._audit_records)
|
|
self._audit_records = []
|
|
return count
|