ai-robot-core/ai-service/tests/test_fusion_policy.py

409 lines
13 KiB
Python

"""
Unit tests for FusionPolicy.
[AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy.
"""
import pytest
from unittest.mock import MagicMock
import uuid
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeResult,
RuleMatchResult,
SemanticCandidate,
SemanticMatchResult,
RouteTrace,
)
class FusionPolicy:
"""[AC-AISVC-115] Fusion decision policy."""
DECISION_PRIORITY = [
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False),
("semantic_fallback", lambda r, s, l: s.top_score > 0.5),
("rule_fallback", lambda r, s, l: r.score > 0),
("no_match", lambda r, s, l: True),
]
def __init__(self, config: FusionConfig):
self._config = config
def fuse(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> FusionResult:
trace = RouteTrace(
rule_match={
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
"match_type": rule_result.match_type,
"matched_text": rule_result.matched_text,
"score": rule_result.score,
"duration_ms": rule_result.duration_ms,
},
semantic_match={
"top_candidates": [
{"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score}
for c in semantic_result.candidates
],
"top_score": semantic_result.top_score,
"duration_ms": semantic_result.duration_ms,
"skipped": semantic_result.skipped,
"skip_reason": semantic_result.skip_reason,
},
llm_judge={
"triggered": llm_result.triggered if llm_result else False,
"intent_id": llm_result.intent_id if llm_result else None,
"score": llm_result.score if llm_result else 0.0,
"duration_ms": llm_result.duration_ms if llm_result else 0,
"tokens_used": llm_result.tokens_used if llm_result else 0,
},
fusion={},
)
final_intent = None
final_confidence = 0.0
decision_reason = "no_match"
for reason, condition in self.DECISION_PRIORITY:
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
decision_reason = reason
break
if decision_reason == "rule_high_confidence":
final_intent = rule_result.rule
final_confidence = 1.0
elif decision_reason == "llm_judge" and llm_result:
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
final_confidence = llm_result.score
elif decision_reason == "semantic_override":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_semantic_agree":
final_intent = rule_result.rule
final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result)
elif decision_reason == "semantic_fallback":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_fallback":
final_intent = rule_result.rule
final_confidence = rule_result.score
need_clarify = final_confidence < self._config.clarify_threshold
clarify_candidates = None
if need_clarify and len(semantic_result.candidates) > 1:
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
trace.fusion = {
"weights": {
"w_rule": self._config.w_rule,
"w_semantic": self._config.w_semantic,
"w_llm": self._config.w_llm,
},
"final_confidence": final_confidence,
"decision_reason": decision_reason,
}
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=need_clarify,
clarify_candidates=clarify_candidates,
trace=trace,
)
def _calculate_weighted_confidence(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> float:
rule_score = rule_result.score
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
total_weight = self._config.w_rule + self._config.w_semantic
if llm_result and llm_result.triggered:
total_weight += self._config.w_llm
confidence = (
self._config.w_rule * rule_score +
self._config.w_semantic * semantic_score +
self._config.w_llm * llm_score
) / total_weight
return min(1.0, max(0.0, confidence))
def _find_rule_by_id(
self,
intent_id: str | None,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
):
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for candidate in semantic_result.candidates:
if str(candidate.rule.id) == intent_id:
return candidate.rule
return None
@pytest.fixture
def config():
return FusionConfig()
@pytest.fixture
def mock_rule():
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Test Intent"
rule.response_type = "rag"
return rule
class TestFusionPolicy:
"""Tests for FusionPolicy class."""
def test_init(self, config):
"""Test FusionPolicy initialization."""
policy = FusionPolicy(config)
assert policy._config == config
def test_fuse_rule_high_confidence(self, config, mock_rule):
"""Test fusion with rule high confidence."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=50,
skipped=True,
skip_reason="no_semantic_config",
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rule
assert result.final_confidence == 1.0
assert result.need_clarify is False
def test_fuse_llm_judge(self, config, mock_rule):
"""Test fusion with LLM judge result."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
top_score=0.5,
duration_ms=50,
skipped=False,
skip_reason=None,
)
llm_result = LlmJudgeResult(
intent_id=str(mock_rule.id),
intent_name="Test Intent",
score=0.85,
reasoning="Test reasoning",
duration_ms=500,
tokens_used=100,
triggered=True,
)
result = policy.fuse(rule_result, semantic_result, llm_result)
assert result.decision_reason == "llm_judge"
assert result.final_intent == mock_rule
assert result.final_confidence == 0.85
def test_fuse_semantic_override(self, config, mock_rule):
"""Test fusion with semantic override."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.85)],
top_score=0.85,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "semantic_override"
assert result.final_intent == mock_rule
assert result.final_confidence == 0.85
def test_fuse_rule_semantic_agree(self, config, mock_rule):
"""Test fusion when rule and semantic agree."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rule
def test_fuse_no_match(self, config):
"""Test fusion with no match."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=50,
skipped=True,
skip_reason="no_semantic_config",
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "no_match"
assert result.final_intent is None
assert result.final_confidence == 0.0
def test_fuse_need_clarify(self, config, mock_rule):
"""Test fusion with clarify needed."""
policy = FusionPolicy(config)
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Other Intent"
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[
SemanticCandidate(rule=mock_rule, score=0.35),
SemanticCandidate(rule=other_rule, score=0.30),
],
top_score=0.35,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.need_clarify is True
assert result.clarify_candidates is not None
assert len(result.clarify_candidates) == 2
def test_calculate_weighted_confidence(self, config, mock_rule):
"""Test weighted confidence calculation."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None)
expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3)
assert abs(confidence - expected) < 0.01
def test_trace_generation(self, config, mock_rule):
"""Test that trace is properly generated."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.trace is not None
assert result.trace.rule_match["rule_id"] == str(mock_rule.id)
assert result.trace.semantic_match["top_score"] == 0.8
assert result.trace.fusion["decision_reason"] == "rule_high_confidence"