""" Trace Logger for Mid Platform. [AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace collection and audit logging. Audit Fields: - session_id, request_id, generation_id - mode, intent, tool_calls - guardrail_triggered, guardrail_rule_id - interrupt_consumed """ import logging import time import uuid from dataclasses import dataclass, field from typing import Any from app.models.mid.schemas import ( ExecutionMode, ToolCallTrace, TraceInfo, ) logger = logging.getLogger(__name__) @dataclass class AuditRecord: """[AC-MARH-12] Audit record for database persistence.""" tenant_id: str session_id: str request_id: str generation_id: str mode: ExecutionMode intent: str | None = None tool_calls: list[dict[str, Any]] = field(default_factory=list) guardrail_triggered: bool = False guardrail_rule_id: str | None = None interrupt_consumed: bool = False fallback_reason_code: str | None = None react_iterations: int = 0 latency_ms: int = 0 created_at: str | None = None def to_dict(self) -> dict[str, Any]: return { "tenant_id": self.tenant_id, "session_id": self.session_id, "request_id": self.request_id, "generation_id": self.generation_id, "mode": self.mode.value, "intent": self.intent, "tool_calls": self.tool_calls, "guardrail_triggered": self.guardrail_triggered, "guardrail_rule_id": self.guardrail_rule_id, "interrupt_consumed": self.interrupt_consumed, "fallback_reason_code": self.fallback_reason_code, "react_iterations": self.react_iterations, "latency_ms": self.latency_ms, "created_at": self.created_at, } class TraceLogger: """ [AC-MARH-02, AC-MARH-03, AC-MARH-12] Trace logger for observability and audit. Features: - Request-scoped trace context - Tool call tracing - Guardrail event logging with rule_id - Interrupt consumption tracking - Audit record generation """ def __init__(self): self._traces: dict[str, TraceInfo] = {} self._audit_records: list[AuditRecord] = [] def start_trace( self, tenant_id: str, session_id: str, request_id: str | None = None, generation_id: str | None = None, ) -> TraceInfo: """ [AC-MARH-12] Start a new trace context. Args: tenant_id: Tenant ID session_id: Session ID request_id: Request ID (auto-generated if not provided) generation_id: Generation ID for interrupt handling Returns: TraceInfo for the new trace """ request_id = request_id or str(uuid.uuid4()) generation_id = generation_id or str(uuid.uuid4()) trace = TraceInfo( mode=ExecutionMode.AGENT, request_id=request_id, generation_id=generation_id, ) self._traces[request_id] = trace logger.info( f"[AC-MARH-12] Trace started: request_id={request_id}, " f"session_id={session_id}, generation_id={generation_id}" ) return trace def get_trace(self, request_id: str) -> TraceInfo | None: """Get trace by request ID.""" return self._traces.get(request_id) def update_trace( self, request_id: str, mode: ExecutionMode | None = None, intent: str | None = None, guardrail_triggered: bool | None = None, guardrail_rule_id: str | None = None, interrupt_consumed: bool | None = None, fallback_reason_code: str | None = None, react_iterations: int | None = None, tool_calls: list[ToolCallTrace] | None = None, ) -> TraceInfo | None: """ [AC-MARH-02, AC-MARH-03, AC-MARH-12] Update trace with execution details. """ trace = self._traces.get(request_id) if not trace: logger.warning(f"[AC-MARH-12] Trace not found: {request_id}") return None if mode is not None: trace.mode = mode if intent is not None: trace.intent = intent if guardrail_triggered is not None: trace.guardrail_triggered = guardrail_triggered if guardrail_rule_id is not None: trace.guardrail_rule_id = guardrail_rule_id if interrupt_consumed is not None: trace.interrupt_consumed = interrupt_consumed if fallback_reason_code is not None: trace.fallback_reason_code = fallback_reason_code if react_iterations is not None: trace.react_iterations = react_iterations if tool_calls is not None: trace.tool_calls = tool_calls return trace def add_tool_call( self, request_id: str, tool_call: ToolCallTrace, ) -> None: """ [AC-MARH-12] Add tool call trace to request. """ trace = self._traces.get(request_id) if not trace: logger.warning(f"[AC-MARH-12] Trace not found for tool call: {request_id}") return if trace.tool_calls is None: trace.tool_calls = [] trace.tool_calls.append(tool_call) if trace.tools_used is None: trace.tools_used = [] if tool_call.tool_name not in trace.tools_used: trace.tools_used.append(tool_call.tool_name) logger.debug( f"[AC-MARH-12] Tool call recorded: request_id={request_id}, " f"tool={tool_call.tool_name}, status={tool_call.status.value}" ) def end_trace( self, request_id: str, tenant_id: str, session_id: str, latency_ms: int, ) -> AuditRecord: """ [AC-MARH-12] End trace and create audit record. """ trace = self._traces.get(request_id) if not trace: logger.warning(f"[AC-MARH-12] Trace not found for end: {request_id}") return AuditRecord( tenant_id=tenant_id, session_id=session_id, request_id=request_id, generation_id=str(uuid.uuid4()), mode=ExecutionMode.FIXED, latency_ms=latency_ms, ) audit = AuditRecord( tenant_id=tenant_id, session_id=session_id, request_id=request_id, generation_id=trace.generation_id or str(uuid.uuid4()), mode=trace.mode, intent=trace.intent, tool_calls=[tc.model_dump() for tc in trace.tool_calls] if trace.tool_calls else [], guardrail_triggered=trace.guardrail_triggered or False, guardrail_rule_id=trace.guardrail_rule_id, interrupt_consumed=trace.interrupt_consumed or False, fallback_reason_code=trace.fallback_reason_code, react_iterations=trace.react_iterations or 0, latency_ms=latency_ms, created_at=time.strftime("%Y-%m-%d %H:%M:%S"), ) self._audit_records.append(audit) if request_id in self._traces: del self._traces[request_id] logger.info( f"[AC-MARH-12] Trace ended: request_id={request_id}, " f"mode={trace.mode.value}, latency_ms={latency_ms}" ) return audit def get_audit_records( self, tenant_id: str, session_id: str | None = None, limit: int = 100, ) -> list[AuditRecord]: """Get audit records for a tenant/session.""" records = [ r for r in self._audit_records if r.tenant_id == tenant_id ] if session_id: records = [r for r in records if r.session_id == session_id] return records[-limit:] def clear_audit_records(self, tenant_id: str | None = None) -> int: """Clear audit records, optionally filtered by tenant.""" if tenant_id: original_count = len(self._audit_records) self._audit_records = [ r for r in self._audit_records if r.tenant_id != tenant_id ] return original_count - len(self._audit_records) else: count = len(self._audit_records) self._audit_records = [] return count