""" Tests for SSE state machine and error handling. [AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling. """ import json import pytest from unittest.mock import AsyncMock, MagicMock, patch from fastapi.testclient import TestClient from sse_starlette.sse import ServerSentEvent from app.core.sse import ( SSEState, SSEStateMachine, create_error_event, create_final_event, create_message_event, ) from app.main import app from app.models import ChatRequest, ChannelType class TestSSEStateMachineTransitions: """ [AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions. """ @pytest.mark.asyncio async def test_init_to_streaming_transition(self): """ [AC-AISVC-08] Test INIT -> STREAMING transition. """ state_machine = SSEStateMachine() assert state_machine.state == SSEState.INIT success = await state_machine.transition_to_streaming() assert success is True assert state_machine.state == SSEState.STREAMING @pytest.mark.asyncio async def test_streaming_to_final_transition(self): """ [AC-AISVC-08] Test STREAMING -> FINAL_SENT transition. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() success = await state_machine.transition_to_final() assert success is True assert state_machine.state == SSEState.FINAL_SENT @pytest.mark.asyncio async def test_streaming_to_error_transition(self): """ [AC-AISVC-09] Test STREAMING -> ERROR_SENT transition. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() success = await state_machine.transition_to_error() assert success is True assert state_machine.state == SSEState.ERROR_SENT @pytest.mark.asyncio async def test_init_to_error_transition(self): """ [AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts). """ state_machine = SSEStateMachine() success = await state_machine.transition_to_error() assert success is True assert state_machine.state == SSEState.ERROR_SENT @pytest.mark.asyncio async def test_cannot_transition_from_final(self): """ [AC-AISVC-08] Test that no transitions are possible after FINAL_SENT. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_final() assert await state_machine.transition_to_streaming() is False assert await state_machine.transition_to_error() is False assert state_machine.state == SSEState.FINAL_SENT @pytest.mark.asyncio async def test_cannot_transition_from_error(self): """ [AC-AISVC-09] Test that no transitions are possible after ERROR_SENT. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_error() assert await state_machine.transition_to_streaming() is False assert await state_machine.transition_to_final() is False assert state_machine.state == SSEState.ERROR_SENT @pytest.mark.asyncio async def test_cannot_send_message_after_final(self): """ [AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_final() assert state_machine.can_send_message() is False @pytest.mark.asyncio async def test_cannot_send_message_after_error(self): """ [AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_error() assert state_machine.can_send_message() is False @pytest.mark.asyncio async def test_close_transition(self): """ [AC-AISVC-08] Test that close() transitions to CLOSED state. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_final() await state_machine.close() assert state_machine.state == SSEState.CLOSED class TestSSEEventSequence: """ [AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement. """ @pytest.fixture def client(self): return TestClient(app) @pytest.fixture def valid_headers(self): return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"} @pytest.fixture def valid_body(self): return { "sessionId": "test_session", "currentMessage": "Hello", "channelType": "wechat", } def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body): """ [AC-AISVC-08] Test that SSE events follow: message* -> final -> close. """ response = client.post("/ai/chat", json=valid_body, headers=valid_headers) assert response.status_code == 200 content = response.text assert "event:message" in content or "event: message" in content assert "event:final" in content or "event: final" in content message_idx = content.find("event:message") if message_idx == -1: message_idx = content.find("event: message") final_idx = content.find("event:final") if final_idx == -1: final_idx = content.find("event: final") assert final_idx > message_idx, "final should come after message events" def test_sse_only_one_final_event(self, client, valid_headers, valid_body): """ [AC-AISVC-08] Test that there is exactly one final event. """ response = client.post("/ai/chat", json=valid_body, headers=valid_headers) content = response.text final_count = content.count("event:final") + content.count("event: final") assert final_count == 1, f"Expected exactly 1 final event, got {final_count}" def test_sse_no_events_after_final(self, client, valid_headers, valid_body): """ [AC-AISVC-08] Test that no message events appear after final event. """ response = client.post("/ai/chat", json=valid_body, headers=valid_headers) content = response.text lines = content.split("\n") final_found = False for line in lines: if "event:final" in line or "event: final" in line: final_found = True elif final_found and ("event:message" in line or "event: message" in line): pytest.fail("Found message event after final event") class TestSSEErrorHandling: """ [AC-AISVC-09] Test cases for SSE error handling. """ @pytest.mark.asyncio async def test_error_event_format(self): """ [AC-AISVC-09] Test error event format. """ event = create_error_event( code="TEST_ERROR", message="Test error message", details=[{"field": "test"}], ) assert event.event == "error" data = json.loads(event.data) assert data["code"] == "TEST_ERROR" assert data["message"] == "Test error message" assert data["details"] == [{"field": "test"}] @pytest.mark.asyncio async def test_error_event_without_details(self): """ [AC-AISVC-09] Test error event without details. """ event = create_error_event( code="SIMPLE_ERROR", message="Simple error", ) assert event.event == "error" data = json.loads(event.data) assert data["code"] == "SIMPLE_ERROR" assert data["message"] == "Simple error" assert "details" not in data def test_missing_tenant_id_returns_400(self): """ [AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error. """ client = TestClient(app) headers = {"Accept": "text/event-stream"} body = { "sessionId": "test_session", "currentMessage": "Hello", "channelType": "wechat", } response = client.post("/ai/chat", json=body, headers=headers) assert response.status_code == 400 data = response.json() assert data["code"] == "MISSING_TENANT_ID" class TestSSEStateConcurrency: """ [AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety. """ @pytest.mark.asyncio async def test_concurrent_transitions(self): """ [AC-AISVC-08] Test that concurrent transitions are handled correctly. """ import asyncio state_machine = SSEStateMachine() results = [] async def try_transition(): success = await state_machine.transition_to_streaming() results.append(success) await asyncio.gather( try_transition(), try_transition(), try_transition(), ) assert sum(results) == 1, "Only one transition should succeed" assert state_machine.state == SSEState.STREAMING @pytest.mark.asyncio async def test_concurrent_final_transitions(self): """ [AC-AISVC-08] Test that only one final transition succeeds. """ import asyncio state_machine = SSEStateMachine() await state_machine.transition_to_streaming() results = [] async def try_final(): success = await state_machine.transition_to_final() results.append(success) await asyncio.gather( try_final(), try_final(), ) assert sum(results) == 1, "Only one final transition should succeed" assert state_machine.state == SSEState.FINAL_SENT class TestSSEIntegrationWithOrchestrator: """ [AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator. """ @pytest.mark.asyncio async def test_orchestrator_stream_with_error(self): """ [AC-AISVC-09] Test that orchestrator errors are properly handled. """ from app.services.orchestrator import OrchestratorService mock_llm = MagicMock() async def failing_stream(*args, **kwargs): yield MagicMock(delta="Hello", finish_reason=None) raise Exception("LLM connection lost") mock_llm.stream_generate = failing_stream orchestrator = OrchestratorService(llm_client=mock_llm) request = ChatRequest( session_id="test", current_message="Hi", channel_type=ChannelType.WECHAT, ) events = [] async for event in orchestrator.generate_stream("tenant", request): events.append(event) event_types = [e.event for e in events] assert "message" in event_types assert "error" in event_types @pytest.mark.asyncio async def test_orchestrator_stream_normal_flow(self): """ [AC-AISVC-08] Test normal streaming flow ends with final event. """ from app.services.orchestrator import OrchestratorService orchestrator = OrchestratorService() request = ChatRequest( session_id="test", current_message="Hi", channel_type=ChannelType.WECHAT, ) events = [] async for event in orchestrator.generate_stream("tenant", request): events.append(event) event_types = [e.event for e in events] assert "message" in event_types assert "final" in event_types final_index = event_types.index("final") for i, t in enumerate(event_types): if t == "message": assert i < final_index, "message events should come before final"