ai-robot-core/ai-service/app/core/sse.py

174 lines
5.4 KiB
Python

"""
SSE utilities for AI Service.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
"""
import asyncio
import json
import logging
from enum import Enum
from typing import Any, AsyncGenerator
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from app.core.config import get_settings
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
logger = logging.getLogger(__name__)
class SSEState(str, Enum):
INIT = "INIT"
STREAMING = "STREAMING"
FINAL_SENT = "FINAL_SENT"
ERROR_SENT = "ERROR_SENT"
CLOSED = "CLOSED"
class SSEStateMachine:
"""
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
"""
def __init__(self):
self._state = SSEState.INIT
self._lock = asyncio.Lock()
@property
def state(self) -> SSEState:
return self._state
async def transition_to_streaming(self) -> bool:
async with self._lock:
if self._state == SSEState.INIT:
self._state = SSEState.STREAMING
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
return True
return False
async def transition_to_final(self) -> bool:
async with self._lock:
if self._state == SSEState.STREAMING:
self._state = SSEState.FINAL_SENT
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
return True
return False
async def transition_to_error(self) -> bool:
async with self._lock:
if self._state in (SSEState.INIT, SSEState.STREAMING):
self._state = SSEState.ERROR_SENT
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
return True
return False
async def close(self) -> None:
async with self._lock:
self._state = SSEState.CLOSED
logger.debug("SSE state transition: -> CLOSED")
def can_send_message(self) -> bool:
return self._state == SSEState.STREAMING
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
"""Format data as SSE event."""
return ServerSentEvent(
event=event_type.value,
data=json.dumps(data, ensure_ascii=False),
)
def create_message_event(delta: str) -> ServerSentEvent:
"""[AC-AISVC-07] Create a message event with incremental content."""
event_data = SSEMessageEvent(delta=delta)
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
def create_final_event(
reply: str,
confidence: float,
should_transfer: bool,
transfer_reason: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-08] Create a final event with complete response."""
event_data = SSEFinalEvent(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=metadata,
)
return format_sse_event(
SSEEventType.FINAL,
event_data.model_dump(exclude_none=True, by_alias=True)
)
def create_error_event(
code: str,
message: str,
details: list[dict[str, Any]] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-09] Create an error event."""
event_data = SSEErrorEvent(
code=code,
message=message,
details=details,
)
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
"""
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
Sends ': ping' as comment lines (not events) to keep connection alive.
"""
while True:
await asyncio.sleep(interval_seconds)
yield ": ping\n\n"
class SSEResponseBuilder:
"""
Builder for SSE response with proper event sequencing and ping keep-alive.
"""
def __init__(self):
self._state_machine = SSEStateMachine()
self._settings = get_settings()
async def build_response(
self,
content_generator: AsyncGenerator[ServerSentEvent, None],
) -> EventSourceResponse:
"""
Build SSE response with ping keep-alive mechanism.
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
"""
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
await self._state_machine.transition_to_streaming()
try:
async for event in content_generator:
if self._state_machine.can_send_message():
yield event
else:
break
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
if await self._state_machine.transition_to_error():
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
finally:
await self._state_machine.close()
return EventSourceResponse(
event_generator(),
ping=self._settings.sse_ping_interval_seconds,
)