feat(ai-service): implement confidence calculation for T3.3 [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19]

- Add ConfidenceCalculator class for confidence scoring
- Implement retrieval insufficiency detection (hit count, score threshold, evidence tokens)
- Implement confidence calculation based on retrieval scores
- Implement shouldTransfer logic with configurable threshold
- Add transferReason for low confidence scenarios
- Add comprehensive unit tests (19 test cases)
- Update config with confidence-related settings
This commit is contained in:
MerCry 2026-02-24 13:31:42 +08:00
parent c9f2c1eb3a
commit 66fa2d2677
5 changed files with 952 additions and 9 deletions

View File

@ -1,6 +1,6 @@
"""
Chat endpoint for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-08, AC-AISVC-09] Main chat endpoint with streaming/non-streaming modes.
"""
import logging
@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from app.core.middleware import get_response_mode, is_sse_request
from app.core.sse import create_error_event
from app.core.sse import SSEStateMachine, create_error_event
from app.core.tenant import get_tenant_id
from app.models import ChatRequest, ChatResponse, ErrorResponse
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
@ -109,19 +109,58 @@ async def _handle_streaming_request(
) -> EventSourceResponse:
"""
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
Yields message events followed by final or error event.
SSE Event Sequence (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
State machine ensures:
- No events after final/error
- Only one final OR one error event
- Proper connection close
"""
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
state_machine = SSEStateMachine()
async def event_generator():
"""
[AC-AISVC-08, AC-AISVC-09] Event generator with state machine enforcement.
Ensures proper event sequence and error handling.
"""
await state_machine.transition_to_streaming()
try:
async for event in orchestrator.generate_stream(tenant_id, chat_request):
if not state_machine.can_send_message():
logger.warning("[AC-AISVC-08] Received event after state machine closed, ignoring")
break
if event.event == "final":
if await state_machine.transition_to_final():
logger.info("[AC-AISVC-08] Sending final event and closing stream")
yield event
break
elif event.event == "error":
if await state_machine.transition_to_error():
logger.info("[AC-AISVC-09] Sending error event and closing stream")
yield event
break
elif event.event == "message":
yield event
except Exception as e:
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
finally:
await state_machine.close()
logger.debug("[AC-AISVC-08] SSE connection closed")
return EventSourceResponse(event_generator(), ping=15)

View File

@ -45,7 +45,9 @@ class Settings(BaseSettings):
rag_min_hits: int = 1
rag_max_evidence_tokens: int = 2000
confidence_threshold_low: float = 0.5
confidence_low_threshold: float = 0.5
confidence_high_threshold: float = 0.8
confidence_insufficient_penalty: float = 0.3
max_history_tokens: int = 4000

View File

