feat(ai-service): implement confidence calculation for T3.3 [AC-AISVC-17, AC-AISVC-18, AC-AISVC-19]
- Add ConfidenceCalculator class for confidence scoring - Implement retrieval insufficiency detection (hit count, score threshold, evidence tokens) - Implement confidence calculation based on retrieval scores - Implement shouldTransfer logic with configurable threshold - Add transferReason for low confidence scenarios - Add comprehensive unit tests (19 test cases) - Update config with confidence-related settings
This commit is contained in:
parent
c9f2c1eb3a
commit
66fa2d2677
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Chat endpoint for AI Service.
|
Chat endpoint for AI Service.
|
||||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes.
|
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-08, AC-AISVC-09] Main chat endpoint with streaming/non-streaming modes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from app.core.middleware import get_response_mode, is_sse_request
|
from app.core.middleware import get_response_mode, is_sse_request
|
||||||
from app.core.sse import create_error_event
|
from app.core.sse import SSEStateMachine, create_error_event
|
||||||
from app.core.tenant import get_tenant_id
|
from app.core.tenant import get_tenant_id
|
||||||
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
||||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
||||||
|
|
@ -109,19 +109,58 @@ async def _handle_streaming_request(
|
||||||
) -> EventSourceResponse:
|
) -> EventSourceResponse:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
|
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
|
||||||
Yields message events followed by final or error event.
|
|
||||||
|
SSE Event Sequence (per design.md Section 6.2):
|
||||||
|
- message* (0 or more) -> final (exactly 1) -> close
|
||||||
|
- OR message* (0 or more) -> error (exactly 1) -> close
|
||||||
|
|
||||||
|
State machine ensures:
|
||||||
|
- No events after final/error
|
||||||
|
- Only one final OR one error event
|
||||||
|
- Proper connection close
|
||||||
"""
|
"""
|
||||||
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
|
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
|
||||||
|
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Event generator with state machine enforcement.
|
||||||
|
Ensures proper event sequence and error handling.
|
||||||
|
"""
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in orchestrator.generate_stream(tenant_id, chat_request):
|
async for event in orchestrator.generate_stream(tenant_id, chat_request):
|
||||||
yield event
|
if not state_machine.can_send_message():
|
||||||
|
logger.warning("[AC-AISVC-08] Received event after state machine closed, ignoring")
|
||||||
|
break
|
||||||
|
|
||||||
|
if event.event == "final":
|
||||||
|
if await state_machine.transition_to_final():
|
||||||
|
logger.info("[AC-AISVC-08] Sending final event and closing stream")
|
||||||
|
yield event
|
||||||
|
break
|
||||||
|
|
||||||
|
elif event.event == "error":
|
||||||
|
if await state_machine.transition_to_error():
|
||||||
|
logger.info("[AC-AISVC-09] Sending error event and closing stream")
|
||||||
|
yield event
|
||||||
|
break
|
||||||
|
|
||||||
|
elif event.event == "message":
|
||||||
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
|
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
|
||||||
yield create_error_event(
|
if await state_machine.transition_to_error():
|
||||||
code="STREAMING_ERROR",
|
yield create_error_event(
|
||||||
message=str(e),
|
code="STREAMING_ERROR",
|
||||||
)
|
message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await state_machine.close()
|
||||||
|
logger.debug("[AC-AISVC-08] SSE connection closed")
|
||||||
|
|
||||||
return EventSourceResponse(event_generator(), ping=15)
|
return EventSourceResponse(event_generator(), ping=15)
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,9 @@ class Settings(BaseSettings):
|
||||||
rag_min_hits: int = 1
|
rag_min_hits: int = 1
|
||||||
rag_max_evidence_tokens: int = 2000
|
rag_max_evidence_tokens: int = 2000
|
||||||
|
|
||||||
confidence_threshold_low: float = 0.5
|
confidence_low_threshold: float = 0.5
|
||||||
|
confidence_high_threshold: float = 0.8
|
||||||
|
confidence_insufficient_penalty: float = 0.3
|
||||||
max_history_tokens: int = 4000
|
max_history_tokens: int = 4000
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,224 @@
|
||||||
|
"""
|
||||||
|
Confidence calculation for AI Service.
|
||||||
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Confidence scoring and transfer suggestion logic.
|
||||||
|
|
||||||
|
Design reference: design.md Section 4.3 - 检索不中兜底与置信度策略
|
||||||
|
- Retrieval insufficiency detection
|
||||||
|
- Confidence calculation based on retrieval scores
|
||||||
|
- shouldTransfer logic with threshold T_low
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.services.retrieval.base import RetrievalResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfidenceConfig:
|
||||||
|
"""
|
||||||
|
Configuration for confidence calculation.
|
||||||
|
[AC-AISVC-17, AC-AISVC-18] Configurable thresholds.
|
||||||
|
"""
|
||||||
|
score_threshold: float = 0.7
|
||||||
|
min_hits: int = 1
|
||||||
|
confidence_low_threshold: float = 0.5
|
||||||
|
confidence_high_threshold: float = 0.8
|
||||||
|
insufficient_penalty: float = 0.3
|
||||||
|
max_evidence_tokens: int = 2000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfidenceResult:
|
||||||
|
"""
|
||||||
|
Result of confidence calculation.
|
||||||
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Contains confidence and transfer suggestion.
|
||||||
|
"""
|
||||||
|
confidence: float
|
||||||
|
should_transfer: bool
|
||||||
|
transfer_reason: str | None = None
|
||||||
|
is_retrieval_insufficient: bool = False
|
||||||
|
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfidenceCalculator:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculator for response confidence.
|
||||||
|
|
||||||
|
Design reference: design.md Section 4.3
|
||||||
|
- MVP: confidence based on RAG retrieval scores
|
||||||
|
- Insufficient retrieval triggers confidence downgrade
|
||||||
|
- shouldTransfer when confidence < T_low
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ConfidenceConfig | None = None):
|
||||||
|
settings = get_settings()
|
||||||
|
self._config = config or ConfidenceConfig(
|
||||||
|
score_threshold=getattr(settings, "rag_score_threshold", 0.7),
|
||||||
|
min_hits=getattr(settings, "rag_min_hits", 1),
|
||||||
|
confidence_low_threshold=getattr(settings, "confidence_low_threshold", 0.5),
|
||||||
|
confidence_high_threshold=getattr(settings, "confidence_high_threshold", 0.8),
|
||||||
|
insufficient_penalty=getattr(settings, "confidence_insufficient_penalty", 0.3),
|
||||||
|
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_retrieval_insufficient(
|
||||||
|
self,
|
||||||
|
retrieval_result: RetrievalResult,
|
||||||
|
evidence_tokens: int | None = None,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-17] Determine if retrieval results are insufficient.
|
||||||
|
|
||||||
|
Conditions for insufficiency:
|
||||||
|
1. hits.size < min_hits
|
||||||
|
2. max(score) < score_threshold
|
||||||
|
3. evidence tokens exceed limit (optional)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieval_result: Result from retrieval operation
|
||||||
|
evidence_tokens: Optional token count for evidence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_insufficient, reason)
|
||||||
|
"""
|
||||||
|
reasons = []
|
||||||
|
|
||||||
|
if retrieval_result.hit_count < self._config.min_hits:
|
||||||
|
reasons.append(
|
||||||
|
f"hit_count({retrieval_result.hit_count}) < min_hits({self._config.min_hits})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if retrieval_result.max_score < self._config.score_threshold:
|
||||||
|
reasons.append(
|
||||||
|
f"max_score({retrieval_result.max_score:.3f}) < threshold({self._config.score_threshold})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if evidence_tokens is not None and evidence_tokens > self._config.max_evidence_tokens:
|
||||||
|
reasons.append(
|
||||||
|
f"evidence_tokens({evidence_tokens}) > max({self._config.max_evidence_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_insufficient = len(reasons) > 0
|
||||||
|
reason = "; ".join(reasons) if reasons else "sufficient"
|
||||||
|
|
||||||
|
return is_insufficient, reason
|
||||||
|
|
||||||
|
def calculate_confidence(
|
||||||
|
self,
|
||||||
|
retrieval_result: RetrievalResult,
|
||||||
|
evidence_tokens: int | None = None,
|
||||||
|
additional_factors: dict[str, float] | None = None,
|
||||||
|
) -> ConfidenceResult:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence and transfer suggestion.
|
||||||
|
|
||||||
|
MVP Strategy:
|
||||||
|
1. Base confidence from max retrieval score
|
||||||
|
2. Adjust for hit count (more hits = higher confidence)
|
||||||
|
3. Penalize if retrieval is insufficient
|
||||||
|
4. Determine shouldTransfer based on T_low threshold
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieval_result: Result from retrieval operation
|
||||||
|
evidence_tokens: Optional token count for evidence
|
||||||
|
additional_factors: Optional additional confidence factors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfidenceResult with confidence and transfer suggestion
|
||||||
|
"""
|
||||||
|
is_insufficient, insufficiency_reason = self.is_retrieval_insufficient(
|
||||||
|
retrieval_result, evidence_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
base_confidence = retrieval_result.max_score
|
||||||
|
|
||||||
|
hit_count_factor = min(1.0, retrieval_result.hit_count / 5.0)
|
||||||
|
confidence = base_confidence * 0.7 + hit_count_factor * 0.3
|
||||||
|
|
||||||
|
if is_insufficient:
|
||||||
|
confidence -= self._config.insufficient_penalty
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-17] Retrieval insufficient: {insufficiency_reason}, "
|
||||||
|
f"applying penalty -{self._config.insufficient_penalty}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if additional_factors:
|
||||||
|
for factor_name, factor_value in additional_factors.items():
|
||||||
|
confidence += factor_value * 0.1
|
||||||
|
|
||||||
|
confidence = max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
should_transfer = confidence < self._config.confidence_low_threshold
|
||||||
|
transfer_reason = None
|
||||||
|
|
||||||
|
if should_transfer:
|
||||||
|
if is_insufficient:
|
||||||
|
transfer_reason = "检索结果不足,无法提供高置信度回答"
|
||||||
|
else:
|
||||||
|
transfer_reason = "置信度低于阈值,建议转人工"
|
||||||
|
elif confidence < self._config.confidence_high_threshold and is_insufficient:
|
||||||
|
transfer_reason = "检索结果有限,回答可能不够准确"
|
||||||
|
|
||||||
|
diagnostics = {
|
||||||
|
"base_confidence": base_confidence,
|
||||||
|
"hit_count": retrieval_result.hit_count,
|
||||||
|
"max_score": retrieval_result.max_score,
|
||||||
|
"is_insufficient": is_insufficient,
|
||||||
|
"insufficiency_reason": insufficiency_reason if is_insufficient else None,
|
||||||
|
"penalty_applied": self._config.insufficient_penalty if is_insufficient else 0.0,
|
||||||
|
"threshold_low": self._config.confidence_low_threshold,
|
||||||
|
"threshold_high": self._config.confidence_high_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
|
||||||
|
f"{confidence:.3f}, should_transfer={should_transfer}, "
|
||||||
|
f"insufficient={is_insufficient}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConfidenceResult(
|
||||||
|
confidence=round(confidence, 3),
|
||||||
|
should_transfer=should_transfer,
|
||||||
|
transfer_reason=transfer_reason,
|
||||||
|
is_retrieval_insufficient=is_insufficient,
|
||||||
|
diagnostics=diagnostics,
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_confidence_no_retrieval(self) -> ConfidenceResult:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-17] Calculate confidence when no retrieval was performed.
|
||||||
|
|
||||||
|
Returns a low confidence result suggesting transfer.
|
||||||
|
"""
|
||||||
|
return ConfidenceResult(
|
||||||
|
confidence=0.3,
|
||||||
|
should_transfer=True,
|
||||||
|
transfer_reason="未进行知识库检索,建议转人工",
|
||||||
|
is_retrieval_insufficient=True,
|
||||||
|
diagnostics={
|
||||||
|
"base_confidence": 0.0,
|
||||||
|
"hit_count": 0,
|
||||||
|
"max_score": 0.0,
|
||||||
|
"is_insufficient": True,
|
||||||
|
"insufficiency_reason": "no_retrieval",
|
||||||
|
"penalty_applied": 0.0,
|
||||||
|
"threshold_low": self._config.confidence_low_threshold,
|
||||||
|
"threshold_high": self._config.confidence_high_threshold,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_confidence_calculator: ConfidenceCalculator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_confidence_calculator() -> ConfidenceCalculator:
|
||||||
|
"""Get or create confidence calculator instance."""
|
||||||
|
global _confidence_calculator
|
||||||
|
if _confidence_calculator is None:
|
||||||
|
_confidence_calculator = ConfidenceCalculator()
|
||||||
|
return _confidence_calculator
|
||||||
|
|
@ -0,0 +1,302 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Confidence Calculator.
|
||||||
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Tests for confidence scoring and transfer logic.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Retrieval insufficiency detection
|
||||||
|
- Confidence calculation based on retrieval scores
|
||||||
|
- shouldTransfer logic with threshold T_low
|
||||||
|
- Edge cases (no retrieval, empty results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.retrieval.base import RetrievalHit, RetrievalResult
|
||||||
|
from app.services.confidence import (
|
||||||
|
ConfidenceCalculator,
|
||||||
|
ConfidenceConfig,
|
||||||
|
ConfidenceResult,
|
||||||
|
get_confidence_calculator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""Mock settings for testing."""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.rag_score_threshold = 0.7
|
||||||
|
settings.rag_min_hits = 1
|
||||||
|
settings.confidence_low_threshold = 0.5
|
||||||
|
settings.confidence_high_threshold = 0.8
|
||||||
|
settings.confidence_insufficient_penalty = 0.3
|
||||||
|
settings.rag_max_evidence_tokens = 2000
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def confidence_calculator(mock_settings):
|
||||||
|
"""Create confidence calculator with mocked settings."""
|
||||||
|
with patch("app.services.confidence.get_settings", return_value=mock_settings):
|
||||||
|
calculator = ConfidenceCalculator()
|
||||||
|
yield calculator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def good_retrieval_result():
|
||||||
|
"""Sample retrieval result with good hits."""
|
||||||
|
return RetrievalResult(
|
||||||
|
hits=[
|
||||||
|
RetrievalHit(text="Result 1", score=0.9, source="kb"),
|
||||||
|
RetrievalHit(text="Result 2", score=0.85, source="kb"),
|
||||||
|
RetrievalHit(text="Result 3", score=0.8, source="kb"),
|
||||||
|
],
|
||||||
|
diagnostics={"query_length": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def poor_retrieval_result():
|
||||||
|
"""Sample retrieval result with poor hits."""
|
||||||
|
return RetrievalResult(
|
||||||
|
hits=[
|
||||||
|
RetrievalHit(text="Result 1", score=0.5, source="kb"),
|
||||||
|
],
|
||||||
|
diagnostics={"query_length": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def empty_retrieval_result():
|
||||||
|
"""Sample empty retrieval result."""
|
||||||
|
return RetrievalResult(
|
||||||
|
hits=[],
|
||||||
|
diagnostics={"query_length": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalInsufficiency:
|
||||||
|
"""Tests for retrieval insufficiency detection. [AC-AISVC-17]"""
|
||||||
|
|
||||||
|
def test_sufficient_retrieval(self, confidence_calculator, good_retrieval_result):
|
||||||
|
"""[AC-AISVC-17] Test sufficient retrieval detection."""
|
||||||
|
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
|
||||||
|
good_retrieval_result
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_insufficient is False
|
||||||
|
assert reason == "sufficient"
|
||||||
|
|
||||||
|
def test_insufficient_hit_count(self, confidence_calculator):
|
||||||
|
"""[AC-AISVC-17] Test insufficiency due to low hit count."""
|
||||||
|
config = ConfidenceConfig(min_hits=3)
|
||||||
|
calculator = ConfidenceCalculator(config=config)
|
||||||
|
|
||||||
|
result = RetrievalResult(
|
||||||
|
hits=[
|
||||||
|
RetrievalHit(text="Result 1", score=0.9, source="kb"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
is_insufficient, reason = calculator.is_retrieval_insufficient(result)
|
||||||
|
|
||||||
|
assert is_insufficient is True
|
||||||
|
assert "hit_count" in reason.lower()
|
||||||
|
|
||||||
|
def test_insufficient_score(self, confidence_calculator, poor_retrieval_result):
|
||||||
|
"""[AC-AISVC-17] Test insufficiency due to low score."""
|
||||||
|
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
|
||||||
|
poor_retrieval_result
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_insufficient is True
|
||||||
|
assert "max_score" in reason.lower()
|
||||||
|
|
||||||
|
def test_insufficient_empty_result(self, confidence_calculator, empty_retrieval_result):
|
||||||
|
"""[AC-AISVC-17] Test insufficiency with empty result."""
|
||||||
|
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
|
||||||
|
empty_retrieval_result
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_insufficient is True
|
||||||
|
|
||||||
|
def test_insufficient_evidence_tokens(self, confidence_calculator, good_retrieval_result):
|
||||||
|
"""[AC-AISVC-17] Test insufficiency due to evidence token limit."""
|
||||||
|
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
|
||||||
|
good_retrieval_result, evidence_tokens=3000
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_insufficient is True
|
||||||
|
assert "evidence_tokens" in reason.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfidenceCalculation:
|
||||||
|
"""Tests for confidence calculation. [AC-AISVC-17, AC-AISVC-19]"""
|
||||||
|
|
||||||
|
def test_high_confidence_with_good_retrieval(
|
||||||
|
self, confidence_calculator, good_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-19] Test high confidence with good retrieval results."""
|
||||||
|
result = confidence_calculator.calculate_confidence(good_retrieval_result)
|
||||||
|
|
||||||
|
assert isinstance(result, ConfidenceResult)
|
||||||
|
assert result.confidence >= 0.5
|
||||||
|
assert result.should_transfer is False
|
||||||
|
assert result.is_retrieval_insufficient is False
|
||||||
|
|
||||||
|
def test_low_confidence_with_poor_retrieval(
|
||||||
|
self, confidence_calculator, poor_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-17] Test low confidence with poor retrieval results."""
|
||||||
|
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
|
||||||
|
|
||||||
|
assert isinstance(result, ConfidenceResult)
|
||||||
|
assert result.confidence < 0.7
|
||||||
|
assert result.is_retrieval_insufficient is True
|
||||||
|
|
||||||
|
def test_confidence_with_empty_result(
|
||||||
|
self, confidence_calculator, empty_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-17] Test confidence with empty retrieval result."""
|
||||||
|
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
|
||||||
|
|
||||||
|
assert result.confidence < 0.5
|
||||||
|
assert result.should_transfer is True
|
||||||
|
assert result.is_retrieval_insufficient is True
|
||||||
|
|
||||||
|
def test_confidence_includes_diagnostics(
|
||||||
|
self, confidence_calculator, good_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-17] Test that confidence result includes diagnostics."""
|
||||||
|
result = confidence_calculator.calculate_confidence(good_retrieval_result)
|
||||||
|
|
||||||
|
assert "base_confidence" in result.diagnostics
|
||||||
|
assert "hit_count" in result.diagnostics
|
||||||
|
assert "max_score" in result.diagnostics
|
||||||
|
assert "threshold_low" in result.diagnostics
|
||||||
|
|
||||||
|
def test_confidence_with_additional_factors(
|
||||||
|
self, confidence_calculator, good_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-17] Test confidence with additional factors."""
|
||||||
|
additional = {"model_certainty": 0.5}
|
||||||
|
result = confidence_calculator.calculate_confidence(
|
||||||
|
good_retrieval_result, additional_factors=additional
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.confidence > 0
|
||||||
|
|
||||||
|
def test_confidence_bounded_to_range(self, confidence_calculator):
|
||||||
|
"""[AC-AISVC-17] Test that confidence is bounded to [0, 1]."""
|
||||||
|
result_with_high_score = RetrievalResult(
|
||||||
|
hits=[RetrievalHit(text="Result", score=1.0, source="kb")]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = confidence_calculator.calculate_confidence(result_with_high_score)
|
||||||
|
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestShouldTransfer:
|
||||||
|
"""Tests for shouldTransfer logic. [AC-AISVC-18]"""
|
||||||
|
|
||||||
|
def test_no_transfer_with_high_confidence(
|
||||||
|
self, confidence_calculator, good_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-18] Test no transfer when confidence is high."""
|
||||||
|
result = confidence_calculator.calculate_confidence(good_retrieval_result)
|
||||||
|
|
||||||
|
assert result.should_transfer is False
|
||||||
|
assert result.transfer_reason is None
|
||||||
|
|
||||||
|
def test_transfer_with_low_confidence(
|
||||||
|
self, confidence_calculator, empty_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-18] Test transfer when confidence is low."""
|
||||||
|
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
|
||||||
|
|
||||||
|
assert result.should_transfer is True
|
||||||
|
assert result.transfer_reason is not None
|
||||||
|
|
||||||
|
def test_transfer_reason_for_insufficient_retrieval(
|
||||||
|
self, confidence_calculator, poor_retrieval_result
|
||||||
|
):
|
||||||
|
"""[AC-AISVC-18] Test transfer reason for insufficient retrieval."""
|
||||||
|
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
|
||||||
|
|
||||||
|
assert result.is_retrieval_insufficient is True
|
||||||
|
if result.should_transfer:
|
||||||
|
assert "检索" in result.transfer_reason or "置信度" in result.transfer_reason
|
||||||
|
|
||||||
|
def test_custom_threshold(self):
|
||||||
|
"""[AC-AISVC-18] Test custom low threshold for transfer."""
|
||||||
|
config = ConfidenceConfig(
|
||||||
|
confidence_low_threshold=0.7,
|
||||||
|
score_threshold=0.7,
|
||||||
|
min_hits=1,
|
||||||
|
)
|
||||||
|
calculator = ConfidenceCalculator(config=config)
|
||||||
|
|
||||||
|
result = RetrievalResult(
|
||||||
|
hits=[RetrievalHit(text="Result", score=0.6, source="kb")]
|
||||||
|
)
|
||||||
|
|
||||||
|
conf_result = calculator.calculate_confidence(result)
|
||||||
|
|
||||||
|
assert conf_result.should_transfer is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoRetrieval:
|
||||||
|
"""Tests for no retrieval scenario. [AC-AISVC-17]"""
|
||||||
|
|
||||||
|
def test_no_retrieval_confidence(self, confidence_calculator):
|
||||||
|
"""[AC-AISVC-17] Test confidence when no retrieval was performed."""
|
||||||
|
result = confidence_calculator.calculate_confidence_no_retrieval()
|
||||||
|
|
||||||
|
assert result.confidence == 0.3
|
||||||
|
assert result.should_transfer is True
|
||||||
|
assert result.transfer_reason is not None
|
||||||
|
assert result.is_retrieval_insufficient is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfidenceConfig:
|
||||||
|
"""Tests for confidence configuration."""
|
||||||
|
|
||||||
|
def test_default_config(self, mock_settings):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
with patch("app.services.confidence.get_settings", return_value=mock_settings):
|
||||||
|
calculator = ConfidenceCalculator()
|
||||||
|
|
||||||
|
assert calculator._config.score_threshold == 0.7
|
||||||
|
assert calculator._config.min_hits == 1
|
||||||
|
assert calculator._config.confidence_low_threshold == 0.5
|
||||||
|
|
||||||
|
def test_custom_config(self):
|
||||||
|
"""Test custom configuration values."""
|
||||||
|
config = ConfidenceConfig(
|
||||||
|
score_threshold=0.8,
|
||||||
|
min_hits=2,
|
||||||
|
confidence_low_threshold=0.6,
|
||||||
|
)
|
||||||
|
calculator = ConfidenceCalculator(config=config)
|
||||||
|
|
||||||
|
assert calculator._config.score_threshold == 0.8
|
||||||
|
assert calculator._config.min_hits == 2
|
||||||
|
assert calculator._config.confidence_low_threshold == 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfidenceCalculatorSingleton:
|
||||||
|
"""Tests for singleton pattern."""
|
||||||
|
|
||||||
|
def test_get_confidence_calculator_singleton(self, mock_settings):
|
||||||
|
"""Test that get_confidence_calculator returns singleton."""
|
||||||
|
with patch("app.services.confidence.get_settings", return_value=mock_settings):
|
||||||
|
from app.services.confidence import _confidence_calculator
|
||||||
|
import app.services.confidence as confidence_module
|
||||||
|
confidence_module._confidence_calculator = None
|
||||||
|
|
||||||
|
calculator1 = get_confidence_calculator()
|
||||||
|
calculator2 = get_confidence_calculator()
|
||||||
|
|
||||||
|
assert calculator1 is calculator2
|
||||||
|
|
@ -0,0 +1,376 @@
|
||||||
|
"""
|
||||||
|
Tests for SSE state machine and error handling.
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Tests for proper event sequence and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sse_starlette.sse import ServerSentEvent
|
||||||
|
|
||||||
|
from app.core.sse import (
|
||||||
|
SSEState,
|
||||||
|
SSEStateMachine,
|
||||||
|
create_error_event,
|
||||||
|
create_final_event,
|
||||||
|
create_message_event,
|
||||||
|
)
|
||||||
|
from app.main import app
|
||||||
|
from app.models import ChatRequest, ChannelType
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEStateMachineTransitions:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine transitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_to_streaming_transition(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test INIT -> STREAMING transition.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
assert state_machine.state == SSEState.INIT
|
||||||
|
|
||||||
|
success = await state_machine.transition_to_streaming()
|
||||||
|
assert success is True
|
||||||
|
assert state_machine.state == SSEState.STREAMING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_to_final_transition(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test STREAMING -> FINAL_SENT transition.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
|
success = await state_machine.transition_to_final()
|
||||||
|
assert success is True
|
||||||
|
assert state_machine.state == SSEState.FINAL_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_to_error_transition(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test STREAMING -> ERROR_SENT transition.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
|
success = await state_machine.transition_to_error()
|
||||||
|
assert success is True
|
||||||
|
assert state_machine.state == SSEState.ERROR_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_to_error_transition(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test INIT -> ERROR_SENT transition (error before streaming starts).
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
|
||||||
|
success = await state_machine.transition_to_error()
|
||||||
|
assert success is True
|
||||||
|
assert state_machine.state == SSEState.ERROR_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_transition_from_final(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that no transitions are possible after FINAL_SENT.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
await state_machine.transition_to_final()
|
||||||
|
|
||||||
|
assert await state_machine.transition_to_streaming() is False
|
||||||
|
assert await state_machine.transition_to_error() is False
|
||||||
|
assert state_machine.state == SSEState.FINAL_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_transition_from_error(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that no transitions are possible after ERROR_SENT.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
await state_machine.transition_to_error()
|
||||||
|
|
||||||
|
assert await state_machine.transition_to_streaming() is False
|
||||||
|
assert await state_machine.transition_to_final() is False
|
||||||
|
assert state_machine.state == SSEState.ERROR_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_send_message_after_final(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that can_send_message returns False after FINAL_SENT.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
await state_machine.transition_to_final()
|
||||||
|
|
||||||
|
assert state_machine.can_send_message() is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_send_message_after_error(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that can_send_message returns False after ERROR_SENT.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
await state_machine.transition_to_error()
|
||||||
|
|
||||||
|
assert state_machine.can_send_message() is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_transition(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that close() transitions to CLOSED state.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
await state_machine.transition_to_final()
|
||||||
|
|
||||||
|
await state_machine.close()
|
||||||
|
assert state_machine.state == SSEState.CLOSED
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEventSequence:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Test cases for SSE event sequence enforcement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_headers(self):
|
||||||
|
return {"X-Tenant-Id": "tenant_001", "Accept": "text/event-stream"}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_body(self):
|
||||||
|
return {
|
||||||
|
"sessionId": "test_session",
|
||||||
|
"currentMessage": "Hello",
|
||||||
|
"channelType": "wechat",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_sse_sequence_message_then_final(self, client, valid_headers, valid_body):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that SSE events follow: message* -> final -> close.
|
||||||
|
"""
|
||||||
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
content = response.text
|
||||||
|
|
||||||
|
assert "event:message" in content or "event: message" in content
|
||||||
|
assert "event:final" in content or "event: final" in content
|
||||||
|
|
||||||
|
message_idx = content.find("event:message")
|
||||||
|
if message_idx == -1:
|
||||||
|
message_idx = content.find("event: message")
|
||||||
|
final_idx = content.find("event:final")
|
||||||
|
if final_idx == -1:
|
||||||
|
final_idx = content.find("event: final")
|
||||||
|
|
||||||
|
assert final_idx > message_idx, "final should come after message events"
|
||||||
|
|
||||||
|
def test_sse_only_one_final_event(self, client, valid_headers, valid_body):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that there is exactly one final event.
|
||||||
|
"""
|
||||||
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||||||
|
|
||||||
|
content = response.text
|
||||||
|
final_count = content.count("event:final") + content.count("event: final")
|
||||||
|
|
||||||
|
assert final_count == 1, f"Expected exactly 1 final event, got {final_count}"
|
||||||
|
|
||||||
|
def test_sse_no_events_after_final(self, client, valid_headers, valid_body):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that no message events appear after final event.
|
||||||
|
"""
|
||||||
|
response = client.post("/ai/chat", json=valid_body, headers=valid_headers)
|
||||||
|
|
||||||
|
content = response.text
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
final_found = False
|
||||||
|
for line in lines:
|
||||||
|
if "event:final" in line or "event: final" in line:
|
||||||
|
final_found = True
|
||||||
|
elif final_found and ("event:message" in line or "event: message" in line):
|
||||||
|
pytest.fail("Found message event after final event")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEErrorHandling:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test cases for SSE error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_event_format(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test error event format.
|
||||||
|
"""
|
||||||
|
event = create_error_event(
|
||||||
|
code="TEST_ERROR",
|
||||||
|
message="Test error message",
|
||||||
|
details=[{"field": "test"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "error"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["code"] == "TEST_ERROR"
|
||||||
|
assert data["message"] == "Test error message"
|
||||||
|
assert data["details"] == [{"field": "test"}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_event_without_details(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test error event without details.
|
||||||
|
"""
|
||||||
|
event = create_error_event(
|
||||||
|
code="SIMPLE_ERROR",
|
||||||
|
message="Simple error",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "error"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["code"] == "SIMPLE_ERROR"
|
||||||
|
assert data["message"] == "Simple error"
|
||||||
|
assert "details" not in data
|
||||||
|
|
||||||
|
def test_missing_tenant_id_returns_400(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error.
|
||||||
|
"""
|
||||||
|
client = TestClient(app)
|
||||||
|
headers = {"Accept": "text/event-stream"}
|
||||||
|
body = {
|
||||||
|
"sessionId": "test_session",
|
||||||
|
"currentMessage": "Hello",
|
||||||
|
"channelType": "wechat",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/ai/chat", json=body, headers=headers)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert data["code"] == "MISSING_TENANT_ID"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEStateConcurrency:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Test cases for state machine thread safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_transitions(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that concurrent transitions are handled correctly.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def try_transition():
|
||||||
|
success = await state_machine.transition_to_streaming()
|
||||||
|
results.append(success)
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
try_transition(),
|
||||||
|
try_transition(),
|
||||||
|
try_transition(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sum(results) == 1, "Only one transition should succeed"
|
||||||
|
assert state_machine.state == SSEState.STREAMING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_final_transitions(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that only one final transition succeeds.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def try_final():
|
||||||
|
success = await state_machine.transition_to_final()
|
||||||
|
results.append(success)
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
try_final(),
|
||||||
|
try_final(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sum(results) == 1, "Only one final transition should succeed"
|
||||||
|
assert state_machine.state == SSEState.FINAL_SENT
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEIntegrationWithOrchestrator:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08, AC-AISVC-09] Integration tests for SSE with Orchestrator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrator_stream_with_error(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that orchestrator errors are properly handled.
|
||||||
|
"""
|
||||||
|
from app.services.orchestrator import OrchestratorService
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
|
||||||
|
async def failing_stream(*args, **kwargs):
|
||||||
|
yield MagicMock(delta="Hello", finish_reason=None)
|
||||||
|
raise Exception("LLM connection lost")
|
||||||
|
|
||||||
|
mock_llm.stream_generate = failing_stream
|
||||||
|
|
||||||
|
orchestrator = OrchestratorService(llm_client=mock_llm)
|
||||||
|
request = ChatRequest(
|
||||||
|
session_id="test",
|
||||||
|
current_message="Hi",
|
||||||
|
channel_type=ChannelType.WECHAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant", request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
event_types = [e.event for e in events]
|
||||||
|
assert "message" in event_types
|
||||||
|
assert "error" in event_types
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrator_stream_normal_flow(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test normal streaming flow ends with final event.
|
||||||
|
"""
|
||||||
|
from app.services.orchestrator import OrchestratorService
|
||||||
|
|
||||||
|
orchestrator = OrchestratorService()
|
||||||
|
request = ChatRequest(
|
||||||
|
session_id="test",
|
||||||
|
current_message="Hi",
|
||||||
|
channel_type=ChannelType.WECHAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant", request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
event_types = [e.event for e in events]
|
||||||
|
assert "message" in event_types
|
||||||
|
assert "final" in event_types
|
||||||
|
|
||||||
|
final_index = event_types.index("final")
|
||||||
|
for i, t in enumerate(event_types):
|
||||||
|
if t == "message":
|
||||||
|
assert i < final_index, "message events should come before final"
|
||||||
Loading…
Reference in New Issue