ai-robot-core/ai-service/tests/test_accept_switching.py

286 lines
8.1 KiB
Python
Raw Normal View History

"""
Tests for response mode switching based on Accept header.
[AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes.
"""
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from app.main import app
class TestAcceptHeaderSwitching:
"""
[AC-AISVC-06] Test cases for Accept header based response mode switching.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_request_body(self):
return {
"sessionId": "test_session_001",
"currentMessage": "Hello, how are you?",
"channelType": "wechat",
}
@pytest.fixture
def valid_headers(self):
return {"X-Tenant-Id": "tenant_001"}
def test_json_response_with_default_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that default Accept header returns JSON response.
"""
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=valid_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
def test_json_response_with_application_json_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that Accept: application/json returns JSON response.
"""
headers = {**valid_headers, "Accept": "application/json"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
data = response.json()
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
def test_sse_response_with_text_event_stream_accept(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-06] Test that Accept: text/event-stream returns SSE response.
"""
headers = {**valid_headers, "Accept": "text/event-stream"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
content = response.text
assert "event: message" in content
assert "event: final" in content
def test_sse_response_event_sequence(
self, client: TestClient, valid_request_body: dict, valid_headers: dict
):
"""
[AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence.
message* -> final -> close
"""
headers = {**valid_headers, "Accept": "text/event-stream"}
response = client.post(
"/ai/chat",
json=valid_request_body,
headers=headers,
)
content = response.text
assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}"
assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}"
message_idx = content.find("event:message")
if message_idx == -1:
message_idx = content.find("event: message")
final_idx = content.find("event:final")
if final_idx == -1:
final_idx = content.find("event: final")
assert final_idx > message_idx, "final event should come after message events"
def test_missing_tenant_id_returns_400(
self, client: TestClient, valid_request_body: dict
):
"""
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
"""
response = client.post(
"/ai/chat",
json=valid_request_body,
)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
assert "message" in data
def test_invalid_channel_type_returns_400(
self, client: TestClient, valid_headers: dict
):
"""
[AC-AISVC-03] Test that invalid channel type returns 400 error.
"""
invalid_body = {
"sessionId": "test_session_001",
"currentMessage": "Hello",
"channelType": "invalid_channel",
}
response = client.post(
"/ai/chat",
json=invalid_body,
headers=valid_headers,
)
assert response.status_code == 400
def test_missing_required_fields_returns_400(
self, client: TestClient, valid_headers: dict
):
"""
[AC-AISVC-03] Test that missing required fields return 400 error.
"""
incomplete_body = {
"sessionId": "test_session_001",
}
response = client.post(
"/ai/chat",
json=incomplete_body,
headers=valid_headers,
)
assert response.status_code == 400
class TestHealthEndpoint:
"""
[AC-AISVC-20] Test cases for health check endpoint.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_health_check_returns_200(self, client: TestClient):
"""
[AC-AISVC-20] Test that health check returns 200 with status.
"""
response = client.get("/ai/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
class TestSSEStateMachine:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine.
"""
@pytest.mark.asyncio
async def test_state_transitions(self):
from app.core.sse import SSEState, SSEStateMachine
state_machine = SSEStateMachine()
assert state_machine.state == SSEState.INIT
success = await state_machine.transition_to_streaming()
assert success is True
assert state_machine.state == SSEState.STREAMING
assert state_machine.can_send_message() is True
success = await state_machine.transition_to_final()
assert success is True
assert state_machine.state == SSEState.FINAL_SENT
assert state_machine.can_send_message() is False
await state_machine.close()
assert state_machine.state == SSEState.CLOSED
@pytest.mark.asyncio
async def test_error_transition_from_streaming(self):
from app.core.sse import SSEState, SSEStateMachine
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_transition_to_final_from_init(self):
from app.core.sse import SSEStateMachine
state_machine = SSEStateMachine()
success = await state_machine.transition_to_final()
assert success is False
class TestMiddleware:
"""
[AC-AISVC-10, AC-AISVC-12] Test cases for middleware.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_tenant_context_extraction(
self, client: TestClient
):
"""
[AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used.
"""
headers = {"X-Tenant-Id": "tenant_test_123"}
body = {
"sessionId": "session_001",
"currentMessage": "Test message",
"channelType": "wechat",
}
response = client.post("/ai/chat", json=body, headers=headers)
assert response.status_code == 200
def test_health_endpoint_bypasses_tenant_check(
self, client: TestClient
):
"""
Test that health endpoint doesn't require X-Tenant-Id.
"""
response = client.get("/ai/health")
assert response.status_code == 200