171 lines
5.3 KiB
Python
171 lines
5.3 KiB
Python
"""
|
|
SSE utilities for AI Service.
|
|
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from enum import Enum
|
|
from typing import Any, AsyncGenerator
|
|
|
|
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
|
|
|
from app.core.config import get_settings
|
|
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SSEState(str, Enum):
|
|
INIT = "INIT"
|
|
STREAMING = "STREAMING"
|
|
FINAL_SENT = "FINAL_SENT"
|
|
ERROR_SENT = "ERROR_SENT"
|
|
CLOSED = "CLOSED"
|
|
|
|
|
|
class SSEStateMachine:
|
|
"""
|
|
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
|
|
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._state = SSEState.INIT
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def state(self) -> SSEState:
|
|
return self._state
|
|
|
|
async def transition_to_streaming(self) -> bool:
|
|
async with self._lock:
|
|
if self._state == SSEState.INIT:
|
|
self._state = SSEState.STREAMING
|
|
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
|
|
return True
|
|
return False
|
|
|
|
async def transition_to_final(self) -> bool:
|
|
async with self._lock:
|
|
if self._state == SSEState.STREAMING:
|
|
self._state = SSEState.FINAL_SENT
|
|
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
|
|
return True
|
|
return False
|
|
|
|
async def transition_to_error(self) -> bool:
|
|
async with self._lock:
|
|
if self._state in (SSEState.INIT, SSEState.STREAMING):
|
|
self._state = SSEState.ERROR_SENT
|
|
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
|
|
return True
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
async with self._lock:
|
|
self._state = SSEState.CLOSED
|
|
logger.debug("SSE state transition: -> CLOSED")
|
|
|
|
def can_send_message(self) -> bool:
|
|
return self._state == SSEState.STREAMING
|
|
|
|
|
|
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
|
|
"""Format data as SSE event."""
|
|
return ServerSentEvent(
|
|
event=event_type.value,
|
|
data=json.dumps(data, ensure_ascii=False),
|
|
)
|
|
|
|
|
|
def create_message_event(delta: str) -> ServerSentEvent:
|
|
"""[AC-AISVC-07] Create a message event with incremental content."""
|
|
event_data = SSEMessageEvent(delta=delta)
|
|
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
|
|
|
|
|
|
def create_final_event(
|
|
reply: str,
|
|
confidence: float,
|
|
should_transfer: bool,
|
|
transfer_reason: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> ServerSentEvent:
|
|
"""[AC-AISVC-08] Create a final event with complete response."""
|
|
event_data = SSEFinalEvent(
|
|
reply=reply,
|
|
confidence=confidence,
|
|
should_transfer=should_transfer,
|
|
transfer_reason=transfer_reason,
|
|
metadata=metadata,
|
|
)
|
|
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
|
|
|
|
|
|
def create_error_event(
|
|
code: str,
|
|
message: str,
|
|
details: list[dict[str, Any]] | None = None,
|
|
) -> ServerSentEvent:
|
|
"""[AC-AISVC-09] Create an error event."""
|
|
event_data = SSEErrorEvent(
|
|
code=code,
|
|
message=message,
|
|
details=details,
|
|
)
|
|
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
|
|
|
|
|
|
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
|
|
"""
|
|
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
|
|
Sends ': ping' as comment lines (not events) to keep connection alive.
|
|
"""
|
|
while True:
|
|
await asyncio.sleep(interval_seconds)
|
|
yield ": ping\n\n"
|
|
|
|
|
|
class SSEResponseBuilder:
|
|
"""
|
|
Builder for SSE response with proper event sequencing and ping keep-alive.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._state_machine = SSEStateMachine()
|
|
self._settings = get_settings()
|
|
|
|
async def build_response(
|
|
self,
|
|
content_generator: AsyncGenerator[ServerSentEvent, None],
|
|
) -> EventSourceResponse:
|
|
"""
|
|
Build SSE response with ping keep-alive mechanism.
|
|
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
|
|
"""
|
|
|
|
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
|
|
await self._state_machine.transition_to_streaming()
|
|
try:
|
|
async for event in content_generator:
|
|
if self._state_machine.can_send_message():
|
|
yield event
|
|
else:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
|
|
if await self._state_machine.transition_to_error():
|
|
yield create_error_event(
|
|
code="STREAMING_ERROR",
|
|
message=str(e),
|
|
)
|
|
finally:
|
|
await self._state_machine.close()
|
|
|
|
return EventSourceResponse(
|
|
event_generator(),
|
|
ping=self._settings.sse_ping_interval_seconds,
|
|
)
|