""" Unit tests for FusionPolicy. [AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy. """ import pytest from unittest.mock import MagicMock import uuid from app.services.intent.models import ( FusionConfig, FusionResult, LlmJudgeResult, RuleMatchResult, SemanticCandidate, SemanticMatchResult, RouteTrace, ) class FusionPolicy: """[AC-AISVC-115] Fusion decision policy.""" DECISION_PRIORITY = [ ("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None), ("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None), ("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7), ("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False), ("semantic_fallback", lambda r, s, l: s.top_score > 0.5), ("rule_fallback", lambda r, s, l: r.score > 0), ("no_match", lambda r, s, l: True), ] def __init__(self, config: FusionConfig): self._config = config def fuse( self, rule_result: RuleMatchResult, semantic_result: SemanticMatchResult, llm_result: LlmJudgeResult | None, ) -> FusionResult: trace = RouteTrace( rule_match={ "rule_id": str(rule_result.rule_id) if rule_result.rule_id else None, "match_type": rule_result.match_type, "matched_text": rule_result.matched_text, "score": rule_result.score, "duration_ms": rule_result.duration_ms, }, semantic_match={ "top_candidates": [ {"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score} for c in semantic_result.candidates ], "top_score": semantic_result.top_score, "duration_ms": semantic_result.duration_ms, "skipped": semantic_result.skipped, "skip_reason": semantic_result.skip_reason, }, llm_judge={ "triggered": llm_result.triggered if llm_result else False, "intent_id": llm_result.intent_id if llm_result else None, "score": llm_result.score if llm_result else 0.0, "duration_ms": llm_result.duration_ms if llm_result else 0, "tokens_used": llm_result.tokens_used if llm_result else 0, }, fusion={}, ) final_intent = None final_confidence = 0.0 decision_reason = "no_match" for reason, condition in self.DECISION_PRIORITY: if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()): decision_reason = reason break if decision_reason == "rule_high_confidence": final_intent = rule_result.rule final_confidence = 1.0 elif decision_reason == "llm_judge" and llm_result: final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result) final_confidence = llm_result.score elif decision_reason == "semantic_override": final_intent = semantic_result.candidates[0].rule final_confidence = semantic_result.top_score elif decision_reason == "rule_semantic_agree": final_intent = rule_result.rule final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result) elif decision_reason == "semantic_fallback": final_intent = semantic_result.candidates[0].rule final_confidence = semantic_result.top_score elif decision_reason == "rule_fallback": final_intent = rule_result.rule final_confidence = rule_result.score need_clarify = final_confidence < self._config.clarify_threshold clarify_candidates = None if need_clarify and len(semantic_result.candidates) > 1: clarify_candidates = [c.rule for c in semantic_result.candidates[:3]] trace.fusion = { "weights": { "w_rule": self._config.w_rule, "w_semantic": self._config.w_semantic, "w_llm": self._config.w_llm, }, "final_confidence": final_confidence, "decision_reason": decision_reason, } return FusionResult( final_intent=final_intent, final_confidence=final_confidence, decision_reason=decision_reason, need_clarify=need_clarify, clarify_candidates=clarify_candidates, trace=trace, ) def _calculate_weighted_confidence( self, rule_result: RuleMatchResult, semantic_result: SemanticMatchResult, llm_result: LlmJudgeResult | None, ) -> float: rule_score = rule_result.score semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0 llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0 total_weight = self._config.w_rule + self._config.w_semantic if llm_result and llm_result.triggered: total_weight += self._config.w_llm confidence = ( self._config.w_rule * rule_score + self._config.w_semantic * semantic_score + self._config.w_llm * llm_score ) / total_weight return min(1.0, max(0.0, confidence)) def _find_rule_by_id( self, intent_id: str | None, rule_result: RuleMatchResult, semantic_result: SemanticMatchResult, ): if not intent_id: return None if rule_result.rule_id and str(rule_result.rule_id) == intent_id: return rule_result.rule for candidate in semantic_result.candidates: if str(candidate.rule.id) == intent_id: return candidate.rule return None @pytest.fixture def config(): return FusionConfig() @pytest.fixture def mock_rule(): rule = MagicMock() rule.id = uuid.uuid4() rule.name = "Test Intent" rule.response_type = "rag" return rule class TestFusionPolicy: """Tests for FusionPolicy class.""" def test_init(self, config): """Test FusionPolicy initialization.""" policy = FusionPolicy(config) assert policy._config == config def test_fuse_rule_high_confidence(self, config, mock_rule): """Test fusion with rule high confidence.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=mock_rule.id, rule=mock_rule, match_type="keyword", matched_text="test", score=1.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[], top_score=0.0, duration_ms=50, skipped=True, skip_reason="no_semantic_config", ) result = policy.fuse(rule_result, semantic_result, None) assert result.decision_reason == "rule_high_confidence" assert result.final_intent == mock_rule assert result.final_confidence == 1.0 assert result.need_clarify is False def test_fuse_llm_judge(self, config, mock_rule): """Test fusion with LLM judge result.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=None, rule=None, match_type=None, matched_text=None, score=0.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[SemanticCandidate(rule=mock_rule, score=0.5)], top_score=0.5, duration_ms=50, skipped=False, skip_reason=None, ) llm_result = LlmJudgeResult( intent_id=str(mock_rule.id), intent_name="Test Intent", score=0.85, reasoning="Test reasoning", duration_ms=500, tokens_used=100, triggered=True, ) result = policy.fuse(rule_result, semantic_result, llm_result) assert result.decision_reason == "llm_judge" assert result.final_intent == mock_rule assert result.final_confidence == 0.85 def test_fuse_semantic_override(self, config, mock_rule): """Test fusion with semantic override.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=None, rule=None, match_type=None, matched_text=None, score=0.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[SemanticCandidate(rule=mock_rule, score=0.85)], top_score=0.85, duration_ms=50, skipped=False, skip_reason=None, ) result = policy.fuse(rule_result, semantic_result, None) assert result.decision_reason == "semantic_override" assert result.final_intent == mock_rule assert result.final_confidence == 0.85 def test_fuse_rule_semantic_agree(self, config, mock_rule): """Test fusion when rule and semantic agree.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=mock_rule.id, rule=mock_rule, match_type="keyword", matched_text="test", score=1.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], top_score=0.8, duration_ms=50, skipped=False, skip_reason=None, ) result = policy.fuse(rule_result, semantic_result, None) assert result.decision_reason == "rule_high_confidence" assert result.final_intent == mock_rule def test_fuse_no_match(self, config): """Test fusion with no match.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=None, rule=None, match_type=None, matched_text=None, score=0.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[], top_score=0.0, duration_ms=50, skipped=True, skip_reason="no_semantic_config", ) result = policy.fuse(rule_result, semantic_result, None) assert result.decision_reason == "no_match" assert result.final_intent is None assert result.final_confidence == 0.0 def test_fuse_need_clarify(self, config, mock_rule): """Test fusion with clarify needed.""" policy = FusionPolicy(config) other_rule = MagicMock() other_rule.id = uuid.uuid4() other_rule.name = "Other Intent" rule_result = RuleMatchResult( rule_id=None, rule=None, match_type=None, matched_text=None, score=0.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[ SemanticCandidate(rule=mock_rule, score=0.35), SemanticCandidate(rule=other_rule, score=0.30), ], top_score=0.35, duration_ms=50, skipped=False, skip_reason=None, ) result = policy.fuse(rule_result, semantic_result, None) assert result.need_clarify is True assert result.clarify_candidates is not None assert len(result.clarify_candidates) == 2 def test_calculate_weighted_confidence(self, config, mock_rule): """Test weighted confidence calculation.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=mock_rule.id, rule=mock_rule, match_type="keyword", matched_text="test", score=1.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], top_score=0.8, duration_ms=50, skipped=False, skip_reason=None, ) confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None) expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3) assert abs(confidence - expected) < 0.01 def test_trace_generation(self, config, mock_rule): """Test that trace is properly generated.""" policy = FusionPolicy(config) rule_result = RuleMatchResult( rule_id=mock_rule.id, rule=mock_rule, match_type="keyword", matched_text="test", score=1.0, duration_ms=10, ) semantic_result = SemanticMatchResult( candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], top_score=0.8, duration_ms=50, skipped=False, skip_reason=None, ) result = policy.fuse(rule_result, semantic_result, None) assert result.trace is not None assert result.trace.rule_match["rule_id"] == str(mock_rule.id) assert result.trace.semantic_match["top_score"] == 0.8 assert result.trace.fusion["decision_reason"] == "rule_high_confidence"