292 lines
9.1 KiB
Python
292 lines
9.1 KiB
Python
|
|
"""
|
||
|
|
Tests for SSE event generator.
|
||
|
|
[AC-AISVC-07] Tests for message event generation with delta content.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock
|
||
|
|
|
||
|
|
from sse_starlette.sse import ServerSentEvent
|
||
|
|
|
||
|
|
from app.core.sse import (
|
||
|
|
create_message_event,
|
||
|
|
create_final_event,
|
||
|
|
create_error_event,
|
||
|
|
SSEStateMachine,
|
||
|
|
SSEState,
|
||
|
|
)
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
from app.models import ChatRequest, ChannelType
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEEventGenerator:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test cases for SSE event generation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def test_create_message_event_format(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that message event has correct format.
|
||
|
|
Event should have:
|
||
|
|
- event: "message"
|
||
|
|
- data: JSON with "delta" field
|
||
|
|
"""
|
||
|
|
event = create_message_event(delta="Hello, ")
|
||
|
|
|
||
|
|
assert event.event == "message"
|
||
|
|
assert event.data is not None
|
||
|
|
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert "delta" in data
|
||
|
|
assert data["delta"] == "Hello, "
|
||
|
|
|
||
|
|
def test_create_message_event_with_unicode(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that message event handles unicode correctly.
|
||
|
|
"""
|
||
|
|
event = create_message_event(delta="你好,世界!")
|
||
|
|
|
||
|
|
assert event.event == "message"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["delta"] == "你好,世界!"
|
||
|
|
|
||
|
|
def test_create_message_event_with_empty_delta(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that message event handles empty delta.
|
||
|
|
"""
|
||
|
|
event = create_message_event(delta="")
|
||
|
|
|
||
|
|
assert event.event == "message"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["delta"] == ""
|
||
|
|
|
||
|
|
def test_create_final_event_format(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that final event has correct format.
|
||
|
|
"""
|
||
|
|
event = create_final_event(
|
||
|
|
reply="Complete response",
|
||
|
|
confidence=0.85,
|
||
|
|
should_transfer=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "final"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["reply"] == "Complete response"
|
||
|
|
assert data["confidence"] == 0.85
|
||
|
|
assert data["shouldTransfer"] is False
|
||
|
|
|
||
|
|
def test_create_final_event_with_transfer_reason(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test final event with transfer reason.
|
||
|
|
"""
|
||
|
|
event = create_final_event(
|
||
|
|
reply="I cannot help with this",
|
||
|
|
confidence=0.3,
|
||
|
|
should_transfer=True,
|
||
|
|
transfer_reason="Low confidence score",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "final"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["shouldTransfer"] is True
|
||
|
|
assert data["transferReason"] == "Low confidence score"
|
||
|
|
|
||
|
|
def test_create_error_event_format(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that error event has correct format.
|
||
|
|
"""
|
||
|
|
event = create_error_event(
|
||
|
|
code="GENERATION_ERROR",
|
||
|
|
message="Failed to generate response",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "error"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["code"] == "GENERATION_ERROR"
|
||
|
|
assert data["message"] == "Failed to generate response"
|
||
|
|
|
||
|
|
def test_create_error_event_with_details(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test error event with details.
|
||
|
|
"""
|
||
|
|
event = create_error_event(
|
||
|
|
code="VALIDATION_ERROR",
|
||
|
|
message="Invalid input",
|
||
|
|
details=[{"field": "message", "error": "too long"}],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert event.event == "error"
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert data["details"] == [{"field": "message", "error": "too long"}]
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorStreaming:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test cases for orchestrator streaming with SSE events.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def orchestrator(self):
|
||
|
|
return OrchestratorService()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def chat_request(self):
|
||
|
|
return ChatRequest(
|
||
|
|
session_id="test_session",
|
||
|
|
current_message="Hello",
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stream_yields_message_events(self, orchestrator, chat_request):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that streaming yields message events with delta content.
|
||
|
|
"""
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
message_events = [e for e in events if e.event == "message"]
|
||
|
|
final_events = [e for e in events if e.event == "final"]
|
||
|
|
|
||
|
|
assert len(message_events) > 0, "Should have at least one message event"
|
||
|
|
assert len(final_events) == 1, "Should have exactly one final event"
|
||
|
|
|
||
|
|
for event in message_events:
|
||
|
|
data = json.loads(event.data)
|
||
|
|
assert "delta" in data
|
||
|
|
assert isinstance(data["delta"], str)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stream_message_events_contain_content(self, orchestrator, chat_request):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that message events contain the expected content.
|
||
|
|
"""
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
message_events = [e for e in events if e.event == "message"]
|
||
|
|
|
||
|
|
full_content = ""
|
||
|
|
for event in message_events:
|
||
|
|
data = json.loads(event.data)
|
||
|
|
full_content += data["delta"]
|
||
|
|
|
||
|
|
assert "Hello" in full_content, "Content should contain the user message"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stream_event_sequence(self, orchestrator, chat_request):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence.
|
||
|
|
message* -> final -> close
|
||
|
|
"""
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
event_types = [e.event for e in events]
|
||
|
|
|
||
|
|
final_index = event_types.index("final")
|
||
|
|
message_indices = [i for i, t in enumerate(event_types) if t == "message"]
|
||
|
|
|
||
|
|
for msg_idx in message_indices:
|
||
|
|
assert msg_idx < final_index, "All message events should come before final"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stream_with_llm_client(self, chat_request):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test streaming with mock LLM client.
|
||
|
|
"""
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_chunk1 = MagicMock()
|
||
|
|
mock_chunk1.delta = "Hello"
|
||
|
|
mock_chunk1.finish_reason = None
|
||
|
|
|
||
|
|
mock_chunk2 = MagicMock()
|
||
|
|
mock_chunk2.delta = " there!"
|
||
|
|
mock_chunk2.finish_reason = None
|
||
|
|
|
||
|
|
mock_chunk3 = MagicMock()
|
||
|
|
mock_chunk3.delta = ""
|
||
|
|
mock_chunk3.finish_reason = "stop"
|
||
|
|
|
||
|
|
async def mock_stream(*args, **kwargs):
|
||
|
|
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
|
||
|
|
yield chunk
|
||
|
|
|
||
|
|
mock_llm.stream_generate = mock_stream
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(llm_client=mock_llm)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
message_events = [e for e in events if e.event == "message"]
|
||
|
|
assert len(message_events) == 2, "Should have two message events"
|
||
|
|
|
||
|
|
full_content = ""
|
||
|
|
for event in message_events:
|
||
|
|
data = json.loads(event.data)
|
||
|
|
full_content += data["delta"]
|
||
|
|
|
||
|
|
assert full_content == "Hello there!"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stream_handles_error(self, orchestrator, chat_request):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that streaming errors are converted to error events.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEStateMachineIntegration:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_state_machine_prevents_events_after_final(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08] Test that no events can be sent after final.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is True
|
||
|
|
|
||
|
|
await state_machine.transition_to_final()
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is False
|
||
|
|
assert state_machine.state == SSEState.FINAL_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_state_machine_prevents_events_after_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-09] Test that no events can be sent after error.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
await state_machine.transition_to_error()
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is False
|
||
|
|
assert state_machine.state == SSEState.ERROR_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_state_machine_allows_multiple_message_events(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07] Test that multiple message events can be sent during streaming.
|
||
|
|
"""
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
for _ in range(5):
|
||
|
|
assert state_machine.can_send_message() is True
|
||
|
|
|
||
|
|
await state_machine.transition_to_final()
|
||
|
|
assert state_machine.can_send_message() is False
|