ai-robot-core/ai-service/tests/test_sse_events.py

292 lines
9.1 KiB
Python
Raw Permalink Normal View History

"""
Tests for SSE event generator.
[AC-AISVC-07] Tests for message event generation with delta content.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from sse_starlette.sse import ServerSentEvent
from app.core.sse import (
create_message_event,
create_final_event,
create_error_event,
SSEStateMachine,
SSEState,
)
from app.services.orchestrator import OrchestratorService
from app.models import ChatRequest, ChannelType
class TestSSEEventGenerator:
"""
[AC-AISVC-07] Test cases for SSE event generation.
"""
def test_create_message_event_format(self):
"""
[AC-AISVC-07] Test that message event has correct format.
Event should have:
- event: "message"
- data: JSON with "delta" field
"""
event = create_message_event(delta="Hello, ")
assert event.event == "message"
assert event.data is not None
data = json.loads(event.data)
assert "delta" in data
assert data["delta"] == "Hello, "
def test_create_message_event_with_unicode(self):
"""
[AC-AISVC-07] Test that message event handles unicode correctly.
"""
event = create_message_event(delta="你好,世界!")
assert event.event == "message"
data = json.loads(event.data)
assert data["delta"] == "你好,世界!"
def test_create_message_event_with_empty_delta(self):
"""
[AC-AISVC-07] Test that message event handles empty delta.
"""
event = create_message_event(delta="")
assert event.event == "message"
data = json.loads(event.data)
assert data["delta"] == ""
def test_create_final_event_format(self):
"""
[AC-AISVC-08] Test that final event has correct format.
"""
event = create_final_event(
reply="Complete response",
confidence=0.85,
should_transfer=False,
)
assert event.event == "final"
data = json.loads(event.data)
assert data["reply"] == "Complete response"
assert data["confidence"] == 0.85
assert data["shouldTransfer"] is False
def test_create_final_event_with_transfer_reason(self):
"""
[AC-AISVC-08] Test final event with transfer reason.
"""
event = create_final_event(
reply="I cannot help with this",
confidence=0.3,
should_transfer=True,
transfer_reason="Low confidence score",
)
assert event.event == "final"
data = json.loads(event.data)
assert data["shouldTransfer"] is True
assert data["transferReason"] == "Low confidence score"
def test_create_error_event_format(self):
"""
[AC-AISVC-09] Test that error event has correct format.
"""
event = create_error_event(
code="GENERATION_ERROR",
message="Failed to generate response",
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "GENERATION_ERROR"
assert data["message"] == "Failed to generate response"
def test_create_error_event_with_details(self):
"""
[AC-AISVC-09] Test error event with details.
"""
event = create_error_event(
code="VALIDATION_ERROR",
message="Invalid input",
details=[{"field": "message", "error": "too long"}],
)
assert event.event == "error"
data = json.loads(event.data)
assert data["details"] == [{"field": "message", "error": "too long"}]
class TestOrchestratorStreaming:
"""
[AC-AISVC-07] Test cases for orchestrator streaming with SSE events.
"""
@pytest.fixture
def orchestrator(self):
return OrchestratorService()
@pytest.fixture
def chat_request(self):
return ChatRequest(
session_id="test_session",
current_message="Hello",
channel_type=ChannelType.WECHAT,
)
@pytest.mark.asyncio
async def test_stream_yields_message_events(self, orchestrator, chat_request):
"""
[AC-AISVC-07] Test that streaming yields message events with delta content.
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
final_events = [e for e in events if e.event == "final"]
assert len(message_events) > 0, "Should have at least one message event"
assert len(final_events) == 1, "Should have exactly one final event"
for event in message_events:
data = json.loads(event.data)
assert "delta" in data
assert isinstance(data["delta"], str)
@pytest.mark.asyncio
async def test_stream_message_events_contain_content(self, orchestrator, chat_request):
"""
[AC-AISVC-07] Test that message events contain the expected content.
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
full_content = ""
for event in message_events:
data = json.loads(event.data)
full_content += data["delta"]
assert "Hello" in full_content, "Content should contain the user message"
@pytest.mark.asyncio
async def test_stream_event_sequence(self, orchestrator, chat_request):
"""
[AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence.
message* -> final -> close
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
event_types = [e.event for e in events]
final_index = event_types.index("final")
message_indices = [i for i, t in enumerate(event_types) if t == "message"]
for msg_idx in message_indices:
assert msg_idx < final_index, "All message events should come before final"
@pytest.mark.asyncio
async def test_stream_with_llm_client(self, chat_request):
"""
[AC-AISVC-07] Test streaming with mock LLM client.
"""
mock_llm = MagicMock()
mock_chunk1 = MagicMock()
mock_chunk1.delta = "Hello"
mock_chunk1.finish_reason = None
mock_chunk2 = MagicMock()
mock_chunk2.delta = " there!"
mock_chunk2.finish_reason = None
mock_chunk3 = MagicMock()
mock_chunk3.delta = ""
mock_chunk3.finish_reason = "stop"
async def mock_stream(*args, **kwargs):
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
yield chunk
mock_llm.stream_generate = mock_stream
orchestrator = OrchestratorService(llm_client=mock_llm)
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
assert len(message_events) == 2, "Should have two message events"
full_content = ""
for event in message_events:
data = json.loads(event.data)
full_content += data["delta"]
assert full_content == "Hello there!"
@pytest.mark.asyncio
async def test_stream_handles_error(self, orchestrator, chat_request):
"""
[AC-AISVC-09] Test that streaming errors are converted to error events.
"""
pass
class TestSSEStateMachineIntegration:
"""
[AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine.
"""
@pytest.mark.asyncio
async def test_state_machine_prevents_events_after_final(self):
"""
[AC-AISVC-08] Test that no events can be sent after final.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
assert state_machine.can_send_message() is True
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_state_machine_prevents_events_after_error(self):
"""
[AC-AISVC-09] Test that no events can be sent after error.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert state_machine.can_send_message() is False
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_state_machine_allows_multiple_message_events(self):
"""
[AC-AISVC-07] Test that multiple message events can be sent during streaming.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
for _ in range(5):
assert state_machine.can_send_message() is True
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False