ai-robot-core/ai-service/tests/test_contract.py

454 lines
14 KiB
Python

"""
Contract validation tests for AI Service.
[AC-AISVC-02] Verify response fields match openapi.provider.yaml contract.
OpenAPI ChatResponse schema:
- reply: string (required)
- confidence: number (double, required)
- shouldTransfer: boolean (required)
- transferReason: string (optional)
- metadata: object (optional)
"""
import json
import pytest
from pydantic import ValidationError
from app.models import (
ChatResponse,
ChatRequest,
ChatMessage,
Role,
ChannelType,
ErrorResponse,
SSEFinalEvent,
SSEErrorEvent,
)
class TestChatResponseContract:
"""
[AC-AISVC-02] Test ChatResponse matches OpenAPI contract.
"""
def test_required_fields_present(self):
"""
[AC-AISVC-02] ChatResponse must have reply, confidence, shouldTransfer.
"""
response = ChatResponse(
reply="Test reply",
confidence=0.85,
should_transfer=False,
)
assert response.reply == "Test reply"
assert response.confidence == 0.85
assert response.should_transfer is False
def test_json_serialization_uses_camel_case(self):
"""
[AC-AISVC-02] JSON output must use camelCase per OpenAPI contract.
Field names: shouldTransfer, transferReason (not snake_case)
"""
response = ChatResponse(
reply="Test reply",
confidence=0.85,
should_transfer=True,
transfer_reason="Low confidence",
metadata={"key": "value"},
)
json_str = response.model_dump_json(by_alias=True)
data = json.loads(json_str)
assert "shouldTransfer" in data
assert "should_transfer" not in data
assert "transferReason" in data
assert "transfer_reason" not in data
def test_json_output_matches_contract_structure(self):
"""
[AC-AISVC-02] JSON output structure must match OpenAPI schema exactly.
Optional fields with None values are included as null in JSON.
"""
response = ChatResponse(
reply="AI response content",
confidence=0.92,
should_transfer=False,
transfer_reason=None,
metadata={"session_id": "test-123"},
)
data = json.loads(response.model_dump_json(by_alias=True))
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
assert "transferReason" in data
assert "metadata" in data
assert data["reply"] == "AI response content"
assert data["confidence"] == 0.92
assert data["shouldTransfer"] is False
assert data["transferReason"] is None
assert data["metadata"]["session_id"] == "test-123"
def test_optional_fields_can_be_omitted(self):
"""
[AC-AISVC-02] transferReason and metadata are optional.
"""
response = ChatResponse(
reply="Reply without optional fields",
confidence=0.5,
should_transfer=True,
)
json_str = response.model_dump_json(by_alias=True)
data = json.loads(json_str)
assert data["reply"] == "Reply without optional fields"
assert data["confidence"] == 0.5
assert data["shouldTransfer"] is True
assert data.get("transferReason") is None
assert data.get("metadata") is None
def test_confidence_must_be_between_0_and_1(self):
"""
[AC-AISVC-02] confidence must be in range [0.0, 1.0].
"""
valid_response = ChatResponse(
reply="Valid",
confidence=0.0,
should_transfer=False,
)
assert valid_response.confidence == 0.0
valid_response = ChatResponse(
reply="Valid",
confidence=1.0,
should_transfer=False,
)
assert valid_response.confidence == 1.0
def test_confidence_rejects_negative(self):
"""
[AC-AISVC-02] confidence must reject negative values.
"""
with pytest.raises(ValidationError):
ChatResponse(
reply="Invalid",
confidence=-0.1,
should_transfer=False,
)
def test_confidence_rejects_above_1(self):
"""
[AC-AISVC-02] confidence must reject values > 1.0.
"""
with pytest.raises(ValidationError):
ChatResponse(
reply="Invalid",
confidence=1.5,
should_transfer=False,
)
def test_reply_is_required(self):
"""
[AC-AISVC-02] reply field is required.
"""
with pytest.raises(ValidationError):
ChatResponse(
confidence=0.5,
should_transfer=False,
)
def test_confidence_is_required(self):
"""
[AC-AISVC-02] confidence field is required.
"""
with pytest.raises(ValidationError):
ChatResponse(
reply="Test",
should_transfer=False,
)
def test_should_transfer_is_required(self):
"""
[AC-AISVC-02] shouldTransfer field is required.
"""
with pytest.raises(ValidationError):
ChatResponse(
reply="Test",
confidence=0.5,
)
def test_transfer_reason_accepts_string(self):
"""
[AC-AISVC-02] transferReason accepts string value.
"""
response = ChatResponse(
reply="Test",
confidence=0.3,
should_transfer=True,
transfer_reason="检索结果不足,建议转人工",
)
data = json.loads(response.model_dump_json(by_alias=True))
assert data["transferReason"] == "检索结果不足,建议转人工"
def test_metadata_accepts_any_object(self):
"""
[AC-AISVC-02] metadata accepts any object with additionalProperties.
"""
response = ChatResponse(
reply="Test",
confidence=0.8,
should_transfer=False,
metadata={
"session_id": "session-123",
"channel_type": "wechat",
"diagnostics": {
"retrieval_hits": 5,
"llm_model": "gpt-4o-mini",
},
},
)
data = json.loads(response.model_dump_json(by_alias=True))
assert data["metadata"]["session_id"] == "session-123"
assert data["metadata"]["diagnostics"]["retrieval_hits"] == 5
class TestChatRequestContract:
"""
[AC-AISVC-02] Test ChatRequest matches OpenAPI contract.
"""
def test_required_fields(self):
"""
[AC-AISVC-02] ChatRequest required fields: sessionId, currentMessage, channelType.
"""
request = ChatRequest(
session_id="session-123",
current_message="Hello",
channel_type=ChannelType.WECHAT,
)
assert request.session_id == "session-123"
assert request.current_message == "Hello"
assert request.channel_type == ChannelType.WECHAT
def test_json_input_uses_camel_case(self):
"""
[AC-AISVC-02] JSON input should accept camelCase field names.
"""
json_data = {
"sessionId": "session-456",
"currentMessage": "What is the price?",
"channelType": "wechat",
}
request = ChatRequest.model_validate(json_data)
assert request.session_id == "session-456"
assert request.current_message == "What is the price?"
def test_optional_history_field(self):
"""
[AC-AISVC-02] history is optional.
"""
request = ChatRequest(
session_id="session-789",
current_message="Follow-up question",
channel_type=ChannelType.DOUYIN,
history=[
ChatMessage(role=Role.USER, content="Previous question"),
ChatMessage(role=Role.ASSISTANT, content="Previous answer"),
],
)
assert len(request.history) == 2
assert request.history[0].role == Role.USER
def test_channel_type_enum_values(self):
"""
[AC-AISVC-02] channelType must be one of: wechat, douyin, jd.
"""
valid_types = ["wechat", "douyin", "jd"]
for channel in valid_types:
request = ChatRequest(
session_id="test",
current_message="Test",
channel_type=channel,
)
assert request.channel_type.value == channel
class TestErrorResponseContract:
"""
[AC-AISVC-02] Test ErrorResponse matches OpenAPI contract.
"""
def test_required_fields(self):
"""
[AC-AISVC-02] ErrorResponse required fields: code, message.
"""
response = ErrorResponse(
code="INVALID_REQUEST",
message="Missing required field",
)
assert response.code == "INVALID_REQUEST"
assert response.message == "Missing required field"
def test_optional_details(self):
"""
[AC-AISVC-02] details is optional array.
"""
response = ErrorResponse(
code="VALIDATION_ERROR",
message="Multiple validation errors",
details=[
{"field": "sessionId", "error": "required"},
{"field": "channelType", "error": "invalid value"},
],
)
assert len(response.details) == 2
class TestSSEFinalEventContract:
"""
[AC-AISVC-02] Test SSE final event matches OpenAPI ChatResponse structure.
"""
def test_sse_final_event_structure(self):
"""
[AC-AISVC-02] SSE final event must have same structure as ChatResponse.
"""
event = SSEFinalEvent(
reply="Complete AI response",
confidence=0.88,
should_transfer=False,
transfer_reason=None,
metadata={"tokens": 150},
)
data = json.loads(event.model_dump_json(by_alias=True))
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
assert data["shouldTransfer"] is False
def test_sse_final_event_matches_chat_response(self):
"""
[AC-AISVC-02] SSEFinalEvent fields must match ChatResponse exactly.
"""
chat_response = ChatResponse(
reply="Test reply",
confidence=0.75,
should_transfer=True,
transfer_reason="Low confidence",
metadata={"test": "value"},
)
sse_event = SSEFinalEvent(
reply="Test reply",
confidence=0.75,
should_transfer=True,
transfer_reason="Low confidence",
metadata={"test": "value"},
)
chat_data = json.loads(chat_response.model_dump_json(by_alias=True))
sse_data = json.loads(sse_event.model_dump_json(by_alias=True))
assert chat_data == sse_data
class TestSSEErrorEventContract:
"""
[AC-AISVC-02] Test SSE error event matches OpenAPI ErrorResponse structure.
"""
def test_sse_error_event_structure(self):
"""
[AC-AISVC-02] SSE error event must have same structure as ErrorResponse.
"""
event = SSEErrorEvent(
code="GENERATION_ERROR",
message="LLM service unavailable",
details=[{"reason": "timeout"}],
)
data = json.loads(event.model_dump_json())
assert data["code"] == "GENERATION_ERROR"
assert data["message"] == "LLM service unavailable"
assert len(data["details"]) == 1
class TestEndToEndContractValidation:
"""
[AC-AISVC-02] End-to-end contract validation with OrchestratorService.
"""
@pytest.mark.asyncio
async def test_orchestrator_response_matches_contract(self):
"""
[AC-AISVC-02] OrchestratorService.generate() returns valid ChatResponse.
"""
from app.services.orchestrator import OrchestratorService, OrchestratorConfig
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
request = ChatRequest(
session_id="contract-test-session",
current_message="Test message",
channel_type=ChannelType.WECHAT,
)
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert isinstance(response, ChatResponse)
assert isinstance(response.reply, str)
assert isinstance(response.confidence, float)
assert 0.0 <= response.confidence <= 1.0
assert isinstance(response.should_transfer, bool)
@pytest.mark.asyncio
async def test_orchestrator_response_json_serializable(self):
"""
[AC-AISVC-02] OrchestratorService response must be JSON serializable.
"""
from app.services.orchestrator import OrchestratorService, OrchestratorConfig
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
request = ChatRequest(
session_id="json-test-session",
current_message="JSON serialization test",
channel_type=ChannelType.JD,
)
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
json_str = response.model_dump_json(by_alias=True)
data = json.loads(json_str)
assert "reply" in data
assert "confidence" in data
assert "shouldTransfer" in data
assert "should_transfer" not in data