400 lines
13 KiB
Python
400 lines
13 KiB
Python
"""
|
|
Tests for Slot State Cache.
|
|
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化测试
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.services.cache.slot_state_cache import (
|
|
CachedSlotState,
|
|
CachedSlotValue,
|
|
SlotStateCache,
|
|
get_slot_state_cache,
|
|
)
|
|
|
|
|
|
class TestCachedSlotValue:
|
|
"""CachedSlotValue 测试"""
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
value = CachedSlotValue(
|
|
value="test_value",
|
|
source="user_confirmed",
|
|
confidence=0.9,
|
|
)
|
|
assert value.value == "test_value"
|
|
assert value.source == "user_confirmed"
|
|
assert value.confidence == 0.9
|
|
assert value.updated_at > 0
|
|
|
|
def test_to_dict(self):
|
|
"""测试转换为字典"""
|
|
value = CachedSlotValue(
|
|
value="test_value",
|
|
source="rule_extracted",
|
|
confidence=0.8,
|
|
)
|
|
d = value.to_dict()
|
|
assert d["value"] == "test_value"
|
|
assert d["source"] == "rule_extracted"
|
|
assert d["confidence"] == 0.8
|
|
assert "updated_at" in d
|
|
|
|
def test_from_dict(self):
|
|
"""测试从字典创建"""
|
|
d = {
|
|
"value": "test_value",
|
|
"source": "llm_inferred",
|
|
"confidence": 0.7,
|
|
"updated_at": 12345.0,
|
|
}
|
|
value = CachedSlotValue.from_dict(d)
|
|
assert value.value == "test_value"
|
|
assert value.source == "llm_inferred"
|
|
assert value.confidence == 0.7
|
|
assert value.updated_at == 12345.0
|
|
|
|
|
|
class TestCachedSlotState:
|
|
"""CachedSlotState 测试"""
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
state = CachedSlotState()
|
|
assert state.filled_slots == {}
|
|
assert state.slot_to_field_map == {}
|
|
assert state.created_at > 0
|
|
assert state.updated_at > 0
|
|
|
|
def test_with_slots(self):
|
|
"""测试带槽位初始化"""
|
|
slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
|
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
|
}
|
|
state = CachedSlotState(
|
|
filled_slots=slots,
|
|
slot_to_field_map={"region": "region_field"},
|
|
)
|
|
assert len(state.filled_slots) == 2
|
|
assert state.slot_to_field_map["region"] == "region_field"
|
|
|
|
def test_to_dict_and_from_dict(self):
|
|
"""测试序列化和反序列化"""
|
|
slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
|
}
|
|
original = CachedSlotState(
|
|
filled_slots=slots,
|
|
slot_to_field_map={"region": "region_field"},
|
|
)
|
|
|
|
d = original.to_dict()
|
|
restored = CachedSlotState.from_dict(d)
|
|
|
|
assert len(restored.filled_slots) == 1
|
|
assert restored.filled_slots["region"].value == "北京"
|
|
assert restored.filled_slots["region"].source == "user_confirmed"
|
|
assert restored.slot_to_field_map["region"] == "region_field"
|
|
|
|
def test_get_simple_filled_slots(self):
|
|
"""测试获取简化槽位字典"""
|
|
slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
|
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
|
}
|
|
state = CachedSlotState(filled_slots=slots)
|
|
simple = state.get_simple_filled_slots()
|
|
assert simple == {"region": "北京", "product": "手机"}
|
|
|
|
def test_get_slot_sources(self):
|
|
"""测试获取槽位来源"""
|
|
slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
|
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
|
}
|
|
state = CachedSlotState(filled_slots=slots)
|
|
sources = state.get_slot_sources()
|
|
assert sources == {"region": "user_confirmed", "product": "rule_extracted"}
|
|
|
|
def test_get_slot_confidence(self):
|
|
"""测试获取槽位置信度"""
|
|
slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
|
"product": CachedSlotValue(value="手机", source="rule_extracted", confidence=0.8),
|
|
}
|
|
state = CachedSlotState(filled_slots=slots)
|
|
confidence = state.get_slot_confidence()
|
|
assert confidence == {"region": 1.0, "product": 0.8}
|
|
|
|
|
|
class TestSlotStateCache:
|
|
"""SlotStateCache 测试"""
|
|
|
|
def test_source_priority(self):
|
|
"""测试来源优先级"""
|
|
cache = SlotStateCache()
|
|
assert cache._get_source_priority("user_confirmed") == 100
|
|
assert cache._get_source_priority("rule_extracted") == 80
|
|
assert cache._get_source_priority("llm_inferred") == 60
|
|
assert cache._get_source_priority("context") == 40
|
|
assert cache._get_source_priority("default") == 20
|
|
assert cache._get_source_priority("unknown") == 0
|
|
|
|
def test_make_key(self):
|
|
"""测试 key 生成"""
|
|
cache = SlotStateCache()
|
|
key = cache._make_key("tenant_123", "session_456")
|
|
assert key == "slot_state:tenant_123:session_456"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_l1_cache_hit(self):
|
|
"""测试 L1 缓存命中"""
|
|
cache = SlotStateCache()
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
|
)
|
|
|
|
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, time.time())
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
assert result is not None
|
|
assert result.filled_slots["region"].value == "北京"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_l1_cache_expired(self):
|
|
"""测试 L1 缓存过期"""
|
|
cache = SlotStateCache()
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
|
)
|
|
|
|
old_time = time.time() - 400
|
|
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, old_time)
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
assert result is None
|
|
assert f"{tenant_id}:{session_id}" not in cache._local_cache
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_and_get_l1(self):
|
|
"""测试设置和获取 L1 缓存"""
|
|
cache = SlotStateCache(redis_client=None)
|
|
cache._enabled = False
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
|
)
|
|
|
|
await cache.set(tenant_id, session_id, state)
|
|
|
|
local_key = f"{tenant_id}:{session_id}"
|
|
assert local_key in cache._local_cache
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
assert result is not None
|
|
assert result.filled_slots["region"].value == "北京"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete(self):
|
|
"""测试删除缓存"""
|
|
cache = SlotStateCache(redis_client=None)
|
|
cache._enabled = False
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
|
)
|
|
|
|
await cache.set(tenant_id, session_id, state)
|
|
await cache.delete(tenant_id, session_id)
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_slot(self):
|
|
"""测试清除单个槽位"""
|
|
cache = SlotStateCache(redis_client=None)
|
|
cache._enabled = False
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
|
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
|
},
|
|
)
|
|
|
|
await cache.set(tenant_id, session_id, state)
|
|
await cache.clear_slot(tenant_id, session_id, "region")
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
assert result is not None
|
|
assert "region" not in result.filled_slots
|
|
assert "product" in result.filled_slots
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_merge_and_set_priority(self):
|
|
"""测试合并时优先级处理"""
|
|
cache = SlotStateCache(redis_client=None)
|
|
cache._enabled = False
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
existing_state = CachedSlotState(
|
|
filled_slots={
|
|
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
|
|
},
|
|
)
|
|
await cache.set(tenant_id, session_id, existing_state)
|
|
|
|
new_slots = {
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
|
}
|
|
|
|
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
|
|
|
|
assert result.filled_slots["region"].value == "北京"
|
|
assert result.filled_slots["region"].source == "user_confirmed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_merge_and_set_lower_priority_ignored(self):
|
|
"""测试低优先级值被忽略"""
|
|
cache = SlotStateCache(redis_client=None)
|
|
cache._enabled = False
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
existing_state = CachedSlotState(
|
|
filled_slots={
|
|
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
|
},
|
|
)
|
|
await cache.set(tenant_id, session_id, existing_state)
|
|
|
|
new_slots = {
|
|
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
|
|
}
|
|
|
|
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
|
|
|
|
assert result.filled_slots["region"].value == "北京"
|
|
assert result.filled_slots["region"].source == "user_confirmed"
|
|
|
|
|
|
class TestGetSlotStateCache:
|
|
"""get_slot_state_cache 单例测试"""
|
|
|
|
def test_singleton(self):
|
|
"""测试单例模式"""
|
|
cache1 = get_slot_state_cache()
|
|
cache2 = get_slot_state_cache()
|
|
assert cache1 is cache2
|
|
|
|
|
|
class TestSlotStateCacheWithRedis:
|
|
"""SlotStateCache Redis 集成测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_redis_set_and_get(self):
|
|
"""测试 Redis 存取"""
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
mock_redis.setex = AsyncMock(return_value=True)
|
|
|
|
cache = SlotStateCache(redis_client=mock_redis)
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
state = CachedSlotState(
|
|
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
|
)
|
|
|
|
await cache.set(tenant_id, session_id, state)
|
|
|
|
mock_redis.setex.assert_called_once()
|
|
call_args = mock_redis.setex.call_args
|
|
assert call_args[0][0] == f"slot_state:{tenant_id}:{session_id}"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_redis_get_hit(self):
|
|
"""测试 Redis 命中"""
|
|
state_dict = {
|
|
"filled_slots": {
|
|
"region": {
|
|
"value": "北京",
|
|
"source": "user_confirmed",
|
|
"confidence": 1.0,
|
|
"updated_at": 12345.0,
|
|
}
|
|
},
|
|
"slot_to_field_map": {"region": "region_field"},
|
|
"created_at": 12340.0,
|
|
"updated_at": 12345.0,
|
|
}
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(return_value=json.dumps(state_dict))
|
|
|
|
cache = SlotStateCache(redis_client=mock_redis)
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
result = await cache.get(tenant_id, session_id)
|
|
|
|
assert result is not None
|
|
assert result.filled_slots["region"].value == "北京"
|
|
assert result.filled_slots["region"].source == "user_confirmed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_redis_delete(self):
|
|
"""测试 Redis 删除"""
|
|
mock_redis = AsyncMock()
|
|
mock_redis.delete = AsyncMock(return_value=1)
|
|
|
|
cache = SlotStateCache(redis_client=mock_redis)
|
|
|
|
tenant_id = "tenant_1"
|
|
session_id = "session_1"
|
|
|
|
await cache.delete(tenant_id, session_id)
|
|
|
|
mock_redis.delete.assert_called_once_with(f"slot_state:{tenant_id}:{session_id}")
|
|
|
|
|
|
class TestCacheTTL:
|
|
"""TTL 配置测试"""
|
|
|
|
def test_default_ttl(self):
|
|
"""测试默认 TTL"""
|
|
cache = SlotStateCache()
|
|
assert cache._cache_ttl == 1800
|
|
|
|
def test_local_cache_ttl(self):
|
|
"""测试本地缓存 TTL"""
|
|
cache = SlotStateCache()
|
|
assert cache._local_cache_ttl == 300
|