469 lines
16 KiB
Python
469 lines
16 KiB
Python
"""
|
|
Integration tests for IntentRouter.match_hybrid().
|
|
[AC-AISVC-111] Tests for hybrid routing integration.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import uuid
|
|
import asyncio
|
|
|
|
from app.services.intent.models import (
|
|
FusionConfig,
|
|
FusionResult,
|
|
LlmJudgeInput,
|
|
LlmJudgeResult,
|
|
RuleMatchResult,
|
|
SemanticCandidate,
|
|
SemanticMatchResult,
|
|
RouteTrace,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_embedding_provider():
|
|
"""Create a mock embedding provider."""
|
|
provider = AsyncMock()
|
|
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
|
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768])
|
|
return provider
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_client():
|
|
"""Create a mock LLM client."""
|
|
client = AsyncMock()
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def config():
|
|
"""Create a fusion config."""
|
|
return FusionConfig()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_rule():
|
|
"""Create a mock intent rule."""
|
|
rule = MagicMock()
|
|
rule.id = uuid.uuid4()
|
|
rule.name = "Return Intent"
|
|
rule.response_type = "rag"
|
|
rule.keywords = ["退货", "退款"]
|
|
rule.patterns = []
|
|
rule.intent_vector = [0.1] * 768
|
|
rule.semantic_examples = None
|
|
rule.is_enabled = True
|
|
rule.priority = 10
|
|
return rule
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_rules(mock_rule):
|
|
"""Create a list of mock intent rules."""
|
|
other_rule = MagicMock()
|
|
other_rule.id = uuid.uuid4()
|
|
other_rule.name = "Order Query"
|
|
other_rule.response_type = "rag"
|
|
other_rule.keywords = ["订单", "查询"]
|
|
other_rule.patterns = []
|
|
other_rule.intent_vector = [0.5] * 768
|
|
other_rule.semantic_examples = None
|
|
other_rule.is_enabled = True
|
|
other_rule.priority = 5
|
|
|
|
return [mock_rule, other_rule]
|
|
|
|
|
|
class MockRuleMatcher:
|
|
"""Mock RuleMatcher for testing."""
|
|
|
|
def match(self, message: str, rules: list) -> RuleMatchResult:
|
|
import time
|
|
start_time = time.time()
|
|
message_lower = message.lower()
|
|
|
|
for rule in rules:
|
|
if not rule.is_enabled:
|
|
continue
|
|
for keyword in (rule.keywords or []):
|
|
if keyword.lower() in message_lower:
|
|
return RuleMatchResult(
|
|
rule_id=rule.id,
|
|
rule=rule,
|
|
match_type="keyword",
|
|
matched_text=keyword,
|
|
score=1.0,
|
|
duration_ms=int((time.time() - start_time) * 1000),
|
|
)
|
|
return RuleMatchResult(
|
|
rule_id=None,
|
|
rule=None,
|
|
match_type=None,
|
|
matched_text=None,
|
|
score=0.0,
|
|
duration_ms=int((time.time() - start_time) * 1000),
|
|
)
|
|
|
|
|
|
class MockSemanticMatcher:
|
|
"""Mock SemanticMatcher for testing."""
|
|
|
|
def __init__(self, config):
|
|
self._config = config
|
|
|
|
async def match(self, message: str, rules: list, tenant_id: str, top_k: int = 3) -> SemanticMatchResult:
|
|
import time
|
|
start_time = time.time()
|
|
|
|
if not self._config.semantic_matcher_enabled:
|
|
return SemanticMatchResult(
|
|
candidates=[],
|
|
top_score=0.0,
|
|
duration_ms=0,
|
|
skipped=True,
|
|
skip_reason="disabled",
|
|
)
|
|
|
|
candidates = []
|
|
for rule in rules:
|
|
if rule.intent_vector:
|
|
candidates.append(SemanticCandidate(rule=rule, score=0.85))
|
|
break
|
|
|
|
return SemanticMatchResult(
|
|
candidates=candidates[:top_k],
|
|
top_score=candidates[0].score if candidates else 0.0,
|
|
duration_ms=int((time.time() - start_time) * 1000),
|
|
skipped=False,
|
|
skip_reason=None,
|
|
)
|
|
|
|
|
|
class MockLlmJudge:
|
|
"""Mock LlmJudge for testing."""
|
|
|
|
def __init__(self, config):
|
|
self._config = config
|
|
|
|
def should_trigger(self, rule_result, semantic_result, config=None) -> tuple:
|
|
effective_config = config or self._config
|
|
if not effective_config.llm_judge_enabled:
|
|
return False, "disabled"
|
|
|
|
if rule_result.score > 0 and semantic_result.top_score > 0:
|
|
if semantic_result.candidates:
|
|
if rule_result.rule_id != semantic_result.candidates[0].rule.id:
|
|
if abs(rule_result.score - semantic_result.top_score) < effective_config.conflict_threshold:
|
|
return True, "rule_semantic_conflict"
|
|
|
|
max_score = max(rule_result.score, semantic_result.top_score)
|
|
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
|
|
return True, "gray_zone"
|
|
|
|
return False, ""
|
|
|
|
async def judge(self, input_data: LlmJudgeInput, tenant_id: str) -> LlmJudgeResult:
|
|
return LlmJudgeResult(
|
|
intent_id=input_data.candidates[0]["id"] if input_data.candidates else None,
|
|
intent_name=input_data.candidates[0]["name"] if input_data.candidates else None,
|
|
score=0.9,
|
|
reasoning="Test arbitration",
|
|
duration_ms=500,
|
|
tokens_used=100,
|
|
triggered=True,
|
|
)
|
|
|
|
|
|
class MockFusionPolicy:
|
|
"""Mock FusionPolicy for testing."""
|
|
|
|
DECISION_PRIORITY = [
|
|
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
|
|
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
|
|
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
|
|
("no_match", lambda r, s, l: True),
|
|
]
|
|
|
|
def __init__(self, config):
|
|
self._config = config
|
|
|
|
def fuse(self, rule_result, semantic_result, llm_result) -> FusionResult:
|
|
decision_reason = "no_match"
|
|
for reason, condition in self.DECISION_PRIORITY:
|
|
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
|
|
decision_reason = reason
|
|
break
|
|
|
|
final_intent = None
|
|
final_confidence = 0.0
|
|
|
|
if decision_reason == "rule_high_confidence":
|
|
final_intent = rule_result.rule
|
|
final_confidence = 1.0
|
|
elif decision_reason == "llm_judge" and llm_result:
|
|
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
|
|
final_confidence = llm_result.score
|
|
elif decision_reason == "semantic_override":
|
|
final_intent = semantic_result.candidates[0].rule
|
|
final_confidence = semantic_result.top_score
|
|
|
|
return FusionResult(
|
|
final_intent=final_intent,
|
|
final_confidence=final_confidence,
|
|
decision_reason=decision_reason,
|
|
need_clarify=final_confidence < 0.4,
|
|
clarify_candidates=None,
|
|
trace=RouteTrace(),
|
|
)
|
|
|
|
def _find_rule_by_id(self, intent_id, rule_result, semantic_result):
|
|
if not intent_id:
|
|
return None
|
|
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
|
return rule_result.rule
|
|
for c in semantic_result.candidates:
|
|
if str(c.rule.id) == intent_id:
|
|
return c.rule
|
|
return None
|
|
|
|
|
|
class MockIntentRouter:
|
|
"""Mock IntentRouter for testing match_hybrid."""
|
|
|
|
def __init__(self, rule_matcher, semantic_matcher, llm_judge, fusion_policy, config=None):
|
|
self._rule_matcher = rule_matcher
|
|
self._semantic_matcher = semantic_matcher
|
|
self._llm_judge = llm_judge
|
|
self._fusion_policy = fusion_policy
|
|
self._config = config or FusionConfig()
|
|
|
|
async def match_hybrid(
|
|
self,
|
|
message: str,
|
|
rules: list,
|
|
tenant_id: str,
|
|
config: FusionConfig | None = None,
|
|
) -> FusionResult:
|
|
effective_config = config or self._config
|
|
|
|
rule_result, semantic_result = await asyncio.gather(
|
|
asyncio.to_thread(self._rule_matcher.match, message, rules),
|
|
self._semantic_matcher.match(message, rules, tenant_id),
|
|
)
|
|
|
|
llm_result = None
|
|
should_trigger, trigger_reason = self._llm_judge.should_trigger(
|
|
rule_result, semantic_result, effective_config
|
|
)
|
|
|
|
if should_trigger:
|
|
candidates = self._build_llm_candidates(rule_result, semantic_result)
|
|
llm_result = await self._llm_judge.judge(
|
|
LlmJudgeInput(
|
|
message=message,
|
|
candidates=candidates,
|
|
conflict_type=trigger_reason,
|
|
),
|
|
tenant_id,
|
|
)
|
|
|
|
fusion_result = self._fusion_policy.fuse(
|
|
rule_result, semantic_result, llm_result
|
|
)
|
|
|
|
return fusion_result
|
|
|
|
def _build_llm_candidates(self, rule_result, semantic_result) -> list:
|
|
candidates = []
|
|
|
|
if rule_result.rule:
|
|
candidates.append({
|
|
"id": str(rule_result.rule_id),
|
|
"name": rule_result.rule.name,
|
|
"description": f"匹配方式: {rule_result.match_type}",
|
|
})
|
|
|
|
for candidate in semantic_result.candidates[:3]:
|
|
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
|
|
candidates.append({
|
|
"id": str(candidate.rule.id),
|
|
"name": candidate.rule.name,
|
|
"description": f"语义相似度: {candidate.score:.2f}",
|
|
})
|
|
|
|
return candidates
|
|
|
|
|
|
class TestIntentRouterHybrid:
|
|
"""Tests for IntentRouter.match_hybrid() integration."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_rule_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test hybrid routing with rule match."""
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = MockSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason == "rule_high_confidence"
|
|
assert result.final_intent == mock_rules[0]
|
|
assert result.final_confidence == 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_semantic_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test hybrid routing with semantic match only."""
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = MockSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("商品有问题", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason == "semantic_override"
|
|
assert result.final_intent is not None
|
|
assert result.final_confidence > 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_parallel_execution(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test that rule and semantic matching run in parallel."""
|
|
import time
|
|
|
|
class SlowSemanticMatcher(MockSemanticMatcher):
|
|
async def match(self, message, rules, tenant_id, top_k=3):
|
|
await asyncio.sleep(0.1)
|
|
return await super().match(message, rules, tenant_id, top_k)
|
|
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = SlowSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
start_time = time.time()
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
elapsed = time.time() - start_time
|
|
|
|
assert elapsed < 0.2
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_llm_judge_triggered(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test hybrid routing with LLM judge triggered."""
|
|
config = FusionConfig(conflict_threshold=0.3)
|
|
|
|
class ConflictSemanticMatcher(MockSemanticMatcher):
|
|
async def match(self, message, rules, tenant_id, top_k=3):
|
|
result = await super().match(message, rules, tenant_id, top_k)
|
|
if result.candidates:
|
|
result.candidates[0] = SemanticCandidate(rule=rules[1], score=0.9)
|
|
result.top_score = 0.9
|
|
return result
|
|
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = ConflictSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason in ["rule_high_confidence", "llm_judge"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_no_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test hybrid routing with no match."""
|
|
class NoMatchSemanticMatcher(MockSemanticMatcher):
|
|
async def match(self, message, rules, tenant_id, top_k=3):
|
|
return SemanticMatchResult(
|
|
candidates=[],
|
|
top_score=0.0,
|
|
duration_ms=10,
|
|
skipped=True,
|
|
skip_reason="no_semantic_config",
|
|
)
|
|
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = NoMatchSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("随便说说", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason == "no_match"
|
|
assert result.final_intent is None
|
|
assert result.final_confidence == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_semantic_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
|
|
"""Test hybrid routing with semantic matcher disabled."""
|
|
config = FusionConfig(semantic_matcher_enabled=False)
|
|
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = MockSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason == "rule_high_confidence"
|
|
assert result.final_intent == mock_rules[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_llm_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
|
|
"""Test hybrid routing with LLM judge disabled."""
|
|
config = FusionConfig(llm_judge_enabled=False)
|
|
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = MockSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
|
|
assert result.decision_reason == "rule_high_confidence"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_hybrid_trace_generated(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
|
"""Test that route trace is generated."""
|
|
rule_matcher = MockRuleMatcher()
|
|
semantic_matcher = MockSemanticMatcher(config)
|
|
llm_judge = MockLlmJudge(config)
|
|
fusion_policy = MockFusionPolicy(config)
|
|
|
|
router = MockIntentRouter(
|
|
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
|
)
|
|
|
|
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
|
|
|
assert result.trace is not None
|