ai-robot-core/ai-service/tests/test_confidence.py

303 lines
11 KiB
Python
Raw Permalink Normal View History

"""
Unit tests for Confidence Calculator.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Tests for confidence scoring and transfer logic.
Tests cover:
- Retrieval insufficiency detection
- Confidence calculation based on retrieval scores
- shouldTransfer logic with threshold T_low
- Edge cases (no retrieval, empty results)
"""
from unittest.mock import MagicMock, patch
import pytest
from app.services.retrieval.base import RetrievalHit, RetrievalResult
from app.services.confidence import (
ConfidenceCalculator,
ConfidenceConfig,
ConfidenceResult,
get_confidence_calculator,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.rag_score_threshold = 0.7
settings.rag_min_hits = 1
settings.confidence_low_threshold = 0.5
settings.confidence_high_threshold = 0.8
settings.confidence_insufficient_penalty = 0.3
settings.rag_max_evidence_tokens = 2000
return settings
@pytest.fixture
def confidence_calculator(mock_settings):
"""Create confidence calculator with mocked settings."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
calculator = ConfidenceCalculator()
yield calculator
@pytest.fixture
def good_retrieval_result():
"""Sample retrieval result with good hits."""
return RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.9, source="kb"),
RetrievalHit(text="Result 2", score=0.85, source="kb"),
RetrievalHit(text="Result 3", score=0.8, source="kb"),
],
diagnostics={"query_length": 50},
)
@pytest.fixture
def poor_retrieval_result():
"""Sample retrieval result with poor hits."""
return RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.5, source="kb"),
],
diagnostics={"query_length": 50},
)
@pytest.fixture
def empty_retrieval_result():
"""Sample empty retrieval result."""
return RetrievalResult(
hits=[],
diagnostics={"query_length": 50},
)
class TestRetrievalInsufficiency:
"""Tests for retrieval insufficiency detection. [AC-AISVC-17]"""
def test_sufficient_retrieval(self, confidence_calculator, good_retrieval_result):
"""[AC-AISVC-17] Test sufficient retrieval detection."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
good_retrieval_result
)
assert is_insufficient is False
assert reason == "sufficient"
def test_insufficient_hit_count(self, confidence_calculator):
"""[AC-AISVC-17] Test insufficiency due to low hit count."""
config = ConfidenceConfig(min_hits=3)
calculator = ConfidenceCalculator(config=config)
result = RetrievalResult(
hits=[
RetrievalHit(text="Result 1", score=0.9, source="kb"),
]
)
is_insufficient, reason = calculator.is_retrieval_insufficient(result)
assert is_insufficient is True
assert "hit_count" in reason.lower()
def test_insufficient_score(self, confidence_calculator, poor_retrieval_result):
"""[AC-AISVC-17] Test insufficiency due to low score."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
poor_retrieval_result
)
assert is_insufficient is True
assert "max_score" in reason.lower()
def test_insufficient_empty_result(self, confidence_calculator, empty_retrieval_result):
"""[AC-AISVC-17] Test insufficiency with empty result."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
empty_retrieval_result
)
assert is_insufficient is True
def test_insufficient_evidence_tokens(self, confidence_calculator, good_retrieval_result):
"""[AC-AISVC-17] Test insufficiency due to evidence token limit."""
is_insufficient, reason = confidence_calculator.is_retrieval_insufficient(
good_retrieval_result, evidence_tokens=3000
)
assert is_insufficient is True
assert "evidence_tokens" in reason.lower()
class TestConfidenceCalculation:
"""Tests for confidence calculation. [AC-AISVC-17, AC-AISVC-19]"""
def test_high_confidence_with_good_retrieval(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-19] Test high confidence with good retrieval results."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert isinstance(result, ConfidenceResult)
assert result.confidence >= 0.5
assert result.should_transfer is False
assert result.is_retrieval_insufficient is False
def test_low_confidence_with_poor_retrieval(
self, confidence_calculator, poor_retrieval_result
):
"""[AC-AISVC-17] Test low confidence with poor retrieval results."""
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
assert isinstance(result, ConfidenceResult)
assert result.confidence < 0.7
assert result.is_retrieval_insufficient is True
def test_confidence_with_empty_result(
self, confidence_calculator, empty_retrieval_result
):
"""[AC-AISVC-17] Test confidence with empty retrieval result."""
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
assert result.confidence < 0.5
assert result.should_transfer is True
assert result.is_retrieval_insufficient is True
def test_confidence_includes_diagnostics(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-17] Test that confidence result includes diagnostics."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert "base_confidence" in result.diagnostics
assert "hit_count" in result.diagnostics
assert "max_score" in result.diagnostics
assert "threshold_low" in result.diagnostics
def test_confidence_with_additional_factors(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-17] Test confidence with additional factors."""
additional = {"model_certainty": 0.5}
result = confidence_calculator.calculate_confidence(
good_retrieval_result, additional_factors=additional
)
assert result.confidence > 0
def test_confidence_bounded_to_range(self, confidence_calculator):
"""[AC-AISVC-17] Test that confidence is bounded to [0, 1]."""
result_with_high_score = RetrievalResult(
hits=[RetrievalHit(text="Result", score=1.0, source="kb")]
)
result = confidence_calculator.calculate_confidence(result_with_high_score)
assert 0.0 <= result.confidence <= 1.0
class TestShouldTransfer:
"""Tests for shouldTransfer logic. [AC-AISVC-18]"""
def test_no_transfer_with_high_confidence(
self, confidence_calculator, good_retrieval_result
):
"""[AC-AISVC-18] Test no transfer when confidence is high."""
result = confidence_calculator.calculate_confidence(good_retrieval_result)
assert result.should_transfer is False
assert result.transfer_reason is None
def test_transfer_with_low_confidence(
self, confidence_calculator, empty_retrieval_result
):
"""[AC-AISVC-18] Test transfer when confidence is low."""
result = confidence_calculator.calculate_confidence(empty_retrieval_result)
assert result.should_transfer is True
assert result.transfer_reason is not None
def test_transfer_reason_for_insufficient_retrieval(
self, confidence_calculator, poor_retrieval_result
):
"""[AC-AISVC-18] Test transfer reason for insufficient retrieval."""
result = confidence_calculator.calculate_confidence(poor_retrieval_result)
assert result.is_retrieval_insufficient is True
if result.should_transfer:
assert "检索" in result.transfer_reason or "置信度" in result.transfer_reason
def test_custom_threshold(self):
"""[AC-AISVC-18] Test custom low threshold for transfer."""
config = ConfidenceConfig(
confidence_low_threshold=0.7,
score_threshold=0.7,
min_hits=1,
)
calculator = ConfidenceCalculator(config=config)
result = RetrievalResult(
hits=[RetrievalHit(text="Result", score=0.6, source="kb")]
)
conf_result = calculator.calculate_confidence(result)
assert conf_result.should_transfer is True
class TestNoRetrieval:
"""Tests for no retrieval scenario. [AC-AISVC-17]"""
def test_no_retrieval_confidence(self, confidence_calculator):
"""[AC-AISVC-17] Test confidence when no retrieval was performed."""
result = confidence_calculator.calculate_confidence_no_retrieval()
assert result.confidence == 0.3
assert result.should_transfer is True
assert result.transfer_reason is not None
assert result.is_retrieval_insufficient is True
class TestConfidenceConfig:
"""Tests for confidence configuration."""
def test_default_config(self, mock_settings):
"""Test default configuration values."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
calculator = ConfidenceCalculator()
assert calculator._config.score_threshold == 0.7
assert calculator._config.min_hits == 1
assert calculator._config.confidence_low_threshold == 0.5
def test_custom_config(self):
"""Test custom configuration values."""
config = ConfidenceConfig(
score_threshold=0.8,
min_hits=2,
confidence_low_threshold=0.6,
)
calculator = ConfidenceCalculator(config=config)
assert calculator._config.score_threshold == 0.8
assert calculator._config.min_hits == 2
assert calculator._config.confidence_low_threshold == 0.6
class TestConfidenceCalculatorSingleton:
"""Tests for singleton pattern."""
def test_get_confidence_calculator_singleton(self, mock_settings):
"""Test that get_confidence_calculator returns singleton."""
with patch("app.services.confidence.get_settings", return_value=mock_settings):
from app.services.confidence import _confidence_calculator
import app.services.confidence as confidence_module
confidence_module._confidence_calculator = None
calculator1 = get_confidence_calculator()
calculator2 = get_confidence_calculator()
assert calculator1 is calculator2