ai-robot-core/ai-service/tests/test_sse_state_machine.py

377 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Tests for SSE state machine and error handling.
[AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from sse_starlette.sse import ServerSentEvent
from app.core.sse import (
SSEState,
SSEStateMachine,
create_error_event,
create_final_event,
create_message_event,
)
from app.main import app
from app.models import ChatRequest, ChannelType
class TestSSEStateMachineTransitions:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions.
"""
@pytest.mark.asyncio
async def test_init_to_streaming_transition(self):
"""
[AC-AISVC-08] Test INIT -> STREAMING transition.
"""
state_machine = SSEStateMachine()
assert state_machine.state == SSEState.INIT
success = await state_machine.transition_to_streaming()
assert success is True
assert state_machine.state == SSEState.STREAMING
@pytest.mark.asyncio
async def test_streaming_to_final_transition(self):
"""
[AC-AISVC-08] Test STREAMING -> FINAL_SENT transition.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_final()
assert success is True
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_streaming_to_error_transition(self):
"""
[AC-AISVC-09] Test STREAMING -> ERROR_SENT transition.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_init_to_error_transition(self):
"""
[AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts).
"""
state_machine = SSEStateMachine()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_transition_from_final(self):
"""
[AC-AISVC-08] Test that no transitions are possible after FINAL_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
assert await state_machine.transition_to_streaming() is False
assert await state_machine.transition_to_error() is False
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_cannot_transition_from_error(self):
"""
[AC-AISVC-09] Test that no transitions are possible after ERROR_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert await state_machine.transition_to_streaming() is False
assert await state_machine.transition_to_final() is False
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_send_message_after_final(self):
"""
[AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False
@pytest.mark.asyncio
async def test_cannot_send_message_after_error(self):
"""
[AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert state_machine.can_send_message() is False
@pytest.mark.asyncio
async def test_close_transition(self):
"""
[AC-AISVC-08] Test that close() transitions to CLOSED state.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
await state_machine.close()
assert state_machine.state == SSEState.CLOSED
class TestSSEEventSequence:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"}
@pytest.fixture
def valid_body(self):
return {
"sessionId": "test_session",
"currentMessage": "Hello",
"channelType": "wechat",
}
def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that SSE events follow: message* -> final -> close.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
assert response.status_code == 200
content = response.text
assert "event:message" in content or "event: message" in content
assert "event:final" in content or "event: final" in content
message_idx = content.find("event:message")
if message_idx == -1:
message_idx = content.find("event: message")
final_idx = content.find("event:final")
if final_idx == -1:
final_idx = content.find("event: final")
assert final_idx > message_idx, "final should come after message events"
def test_sse_only_one_final_event(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that there is exactly one final event.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
content = response.text
final_count = content.count("event:final") + content.count("event: final")
assert final_count == 1, f"Expected exactly 1 final event, got {final_count}"
def test_sse_no_events_after_final(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that no message events appear after final event.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
content = response.text
lines = content.split("\n")
final_found = False
for line in lines:
if "event:final" in line or "event: final" in line:
final_found = True
elif final_found and ("event:message" in line or "event: message" in line):
pytest.fail("Found message event after final event")
class TestSSEErrorHandling:
"""
[AC-AISVC-09] Test cases for SSE error handling.
"""
@pytest.mark.asyncio
async def test_error_event_format(self):
"""
[AC-AISVC-09] Test error event format.
"""
event = create_error_event(
code="TEST_ERROR",
message="Test error message",
details=[{"field": "test"}],
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "TEST_ERROR"
assert data["message"] == "Test error message"
assert data["details"] == [{"field": "test"}]
@pytest.mark.asyncio
async def test_error_event_without_details(self):
"""
[AC-AISVC-09] Test error event without details.
"""
event = create_error_event(
code="SIMPLE_ERROR",
message="Simple error",
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "SIMPLE_ERROR"
assert data["message"] == "Simple error"
assert "details" not in data
def test_missing_tenant_id_returns_400(self):
"""
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
"""
client = TestClient(app)
headers = {"Accept": "text/event-stream"}
body = {
"sessionId": "test_session",
"currentMessage": "Hello",
"channelType": "wechat",
}
response = client.post("/ai/chat", json=body, headers=headers)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
class TestSSEStateConcurrency:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety.
"""
@pytest.mark.asyncio
async def test_concurrent_transitions(self):
"""
[AC-AISVC-08] Test that concurrent transitions are handled correctly.
"""
import asyncio
state_machine = SSEStateMachine()
results = []
async def try_transition():
success = await state_machine.transition_to_streaming()
results.append(success)
await asyncio.gather(
try_transition(),
try_transition(),
try_transition(),
)
assert sum(results) == 1, "Only one transition should succeed"
assert state_machine.state == SSEState.STREAMING
@pytest.mark.asyncio
async def test_concurrent_final_transitions(self):
"""
[AC-AISVC-08] Test that only one final transition succeeds.
"""
import asyncio
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
results = []
async def try_final():
success = await state_machine.transition_to_final()
results.append(success)
await asyncio.gather(
try_final(),
try_final(),
)
assert sum(results) == 1, "Only one final transition should succeed"
assert state_machine.state == SSEState.FINAL_SENT
class TestSSEIntegrationWithOrchestrator:
"""
[AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator.
"""
@pytest.mark.asyncio
async def test_orchestrator_stream_with_error(self):
"""
[AC-AISVC-09] Test that orchestrator errors are properly handled.
"""
from app.services.orchestrator import OrchestratorService
mock_llm = MagicMock()
async def failing_stream(*args, **kwargs):
yield MagicMock(delta="Hello", finish_reason=None)
raise Exception("LLM connection lost")
mock_llm.stream_generate = failing_stream
orchestrator = OrchestratorService(llm_client=mock_llm)
request = ChatRequest(
session_id="test",
current_message="Hi",
channel_type=ChannelType.WECHAT,
)
events = []
async for event in orchestrator.generate_stream("tenant", request):
events.append(event)
event_types = [e.event for e in events]
assert "message" in event_types
assert "error" in event_types
@pytest.mark.asyncio
async def test_orchestrator_stream_normal_flow(self):
"""
[AC-AISVC-08] Test normal streaming flow ends with final event.
"""
from app.services.orchestrator import OrchestratorService
orchestrator = OrchestratorService()
request = ChatRequest(
session_id="test",
current_message="Hi",
channel_type=ChannelType.WECHAT,
)
events = []
async for event in orchestrator.generate_stream("tenant", request):
events.append(event)
event_types = [e.event for e in events]
assert "message" in event_types
assert "final" in event_types
final_index = event_types.index("final")
for i, t in enumerate(event_types):
if t == "message":
assert i < final_index, "message events should come before final"