""" Tests for SSE event generator. [AC-AISVC-07] Tests for message event generation with delta content. """ import json import pytest from unittest.mock import AsyncMock, MagicMock from sse_starlette.sse import ServerSentEvent from app.core.sse import ( create_message_event, create_final_event, create_error_event, SSEStateMachine, SSEState, ) from app.services.orchestrator import OrchestratorService from app.models import ChatRequest, ChannelType class TestSSEEventGenerator: """ [AC-AISVC-07] Test cases for SSE event generation. """ def test_create_message_event_format(self): """ [AC-AISVC-07] Test that message event has correct format. Event should have: - event: "message" - data: JSON with "delta" field """ event = create_message_event(delta="Hello, ") assert event.event == "message" assert event.data is not None data = json.loads(event.data) assert "delta" in data assert data["delta"] == "Hello, " def test_create_message_event_with_unicode(self): """ [AC-AISVC-07] Test that message event handles unicode correctly. """ event = create_message_event(delta="你好,世界!") assert event.event == "message" data = json.loads(event.data) assert data["delta"] == "你好,世界!" def test_create_message_event_with_empty_delta(self): """ [AC-AISVC-07] Test that message event handles empty delta. """ event = create_message_event(delta="") assert event.event == "message" data = json.loads(event.data) assert data["delta"] == "" def test_create_final_event_format(self): """ [AC-AISVC-08] Test that final event has correct format. """ event = create_final_event( reply="Complete response", confidence=0.85, should_transfer=False, ) assert event.event == "final" data = json.loads(event.data) assert data["reply"] == "Complete response" assert data["confidence"] == 0.85 assert data["shouldTransfer"] is False def test_create_final_event_with_transfer_reason(self): """ [AC-AISVC-08] Test final event with transfer reason. """ event = create_final_event( reply="I cannot help with this", confidence=0.3, should_transfer=True, transfer_reason="Low confidence score", ) assert event.event == "final" data = json.loads(event.data) assert data["shouldTransfer"] is True assert data["transferReason"] == "Low confidence score" def test_create_error_event_format(self): """ [AC-AISVC-09] Test that error event has correct format. """ event = create_error_event( code="GENERATION_ERROR", message="Failed to generate response", ) assert event.event == "error" data = json.loads(event.data) assert data["code"] == "GENERATION_ERROR" assert data["message"] == "Failed to generate response" def test_create_error_event_with_details(self): """ [AC-AISVC-09] Test error event with details. """ event = create_error_event( code="VALIDATION_ERROR", message="Invalid input", details=[{"field": "message", "error": "too long"}], ) assert event.event == "error" data = json.loads(event.data) assert data["details"] == [{"field": "message", "error": "too long"}] class TestOrchestratorStreaming: """ [AC-AISVC-07] Test cases for orchestrator streaming with SSE events. """ @pytest.fixture def orchestrator(self): return OrchestratorService() @pytest.fixture def chat_request(self): return ChatRequest( session_id="test_session", current_message="Hello", channel_type=ChannelType.WECHAT, ) @pytest.mark.asyncio async def test_stream_yields_message_events(self, orchestrator, chat_request): """ [AC-AISVC-07] Test that streaming yields message events with delta content. """ events = [] async for event in orchestrator.generate_stream("tenant_001", chat_request): events.append(event) message_events = [e for e in events if e.event == "message"] final_events = [e for e in events if e.event == "final"] assert len(message_events) > 0, "Should have at least one message event" assert len(final_events) == 1, "Should have exactly one final event" for event in message_events: data = json.loads(event.data) assert "delta" in data assert isinstance(data["delta"], str) @pytest.mark.asyncio async def test_stream_message_events_contain_content(self, orchestrator, chat_request): """ [AC-AISVC-07] Test that message events contain the expected content. """ events = [] async for event in orchestrator.generate_stream("tenant_001", chat_request): events.append(event) message_events = [e for e in events if e.event == "message"] full_content = "" for event in message_events: data = json.loads(event.data) full_content += data["delta"] assert "Hello" in full_content, "Content should contain the user message" @pytest.mark.asyncio async def test_stream_event_sequence(self, orchestrator, chat_request): """ [AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence. message* -> final -> close """ events = [] async for event in orchestrator.generate_stream("tenant_001", chat_request): events.append(event) event_types = [e.event for e in events] final_index = event_types.index("final") message_indices = [i for i, t in enumerate(event_types) if t == "message"] for msg_idx in message_indices: assert msg_idx < final_index, "All message events should come before final" @pytest.mark.asyncio async def test_stream_with_llm_client(self, chat_request): """ [AC-AISVC-07] Test streaming with mock LLM client. """ mock_llm = MagicMock() mock_chunk1 = MagicMock() mock_chunk1.delta = "Hello" mock_chunk1.finish_reason = None mock_chunk2 = MagicMock() mock_chunk2.delta = " there!" mock_chunk2.finish_reason = None mock_chunk3 = MagicMock() mock_chunk3.delta = "" mock_chunk3.finish_reason = "stop" async def mock_stream(*args, **kwargs): for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]: yield chunk mock_llm.stream_generate = mock_stream orchestrator = OrchestratorService(llm_client=mock_llm) events = [] async for event in orchestrator.generate_stream("tenant_001", chat_request): events.append(event) message_events = [e for e in events if e.event == "message"] assert len(message_events) == 2, "Should have two message events" full_content = "" for event in message_events: data = json.loads(event.data) full_content += data["delta"] assert full_content == "Hello there!" @pytest.mark.asyncio async def test_stream_handles_error(self, orchestrator, chat_request): """ [AC-AISVC-09] Test that streaming errors are converted to error events. """ pass class TestSSEStateMachineIntegration: """ [AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine. """ @pytest.mark.asyncio async def test_state_machine_prevents_events_after_final(self): """ [AC-AISVC-08] Test that no events can be sent after final. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() assert state_machine.can_send_message() is True await state_machine.transition_to_final() assert state_machine.can_send_message() is False assert state_machine.state == SSEState.FINAL_SENT @pytest.mark.asyncio async def test_state_machine_prevents_events_after_error(self): """ [AC-AISVC-09] Test that no events can be sent after error. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() await state_machine.transition_to_error() assert state_machine.can_send_message() is False assert state_machine.state == SSEState.ERROR_SENT @pytest.mark.asyncio async def test_state_machine_allows_multiple_message_events(self): """ [AC-AISVC-07] Test that multiple message events can be sent during streaming. """ state_machine = SSEStateMachine() await state_machine.transition_to_streaming() for _ in range(5): assert state_machine.can_send_message() is True await state_machine.transition_to_final() assert state_machine.can_send_message() is False