""" SSE utilities for AI Service. [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine. """ import asyncio import json import logging from enum import Enum from typing import Any, AsyncGenerator from sse_starlette.sse import EventSourceResponse, ServerSentEvent from app.core.config import get_settings from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent logger = logging.getLogger(__name__) class SSEState(str, Enum): INIT = "INIT" STREAMING = "STREAMING" FINAL_SENT = "FINAL_SENT" ERROR_SENT = "ERROR_SENT" CLOSED = "CLOSED" class SSEStateMachine: """ [AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence. State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED """ def __init__(self): self._state = SSEState.INIT self._lock = asyncio.Lock() @property def state(self) -> SSEState: return self._state async def transition_to_streaming(self) -> bool: async with self._lock: if self._state == SSEState.INIT: self._state = SSEState.STREAMING logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING") return True return False async def transition_to_final(self) -> bool: async with self._lock: if self._state == SSEState.STREAMING: self._state = SSEState.FINAL_SENT logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT") return True return False async def transition_to_error(self) -> bool: async with self._lock: if self._state in (SSEState.INIT, SSEState.STREAMING): self._state = SSEState.ERROR_SENT logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT") return True return False async def close(self) -> None: async with self._lock: self._state = SSEState.CLOSED logger.debug("SSE state transition: -> CLOSED") def can_send_message(self) -> bool: return self._state == SSEState.STREAMING def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent: """Format data as SSE event.""" return ServerSentEvent( event=event_type.value, data=json.dumps(data, ensure_ascii=False), ) def create_message_event(delta: str) -> ServerSentEvent: """[AC-AISVC-07] Create a message event with incremental content.""" event_data = SSEMessageEvent(delta=delta) return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump()) def create_final_event( reply: str, confidence: float, should_transfer: bool, transfer_reason: str | None = None, metadata: dict[str, Any] | None = None, ) -> ServerSentEvent: """[AC-AISVC-08] Create a final event with complete response.""" event_data = SSEFinalEvent( reply=reply, confidence=confidence, should_transfer=should_transfer, transfer_reason=transfer_reason, metadata=metadata, ) return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True)) def create_error_event( code: str, message: str, details: list[dict[str, Any]] | None = None, ) -> ServerSentEvent: """[AC-AISVC-09] Create an error event.""" event_data = SSEErrorEvent( code=code, message=message, details=details, ) return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True)) async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]: """ [AC-AISVC-06] Generate ping comments for SSE keep-alive. Sends ': ping' as comment lines (not events) to keep connection alive. """ while True: await asyncio.sleep(interval_seconds) yield ": ping\n\n" class SSEResponseBuilder: """ Builder for SSE response with proper event sequencing and ping keep-alive. """ def __init__(self): self._state_machine = SSEStateMachine() self._settings = get_settings() async def build_response( self, content_generator: AsyncGenerator[ServerSentEvent, None], ) -> EventSourceResponse: """ Build SSE response with ping keep-alive mechanism. [AC-AISVC-06] Implements ping keep-alive to prevent connection timeout. """ async def event_generator() -> AsyncGenerator[ServerSentEvent, None]: await self._state_machine.transition_to_streaming() try: async for event in content_generator: if self._state_machine.can_send_message(): yield event else: break except Exception as e: logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}") if await self._state_machine.transition_to_error(): yield create_error_event( code="STREAMING_ERROR", message=str(e), ) finally: await self._state_machine.close() return EventSourceResponse( event_generator(), ping=self._settings.sse_ping_interval_seconds, )