409 lines
13 KiB
Python
409 lines
13 KiB
Python
"""
|
|
Unit tests for FusionPolicy.
|
|
[AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
import uuid
|
|
|
|
from app.services.intent.models import (
|
|
FusionConfig,
|
|
FusionResult,
|
|
LlmJudgeResult,
|
|
RuleMatchResult,
|
|
SemanticCandidate,
|
|
SemanticMatchResult,
|
|
RouteTrace,
|
|
)
|
|
|
|
|
|
class FusionPolicy:
|
|
"""[AC-AISVC-115] Fusion decision policy."""
|
|
|
|
DECISION_PRIORITY = [
|
|
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
|
|
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
|
|
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
|
|
("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False),
|
|
("semantic_fallback", lambda r, s, l: s.top_score > 0.5),
|
|
("rule_fallback", lambda r, s, l: r.score > 0),
|
|
("no_match", lambda r, s, l: True),
|
|
]
|
|
|
|
def __init__(self, config: FusionConfig):
|
|
self._config = config
|
|
|
|
def fuse(
|
|
self,
|
|
rule_result: RuleMatchResult,
|
|
semantic_result: SemanticMatchResult,
|
|
llm_result: LlmJudgeResult | None,
|
|
) -> FusionResult:
|
|
trace = RouteTrace(
|
|
rule_match={
|
|
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
|
|
"match_type": rule_result.match_type,
|
|
"matched_text": rule_result.matched_text,
|
|
"score": rule_result.score,
|
|
"duration_ms": rule_result.duration_ms,
|
|
},
|
|
semantic_match={
|
|
"top_candidates": [
|
|
{"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score}
|
|
for c in semantic_result.candidates
|
|
],
|
|
"top_score": semantic_result.top_score,
|
|
"duration_ms": semantic_result.duration_ms,
|
|
"skipped": semantic_result.skipped,
|
|
"skip_reason": semantic_result.skip_reason,
|
|
},
|
|
llm_judge={
|
|
"triggered": llm_result.triggered if llm_result else False,
|
|
"intent_id": llm_result.intent_id if llm_result else None,
|
|
"score": llm_result.score if llm_result else 0.0,
|
|
"duration_ms": llm_result.duration_ms if llm_result else 0,
|
|
"tokens_used": llm_result.tokens_used if llm_result else 0,
|
|
},
|
|
fusion={},
|
|
)
|
|
|
|
final_intent = None
|
|
final_confidence = 0.0
|
|
decision_reason = "no_match"
|
|
|
|
for reason, condition in self.DECISION_PRIORITY:
|
|
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
|
|
decision_reason = reason
|
|
break
|
|
|
|
if decision_reason == "rule_high_confidence":
|
|
final_intent = rule_result.rule
|
|
final_confidence = 1.0
|
|
elif decision_reason == "llm_judge" and llm_result:
|
|
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
|
|
final_confidence = llm_result.score
|
|
elif decision_reason == "semantic_override":
|
|
final_intent = semantic_result.candidates[0].rule
|
|
final_confidence = semantic_result.top_score
|
|
elif decision_reason == "rule_semantic_agree":
|
|
final_intent = rule_result.rule
|
|
final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result)
|
|
elif decision_reason == "semantic_fallback":
|
|
final_intent = semantic_result.candidates[0].rule
|
|
final_confidence = semantic_result.top_score
|
|
elif decision_reason == "rule_fallback":
|
|
final_intent = rule_result.rule
|
|
final_confidence = rule_result.score
|
|
|
|
need_clarify = final_confidence < self._config.clarify_threshold
|
|
clarify_candidates = None
|
|
if need_clarify and len(semantic_result.candidates) > 1:
|
|
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
|
|
|
|
trace.fusion = {
|
|
"weights": {
|
|
"w_rule": self._config.w_rule,
|
|
"w_semantic": self._config.w_semantic,
|
|
"w_llm": self._config.w_llm,
|
|
},
|
|
"final_confidence": final_confidence,
|
|
"decision_reason": decision_reason,
|
|
}
|
|
|
|
return FusionResult(
|
|
final_intent=final_intent,
|
|
final_confidence=final_confidence,
|
|
decision_reason=decision_reason,
|
|
need_clarify=need_clarify,
|
|
clarify_candidates=clarify_candidates,
|
|
trace=trace,
|
|
)
|
|
|
|
def _calculate_weighted_confidence(
|
|
self,
|
|
rule_result: RuleMatchResult,
|
|
semantic_result: SemanticMatchResult,
|
|
llm_result: LlmJudgeResult | None,
|
|
) -> float:
|
|
rule_score = rule_result.score
|
|
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
|
|
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
|
|
|
|
total_weight = self._config.w_rule + self._config.w_semantic
|
|
if llm_result and llm_result.triggered:
|
|
total_weight += self._config.w_llm
|
|
|
|
confidence = (
|
|
self._config.w_rule * rule_score +
|
|
self._config.w_semantic * semantic_score +
|
|
self._config.w_llm * llm_score
|
|
) / total_weight
|
|
|
|
return min(1.0, max(0.0, confidence))
|
|
|
|
def _find_rule_by_id(
|
|
self,
|
|
intent_id: str | None,
|
|
rule_result: RuleMatchResult,
|
|
semantic_result: SemanticMatchResult,
|
|
):
|
|
if not intent_id:
|
|
return None
|
|
|
|
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
|
return rule_result.rule
|
|
|
|
for candidate in semantic_result.candidates:
|
|
if str(candidate.rule.id) == intent_id:
|
|
return candidate.rule
|
|
|
|
return None
|
|
|
|
|
|
@pytest.fixture
|
|
def config():
|
|
return FusionConfig()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_rule():
|
|
rule = MagicMock()
|
|
rule.id = uuid.uuid4()
|
|
rule.name = "Test Intent"
|
|
rule.response_type = "rag"
|
|
return rule
|
|
|
|
|
|
class TestFusionPolicy:
|
|
"""Tests for FusionPolicy class."""
|
|
|
|
def test_init(self, config):
|
|
"""Test FusionPolicy initialization."""
|
|
policy = FusionPolicy(config)
|
|
assert policy._config == config
|
|
|
|
def test_fuse_rule_high_confidence(self, config, mock_rule):
|
|
"""Test fusion with rule high confidence."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=mock_rule.id,
|
|
rule=mock_rule,
|
|
match_type="keyword",
|
|
matched_text="test",
|
|
score=1.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[],
|
|
top_score=0.0,
|
|
duration_ms=50,
|
|
skipped=True,
|
|
skip_reason="no_semantic_config",
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.decision_reason == "rule_high_confidence"
|
|
assert result.final_intent == mock_rule
|
|
assert result.final_confidence == 1.0
|
|
assert result.need_clarify is False
|
|
|
|
def test_fuse_llm_judge(self, config, mock_rule):
|
|
"""Test fusion with LLM judge result."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=None,
|
|
rule=None,
|
|
match_type=None,
|
|
matched_text=None,
|
|
score=0.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
|
|
top_score=0.5,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
llm_result = LlmJudgeResult(
|
|
intent_id=str(mock_rule.id),
|
|
intent_name="Test Intent",
|
|
score=0.85,
|
|
reasoning="Test reasoning",
|
|
duration_ms=500,
|
|
tokens_used=100,
|
|
triggered=True,
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, llm_result)
|
|
|
|
assert result.decision_reason == "llm_judge"
|
|
assert result.final_intent == mock_rule
|
|
assert result.final_confidence == 0.85
|
|
|
|
def test_fuse_semantic_override(self, config, mock_rule):
|
|
"""Test fusion with semantic override."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=None,
|
|
rule=None,
|
|
match_type=None,
|
|
matched_text=None,
|
|
score=0.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[SemanticCandidate(rule=mock_rule, score=0.85)],
|
|
top_score=0.85,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.decision_reason == "semantic_override"
|
|
assert result.final_intent == mock_rule
|
|
assert result.final_confidence == 0.85
|
|
|
|
def test_fuse_rule_semantic_agree(self, config, mock_rule):
|
|
"""Test fusion when rule and semantic agree."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=mock_rule.id,
|
|
rule=mock_rule,
|
|
match_type="keyword",
|
|
matched_text="test",
|
|
score=1.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
|
top_score=0.8,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.decision_reason == "rule_high_confidence"
|
|
assert result.final_intent == mock_rule
|
|
|
|
def test_fuse_no_match(self, config):
|
|
"""Test fusion with no match."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=None,
|
|
rule=None,
|
|
match_type=None,
|
|
matched_text=None,
|
|
score=0.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[],
|
|
top_score=0.0,
|
|
duration_ms=50,
|
|
skipped=True,
|
|
skip_reason="no_semantic_config",
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.decision_reason == "no_match"
|
|
assert result.final_intent is None
|
|
assert result.final_confidence == 0.0
|
|
|
|
def test_fuse_need_clarify(self, config, mock_rule):
|
|
"""Test fusion with clarify needed."""
|
|
policy = FusionPolicy(config)
|
|
|
|
other_rule = MagicMock()
|
|
other_rule.id = uuid.uuid4()
|
|
other_rule.name = "Other Intent"
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=None,
|
|
rule=None,
|
|
match_type=None,
|
|
matched_text=None,
|
|
score=0.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[
|
|
SemanticCandidate(rule=mock_rule, score=0.35),
|
|
SemanticCandidate(rule=other_rule, score=0.30),
|
|
],
|
|
top_score=0.35,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.need_clarify is True
|
|
assert result.clarify_candidates is not None
|
|
assert len(result.clarify_candidates) == 2
|
|
|
|
def test_calculate_weighted_confidence(self, config, mock_rule):
|
|
"""Test weighted confidence calculation."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=mock_rule.id,
|
|
rule=mock_rule,
|
|
match_type="keyword",
|
|
matched_text="test",
|
|
score=1.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
|
top_score=0.8,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None)
|
|
|
|
expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3)
|
|
assert abs(confidence - expected) < 0.01
|
|
|
|
def test_trace_generation(self, config, mock_rule):
|
|
"""Test that trace is properly generated."""
|
|
policy = FusionPolicy(config)
|
|
|
|
rule_result = RuleMatchResult(
|
|
rule_id=mock_rule.id,
|
|
rule=mock_rule,
|
|
match_type="keyword",
|
|
matched_text="test",
|
|
score=1.0,
|
|
duration_ms=10,
|
|
)
|
|
semantic_result = SemanticMatchResult(
|
|
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
|
top_score=0.8,
|
|
duration_ms=50,
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
result = policy.fuse(rule_result, semantic_result, None)
|
|
|
|
assert result.trace is not None
|
|
assert result.trace.rule_match["rule_id"] == str(mock_rule.id)
|
|
assert result.trace.semantic_match["top_score"] == 0.8
|
|
assert result.trace.fusion["decision_reason"] == "rule_high_confidence"
|