ai-robot-core/ai-service/tests/test_intent_router_hybrid.py

469 lines
16 KiB
Python

"""
Integration tests for IntentRouter.match_hybrid().
[AC-AISVC-111] Tests for hybrid routing integration.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import asyncio
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeInput,
LlmJudgeResult,
RuleMatchResult,
SemanticCandidate,
SemanticMatchResult,
RouteTrace,
)
@pytest.fixture
def mock_embedding_provider():
"""Create a mock embedding provider."""
provider = AsyncMock()
provider.embed = AsyncMock(return_value=[0.1] * 768)
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768])
return provider
@pytest.fixture
def mock_llm_client():
"""Create a mock LLM client."""
client = AsyncMock()
return client
@pytest.fixture
def config():
"""Create a fusion config."""
return FusionConfig()
@pytest.fixture
def mock_rule():
"""Create a mock intent rule."""
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Return Intent"
rule.response_type = "rag"
rule.keywords = ["退货", "退款"]
rule.patterns = []
rule.intent_vector = [0.1] * 768
rule.semantic_examples = None
rule.is_enabled = True
rule.priority = 10
return rule
@pytest.fixture
def mock_rules(mock_rule):
"""Create a list of mock intent rules."""
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Order Query"
other_rule.response_type = "rag"
other_rule.keywords = ["订单", "查询"]
other_rule.patterns = []
other_rule.intent_vector = [0.5] * 768
other_rule.semantic_examples = None
other_rule.is_enabled = True
other_rule.priority = 5
return [mock_rule, other_rule]
class MockRuleMatcher:
"""Mock RuleMatcher for testing."""
def match(self, message: str, rules: list) -> RuleMatchResult:
import time
start_time = time.time()
message_lower = message.lower()
for rule in rules:
if not rule.is_enabled:
continue
for keyword in (rule.keywords or []):
if keyword.lower() in message_lower:
return RuleMatchResult(
rule_id=rule.id,
rule=rule,
match_type="keyword",
matched_text=keyword,
score=1.0,
duration_ms=int((time.time() - start_time) * 1000),
)
return RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=int((time.time() - start_time) * 1000),
)
class MockSemanticMatcher:
"""Mock SemanticMatcher for testing."""
def __init__(self, config):
self._config = config
async def match(self, message: str, rules: list, tenant_id: str, top_k: int = 3) -> SemanticMatchResult:
import time
start_time = time.time()
if not self._config.semantic_matcher_enabled:
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason="disabled",
)
candidates = []
for rule in rules:
if rule.intent_vector:
candidates.append(SemanticCandidate(rule=rule, score=0.85))
break
return SemanticMatchResult(
candidates=candidates[:top_k],
top_score=candidates[0].score if candidates else 0.0,
duration_ms=int((time.time() - start_time) * 1000),
skipped=False,
skip_reason=None,
)
class MockLlmJudge:
"""Mock LlmJudge for testing."""
def __init__(self, config):
self._config = config
def should_trigger(self, rule_result, semantic_result, config=None) -> tuple:
effective_config = config or self._config
if not effective_config.llm_judge_enabled:
return False, "disabled"
if rule_result.score > 0 and semantic_result.top_score > 0:
if semantic_result.candidates:
if rule_result.rule_id != semantic_result.candidates[0].rule.id:
if abs(rule_result.score - semantic_result.top_score) < effective_config.conflict_threshold:
return True, "rule_semantic_conflict"
max_score = max(rule_result.score, semantic_result.top_score)
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
return True, "gray_zone"
return False, ""
async def judge(self, input_data: LlmJudgeInput, tenant_id: str) -> LlmJudgeResult:
return LlmJudgeResult(
intent_id=input_data.candidates[0]["id"] if input_data.candidates else None,
intent_name=input_data.candidates[0]["name"] if input_data.candidates else None,
score=0.9,
reasoning="Test arbitration",
duration_ms=500,
tokens_used=100,
triggered=True,
)
class MockFusionPolicy:
"""Mock FusionPolicy for testing."""
DECISION_PRIORITY = [
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
("no_match", lambda r, s, l: True),
]
def __init__(self, config):
self._config = config
def fuse(self, rule_result, semantic_result, llm_result) -> FusionResult:
decision_reason = "no_match"
for reason, condition in self.DECISION_PRIORITY:
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
decision_reason = reason
break
final_intent = None
final_confidence = 0.0
if decision_reason == "rule_high_confidence":
final_intent = rule_result.rule
final_confidence = 1.0
elif decision_reason == "llm_judge" and llm_result:
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
final_confidence = llm_result.score
elif decision_reason == "semantic_override":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=final_confidence < 0.4,
clarify_candidates=None,
trace=RouteTrace(),
)
def _find_rule_by_id(self, intent_id, rule_result, semantic_result):
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for c in semantic_result.candidates:
if str(c.rule.id) == intent_id:
return c.rule
return None
class MockIntentRouter:
"""Mock IntentRouter for testing match_hybrid."""
def __init__(self, rule_matcher, semantic_matcher, llm_judge, fusion_policy, config=None):
self._rule_matcher = rule_matcher
self._semantic_matcher = semantic_matcher
self._llm_judge = llm_judge
self._fusion_policy = fusion_policy
self._config = config or FusionConfig()
async def match_hybrid(
self,
message: str,
rules: list,
tenant_id: str,
config: FusionConfig | None = None,
) -> FusionResult:
effective_config = config or self._config
rule_result, semantic_result = await asyncio.gather(
asyncio.to_thread(self._rule_matcher.match, message, rules),
self._semantic_matcher.match(message, rules, tenant_id),
)
llm_result = None
should_trigger, trigger_reason = self._llm_judge.should_trigger(
rule_result, semantic_result, effective_config
)
if should_trigger:
candidates = self._build_llm_candidates(rule_result, semantic_result)
llm_result = await self._llm_judge.judge(
LlmJudgeInput(
message=message,
candidates=candidates,
conflict_type=trigger_reason,
),
tenant_id,
)
fusion_result = self._fusion_policy.fuse(
rule_result, semantic_result, llm_result
)
return fusion_result
def _build_llm_candidates(self, rule_result, semantic_result) -> list:
candidates = []
if rule_result.rule:
candidates.append({
"id": str(rule_result.rule_id),
"name": rule_result.rule.name,
"description": f"匹配方式: {rule_result.match_type}",
})
for candidate in semantic_result.candidates[:3]:
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
candidates.append({
"id": str(candidate.rule.id),
"name": candidate.rule.name,
"description": f"语义相似度: {candidate.score:.2f}",
})
return candidates
class TestIntentRouterHybrid:
"""Tests for IntentRouter.match_hybrid() integration."""
@pytest.mark.asyncio
async def test_match_hybrid_rule_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with rule match."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rules[0]
assert result.final_confidence == 1.0
@pytest.mark.asyncio
async def test_match_hybrid_semantic_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with semantic match only."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("商品有问题", mock_rules, "tenant-1")
assert result.decision_reason == "semantic_override"
assert result.final_intent is not None
assert result.final_confidence > 0.7
@pytest.mark.asyncio
async def test_match_hybrid_parallel_execution(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test that rule and semantic matching run in parallel."""
import time
class SlowSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
await asyncio.sleep(0.1)
return await super().match(message, rules, tenant_id, top_k)
rule_matcher = MockRuleMatcher()
semantic_matcher = SlowSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
start_time = time.time()
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
elapsed = time.time() - start_time
assert elapsed < 0.2
assert result is not None
@pytest.mark.asyncio
async def test_match_hybrid_llm_judge_triggered(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with LLM judge triggered."""
config = FusionConfig(conflict_threshold=0.3)
class ConflictSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
result = await super().match(message, rules, tenant_id, top_k)
if result.candidates:
result.candidates[0] = SemanticCandidate(rule=rules[1], score=0.9)
result.top_score = 0.9
return result
rule_matcher = MockRuleMatcher()
semantic_matcher = ConflictSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason in ["rule_high_confidence", "llm_judge"]
@pytest.mark.asyncio
async def test_match_hybrid_no_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with no match."""
class NoMatchSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=10,
skipped=True,
skip_reason="no_semantic_config",
)
rule_matcher = MockRuleMatcher()
semantic_matcher = NoMatchSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("随便说说", mock_rules, "tenant-1")
assert result.decision_reason == "no_match"
assert result.final_intent is None
assert result.final_confidence == 0.0
@pytest.mark.asyncio
async def test_match_hybrid_semantic_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
"""Test hybrid routing with semantic matcher disabled."""
config = FusionConfig(semantic_matcher_enabled=False)
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rules[0]
@pytest.mark.asyncio
async def test_match_hybrid_llm_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
"""Test hybrid routing with LLM judge disabled."""
config = FusionConfig(llm_judge_enabled=False)
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
@pytest.mark.asyncio
async def test_match_hybrid_trace_generated(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test that route trace is generated."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.trace is not None