""" Tests for response mode switching based on Accept header. [AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes. """ import pytest from fastapi.testclient import TestClient from httpx import AsyncClient from app.main import app class TestAcceptHeaderSwitching: """ [AC-AISVC-06] Test cases for Accept header based response mode switching. """ @pytest.fixture def client(self): return TestClient(app) @pytest.fixture def valid_request_body(self): return { "sessionId": "test_session_001", "currentMessage": "Hello, how are you?", "channelType": "wechat", } @pytest.fixture def valid_headers(self): return {"X-Tenant-Id": "tenant_001"} def test_json_response_with_default_accept( self, client: TestClient, valid_request_body: dict, valid_headers: dict ): """ [AC-AISVC-06] Test that default Accept header returns JSON response. """ response = client.post( "/ai/chat", json=valid_request_body, headers=valid_headers, ) assert response.status_code == 200 assert response.headers["content-type"] == "application/json" data = response.json() assert "reply" in data assert "confidence" in data assert "shouldTransfer" in data def test_json_response_with_application_json_accept( self, client: TestClient, valid_request_body: dict, valid_headers: dict ): """ [AC-AISVC-06] Test that Accept: application/json returns JSON response. """ headers = {**valid_headers, "Accept": "application/json"} response = client.post( "/ai/chat", json=valid_request_body, headers=headers, ) assert response.status_code == 200 assert response.headers["content-type"] == "application/json" data = response.json() assert "reply" in data assert "confidence" in data assert "shouldTransfer" in data def test_sse_response_with_text_event_stream_accept( self, client: TestClient, valid_request_body: dict, valid_headers: dict ): """ [AC-AISVC-06] Test that Accept: text/event-stream returns SSE response. """ headers = {**valid_headers, "Accept": "text/event-stream"} response = client.post( "/ai/chat", json=valid_request_body, headers=headers, ) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] content = response.text assert "event: message" in content assert "event: final" in content def test_sse_response_event_sequence( self, client: TestClient, valid_request_body: dict, valid_headers: dict ): """ [AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence. message* -> final -> close """ headers = {**valid_headers, "Accept": "text/event-stream"} response = client.post( "/ai/chat", json=valid_request_body, headers=headers, ) content = response.text assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}" assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}" message_idx = content.find("event:message") if message_idx == -1: message_idx = content.find("event: message") final_idx = content.find("event:final") if final_idx == -1: final_idx = content.find("event: final") assert final_idx > message_idx, "final event should come after message events" def test_missing_tenant_id_returns_400( self, client: TestClient, valid_request_body: dict ): """ [AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error. """ response = client.post( "/ai/chat", json=valid_request_body, ) assert response.status_code == 400 data = response.json() assert data["code"] == "MISSING_TENANT_ID" assert "message" in data def test_invalid_channel_type_returns_400( self, client: TestClient, valid_headers: dict ): """ [AC-AISVC-03] Test that invalid channel type returns 400 error. """ invalid_body = { "sessionId": "test_session_001", "currentMessage": "Hello", "channelType": "invalid_channel", } response = client.post( "/ai/chat", json=invalid_body, headers=valid_headers, ) assert response.status_code == 400 def test_missing_required_fields_returns_400( self, client: TestClient, valid_headers: dict ): """ [AC-AISVC-03] Test that missing required fields return 400 error. """ incomplete_body = { "sessionId": "test_session_001", } response = client.post( "/ai/chat", json=incomplete_body, headers=valid_headers, ) assert response.status_code == 400 class TestHealthEndpoint: """ [AC-AISVC-20] Test cases for health check endpoint. """ @pytest.fixture def client(self): return TestClient(app) def test_health_check_returns_200(self, client: TestClient): """ [AC-AISVC-20] Test that health check returns 200 with status. """ response = client.get("/ai/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" class TestSSEStateMachine: """ [AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine. """ @pytest.mark.asyncio async def test_state_transitions(self): from app.core.sse import SSEState, SSEStateMachine state_machine = SSEStateMachine() assert state_machine.state == SSEState.INIT success = await state_machine.transition_to_streaming() assert success is True assert state_machine.state == SSEState.STREAMING assert state_machine.can_send_message() is True success = await state_machine.transition_to_final() assert success is True assert state_machine.state == SSEState.FINAL_SENT assert state_machine.can_send_message() is False await state_machine.close() assert state_machine.state == SSEState.CLOSED @pytest.mark.asyncio async def test_error_transition_from_streaming(self): from app.core.sse import SSEState, SSEStateMachine state_machine = SSEStateMachine() await state_machine.transition_to_streaming() success = await state_machine.transition_to_error() assert success is True assert state_machine.state == SSEState.ERROR_SENT @pytest.mark.asyncio async def test_cannot_transition_to_final_from_init(self): from app.core.sse import SSEStateMachine state_machine = SSEStateMachine() success = await state_machine.transition_to_final() assert success is False class TestMiddleware: """ [AC-AISVC-10, AC-AISVC-12] Test cases for middleware. """ @pytest.fixture def client(self): return TestClient(app) def test_tenant_context_extraction( self, client: TestClient ): """ [AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used. """ headers = {"X-Tenant-Id": "tenant_test_123"} body = { "sessionId": "session_001", "currentMessage": "Test message", "channelType": "wechat", } response = client.post("/ai/chat", json=body, headers=headers) assert response.status_code == 200 def test_health_endpoint_bypasses_tenant_check( self, client: TestClient ): """ Test that health endpoint doesn't require X-Tenant-Id. """ response = client.get("/ai/health") assert response.status_code == 200