@ -0,0 +1,224 @@
"""
Confidence calculation for AI Service.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Confidence scoring and transfer suggestion logic.
Design reference: design.md Section 4.3 - 检索不中兜底与置信度策略
- Retrieval insufficiency detection
- Confidence calculation based on retrieval scores
- shouldTransfer logic with threshold T_low
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from app.core.config import get_settings
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ConfidenceConfig:
"""
Configuration for confidence calculation.
[AC-AISVC-17, AC-AISVC-18] Configurable thresholds.
"""
score_threshold: float = 0.7
min_hits: int = 1
confidence_low_threshold: float = 0.5
confidence_high_threshold: float = 0.8
insufficient_penalty: float = 0.3
max_evidence_tokens: int = 2000
@dataclass
class ConfidenceResult:
"""
Result of confidence calculation.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Contains confidence and transfer suggestion.
"""
confidence: float
should_transfer: bool
transfer_reason: str | None = None
is_retrieval_insufficient: bool = False
diagnostics: dict[str, Any] = field(default_factory=dict)
class ConfidenceCalculator:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculator for response confidence.
Design reference: design.md Section 4.3
- MVP: confidence based on RAG retrieval scores
- Insufficient retrieval triggers confidence downgrade
- shouldTransfer when confidence < T_low
"""
def __init__(self, config: ConfidenceConfig | None = None):
settings = get_settings()
self._config = config or ConfidenceConfig(
score_threshold=getattr(settings, "rag_score_threshold", 0.7),
min_hits=getattr(settings, "rag_min_hits", 1),
confidence_low_threshold=getattr(settings, "confidence_low_threshold", 0.5),
confidence_high_threshold=getattr(settings, "confidence_high_threshold", 0.8),
insufficient_penalty=getattr(settings, "confidence_insufficient_penalty", 0.3),
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
)
def is_retrieval_insufficient(
self,
retrieval_result: RetrievalResult,
evidence_tokens: int | None = None,
) -> tuple[bool, str]:
"""
[AC-AISVC-17] Determine if retrieval results are insufficient.
Conditions for insufficiency:
1. hits.size < min_hits
2. max(score) < score_threshold
3. evidence tokens exceed limit (optional)
Args:
retrieval_result: Result from retrieval operation
evidence_tokens: Optional token count for evidence
Returns:
Tuple of (is_insufficient, reason)
"""
reasons = []
if retrieval_result.hit_count < self._config.min_hits:
reasons.append(
f"hit_count({retrieval_result.hit_count}) < min_hits({self._config.min_hits})"
)
if retrieval_result.max_score < self._config.score_threshold:
reasons.append(
f"max_score({retrieval_result.max_score:.3f}) < threshold({self._config.score_threshold})"
)
if evidence_tokens is not None and evidence_tokens > self._config.max_evidence_tokens:
reasons.append(
f"evidence_tokens({evidence_tokens}) > max({self._config.max_evidence_tokens})"
)
is_insufficient = len(reasons) > 0
reason = "; ".join(reasons) if reasons else "sufficient"
return is_insufficient, reason
def calculate_confidence(
self,
retrieval_result: RetrievalResult,
evidence_tokens: int | None = None,
additional_factors: dict[str, float] | None = None,
) -> ConfidenceResult:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence and transfer suggestion.
MVP Strategy:
1. Base confidence from max retrieval score
2. Adjust for hit count (more hits = higher confidence)
3. Penalize if retrieval is insufficient
4. Determine shouldTransfer based on T_low threshold
Args:
retrieval_result: Result from retrieval operation
evidence_tokens: Optional token count for evidence
additional_factors: Optional additional confidence factors
Returns:
ConfidenceResult with confidence and transfer suggestion
"""
is_insufficient, insufficiency_reason = self.is_retrieval_insufficient(
retrieval_result, evidence_tokens
)
base_confidence = retrieval_result.max_score
hit_count_factor = min(1.0, retrieval_result.hit_count / 5.0)
confidence = base_confidence * 0.7 + hit_count_factor * 0.3
if is_insufficient:
confidence -= self._config.insufficient_penalty
logger.info(
f"[AC-AISVC-17] Retrieval insufficient: {insufficiency_reason}, "
f"applying penalty -{self._config.insufficient_penalty}"
)
if additional_factors:
for factor_name, factor_value in additional_factors.items():
confidence += factor_value * 0.1
confidence = max(0.0, min(1.0, confidence))
should_transfer = confidence < self._config.confidence_low_threshold
transfer_reason = None
if should_transfer:
if is_insufficient:
transfer_reason = "检索结果不足,无法提供高置信度回答"
else:
transfer_reason = "置信度低于阈值,建议转人工"
elif confidence < self._config.confidence_high_threshold and is_insufficient:
transfer_reason = "检索结果有限,回答可能不够准确"
diagnostics = {
"base_confidence": base_confidence,
"hit_count": retrieval_result.hit_count,
"max_score": retrieval_result.max_score,
"is_insufficient": is_insufficient,
"insufficiency_reason": insufficiency_reason if is_insufficient else None,
"penalty_applied": self._config.insufficient_penalty if is_insufficient else 0.0,
"threshold_low": self._config.confidence_low_threshold,
"threshold_high": self._config.confidence_high_threshold,
}
logger.info(
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
f"{confidence:.3f}, should_transfer={should_transfer}, "
f"insufficient={is_insufficient}"
)
return ConfidenceResult(
confidence=round(confidence, 3),
should_transfer=should_transfer,
transfer_reason=transfer_reason,
is_retrieval_insufficient=is_insufficient,
diagnostics=diagnostics,
)
def calculate_confidence_no_retrieval(self) -> ConfidenceResult:
"""
[AC-AISVC-17] Calculate confidence when no retrieval was performed.
Returns a low confidence result suggesting transfer.
"""
return ConfidenceResult(
confidence=0.3,
should_transfer=True,
transfer_reason="未进行知识库检索,建议转人工",
is_retrieval_insufficient=True,
diagnostics={
"base_confidence": 0.0,
"hit_count": 0,
"max_score": 0.0,
"is_insufficient": True,
"insufficiency_reason": "no_retrieval",
"penalty_applied": 0.0,
"threshold_low": self._config.confidence_low_threshold,
"threshold_high": self._config.confidence_high_threshold,
},
)
_confidence_calculator: ConfidenceCalculator | None = None
def get_confidence_calculator() -> ConfidenceCalculator:
"""Get or create confidence calculator instance."""
global _confidence_calculator
if _confidence_calculator is None:
_confidence_calculator = ConfidenceCalculator()
return _confidence_calculator

View File

@ -0,0 +1,302 @@
"""
Unit tests for Confidence Calculator.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Tests for confidence scoring and transfer logic.
Tests cover:
- Retrieval insufficiency detection
- Confidence calculation based on retrieval scores
- shouldTransfer logic with threshold T_low
- Edge cases (no retrieval, empty results)
"""
from unittest.mock import MagicMock, patch
import pytest
from app.services.retrieval.base import RetrievalHit, RetrievalResult
from app.services.confidence import (
ConfidenceCalculator,
ConfidenceConfig,
ConfidenceResult,
get_confidence_calculator,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.rag_score_threshold = 0.7
settings.rag_min_hits = 1
settings.confidence_low_threshold = 0.5
settings.confidence_high_threshold = 0.8
settings.confidence_insufficient_penalty = 0.3
settings.rag_max_evidence_tokens = 2000
return settings
@pytest.fixture
def confidence_calculator(mock_settings):
"""Create confidence calculator with mocked settings."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
calculator = ConfidenceCalculator()
yield calculator
@pytest.fixture
def good_retrieval_result():
"""Sample retrieval result with good hits."""
return RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.9, source="kb"),
RetrievalHit(text="Result 2", score=0.85, source="kb"),
RetrievalHit(text="Result 3", score=0.8, source="kb"),
],
diagnostics={"query_length": 50},
)
@pytest.fixture
def poor_retrieval_result():
"""Sample retrieval result with poor hits."""
return RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.5, source="kb"),
],
diagnostics={"query_length": 50},
)
@pytest.fixture
def empty_retrieval_result():
"""Sample empty retrieval result."""
return RetrievalResult(
hits=[],
diagnostics={"query_length": 50},
)
class TestRetrievalInsufficiency:
"""Tests for retrieval insufficiency detection. [AC-AISVC-17]"""
def test_sufficient_retrieval(self, confidence_calculator, good_retrieval_result):
"""[AC-AISVC-17] Test sufficient retrieval detection."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
good_retrieval_result
)
assert is_insufficient is False
assert reason == "sufficient"
def test_insufficient_hit_count(self, confidence_calculator):
"""[AC-AISVC-17] Test insufficiency due to low hit count."""
config = ConfidenceConfig(min_hits=3)
calculator = ConfidenceCalculator(config=config)
result = RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.9, source="kb"),
]
)
is_insufficient, reason = calculator.is_retrieval_insufficient(result)
assert is_insufficient is True
assert "hit_count" in reason.lower()
def test_insufficient_score(self, confidence_calculator, poor_retrieval_result):
"""[AC-AISVC-17] Test insufficiency due to low score."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
poor_retrieval_result
)
assert is_insufficient is True
assert "max_score" in reason.lower()
def test_insufficient_empty_result(self, confidence_calculator, empty_retrieval_result):
"""[AC-AISVC-17] Test insufficiency with empty result."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
empty_retrieval_result
)
assert is_insufficient is True
def test_insufficient_evidence_tokens(self, confidence_calculator, good_retrieval_result):
"""[AC-AISVC-17] Test insufficiency due to evidence token limit."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
good_retrieval_result, evidence_tokens=3000
)
assert is_insufficient is True
assert "evidence_tokens" in reason.lower()
class TestConfidenceCalculation:
"""Tests for confidence calculation. [AC-AISVC-17, AC-AISVC-19]"""
def test_high_confidence_with_good_retrieval(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-19] Test high confidence with good retrieval results."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert isinstance(result, ConfidenceResult)
assert result.confidence >= 0.5
assert result.should_transfer is False
assert result.is_retrieval_insufficient is False
def test_low_confidence_with_poor_retrieval(
self, confidence_calculator, poor_retrieval_result
):
"""[AC-AISVC-17] Test low confidence with poor retrieval results."""
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
assert isinstance(result, ConfidenceResult)
assert result.confidence < 0.7
assert result.is_retrieval_insufficient is True
def test_confidence_with_empty_result(
self, confidence_calculator, empty_retrieval_result
):
"""[AC-AISVC-17] Test confidence with empty retrieval result."""
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
assert result.confidence < 0.5
assert result.should_transfer is True
assert result.is_retrieval_insufficient is True
def test_confidence_includes_diagnostics(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-17] Test that confidence result includes diagnostics."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert "base_confidence" in result.diagnostics
assert "hit_count" in result.diagnostics
assert "max_score" in result.diagnostics
assert "threshold_low" in result.diagnostics
def test_confidence_with_additional_factors(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-17] Test confidence with additional factors."""
additional = {"model_certainty": 0.5}
result = confidence_calculator.calculate_confidence(
good_retrieval_result, additional_factors=additional
)
assert result.confidence > 0
def test_confidence_bounded_to_range(self, confidence_calculator):
"""[AC-AISVC-17] Test that confidence is bounded to [0, 1]."""
result_with_high_score = RetrievalResult(
hits=[RetrievalHit(text="Result", score=1.0, source="kb")]
)
result = confidence_calculator.calculate_confidence(result_with_high_score)
assert 0.0 <= result.confidence <= 1.0
class TestShouldTransfer:
"""Tests for shouldTransfer logic. [AC-AISVC-18]"""
def test_no_transfer_with_high_confidence(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-18] Test no transfer when confidence is high."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert result.should_transfer is False
assert result.transfer_reason is None
def test_transfer_with_low_confidence(
self, confidence_calculator, empty_retrieval_result
):
"""[AC-AISVC-18] Test transfer when confidence is low."""
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
assert result.should_transfer is True
assert result.transfer_reason is not None
def test_transfer_reason_for_insufficient_retrieval(
self, confidence_calculator, poor_retrieval_result
):
"""[AC-AISVC-18] Test transfer reason for insufficient retrieval."""
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
assert result.is_retrieval_insufficient is True
if result.should_transfer:
assert "检索" in result.transfer_reason or "置信度" in result.transfer_reason
def test_custom_threshold(self):
"""[AC-AISVC-18] Test custom low threshold for transfer."""
config = ConfidenceConfig(
confidence_low_threshold=0.7,
score_threshold=0.7,
min_hits=1,
)
calculator = ConfidenceCalculator(config=config)
result = RetrievalResult(
hits=[RetrievalHit(text="Result", score=0.6, source="kb")]
)
conf_result = calculator.calculate_confidence(result)
assert conf_result.should_transfer is True
class TestNoRetrieval:
"""Tests for no retrieval scenario. [AC-AISVC-17]"""
def test_no_retrieval_confidence(self, confidence_calculator):
"""[AC-AISVC-17] Test confidence when no retrieval was performed."""
result = confidence_calculator.calculate_confidence_no_retrieval()
assert result.confidence == 0.3
assert result.should_transfer is True
assert result.transfer_reason is not None
assert result.is_retrieval_insufficient is True
class TestConfidenceConfig:
"""Tests for confidence configuration."""
def test_default_config(self, mock_settings):
"""Test default configuration values."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
calculator = ConfidenceCalculator()
assert calculator._config.score_threshold == 0.7
assert calculator._config.min_hits == 1
assert calculator._config.confidence_low_threshold == 0.5
def test_custom_config(self):
"""Test custom configuration values."""
config = ConfidenceConfig(
score_threshold=0.8,
min_hits=2,
confidence_low_threshold=0.6,
)
calculator = ConfidenceCalculator(config=config)
assert calculator._config.score_threshold == 0.8
assert calculator._config.min_hits == 2
assert calculator._config.confidence_low_threshold == 0.6
class TestConfidenceCalculatorSingleton:
"""Tests for singleton pattern."""
def test_get_confidence_calculator_singleton(self, mock_settings):
"""Test that get_confidence_calculator returns singleton."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
from app.services.confidence import _confidence_calculator
import app.services.confidence as confidence_module
confidence_module._confidence_calculator = None
calculator1 = get_confidence_calculator()
calculator2 = get_confidence_calculator()
assert calculator1 is calculator2

View File

@ -0,0 +1,376 @@
"""
Tests for SSE state machine and error handling.
[AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from sse_starlette.sse import ServerSentEvent
from app.core.sse import (
SSEState,
SSEStateMachine,
create_error_event,
create_final_event,
create_message_event,
)
from app.main import app
from app.models import ChatRequest, ChannelType
class TestSSEStateMachineTransitions:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions.
"""
@pytest.mark.asyncio
async def test_init_to_streaming_transition(self):
"""
[AC-AISVC-08] Test INIT -> STREAMING transition.
"""
state_machine = SSEStateMachine()
assert state_machine.state == SSEState.INIT
success = await state_machine.transition_to_streaming()
assert success is True
assert state_machine.state == SSEState.STREAMING
@pytest.mark.asyncio
async def test_streaming_to_final_transition(self):
"""
[AC-AISVC-08] Test STREAMING -> FINAL_SENT transition.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_final()
assert success is True
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_streaming_to_error_transition(self):
"""
[AC-AISVC-09] Test STREAMING -> ERROR_SENT transition.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_init_to_error_transition(self):
"""
[AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts).
"""
state_machine = SSEStateMachine()
success = await state_machine.transition_to_error()
assert success is True
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_transition_from_final(self):
"""
[AC-AISVC-08] Test that no transitions are possible after FINAL_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
assert await state_machine.transition_to_streaming() is False
assert await state_machine.transition_to_error() is False
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_cannot_transition_from_error(self):
"""
[AC-AISVC-09] Test that no transitions are possible after ERROR_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert await state_machine.transition_to_streaming() is False
assert await state_machine.transition_to_final() is False
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_cannot_send_message_after_final(self):
"""
[AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False
@pytest.mark.asyncio
async def test_cannot_send_message_after_error(self):
"""
[AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert state_machine.can_send_message() is False
@pytest.mark.asyncio
async def test_close_transition(self):
"""
[AC-AISVC-08] Test that close() transitions to CLOSED state.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_final()
await state_machine.close()
assert state_machine.state == SSEState.CLOSED
class TestSSEEventSequence:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"}
@pytest.fixture
def valid_body(self):
return {
"sessionId": "test_session",
"currentMessage": "Hello",
"channelType": "wechat",
}
def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that SSE events follow: message* -> final -> close.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
assert response.status_code == 200
content = response.text
assert "event:message" in content or "event: message" in content
assert "event:final" in content or "event: final" in content
message_idx = content.find("event:message")
if message_idx == -1:
message_idx = content.find("event: message")
final_idx = content.find("event:final")
if final_idx == -1:
final_idx = content.find("event: final")
assert final_idx > message_idx, "final should come after message events"
def test_sse_only_one_final_event(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that there is exactly one final event.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
content = response.text
final_count = content.count("event:final") + content.count("event: final")
assert final_count == 1, f"Expected exactly 1 final event, got {final_count}"
def test_sse_no_events_after_final(self, client, valid_headers, valid_body):
"""
[AC-AISVC-08] Test that no message events appear after final event.
"""
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
content = response.text
lines = content.split("\n")
final_found = False
for line in lines:
if "event:final" in line or "event: final" in line:
final_found = True
elif final_found and ("event:message" in line or "event: message" in line):
pytest.fail("Found message event after final event")
class TestSSEErrorHandling:
"""
[AC-AISVC-09] Test cases for SSE error handling.
"""
@pytest.mark.asyncio
async def test_error_event_format(self):
"""
[AC-AISVC-09] Test error event format.
"""
event = create_error_event(
code="TEST_ERROR",
message="Test error message",
details=[{"field": "test"}],
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "TEST_ERROR"
assert data["message"] == "Test error message"
assert data["details"] == [{"field": "test"}]
@pytest.mark.asyncio
async def test_error_event_without_details(self):
"""
[AC-AISVC-09] Test error event without details.
"""
event = create_error_event(
code="SIMPLE_ERROR",
message="Simple error",
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "SIMPLE_ERROR"
assert data["message"] == "Simple error"
assert "details" not in data
def test_missing_tenant_id_returns_400(self):
"""
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
"""
client = TestClient(app)
headers = {"Accept": "text/event-stream"}
body = {
"sessionId": "test_session",
"currentMessage": "Hello",
"channelType": "wechat",
}
response = client.post("/ai/chat", json=body, headers=headers)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
class TestSSEStateConcurrency:
"""
[AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety.
"""
@pytest.mark.asyncio
async def test_concurrent_transitions(self):
"""
[AC-AISVC-08] Test that concurrent transitions are handled correctly.
"""
import asyncio
state_machine = SSEStateMachine()
results = []
async def try_transition():
success = await state_machine.transition_to_streaming()
results.append(success)
await asyncio.gather(
try_transition(),
try_transition(),
try_transition(),
)
assert sum(results) == 1, "Only one transition should succeed"
assert state_machine.state == SSEState.STREAMING
@pytest.mark.asyncio
async def test_concurrent_final_transitions(self):
"""
[AC-AISVC-08] Test that only one final transition succeeds.
"""
import asyncio
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
results = []
async def try_final():
success = await state_machine.transition_to_final()
results.append(success)
await asyncio.gather(
try_final(),
try_final(),
)
assert sum(results) == 1, "Only one final transition should succeed"
assert state_machine.state == SSEState.FINAL_SENT
class TestSSEIntegrationWithOrchestrator:
"""
[AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator.
"""
@pytest.mark.asyncio
async def test_orchestrator_stream_with_error(self):
"""
[AC-AISVC-09] Test that orchestrator errors are properly handled.
"""
from app.services.orchestrator import OrchestratorService
mock_llm = MagicMock()
async def failing_stream(*args, **kwargs):
yield MagicMock(delta="Hello", finish_reason=None)
raise Exception("LLM connection lost")
mock_llm.stream_generate = failing_stream
orchestrator = OrchestratorService(llm_client=mock_llm)
request = ChatRequest(
session_id="test",
current_message="Hi",
channel_type=ChannelType.WECHAT,
)
events = []
async for event in orchestrator.generate_stream("tenant", request):
events.append(event)
event_types = [e.event for e in events]
assert "message" in event_types
assert "error" in event_types
@pytest.mark.asyncio
async def test_orchestrator_stream_normal_flow(self):
"""
[AC-AISVC-08] Test normal streaming flow ends with final event.
"""
from app.services.orchestrator import OrchestratorService
orchestrator = OrchestratorService()
request = ChatRequest(
session_id="test",
current_message="Hi",
channel_type=ChannelType.WECHAT,
)
events = []
async for event in orchestrator.generate_stream("tenant", request):
events.append(event)
event_types = [e.event for e in events]
assert "message" in event_types
assert "final" in event_types
final_index = event_types.index("final")
for i, t in enumerate(event_types):
if t == "message":
assert i < final_index, "message events should come before final"