286 lines
8.1 KiB
Python
286 lines
8.1 KiB
Python
|
|
"""
|
||
|
|
Tests for response mode switching based on Accept header.
|
||
|
|
[AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
from httpx import AsyncClient
|
||
|
|
|
||
|
|
from app.main import app
|
||
|
|
|
||
|
|
|
||
|
|
class TestAcceptHeaderSwitching:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-06] Test cases for Accept header based response mode switching.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def valid_request_body(self):
|
||
|
|
return {
|
||
|
|
"sessionId": "test_session_001",
|
||
|
|
"currentMessage": "Hello, how are you?",
|
||
|
|
"channelType": "wechat",
|
||
|
|
}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def valid_headers(self):
|
||
|
|
return {"X-Tenant-Id": "tenant_001"}
|
||
|
|
|
||
|
|
def test_json_response_with_default_accept(
|
||
|
|
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-06] Test that default Accept header returns JSON response.
|
||
|
|
"""
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=valid_request_body,
|
||
|
|
headers=valid_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert response.headers["content-type"] == "application/json"
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
assert "reply" in data
|
||
|
|
assert "confidence" in data
|
||
|
|
assert "shouldTransfer" in data
|
||
|
|
|
||
|
|
def test_json_response_with_application_json_accept(
|
||
|
|
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-06] Test that Accept: application/json returns JSON response.
|
||
|
|
"""
|
||
|
|
headers = {**valid_headers, "Accept": "application/json"}
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=valid_request_body,
|
||
|
|
headers=headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert response.headers["content-type"] == "application/json"
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
assert "reply" in data
|
||
|
|
assert "confidence" in data
|
||
|
|
assert "shouldTransfer" in data
|
||
|
|
|
||
|
|
def test_sse_response_with_text_event_stream_accept(
|
||
|
|
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-06] Test that Accept: text/event-stream returns SSE response.
|
||
|
|
"""
|
||
|
|
headers = {**valid_headers, "Accept": "text/event-stream"}
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=valid_request_body,
|
||
|
|
headers=headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert "text/event-stream" in response.headers["content-type"]
|
||
|
|
|
||
|
|
content = response.text
|
||
|
|
assert "event: message" in content
|
||
|
|
assert "event: final" in content
|
||
|
|
|
||
|
|
def test_sse_response_event_sequence(
|
||
|
|
self, client: TestClient, valid_request_body: dict, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence.
|
||
|
|
message* -> final -> close
|
||
|
|
"""
|
||
|
|
headers = {**valid_headers, "Accept": "text/event-stream"}
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=valid_request_body,
|
||
|
|
headers=headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
content = response.text
|
||
|
|
|
||
|
|
assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}"
|
||
|
|
assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}"
|
||
|
|
|
||
|
|
message_idx = content.find("event:message")
|
||
|
|
if message_idx == -1:
|
||
|
|
message_idx = content.find("event: message")
|
||
|
|
final_idx = content.find("event:final")
|
||
|
|
if final_idx == -1:
|
||
|
|
final_idx = content.find("event: final")
|
||
|
|
|
||
|
|
assert final_idx > message_idx, "final event should come after message events"
|
||
|
|
|
||
|
|
def test_missing_tenant_id_returns_400(
|
||
|
|
self, client: TestClient, valid_request_body: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
|
||
|
|
"""
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=valid_request_body,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
assert data["code"] == "MISSING_TENANT_ID"
|
||
|
|
assert "message" in data
|
||
|
|
|
||
|
|
def test_invalid_channel_type_returns_400(
|
||
|
|
self, client: TestClient, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-03] Test that invalid channel type returns 400 error.
|
||
|
|
"""
|
||
|
|
invalid_body = {
|
||
|
|
"sessionId": "test_session_001",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "invalid_channel",
|
||
|
|
}
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=invalid_body,
|
||
|
|
headers=valid_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
|
||
|
|
def test_missing_required_fields_returns_400(
|
||
|
|
self, client: TestClient, valid_headers: dict
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-03] Test that missing required fields return 400 error.
|
||
|
|
"""
|
||
|
|
incomplete_body = {
|
||
|
|
"sessionId": "test_session_001",
|
||
|
|
}
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json=incomplete_body,
|
||
|
|
headers=valid_headers,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestHealthEndpoint:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-20] Test cases for health check endpoint.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
def test_health_check_returns_200(self, client: TestClient):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-20] Test that health check returns 200 with status.
|
||
|
|
"""
|
||
|
|
response = client.get("/ai/health")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "healthy"
|
||
|
|
|
||
|
|
|
||
|
|
class TestSSEStateMachine:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_state_transitions(self):
|
||
|
|
from app.core.sse import SSEState, SSEStateMachine
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
|
||
|
|
assert state_machine.state == SSEState.INIT
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_streaming()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.STREAMING
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is True
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_final()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.FINAL_SENT
|
||
|
|
|
||
|
|
assert state_machine.can_send_message() is False
|
||
|
|
|
||
|
|
await state_machine.close()
|
||
|
|
assert state_machine.state == SSEState.CLOSED
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_error_transition_from_streaming(self):
|
||
|
|
from app.core.sse import SSEState, SSEStateMachine
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_error()
|
||
|
|
assert success is True
|
||
|
|
assert state_machine.state == SSEState.ERROR_SENT
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cannot_transition_to_final_from_init(self):
|
||
|
|
from app.core.sse import SSEStateMachine
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
|
||
|
|
success = await state_machine.transition_to_final()
|
||
|
|
assert success is False
|
||
|
|
|
||
|
|
|
||
|
|
class TestMiddleware:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10, AC-AISVC-12] Test cases for middleware.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
def test_tenant_context_extraction(
|
||
|
|
self, client: TestClient
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used.
|
||
|
|
"""
|
||
|
|
headers = {"X-Tenant-Id": "tenant_test_123"}
|
||
|
|
body = {
|
||
|
|
"sessionId": "session_001",
|
||
|
|
"currentMessage": "Test message",
|
||
|
|
"channelType": "wechat",
|
||
|
|
}
|
||
|
|
|
||
|
|
response = client.post("/ai/chat", json=body, headers=headers)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
|
||
|
|
def test_health_endpoint_bypasses_tenant_check(
|
||
|
|
self, client: TestClient
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Test that health endpoint doesn't require X-Tenant-Id.
|
||
|
|
"""
|
||
|
|
response = client.get("/ai/health")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|