test: add unit tests and utility scripts for intent routing, slot management, and KB search [AC-TEST]
This commit is contained in:
parent
fe883cfff0
commit
f4ca25b0d8
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Database Migration: Scene Slot Bundle Tables.
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置表迁移
|
||||
|
||||
创建时间: 2025-03-07
|
||||
变更说明:
|
||||
- 新增 scene_slot_bundles 表用于存储场景槽位包配置
|
||||
|
||||
执行方式:
|
||||
- SQLModel 会自动创建表(通过 init_db)
|
||||
- 此脚本用于手动迁移或回滚
|
||||
|
||||
SQL DDL:
|
||||
```sql
|
||||
CREATE TABLE scene_slot_bundles (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id VARCHAR NOT NULL,
|
||||
scene_key VARCHAR(100) NOT NULL,
|
||||
scene_name VARCHAR(100) NOT NULL,
|
||||
description TEXT,
|
||||
required_slots JSON NOT NULL DEFAULT '[]',
|
||||
optional_slots JSON NOT NULL DEFAULT '[]',
|
||||
slot_priority JSON,
|
||||
completion_threshold FLOAT NOT NULL DEFAULT 1.0,
|
||||
ask_back_order VARCHAR NOT NULL DEFAULT 'priority',
|
||||
status VARCHAR NOT NULL DEFAULT 'draft',
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id);
|
||||
CREATE UNIQUE INDEX ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key);
|
||||
CREATE INDEX ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status);
|
||||
```
|
||||
|
||||
回滚 SQL:
|
||||
```sql
|
||||
DROP TABLE IF EXISTS scene_slot_bundles;
|
||||
```
|
||||
"""
|
||||
|
||||
SCENE_SLOT_BUNDLES_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS scene_slot_bundles (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id VARCHAR NOT NULL,
|
||||
scene_key VARCHAR(100) NOT NULL,
|
||||
scene_name VARCHAR(100) NOT NULL,
|
||||
description TEXT,
|
||||
required_slots JSON NOT NULL DEFAULT '[]',
|
||||
optional_slots JSON NOT NULL DEFAULT '[]',
|
||||
slot_priority JSON,
|
||||
completion_threshold FLOAT NOT NULL DEFAULT 1.0,
|
||||
ask_back_order VARCHAR NOT NULL DEFAULT 'priority',
|
||||
status VARCHAR NOT NULL DEFAULT 'draft',
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
SCENE_SLOT_BUNDLES_INDEXES = """
|
||||
CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key);
|
||||
CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status);
|
||||
"""
|
||||
|
||||
SCENE_SLOT_BUNDLES_ROLLBACK = """
|
||||
DROP TABLE IF EXISTS scene_slot_bundles;
|
||||
"""
|
||||
|
||||
|
||||
async def upgrade(conn):
|
||||
"""执行迁移"""
|
||||
await conn.execute(SCENE_SLOT_BUNDLES_DDL)
|
||||
await conn.execute(SCENE_SLOT_BUNDLES_INDEXES)
|
||||
|
||||
|
||||
async def downgrade(conn):
|
||||
"""回滚迁移"""
|
||||
await conn.execute(SCENE_SLOT_BUNDLES_ROLLBACK)
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""
|
||||
Database Migration: Add display_name and description to slot_definitions.
|
||||
添加槽位名称和槽位说明字段
|
||||
|
||||
创建时间: 2026-03-08
|
||||
变更说明:
|
||||
- 新增 display_name 字段:槽位名称,给运营/教研看的中文名
|
||||
- 新增 description 字段:槽位说明,解释这个槽位采集什么、用于哪里
|
||||
|
||||
执行方式:
|
||||
- SQLModel 会自动处理新字段(通过 init_db)
|
||||
- 此脚本用于手动迁移现有数据库
|
||||
|
||||
SQL DDL:
|
||||
```sql
|
||||
ALTER TABLE slot_definitions
|
||||
ADD COLUMN IF NOT EXISTS display_name VARCHAR(100),
|
||||
ADD COLUMN IF NOT EXISTS description VARCHAR(500);
|
||||
```
|
||||
|
||||
回滚 SQL:
|
||||
```sql
|
||||
ALTER TABLE slot_definitions
|
||||
DROP COLUMN IF EXISTS display_name,
|
||||
DROP COLUMN IF EXISTS description;
|
||||
```
|
||||
"""
|
||||
|
||||
ALTER_SLOT_DEFINITIONS_DDL = """
|
||||
ALTER TABLE slot_definitions
|
||||
ADD COLUMN IF NOT EXISTS display_name VARCHAR(100),
|
||||
ADD COLUMN IF NOT EXISTS description VARCHAR(500);
|
||||
"""
|
||||
|
||||
ALTER_SLOT_DEFINITIONS_ROLLBACK = """
|
||||
ALTER TABLE slot_definitions
|
||||
DROP COLUMN IF EXISTS display_name,
|
||||
DROP COLUMN IF EXISTS description;
|
||||
"""
|
||||
|
||||
|
||||
async def upgrade(conn):
|
||||
"""执行迁移"""
|
||||
await conn.execute(ALTER_SLOT_DEFINITIONS_DDL)
|
||||
|
||||
|
||||
async def downgrade(conn):
|
||||
"""回滚迁移"""
|
||||
await conn.execute(ALTER_SLOT_DEFINITIONS_ROLLBACK)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""
|
||||
检查所有意图规则(包括未启用的)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.models.entities import IntentRule
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def check_all_rules():
|
||||
"""获取所有意图规则"""
|
||||
settings = get_settings()
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
# 查询所有规则(包括未启用的)
|
||||
result = await session.execute(
|
||||
select(IntentRule).where(
|
||||
IntentRule.tenant_id == "szmp@ash@2026"
|
||||
)
|
||||
)
|
||||
rules = result.scalars().all()
|
||||
|
||||
print("=" * 80)
|
||||
print(f"数据库中的所有意图规则 (tenant=szmp@ash@2026):")
|
||||
print(f"总计: {len(rules)} 条")
|
||||
print("=" * 80)
|
||||
|
||||
for rule in rules:
|
||||
print(f"\n规则: {rule.name}")
|
||||
print(f" ID: {rule.id}")
|
||||
print(f" 响应类型: {rule.response_type}")
|
||||
print(f" 关键词: {rule.keywords}")
|
||||
print(f" 目标知识库: {rule.target_kb_ids}")
|
||||
print(f" 优先级: {rule.priority}")
|
||||
print(f" 启用状态: {'✅ 启用' if rule.is_enabled else '❌ 禁用'}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_all_rules())
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
检查 Qdrant 中课程知识库的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def check_course_kb():
|
||||
"""检查课程知识库"""
|
||||
settings = get_settings()
|
||||
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = settings.qdrant_collection_prefix
|
||||
|
||||
expected_collection = f"{prefix}{safe_tenant_id}_{course_kb_id}"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查课程知识库 Collection")
|
||||
print(f"{'='*80}")
|
||||
print(f"租户 ID: {tenant_id}")
|
||||
print(f"课程知识库 ID: {course_kb_id}")
|
||||
print(f"预期 Collection 名称: {expected_collection}")
|
||||
|
||||
collections = await qdrant.get_collections()
|
||||
collection_names = [c.name for c in collections.collections]
|
||||
|
||||
print(f"\n租户的所有 Collections:")
|
||||
for name in collection_names:
|
||||
if safe_tenant_id in name:
|
||||
print(f" - {name}")
|
||||
|
||||
if expected_collection in collection_names:
|
||||
print(f"\n✅ 课程知识库 Collection 存在: {expected_collection}")
|
||||
|
||||
points, _ = qdrant.scroll(
|
||||
collection_name=expected_collection,
|
||||
limit=3,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
print(f"\n课程知识库数据 (共 {len(points)} 条):")
|
||||
for i, point in enumerate(points, 1):
|
||||
payload = point.get('payload', {})
|
||||
print(f"\n [{i}] id: {point.get('id')}")
|
||||
print(f" payload keys: {list(payload.keys())}")
|
||||
if 'metadata' in payload:
|
||||
print(f" metadata: {payload['metadata']}")
|
||||
else:
|
||||
print(f"\n❌ 课程知识库 Collection 不存在!")
|
||||
print(f" 可用的 Collections: {collection_names}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_course_kb())
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""
|
||||
检查课程知识库的 collection 是否存在
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def check_course_kb_collection():
|
||||
"""检查课程知识库的 collection 是否存在"""
|
||||
settings = get_settings()
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
print(f"课程知识库 collection name: {collection_name}")
|
||||
|
||||
exists = await qdrant.collection_exists(collection_name)
|
||||
print(f"Collection exists: {exists}")
|
||||
|
||||
if exists:
|
||||
points = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=5,
|
||||
with_vectors=False,
|
||||
)
|
||||
print(f"\n课程知识库中有 {len(points)} 条数据:")
|
||||
for i, point in enumerate(points, 1):
|
||||
payload = point.get('payload', {})
|
||||
print(f" [{i}] payload keys: {list(payload.keys())}")
|
||||
for key, value in payload.items():
|
||||
if key != 'text' and key != 'vector':
|
||||
print(f" {key}: {value}")
|
||||
else:
|
||||
print(f"课程知识库 collection 不存在!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_course_kb_collection())
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
检查课程知识库的录入情况
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
from app.models.entities import Document
|
||||
|
||||
|
||||
async def check_course_kb_status():
|
||||
"""检查课程知识库的录入情况"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查课程知识库的录入情况")
|
||||
print(f"{'='*80}")
|
||||
print(f"租户 ID: {tenant_id}")
|
||||
print(f"知识库 ID: {kb_id}")
|
||||
|
||||
async with async_session() as session:
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.kb_id == kb_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
documents = result.scalars().all()
|
||||
|
||||
print(f"\n数据库中的文档记录: {len(documents)} 个")
|
||||
if documents:
|
||||
for doc in documents[:5]:
|
||||
print(f" - {doc.file_name} (status: {doc.status})")
|
||||
if len(documents) > 5:
|
||||
print(f" ... 还有 {len(documents) - 5} 个文档")
|
||||
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
print(f"\nQdrant Collection 名称: {collection_name}")
|
||||
|
||||
exists = await qdrant.collection_exists(collection_name)
|
||||
if exists:
|
||||
points_result = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=5,
|
||||
with_vectors=False,
|
||||
)
|
||||
points = points_result[0] if isinstance(points_result, tuple) else points_result
|
||||
print(f"Qdrant Collection 存在,有 {len(points)} 条数据")
|
||||
for i, point in enumerate(points, 1):
|
||||
if hasattr(point, 'payload'):
|
||||
payload = point.payload
|
||||
point_id = point.id
|
||||
else:
|
||||
payload = point.get('payload', {})
|
||||
point_id = point.get('id', 'unknown')
|
||||
print(f" [{i}] id: {point_id}")
|
||||
if 'text' in payload:
|
||||
text = payload['text'][:50] + '...' if len(payload['text']) > 50 else payload['text']
|
||||
print(f" text: {text}")
|
||||
else:
|
||||
print(f"Qdrant Collection 不存在!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"结论:")
|
||||
if len(documents) > 0 and not exists:
|
||||
print(" 数据库有文档记录,但 Qdrant Collection 不存在")
|
||||
print(" 需要等待文档向量化任务完成")
|
||||
elif len(documents) == 0 and exists:
|
||||
print(" 数据库没有文档记录,但 Qdrant Collection 存在")
|
||||
print(" 可能是旧数据")
|
||||
elif len(documents) > 0 and exists:
|
||||
print(f" 数据库有 {len(documents)} 个文档记录")
|
||||
print(f" Qdrant Collection 存在")
|
||||
print(" ✅ 知识库已录入完成")
|
||||
else:
|
||||
print(" 数据库没有文档记录")
|
||||
print(" Qdrant Collection 不存在")
|
||||
print(" ❌ 知识库未录入")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_course_kb_status())
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
检查 Qdrant 中是否有 grade=五年级 的数据
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def check_grade_data():
|
||||
"""检查 Qdrant 中是否有 grade=五年级 的数据"""
|
||||
settings = get_settings()
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查 Qdrant 中 grade 字段的分布")
|
||||
print(f"{'='*80}")
|
||||
print(f"Collection: {collection_name}")
|
||||
|
||||
# 获取所有数据
|
||||
all_points = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=100,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
print(f"\n总数据量: {len(all_points[0])} 条")
|
||||
|
||||
# 统计 grade 分布
|
||||
grade_count = {}
|
||||
for point in all_points[0]:
|
||||
metadata = point.payload.get('metadata', {})
|
||||
grade = metadata.get('grade', '无')
|
||||
grade_count[grade] = grade_count.get(grade, 0) + 1
|
||||
|
||||
print(f"\ngrade 字段分布:")
|
||||
for grade, count in sorted(grade_count.items()):
|
||||
print(f" {grade}: {count} 条")
|
||||
|
||||
# 检查是否有 五年级 的数据
|
||||
print(f"\n--- 检查 grade=五年级 的数据 ---")
|
||||
qdrant_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="metadata.grade",
|
||||
match=MatchValue(value="五年级"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=10,
|
||||
with_vectors=False,
|
||||
scroll_filter=qdrant_filter,
|
||||
)
|
||||
|
||||
print(f"grade=五年级 的数据: {len(results[0])} 条")
|
||||
for p in results[0]:
|
||||
print(f" text: {p.payload.get('text', '')[:80]}...")
|
||||
print(f" metadata: {p.payload.get('metadata', {})}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_grade_data())
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
"""
|
||||
查看指定知识库的内容
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def check_kb_content():
|
||||
"""查看知识库内容"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c"
|
||||
collection_name = f"kb_szmp_ash_2026_8559ebc9"
|
||||
|
||||
print("=" * 80)
|
||||
print(f"查看知识库: {kb_id}")
|
||||
print(f"Collection: {collection_name}")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# 检查 collection 是否存在
|
||||
exists = await client.collection_exists(collection_name)
|
||||
print(f"\nCollection 存在: {exists}")
|
||||
|
||||
if not exists:
|
||||
print("Collection 不存在!")
|
||||
return
|
||||
|
||||
# 获取 collection 信息
|
||||
info = await client.get_collection(collection_name)
|
||||
print(f"\nCollection 信息:")
|
||||
print(f" 向量数: {info.points_count}")
|
||||
|
||||
# 滚动查询所有点
|
||||
print(f"\n文档内容:")
|
||||
print("-" * 80)
|
||||
|
||||
offset = None
|
||||
total = 0
|
||||
while True:
|
||||
result = await client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=10,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
points = result[0]
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
total += 1
|
||||
payload = point.payload or {}
|
||||
text = payload.get('text', 'N/A')[:100]
|
||||
metadata = payload.get('metadata', {})
|
||||
filename = payload.get('filename', 'N/A')
|
||||
|
||||
print(f"\n [{total}] ID: {point.id}")
|
||||
print(f" Filename: {filename}")
|
||||
print(f" Text: {text}...")
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
offset = result[1]
|
||||
if offset is None:
|
||||
break
|
||||
|
||||
print(f"\n总计 {total} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_kb_content())
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""
|
||||
检查租户的所有知识库
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import get_settings
|
||||
from app.models.entities import KnowledgeBase
|
||||
|
||||
|
||||
async def check_knowledge_bases():
|
||||
"""检查租户的所有知识库"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查租户 {tenant_id} 的所有知识库")
|
||||
print(f"{'='*80}")
|
||||
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.tenant_id == tenant_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
kbs = result.scalars().all()
|
||||
|
||||
print(f"\n找到 {len(kbs)} 个知识库:")
|
||||
|
||||
for kb in kbs:
|
||||
print(f"\n 知识库: {kb.name}")
|
||||
print(f" id: {kb.id}")
|
||||
print(f" kb_type: {kb.kb_type}")
|
||||
print(f" description: {kb.description}")
|
||||
print(f" is_enabled: {kb.is_enabled}")
|
||||
print(f" doc_count: {kb.doc_count}")
|
||||
print(f" created_at: {kb.created_at}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_knowledge_bases())
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
检查知识库的元数据字段定义
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import get_settings
|
||||
from app.models.entities import (
|
||||
MetadataFieldDefinition,
|
||||
MetadataFieldStatus,
|
||||
FieldRole,
|
||||
)
|
||||
|
||||
|
||||
async def check_metadata_fields():
|
||||
"""检查元数据字段定义"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查租户 {tenant_id} 的元数据字段定义")
|
||||
print(f"{'='*80}")
|
||||
|
||||
stmt = select(MetadataFieldDefinition).where(
|
||||
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
fields = result.scalars().all()
|
||||
|
||||
print(f"\n找到 {len(fields)} 个活跃字段定义:")
|
||||
|
||||
for f in fields:
|
||||
print(f"\n 字段: {f.field_key}")
|
||||
print(f" label: {f.label}")
|
||||
print(f" type: {f.type}")
|
||||
print(f" required: {f.required}")
|
||||
print(f" field_roles: {f.field_roles}")
|
||||
print(f" options: {f.options}")
|
||||
print(f" default_value: {f.default_value}")
|
||||
|
||||
filterable_fields = [
|
||||
f for f in fields
|
||||
if f.field_roles and FieldRole.RESOURCE_FILTER.value in f.field_roles
|
||||
]
|
||||
print(f"\n{'='*80}")
|
||||
print(f"可过滤字段 (field_roles 包含 resource_filter): {len(filterable_fields)} 个")
|
||||
for f in filterable_fields:
|
||||
print(f" - {f.field_key} (label: {f.label}, required: {f.required})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_metadata_fields())
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
检查 Qdrant 中数据的 metadata 存储结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def check_metadata_structure():
|
||||
"""检查 Qdrant 中数据的 metadata 存储结构"""
|
||||
settings = get_settings()
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查 Qdrant 数据结构")
|
||||
print(f"{'='*80}")
|
||||
print(f"Collection: {collection_name}")
|
||||
|
||||
points = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(points[0])} 条数据:")
|
||||
|
||||
for i, point in enumerate(points[0], 1):
|
||||
print(f"\n--- Point {i} ---")
|
||||
if hasattr(point, 'payload'):
|
||||
payload = point.payload
|
||||
point_id = point.id
|
||||
else:
|
||||
payload = point.get('payload', {})
|
||||
point_id = point.get('id', 'unknown')
|
||||
|
||||
print(f"ID: {point_id}")
|
||||
print(f"Payload keys: {list(payload.keys())}")
|
||||
|
||||
# 打印完整的 payload 结构
|
||||
for key, value in payload.items():
|
||||
if key == 'text':
|
||||
print(f" {key}: {value[:50]}..." if len(str(value)) > 50 else f" {key}: {value}")
|
||||
elif key == 'vector':
|
||||
print(f" {key}: [向量数据]")
|
||||
else:
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# 检查 metadata 字段
|
||||
if 'metadata' in payload:
|
||||
print(f"\n metadata 字段内容:")
|
||||
for mk, mv in payload['metadata'].items():
|
||||
print(f" {mk}: {mv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_metadata_structure())
|
||||
|
|
@ -1,79 +1,112 @@
|
|||
"""
|
||||
Check Qdrant vector database contents - detailed view.
|
||||
检查 Qdrant 向量数据库状态和知识库内容
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from app.core.config import get_settings
|
||||
from collections import defaultdict
|
||||
|
||||
settings = get_settings()
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
|
||||
|
||||
async def check_qdrant():
|
||||
"""Check Qdrant collections and vectors."""
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False)
|
||||
"""检查 Qdrant 状态"""
|
||||
settings = get_settings()
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Database URL: {settings.database_url}")
|
||||
print(f"Qdrant URL: {settings.qdrant_url}")
|
||||
print(f"{'='*60}\n")
|
||||
print(f"Tenant ID: {tenant_id}")
|
||||
print()
|
||||
|
||||
# List all collections
|
||||
collections = await client.get_collections()
|
||||
|
||||
# Check kb_default collection
|
||||
for c in collections.collections:
|
||||
if c.name == "kb_default":
|
||||
print(f"\n--- Collection: {c.name} ---")
|
||||
try:
|
||||
qdrant_manager = await get_qdrant_client()
|
||||
client = await qdrant_manager.get_client()
|
||||
|
||||
# 检查集合是否存在
|
||||
collections = (await client.get_collections()).collections
|
||||
collection_names = [c.name for c in collections]
|
||||
print(f"Available collections: {collection_names}")
|
||||
print()
|
||||
|
||||
# 筛选该租户的 collections
|
||||
tenant_collections = [name for name in collection_names if "szmp_ash_2026" in name]
|
||||
print(f"Tenant collections: {tenant_collections}")
|
||||
print()
|
||||
|
||||
# 检查每个集合
|
||||
for collection_name in tenant_collections:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Collection: {collection_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Get collection info
|
||||
info = await client.get_collection(c.name)
|
||||
print(f" Total vectors: {info.points_count}")
|
||||
# 获取集合信息
|
||||
collection_info = await client.get_collection(collection_name)
|
||||
print(f" Points count: {collection_info.points_count}")
|
||||
print(f" Vectors count: {collection_info.vectors_count}")
|
||||
print(f" Status: {collection_info.status}")
|
||||
|
||||
# Scroll through all points and group by source
|
||||
all_points = []
|
||||
offset = None
|
||||
if collection_info.points_count == 0:
|
||||
print(" ⚠️ Collection is empty!")
|
||||
continue
|
||||
|
||||
while True:
|
||||
points, offset = await client.scroll(
|
||||
collection_name=c.name,
|
||||
limit=100,
|
||||
offset=offset,
|
||||
# 滚动获取一些数据
|
||||
print(f"\n 前 3 条数据:")
|
||||
points, next_page = await client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
for i, point in enumerate(points, 1):
|
||||
payload = point.payload or {}
|
||||
text = payload.get("text", "")[:100] + "..." if payload.get("text") else "N/A"
|
||||
kb_id = payload.get("kb_id", "N/A")
|
||||
metadata = payload.get("metadata", {})
|
||||
print(f"\n Point {i}:")
|
||||
print(f" ID: {point.id}")
|
||||
print(f" KB ID: {kb_id}")
|
||||
print(f" Text: {text}")
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
# 尝试向量搜索
|
||||
print(f"\n\n{'='*60}")
|
||||
print(f"尝试向量搜索 (query='课程'):")
|
||||
print(f"{'='*60}")
|
||||
|
||||
from app.services.embedding.factory import get_embedding_provider
|
||||
|
||||
embedding_provider = await get_embedding_provider()
|
||||
query_vector = await embedding_provider.embed("课程")
|
||||
print(f"Query vector dimension: {len(query_vector)}")
|
||||
|
||||
for collection_name in tenant_collections:
|
||||
print(f"\n搜索 collection: {collection_name}")
|
||||
try:
|
||||
search_results = await client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using="full", # 使用 full 向量
|
||||
limit=3,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
all_points.extend(points)
|
||||
if offset is None:
|
||||
break
|
||||
|
||||
# Group by source
|
||||
by_source = defaultdict(list)
|
||||
for p in all_points:
|
||||
source = p.payload.get("source", "unknown") if p.payload else "unknown"
|
||||
by_source[source].append(p)
|
||||
|
||||
print(f"\n Documents by source:")
|
||||
for source, points in by_source.items():
|
||||
print(f"\n Source: {source}")
|
||||
print(f" Chunks: {len(points)}")
|
||||
|
||||
# Check first chunk content
|
||||
first_point = points[0]
|
||||
text = first_point.payload.get("text", "") if first_point.payload else ""
|
||||
|
||||
# Check if it's binary garbage or proper text
|
||||
is_garbage = any(ord(c) > 0xFFFF or (ord(c) < 32 and c not in '\n\r\t') for c in text[:200])
|
||||
|
||||
if is_garbage:
|
||||
print(f" Status: ❌ BINARY GARBAGE (parsing failed)")
|
||||
else:
|
||||
print(f" Status: ✅ PROPER TEXT (parsed correctly)")
|
||||
|
||||
print(f" Preview: {text[:150]}...")
|
||||
|
||||
await client.close()
|
||||
print(f" Search results: {len(search_results.points)}")
|
||||
for i, result in enumerate(search_results.points, 1):
|
||||
payload = result.payload or {}
|
||||
text = payload.get("text", "")[:80] + "..." if payload.get("text") else "N/A"
|
||||
print(f" {i}. [score={result.score:.4f}] {text}")
|
||||
except Exception as e:
|
||||
print(f" ❌ Search error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
检查 Qdrant 中实际存在的 collections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def check_qdrant_collections():
|
||||
"""检查 Qdrant 中实际存在的 collections"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Qdrant 中所有 collections:")
|
||||
print(f"{'='*80}")
|
||||
|
||||
for coll in collections.collections:
|
||||
print(f" - {coll.name}")
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"kb_{safe_tenant_id}"
|
||||
|
||||
tenant_collections = [coll.name for coll in collections.collections if coll.name.startswith(prefix)]
|
||||
print(f"\n租户 {tenant_id} 的 collections:")
|
||||
print(f"{'='*80}")
|
||||
for coll_name in tenant_collections:
|
||||
print(f" - {coll_name}")
|
||||
|
||||
kb_id = None
|
||||
if coll_name.startswith(prefix):
|
||||
parts = coll_name.split('_')
|
||||
if len(parts) > 2:
|
||||
kb_id = parts[2]
|
||||
print(f" kb_id: {kb_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_qdrant_collections())
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
检查 Qdrant 中实际存储的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def check_qdrant_data():
|
||||
"""检查 Qdrant 中的数据"""
|
||||
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查租户 {tenant_id} 的 Qdrant 数据")
|
||||
print(f"{'='*80}")
|
||||
|
||||
collections = await client.list_collections(tenant_id)
|
||||
print(f"\n找到 {len(collections)} 个集合:")
|
||||
for coll in collections:
|
||||
print(f" - {coll}")
|
||||
|
||||
for collection_name in collections[:3]:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"检查集合: {collection_name}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
try:
|
||||
points = await client.scroll_points(
|
||||
collection_name=collection_name,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(points)} 条数据:")
|
||||
for i, point in enumerate(points, 1):
|
||||
payload = point.get('payload', {})
|
||||
print(f"\n [{i}] id: {point.get('id')}")
|
||||
print(f" metadata 字段:")
|
||||
for key, value in payload.items():
|
||||
if key != 'text' and key != 'vector':
|
||||
print(f" {key}: {value}")
|
||||
|
||||
text = payload.get('text', '')
|
||||
if text:
|
||||
print(f" text 预览: {text[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_qdrant_data())
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
清理 szmp@ash@2026 租户下不需要的 Qdrant collections
|
||||
保留:8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c, 30c19c84-8f69-4768-9d23-7f4a5bc3627a
|
||||
删除:其他所有 collections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def cleanup_collections():
|
||||
"""清理 collections"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"kb_{safe_tenant_id}"
|
||||
|
||||
# 保留的 kb_id 前缀(前8位)
|
||||
keep_kb_ids = [
|
||||
"8559ebc9",
|
||||
"30c19c84",
|
||||
]
|
||||
|
||||
print(f"🔍 扫描租户 {tenant_id} 的 collections...")
|
||||
print(f" 前缀: {prefix}")
|
||||
print(f" 保留: {keep_kb_ids}")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
|
||||
# 找出该租户的所有 collections
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
|
||||
print(f"\n📊 找到 {len(tenant_collections)} 个 collections:")
|
||||
for name in sorted(tenant_collections):
|
||||
# 检查是否需要保留
|
||||
should_keep = any(kb_id in name for kb_id in keep_kb_ids)
|
||||
status = "✅ 保留" if should_keep else "❌ 删除"
|
||||
print(f" {status} {name}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("开始删除...")
|
||||
print("=" * 80)
|
||||
|
||||
deleted = []
|
||||
skipped = []
|
||||
|
||||
for collection_name in tenant_collections:
|
||||
# 检查是否需要保留
|
||||
should_keep = any(kb_id in collection_name for kb_id in keep_kb_ids)
|
||||
|
||||
if should_keep:
|
||||
print(f"\n⏭️ 跳过 {collection_name} (保留)")
|
||||
skipped.append(collection_name)
|
||||
continue
|
||||
|
||||
print(f"\n🗑️ 删除 {collection_name}...")
|
||||
try:
|
||||
await client.delete_collection(collection_name)
|
||||
print(f" ✅ 已删除")
|
||||
deleted.append(collection_name)
|
||||
except Exception as e:
|
||||
print(f" ❌ 删除失败: {e}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("清理完成!")
|
||||
print("=" * 80)
|
||||
print(f"\n📈 统计:")
|
||||
print(f" 保留: {len(skipped)} 个")
|
||||
for name in skipped:
|
||||
print(f" - {name}")
|
||||
print(f"\n 删除: {len(deleted)} 个")
|
||||
for name in deleted:
|
||||
print(f" - {name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 安全确认
|
||||
print("=" * 80)
|
||||
print("⚠️ 警告: 此操作将永久删除以下 collections:")
|
||||
print(" - kb_szmp_ash_2026")
|
||||
print(" - kb_szmp_ash_2026_fa4c1d61")
|
||||
print(" - kb_szmp_ash_2026_3ddf0ce7")
|
||||
print("\n 保留:")
|
||||
print(" - kb_szmp_ash_2026_8559ebc9")
|
||||
print(" - kb_szmp_ash_2026_30c19c84")
|
||||
print("=" * 80)
|
||||
print("\n确认删除? (yes/no): ", end="")
|
||||
|
||||
# 在非交互环境自动确认
|
||||
import os
|
||||
if os.environ.get('AUTO_CONFIRM') == 'true':
|
||||
response = 'yes'
|
||||
print('yes (auto)')
|
||||
else:
|
||||
response = input().strip().lower()
|
||||
|
||||
if response in ('yes', 'y'):
|
||||
asyncio.run(cleanup_collections())
|
||||
else:
|
||||
print("\n❌ 已取消")
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""
|
||||
删除课程知识库的 Qdrant Collection
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def delete_course_kb_collection():
|
||||
"""删除课程知识库的 Qdrant Collection"""
|
||||
settings = get_settings()
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"删除课程知识库的 Qdrant Collection")
|
||||
print(f"{'='*80}")
|
||||
print(f"租户 ID: {tenant_id}")
|
||||
print(f"知识库 ID: {kb_id}")
|
||||
print(f"Collection 名称: {collection_name}")
|
||||
|
||||
exists = await qdrant.collection_exists(collection_name)
|
||||
if exists:
|
||||
await qdrant.delete_collection(collection_name)
|
||||
print(f"\n✅ Collection {collection_name} 已删除!")
|
||||
else:
|
||||
print(f"\n⚠️ Collection {collection_name} 不存在,无需删除")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(delete_course_kb_collection())
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
删除课程知识库的文档记录
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import get_settings
|
||||
from app.models.entities import Document, IndexJob
|
||||
|
||||
|
||||
async def delete_course_kb_documents():
|
||||
"""删除课程知识库的文档记录"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"删除课程知识库的文档记录")
|
||||
print(f"{'='*80}")
|
||||
print(f"租户 ID: {tenant_id}")
|
||||
print(f"知识库 ID: {kb_id}")
|
||||
|
||||
async with async_session() as session:
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.kb_id == kb_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
documents = result.scalars().all()
|
||||
|
||||
print(f"\n找到 {len(documents)} 个文档记录")
|
||||
|
||||
if not documents:
|
||||
print("没有需要删除的文档记录")
|
||||
return
|
||||
|
||||
for doc in documents[:5]:
|
||||
print(f" - {doc.file_name} (id: {doc.id})")
|
||||
if len(documents) > 5:
|
||||
print(f" ... 还有 {len(documents) - 5} 个文档")
|
||||
|
||||
doc_ids = [doc.id for doc in documents]
|
||||
|
||||
index_job_stmt = delete(IndexJob).where(
|
||||
IndexJob.tenant_id == tenant_id,
|
||||
IndexJob.doc_id.in_(doc_ids),
|
||||
)
|
||||
index_job_result = await session.execute(index_job_stmt)
|
||||
print(f"\n删除了 {index_job_result.rowcount} 个索引任务记录")
|
||||
|
||||
doc_stmt = delete(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.kb_id == kb_id,
|
||||
)
|
||||
doc_result = await session.execute(doc_stmt)
|
||||
print(f"删除了 {doc_result.rowcount} 个文档记录")
|
||||
|
||||
await session.commit()
|
||||
|
||||
print(f"\n✅ 删除完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(delete_course_kb_documents())
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""
|
||||
获取数据库中的 API key
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.models.entities import ApiKey
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def get_api_keys():
|
||||
"""获取所有 API keys"""
|
||||
settings = get_settings()
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).where(ApiKey.is_active == True)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
print("=" * 80)
|
||||
print("数据库中的 API Keys:")
|
||||
print("=" * 80)
|
||||
for key in keys:
|
||||
print(f"\nKey: {key.key}")
|
||||
print(f" Name: {key.name}")
|
||||
print(f" Tenant: {key.tenant_id}")
|
||||
print(f" Active: {key.is_active}")
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(get_api_keys())
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
获取数据库中的意图规则
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.models.entities import IntentRule
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def get_intent_rules():
|
||||
"""获取所有意图规则"""
|
||||
settings = get_settings()
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(IntentRule).where(
|
||||
IntentRule.tenant_id == "szmp@ash@2026",
|
||||
IntentRule.is_enabled == True
|
||||
)
|
||||
)
|
||||
rules = result.scalars().all()
|
||||
|
||||
print("=" * 80)
|
||||
print("数据库中的意图规则 (tenant=szmp@ash@2026):")
|
||||
print("=" * 80)
|
||||
|
||||
if not rules:
|
||||
print("\n没有找到任何启用的意图规则!")
|
||||
else:
|
||||
for rule in rules:
|
||||
print(f"\n规则: {rule.name}")
|
||||
print(f" ID: {rule.id}")
|
||||
print(f" 响应类型: {rule.response_type}")
|
||||
print(f" 关键词: {rule.keywords}")
|
||||
print(f" 正则模式: {rule.patterns}")
|
||||
print(f" 目标知识库: {rule.target_kb_ids}")
|
||||
print(f" 优先级: {rule.priority}")
|
||||
print(f" 启用: {rule.is_enabled}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(get_intent_rules())
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
-- Migration: Add usage_description column to metadata_field_definitions table
|
||||
-- [AC-IDSMETA-XX] Add usage description field for metadata field definitions
|
||||
|
||||
-- Add usage_description column
|
||||
ALTER TABLE metadata_field_definitions
|
||||
ADD COLUMN IF NOT EXISTS usage_description TEXT;
|
||||
|
||||
-- Add comment
|
||||
COMMENT ON COLUMN metadata_field_definitions.usage_description IS '用途说明,描述该元数据字段的业务用途';
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
-- Migration: Add extract_strategies field to slot_definitions
|
||||
-- Date: 2026-03-06
|
||||
-- Issue: [AC-MRS-07-UPGRADE] 提取策略体系升级 - 支持策略链
|
||||
|
||||
-- 1. 为 slot_definitions 表新增 extract_strategies 字段(JSONB 数组格式)
|
||||
ALTER TABLE slot_definitions
|
||||
ADD COLUMN IF NOT EXISTS extract_strategies JSONB DEFAULT NULL;
|
||||
|
||||
-- 2. 添加字段注释
|
||||
COMMENT ON COLUMN slot_definitions.extract_strategies IS
|
||||
'[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功';
|
||||
|
||||
-- 3. 数据迁移:将旧的 extract_strategy 转换为 extract_strategies 数组
|
||||
-- 注意:保留旧字段用于兼容读取
|
||||
UPDATE slot_definitions
|
||||
SET extract_strategies = CASE
|
||||
WHEN extract_strategy IS NOT NULL THEN jsonb_build_array(extract_strategy)
|
||||
ELSE NULL
|
||||
END
|
||||
WHERE extract_strategies IS NULL;
|
||||
|
||||
-- 4. 创建 GIN 索引支持策略查询(可选,根据实际查询需求)
|
||||
-- CREATE INDEX IF NOT EXISTS idx_slot_definitions_extract_strategies
|
||||
-- ON slot_definitions USING GIN (extract_strategies);
|
||||
|
||||
-- 5. 验证迁移结果
|
||||
-- SELECT
|
||||
-- id,
|
||||
-- slot_key,
|
||||
-- extract_strategy as old_strategy,
|
||||
-- extract_strategies as new_strategies
|
||||
-- FROM slot_definitions
|
||||
-- WHERE extract_strategy IS NOT NULL;
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
-- Add metadata field to documents table
|
||||
-- Migration: 010_add_metadata_to_documents
|
||||
-- Description: Add metadata JSON field to store document-level metadata
|
||||
|
||||
ALTER TABLE documents ADD COLUMN IF NOT EXISTS metadata JSONB;
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""
|
||||
Migration script to add intent_vector and semantic_examples fields to intent_rules table.
|
||||
[v0.8.0] Hybrid routing - Intent vector fields
|
||||
|
||||
Run this script with: python scripts/migrations/011_add_intent_vector_fields.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from sqlalchemy import text
|
||||
from app.core.database import async_session_maker
|
||||
|
||||
|
||||
async def run_migration():
|
||||
"""Run the migration to add new columns to intent_rules table."""
|
||||
|
||||
statements = [
|
||||
"""
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
|
||||
""",
|
||||
"""
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
|
||||
""",
|
||||
"""
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN IF NOT EXISTS route_trace JSONB;
|
||||
""",
|
||||
]
|
||||
|
||||
async with async_session_maker() as session:
|
||||
for i, statement in enumerate(statements, 1):
|
||||
try:
|
||||
await session.execute(text(statement))
|
||||
print(f"[{i}] Executed successfully")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower() or "duplicate" in str(e).lower():
|
||||
print(f"[{i}] Skipped (already exists): {str(e)[:50]}...")
|
||||
else:
|
||||
raise
|
||||
|
||||
await session.commit()
|
||||
print("\nMigration completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_migration())
|
||||
else:
|
||||
print("Please run this script directly, not imported.")
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
-- Add intent_vector and semantic_examples fields to intent_rules table
|
||||
-- [v0.8.0] Hybrid routing support for semantic matching
|
||||
|
||||
-- Add intent_vector column (JSONB for storing pre-computed embedding vectors)
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
|
||||
|
||||
-- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation)
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
|
||||
|
||||
-- Add comments for documentation
|
||||
COMMENT ON COLUMN intent_rules.intent_vector IS '[v0.8.0] Pre-computed intent vector for semantic matching';
|
||||
COMMENT ON COLUMN intent_rules.semantic_examples IS '[v0.8.0] Semantic example sentences for dynamic vector computation';
|
||||
|
||||
-- Add route_trace column to chat_messages table if not exists
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN IF NOT EXISTS route_trace JSONB;
|
||||
|
||||
COMMENT ON COLUMN chat_messages.route_trace IS '[v0.8.0] Intent routing trace log for hybrid routing observability';
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
详细性能分析 - 确认每个环节的耗时
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from app.services.embedding import get_embedding_provider
|
||||
|
||||
|
||||
async def profile_detailed():
|
||||
"""详细分析每个环节的耗时"""
|
||||
settings = get_settings()
|
||||
|
||||
print("=" * 80)
|
||||
print("详细性能分析")
|
||||
print("=" * 80)
|
||||
|
||||
query = "三年级语文学习"
|
||||
tenant_id = "szmp@ash@2026"
|
||||
metadata_filter = {"grade": "三年级", "subject": "语文"}
|
||||
|
||||
# 1. Embedding 生成(应该已预初始化)
|
||||
print("\n📊 1. Embedding 生成")
|
||||
print("-" * 80)
|
||||
start = time.time()
|
||||
embedding_service = await get_embedding_provider()
|
||||
init_time = (time.time() - start) * 1000
|
||||
|
||||
start = time.time()
|
||||
embedding_result = await embedding_service.embed_query(query)
|
||||
embed_time = (time.time() - start) * 1000
|
||||
|
||||
# 获取 embedding 向量
|
||||
if hasattr(embedding_result, 'embedding_full'):
|
||||
query_vector = embedding_result.embedding_full
|
||||
elif hasattr(embedding_result, 'embedding'):
|
||||
query_vector = embedding_result.embedding
|
||||
else:
|
||||
query_vector = embedding_result
|
||||
|
||||
print(f" 获取服务实例: {init_time:.2f} ms")
|
||||
print(f" Embedding 生成: {embed_time:.2f} ms")
|
||||
print(f" 向量维度: {len(query_vector)}")
|
||||
|
||||
# 2. 获取 collections 列表(带缓存)
|
||||
print("\n📊 2. 获取 collections 列表")
|
||||
print("-" * 80)
|
||||
client = await get_qdrant_client()
|
||||
qdrant_client = await client.get_client()
|
||||
|
||||
start = time.time()
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
cache_key = f"collections:{tenant_id}"
|
||||
|
||||
# 尝试从缓存获取
|
||||
redis_client = await cache_service._get_redis()
|
||||
cache_hit = False
|
||||
if redis_client and cache_service._enabled:
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
import json
|
||||
tenant_collections = json.loads(cached)
|
||||
cache_hit = True
|
||||
cache_time = (time.time() - start) * 1000
|
||||
print(f" ✅ 缓存命中: {cache_time:.2f} ms")
|
||||
print(f" Collections: {tenant_collections}")
|
||||
|
||||
if not cache_hit:
|
||||
import json
|
||||
# 从 Qdrant 查询
|
||||
start = time.time()
|
||||
collections = await qdrant_client.get_collections()
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"kb_{safe_tenant_id}"
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
tenant_collections.sort()
|
||||
db_time = (time.time() - start) * 1000
|
||||
print(f" ❌ 缓存未命中,从 Qdrant 查询: {db_time:.2f} ms")
|
||||
print(f" Collections: {tenant_collections}")
|
||||
|
||||
# 缓存结果
|
||||
if redis_client and cache_service._enabled:
|
||||
await redis_client.setex(cache_key, 300, json.dumps(tenant_collections))
|
||||
print(f" 已缓存到 Redis")
|
||||
|
||||
# 3. Qdrant 搜索(每个 collection)
|
||||
print("\n📊 3. Qdrant 搜索")
|
||||
print("-" * 80)
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
# 构建 filter
|
||||
start = time.time()
|
||||
must_conditions = []
|
||||
for key, value in metadata_filter.items():
|
||||
field_path = f"metadata.{key}"
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
|
||||
filter_time = (time.time() - start) * 1000
|
||||
print(f" 构建 filter: {filter_time:.2f} ms")
|
||||
|
||||
# 逐个 collection 搜索
|
||||
total_search_time = 0
|
||||
for collection_name in tenant_collections:
|
||||
print(f"\n Collection: {collection_name}")
|
||||
|
||||
# 检查是否存在
|
||||
start = time.time()
|
||||
exists = await qdrant_client.collection_exists(collection_name)
|
||||
check_time = (time.time() - start) * 1000
|
||||
print(f" 检查存在: {check_time:.2f} ms")
|
||||
|
||||
if not exists:
|
||||
print(f" ❌ 不存在")
|
||||
continue
|
||||
|
||||
# 搜索
|
||||
start = time.time()
|
||||
try:
|
||||
results = await qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using="full",
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
search_time = (time.time() - start) * 1000
|
||||
total_search_time += search_time
|
||||
print(f" 搜索时间: {search_time:.2f} ms")
|
||||
print(f" 结果数: {len(results.points)}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 搜索失败: {e}")
|
||||
|
||||
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
|
||||
|
||||
# 4. 完整 KB Search 流程
|
||||
print("\n📊 4. 完整 KB Search 流程")
|
||||
print("-" * 80)
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=30000,
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
start = time.time()
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene="学习方案",
|
||||
top_k=5,
|
||||
context=metadata_filter,
|
||||
)
|
||||
total_time = (time.time() - start) * 1000
|
||||
|
||||
print(f" 总耗时: {total_time:.2f} ms")
|
||||
print(f" 工具内部耗时: {result.duration_ms} ms")
|
||||
print(f" 结果数: {len(result.hits)}")
|
||||
|
||||
# 5. 总结
|
||||
print("\n" + "=" * 80)
|
||||
print("📈 耗时总结")
|
||||
print("=" * 80)
|
||||
print(f"\n各环节耗时:")
|
||||
print(f" Embedding 获取服务: {init_time:.2f} ms")
|
||||
print(f" Embedding 生成: {embed_time:.2f} ms")
|
||||
print(f" Collections 获取: {cache_time if cache_hit else db_time:.2f} ms")
|
||||
print(f" Filter 构建: {filter_time:.2f} ms")
|
||||
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
|
||||
print(f" 完整流程: {total_time:.2f} ms")
|
||||
|
||||
other_time = total_time - embed_time - (cache_time if cache_hit else db_time) - filter_time - total_search_time
|
||||
print(f" 其他开销: {other_time:.2f} ms")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(profile_detailed())
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
详细分析完整参数查询的耗时
|
||||
对比带 metadata_filter 和不带的区别
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
|
||||
async def profile_step_by_step():
|
||||
"""逐步分析完整参数查询的耗时"""
|
||||
settings = get_settings()
|
||||
|
||||
print("=" * 80)
|
||||
print("完整参数查询耗时分析")
|
||||
print("=" * 80)
|
||||
|
||||
query = "三年级语文学习"
|
||||
tenant_id = "szmp@ash@2026"
|
||||
metadata_filter = {"grade": "三年级", "subject": "语文"}
|
||||
|
||||
# 1. Embedding 生成
|
||||
print("\n📊 1. Embedding 生成")
|
||||
print("-" * 80)
|
||||
from app.services.embedding import get_embedding_provider
|
||||
|
||||
start = time.time()
|
||||
embedding_service = await get_embedding_provider()
|
||||
init_time = time.time() - start
|
||||
|
||||
start = time.time()
|
||||
embedding_result = await embedding_service.embed_query(query)
|
||||
embed_time = (time.time() - start) * 1000
|
||||
|
||||
# 获取 embedding 向量
|
||||
if hasattr(embedding_result, 'embedding_full'):
|
||||
query_vector = embedding_result.embedding_full
|
||||
elif hasattr(embedding_result, 'embedding'):
|
||||
query_vector = embedding_result.embedding
|
||||
else:
|
||||
query_vector = embedding_result
|
||||
|
||||
print(f" 初始化时间: {init_time * 1000:.2f} ms")
|
||||
print(f" Embedding 生成: {embed_time:.2f} ms")
|
||||
print(f" 向量维度: {len(query_vector)}")
|
||||
|
||||
# 2. 获取 collections 列表
|
||||
print("\n📊 2. 获取 collections 列表")
|
||||
print("-" * 80)
|
||||
client = await get_qdrant_client()
|
||||
qdrant_client = await client.get_client()
|
||||
|
||||
start = time.time()
|
||||
collections = await qdrant_client.get_collections()
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"kb_{safe_tenant_id}"
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
list_time = (time.time() - start) * 1000
|
||||
|
||||
print(f" 获取 collections: {list_time:.2f} ms")
|
||||
print(f" Collections: {tenant_collections}")
|
||||
|
||||
# 3. 构建 metadata filter
|
||||
print("\n📊 3. 构建 metadata filter")
|
||||
print("-" * 80)
|
||||
start = time.time()
|
||||
|
||||
must_conditions = []
|
||||
for key, value in metadata_filter.items():
|
||||
field_path = f"metadata.{key}"
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
filter_time = (time.time() - start) * 1000
|
||||
print(f" 构建 filter: {filter_time:.2f} ms")
|
||||
print(f" Filter: {qdrant_filter}")
|
||||
|
||||
# 4. 逐个 collection 搜索(带 filter)
|
||||
print("\n📊 4. Qdrant 搜索(带 metadata filter)")
|
||||
print("-" * 80)
|
||||
|
||||
total_search_time = 0
|
||||
total_results = 0
|
||||
|
||||
for collection_name in tenant_collections:
|
||||
print(f"\n Collection: {collection_name}")
|
||||
|
||||
# 检查是否存在
|
||||
start = time.time()
|
||||
exists = await qdrant_client.collection_exists(collection_name)
|
||||
check_time = (time.time() - start) * 1000
|
||||
print(f" 检查存在: {check_time:.2f} ms")
|
||||
|
||||
if not exists:
|
||||
print(f" ❌ 不存在")
|
||||
continue
|
||||
|
||||
# 搜索(带 filter)
|
||||
start = time.time()
|
||||
try:
|
||||
results = await qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using="full",
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
search_time = (time.time() - start) * 1000
|
||||
total_search_time += search_time
|
||||
total_results += len(results.points)
|
||||
|
||||
print(f" 搜索时间: {search_time:.2f} ms")
|
||||
print(f" 结果数: {len(results.points)}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 搜索失败: {e}")
|
||||
|
||||
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
|
||||
print(f" 总结果数: {total_results}")
|
||||
|
||||
# 5. 对比:不带 filter 的搜索
|
||||
print("\n📊 5. Qdrant 搜索(不带 metadata filter)对比")
|
||||
print("-" * 80)
|
||||
|
||||
total_search_time_no_filter = 0
|
||||
total_results_no_filter = 0
|
||||
|
||||
for collection_name in tenant_collections:
|
||||
start = time.time()
|
||||
try:
|
||||
results = await qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using="full",
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
# 不带 filter
|
||||
)
|
||||
search_time = (time.time() - start) * 1000
|
||||
total_search_time_no_filter += search_time
|
||||
total_results_no_filter += len(results.points)
|
||||
except Exception as e:
|
||||
print(f" {collection_name}: 失败 {e}")
|
||||
|
||||
print(f" 总搜索时间(无 filter): {total_search_time_no_filter:.2f} ms")
|
||||
print(f" 总结果数(无 filter): {total_results_no_filter}")
|
||||
|
||||
# 6. 完整 KB Search 流程
|
||||
print("\n📊 6. 完整 KB Search 流程(带 context)")
|
||||
print("-" * 80)
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=30000,
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
start = time.time()
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene="学习方案",
|
||||
top_k=5,
|
||||
context=metadata_filter,
|
||||
)
|
||||
total_time = (time.time() - start) * 1000
|
||||
|
||||
print(f" 总耗时: {total_time:.2f} ms")
|
||||
print(f" 工具内部耗时: {result.duration_ms} ms")
|
||||
print(f" 结果数: {len(result.hits)}")
|
||||
print(f" 应用的 filter: {result.applied_filter}")
|
||||
|
||||
# 7. 对比:不带 context 的完整流程
|
||||
print("\n📊 7. 完整 KB Search 流程(不带 context)")
|
||||
print("-" * 80)
|
||||
|
||||
async with async_session() as session:
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
start = time.time()
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene="学习方案",
|
||||
top_k=5,
|
||||
# 不带 context
|
||||
)
|
||||
total_time_no_context = (time.time() - start) * 1000
|
||||
|
||||
print(f" 总耗时: {total_time_no_context:.2f} ms")
|
||||
print(f" 工具内部耗时: {result.duration_ms} ms")
|
||||
print(f" 结果数: {len(result.hits)}")
|
||||
|
||||
# 8. 总结
|
||||
print("\n" + "=" * 80)
|
||||
print("📈 耗时分析总结")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n带 metadata filter:")
|
||||
print(f" Embedding: {embed_time:.2f} ms")
|
||||
print(f" 获取 collections: {list_time:.2f} ms")
|
||||
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
|
||||
print(f" 完整流程: {total_time:.2f} ms")
|
||||
|
||||
print(f"\n不带 metadata filter:")
|
||||
print(f" Qdrant 搜索: {total_search_time_no_filter:.2f} ms")
|
||||
print(f" 完整流程: {total_time_no_context:.2f} ms")
|
||||
|
||||
print(f"\nMetadata filter 额外开销:")
|
||||
print(f" Qdrant 搜索: {total_search_time - total_search_time_no_filter:.2f} ms")
|
||||
print(f" 完整流程: {total_time - total_time_no_context:.2f} ms")
|
||||
|
||||
if total_search_time > total_search_time_no_filter:
|
||||
print(f"\n⚠️ 带 filter 的搜索更慢,可能原因:")
|
||||
print(f" - Filter 增加了索引查找的复杂度")
|
||||
print(f" - 需要匹配 metadata 字段")
|
||||
print(f" - 建议: 检查 Qdrant 的 payload 索引配置")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(profile_step_by_step())
|
||||
|
|
@ -0,0 +1,259 @@
|
|||
"""
|
||||
知识库检索性能分析脚本
|
||||
详细分析每个环节的耗时
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def profile_embedding_generation(query: str):
|
||||
"""分析 embedding 生成耗时"""
|
||||
from app.services.embedding import get_embedding_provider
|
||||
|
||||
start = time.time()
|
||||
embedding_service = await get_embedding_provider()
|
||||
init_time = time.time() - start
|
||||
|
||||
start = time.time()
|
||||
embedding = await embedding_service.embed_query(query)
|
||||
embed_time = time.time() - start
|
||||
|
||||
# 获取 embedding 向量(兼容不同 provider)
|
||||
if hasattr(embedding, 'embedding_full'):
|
||||
vector = embedding.embedding_full
|
||||
elif hasattr(embedding, 'embedding'):
|
||||
vector = embedding.embedding
|
||||
else:
|
||||
vector = embedding
|
||||
|
||||
return {
|
||||
"init_time_ms": init_time * 1000,
|
||||
"embed_time_ms": embed_time * 1000,
|
||||
"dimension": len(vector),
|
||||
}
|
||||
|
||||
|
||||
async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None):
|
||||
"""分析 Qdrant 搜索耗时"""
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
|
||||
client = await get_qdrant_client()
|
||||
|
||||
# 获取 collections
|
||||
start = time.time()
|
||||
qdrant_client = await client.get_client()
|
||||
collections = await qdrant_client.get_collections()
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"kb_{safe_tenant_id}"
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
list_collections_time = time.time() - start
|
||||
|
||||
# 逐个 collection 搜索
|
||||
collection_times = []
|
||||
for collection_name in tenant_collections:
|
||||
start = time.time()
|
||||
exists = await qdrant_client.collection_exists(collection_name)
|
||||
check_time = time.time() - start
|
||||
|
||||
if not exists:
|
||||
collection_times.append({
|
||||
"collection": collection_name,
|
||||
"exists": False,
|
||||
"time_ms": check_time * 1000,
|
||||
})
|
||||
continue
|
||||
|
||||
start = time.time()
|
||||
# 构建 filter
|
||||
qdrant_filter = None
|
||||
if metadata_filter:
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
must_conditions = []
|
||||
for key, value in metadata_filter.items():
|
||||
field_path = f"metadata.{key}"
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
try:
|
||||
results = await qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using="full", # 使用 full 向量
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower():
|
||||
# 尝试不使用 vector name
|
||||
results = await qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
limit=5,
|
||||
score_threshold=0.5,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
search_time = time.time() - start
|
||||
|
||||
collection_times.append({
|
||||
"collection": collection_name,
|
||||
"exists": True,
|
||||
"check_time_ms": check_time * 1000,
|
||||
"search_time_ms": search_time * 1000,
|
||||
"results_count": len(results.points),
|
||||
})
|
||||
|
||||
return {
|
||||
"list_collections_time_ms": list_collections_time * 1000,
|
||||
"collections_count": len(tenant_collections),
|
||||
"collection_times": collection_times,
|
||||
}
|
||||
|
||||
|
||||
async def profile_full_kb_search():
|
||||
"""分析完整的知识库搜索流程"""
|
||||
settings = get_settings()
|
||||
|
||||
print("=" * 80)
|
||||
print("知识库检索性能分析")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 分析 Embedding 生成
|
||||
print("\n📊 1. Embedding 生成分析")
|
||||
print("-" * 80)
|
||||
query = "三年级语文学习"
|
||||
embed_result = await profile_embedding_generation(query)
|
||||
print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms")
|
||||
print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms")
|
||||
print(f" 向量维度: {embed_result['dimension']}")
|
||||
|
||||
# 2. 分析 Qdrant 搜索
|
||||
print("\n📊 2. Qdrant 搜索分析")
|
||||
print("-" * 80)
|
||||
|
||||
# 先生成 embedding
|
||||
from app.services.embedding import get_embedding_provider
|
||||
embedding_service = await get_embedding_provider()
|
||||
embedding_result = await embedding_service.embed_query(query)
|
||||
# 获取 embedding 向量(兼容不同 provider)
|
||||
if hasattr(embedding_result, 'embedding_full'):
|
||||
query_vector = embedding_result.embedding_full
|
||||
elif hasattr(embedding_result, 'embedding'):
|
||||
query_vector = embedding_result.embedding
|
||||
else:
|
||||
query_vector = embedding_result
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
metadata_filter = {"grade": "三年级", "subject": "语文"}
|
||||
|
||||
qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter)
|
||||
print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms")
|
||||
print(f" Collections 数量: {qdrant_result['collections_count']}")
|
||||
print(f"\n 各 Collection 搜索耗时:")
|
||||
for ct in qdrant_result['collection_times']:
|
||||
if ct['exists']:
|
||||
print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)")
|
||||
else:
|
||||
print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)")
|
||||
|
||||
total_search_time = sum(
|
||||
ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times']
|
||||
)
|
||||
print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms")
|
||||
|
||||
# 3. 分析完整流程
|
||||
print("\n📊 3. 完整 KB Search 流程分析")
|
||||
print("-" * 80)
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=30000, # 30秒
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
# 记录各阶段时间
|
||||
stages = []
|
||||
|
||||
start_total = time.time()
|
||||
|
||||
# 执行搜索
|
||||
start = time.time()
|
||||
result = await tool.execute(
|
||||
query=query,
|
||||
tenant_id=tenant_id,
|
||||
scene="学习方案",
|
||||
top_k=5,
|
||||
context=metadata_filter,
|
||||
)
|
||||
total_time = (time.time() - start_total) * 1000
|
||||
|
||||
print(f" 总耗时: {total_time:.2f} ms")
|
||||
print(f" 结果: success={result.success}, hits={len(result.hits)}")
|
||||
print(f" 工具内部耗时: {result.duration_ms} ms")
|
||||
|
||||
# 计算时间差(工具内部 vs 外部测量)
|
||||
overhead = total_time - result.duration_ms
|
||||
print(f" 额外开销(初始化等): {overhead:.2f} ms")
|
||||
|
||||
# 4. 性能瓶颈分析
|
||||
print("\n📊 4. 性能瓶颈分析")
|
||||
print("-" * 80)
|
||||
|
||||
embedding_time = embed_result['embed_time_ms']
|
||||
qdrant_time = total_search_time
|
||||
total_measured = embedding_time + qdrant_time
|
||||
|
||||
print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)")
|
||||
print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)")
|
||||
print(f" 其他开销: {total_time - total_measured:.2f} ms")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("优化建议:")
|
||||
print("=" * 80)
|
||||
|
||||
if embedding_time > 1000:
|
||||
print(" ⚠️ Embedding 生成较慢,考虑:")
|
||||
print(" - 使用更快的 embedding 模型")
|
||||
print(" - 增加 embedding 服务缓存")
|
||||
|
||||
if qdrant_time > 1000:
|
||||
print(" ⚠️ Qdrant 搜索较慢,考虑:")
|
||||
print(" - 并行查询多个 collections")
|
||||
print(" - 优化 Qdrant 索引")
|
||||
print(" - 减少 collections 数量")
|
||||
|
||||
if len(qdrant_result['collection_times']) > 3:
|
||||
print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)")
|
||||
print(" - 建议合并或归档空/少数据的 collections")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(profile_full_kb_search())
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
查询 Qdrant Collection 中的所有内容
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import ScrollRequest
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def query_all_points(collection_name: str):
|
||||
"""查询 collection 中的所有 points"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False)
|
||||
|
||||
print(f"🔍 查询 Collection: {collection_name}")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
# 获取 collection 信息
|
||||
info = await client.get_collection(collection_name)
|
||||
total_points = info.points_count
|
||||
print(f"📊 总向量数: {total_points}\n")
|
||||
|
||||
# 分页获取所有 points
|
||||
all_points = []
|
||||
offset = None
|
||||
batch_size = 100
|
||||
|
||||
while True:
|
||||
scroll_result = await client.scroll(
|
||||
collection_name=collection_name,
|
||||
offset=offset,
|
||||
limit=batch_size,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points, next_offset = scroll_result
|
||||
all_points.extend(points)
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
# 显示进度
|
||||
if len(all_points) % 500 == 0:
|
||||
print(f" 已获取 {len(all_points)} / {total_points} 条记录...")
|
||||
|
||||
print(f"✅ 成功获取全部 {len(all_points)} 条记录\n")
|
||||
print("=" * 80)
|
||||
|
||||
# 显示所有内容
|
||||
for i, point in enumerate(all_points, 1):
|
||||
payload = point.payload or {}
|
||||
|
||||
print(f"\n📄 记录 {i}/{len(all_points)} (ID: {point.id})")
|
||||
print("-" * 80)
|
||||
|
||||
# 显示主要字段
|
||||
text = payload.get('text', '')
|
||||
kb_id = payload.get('kb_id', 'N/A')
|
||||
source = payload.get('source', 'N/A')
|
||||
chunk_index = payload.get('chunk_index', 'N/A')
|
||||
metadata = payload.get('metadata', {})
|
||||
|
||||
print(f" KB ID: {kb_id}")
|
||||
print(f" Source: {source}")
|
||||
print(f" Chunk Index: {chunk_index}")
|
||||
|
||||
if metadata:
|
||||
print(f" Metadata: {json.dumps(metadata, ensure_ascii=False)}")
|
||||
|
||||
# 显示文本内容(格式化)
|
||||
print(f"\n 文本内容:")
|
||||
if text:
|
||||
# 按行显示,保持格式
|
||||
lines = text.split('\n')
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
print(f" {line}")
|
||||
else:
|
||||
print(" (无文本内容)")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"✅ 查询完成,共 {len(all_points)} 条记录")
|
||||
|
||||
# 统计信息
|
||||
print("\n📈 统计信息:")
|
||||
kb_ids = {}
|
||||
for point in all_points:
|
||||
payload = point.payload or {}
|
||||
kb_id = payload.get('kb_id', 'N/A')
|
||||
kb_ids[kb_id] = kb_ids.get(kb_id, 0) + 1
|
||||
|
||||
print(f" KB ID 分布:")
|
||||
for kb_id, count in sorted(kb_ids.items()):
|
||||
print(f" - {kb_id}: {count} 条")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def main():
|
||||
collection_name = "kb_szmp_ash_2026_30c19c84"
|
||||
await query_all_points(collection_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,305 @@
|
|||
"""
|
||||
恢复处理中断的索引任务
|
||||
用于服务重启后继续处理pending/processing状态的任务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import select
|
||||
from app.core.database import async_session_maker
|
||||
from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus
|
||||
from app.api.admin.kb import _index_document
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resume_pending_jobs():
|
||||
"""恢复所有pending和processing状态的任务"""
|
||||
async with async_session_maker() as session:
|
||||
# 查询所有未完成的任务
|
||||
result = await session.execute(
|
||||
select(IndexJob).where(
|
||||
IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value])
|
||||
)
|
||||
)
|
||||
pending_jobs = result.scalars().all()
|
||||
|
||||
if not pending_jobs:
|
||||
logger.info("没有需要恢复的任务")
|
||||
return
|
||||
|
||||
logger.info(f"发现 {len(pending_jobs)} 个未完成的任务")
|
||||
|
||||
for job in pending_jobs:
|
||||
try:
|
||||
# 获取关联的文档
|
||||
doc_result = await session.execute(
|
||||
select(Document).where(Document.id == job.doc_id)
|
||||
)
|
||||
doc = doc_result.scalar_one_or_none()
|
||||
|
||||
if not doc:
|
||||
logger.error(f"找不到文档: {job.doc_id}")
|
||||
continue
|
||||
|
||||
if not doc.file_path or not Path(doc.file_path).exists():
|
||||
logger.error(f"文档文件不存在: {doc.file_path}")
|
||||
# 标记为失败
|
||||
job.status = IndexJobStatus.FAILED.value
|
||||
job.error_msg = "文档文件不存在"
|
||||
doc.status = DocumentStatus.FAILED.value
|
||||
doc.error_msg = "文档文件不存在"
|
||||
await session.commit()
|
||||
continue
|
||||
|
||||
logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}")
|
||||
|
||||
# 读取文件内容
|
||||
with open(doc.file_path, 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
# 重置任务状态为pending
|
||||
job.status = IndexJobStatus.PENDING.value
|
||||
job.progress = 0
|
||||
job.error_msg = None
|
||||
await session.commit()
|
||||
|
||||
# 启动后台任务处理
|
||||
# 注意:这里我们直接调用,不使用background_tasks
|
||||
await process_job(
|
||||
tenant_id=job.tenant_id,
|
||||
kb_id=doc.kb_id,
|
||||
job_id=str(job.id),
|
||||
doc_id=str(doc.id),
|
||||
file_content=file_content,
|
||||
filename=doc.file_name,
|
||||
metadata=doc.doc_metadata or {}
|
||||
)
|
||||
|
||||
logger.info(f"任务处理完成: job_id={job.id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理任务失败: job_id={job.id}, error={e}")
|
||||
# 标记为失败
|
||||
job.status = IndexJobStatus.FAILED.value
|
||||
job.error_msg = str(e)
|
||||
if doc:
|
||||
doc.status = DocumentStatus.FAILED.value
|
||||
doc.error_msg = str(e)
|
||||
await session.commit()
|
||||
|
||||
logger.info("所有任务处理完成")
|
||||
|
||||
|
||||
async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str,
|
||||
file_content: bytes, filename: str, metadata: dict):
|
||||
"""
|
||||
处理单个索引任务
|
||||
复制自 _index_document 函数
|
||||
"""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
|
||||
from app.services.embedding import get_embedding_provider
|
||||
from app.services.kb import KBService
|
||||
from app.api.admin.kb import chunk_text_by_lines, TextChunk
|
||||
|
||||
logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}")
|
||||
|
||||
async with async_session_maker() as session:
|
||||
kb_service = KBService(session)
|
||||
try:
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
parse_result = None
|
||||
text = None
|
||||
file_ext = Path(filename or "").suffix.lower()
|
||||
logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes")
|
||||
|
||||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||||
|
||||
if file_ext in text_extensions or not file_ext:
|
||||
logger.info("[RESUME] Treating as text file")
|
||||
text = None
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
|
||||
try:
|
||||
text = file_content.decode(encoding)
|
||||
logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}")
|
||||
break
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
text = file_content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
logger.info("[RESUME] Binary file detected, will parse with document parser")
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
|
||||
tmp_file.write(file_content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
logger.info(f"[RESUME] Temp file created: {tmp_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"[RESUME] Starting document parsing for {file_ext}...")
|
||||
parse_result = parse_document(tmp_path)
|
||||
text = parse_result.text
|
||||
logger.info(
|
||||
f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}"
|
||||
)
|
||||
except UnsupportedFormatError as e:
|
||||
logger.error(f"[RESUME] UnsupportedFormatError: {e}")
|
||||
text = file_content.decode("utf-8", errors="ignore")
|
||||
except DocumentParseException as e:
|
||||
logger.error(f"[RESUME] DocumentParseException: {e}")
|
||||
text = file_content.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}")
|
||||
text = file_content.decode("utf-8", errors="ignore")
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
logger.info(f"[RESUME] Final text length: {len(text)} chars")
|
||||
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("[RESUME] Getting embedding provider...")
|
||||
embedding_provider = await get_embedding_provider()
|
||||
logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}")
|
||||
|
||||
all_chunks: list[TextChunk] = []
|
||||
|
||||
if parse_result and parse_result.pages:
|
||||
logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages")
|
||||
for page in parse_result.pages:
|
||||
page_chunks = chunk_text_by_lines(
|
||||
page.text,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
for pc in page_chunks:
|
||||
pc.page = page.page
|
||||
all_chunks.extend(page_chunks)
|
||||
else:
|
||||
logger.info("[RESUME] Using line-based chunking")
|
||||
all_chunks = chunk_text_by_lines(
|
||||
text,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
|
||||
logger.info(f"[RESUME] Total chunks: {len(all_chunks)}")
|
||||
|
||||
qdrant = await get_qdrant_client()
|
||||
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
|
||||
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
|
||||
logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}")
|
||||
|
||||
import uuid
|
||||
points = []
|
||||
total_chunks = len(all_chunks)
|
||||
doc_metadata = metadata or {}
|
||||
|
||||
for i, chunk in enumerate(all_chunks):
|
||||
payload = {
|
||||
"text": chunk.text,
|
||||
"source": doc_id,
|
||||
"kb_id": kb_id,
|
||||
"chunk_index": i,
|
||||
"start_token": chunk.start_token,
|
||||
"end_token": chunk.end_token,
|
||||
"metadata": doc_metadata,
|
||||
}
|
||||
if chunk.page is not None:
|
||||
payload["page"] = chunk.page
|
||||
if chunk.source:
|
||||
payload["filename"] = chunk.source
|
||||
|
||||
if use_multi_vector:
|
||||
embedding_result = await embedding_provider.embed_document(chunk.text)
|
||||
points.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": {
|
||||
"full": embedding_result.embedding_full,
|
||||
"dim_256": embedding_result.embedding_256,
|
||||
"dim_512": embedding_result.embedding_512,
|
||||
},
|
||||
"payload": payload,
|
||||
})
|
||||
else:
|
||||
embedding = await embedding_provider.embed(chunk.text)
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=embedding,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
progress = 20 + int((i + 1) / total_chunks * 70)
|
||||
if i % 10 == 0 or i == total_chunks - 1:
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if points:
|
||||
logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...")
|
||||
if use_multi_vector:
|
||||
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
|
||||
else:
|
||||
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
|
||||
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
|
||||
f"job_id={job_id}, chunks={len(all_chunks)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}")
|
||||
await session.rollback()
|
||||
async with async_session_maker() as error_session:
|
||||
kb_service = KBService(error_session)
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.FAILED.value,
|
||||
progress=0, error_msg=str(e)
|
||||
)
|
||||
await error_session.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("开始恢复索引任务...")
|
||||
asyncio.run(resume_pending_jobs())
|
||||
logger.info("恢复脚本执行完成")
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Migration script to add intent_vector and semantic_examples fields.
|
||||
Run: python scripts/run_migration_011.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sqlalchemy import text
|
||||
from app.core.database import engine
|
||||
|
||||
|
||||
async def run_migration():
|
||||
"""Execute migration to add new fields."""
|
||||
migration_sql = """
|
||||
-- Add intent_vector column (JSONB for storing pre-computed embedding vectors)
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
|
||||
|
||||
-- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation)
|
||||
ALTER TABLE intent_rules
|
||||
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
|
||||
|
||||
-- Add route_trace column to chat_messages table if not exists
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN IF NOT EXISTS route_trace JSONB;
|
||||
"""
|
||||
|
||||
async with engine.begin() as conn:
|
||||
for statement in migration_sql.strip().split(";"):
|
||||
statement = statement.strip()
|
||||
if statement and not statement.startswith("--"):
|
||||
try:
|
||||
await conn.execute(text(statement))
|
||||
print(f"Executed: {statement[:80]}...")
|
||||
except Exception as e:
|
||||
print(f"Error executing: {statement[:80]}...")
|
||||
print(f" Error: {e}")
|
||||
|
||||
print("\nMigration completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_migration())
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
通过 API 测试知识库检索性能
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
API_BASE = "http://localhost:8000"
|
||||
API_KEY = "oQfkSAbL8iafzyHxqb--G7zRWSOYJHvlzQxia2KpYms"
|
||||
TENANT_ID = "szmp@ash@2026"
|
||||
|
||||
def test_kb_search():
|
||||
"""测试知识库搜索 API"""
|
||||
print("=" * 80)
|
||||
print("测试知识库检索 API")
|
||||
print("=" * 80)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": API_KEY,
|
||||
"X-Tenant-Id": TENANT_ID,
|
||||
}
|
||||
|
||||
# 测试数据
|
||||
test_cases = [
|
||||
{
|
||||
"name": "完整参数(含context过滤)",
|
||||
"data": {
|
||||
"query": "三年级语文学习",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
"context": {"grade": "三年级", "subject": "语文"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "简化参数(无context)",
|
||||
"data": {
|
||||
"query": "三年级语文学习",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试: {test_case['name']}")
|
||||
print(f"{'='*80}")
|
||||
print(f"请求数据: {json.dumps(test_case['data'], ensure_ascii=False)}")
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
response = requests.post(
|
||||
f"{API_BASE}/api/v1/mid/kb-search-dynamic",
|
||||
headers=headers,
|
||||
json=test_case['data'],
|
||||
timeout=30,
|
||||
)
|
||||
elapsed = (time.time() - start) * 1000
|
||||
|
||||
print(f"\n响应状态: {response.status_code}")
|
||||
print(f"总耗时: {elapsed:.2f} ms")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"API 结果:")
|
||||
print(f" success: {result.get('success')}")
|
||||
print(f" hits count: {len(result.get('hits', []))}")
|
||||
print(f" duration_ms: {result.get('duration_ms')}")
|
||||
print(f" applied_filter: {result.get('applied_filter')}")
|
||||
else:
|
||||
print(f"错误: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"请求失败: {e}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_kb_search()
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
测试 Redis 缓存是否正常工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
|
||||
|
||||
async def test_cache():
|
||||
"""测试缓存服务"""
|
||||
print("=" * 80)
|
||||
print("测试 Redis 缓存")
|
||||
print("=" * 80)
|
||||
|
||||
cache_service = await get_metadata_cache_service()
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
# 1. 检查缓存是否存在
|
||||
print("\n📊 1. 检查缓存是否存在")
|
||||
cached = await cache_service.get_fields(tenant_id)
|
||||
if cached:
|
||||
print(f" ✅ 缓存存在,包含 {len(cached)} 个字段")
|
||||
for field in cached[:3]:
|
||||
print(f" - {field['field_key']}: {field['label']}")
|
||||
else:
|
||||
print(" ❌ 缓存不存在")
|
||||
|
||||
# 2. 手动设置缓存
|
||||
print("\n📊 2. 手动设置测试缓存")
|
||||
test_fields = [
|
||||
{
|
||||
"field_key": "grade",
|
||||
"label": "年级",
|
||||
"field_type": "enum",
|
||||
"required": False,
|
||||
"options": ["三年级", "四年级", "五年级"],
|
||||
"default_value": None,
|
||||
"is_filterable": True,
|
||||
},
|
||||
{
|
||||
"field_key": "subject",
|
||||
"label": "学科",
|
||||
"field_type": "enum",
|
||||
"required": False,
|
||||
"options": ["语文", "数学", "英语"],
|
||||
"default_value": None,
|
||||
"is_filterable": True,
|
||||
},
|
||||
]
|
||||
|
||||
result = await cache_service.set_fields(tenant_id, test_fields, ttl=3600)
|
||||
print(f" 设置缓存结果: {result}")
|
||||
|
||||
# 3. 再次获取缓存
|
||||
print("\n📊 3. 再次获取缓存")
|
||||
cached = await cache_service.get_fields(tenant_id)
|
||||
if cached:
|
||||
print(f" ✅ 缓存存在,包含 {len(cached)} 个字段")
|
||||
else:
|
||||
print(" ❌ 缓存不存在")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("测试完成")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_cache())
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
测试动态生成的工具 Schema
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
|
||||
|
||||
|
||||
async def test_dynamic_tool_schema():
|
||||
"""测试动态生成的工具 Schema"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试动态生成的工具 Schema")
|
||||
print(f"{'='*80}")
|
||||
print(f"租户 ID: {tenant_id}")
|
||||
|
||||
async with async_session() as session:
|
||||
tool = KbSearchDynamicTool(session)
|
||||
|
||||
# 获取静态 Schema
|
||||
static_schema = tool.get_tool_schema()
|
||||
print(f"\n--- 静态 Schema ---")
|
||||
print(json.dumps(static_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
# 获取动态 Schema
|
||||
dynamic_schema = await tool.get_dynamic_tool_schema(tenant_id)
|
||||
print(f"\n--- 动态 Schema ---")
|
||||
print(json.dumps(dynamic_schema, indent=2, ensure_ascii=False))
|
||||
|
||||
# 再次获取,测试缓存
|
||||
print(f"\n--- 测试缓存 ---")
|
||||
dynamic_schema2 = await tool.get_dynamic_tool_schema(tenant_id)
|
||||
print(f"缓存命中: {dynamic_schema == dynamic_schema2}")
|
||||
|
||||
# 打印 context 字段的详细结构
|
||||
print(f"\n--- context 字段详情 ---")
|
||||
context_props = dynamic_schema["parameters"]["properties"].get("context", {}).get("properties", {})
|
||||
print(f"过滤字段数量: {len(context_props)}")
|
||||
for key, value in context_props.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_dynamic_tool_schema())
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
测试 kb_search_dynamic 工具是否能用给定的参数查出数据
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def test_kb_search():
|
||||
"""测试知识库搜索"""
|
||||
settings = get_settings()
|
||||
|
||||
# 创建数据库会话
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
tool = KbSearchDynamicTool(session=session)
|
||||
|
||||
# 测试参数
|
||||
test_cases = [
|
||||
{
|
||||
"name": "完整参数(含context过滤)",
|
||||
"params": {
|
||||
"query": "三年级语文学习",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
"context": {"grade": "三年级", "subject": "语文"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "简化参数(无context)",
|
||||
"params": {
|
||||
"query": "三年级语文学习",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "仅query和tenant_id",
|
||||
"params": {
|
||||
"query": "三年级语文学习",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"top_k": 5,
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试: {test_case['name']}")
|
||||
print(f"{'='*80}")
|
||||
print(f"参数: {test_case['params']}")
|
||||
|
||||
try:
|
||||
result = await tool.execute(**test_case['params'])
|
||||
|
||||
print(f"\n结果:")
|
||||
print(f" success: {result.success}")
|
||||
print(f" hits count: {len(result.hits)}")
|
||||
print(f" applied_filter: {result.applied_filter}")
|
||||
print(f" fallback_reason_code: {result.fallback_reason_code}")
|
||||
print(f" duration_ms: {result.duration_ms}")
|
||||
|
||||
if result.hits:
|
||||
print(f"\n 前3条结果:")
|
||||
for i, hit in enumerate(result.hits[:3], 1):
|
||||
text = hit.get('text', '')[:80] + '...' if hit.get('text') else 'N/A'
|
||||
score = hit.get('score', 0)
|
||||
metadata = hit.get('metadata', {})
|
||||
print(f" {i}. [score={score:.4f}] {text}")
|
||||
print(f" metadata: {metadata}")
|
||||
else:
|
||||
print(f"\n ⚠️ 没有命中任何结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_kb_search())
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
测试 kb_search_dynamic 工具 - 课程咨询场景
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicTool,
|
||||
KbSearchDynamicConfig,
|
||||
StepKbConfig,
|
||||
)
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def test_kb_search():
|
||||
"""测试知识库搜索 - 课程咨询场景"""
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=10,
|
||||
timeout_ms=15000,
|
||||
min_score_threshold=0.3,
|
||||
)
|
||||
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
step_kb_config = StepKbConfig(
|
||||
allowed_kb_ids=[course_kb_id],
|
||||
preferred_kb_ids=[course_kb_id],
|
||||
step_id="test_course_query",
|
||||
)
|
||||
|
||||
test_params = {
|
||||
"query": "课程介绍",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"top_k": 10,
|
||||
"context": {
|
||||
"grade": "五年级",
|
||||
},
|
||||
"step_kb_config": step_kb_config,
|
||||
}
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试: kb_search_dynamic - 课程知识库")
|
||||
print(f"{'='*80}")
|
||||
print(f"参数: query={test_params['query']}")
|
||||
print(f" tenant_id={test_params['tenant_id']}")
|
||||
print(f" context={test_params['context']}")
|
||||
print(f" step_kb_config.allowed_kb_ids={step_kb_config.allowed_kb_ids}")
|
||||
print(f"超时设置: {config.timeout_ms}ms")
|
||||
print(f"最低分数阈值: {config.min_score_threshold}")
|
||||
|
||||
try:
|
||||
result = await tool.execute(**test_params)
|
||||
|
||||
print(f"\n结果:")
|
||||
print(f" success: {result.success}")
|
||||
print(f" hits count: {len(result.hits)}")
|
||||
print(f" applied_filter: {result.applied_filter}")
|
||||
print(f" fallback_reason_code: {result.fallback_reason_code}")
|
||||
print(f" duration_ms: {result.duration_ms}")
|
||||
|
||||
if result.filter_debug:
|
||||
print(f" filter_debug: {result.filter_debug}")
|
||||
|
||||
if result.step_kb_binding:
|
||||
print(f" step_kb_binding: {result.step_kb_binding}")
|
||||
|
||||
if result.tool_trace:
|
||||
print(f"\n Tool Trace:")
|
||||
print(f" tool_name: {result.tool_trace.tool_name}")
|
||||
print(f" status: {result.tool_trace.status}")
|
||||
print(f" duration_ms: {result.tool_trace.duration_ms}")
|
||||
print(f" args_digest: {result.tool_trace.args_digest}")
|
||||
print(f" result_digest: {result.tool_trace.result_digest}")
|
||||
if hasattr(result.tool_trace, 'arguments') and result.tool_trace.arguments:
|
||||
print(f" arguments: {result.tool_trace.arguments}")
|
||||
|
||||
if result.hits:
|
||||
print(f"\n 检索结果 (共 {len(result.hits)} 条):")
|
||||
for i, hit in enumerate(result.hits, 1):
|
||||
text = hit.get('text', '')
|
||||
text_preview = text[:200] + '...' if len(text) > 200 else text
|
||||
score = hit.get('score', 0)
|
||||
metadata = hit.get('metadata', {})
|
||||
collection = hit.get('collection', 'unknown')
|
||||
kb_id = hit.get('kb_id', 'unknown')
|
||||
print(f"\n [{i}] score={score:.4f}")
|
||||
print(f" collection: {collection}")
|
||||
print(f" kb_id: {kb_id}")
|
||||
print(f" metadata: {metadata}")
|
||||
print(f" text: {text_preview}")
|
||||
else:
|
||||
print(f"\n ⚠️ 没有命中任何结果")
|
||||
print(f" 请检查:")
|
||||
print(f" 1. 知识库是否有数据")
|
||||
print(f" 2. 向量是否正确生成")
|
||||
print(f" 3. 过滤条件是否过于严格")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_kb_search())
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
测试 kb_search_dynamic 工具 - 增加超时时间
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def test_kb_search():
|
||||
"""测试知识库搜索 - 增加超时时间"""
|
||||
settings = get_settings()
|
||||
|
||||
# 创建数据库会话
|
||||
engine = create_async_engine(settings.database_url)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
# 使用更长的超时时间
|
||||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=10000, # 10秒超时
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
tool = KbSearchDynamicTool(session=session, config=config)
|
||||
|
||||
# 测试参数
|
||||
test_cases = [
|
||||
{
|
||||
"name": "完整参数(含context过滤)",
|
||||
"params": {
|
||||
"query": "三年级语文学习",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
"context": {"grade": "三年级", "subject": "语文"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "简化参数(无context)",
|
||||
"params": {
|
||||
"query": "三年级语文学习",
|
||||
"tenant_id": "szmp@ash@2026",
|
||||
"scene": "学习方案",
|
||||
"top_k": 5,
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试: {test_case['name']}")
|
||||
print(f"{'='*80}")
|
||||
print(f"参数: {test_case['params']}")
|
||||
print(f"超时设置: {config.timeout_ms}ms")
|
||||
|
||||
try:
|
||||
result = await tool.execute(**test_case['params'])
|
||||
|
||||
print(f"\n结果:")
|
||||
print(f" success: {result.success}")
|
||||
print(f" hits count: {len(result.hits)}")
|
||||
print(f" applied_filter: {result.applied_filter}")
|
||||
print(f" fallback_reason_code: {result.fallback_reason_code}")
|
||||
print(f" duration_ms: {result.duration_ms}")
|
||||
|
||||
if result.hits:
|
||||
print(f"\n 所有结果:")
|
||||
for i, hit in enumerate(result.hits, 1):
|
||||
text = hit.get('text', '')[:100] + '...' if hit.get('text') else 'N/A'
|
||||
score = hit.get('score', 0)
|
||||
metadata = hit.get('metadata', {})
|
||||
collection = hit.get('collection', 'unknown')
|
||||
print(f" {i}. [score={score:.4f}] [collection={collection}]")
|
||||
print(f" text: {text}")
|
||||
print(f" metadata: {metadata}")
|
||||
else:
|
||||
print(f"\n ⚠️ 没有命中任何结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_kb_search())
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
直接测试 Qdrant 过滤功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient
|
||||
|
||||
|
||||
async def test_qdrant_filter():
|
||||
"""直接测试 Qdrant 过滤功能"""
|
||||
settings = get_settings()
|
||||
client = QdrantClient()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
tenant_id = "szmp@ash@2026"
|
||||
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"测试 Qdrant 过滤功能")
|
||||
print(f"{'='*80}")
|
||||
print(f"Collection: {collection_name}")
|
||||
|
||||
# 测试 1: 无过滤
|
||||
print(f"\n--- 测试 1: 无过滤 ---")
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=5,
|
||||
with_vectors=False,
|
||||
)
|
||||
print(f"无过滤结果数: {len(results[0])}")
|
||||
for p in results[0][:3]:
|
||||
print(f" grade: {p.payload.get('metadata', {}).get('grade')}")
|
||||
|
||||
# 测试 2: 使用 Filter 对象过滤
|
||||
print(f"\n--- 测试 2: 使用 Filter 对象过滤 (grade=五年级) ---")
|
||||
qdrant_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="metadata.grade",
|
||||
match=MatchValue(value="五年级"),
|
||||
)
|
||||
]
|
||||
)
|
||||
print(f"Filter: {qdrant_filter}")
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=10,
|
||||
with_vectors=False,
|
||||
scroll_filter=qdrant_filter,
|
||||
)
|
||||
print(f"过滤后结果数: {len(results[0])}")
|
||||
for p in results[0]:
|
||||
print(f" grade: {p.payload.get('metadata', {}).get('grade')}, text: {p.payload.get('text', '')[:50]}...")
|
||||
|
||||
# 测试 3: 使用 query_points 过滤
|
||||
print(f"\n--- 测试 3: 使用 query_points 过滤 ---")
|
||||
# 先获取一个向量用于测试
|
||||
all_points = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1,
|
||||
with_vectors=True,
|
||||
)
|
||||
if all_points[0]:
|
||||
query_vector = all_points[0][0].vector
|
||||
|
||||
results = await qdrant.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
limit=10,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
print(f"query_points 过滤后结果数: {len(results.points)}")
|
||||
for p in results.points:
|
||||
print(f" grade: {p.payload.get('metadata', {}).get('grade')}, score: {p.score:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_qdrant_filter())
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
验证 Qdrant 向量数据库中的 collections 情况
|
||||
用于检查 szmp@ash@2026 租户下的知识库 collections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
async def list_collections():
|
||||
"""列出所有 collections"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
print(f"🔗 Qdrant URL: {settings.qdrant_url}")
|
||||
print(f"📦 Collection Prefix: {settings.qdrant_collection_prefix}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
|
||||
if not collections.collections:
|
||||
print("⚠️ 没有找到任何 collections")
|
||||
return
|
||||
|
||||
print(f"✅ 找到 {len(collections.collections)} 个 collections:\n")
|
||||
|
||||
# 过滤出 szmp 相关的 collections
|
||||
szmp_collections = []
|
||||
other_collections = []
|
||||
|
||||
for collection in collections.collections:
|
||||
name = collection.name
|
||||
if "szmp" in name.lower():
|
||||
szmp_collections.append(name)
|
||||
else:
|
||||
other_collections.append(name)
|
||||
|
||||
# 显示 szmp 相关的 collections
|
||||
if szmp_collections:
|
||||
print(f"🎯 szmp@ash@2026 租户相关的 collections ({len(szmp_collections)} 个):")
|
||||
print("-" * 60)
|
||||
for name in sorted(szmp_collections):
|
||||
try:
|
||||
info = await client.get_collection(name)
|
||||
points_count = info.points_count if hasattr(info, 'points_count') else 'N/A'
|
||||
print(f" 📁 {name}")
|
||||
print(f" └─ 向量数量: {points_count}")
|
||||
|
||||
# 获取 collection 信息
|
||||
if hasattr(info, 'config') and hasattr(info.config, 'params'):
|
||||
params = info.config.params
|
||||
if hasattr(params, 'vectors'):
|
||||
vector_params = params.vectors
|
||||
if hasattr(vector_params, 'size'):
|
||||
print(f" └─ 向量维度: {vector_params.size}")
|
||||
if hasattr(vector_params, 'distance'):
|
||||
print(f" └─ 距离函数: {vector_params.distance}")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f" 📁 {name}")
|
||||
print(f" └─ 获取信息失败: {e}\n")
|
||||
else:
|
||||
print("⚠️ 没有找到 szmp@ash@2026 租户相关的 collections\n")
|
||||
|
||||
# 显示其他 collections
|
||||
if other_collections:
|
||||
print(f"📂 其他 collections ({len(other_collections)} 个):")
|
||||
print("-" * 60)
|
||||
for name in sorted(other_collections):
|
||||
try:
|
||||
info = await client.get_collection(name)
|
||||
points_count = info.points_count if hasattr(info, 'points_count') else 'N/A'
|
||||
print(f" 📁 {name} (向量数: {points_count})")
|
||||
except Exception as e:
|
||||
print(f" 📁 {name} (获取信息失败: {e})")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 总结:")
|
||||
print(f" - Collections 总数: {len(collections.collections)}")
|
||||
print(f" - szmp 相关: {len(szmp_collections)} 个")
|
||||
print(f" - 其他: {len(other_collections)} 个")
|
||||
|
||||
# 验证预期
|
||||
print("\n✅ 验证:")
|
||||
if len(szmp_collections) == 2:
|
||||
print(" ✓ szmp 租户的 collection 数量符合预期 (2个)")
|
||||
else:
|
||||
print(f" ⚠️ szmp 租户的 collection 数量不符合预期 (实际: {len(szmp_collections)} 个, 预期: 2个)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 连接 Qdrant 失败: {e}")
|
||||
print(f" 请检查 Qdrant 是否运行在 {settings.qdrant_url}")
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def check_collection_details(collection_name: str):
|
||||
"""查看特定 collection 的详细信息"""
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
try:
|
||||
print(f"\n📋 Collection '{collection_name}' 详细信息:")
|
||||
print("-" * 60)
|
||||
|
||||
info = await client.get_collection(collection_name)
|
||||
print(f" 名称: {collection_name}")
|
||||
print(f" 向量数量: {info.points_count}")
|
||||
|
||||
if hasattr(info, 'config') and hasattr(info.config, 'params'):
|
||||
params = info.config.params
|
||||
|
||||
if hasattr(params, 'vectors'):
|
||||
vector_params = params.vectors
|
||||
print(f" 向量配置:")
|
||||
if hasattr(vector_params, 'size'):
|
||||
print(f" - 维度: {vector_params.size}")
|
||||
if hasattr(vector_params, 'distance'):
|
||||
print(f" - 距离函数: {vector_params.distance}")
|
||||
if hasattr(vector_params, 'on_disk'):
|
||||
print(f" - 磁盘存储: {vector_params.on_disk}")
|
||||
|
||||
if hasattr(params, 'shard_number'):
|
||||
print(f" 分片数: {params.shard_number}")
|
||||
if hasattr(params, 'replication_factor'):
|
||||
print(f" 副本数: {params.replication_factor}")
|
||||
|
||||
# 获取一些样本数据
|
||||
try:
|
||||
from qdrant_client.models import ScrollRequest
|
||||
|
||||
scroll_result = await client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
if scroll_result[0]:
|
||||
print(f"\n 样本数据 (前3条):")
|
||||
for i, point in enumerate(scroll_result[0], 1):
|
||||
payload = point.payload or {}
|
||||
text = payload.get('text', '')[:50] + '...' if payload.get('text') else 'N/A'
|
||||
kb_id = payload.get('kb_id', 'N/A')
|
||||
print(f" {i}. ID: {point.id}")
|
||||
print(f" KB ID: {kb_id}")
|
||||
print(f" 文本: {text}")
|
||||
except Exception as e:
|
||||
print(f" 获取样本数据失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取 collection 信息失败: {e}")
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("🔍 Qdrant 向量数据库 Collections 验证工具")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 列出所有 collections
|
||||
await list_collections()
|
||||
|
||||
# 检查 szmp 相关的 collections 详情
|
||||
settings = get_settings()
|
||||
client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
szmp_collections = [c.name for c in collections.collections if "szmp" in c.name.lower()]
|
||||
|
||||
for name in sorted(szmp_collections):
|
||||
await check_collection_details(name)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""
|
||||
Tests for Batch Ask-Back Service.
|
||||
[AC-MRS-SLOT-ASKBACK-01] 批量追问测试
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mid.batch_ask_back_service import (
|
||||
AskBackSlot,
|
||||
BatchAskBackConfig,
|
||||
BatchAskBackResult,
|
||||
BatchAskBackService,
|
||||
create_batch_ask_back_service,
|
||||
)
|
||||
|
||||
|
||||
class TestAskBackSlot:
|
||||
"""AskBackSlot 测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
slot = AskBackSlot(
|
||||
slot_key="region",
|
||||
label="地区",
|
||||
ask_back_prompt="请告诉我您的地区",
|
||||
priority=100,
|
||||
is_required=True,
|
||||
)
|
||||
assert slot.slot_key == "region"
|
||||
assert slot.label == "地区"
|
||||
assert slot.priority == 100
|
||||
assert slot.is_required is True
|
||||
|
||||
|
||||
class TestBatchAskBackConfig:
|
||||
"""BatchAskBackConfig 测试"""
|
||||
|
||||
def test_default_config(self):
|
||||
"""测试默认配置"""
|
||||
config = BatchAskBackConfig()
|
||||
assert config.max_ask_back_slots_per_turn == 2
|
||||
assert config.prefer_required is True
|
||||
assert config.prefer_scene_relevant is True
|
||||
assert config.avoid_recent_asked is True
|
||||
assert config.recent_asked_threshold_seconds == 60.0
|
||||
assert config.merge_prompts is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""测试自定义配置"""
|
||||
config = BatchAskBackConfig(
|
||||
max_ask_back_slots_per_turn=3,
|
||||
prefer_required=False,
|
||||
merge_prompts=False,
|
||||
)
|
||||
assert config.max_ask_back_slots_per_turn == 3
|
||||
assert config.prefer_required is False
|
||||
assert config.merge_prompts is False
|
||||
|
||||
|
||||
class TestBatchAskBackResult:
|
||||
"""BatchAskBackResult 测试"""
|
||||
|
||||
def test_has_ask_back(self):
|
||||
"""测试是否有追问"""
|
||||
result = BatchAskBackResult(ask_back_count=2)
|
||||
assert result.has_ask_back() is True
|
||||
|
||||
result = BatchAskBackResult(ask_back_count=0)
|
||||
assert result.has_ask_back() is False
|
||||
|
||||
def test_get_prompt_with_merged(self):
|
||||
"""测试获取合并后的提示"""
|
||||
result = BatchAskBackResult(
|
||||
merged_prompt="请告诉我您的地区和产品",
|
||||
prompts=["请告诉我您的地区", "请告诉我您的产品"],
|
||||
ask_back_count=2,
|
||||
)
|
||||
assert result.get_prompt() == "请告诉我您的地区和产品"
|
||||
|
||||
def test_get_prompt_without_merged(self):
|
||||
"""测试获取未合并的提示"""
|
||||
result = BatchAskBackResult(
|
||||
prompts=["请告诉我您的地区"],
|
||||
ask_back_count=1,
|
||||
)
|
||||
assert result.get_prompt() == "请告诉我您的地区"
|
||||
|
||||
def test_get_prompt_empty(self):
|
||||
"""测试空结果的提示"""
|
||||
result = BatchAskBackResult()
|
||||
assert "请提供更多信息" in result.get_prompt()
|
||||
|
||||
|
||||
class TestBatchAskBackService:
|
||||
"""BatchAskBackService 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""创建 mock session"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""创建配置"""
|
||||
return BatchAskBackConfig(
|
||||
max_ask_back_slots_per_turn=2,
|
||||
prefer_required=True,
|
||||
merge_prompts=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session, config):
|
||||
"""创建服务实例"""
|
||||
return BatchAskBackService(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
config=config,
|
||||
)
|
||||
|
||||
def test_calculate_priority_required(self, service):
|
||||
"""测试必填槽位优先级"""
|
||||
priority = service._calculate_priority(is_required=True, scene_relevance=0.0)
|
||||
assert priority == 100
|
||||
|
||||
def test_calculate_priority_scene_relevant(self, service):
|
||||
"""测试场景相关优先级"""
|
||||
priority = service._calculate_priority(is_required=False, scene_relevance=1.0)
|
||||
assert priority == 50
|
||||
|
||||
def test_calculate_priority_both(self, service):
|
||||
"""测试必填且场景相关优先级"""
|
||||
priority = service._calculate_priority(is_required=True, scene_relevance=1.0)
|
||||
assert priority == 150
|
||||
|
||||
def test_select_slots_for_ask_back(self, service):
|
||||
"""测试选择追问槽位"""
|
||||
slots = [
|
||||
AskBackSlot(slot_key="a", label="A", priority=50),
|
||||
AskBackSlot(slot_key="b", label="B", priority=100),
|
||||
AskBackSlot(slot_key="c", label="C", priority=75),
|
||||
]
|
||||
|
||||
selected = service._select_slots_for_ask_back(slots)
|
||||
|
||||
assert len(selected) == 2
|
||||
assert selected[0].slot_key == "b"
|
||||
assert selected[1].slot_key == "c"
|
||||
|
||||
def test_filter_recently_asked(self, service):
|
||||
"""测试过滤最近追问过的槽位"""
|
||||
current_time = time.time()
|
||||
asked_history = {
|
||||
"recently_asked": current_time - 30,
|
||||
"old_asked": current_time - 120,
|
||||
}
|
||||
|
||||
slots = [
|
||||
AskBackSlot(slot_key="recently_asked", label="最近追问过"),
|
||||
AskBackSlot(slot_key="old_asked", label="很久前追问过"),
|
||||
AskBackSlot(slot_key="never_asked", label="从未追问过"),
|
||||
]
|
||||
|
||||
filtered = service._filter_recently_asked(slots, asked_history)
|
||||
|
||||
assert len(filtered) == 2
|
||||
slot_keys = [s.slot_key for s in filtered]
|
||||
assert "old_asked" in slot_keys
|
||||
assert "never_asked" in slot_keys
|
||||
assert "recently_asked" not in slot_keys
|
||||
|
||||
def test_generate_prompts(self, service):
|
||||
"""测试生成追问提示"""
|
||||
slots = [
|
||||
AskBackSlot(slot_key="region", label="地区", ask_back_prompt="请告诉我您的地区"),
|
||||
AskBackSlot(slot_key="product", label="产品", ask_back_prompt=None),
|
||||
]
|
||||
|
||||
prompts = service._generate_prompts(slots)
|
||||
|
||||
assert len(prompts) == 2
|
||||
assert prompts[0] == "请告诉我您的地区"
|
||||
assert "产品" in prompts[1]
|
||||
|
||||
def test_merge_prompts_single(self, service):
|
||||
"""测试合并单个提示"""
|
||||
prompts = ["请告诉我您的地区"]
|
||||
merged = service._merge_prompts(prompts)
|
||||
assert merged == "请告诉我您的地区"
|
||||
|
||||
def test_merge_prompts_two(self, service):
|
||||
"""测试合并两个提示"""
|
||||
prompts = ["请告诉我您的地区", "请告诉我您的产品"]
|
||||
merged = service._merge_prompts(prompts)
|
||||
assert "地区" in merged
|
||||
assert "产品" in merged
|
||||
assert "以及" in merged
|
||||
|
||||
def test_merge_prompts_multiple(self, service):
|
||||
"""测试合并多个提示"""
|
||||
prompts = ["您的地区", "您的产品", "您的等级"]
|
||||
merged = service._merge_prompts(prompts)
|
||||
assert "地区" in merged
|
||||
assert "产品" in merged
|
||||
assert "等级" in merged
|
||||
assert "、" in merged
|
||||
assert "以及" in merged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_ask_back_empty(self, service):
|
||||
"""测试空缺失槽位"""
|
||||
result = await service.generate_batch_ask_back(missing_slots=[])
|
||||
assert result.has_ask_back() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_ask_back_single(self, service):
|
||||
"""测试单个缺失槽位"""
|
||||
missing_slots = [
|
||||
{
|
||||
"slot_key": "region",
|
||||
"label": "地区",
|
||||
"ask_back_prompt": "请告诉我您的地区",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get:
|
||||
mock_get.return_value = MagicMock(required=True)
|
||||
|
||||
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
|
||||
|
||||
assert result.has_ask_back() is True
|
||||
assert result.ask_back_count == 1
|
||||
assert "地区" in result.get_prompt()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_ask_back_multiple(self, service):
|
||||
"""测试多个缺失槽位"""
|
||||
missing_slots = [
|
||||
{"slot_key": "region", "label": "地区", "ask_back_prompt": "您的地区"},
|
||||
{"slot_key": "product", "label": "产品", "ask_back_prompt": "您的产品"},
|
||||
{"slot_key": "grade", "label": "等级", "ask_back_prompt": "您的等级"},
|
||||
]
|
||||
|
||||
with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get:
|
||||
mock_get.return_value = MagicMock(required=True)
|
||||
|
||||
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
|
||||
|
||||
assert result.has_ask_back() is True
|
||||
assert result.ask_back_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_batch_ask_back_prioritize_required(self, service):
|
||||
"""测试优先追问必填槽位"""
|
||||
missing_slots = [
|
||||
{"slot_key": "optional", "label": "可选", "ask_back_prompt": "可选信息"},
|
||||
{"slot_key": "required", "label": "必填", "ask_back_prompt": "必填信息"},
|
||||
]
|
||||
|
||||
def mock_get_slot(tenant_id, slot_key):
|
||||
if slot_key == "required":
|
||||
return MagicMock(required=True)
|
||||
return MagicMock(required=False)
|
||||
|
||||
with patch.object(service._slot_def_service, 'get_slot_definition_by_key', side_effect=mock_get_slot):
|
||||
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
|
||||
|
||||
assert result.has_ask_back() is True
|
||||
assert result.selected_slots[0].slot_key == "required"
|
||||
|
||||
|
||||
class TestCreateBatchAskBackService:
|
||||
"""create_batch_ask_back_service 工厂函数测试"""
|
||||
|
||||
def test_create(self):
|
||||
"""测试创建服务实例"""
|
||||
mock_session = AsyncMock()
|
||||
config = BatchAskBackConfig(max_ask_back_slots_per_turn=3)
|
||||
|
||||
service = create_batch_ask_back_service(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert isinstance(service, BatchAskBackService)
|
||||
assert service._tenant_id == "tenant_1"
|
||||
assert service._session_id == "session_1"
|
||||
assert service._config.max_ask_back_slots_per_turn == 3
|
||||
|
||||
|
||||
class TestAskBackHistory:
|
||||
"""追问历史测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""创建服务实例"""
|
||||
mock_session = AsyncMock()
|
||||
return BatchAskBackService(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asked_history_empty(self, service):
|
||||
"""测试获取空历史"""
|
||||
with patch.object(service._cache, '_get_client') as mock_client:
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_client.return_value = mock_redis
|
||||
|
||||
history = await service._get_asked_history()
|
||||
assert history == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asked_history_with_data(self, service):
|
||||
"""测试获取有数据的历史"""
|
||||
import json
|
||||
|
||||
history_data = {"region": 12345.0, "product": 12346.0}
|
||||
|
||||
with patch.object(service._cache, '_get_client') as mock_client:
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps(history_data))
|
||||
mock_client.return_value = mock_redis
|
||||
|
||||
history = await service._get_asked_history()
|
||||
assert history == history_data
|
||||
|
|
@ -0,0 +1,543 @@
|
|||
"""
|
||||
Tests for clarification mechanism.
|
||||
[AC-CLARIFY] 澄清机制测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.services.intent.clarification import (
|
||||
ClarificationEngine,
|
||||
ClarifyMetrics,
|
||||
ClarifyReason,
|
||||
ClarifySessionManager,
|
||||
ClarifyState,
|
||||
HybridIntentResult,
|
||||
IntentCandidate,
|
||||
T_HIGH,
|
||||
T_LOW,
|
||||
MAX_CLARIFY_RETRY,
|
||||
get_clarify_metrics,
|
||||
)
|
||||
|
||||
|
||||
class TestClarifyMetrics:
|
||||
def test_singleton_pattern(self):
|
||||
m1 = ClarifyMetrics()
|
||||
m2 = ClarifyMetrics()
|
||||
assert m1 is m2
|
||||
|
||||
def test_record_clarify_trigger(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.reset()
|
||||
|
||||
metrics.record_clarify_trigger()
|
||||
metrics.record_clarify_trigger()
|
||||
metrics.record_clarify_trigger()
|
||||
|
||||
counts = metrics.get_metrics()
|
||||
assert counts["clarify_trigger_rate"] == 3
|
||||
|
||||
def test_record_clarify_converge(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.reset()
|
||||
|
||||
metrics.record_clarify_converge()
|
||||
metrics.record_clarify_converge()
|
||||
|
||||
counts = metrics.get_metrics()
|
||||
assert counts["clarify_converge_rate"] == 2
|
||||
|
||||
def test_record_misroute(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.reset()
|
||||
|
||||
metrics.record_misroute()
|
||||
|
||||
counts = metrics.get_metrics()
|
||||
assert counts["misroute_rate"] == 1
|
||||
|
||||
def test_get_rates(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.reset()
|
||||
|
||||
metrics.record_clarify_trigger()
|
||||
metrics.record_clarify_converge()
|
||||
metrics.record_misroute()
|
||||
|
||||
rates = metrics.get_rates(100)
|
||||
assert rates["clarify_trigger_rate"] == 0.01
|
||||
assert rates["clarify_converge_rate"] == 1.0
|
||||
assert rates["misroute_rate"] == 0.01
|
||||
|
||||
def test_get_rates_zero_requests(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.reset()
|
||||
|
||||
rates = metrics.get_rates(0)
|
||||
assert rates["clarify_trigger_rate"] == 0.0
|
||||
assert rates["clarify_converge_rate"] == 0.0
|
||||
assert rates["misroute_rate"] == 0.0
|
||||
|
||||
def test_reset(self):
|
||||
metrics = ClarifyMetrics()
|
||||
metrics.record_clarify_trigger()
|
||||
metrics.record_clarify_converge()
|
||||
metrics.record_misroute()
|
||||
|
||||
metrics.reset()
|
||||
|
||||
counts = metrics.get_metrics()
|
||||
assert counts["clarify_trigger_rate"] == 0
|
||||
assert counts["clarify_converge_rate"] == 0
|
||||
assert counts["misroute_rate"] == 0
|
||||
|
||||
|
||||
class TestIntentCandidate:
|
||||
def test_to_dict(self):
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.85,
|
||||
response_type="flow",
|
||||
target_kb_ids=["kb-1"],
|
||||
flow_id="flow-1",
|
||||
fixed_reply=None,
|
||||
transfer_message=None,
|
||||
)
|
||||
|
||||
result = candidate.to_dict()
|
||||
|
||||
assert result["intent_id"] == "intent-1"
|
||||
assert result["intent_name"] == "退货意图"
|
||||
assert result["confidence"] == 0.85
|
||||
assert result["response_type"] == "flow"
|
||||
assert result["target_kb_ids"] == ["kb-1"]
|
||||
assert result["flow_id"] == "flow-1"
|
||||
|
||||
|
||||
class TestHybridIntentResult:
|
||||
def test_to_dict(self):
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=candidate,
|
||||
confidence=0.85,
|
||||
candidates=[candidate],
|
||||
need_clarify=False,
|
||||
clarify_reason=None,
|
||||
missing_slots=[],
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["intent"]["intent_id"] == "intent-1"
|
||||
assert d["confidence"] == 0.85
|
||||
assert len(d["candidates"]) == 1
|
||||
assert d["need_clarify"] is False
|
||||
|
||||
def test_from_fusion_result(self):
|
||||
mock_fusion = MagicMock()
|
||||
mock_fusion.final_intent = MagicMock()
|
||||
mock_fusion.final_intent.id = "intent-1"
|
||||
mock_fusion.final_intent.name = "退货意图"
|
||||
mock_fusion.final_intent.response_type = "flow"
|
||||
mock_fusion.final_intent.target_kb_ids = ["kb-1"]
|
||||
mock_fusion.final_intent.flow_id = None
|
||||
mock_fusion.final_intent.fixed_reply = None
|
||||
mock_fusion.final_intent.transfer_message = None
|
||||
mock_fusion.final_confidence = 0.85
|
||||
mock_fusion.need_clarify = False
|
||||
mock_fusion.decision_reason = "rule_high_confidence"
|
||||
mock_fusion.clarify_candidates = []
|
||||
|
||||
result = HybridIntentResult.from_fusion_result(mock_fusion)
|
||||
|
||||
assert result.intent is not None
|
||||
assert result.intent.intent_id == "intent-1"
|
||||
assert result.confidence == 0.85
|
||||
assert result.need_clarify is False
|
||||
|
||||
def test_from_fusion_result_with_clarify(self):
|
||||
mock_fusion = MagicMock()
|
||||
mock_fusion.final_intent = None
|
||||
mock_fusion.final_confidence = 0.5
|
||||
mock_fusion.need_clarify = True
|
||||
mock_fusion.decision_reason = "multi_intent"
|
||||
|
||||
candidate1 = MagicMock()
|
||||
candidate1.id = "intent-1"
|
||||
candidate1.name = "退货意图"
|
||||
candidate1.response_type = "flow"
|
||||
candidate1.target_kb_ids = None
|
||||
candidate1.flow_id = None
|
||||
candidate1.fixed_reply = None
|
||||
candidate1.transfer_message = None
|
||||
|
||||
candidate2 = MagicMock()
|
||||
candidate2.id = "intent-2"
|
||||
candidate2.name = "换货意图"
|
||||
candidate2.response_type = "flow"
|
||||
candidate2.target_kb_ids = None
|
||||
candidate2.flow_id = None
|
||||
candidate2.fixed_reply = None
|
||||
candidate2.transfer_message = None
|
||||
|
||||
mock_fusion.clarify_candidates = [candidate1, candidate2]
|
||||
|
||||
result = HybridIntentResult.from_fusion_result(mock_fusion)
|
||||
|
||||
assert result.need_clarify is True
|
||||
assert result.clarify_reason == ClarifyReason.MULTI_INTENT
|
||||
assert len(result.candidates) == 2
|
||||
|
||||
|
||||
class TestClarifyState:
|
||||
def test_to_dict(self):
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
state = ClarifyState(
|
||||
reason=ClarifyReason.INTENT_AMBIGUITY,
|
||||
asked_slot=None,
|
||||
retry_count=1,
|
||||
candidates=[candidate],
|
||||
asked_intent_ids=["intent-1"],
|
||||
)
|
||||
|
||||
d = state.to_dict()
|
||||
|
||||
assert d["reason"] == "intent_ambiguity"
|
||||
assert d["retry_count"] == 1
|
||||
assert len(d["candidates"]) == 1
|
||||
|
||||
def test_increment_retry(self):
|
||||
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
|
||||
|
||||
state.increment_retry()
|
||||
|
||||
assert state.retry_count == 1
|
||||
|
||||
state.increment_retry()
|
||||
|
||||
assert state.retry_count == 2
|
||||
|
||||
def test_is_max_retry(self):
|
||||
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
|
||||
|
||||
assert not state.is_max_retry()
|
||||
|
||||
state.retry_count = MAX_CLARIFY_RETRY
|
||||
|
||||
assert state.is_max_retry()
|
||||
|
||||
|
||||
class TestClarificationEngine:
|
||||
def test_compute_confidence_rule_only(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
confidence = engine.compute_confidence(
|
||||
rule_score=1.0,
|
||||
semantic_score=0.0,
|
||||
llm_score=0.0,
|
||||
w_rule=1.0,
|
||||
w_semantic=0.0,
|
||||
w_llm=0.0,
|
||||
)
|
||||
|
||||
assert confidence == 1.0
|
||||
|
||||
def test_compute_confidence_semantic_only(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
confidence = engine.compute_confidence(
|
||||
rule_score=0.0,
|
||||
semantic_score=0.8,
|
||||
llm_score=0.0,
|
||||
w_rule=0.3,
|
||||
w_semantic=0.5,
|
||||
w_llm=0.2,
|
||||
)
|
||||
|
||||
# With weights w_rule=0.3, w_semantic=0.5, w_llm=0.2 and scores
|
||||
# rule=0.0, semantic=0.8, llm=0.0:
|
||||
# confidence = (0.0*0.3 + 0.8*0.5 + 0.0*0.2) / (0.3+0.5+0.2) = 0.4/1.0 = 0.4
|
||||
assert confidence == 0.4
|
||||
|
||||
def test_compute_confidence_weighted(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
confidence = engine.compute_confidence(
|
||||
rule_score=1.0,
|
||||
semantic_score=0.8,
|
||||
llm_score=0.9,
|
||||
w_rule=0.5,
|
||||
w_semantic=0.3,
|
||||
w_llm=0.2,
|
||||
)
|
||||
|
||||
expected = (1.0 * 0.5 + 0.8 * 0.3 + 0.9 * 0.2) / 1.0
|
||||
assert abs(confidence - expected) < 0.001
|
||||
|
||||
def test_check_hard_block_low_confidence(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.5,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
is_blocked, reason = engine.check_hard_block(result)
|
||||
|
||||
assert is_blocked is True
|
||||
assert reason == ClarifyReason.LOW_CONFIDENCE
|
||||
|
||||
def test_check_hard_block_high_confidence(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.85,
|
||||
),
|
||||
confidence=0.85,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
is_blocked, reason = engine.check_hard_block(result)
|
||||
|
||||
assert is_blocked is False
|
||||
assert reason is None
|
||||
|
||||
def test_check_hard_block_missing_slots(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.85,
|
||||
),
|
||||
confidence=0.85,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
is_blocked, reason = engine.check_hard_block(
|
||||
result,
|
||||
required_slots=["order_id", "product_id"],
|
||||
filled_slots={"order_id": "123"},
|
||||
)
|
||||
|
||||
assert is_blocked is True
|
||||
assert reason == ClarifyReason.MISSING_SLOT
|
||||
|
||||
def test_should_trigger_clarify_below_t_low(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.3,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
should_clarify, state = engine.should_trigger_clarify(result)
|
||||
|
||||
assert should_clarify is True
|
||||
assert state is not None
|
||||
assert state.reason == ClarifyReason.LOW_CONFIDENCE
|
||||
|
||||
def test_should_trigger_clarify_gray_zone(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=candidate,
|
||||
confidence=0.5,
|
||||
candidates=[candidate],
|
||||
need_clarify=True,
|
||||
clarify_reason=ClarifyReason.INTENT_AMBIGUITY,
|
||||
)
|
||||
|
||||
should_clarify, state = engine.should_trigger_clarify(result)
|
||||
|
||||
assert should_clarify is True
|
||||
assert state is not None
|
||||
assert state.reason == ClarifyReason.INTENT_AMBIGUITY
|
||||
|
||||
def test_should_trigger_clarify_above_t_high(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
result = HybridIntentResult(
|
||||
intent=candidate,
|
||||
confidence=0.85,
|
||||
candidates=[candidate],
|
||||
)
|
||||
|
||||
should_clarify, state = engine.should_trigger_clarify(result)
|
||||
|
||||
assert should_clarify is False
|
||||
assert state is None
|
||||
|
||||
def test_generate_clarify_prompt_missing_slot(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
state = ClarifyState(
|
||||
reason=ClarifyReason.MISSING_SLOT,
|
||||
asked_slot="order_id",
|
||||
)
|
||||
|
||||
prompt = engine.generate_clarify_prompt(state)
|
||||
|
||||
assert "order_id" in prompt or "相关信息" in prompt
|
||||
|
||||
def test_generate_clarify_prompt_low_confidence(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
|
||||
|
||||
prompt = engine.generate_clarify_prompt(state)
|
||||
|
||||
assert "理解" in prompt or "详细" in prompt
|
||||
|
||||
def test_generate_clarify_prompt_multi_intent(self):
|
||||
engine = ClarificationEngine()
|
||||
|
||||
candidates = [
|
||||
IntentCandidate(intent_id="1", intent_name="退货", confidence=0.5),
|
||||
IntentCandidate(intent_id="2", intent_name="换货", confidence=0.4),
|
||||
]
|
||||
|
||||
state = ClarifyState(
|
||||
reason=ClarifyReason.MULTI_INTENT,
|
||||
candidates=candidates,
|
||||
)
|
||||
|
||||
prompt = engine.generate_clarify_prompt(state)
|
||||
|
||||
assert "退货" in prompt
|
||||
assert "换货" in prompt
|
||||
|
||||
def test_process_clarify_response_max_retry(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
state = ClarifyState(
|
||||
reason=ClarifyReason.LOW_CONFIDENCE,
|
||||
retry_count=MAX_CLARIFY_RETRY,
|
||||
)
|
||||
|
||||
result = engine.process_clarify_response("用户回复", state)
|
||||
|
||||
assert result.intent is None
|
||||
assert result.confidence == 0.0
|
||||
assert result.need_clarify is False
|
||||
|
||||
def test_process_clarify_response_missing_slot(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
candidate = IntentCandidate(
|
||||
intent_id="intent-1",
|
||||
intent_name="退货意图",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
state = ClarifyState(
|
||||
reason=ClarifyReason.MISSING_SLOT,
|
||||
asked_slot="order_id",
|
||||
candidates=[candidate],
|
||||
)
|
||||
|
||||
result = engine.process_clarify_response("订单号是123", state)
|
||||
|
||||
assert result.intent is not None
|
||||
assert result.need_clarify is False
|
||||
|
||||
def test_get_metrics(self):
|
||||
engine = ClarificationEngine()
|
||||
get_clarify_metrics().reset()
|
||||
|
||||
engine._metrics.record_clarify_trigger()
|
||||
engine._metrics.record_clarify_converge()
|
||||
|
||||
metrics = engine.get_metrics()
|
||||
|
||||
assert metrics["clarify_trigger_rate"] == 1
|
||||
assert metrics["clarify_converge_rate"] == 1
|
||||
|
||||
|
||||
class TestClarifySessionManager:
|
||||
def test_set_and_get_session(self):
|
||||
ClarifySessionManager.clear_session("test-session")
|
||||
|
||||
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
|
||||
|
||||
ClarifySessionManager.set_session("test-session", state)
|
||||
|
||||
retrieved = ClarifySessionManager.get_session("test-session")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.reason == ClarifyReason.LOW_CONFIDENCE
|
||||
|
||||
def test_clear_session(self):
|
||||
ClarifySessionManager.set_session(
|
||||
"test-session",
|
||||
ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE),
|
||||
)
|
||||
|
||||
ClarifySessionManager.clear_session("test-session")
|
||||
|
||||
retrieved = ClarifySessionManager.get_session("test-session")
|
||||
|
||||
assert retrieved is None
|
||||
|
||||
def test_has_active_clarify(self):
|
||||
ClarifySessionManager.clear_session("test-session")
|
||||
|
||||
assert not ClarifySessionManager.has_active_clarify("test-session")
|
||||
|
||||
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
|
||||
ClarifySessionManager.set_session("test-session", state)
|
||||
|
||||
assert ClarifySessionManager.has_active_clarify("test-session")
|
||||
|
||||
state.retry_count = MAX_CLARIFY_RETRY
|
||||
|
||||
assert not ClarifySessionManager.has_active_clarify("test-session")
|
||||
|
||||
|
||||
class TestThresholds:
|
||||
def test_t_high_value(self):
|
||||
assert T_HIGH == 0.75
|
||||
|
||||
def test_t_low_value(self):
|
||||
assert T_LOW == 0.45
|
||||
|
||||
def test_t_high_greater_than_t_low(self):
|
||||
assert T_HIGH > T_LOW
|
||||
|
||||
def test_max_retry_value(self):
|
||||
assert MAX_CLARIFY_RETRY == 3
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
"""
|
||||
Tests for Dialogue API with Slot State Integration.
|
||||
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
|
||||
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
|
||||
from app.services.mid.slot_state_aggregator import SlotState
|
||||
|
||||
|
||||
class TestGenerateAskBackResponse:
|
||||
"""测试生成追问响应"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_with_prompt(self):
|
||||
"""测试使用配置的 ask_back_prompt 生成追问"""
|
||||
slot_state = SlotState()
|
||||
missing_slots = [
|
||||
{
|
||||
"slot_key": "region",
|
||||
"label": "地区",
|
||||
"ask_back_prompt": "请问您在哪个地区?",
|
||||
}
|
||||
]
|
||||
mock_session = AsyncMock()
|
||||
|
||||
response = await _generate_ask_back_for_missing_slots(
|
||||
slot_state=slot_state,
|
||||
missing_slots=missing_slots,
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert response == "请问您在哪个地区?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_generic(self):
|
||||
"""测试使用通用模板生成追问"""
|
||||
slot_state = SlotState()
|
||||
missing_slots = [
|
||||
{
|
||||
"slot_key": "product_line",
|
||||
"label": "产品线",
|
||||
# 没有 ask_back_prompt
|
||||
}
|
||||
]
|
||||
mock_session = AsyncMock()
|
||||
|
||||
response = await _generate_ask_back_for_missing_slots(
|
||||
slot_state=slot_state,
|
||||
missing_slots=missing_slots,
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert "产品线" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_empty_slots(self):
|
||||
"""测试空缺失槽位列表"""
|
||||
slot_state = SlotState()
|
||||
missing_slots = []
|
||||
mock_session = AsyncMock()
|
||||
|
||||
response = await _generate_ask_back_for_missing_slots(
|
||||
slot_state=slot_state,
|
||||
missing_slots=missing_slots,
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert "更多信息" in response
|
||||
|
||||
|
||||
class TestDialogueAskBackResponse:
|
||||
"""测试对话追问响应"""
|
||||
|
||||
def test_dialogue_response_with_ask_back(self):
|
||||
"""测试追问响应的结构"""
|
||||
from app.models.mid.schemas import DialogueResponse
|
||||
|
||||
response = DialogueResponse(
|
||||
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id="test_request_id",
|
||||
generation_id="test_generation_id",
|
||||
fallback_reason_code="missing_required_slots",
|
||||
kb_tool_called=True,
|
||||
kb_hit=False,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(response.segments) == 1
|
||||
assert "哪个产品线" in response.segments[0].text
|
||||
assert response.trace.fallback_reason_code == "missing_required_slots"
|
||||
assert response.trace.kb_tool_called is True
|
||||
assert response.trace.kb_hit is False
|
||||
|
||||
|
||||
class TestSlotStateAggregationFlow:
|
||||
"""测试槽位状态聚合流程"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_slots_included_in_state(self):
|
||||
"""测试 memory_recall 的槽位被包含在状态中"""
|
||||
from app.models.mid.schemas import MemorySlot, SlotSource
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||||
|
||||
mock_session = AsyncMock()
|
||||
aggregator = SlotStateAggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
memory_slots = {
|
||||
"product_line": MemorySlot(
|
||||
key="product_line",
|
||||
value="vip_course",
|
||||
source=SlotSource.USER_CONFIRMED,
|
||||
confidence=1.0,
|
||||
)
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots=memory_slots,
|
||||
current_input_slots=None,
|
||||
context=None,
|
||||
)
|
||||
|
||||
assert "product_line" in state.filled_slots
|
||||
assert state.filled_slots["product_line"] == "vip_course"
|
||||
assert state.slot_sources["product_line"] == "user_confirmed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_slots_identified(self):
|
||||
"""测试缺失的必填槽位被正确识别"""
|
||||
from unittest.mock import MagicMock
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||||
|
||||
mock_session = AsyncMock()
|
||||
aggregator = SlotStateAggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# 模拟一个 required 的槽位定义
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.slot_key = "region"
|
||||
mock_slot_def.required = True
|
||||
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
|
||||
mock_slot_def.linked_field_id = None
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[mock_slot_def],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots={},
|
||||
current_input_slots=None,
|
||||
context=None,
|
||||
)
|
||||
|
||||
assert len(state.missing_required_slots) == 1
|
||||
assert state.missing_required_slots[0]["slot_key"] == "region"
|
||||
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
|
||||
|
||||
|
||||
class TestSlotMetadataLinkage:
|
||||
"""测试槽位与元数据关联"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slot_to_field_mapping(self):
|
||||
"""测试槽位到元数据字段的映射"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
mock_session = AsyncMock()
|
||||
aggregator = SlotStateAggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# 模拟槽位定义(带 linked_field_id)
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.slot_key = "product"
|
||||
mock_slot_def.linked_field_id = "field-uuid-123"
|
||||
mock_slot_def.required = False
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.options = None
|
||||
|
||||
# 模拟关联的元数据字段
|
||||
mock_field = MagicMock()
|
||||
mock_field.field_key = "product_line"
|
||||
mock_field.label = "产品线"
|
||||
mock_field.type = "string"
|
||||
mock_field.required = False
|
||||
mock_field.options = None
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[mock_slot_def],
|
||||
):
|
||||
with patch.object(
|
||||
MetadataFieldDefinitionService,
|
||||
"get_field_definition",
|
||||
return_value=mock_field,
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots={},
|
||||
current_input_slots=None,
|
||||
context=None,
|
||||
)
|
||||
|
||||
# 验证映射已建立
|
||||
assert state.slot_to_field_map.get("product") == "product_line"
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""测试向后兼容性"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_search_without_slot_state(self):
|
||||
"""测试不使用 slot_state 时 KB 检索仍然工作"""
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicConfig,
|
||||
KbSearchDynamicTool,
|
||||
)
|
||||
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
||||
|
||||
mock_session = AsyncMock()
|
||||
kb_tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
# 模拟 filter_builder 返回空结果
|
||||
with patch.object(
|
||||
MetadataFilterBuilder,
|
||||
"_get_filterable_fields",
|
||||
return_value=[],
|
||||
):
|
||||
with patch.object(
|
||||
kb_tool,
|
||||
"_retrieve_with_timeout",
|
||||
return_value=[],
|
||||
):
|
||||
result = await kb_tool.execute(
|
||||
query="退款政策",
|
||||
tenant_id="test_tenant",
|
||||
context={},
|
||||
slot_state=None, # 不提供 slot_state
|
||||
)
|
||||
|
||||
# 应该成功执行
|
||||
assert result.success is True
|
||||
assert result.fallback_reason_code is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_context_filter(self):
|
||||
"""测试使用传统 context 构建过滤器"""
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicConfig,
|
||||
KbSearchDynamicTool,
|
||||
)
|
||||
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
||||
|
||||
mock_session = AsyncMock()
|
||||
kb_tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
# 使用简单 context
|
||||
context = {"product_line": "vip_course", "region": "beijing"}
|
||||
|
||||
with patch.object(
|
||||
MetadataFilterBuilder,
|
||||
"_get_filterable_fields",
|
||||
return_value=[],
|
||||
):
|
||||
with patch.object(
|
||||
kb_tool,
|
||||
"_retrieve_with_timeout",
|
||||
return_value=[],
|
||||
):
|
||||
result = await kb_tool.execute(
|
||||
query="退款政策",
|
||||
tenant_id="test_tenant",
|
||||
context=context,
|
||||
slot_state=None,
|
||||
)
|
||||
|
||||
# 应该成功执行
|
||||
assert result.success is True
|
||||
# 简单 context 应该直接使用作为 filter
|
||||
assert result.applied_filter.get("product_line") == "vip_course"
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
"""
|
||||
Tests for field_roles update functionality.
|
||||
[AC-MRS-01] 验证字段角色更新功能
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.models.entities import (
|
||||
MetadataFieldDefinition,
|
||||
MetadataFieldDefinitionUpdate,
|
||||
MetadataFieldStatus,
|
||||
)
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
|
||||
class TestFieldRolesUpdate:
|
||||
"""测试字段角色更新功能"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock session"""
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance"""
|
||||
return MetadataFieldDefinitionService(mock_session)
|
||||
|
||||
@pytest.fixture
|
||||
def existing_field(self):
|
||||
"""Create existing field with field_roles"""
|
||||
field = MagicMock(spec=MetadataFieldDefinition)
|
||||
field.id = uuid.uuid4()
|
||||
field.tenant_id = "test-tenant"
|
||||
field.field_key = "grade"
|
||||
field.label = "年级"
|
||||
field.type = "string"
|
||||
field.required = True
|
||||
field.options = None
|
||||
field.default_value = None
|
||||
field.scope = ["kb_document"]
|
||||
field.is_filterable = True
|
||||
field.is_rank_feature = False
|
||||
field.field_roles = ["slot"] # 初始角色
|
||||
field.status = MetadataFieldStatus.ACTIVE.value
|
||||
field.version = 1
|
||||
return field
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_field_roles_success(self, service, mock_session, existing_field):
|
||||
"""[AC-MRS-01] 测试成功更新字段角色"""
|
||||
# Mock get_field_definition to return existing field
|
||||
service.get_field_definition = AsyncMock(return_value=existing_field)
|
||||
|
||||
# Create update request with new field_roles
|
||||
field_update = MetadataFieldDefinitionUpdate(
|
||||
field_roles=["slot", "resource_filter"]
|
||||
)
|
||||
|
||||
# Execute update
|
||||
result = await service.update_field_definition(
|
||||
"test-tenant",
|
||||
str(existing_field.id),
|
||||
field_update
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result.field_roles == ["slot", "resource_filter"]
|
||||
assert result.version == 2 # Version should increment
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_field_roles_to_empty(self, service, mock_session, existing_field):
|
||||
"""[AC-MRS-01] 测试将字段角色更新为空列表"""
|
||||
service.get_field_definition = AsyncMock(return_value=existing_field)
|
||||
|
||||
field_update = MetadataFieldDefinitionUpdate(
|
||||
field_roles=[]
|
||||
)
|
||||
|
||||
result = await service.update_field_definition(
|
||||
"test-tenant",
|
||||
str(existing_field.id),
|
||||
field_update
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.field_roles == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_field_roles_invalid_role(self, service, mock_session, existing_field):
|
||||
"""[AC-MRS-01] 测试更新无效的字段角色"""
|
||||
service.get_field_definition = AsyncMock(return_value=existing_field)
|
||||
|
||||
field_update = MetadataFieldDefinitionUpdate(
|
||||
field_roles=["invalid_role"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.update_field_definition(
|
||||
"test-tenant",
|
||||
str(existing_field.id),
|
||||
field_update
|
||||
)
|
||||
|
||||
assert "无效的字段角色" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_without_field_roles_unchanged(self, service, mock_session, existing_field):
|
||||
"""[AC-MRS-01] 测试不更新 field_roles 时保持原值"""
|
||||
service.get_field_definition = AsyncMock(return_value=existing_field)
|
||||
|
||||
# Update only label, not field_roles
|
||||
field_update = MetadataFieldDefinitionUpdate(
|
||||
label="新年级标签"
|
||||
)
|
||||
|
||||
result = await service.update_field_definition(
|
||||
"test-tenant",
|
||||
str(existing_field.id),
|
||||
field_update
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.label == "新年级标签"
|
||||
assert result.field_roles == ["slot"] # Should remain unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_field_roles_not_found(self, service, mock_session):
|
||||
"""[AC-MRS-01] 测试更新不存在的字段"""
|
||||
service.get_field_definition = AsyncMock(return_value=None)
|
||||
|
||||
field_update = MetadataFieldDefinitionUpdate(
|
||||
field_roles=["slot"]
|
||||
)
|
||||
|
||||
result = await service.update_field_definition(
|
||||
"test-tenant",
|
||||
str(uuid.uuid4()),
|
||||
field_update
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Unit tests for FusionPolicy.
|
||||
[AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
FusionResult,
|
||||
LlmJudgeResult,
|
||||
RuleMatchResult,
|
||||
SemanticCandidate,
|
||||
SemanticMatchResult,
|
||||
RouteTrace,
|
||||
)
|
||||
|
||||
|
||||
class FusionPolicy:
|
||||
"""[AC-AISVC-115] Fusion decision policy."""
|
||||
|
||||
DECISION_PRIORITY = [
|
||||
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
|
||||
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
|
||||
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
|
||||
("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False),
|
||||
("semantic_fallback", lambda r, s, l: s.top_score > 0.5),
|
||||
("rule_fallback", lambda r, s, l: r.score > 0),
|
||||
("no_match", lambda r, s, l: True),
|
||||
]
|
||||
|
||||
def __init__(self, config: FusionConfig):
|
||||
self._config = config
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
) -> FusionResult:
|
||||
trace = RouteTrace(
|
||||
rule_match={
|
||||
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
|
||||
"match_type": rule_result.match_type,
|
||||
"matched_text": rule_result.matched_text,
|
||||
"score": rule_result.score,
|
||||
"duration_ms": rule_result.duration_ms,
|
||||
},
|
||||
semantic_match={
|
||||
"top_candidates": [
|
||||
{"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score}
|
||||
for c in semantic_result.candidates
|
||||
],
|
||||
"top_score": semantic_result.top_score,
|
||||
"duration_ms": semantic_result.duration_ms,
|
||||
"skipped": semantic_result.skipped,
|
||||
"skip_reason": semantic_result.skip_reason,
|
||||
},
|
||||
llm_judge={
|
||||
"triggered": llm_result.triggered if llm_result else False,
|
||||
"intent_id": llm_result.intent_id if llm_result else None,
|
||||
"score": llm_result.score if llm_result else 0.0,
|
||||
"duration_ms": llm_result.duration_ms if llm_result else 0,
|
||||
"tokens_used": llm_result.tokens_used if llm_result else 0,
|
||||
},
|
||||
fusion={},
|
||||
)
|
||||
|
||||
final_intent = None
|
||||
final_confidence = 0.0
|
||||
decision_reason = "no_match"
|
||||
|
||||
for reason, condition in self.DECISION_PRIORITY:
|
||||
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
|
||||
decision_reason = reason
|
||||
break
|
||||
|
||||
if decision_reason == "rule_high_confidence":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = 1.0
|
||||
elif decision_reason == "llm_judge" and llm_result:
|
||||
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
|
||||
final_confidence = llm_result.score
|
||||
elif decision_reason == "semantic_override":
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
elif decision_reason == "rule_semantic_agree":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result)
|
||||
elif decision_reason == "semantic_fallback":
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
elif decision_reason == "rule_fallback":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = rule_result.score
|
||||
|
||||
need_clarify = final_confidence < self._config.clarify_threshold
|
||||
clarify_candidates = None
|
||||
if need_clarify and len(semantic_result.candidates) > 1:
|
||||
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
|
||||
|
||||
trace.fusion = {
|
||||
"weights": {
|
||||
"w_rule": self._config.w_rule,
|
||||
"w_semantic": self._config.w_semantic,
|
||||
"w_llm": self._config.w_llm,
|
||||
},
|
||||
"final_confidence": final_confidence,
|
||||
"decision_reason": decision_reason,
|
||||
}
|
||||
|
||||
return FusionResult(
|
||||
final_intent=final_intent,
|
||||
final_confidence=final_confidence,
|
||||
decision_reason=decision_reason,
|
||||
need_clarify=need_clarify,
|
||||
clarify_candidates=clarify_candidates,
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
def _calculate_weighted_confidence(
|
||||
self,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
llm_result: LlmJudgeResult | None,
|
||||
) -> float:
|
||||
rule_score = rule_result.score
|
||||
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
|
||||
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
|
||||
|
||||
total_weight = self._config.w_rule + self._config.w_semantic
|
||||
if llm_result and llm_result.triggered:
|
||||
total_weight += self._config.w_llm
|
||||
|
||||
confidence = (
|
||||
self._config.w_rule * rule_score +
|
||||
self._config.w_semantic * semantic_score +
|
||||
self._config.w_llm * llm_score
|
||||
) / total_weight
|
||||
|
||||
return min(1.0, max(0.0, confidence))
|
||||
|
||||
def _find_rule_by_id(
|
||||
self,
|
||||
intent_id: str | None,
|
||||
rule_result: RuleMatchResult,
|
||||
semantic_result: SemanticMatchResult,
|
||||
):
|
||||
if not intent_id:
|
||||
return None
|
||||
|
||||
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
||||
return rule_result.rule
|
||||
|
||||
for candidate in semantic_result.candidates:
|
||||
if str(candidate.rule.id) == intent_id:
|
||||
return candidate.rule
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return FusionConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rule():
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = "Test Intent"
|
||||
rule.response_type = "rag"
|
||||
return rule
|
||||
|
||||
|
||||
class TestFusionPolicy:
|
||||
"""Tests for FusionPolicy class."""
|
||||
|
||||
def test_init(self, config):
|
||||
"""Test FusionPolicy initialization."""
|
||||
policy = FusionPolicy(config)
|
||||
assert policy._config == config
|
||||
|
||||
def test_fuse_rule_high_confidence(self, config, mock_rule):
|
||||
"""Test fusion with rule high confidence."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=mock_rule.id,
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=50,
|
||||
skipped=True,
|
||||
skip_reason="no_semantic_config",
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.decision_reason == "rule_high_confidence"
|
||||
assert result.final_intent == mock_rule
|
||||
assert result.final_confidence == 1.0
|
||||
assert result.need_clarify is False
|
||||
|
||||
def test_fuse_llm_judge(self, config, mock_rule):
|
||||
"""Test fusion with LLM judge result."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
|
||||
top_score=0.5,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
llm_result = LlmJudgeResult(
|
||||
intent_id=str(mock_rule.id),
|
||||
intent_name="Test Intent",
|
||||
score=0.85,
|
||||
reasoning="Test reasoning",
|
||||
duration_ms=500,
|
||||
tokens_used=100,
|
||||
triggered=True,
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, llm_result)
|
||||
|
||||
assert result.decision_reason == "llm_judge"
|
||||
assert result.final_intent == mock_rule
|
||||
assert result.final_confidence == 0.85
|
||||
|
||||
def test_fuse_semantic_override(self, config, mock_rule):
|
||||
"""Test fusion with semantic override."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.85)],
|
||||
top_score=0.85,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.decision_reason == "semantic_override"
|
||||
assert result.final_intent == mock_rule
|
||||
assert result.final_confidence == 0.85
|
||||
|
||||
def test_fuse_rule_semantic_agree(self, config, mock_rule):
|
||||
"""Test fusion when rule and semantic agree."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=mock_rule.id,
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
||||
top_score=0.8,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.decision_reason == "rule_high_confidence"
|
||||
assert result.final_intent == mock_rule
|
||||
|
||||
def test_fuse_no_match(self, config):
|
||||
"""Test fusion with no match."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=50,
|
||||
skipped=True,
|
||||
skip_reason="no_semantic_config",
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.decision_reason == "no_match"
|
||||
assert result.final_intent is None
|
||||
assert result.final_confidence == 0.0
|
||||
|
||||
def test_fuse_need_clarify(self, config, mock_rule):
|
||||
"""Test fusion with clarify needed."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
other_rule = MagicMock()
|
||||
other_rule.id = uuid.uuid4()
|
||||
other_rule.name = "Other Intent"
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[
|
||||
SemanticCandidate(rule=mock_rule, score=0.35),
|
||||
SemanticCandidate(rule=other_rule, score=0.30),
|
||||
],
|
||||
top_score=0.35,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.need_clarify is True
|
||||
assert result.clarify_candidates is not None
|
||||
assert len(result.clarify_candidates) == 2
|
||||
|
||||
def test_calculate_weighted_confidence(self, config, mock_rule):
|
||||
"""Test weighted confidence calculation."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=mock_rule.id,
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
||||
top_score=0.8,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None)
|
||||
|
||||
expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3)
|
||||
assert abs(confidence - expected) < 0.01
|
||||
|
||||
def test_trace_generation(self, config, mock_rule):
|
||||
"""Test that trace is properly generated."""
|
||||
policy = FusionPolicy(config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=mock_rule.id,
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
|
||||
top_score=0.8,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
result = policy.fuse(rule_result, semantic_result, None)
|
||||
|
||||
assert result.trace is not None
|
||||
assert result.trace.rule_match["rule_id"] == str(mock_rule.id)
|
||||
assert result.trace.semantic_match["top_score"] == 0.8
|
||||
assert result.trace.fusion["decision_reason"] == "rule_high_confidence"
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
Tests for intent router.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.intent.router import IntentRouter, RuleMatcher
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
RuleMatchResult,
|
||||
SemanticMatchResult,
|
||||
LlmJudgeResult,
|
||||
FusionResult,
|
||||
)
|
||||
from app.models.entities import IntentRule
|
||||
|
||||
|
||||
class TestRuleMatcher:
|
||||
"""Test RuleMatcher basic functionality."""
|
||||
|
||||
def test_match_empty_message(self):
|
||||
matcher = RuleMatcher()
|
||||
result = matcher.match("", [])
|
||||
assert result.score == 0.0
|
||||
assert result.rule is None
|
||||
assert result.duration_ms >= 0
|
||||
|
||||
def test_match_empty_rules(self):
|
||||
matcher = RuleMatcher()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Test Rule",
|
||||
keywords=["test", "demo"],
|
||||
is_enabled=True,
|
||||
)
|
||||
result = matcher.match("test message", [rule])
|
||||
assert result.score == 1.0
|
||||
assert result.rule == rule
|
||||
assert result.match_type == "keyword"
|
||||
assert result.matched_text == "test"
|
||||
|
||||
def test_match_regex(self):
|
||||
matcher = RuleMatcher()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Test Regex Rule",
|
||||
patterns=[r"test.*pattern"],
|
||||
is_enabled=True,
|
||||
)
|
||||
result = matcher.match("this is a test regex pattern", [rule])
|
||||
assert result.score == 1.0
|
||||
assert result.rule == rule
|
||||
assert result.match_type == "regex"
|
||||
assert "pattern" in result.matched_text
|
||||
|
||||
def test_no_match(self):
|
||||
matcher = RuleMatcher()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Test Rule",
|
||||
keywords=["specific", "keyword"],
|
||||
is_enabled=True,
|
||||
)
|
||||
result = matcher.match("no match here", [rule])
|
||||
assert result.score == 0.0
|
||||
assert result.rule is None
|
||||
|
||||
|
||||
def test_priority_order(self):
|
||||
matcher = RuleMatcher()
|
||||
rule1 = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="High Priority",
|
||||
keywords=["high"],
|
||||
priority=10,
|
||||
is_enabled=True,
|
||||
)
|
||||
rule2 = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Low Priority",
|
||||
keywords=["low"],
|
||||
priority=1,
|
||||
is_enabled=True,
|
||||
)
|
||||
result = matcher.match("high priority message", [rule1, rule2])
|
||||
assert result.rule == rule1
|
||||
assert result.rule.name == "High Priority"
|
||||
|
||||
|
||||
def test_disabled_rule(self):
|
||||
matcher = RuleMatcher()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Disabled Rule",
|
||||
keywords=["disabled"],
|
||||
is_enabled=False,
|
||||
)
|
||||
result = matcher.match("disabled message", [rule])
|
||||
assert result.score == 0.0
|
||||
assert result.rule is None
|
||||
|
||||
|
||||
class TestIntentRouterBackwardCompatibility:
|
||||
"""Test IntentRouter backward compatibility."""
|
||||
|
||||
def test_match_backward_compatible(self):
|
||||
router = IntentRouter()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Test Rule",
|
||||
keywords=["hello", "hi"],
|
||||
is_enabled=True,
|
||||
)
|
||||
result = router.match("hello world", [rule])
|
||||
assert result is not None
|
||||
assert result.rule.name == "Test Rule"
|
||||
assert result.match_type == "keyword"
|
||||
assert result.matched == "hello"
|
||||
|
||||
def test_match_with_stats(self):
|
||||
router = IntentRouter()
|
||||
rule = IntentRule(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
name="Test Rule",
|
||||
keywords=["test"],
|
||||
is_enabled=True,
|
||||
)
|
||||
result, rule_id = router.match_with_stats("test message", [rule])
|
||||
assert result is not None
|
||||
assert rule_id == str(rule.id)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,468 @@
|
|||
"""
|
||||
Integration tests for IntentRouter.match_hybrid().
|
||||
[AC-AISVC-111] Tests for hybrid routing integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
FusionResult,
|
||||
LlmJudgeInput,
|
||||
LlmJudgeResult,
|
||||
RuleMatchResult,
|
||||
SemanticCandidate,
|
||||
SemanticMatchResult,
|
||||
RouteTrace,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_provider():
|
||||
"""Create a mock embedding provider."""
|
||||
provider = AsyncMock()
|
||||
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
||||
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768])
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
"""Create a mock LLM client."""
|
||||
client = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a fusion config."""
|
||||
return FusionConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rule():
|
||||
"""Create a mock intent rule."""
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = "Return Intent"
|
||||
rule.response_type = "rag"
|
||||
rule.keywords = ["退货", "退款"]
|
||||
rule.patterns = []
|
||||
rule.intent_vector = [0.1] * 768
|
||||
rule.semantic_examples = None
|
||||
rule.is_enabled = True
|
||||
rule.priority = 10
|
||||
return rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rules(mock_rule):
|
||||
"""Create a list of mock intent rules."""
|
||||
other_rule = MagicMock()
|
||||
other_rule.id = uuid.uuid4()
|
||||
other_rule.name = "Order Query"
|
||||
other_rule.response_type = "rag"
|
||||
other_rule.keywords = ["订单", "查询"]
|
||||
other_rule.patterns = []
|
||||
other_rule.intent_vector = [0.5] * 768
|
||||
other_rule.semantic_examples = None
|
||||
other_rule.is_enabled = True
|
||||
other_rule.priority = 5
|
||||
|
||||
return [mock_rule, other_rule]
|
||||
|
||||
|
||||
class MockRuleMatcher:
|
||||
"""Mock RuleMatcher for testing."""
|
||||
|
||||
def match(self, message: str, rules: list) -> RuleMatchResult:
|
||||
import time
|
||||
start_time = time.time()
|
||||
message_lower = message.lower()
|
||||
|
||||
for rule in rules:
|
||||
if not rule.is_enabled:
|
||||
continue
|
||||
for keyword in (rule.keywords or []):
|
||||
if keyword.lower() in message_lower:
|
||||
return RuleMatchResult(
|
||||
rule_id=rule.id,
|
||||
rule=rule,
|
||||
match_type="keyword",
|
||||
matched_text=keyword,
|
||||
score=1.0,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
return RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
)
|
||||
|
||||
|
||||
class MockSemanticMatcher:
|
||||
"""Mock SemanticMatcher for testing."""
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
async def match(self, message: str, rules: list, tenant_id: str, top_k: int = 3) -> SemanticMatchResult:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
if not self._config.semantic_matcher_enabled:
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=0,
|
||||
skipped=True,
|
||||
skip_reason="disabled",
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for rule in rules:
|
||||
if rule.intent_vector:
|
||||
candidates.append(SemanticCandidate(rule=rule, score=0.85))
|
||||
break
|
||||
|
||||
return SemanticMatchResult(
|
||||
candidates=candidates[:top_k],
|
||||
top_score=candidates[0].score if candidates else 0.0,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
|
||||
class MockLlmJudge:
|
||||
"""Mock LlmJudge for testing."""
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
def should_trigger(self, rule_result, semantic_result, config=None) -> tuple:
|
||||
effective_config = config or self._config
|
||||
if not effective_config.llm_judge_enabled:
|
||||
return False, "disabled"
|
||||
|
||||
if rule_result.score > 0 and semantic_result.top_score > 0:
|
||||
if semantic_result.candidates:
|
||||
if rule_result.rule_id != semantic_result.candidates[0].rule.id:
|
||||
if abs(rule_result.score - semantic_result.top_score) < effective_config.conflict_threshold:
|
||||
return True, "rule_semantic_conflict"
|
||||
|
||||
max_score = max(rule_result.score, semantic_result.top_score)
|
||||
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
|
||||
return True, "gray_zone"
|
||||
|
||||
return False, ""
|
||||
|
||||
async def judge(self, input_data: LlmJudgeInput, tenant_id: str) -> LlmJudgeResult:
|
||||
return LlmJudgeResult(
|
||||
intent_id=input_data.candidates[0]["id"] if input_data.candidates else None,
|
||||
intent_name=input_data.candidates[0]["name"] if input_data.candidates else None,
|
||||
score=0.9,
|
||||
reasoning="Test arbitration",
|
||||
duration_ms=500,
|
||||
tokens_used=100,
|
||||
triggered=True,
|
||||
)
|
||||
|
||||
|
||||
class MockFusionPolicy:
|
||||
"""Mock FusionPolicy for testing."""
|
||||
|
||||
DECISION_PRIORITY = [
|
||||
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
|
||||
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
|
||||
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
|
||||
("no_match", lambda r, s, l: True),
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
def fuse(self, rule_result, semantic_result, llm_result) -> FusionResult:
|
||||
decision_reason = "no_match"
|
||||
for reason, condition in self.DECISION_PRIORITY:
|
||||
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
|
||||
decision_reason = reason
|
||||
break
|
||||
|
||||
final_intent = None
|
||||
final_confidence = 0.0
|
||||
|
||||
if decision_reason == "rule_high_confidence":
|
||||
final_intent = rule_result.rule
|
||||
final_confidence = 1.0
|
||||
elif decision_reason == "llm_judge" and llm_result:
|
||||
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
|
||||
final_confidence = llm_result.score
|
||||
elif decision_reason == "semantic_override":
|
||||
final_intent = semantic_result.candidates[0].rule
|
||||
final_confidence = semantic_result.top_score
|
||||
|
||||
return FusionResult(
|
||||
final_intent=final_intent,
|
||||
final_confidence=final_confidence,
|
||||
decision_reason=decision_reason,
|
||||
need_clarify=final_confidence < 0.4,
|
||||
clarify_candidates=None,
|
||||
trace=RouteTrace(),
|
||||
)
|
||||
|
||||
def _find_rule_by_id(self, intent_id, rule_result, semantic_result):
|
||||
if not intent_id:
|
||||
return None
|
||||
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
|
||||
return rule_result.rule
|
||||
for c in semantic_result.candidates:
|
||||
if str(c.rule.id) == intent_id:
|
||||
return c.rule
|
||||
return None
|
||||
|
||||
|
||||
class MockIntentRouter:
|
||||
"""Mock IntentRouter for testing match_hybrid."""
|
||||
|
||||
def __init__(self, rule_matcher, semantic_matcher, llm_judge, fusion_policy, config=None):
|
||||
self._rule_matcher = rule_matcher
|
||||
self._semantic_matcher = semantic_matcher
|
||||
self._llm_judge = llm_judge
|
||||
self._fusion_policy = fusion_policy
|
||||
self._config = config or FusionConfig()
|
||||
|
||||
async def match_hybrid(
|
||||
self,
|
||||
message: str,
|
||||
rules: list,
|
||||
tenant_id: str,
|
||||
config: FusionConfig | None = None,
|
||||
) -> FusionResult:
|
||||
effective_config = config or self._config
|
||||
|
||||
rule_result, semantic_result = await asyncio.gather(
|
||||
asyncio.to_thread(self._rule_matcher.match, message, rules),
|
||||
self._semantic_matcher.match(message, rules, tenant_id),
|
||||
)
|
||||
|
||||
llm_result = None
|
||||
should_trigger, trigger_reason = self._llm_judge.should_trigger(
|
||||
rule_result, semantic_result, effective_config
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
candidates = self._build_llm_candidates(rule_result, semantic_result)
|
||||
llm_result = await self._llm_judge.judge(
|
||||
LlmJudgeInput(
|
||||
message=message,
|
||||
candidates=candidates,
|
||||
conflict_type=trigger_reason,
|
||||
),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
fusion_result = self._fusion_policy.fuse(
|
||||
rule_result, semantic_result, llm_result
|
||||
)
|
||||
|
||||
return fusion_result
|
||||
|
||||
def _build_llm_candidates(self, rule_result, semantic_result) -> list:
|
||||
candidates = []
|
||||
|
||||
if rule_result.rule:
|
||||
candidates.append({
|
||||
"id": str(rule_result.rule_id),
|
||||
"name": rule_result.rule.name,
|
||||
"description": f"匹配方式: {rule_result.match_type}",
|
||||
})
|
||||
|
||||
for candidate in semantic_result.candidates[:3]:
|
||||
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
|
||||
candidates.append({
|
||||
"id": str(candidate.rule.id),
|
||||
"name": candidate.rule.name,
|
||||
"description": f"语义相似度: {candidate.score:.2f}",
|
||||
})
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
class TestIntentRouterHybrid:
|
||||
"""Tests for IntentRouter.match_hybrid() integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_rule_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test hybrid routing with rule match."""
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = MockSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason == "rule_high_confidence"
|
||||
assert result.final_intent == mock_rules[0]
|
||||
assert result.final_confidence == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_semantic_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test hybrid routing with semantic match only."""
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = MockSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("商品有问题", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason == "semantic_override"
|
||||
assert result.final_intent is not None
|
||||
assert result.final_confidence > 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_parallel_execution(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test that rule and semantic matching run in parallel."""
|
||||
import time
|
||||
|
||||
class SlowSemanticMatcher(MockSemanticMatcher):
|
||||
async def match(self, message, rules, tenant_id, top_k=3):
|
||||
await asyncio.sleep(0.1)
|
||||
return await super().match(message, rules, tenant_id, top_k)
|
||||
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = SlowSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert elapsed < 0.2
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_llm_judge_triggered(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test hybrid routing with LLM judge triggered."""
|
||||
config = FusionConfig(conflict_threshold=0.3)
|
||||
|
||||
class ConflictSemanticMatcher(MockSemanticMatcher):
|
||||
async def match(self, message, rules, tenant_id, top_k=3):
|
||||
result = await super().match(message, rules, tenant_id, top_k)
|
||||
if result.candidates:
|
||||
result.candidates[0] = SemanticCandidate(rule=rules[1], score=0.9)
|
||||
result.top_score = 0.9
|
||||
return result
|
||||
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = ConflictSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason in ["rule_high_confidence", "llm_judge"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_no_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test hybrid routing with no match."""
|
||||
class NoMatchSemanticMatcher(MockSemanticMatcher):
|
||||
async def match(self, message, rules, tenant_id, top_k=3):
|
||||
return SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.0,
|
||||
duration_ms=10,
|
||||
skipped=True,
|
||||
skip_reason="no_semantic_config",
|
||||
)
|
||||
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = NoMatchSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("随便说说", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason == "no_match"
|
||||
assert result.final_intent is None
|
||||
assert result.final_confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_semantic_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
|
||||
"""Test hybrid routing with semantic matcher disabled."""
|
||||
config = FusionConfig(semantic_matcher_enabled=False)
|
||||
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = MockSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason == "rule_high_confidence"
|
||||
assert result.final_intent == mock_rules[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_llm_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
|
||||
"""Test hybrid routing with LLM judge disabled."""
|
||||
config = FusionConfig(llm_judge_enabled=False)
|
||||
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = MockSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
|
||||
assert result.decision_reason == "rule_high_confidence"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_hybrid_trace_generated(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
|
||||
"""Test that route trace is generated."""
|
||||
rule_matcher = MockRuleMatcher()
|
||||
semantic_matcher = MockSemanticMatcher(config)
|
||||
llm_judge = MockLlmJudge(config)
|
||||
fusion_policy = MockFusionPolicy(config)
|
||||
|
||||
router = MockIntentRouter(
|
||||
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
|
||||
)
|
||||
|
||||
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
|
||||
|
||||
assert result.trace is not None
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
"""
|
||||
Tests for KB Search Dynamic Tool with Slot State Integration.
|
||||
[AC-MRS-SLOT-META-02] KB 检索与槽位状态集成测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.models.mid.schemas import MemorySlot, SlotSource, ToolCallStatus
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicConfig,
|
||||
KbSearchDynamicTool,
|
||||
)
|
||||
from app.services.mid.slot_state_aggregator import SlotState
|
||||
|
||||
|
||||
class TestKbSearchDynamicWithSlotState:
|
||||
"""测试 KB Search Dynamic Tool 与槽位状态的集成"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""模拟数据库会话"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def kb_tool(self, mock_session):
|
||||
"""创建 KB 工具实例"""
|
||||
return KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
config=KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=10000,
|
||||
min_score_threshold=0.5,
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_missing_required_slots(self, kb_tool, mock_session):
|
||||
"""测试当存在缺失必填槽位时返回追问响应"""
|
||||
slot_state = SlotState(
|
||||
filled_slots={},
|
||||
missing_required_slots=[
|
||||
{
|
||||
"slot_key": "product_line",
|
||||
"label": "产品线",
|
||||
"reason": "required_slot_missing",
|
||||
"ask_back_prompt": "请问您咨询的是哪个产品线?",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = await kb_tool.execute(
|
||||
query="退款政策",
|
||||
tenant_id="test_tenant",
|
||||
context={},
|
||||
slot_state=slot_state,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
|
||||
assert len(result.missing_required_slots) == 1
|
||||
assert result.missing_required_slots[0]["slot_key"] == "product_line"
|
||||
assert result.tool_trace is not None
|
||||
assert result.tool_trace.error_code == "MISSING_REQUIRED_SLOTS"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_filled_slots(self, kb_tool, mock_session):
|
||||
"""测试使用已填充槽位构建过滤器"""
|
||||
slot_state = SlotState(
|
||||
filled_slots={"product_line": "vip_course"},
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={"product_line": "product_line"},
|
||||
)
|
||||
|
||||
# 模拟 filter_builder
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[])
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
# 模拟检索结果
|
||||
with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]):
|
||||
result = await kb_tool.execute(
|
||||
query="退款政策",
|
||||
tenant_id="test_tenant",
|
||||
context={},
|
||||
slot_state=slot_state,
|
||||
)
|
||||
|
||||
# 应该成功执行(虽然没有命中结果)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_filter_from_slot_state_priority(self, kb_tool, mock_session):
|
||||
"""测试过滤值来源优先级:slot > context > default"""
|
||||
from app.services.mid.metadata_filter_builder import FilterFieldInfo
|
||||
|
||||
slot_state = SlotState(
|
||||
filled_slots={"product_line": "from_slot"},
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={"product_line": "product_line"},
|
||||
)
|
||||
|
||||
context = {"product_line": "from_context"}
|
||||
|
||||
# 模拟可过滤字段
|
||||
mock_field = FilterFieldInfo(
|
||||
field_key="product_line",
|
||||
label="产品线",
|
||||
field_type="string",
|
||||
required=True,
|
||||
options=None,
|
||||
default_value="from_default",
|
||||
is_filterable=True,
|
||||
)
|
||||
|
||||
# 模拟 filter_builder
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
|
||||
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_slot"})
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
filter_result = await kb_tool._build_filter_from_slot_state(
|
||||
tenant_id="test_tenant",
|
||||
slot_state=slot_state,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# 应该使用 slot 的值(优先级最高)
|
||||
assert "product_line" in filter_result
|
||||
mock_filter_builder._build_field_filter.assert_called_once()
|
||||
call_args = mock_filter_builder._build_field_filter.call_args
|
||||
assert call_args[0][1] == "from_slot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_filter_uses_context_when_slot_empty(self, kb_tool, mock_session):
|
||||
"""测试当 slot 为空时使用 context 值"""
|
||||
from app.services.mid.metadata_filter_builder import FilterFieldInfo
|
||||
|
||||
slot_state = SlotState(
|
||||
filled_slots={}, # 空槽位
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={},
|
||||
)
|
||||
|
||||
context = {"product_line": "from_context"}
|
||||
|
||||
# 模拟可过滤字段
|
||||
mock_field = FilterFieldInfo(
|
||||
field_key="product_line",
|
||||
label="产品线",
|
||||
field_type="string",
|
||||
required=True,
|
||||
options=None,
|
||||
default_value="from_default",
|
||||
is_filterable=True,
|
||||
)
|
||||
|
||||
# 模拟 filter_builder
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
|
||||
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_context"})
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
filter_result = await kb_tool._build_filter_from_slot_state(
|
||||
tenant_id="test_tenant",
|
||||
slot_state=slot_state,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# 应该使用 context 的值
|
||||
assert "product_line" in filter_result
|
||||
call_args = mock_filter_builder._build_field_filter.call_args
|
||||
assert call_args[0][1] == "from_context"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_filter_uses_default_when_no_other(self, kb_tool, mock_session):
|
||||
"""测试当 slot 和 context 都为空时使用默认值"""
|
||||
from app.services.mid.metadata_filter_builder import FilterFieldInfo
|
||||
|
||||
slot_state = SlotState(
|
||||
filled_slots={},
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={},
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
||||
# 模拟可过滤字段(带默认值)
|
||||
mock_field = FilterFieldInfo(
|
||||
field_key="product_line",
|
||||
label="产品线",
|
||||
field_type="string",
|
||||
required=False, # 非必填
|
||||
options=None,
|
||||
default_value="from_default",
|
||||
is_filterable=True,
|
||||
)
|
||||
|
||||
# 模拟 filter_builder
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
|
||||
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_default"})
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
filter_result = await kb_tool._build_filter_from_slot_state(
|
||||
tenant_id="test_tenant",
|
||||
slot_state=slot_state,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# 应该使用默认值
|
||||
assert "product_line" in filter_result
|
||||
call_args = mock_filter_builder._build_field_filter.call_args
|
||||
assert call_args[0][1] == "from_default"
|
||||
|
||||
|
||||
class TestKbSearchDynamicSlotMapping:
|
||||
"""测试槽位与字段映射"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slot_to_field_mapping_in_filter(self):
|
||||
"""测试通过 slot_to_field_map 映射槽位值到字段"""
|
||||
from app.services.mid.metadata_filter_builder import FilterFieldInfo
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_session = AsyncMock()
|
||||
kb_tool = KbSearchDynamicTool(session=mock_session)
|
||||
|
||||
# slot_key 是 "product",但映射到 field_key "product_line"
|
||||
slot_state = SlotState(
|
||||
filled_slots={"product": "vip_course"},
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={"product": "product_line"},
|
||||
)
|
||||
|
||||
# 模拟可过滤字段(使用 field_key)
|
||||
mock_field = FilterFieldInfo(
|
||||
field_key="product_line",
|
||||
label="产品线",
|
||||
field_type="string",
|
||||
required=True,
|
||||
options=None,
|
||||
default_value=None,
|
||||
is_filterable=True,
|
||||
)
|
||||
|
||||
# 模拟 filter_builder
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
|
||||
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"})
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
filter_result = await kb_tool._build_filter_from_slot_state(
|
||||
tenant_id="test_tenant",
|
||||
slot_state=slot_state,
|
||||
context={},
|
||||
)
|
||||
|
||||
# 应该通过映射找到值
|
||||
assert "product_line" in filter_result
|
||||
call_args = mock_filter_builder._build_field_filter.call_args
|
||||
assert call_args[0][1] == "vip_course"
|
||||
|
||||
|
||||
class TestKbSearchDynamicDebugInfo:
|
||||
"""测试调试信息输出"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_debug_includes_sources(self):
|
||||
"""测试过滤器调试信息包含来源标识"""
|
||||
from app.services.mid.metadata_filter_builder import FilterFieldInfo
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_session = AsyncMock()
|
||||
kb_tool = KbSearchDynamicTool(session=mock_session)
|
||||
|
||||
slot_state = SlotState(
|
||||
filled_slots={"product_line": "vip_course"},
|
||||
missing_required_slots=[],
|
||||
slot_to_field_map={"product_line": "product_line"},
|
||||
)
|
||||
|
||||
mock_field = FilterFieldInfo(
|
||||
field_key="product_line",
|
||||
label="产品线",
|
||||
field_type="string",
|
||||
required=True,
|
||||
options=None,
|
||||
default_value=None,
|
||||
is_filterable=True,
|
||||
)
|
||||
|
||||
mock_filter_builder = MagicMock()
|
||||
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
|
||||
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"})
|
||||
kb_tool._filter_builder = mock_filter_builder
|
||||
|
||||
with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]):
|
||||
result = await kb_tool.execute(
|
||||
query="退款政策",
|
||||
tenant_id="test_tenant",
|
||||
context={},
|
||||
slot_state=slot_state,
|
||||
)
|
||||
|
||||
# 调试信息应该包含来源
|
||||
assert result.filter_debug is not None
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
Unit tests for LlmJudge.
|
||||
[AC-AISVC-118, AC-AISVC-119] Tests for LLM-based intent arbitration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from app.services.intent.llm_judge import LlmJudge
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
LlmJudgeInput,
|
||||
LlmJudgeResult,
|
||||
RuleMatchResult,
|
||||
SemanticCandidate,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
"""Create a mock LLM client."""
|
||||
client = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a fusion config."""
|
||||
return FusionConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rule():
|
||||
"""Create a mock intent rule."""
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = "Test Intent"
|
||||
return rule
|
||||
|
||||
|
||||
class TestLlmJudge:
|
||||
"""Tests for LlmJudge class."""
|
||||
|
||||
def test_init(self, mock_llm_client, config):
|
||||
"""Test LlmJudge initialization."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
assert judge._llm_client == mock_llm_client
|
||||
assert judge._config == config
|
||||
|
||||
def test_should_trigger_disabled(self, mock_llm_client):
|
||||
"""Test should_trigger when LLM judge is disabled."""
|
||||
config = FusionConfig(llm_judge_enabled=False)
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=uuid.uuid4(),
|
||||
rule=MagicMock(),
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[],
|
||||
top_score=0.8,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
triggered, reason = judge.should_trigger(rule_result, semantic_result)
|
||||
assert triggered is False
|
||||
assert reason == "disabled"
|
||||
|
||||
def test_should_trigger_rule_semantic_conflict(self, mock_llm_client, config, mock_rule):
|
||||
"""Test should_trigger for rule vs semantic conflict."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=uuid.uuid4(),
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
|
||||
other_rule = MagicMock()
|
||||
other_rule.id = uuid.uuid4()
|
||||
other_rule.name = "Other Intent"
|
||||
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=other_rule, score=0.95)],
|
||||
top_score=0.95,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
triggered, reason = judge.should_trigger(rule_result, semantic_result)
|
||||
assert triggered is True
|
||||
assert reason == "rule_semantic_conflict"
|
||||
|
||||
def test_should_trigger_gray_zone(self, mock_llm_client, config, mock_rule):
|
||||
"""Test should_trigger for gray zone scenario."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
|
||||
top_score=0.5,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
triggered, reason = judge.should_trigger(rule_result, semantic_result)
|
||||
assert triggered is True
|
||||
assert reason == "gray_zone"
|
||||
|
||||
def test_should_trigger_multi_intent(self, mock_llm_client, config, mock_rule):
|
||||
"""Test should_trigger for multi-intent scenario."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=None,
|
||||
rule=None,
|
||||
match_type=None,
|
||||
matched_text=None,
|
||||
score=0.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
|
||||
other_rule = MagicMock()
|
||||
other_rule.id = uuid.uuid4()
|
||||
other_rule.name = "Other Intent"
|
||||
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[
|
||||
SemanticCandidate(rule=mock_rule, score=0.8),
|
||||
SemanticCandidate(rule=other_rule, score=0.75),
|
||||
],
|
||||
top_score=0.8,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
triggered, reason = judge.should_trigger(rule_result, semantic_result)
|
||||
assert triggered is True
|
||||
assert reason == "multi_intent"
|
||||
|
||||
def test_should_not_trigger_high_confidence(self, mock_llm_client, config, mock_rule):
|
||||
"""Test should_trigger returns False for high confidence match."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
rule_result = RuleMatchResult(
|
||||
rule_id=mock_rule.id,
|
||||
rule=mock_rule,
|
||||
match_type="keyword",
|
||||
matched_text="test",
|
||||
score=1.0,
|
||||
duration_ms=10,
|
||||
)
|
||||
|
||||
semantic_result = SemanticMatchResult(
|
||||
candidates=[SemanticCandidate(rule=mock_rule, score=0.9)],
|
||||
top_score=0.9,
|
||||
duration_ms=50,
|
||||
skipped=False,
|
||||
skip_reason=None,
|
||||
)
|
||||
|
||||
triggered, reason = judge.should_trigger(rule_result, semantic_result)
|
||||
assert triggered is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_success(self, mock_llm_client, config):
|
||||
"""Test successful LLM judge."""
|
||||
from app.services.llm.base import LLMResponse
|
||||
|
||||
mock_response = LLMResponse(
|
||||
content='{"intent_id": "test-id", "intent_name": "Test", "confidence": 0.85, "reasoning": "Test reasoning"}',
|
||||
model="gpt-4",
|
||||
usage={"total_tokens": 100},
|
||||
)
|
||||
mock_llm_client.generate = AsyncMock(return_value=mock_response)
|
||||
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
input_data = LlmJudgeInput(
|
||||
message="test message",
|
||||
candidates=[{"id": "test-id", "name": "Test", "description": "Test intent"}],
|
||||
conflict_type="gray_zone",
|
||||
)
|
||||
|
||||
result = await judge.judge(input_data, "tenant-1")
|
||||
|
||||
assert result.triggered is True
|
||||
assert result.intent_id == "test-id"
|
||||
assert result.intent_name == "Test"
|
||||
assert result.score == 0.85
|
||||
assert result.reasoning == "Test reasoning"
|
||||
assert result.tokens_used == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_timeout(self, mock_llm_client, config):
|
||||
"""Test LLM judge timeout."""
|
||||
import asyncio
|
||||
mock_llm_client.generate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
input_data = LlmJudgeInput(
|
||||
message="test message",
|
||||
candidates=[{"id": "test-id", "name": "Test"}],
|
||||
conflict_type="gray_zone",
|
||||
)
|
||||
|
||||
result = await judge.judge(input_data, "tenant-1")
|
||||
|
||||
assert result.triggered is True
|
||||
assert result.intent_id is None
|
||||
assert "timeout" in result.reasoning.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_error(self, mock_llm_client, config):
|
||||
"""Test LLM judge error handling."""
|
||||
mock_llm_client.generate = AsyncMock(side_effect=Exception("LLM error"))
|
||||
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
input_data = LlmJudgeInput(
|
||||
message="test message",
|
||||
candidates=[{"id": "test-id", "name": "Test"}],
|
||||
conflict_type="gray_zone",
|
||||
)
|
||||
|
||||
result = await judge.judge(input_data, "tenant-1")
|
||||
|
||||
assert result.triggered is True
|
||||
assert result.intent_id is None
|
||||
assert "error" in result.reasoning.lower()
|
||||
|
||||
def test_parse_response_valid_json(self, mock_llm_client, config):
|
||||
"""Test parsing valid JSON response."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
content = '{"intent_id": "test", "confidence": 0.9}'
|
||||
result = judge._parse_response(content)
|
||||
|
||||
assert result["intent_id"] == "test"
|
||||
assert result["confidence"] == 0.9
|
||||
|
||||
def test_parse_response_with_markdown(self, mock_llm_client, config):
|
||||
"""Test parsing JSON response with markdown code block."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
content = '```json\n{"intent_id": "test", "confidence": 0.9}\n```'
|
||||
result = judge._parse_response(content)
|
||||
|
||||
assert result["intent_id"] == "test"
|
||||
assert result["confidence"] == 0.9
|
||||
|
||||
def test_parse_response_invalid_json(self, mock_llm_client, config):
|
||||
"""Test parsing invalid JSON response."""
|
||||
judge = LlmJudge(mock_llm_client, config)
|
||||
|
||||
content = "This is not valid JSON"
|
||||
result = judge._parse_response(content)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_llm_judge_result_empty(self):
|
||||
"""Test LlmJudgeResult.empty() class method."""
|
||||
result = LlmJudgeResult.empty()
|
||||
|
||||
assert result.intent_id is None
|
||||
assert result.intent_name is None
|
||||
assert result.score == 0.0
|
||||
assert result.reasoning is None
|
||||
assert result.duration_ms == 0
|
||||
assert result.tokens_used == 0
|
||||
assert result.triggered is False
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Tests for Scene Slot Bundle Loader.
|
||||
[AC-SCENE-SLOT-02] 场景槽位包加载器测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.mid.scene_slot_bundle_loader import (
|
||||
SceneSlotBundleLoader,
|
||||
SceneSlotContext,
|
||||
SlotInfo,
|
||||
)
|
||||
from app.models.entities import (
|
||||
SceneSlotBundle,
|
||||
SceneSlotBundleStatus,
|
||||
SlotDefinition,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_slot_definitions():
|
||||
"""Sample slot definitions for testing."""
|
||||
return [
|
||||
SlotDefinition(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
slot_key="course_type",
|
||||
type="string",
|
||||
required=True,
|
||||
ask_back_prompt="请问您想咨询哪种类型的课程?",
|
||||
),
|
||||
SlotDefinition(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
slot_key="grade",
|
||||
type="string",
|
||||
required=True,
|
||||
ask_back_prompt="请问您是几年级?",
|
||||
),
|
||||
SlotDefinition(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
slot_key="region",
|
||||
type="string",
|
||||
required=False,
|
||||
ask_back_prompt="请问您在哪个地区?",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bundle(sample_slot_definitions):
|
||||
"""Sample scene slot bundle for testing."""
|
||||
return SceneSlotBundle(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
description="开放咨询场景的槽位配置",
|
||||
required_slots=["course_type", "grade"],
|
||||
optional_slots=["region"],
|
||||
slot_priority=["course_type", "grade", "region"],
|
||||
completion_threshold=1.0,
|
||||
ask_back_order="priority",
|
||||
status=SceneSlotBundleStatus.ACTIVE.value,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
class TestSceneSlotContext:
|
||||
"""Test cases for SceneSlotContext."""
|
||||
|
||||
def test_get_all_slot_keys(self):
|
||||
"""Test getting all slot keys."""
|
||||
context = SceneSlotContext(
|
||||
scene_key="test_scene",
|
||||
scene_name="测试场景",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True),
|
||||
],
|
||||
optional_slots=[
|
||||
SlotInfo(slot_key="region", type="string", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
all_keys = context.get_all_slot_keys()
|
||||
|
||||
assert "course_type" in all_keys
|
||||
assert "region" in all_keys
|
||||
assert len(all_keys) == 2
|
||||
|
||||
def test_get_missing_slots(self):
|
||||
"""Test getting missing slots."""
|
||||
context = SceneSlotContext(
|
||||
scene_key="test_scene",
|
||||
scene_name="测试场景",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"),
|
||||
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"),
|
||||
],
|
||||
optional_slots=[],
|
||||
)
|
||||
|
||||
filled_slots = {"course_type": "数学"}
|
||||
missing = context.get_missing_slots(filled_slots)
|
||||
|
||||
assert len(missing) == 1
|
||||
assert missing[0]["slot_key"] == "grade"
|
||||
|
||||
def test_get_ordered_missing_slots_priority(self):
|
||||
"""Test getting ordered missing slots with priority order."""
|
||||
context = SceneSlotContext(
|
||||
scene_key="test_scene",
|
||||
scene_name="测试场景",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True),
|
||||
SlotInfo(slot_key="grade", type="string", required=True),
|
||||
],
|
||||
optional_slots=[],
|
||||
slot_priority=["grade", "course_type"],
|
||||
ask_back_order="priority",
|
||||
)
|
||||
|
||||
filled_slots = {}
|
||||
missing = context.get_ordered_missing_slots(filled_slots)
|
||||
|
||||
assert len(missing) == 2
|
||||
assert missing[0]["slot_key"] == "grade"
|
||||
assert missing[1]["slot_key"] == "course_type"
|
||||
|
||||
def test_get_completion_ratio(self):
|
||||
"""Test calculating completion ratio."""
|
||||
context = SceneSlotContext(
|
||||
scene_key="test_scene",
|
||||
scene_name="测试场景",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True),
|
||||
SlotInfo(slot_key="grade", type="string", required=True),
|
||||
],
|
||||
optional_slots=[],
|
||||
completion_threshold=0.5,
|
||||
)
|
||||
|
||||
filled_slots = {"course_type": "数学"}
|
||||
ratio = context.get_completion_ratio(filled_slots)
|
||||
|
||||
assert ratio == 0.5
|
||||
|
||||
def test_is_complete(self):
|
||||
"""Test checking if complete."""
|
||||
context = SceneSlotContext(
|
||||
scene_key="test_scene",
|
||||
scene_name="测试场景",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True),
|
||||
SlotInfo(slot_key="grade", type="string", required=True),
|
||||
],
|
||||
optional_slots=[],
|
||||
completion_threshold=1.0,
|
||||
)
|
||||
|
||||
filled_slots = {"course_type": "数学", "grade": "高一"}
|
||||
is_complete = context.is_complete(filled_slots)
|
||||
|
||||
assert is_complete is True
|
||||
|
||||
filled_slots_partial = {"course_type": "数学"}
|
||||
is_complete_partial = context.is_complete(filled_slots_partial)
|
||||
|
||||
assert is_complete_partial is False
|
||||
|
||||
|
||||
class TestSceneSlotBundleLoader:
|
||||
"""Test cases for SceneSlotBundleLoader."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_scene_context(self, mock_session, sample_bundle, sample_slot_definitions):
|
||||
"""Test loading scene context."""
|
||||
loader = SceneSlotBundleLoader(mock_session)
|
||||
|
||||
with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get_bundle:
|
||||
mock_get_bundle.return_value = sample_bundle
|
||||
|
||||
with patch.object(loader._slot_service, 'list_slot_definitions', new_callable=AsyncMock) as mock_get_slots:
|
||||
mock_get_slots.return_value = sample_slot_definitions
|
||||
|
||||
context = await loader.load_scene_context("test_tenant", "open_consult")
|
||||
|
||||
assert context is not None
|
||||
assert context.scene_key == "open_consult"
|
||||
assert len(context.required_slots) == 2
|
||||
assert len(context.optional_slots) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_scene_context_not_found(self, mock_session):
|
||||
"""Test loading scene context when bundle not found."""
|
||||
loader = SceneSlotBundleLoader(mock_session)
|
||||
|
||||
with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
context = await loader.load_scene_context("test_tenant", "unknown_scene")
|
||||
|
||||
assert context is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_slots_for_scene(self, mock_session, sample_bundle, sample_slot_definitions):
|
||||
"""Test getting missing slots for a scene."""
|
||||
loader = SceneSlotBundleLoader(mock_session)
|
||||
|
||||
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
|
||||
mock_context = SceneSlotContext(
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"),
|
||||
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"),
|
||||
],
|
||||
optional_slots=[],
|
||||
slot_priority=["course_type", "grade"],
|
||||
)
|
||||
mock_load.return_value = mock_context
|
||||
|
||||
filled_slots = {"course_type": "数学"}
|
||||
missing = await loader.get_missing_slots_for_scene(
|
||||
"test_tenant", "open_consult", filled_slots
|
||||
)
|
||||
|
||||
assert len(missing) == 1
|
||||
assert missing[0]["slot_key"] == "grade"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_prompt_single(self, mock_session):
|
||||
"""Test generating ask-back prompt for single missing slot."""
|
||||
loader = SceneSlotBundleLoader(mock_session)
|
||||
|
||||
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
|
||||
mock_context = SceneSlotContext(
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问您想咨询哪种课程?"),
|
||||
],
|
||||
optional_slots=[],
|
||||
ask_back_order="priority",
|
||||
)
|
||||
mock_load.return_value = mock_context
|
||||
|
||||
with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing:
|
||||
mock_missing.return_value = [
|
||||
{"slot_key": "course_type", "ask_back_prompt": "请问您想咨询哪种课程?"}
|
||||
]
|
||||
|
||||
prompt = await loader.generate_ask_back_prompt(
|
||||
"test_tenant", "open_consult", {}
|
||||
)
|
||||
|
||||
assert prompt == "请问您想咨询哪种课程?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_prompt_parallel(self, mock_session):
|
||||
"""Test generating ask-back prompt with parallel strategy."""
|
||||
loader = SceneSlotBundleLoader(mock_session)
|
||||
|
||||
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
|
||||
mock_context = SceneSlotContext(
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
required_slots=[
|
||||
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="课程类型"),
|
||||
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="年级"),
|
||||
],
|
||||
optional_slots=[],
|
||||
ask_back_order="parallel",
|
||||
)
|
||||
mock_load.return_value = mock_context
|
||||
|
||||
with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing:
|
||||
mock_missing.return_value = [
|
||||
{"slot_key": "course_type", "ask_back_prompt": "课程类型"},
|
||||
{"slot_key": "grade", "ask_back_prompt": "年级"},
|
||||
]
|
||||
|
||||
prompt = await loader.generate_ask_back_prompt(
|
||||
"test_tenant", "open_consult", {}
|
||||
)
|
||||
|
||||
assert "课程类型" in prompt
|
||||
assert "年级" in prompt
|
||||
|
|
@ -0,0 +1,284 @@
|
|||
"""
|
||||
Tests for Scene Slot Bundle Service.
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置服务测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.models.entities import (
|
||||
SceneSlotBundle,
|
||||
SceneSlotBundleCreate,
|
||||
SceneSlotBundleUpdate,
|
||||
SceneSlotBundleStatus,
|
||||
SlotDefinition,
|
||||
)
|
||||
from app.services.scene_slot_bundle_service import SceneSlotBundleService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_slot_definition():
|
||||
"""Sample slot definition for testing."""
|
||||
return SlotDefinition(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
slot_key="course_type",
|
||||
type="string",
|
||||
required=True,
|
||||
ask_back_prompt="请问您想咨询哪种类型的课程?",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bundle():
|
||||
"""Sample scene slot bundle for testing."""
|
||||
return SceneSlotBundle(
|
||||
id=uuid4(),
|
||||
tenant_id="test_tenant",
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
description="开放咨询场景的槽位配置",
|
||||
required_slots=["course_type", "grade"],
|
||||
optional_slots=["region"],
|
||||
slot_priority=["course_type", "grade", "region"],
|
||||
completion_threshold=1.0,
|
||||
ask_back_order="priority",
|
||||
status=SceneSlotBundleStatus.ACTIVE.value,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
class TestSceneSlotBundleService:
|
||||
"""Test cases for SceneSlotBundleService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_bundles(self, mock_session, sample_bundle):
|
||||
"""Test listing scene slot bundles."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [sample_bundle]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
bundles = await service.list_bundles("test_tenant")
|
||||
|
||||
assert len(bundles) == 1
|
||||
assert bundles[0].scene_key == "open_consult"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_bundles_with_status_filter(self, mock_session, sample_bundle):
|
||||
"""Test listing scene slot bundles with status filter."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [sample_bundle]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
bundles = await service.list_bundles("test_tenant", status="active")
|
||||
|
||||
assert len(bundles) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bundle_by_id(self, mock_session, sample_bundle):
|
||||
"""Test getting a bundle by ID."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = sample_bundle
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
bundle = await service.get_bundle("test_tenant", str(sample_bundle.id))
|
||||
|
||||
assert bundle is not None
|
||||
assert bundle.scene_key == "open_consult"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bundle_by_scene_key(self, mock_session, sample_bundle):
|
||||
"""Test getting a bundle by scene key."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = sample_bundle
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
bundle = await service.get_bundle_by_scene_key("test_tenant", "open_consult")
|
||||
|
||||
assert bundle is not None
|
||||
assert bundle.scene_key == "open_consult"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_bundle_by_scene(self, mock_session, sample_bundle):
|
||||
"""Test getting an active bundle by scene key."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = sample_bundle
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
bundle = await service.get_active_bundle_by_scene("test_tenant", "open_consult")
|
||||
|
||||
assert bundle is not None
|
||||
assert bundle.status == SceneSlotBundleStatus.ACTIVE.value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bundle_success(self, mock_session, sample_slot_definition):
|
||||
"""Test creating a scene slot bundle successfully."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
with patch.object(service, '_validate_slot_keys', new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = {"course_type"}
|
||||
|
||||
with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
bundle_create = SceneSlotBundleCreate(
|
||||
scene_key="new_scene",
|
||||
scene_name="新场景",
|
||||
required_slots=["course_type"],
|
||||
optional_slots=[],
|
||||
)
|
||||
|
||||
bundle = await service.create_bundle("test_tenant", bundle_create)
|
||||
|
||||
assert bundle is not None
|
||||
assert bundle.scene_key == "new_scene"
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bundle_duplicate_scene_key(self, mock_session, sample_bundle):
|
||||
"""Test creating a bundle with duplicate scene key."""
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = sample_bundle
|
||||
|
||||
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = []
|
||||
|
||||
bundle_create = SceneSlotBundleCreate(
|
||||
scene_key="open_consult",
|
||||
scene_name="开放咨询",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="已存在"):
|
||||
await service.create_bundle("test_tenant", bundle_create)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bundle_validation_error(self, mock_session):
|
||||
"""Test creating a bundle with validation error."""
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = ["必填和可选槽位存在交叉"]
|
||||
|
||||
bundle_create = SceneSlotBundleCreate(
|
||||
scene_key="new_scene",
|
||||
scene_name="新场景",
|
||||
required_slots=["course_type"],
|
||||
optional_slots=["course_type"],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="交叉"):
|
||||
await service.create_bundle("test_tenant", bundle_create)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_bundle_success(self, mock_session, sample_bundle):
|
||||
"""Test updating a scene slot bundle successfully."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = sample_bundle
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = []
|
||||
|
||||
bundle_update = SceneSlotBundleUpdate(
|
||||
scene_name="更新后的场景名称",
|
||||
)
|
||||
|
||||
bundle = await service.update_bundle("test_tenant", str(sample_bundle.id), bundle_update)
|
||||
|
||||
assert bundle is not None
|
||||
assert bundle.scene_name == "更新后的场景名称"
|
||||
assert bundle.version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_bundle_success(self, mock_session, sample_bundle):
|
||||
"""Test deleting a scene slot bundle successfully."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = sample_bundle
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
success = await service.delete_bundle("test_tenant", str(sample_bundle.id))
|
||||
|
||||
assert success is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_bundle_not_found(self, mock_session):
|
||||
"""Test deleting a non-existent bundle."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
success = await service.delete_bundle("test_tenant", str(uuid4()))
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
class TestSceneSlotBundleValidation:
|
||||
"""Test cases for bundle validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_required_optional_overlap(self, mock_session, sample_slot_definition):
|
||||
"""Test validation for required and optional slots overlap."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
errors = await service._validate_bundle_data(
|
||||
tenant_id="test_tenant",
|
||||
required_slots=["course_type"],
|
||||
optional_slots=["course_type"],
|
||||
slot_priority=None,
|
||||
)
|
||||
|
||||
assert len(errors) > 0
|
||||
assert any("交叉" in e for e in errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_priority_unknown_slots(self, mock_session, sample_slot_definition):
|
||||
"""Test validation for unknown slots in priority list."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
service = SceneSlotBundleService(mock_session)
|
||||
|
||||
errors = await service._validate_bundle_data(
|
||||
tenant_id="test_tenant",
|
||||
required_slots=["course_type"],
|
||||
optional_slots=[],
|
||||
slot_priority=["course_type", "unknown_slot"],
|
||||
)
|
||||
|
||||
assert len(errors) > 0
|
||||
assert any("未定义" in e for e in errors)
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Unit tests for SemanticMatcher.
|
||||
[AC-AISVC-113, AC-AISVC-114] Tests for semantic matching.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from app.services.intent.semantic_matcher import SemanticMatcher
|
||||
from app.services.intent.models import (
|
||||
FusionConfig,
|
||||
SemanticCandidate,
|
||||
SemanticMatchResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_provider():
|
||||
"""Create a mock embedding provider."""
|
||||
provider = AsyncMock()
|
||||
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
||||
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768, [0.2] * 768])
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rule():
|
||||
"""Create a mock intent rule with semantic config."""
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = "Test Intent"
|
||||
rule.intent_vector = [0.1] * 768
|
||||
rule.semantic_examples = None
|
||||
rule.is_enabled = True
|
||||
return rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rule_with_examples():
|
||||
"""Create a mock intent rule with semantic examples."""
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = "Test Intent with Examples"
|
||||
rule.intent_vector = None
|
||||
rule.semantic_examples = ["我想退货", "如何退款"]
|
||||
rule.is_enabled = True
|
||||
return rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a fusion config."""
|
||||
return FusionConfig()
|
||||
|
||||
|
||||
class TestSemanticMatcher:
|
||||
"""Tests for SemanticMatcher class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init(self, mock_embedding_provider, config):
|
||||
"""Test SemanticMatcher initialization."""
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
assert matcher._embedding_provider == mock_embedding_provider
|
||||
assert matcher._config == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_disabled(self, mock_embedding_provider):
|
||||
"""Test match when semantic matcher is disabled."""
|
||||
config = FusionConfig(semantic_matcher_enabled=False)
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
|
||||
result = await matcher.match("test message", [], "tenant-1")
|
||||
|
||||
assert result.skipped is True
|
||||
assert result.skip_reason == "disabled"
|
||||
assert result.candidates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_no_semantic_config(
|
||||
self, mock_embedding_provider, config, mock_rule
|
||||
):
|
||||
"""Test match when no rules have semantic config."""
|
||||
mock_rule.intent_vector = None
|
||||
mock_rule.semantic_examples = None
|
||||
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("test message", [mock_rule], "tenant-1")
|
||||
|
||||
assert result.skipped is True
|
||||
assert result.skip_reason == "no_semantic_config"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_mode_a_with_intent_vector(
|
||||
self, mock_embedding_provider, config, mock_rule
|
||||
):
|
||||
"""Test match with pre-computed intent vector (Mode A)."""
|
||||
mock_embedding_provider.embed.return_value = [0.1] * 768
|
||||
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("我想退货", [mock_rule], "tenant-1")
|
||||
|
||||
assert result.skipped is False
|
||||
assert result.skip_reason is None
|
||||
assert len(result.candidates) == 1
|
||||
assert result.top_score > 0.9
|
||||
assert result.duration_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_mode_b_with_examples(
|
||||
self, mock_embedding_provider, config, mock_rule_with_examples
|
||||
):
|
||||
"""Test match with semantic examples (Mode B)."""
|
||||
mock_embedding_provider.embed.return_value = [0.1] * 768
|
||||
mock_embedding_provider.embed_batch.return_value = [[0.1] * 768, [0.1] * 768]
|
||||
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("我想退货", [mock_rule_with_examples], "tenant-1")
|
||||
|
||||
assert result.skipped is False
|
||||
assert len(result.candidates) == 1
|
||||
assert result.top_score > 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_embedding_timeout(self, mock_embedding_provider, config, mock_rule):
|
||||
"""Test match when embedding times out."""
|
||||
import asyncio
|
||||
mock_embedding_provider.embed.side_effect = asyncio.TimeoutError()
|
||||
|
||||
config = FusionConfig(semantic_matcher_timeout_ms=100)
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("test message", [mock_rule], "tenant-1")
|
||||
|
||||
assert result.skipped is True
|
||||
assert "embedding_timeout" in result.skip_reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_embedding_error(self, mock_embedding_provider, config, mock_rule):
|
||||
"""Test match when embedding fails with error."""
|
||||
mock_embedding_provider.embed.side_effect = Exception("Embedding failed")
|
||||
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("test message", [mock_rule], "tenant-1")
|
||||
|
||||
assert result.skipped is True
|
||||
assert "embedding_error" in result.skip_reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_top_k_limit(self, mock_embedding_provider, config):
|
||||
"""Test that match returns only top_k candidates."""
|
||||
rules = []
|
||||
for i in range(5):
|
||||
rule = MagicMock()
|
||||
rule.id = uuid.uuid4()
|
||||
rule.name = f"Intent {i}"
|
||||
rule.intent_vector = [0.1 + i * 0.01] * 768
|
||||
rule.semantic_examples = None
|
||||
rule.is_enabled = True
|
||||
rules.append(rule)
|
||||
|
||||
mock_embedding_provider.embed.return_value = [0.1] * 768
|
||||
|
||||
config = FusionConfig(semantic_top_k=3)
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
result = await matcher.match("test message", rules, "tenant-1")
|
||||
|
||||
assert len(result.candidates) <= 3
|
||||
|
||||
def test_cosine_similarity(self, mock_embedding_provider, config):
|
||||
"""Test cosine similarity calculation."""
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
|
||||
v1 = [1.0, 0.0, 0.0]
|
||||
v2 = [1.0, 0.0, 0.0]
|
||||
similarity = matcher._cosine_similarity(v1, v2)
|
||||
assert similarity == 1.0
|
||||
|
||||
v1 = [1.0, 0.0, 0.0]
|
||||
v2 = [0.0, 1.0, 0.0]
|
||||
similarity = matcher._cosine_similarity(v1, v2)
|
||||
assert similarity == 0.0
|
||||
|
||||
v1 = [1.0, 1.0, 0.0]
|
||||
v2 = [1.0, 0.0, 0.0]
|
||||
similarity = matcher._cosine_similarity(v1, v2)
|
||||
assert 0.0 < similarity < 1.0
|
||||
|
||||
def test_cosine_similarity_empty_vectors(self, mock_embedding_provider, config):
|
||||
"""Test cosine similarity with empty vectors."""
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
|
||||
assert matcher._cosine_similarity([], [1.0]) == 0.0
|
||||
assert matcher._cosine_similarity([1.0], []) == 0.0
|
||||
assert matcher._cosine_similarity([], []) == 0.0
|
||||
|
||||
def test_has_semantic_config(self, mock_embedding_provider, config, mock_rule):
|
||||
"""Test checking if rule has semantic config."""
|
||||
matcher = SemanticMatcher(mock_embedding_provider, config)
|
||||
|
||||
mock_rule.intent_vector = [0.1] * 768
|
||||
mock_rule.semantic_examples = None
|
||||
assert matcher._has_semantic_config(mock_rule) is True
|
||||
|
||||
mock_rule.intent_vector = None
|
||||
mock_rule.semantic_examples = ["example"]
|
||||
assert matcher._has_semantic_config(mock_rule) is True
|
||||
|
||||
mock_rule.intent_vector = None
|
||||
mock_rule.semantic_examples = None
|
||||
assert matcher._has_semantic_config(mock_rule) is False
|
||||
|
|
@ -0,0 +1,419 @@
|
|||
"""
|
||||
Tests for Slot Backfill Service.
|
||||
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_backfill_service import (
|
||||
BackfillResult,
|
||||
BackfillStatus,
|
||||
BatchBackfillResult,
|
||||
SlotBackfillService,
|
||||
create_slot_backfill_service,
|
||||
)
|
||||
from app.services.mid.slot_manager import SlotWriteResult
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
StrategyChainResult,
|
||||
StrategyStepResult,
|
||||
)
|
||||
|
||||
|
||||
class TestBackfillResult:
|
||||
"""BackfillResult 测试"""
|
||||
|
||||
def test_is_success(self):
|
||||
"""测试成功判断"""
|
||||
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||||
assert result.is_success() is True
|
||||
|
||||
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
|
||||
assert result.is_success() is False
|
||||
|
||||
def test_needs_ask_back(self):
|
||||
"""测试需要追问判断"""
|
||||
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
|
||||
assert result.needs_ask_back() is True
|
||||
|
||||
result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED)
|
||||
assert result.needs_ask_back() is True
|
||||
|
||||
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||||
assert result.needs_ask_back() is False
|
||||
|
||||
def test_needs_confirmation(self):
|
||||
"""测试需要确认判断"""
|
||||
result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION)
|
||||
assert result.needs_confirmation() is True
|
||||
|
||||
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||||
assert result.needs_confirmation() is False
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key="region",
|
||||
value="北京",
|
||||
normalized_value="北京",
|
||||
source="user_confirmed",
|
||||
confidence=1.0,
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["status"] == "success"
|
||||
assert d["slot_key"] == "region"
|
||||
assert d["value"] == "北京"
|
||||
assert d["source"] == "user_confirmed"
|
||||
|
||||
|
||||
class TestBatchBackfillResult:
|
||||
"""BatchBackfillResult 测试"""
|
||||
|
||||
def test_add_result(self):
|
||||
"""测试添加结果"""
|
||||
batch = BatchBackfillResult()
|
||||
|
||||
batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region"))
|
||||
batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product"))
|
||||
batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade"))
|
||||
|
||||
assert batch.success_count == 1
|
||||
assert batch.failed_count == 1
|
||||
assert batch.confirmation_needed_count == 1
|
||||
|
||||
def test_get_ask_back_prompts(self):
|
||||
"""测试获取追问提示"""
|
||||
batch = BatchBackfillResult()
|
||||
|
||||
batch.add_result(BackfillResult(
|
||||
status=BackfillStatus.VALIDATION_FAILED,
|
||||
ask_back_prompt="请重新输入",
|
||||
))
|
||||
batch.add_result(BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
))
|
||||
batch.add_result(BackfillResult(
|
||||
status=BackfillStatus.EXTRACTION_FAILED,
|
||||
ask_back_prompt="无法识别,请重试",
|
||||
))
|
||||
|
||||
prompts = batch.get_ask_back_prompts()
|
||||
assert len(prompts) == 2
|
||||
assert "请重新输入" in prompts
|
||||
assert "无法识别,请重试" in prompts
|
||||
|
||||
def test_get_confirmation_prompts(self):
|
||||
"""测试获取确认提示"""
|
||||
batch = BatchBackfillResult()
|
||||
|
||||
batch.add_result(BackfillResult(
|
||||
status=BackfillStatus.NEEDS_CONFIRMATION,
|
||||
confirmation_prompt="我理解您说的是「北京」,对吗?",
|
||||
))
|
||||
batch.add_result(BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
))
|
||||
|
||||
prompts = batch.get_confirmation_prompts()
|
||||
assert len(prompts) == 1
|
||||
assert "北京" in prompts[0]
|
||||
|
||||
|
||||
class TestSlotBackfillService:
|
||||
"""SlotBackfillService 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""创建 mock session"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_slot_manager(self):
|
||||
"""创建 mock slot manager"""
|
||||
manager = MagicMock()
|
||||
manager.write_slot = AsyncMock()
|
||||
manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息")
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session, mock_slot_manager):
|
||||
"""创建服务实例"""
|
||||
return SlotBackfillService(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
slot_manager=mock_slot_manager,
|
||||
)
|
||||
|
||||
def test_confidence_thresholds(self, service):
|
||||
"""测试置信度阈值"""
|
||||
assert service.CONFIDENCE_THRESHOLD_LOW == 0.5
|
||||
assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8
|
||||
|
||||
def test_get_source_for_strategy(self, service):
|
||||
"""测试策略到来源的映射"""
|
||||
assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
|
||||
assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
|
||||
assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
|
||||
assert service._get_source_for_strategy("unknown") == "unknown"
|
||||
|
||||
def test_get_confidence_for_strategy(self, service):
|
||||
"""测试来源到置信度的映射"""
|
||||
assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0
|
||||
assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9
|
||||
assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7
|
||||
assert service._get_confidence_for_strategy("context") == 0.5
|
||||
assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3
|
||||
|
||||
def test_generate_confirmation_prompt(self, service):
|
||||
"""测试生成确认提示"""
|
||||
prompt = service._generate_confirmation_prompt("region", "北京")
|
||||
assert "北京" in prompt
|
||||
assert "对吗" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_single_slot_success(self, service, mock_slot_manager):
|
||||
"""测试单个槽位回填成功"""
|
||||
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||||
success=True,
|
||||
slot_key="region",
|
||||
value="北京",
|
||||
)
|
||||
|
||||
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||||
mock_aggregator = AsyncMock()
|
||||
mock_aggregator.update_slot = AsyncMock()
|
||||
mock_agg.return_value = mock_aggregator
|
||||
|
||||
result = await service.backfill_single_slot(
|
||||
slot_key="region",
|
||||
candidate_value="北京",
|
||||
source="user_confirmed",
|
||||
confidence=1.0,
|
||||
)
|
||||
|
||||
assert result.status == BackfillStatus.SUCCESS
|
||||
assert result.slot_key == "region"
|
||||
assert result.normalized_value == "北京"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager):
|
||||
"""测试单个槽位回填校验失败"""
|
||||
from app.services.mid.slot_validation_service import SlotValidationError
|
||||
|
||||
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||||
success=False,
|
||||
slot_key="region",
|
||||
error=SlotValidationError(
|
||||
slot_key="region",
|
||||
error_code="INVALID_VALUE",
|
||||
error_message="无效的地区",
|
||||
),
|
||||
ask_back_prompt="请提供有效的地区",
|
||||
)
|
||||
|
||||
result = await service.backfill_single_slot(
|
||||
slot_key="region",
|
||||
candidate_value="无效地区",
|
||||
source="user_confirmed",
|
||||
confidence=1.0,
|
||||
)
|
||||
|
||||
assert result.status == BackfillStatus.VALIDATION_FAILED
|
||||
assert result.ask_back_prompt == "请提供有效的地区"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager):
|
||||
"""测试低置信度槽位需要确认"""
|
||||
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||||
success=True,
|
||||
slot_key="region",
|
||||
value="北京",
|
||||
)
|
||||
|
||||
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||||
mock_aggregator = AsyncMock()
|
||||
mock_aggregator.update_slot = AsyncMock()
|
||||
mock_agg.return_value = mock_aggregator
|
||||
|
||||
result = await service.backfill_single_slot(
|
||||
slot_key="region",
|
||||
candidate_value="北京",
|
||||
source="llm_inferred",
|
||||
confidence=0.4,
|
||||
)
|
||||
|
||||
assert result.status == BackfillStatus.NEEDS_CONFIRMATION
|
||||
assert result.confirmation_prompt is not None
|
||||
assert "北京" in result.confirmation_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_multiple_slots(self, service, mock_slot_manager):
|
||||
"""测试批量回填槽位"""
|
||||
mock_slot_manager.write_slot.side_effect = [
|
||||
SlotWriteResult(success=True, slot_key="region", value="北京"),
|
||||
SlotWriteResult(success=True, slot_key="product", value="手机"),
|
||||
SlotWriteResult(success=False, slot_key="grade", error=MagicMock()),
|
||||
]
|
||||
|
||||
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||||
mock_aggregator = AsyncMock()
|
||||
mock_aggregator.update_slot = AsyncMock()
|
||||
mock_agg.return_value = mock_aggregator
|
||||
|
||||
result = await service.backfill_multiple_slots(
|
||||
candidates={
|
||||
"region": "北京",
|
||||
"product": "手机",
|
||||
"grade": "无效等级",
|
||||
},
|
||||
source="user_confirmed",
|
||||
)
|
||||
|
||||
assert result.success_count == 2
|
||||
assert result.failed_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_low_confidence_slot_confirmed(self, service):
|
||||
"""测试确认低置信度槽位 - 用户确认"""
|
||||
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||||
mock_aggregator = AsyncMock()
|
||||
mock_aggregator.update_slot = AsyncMock()
|
||||
mock_agg.return_value = mock_aggregator
|
||||
|
||||
result = await service.confirm_low_confidence_slot(
|
||||
slot_key="region",
|
||||
confirmed=True,
|
||||
)
|
||||
|
||||
assert result.status == BackfillStatus.SUCCESS
|
||||
assert result.source == SlotSource.USER_CONFIRMED.value
|
||||
assert result.confidence == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager):
|
||||
"""测试确认低置信度槽位 - 用户拒绝"""
|
||||
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||||
mock_aggregator = AsyncMock()
|
||||
mock_aggregator.clear_slot = AsyncMock()
|
||||
mock_agg.return_value = mock_aggregator
|
||||
|
||||
result = await service.confirm_low_confidence_slot(
|
||||
slot_key="region",
|
||||
confirmed=False,
|
||||
)
|
||||
|
||||
assert result.status == BackfillStatus.VALIDATION_FAILED
|
||||
assert result.ask_back_prompt is not None
|
||||
|
||||
|
||||
class TestCreateSlotBackfillService:
|
||||
"""create_slot_backfill_service 工厂函数测试"""
|
||||
|
||||
def test_create(self):
|
||||
"""测试创建服务实例"""
|
||||
mock_session = AsyncMock()
|
||||
service = create_slot_backfill_service(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
)
|
||||
assert isinstance(service, SlotBackfillService)
|
||||
assert service._tenant_id == "tenant_1"
|
||||
assert service._session_id == "session_1"
|
||||
|
||||
|
||||
class TestBackfillFromUserResponse:
|
||||
"""从用户回复回填测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""创建服务实例"""
|
||||
mock_session = AsyncMock()
|
||||
mock_slot_def_service = AsyncMock()
|
||||
|
||||
service = SlotBackfillService(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
)
|
||||
service._slot_def_service = mock_slot_def_service
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_from_user_response_success(self, service):
|
||||
"""测试从用户回复成功提取并回填"""
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.validation_rule = None
|
||||
mock_slot_def.ask_back_prompt = "请提供地区"
|
||||
|
||||
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=mock_slot_def
|
||||
)
|
||||
|
||||
with patch.object(service, '_extract_value') as mock_extract:
|
||||
mock_extract.return_value = StrategyChainResult(
|
||||
slot_key="region",
|
||||
success=True,
|
||||
final_value="北京",
|
||||
final_strategy="rule",
|
||||
)
|
||||
|
||||
with patch.object(service, 'backfill_single_slot') as mock_backfill:
|
||||
mock_backfill.return_value = BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key="region",
|
||||
value="北京",
|
||||
)
|
||||
|
||||
result = await service.backfill_from_user_response(
|
||||
user_response="我想查询北京的产品",
|
||||
expected_slots=["region"],
|
||||
)
|
||||
|
||||
assert result.success_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_from_user_response_no_definition(self, service):
|
||||
"""测试槽位定义不存在"""
|
||||
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
|
||||
result = await service.backfill_from_user_response(
|
||||
user_response="我想查询北京的产品",
|
||||
expected_slots=["unknown_slot"],
|
||||
)
|
||||
|
||||
assert result.success_count == 0
|
||||
assert result.failed_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_from_user_response_extraction_failed(self, service):
|
||||
"""测试提取失败"""
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.validation_rule = None
|
||||
mock_slot_def.ask_back_prompt = "请提供地区"
|
||||
|
||||
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=mock_slot_def
|
||||
)
|
||||
|
||||
with patch.object(service, '_extract_value') as mock_extract:
|
||||
mock_extract.return_value = StrategyChainResult(
|
||||
slot_key="region",
|
||||
success=False,
|
||||
)
|
||||
|
||||
result = await service.backfill_from_user_response(
|
||||
user_response="我想查询产品",
|
||||
expected_slots=["region"],
|
||||
)
|
||||
|
||||
assert result.failed_count == 1
|
||||
assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
"""
|
||||
Tests for Slot Extraction Integration.
|
||||
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_extraction_integration import (
|
||||
ExtractionResult,
|
||||
ExtractionTrace,
|
||||
SlotExtractionIntegration,
|
||||
integrate_slot_extraction,
|
||||
)
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
StrategyChainResult,
|
||||
StrategyStepResult,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractionTrace:
|
||||
"""ExtractionTrace 测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
trace = ExtractionTrace(slot_key="region")
|
||||
assert trace.slot_key == "region"
|
||||
assert trace.strategy is None
|
||||
assert trace.validation_passed is False
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
trace = ExtractionTrace(
|
||||
slot_key="region",
|
||||
strategy="rule",
|
||||
extracted_value="北京",
|
||||
validation_passed=True,
|
||||
final_value="北京",
|
||||
execution_time_ms=10.5,
|
||||
)
|
||||
d = trace.to_dict()
|
||||
assert d["slot_key"] == "region"
|
||||
assert d["strategy"] == "rule"
|
||||
assert d["extracted_value"] == "北京"
|
||||
assert d["validation_passed"] is True
|
||||
|
||||
|
||||
class TestExtractionResult:
|
||||
"""ExtractionResult 测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
result = ExtractionResult()
|
||||
assert result.success is False
|
||||
assert result.extracted_slots == {}
|
||||
assert result.failed_slots == []
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
result = ExtractionResult(
|
||||
success=True,
|
||||
extracted_slots={"region": "北京"},
|
||||
failed_slots=["product"],
|
||||
traces=[ExtractionTrace(slot_key="region")],
|
||||
total_execution_time_ms=50.0,
|
||||
ask_back_triggered=True,
|
||||
ask_back_prompts=["请提供产品信息"],
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["success"] is True
|
||||
assert d["extracted_slots"] == {"region": "北京"}
|
||||
assert d["failed_slots"] == ["product"]
|
||||
assert d["ask_back_triggered"] is True
|
||||
|
||||
|
||||
class TestSlotExtractionIntegration:
|
||||
"""SlotExtractionIntegration 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""创建 mock session"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def integration(self, mock_session):
|
||||
"""创建集成实例"""
|
||||
return SlotExtractionIntegration(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
)
|
||||
|
||||
def test_default_strategies(self, integration):
|
||||
"""测试默认策略"""
|
||||
assert integration.DEFAULT_STRATEGIES == ["rule", "llm"]
|
||||
|
||||
def test_get_source_for_strategy(self, integration):
|
||||
"""测试策略到来源的映射"""
|
||||
assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
|
||||
assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
|
||||
assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
|
||||
|
||||
def test_get_confidence_for_source(self, integration):
|
||||
"""测试来源到置信度的映射"""
|
||||
assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0
|
||||
assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9
|
||||
assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_fill_no_target_slots(self, integration):
|
||||
"""测试没有目标槽位"""
|
||||
result = await integration.extract_and_fill(
|
||||
user_input="测试输入",
|
||||
target_slots=[],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.extracted_slots == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_fill_slot_not_found(self, integration):
|
||||
"""测试槽位定义不存在"""
|
||||
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None)
|
||||
|
||||
result = await integration.extract_and_fill(
|
||||
user_input="测试输入",
|
||||
target_slots=["unknown_slot"],
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "unknown_slot" in result.failed_slots
|
||||
assert result.traces[0].failure_reason == "Slot definition not found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_fill_extraction_success(self, integration):
|
||||
"""测试提取成功"""
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.validation_rule = None
|
||||
mock_slot_def.ask_back_prompt = "请提供地区"
|
||||
|
||||
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=mock_slot_def
|
||||
)
|
||||
|
||||
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
||||
mock_chain.return_value = StrategyChainResult(
|
||||
slot_key="region",
|
||||
success=True,
|
||||
final_value="北京",
|
||||
final_strategy="rule",
|
||||
)
|
||||
|
||||
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
|
||||
mock_backfill = AsyncMock()
|
||||
mock_backfill.backfill_single_slot = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
is_success=lambda: True,
|
||||
normalized_value="北京",
|
||||
error_message=None,
|
||||
)
|
||||
)
|
||||
mock_backfill_svc.return_value = mock_backfill
|
||||
|
||||
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
|
||||
result = await integration.extract_and_fill(
|
||||
user_input="我想查询北京的产品",
|
||||
target_slots=["region"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "region" in result.extracted_slots
|
||||
assert result.extracted_slots["region"] == "北京"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_and_fill_extraction_failed(self, integration):
|
||||
"""测试提取失败"""
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.validation_rule = None
|
||||
mock_slot_def.ask_back_prompt = "请提供地区"
|
||||
|
||||
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=mock_slot_def
|
||||
)
|
||||
|
||||
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
||||
mock_chain.return_value = StrategyChainResult(
|
||||
slot_key="region",
|
||||
success=False,
|
||||
steps=[
|
||||
StrategyStepResult(
|
||||
strategy="rule",
|
||||
success=False,
|
||||
failure_reason="无法提取",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
integration._slot_manager.get_ask_back_prompt = AsyncMock(
|
||||
return_value="请提供地区"
|
||||
)
|
||||
|
||||
result = await integration.extract_and_fill(
|
||||
user_input="测试输入",
|
||||
target_slots=["region"],
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "region" in result.failed_slots
|
||||
assert result.ask_back_triggered is True
|
||||
assert "请提供地区" in result.ask_back_prompts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_required_slots_from_state(self, integration):
|
||||
"""测试从状态获取缺失槽位"""
|
||||
from app.services.mid.slot_state_aggregator import SlotState
|
||||
|
||||
slot_state = SlotState()
|
||||
slot_state.missing_required_slots = [
|
||||
{"slot_key": "region"},
|
||||
{"slot_key": "product"},
|
||||
]
|
||||
|
||||
result = await integration._get_missing_required_slots(slot_state)
|
||||
|
||||
assert "region" in result
|
||||
assert "product" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_required_slots_from_db(self, integration):
|
||||
"""测试从数据库获取缺失槽位"""
|
||||
mock_defs = [
|
||||
MagicMock(slot_key="region"),
|
||||
MagicMock(slot_key="product"),
|
||||
]
|
||||
|
||||
integration._slot_def_service.list_slot_definitions = AsyncMock(
|
||||
return_value=mock_defs
|
||||
)
|
||||
|
||||
result = await integration._get_missing_required_slots(None)
|
||||
|
||||
assert "region" in result
|
||||
assert "product" in result
|
||||
|
||||
|
||||
class TestIntegrateSlotExtraction:
|
||||
"""integrate_slot_extraction 便捷函数测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integrate(self):
|
||||
"""测试便捷函数"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.extract_and_fill = AsyncMock(
|
||||
return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"})
|
||||
)
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
result = await integrate_slot_extraction(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
user_input="我想查询北京的产品",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "region" in result.extracted_slots
|
||||
|
||||
|
||||
class TestExtractionTraceFlow:
|
||||
"""提取追踪流程测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def integration(self):
|
||||
"""创建集成实例"""
|
||||
mock_session = AsyncMock()
|
||||
return SlotExtractionIntegration(
|
||||
session=mock_session,
|
||||
tenant_id="tenant_1",
|
||||
session_id="session_1",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_extraction_flow(self, integration):
|
||||
"""测试完整提取流程"""
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.type = "string"
|
||||
mock_slot_def.validation_rule = None
|
||||
mock_slot_def.ask_back_prompt = None
|
||||
|
||||
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||||
return_value=mock_slot_def
|
||||
)
|
||||
|
||||
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
||||
mock_chain.return_value = StrategyChainResult(
|
||||
slot_key="region",
|
||||
success=True,
|
||||
final_value="北京",
|
||||
final_strategy="rule",
|
||||
)
|
||||
|
||||
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
|
||||
mock_backfill = AsyncMock()
|
||||
mock_backfill.backfill_single_slot = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
is_success=lambda: True,
|
||||
normalized_value="北京",
|
||||
error_message=None,
|
||||
)
|
||||
)
|
||||
mock_backfill_svc.return_value = mock_backfill
|
||||
|
||||
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
|
||||
result = await integration.extract_and_fill(
|
||||
user_input="北京",
|
||||
target_slots=["region"],
|
||||
)
|
||||
|
||||
assert len(result.traces) == 1
|
||||
trace = result.traces[0]
|
||||
assert trace.slot_key == "region"
|
||||
assert trace.strategy == "rule"
|
||||
assert trace.extracted_value == "北京"
|
||||
assert trace.validation_passed is True
|
||||
assert trace.final_value == "北京"
|
||||
|
|
@ -0,0 +1,256 @@
|
|||
"""
|
||||
Tests for Slot State Aggregator.
|
||||
[AC-MRS-SLOT-META-01] 槽位状态聚合服务测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.models.mid.schemas import MemorySlot, SlotSource
|
||||
from app.services.mid.slot_state_aggregator import (
|
||||
SlotState,
|
||||
SlotStateAggregator,
|
||||
create_slot_state_aggregator,
|
||||
)
|
||||
|
||||
|
||||
class TestSlotState:
|
||||
"""测试 SlotState 数据类"""
|
||||
|
||||
def test_slot_state_initialization(self):
|
||||
"""测试 SlotState 初始化"""
|
||||
state = SlotState()
|
||||
assert state.filled_slots == {}
|
||||
assert state.missing_required_slots == []
|
||||
assert state.slot_sources == {}
|
||||
assert state.slot_confidence == {}
|
||||
assert state.slot_to_field_map == {}
|
||||
|
||||
def test_get_value_for_filter_direct_match(self):
|
||||
"""测试直接匹配获取过滤值"""
|
||||
state = SlotState(
|
||||
filled_slots={"product_line": "vip_course"},
|
||||
slot_to_field_map={},
|
||||
)
|
||||
value = state.get_value_for_filter("product_line")
|
||||
assert value == "vip_course"
|
||||
|
||||
def test_get_value_for_filter_via_mapping(self):
|
||||
"""测试通过映射获取过滤值"""
|
||||
state = SlotState(
|
||||
filled_slots={"product": "vip_course"},
|
||||
slot_to_field_map={"product": "product_line"},
|
||||
)
|
||||
value = state.get_value_for_filter("product_line")
|
||||
assert value == "vip_course"
|
||||
|
||||
def test_get_value_for_filter_not_found(self):
|
||||
"""测试获取不存在的过滤值"""
|
||||
state = SlotState(filled_slots={})
|
||||
value = state.get_value_for_filter("non_existent")
|
||||
assert value is None
|
||||
|
||||
def test_to_debug_info(self):
|
||||
"""测试转换为调试信息"""
|
||||
state = SlotState(
|
||||
filled_slots={"key": "value"},
|
||||
missing_required_slots=[{"slot_key": "missing"}],
|
||||
)
|
||||
debug_info = state.to_debug_info()
|
||||
assert debug_info["filled_slots"] == {"key": "value"}
|
||||
assert len(debug_info["missing_required_slots"]) == 1
|
||||
|
||||
|
||||
class TestSlotStateAggregator:
|
||||
"""测试 SlotStateAggregator"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""模拟数据库会话"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def aggregator(self, mock_session):
|
||||
"""创建聚合器实例"""
|
||||
return SlotStateAggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_from_memory_slots(self, aggregator, mock_session):
|
||||
"""测试从 memory_slots 初始化"""
|
||||
memory_slots = {
|
||||
"product_line": MemorySlot(
|
||||
key="product_line",
|
||||
value="vip_course",
|
||||
source=SlotSource.USER_CONFIRMED,
|
||||
confidence=1.0,
|
||||
)
|
||||
}
|
||||
|
||||
# 模拟槽位定义服务返回空列表(没有 required 槽位)
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots=memory_slots,
|
||||
current_input_slots=None,
|
||||
context=None,
|
||||
)
|
||||
|
||||
assert state.filled_slots["product_line"] == "vip_course"
|
||||
assert state.slot_sources["product_line"] == "user_confirmed"
|
||||
assert state.slot_confidence["product_line"] == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_current_input_priority(self, aggregator, mock_session):
|
||||
"""测试当前输入优先级高于 memory"""
|
||||
memory_slots = {
|
||||
"product_line": MemorySlot(
|
||||
key="product_line",
|
||||
value="old_value",
|
||||
source=SlotSource.USER_CONFIRMED,
|
||||
confidence=1.0,
|
||||
)
|
||||
}
|
||||
current_input = {"product_line": "new_value"}
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots=memory_slots,
|
||||
current_input_slots=current_input,
|
||||
context=None,
|
||||
)
|
||||
|
||||
# 当前输入应该覆盖 memory 的值
|
||||
assert state.filled_slots["product_line"] == "new_value"
|
||||
assert state.slot_sources["product_line"] == "user_confirmed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_extract_from_context(self, aggregator, mock_session):
|
||||
"""测试从 context 提取槽位值"""
|
||||
context = {
|
||||
"scene": "open_consult",
|
||||
"product_line": "vip_course",
|
||||
"other_key": "other_value",
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots=None,
|
||||
current_input_slots=None,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# 应该提取 scene 和 product_line
|
||||
assert state.filled_slots.get("scene") == "open_consult"
|
||||
assert state.filled_slots.get("product_line") == "vip_course"
|
||||
assert state.slot_sources.get("scene") == "context"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_response_with_prompt(self, aggregator):
|
||||
"""测试生成追问响应 - 使用配置的 ask_back_prompt"""
|
||||
state = SlotState(
|
||||
missing_required_slots=[
|
||||
{
|
||||
"slot_key": "region",
|
||||
"label": "地区",
|
||||
"reason": "required_slot_missing",
|
||||
"ask_back_prompt": "请问您在哪个地区?",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
response = await aggregator.generate_ask_back_response(state)
|
||||
assert response == "请问您在哪个地区?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_response_generic(self, aggregator):
|
||||
"""测试生成追问响应 - 使用通用模板"""
|
||||
state = SlotState(
|
||||
missing_required_slots=[
|
||||
{
|
||||
"slot_key": "region",
|
||||
"label": "地区",
|
||||
"reason": "required_slot_missing",
|
||||
# 没有 ask_back_prompt
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
response = await aggregator.generate_ask_back_response(state)
|
||||
assert "地区" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ask_back_response_no_missing(self, aggregator):
|
||||
"""测试没有缺失槽位时返回 None"""
|
||||
state = SlotState(missing_required_slots=[])
|
||||
response = await aggregator.generate_ask_back_response(state)
|
||||
assert response is None
|
||||
|
||||
|
||||
class TestCreateSlotStateAggregator:
|
||||
"""测试工厂函数"""
|
||||
|
||||
def test_create_aggregator(self):
|
||||
"""测试创建聚合器实例"""
|
||||
mock_session = MagicMock()
|
||||
aggregator = create_slot_state_aggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
assert isinstance(aggregator, SlotStateAggregator)
|
||||
assert aggregator._tenant_id == "test_tenant"
|
||||
|
||||
|
||||
class TestSlotStateFilterPriority:
|
||||
"""测试过滤值来源优先级"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_priority_slot_first(self):
|
||||
"""测试优先级:slot > context > default"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = AsyncMock()
|
||||
aggregator = SlotStateAggregator(
|
||||
session=mock_session,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
# 模拟槽位定义
|
||||
mock_slot_def = MagicMock()
|
||||
mock_slot_def.slot_key = "product"
|
||||
mock_slot_def.linked_field_id = None
|
||||
mock_slot_def.required = False
|
||||
|
||||
with patch.object(
|
||||
aggregator._slot_def_service,
|
||||
"list_slot_definitions",
|
||||
return_value=[mock_slot_def],
|
||||
):
|
||||
state = await aggregator.aggregate(
|
||||
memory_slots={
|
||||
"product": MemorySlot(
|
||||
key="product",
|
||||
value="from_memory",
|
||||
source=SlotSource.USER_CONFIRMED,
|
||||
confidence=1.0,
|
||||
)
|
||||
},
|
||||
current_input_slots={"product": "from_input"},
|
||||
context={"product": "from_context"},
|
||||
)
|
||||
|
||||
# 当前输入应该优先级最高
|
||||
assert state.filled_slots["product"] == "from_input"
|
||||
|
|
@ -0,0 +1,399 @@
|
|||
"""
|
||||
Tests for Slot State Cache.
|
||||
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化测试
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.cache.slot_state_cache import (
|
||||
CachedSlotState,
|
||||
CachedSlotValue,
|
||||
SlotStateCache,
|
||||
get_slot_state_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestCachedSlotValue:
|
||||
"""CachedSlotValue 测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
value = CachedSlotValue(
|
||||
value="test_value",
|
||||
source="user_confirmed",
|
||||
confidence=0.9,
|
||||
)
|
||||
assert value.value == "test_value"
|
||||
assert value.source == "user_confirmed"
|
||||
assert value.confidence == 0.9
|
||||
assert value.updated_at > 0
|
||||
|
||||
def test_to_dict(self):
|
||||
"""测试转换为字典"""
|
||||
value = CachedSlotValue(
|
||||
value="test_value",
|
||||
source="rule_extracted",
|
||||
confidence=0.8,
|
||||
)
|
||||
d = value.to_dict()
|
||||
assert d["value"] == "test_value"
|
||||
assert d["source"] == "rule_extracted"
|
||||
assert d["confidence"] == 0.8
|
||||
assert "updated_at" in d
|
||||
|
||||
def test_from_dict(self):
|
||||
"""测试从字典创建"""
|
||||
d = {
|
||||
"value": "test_value",
|
||||
"source": "llm_inferred",
|
||||
"confidence": 0.7,
|
||||
"updated_at": 12345.0,
|
||||
}
|
||||
value = CachedSlotValue.from_dict(d)
|
||||
assert value.value == "test_value"
|
||||
assert value.source == "llm_inferred"
|
||||
assert value.confidence == 0.7
|
||||
assert value.updated_at == 12345.0
|
||||
|
||||
|
||||
class TestCachedSlotState:
|
||||
"""CachedSlotState 测试"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
state = CachedSlotState()
|
||||
assert state.filled_slots == {}
|
||||
assert state.slot_to_field_map == {}
|
||||
assert state.created_at > 0
|
||||
assert state.updated_at > 0
|
||||
|
||||
def test_with_slots(self):
|
||||
"""测试带槽位初始化"""
|
||||
slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
||||
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
||||
}
|
||||
state = CachedSlotState(
|
||||
filled_slots=slots,
|
||||
slot_to_field_map={"region": "region_field"},
|
||||
)
|
||||
assert len(state.filled_slots) == 2
|
||||
assert state.slot_to_field_map["region"] == "region_field"
|
||||
|
||||
def test_to_dict_and_from_dict(self):
|
||||
"""测试序列化和反序列化"""
|
||||
slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
||||
}
|
||||
original = CachedSlotState(
|
||||
filled_slots=slots,
|
||||
slot_to_field_map={"region": "region_field"},
|
||||
)
|
||||
|
||||
d = original.to_dict()
|
||||
restored = CachedSlotState.from_dict(d)
|
||||
|
||||
assert len(restored.filled_slots) == 1
|
||||
assert restored.filled_slots["region"].value == "北京"
|
||||
assert restored.filled_slots["region"].source == "user_confirmed"
|
||||
assert restored.slot_to_field_map["region"] == "region_field"
|
||||
|
||||
def test_get_simple_filled_slots(self):
|
||||
"""测试获取简化槽位字典"""
|
||||
slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
||||
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
||||
}
|
||||
state = CachedSlotState(filled_slots=slots)
|
||||
simple = state.get_simple_filled_slots()
|
||||
assert simple == {"region": "北京", "product": "手机"}
|
||||
|
||||
def test_get_slot_sources(self):
|
||||
"""测试获取槽位来源"""
|
||||
slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
||||
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
||||
}
|
||||
state = CachedSlotState(filled_slots=slots)
|
||||
sources = state.get_slot_sources()
|
||||
assert sources == {"region": "user_confirmed", "product": "rule_extracted"}
|
||||
|
||||
def test_get_slot_confidence(self):
|
||||
"""测试获取槽位置信度"""
|
||||
slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
||||
"product": CachedSlotValue(value="手机", source="rule_extracted", confidence=0.8),
|
||||
}
|
||||
state = CachedSlotState(filled_slots=slots)
|
||||
confidence = state.get_slot_confidence()
|
||||
assert confidence == {"region": 1.0, "product": 0.8}
|
||||
|
||||
|
||||
class TestSlotStateCache:
|
||||
"""SlotStateCache 测试"""
|
||||
|
||||
def test_source_priority(self):
|
||||
"""测试来源优先级"""
|
||||
cache = SlotStateCache()
|
||||
assert cache._get_source_priority("user_confirmed") == 100
|
||||
assert cache._get_source_priority("rule_extracted") == 80
|
||||
assert cache._get_source_priority("llm_inferred") == 60
|
||||
assert cache._get_source_priority("context") == 40
|
||||
assert cache._get_source_priority("default") == 20
|
||||
assert cache._get_source_priority("unknown") == 0
|
||||
|
||||
def test_make_key(self):
|
||||
"""测试 key 生成"""
|
||||
cache = SlotStateCache()
|
||||
key = cache._make_key("tenant_123", "session_456")
|
||||
assert key == "slot_state:tenant_123:session_456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_l1_cache_hit(self):
|
||||
"""测试 L1 缓存命中"""
|
||||
cache = SlotStateCache()
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
||||
)
|
||||
|
||||
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, time.time())
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
assert result is not None
|
||||
assert result.filled_slots["region"].value == "北京"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_l1_cache_expired(self):
|
||||
"""测试 L1 缓存过期"""
|
||||
cache = SlotStateCache()
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
||||
)
|
||||
|
||||
old_time = time.time() - 400
|
||||
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, old_time)
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
assert result is None
|
||||
assert f"{tenant_id}:{session_id}" not in cache._local_cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_l1(self):
|
||||
"""测试设置和获取 L1 缓存"""
|
||||
cache = SlotStateCache(redis_client=None)
|
||||
cache._enabled = False
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
||||
)
|
||||
|
||||
await cache.set(tenant_id, session_id, state)
|
||||
|
||||
local_key = f"{tenant_id}:{session_id}"
|
||||
assert local_key in cache._local_cache
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
assert result is not None
|
||||
assert result.filled_slots["region"].value == "北京"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(self):
|
||||
"""测试删除缓存"""
|
||||
cache = SlotStateCache(redis_client=None)
|
||||
cache._enabled = False
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
||||
)
|
||||
|
||||
await cache.set(tenant_id, session_id, state)
|
||||
await cache.delete(tenant_id, session_id)
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_slot(self):
|
||||
"""测试清除单个槽位"""
|
||||
cache = SlotStateCache(redis_client=None)
|
||||
cache._enabled = False
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed"),
|
||||
"product": CachedSlotValue(value="手机", source="rule_extracted"),
|
||||
},
|
||||
)
|
||||
|
||||
await cache.set(tenant_id, session_id, state)
|
||||
await cache.clear_slot(tenant_id, session_id, "region")
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
assert result is not None
|
||||
assert "region" not in result.filled_slots
|
||||
assert "product" in result.filled_slots
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_and_set_priority(self):
|
||||
"""测试合并时优先级处理"""
|
||||
cache = SlotStateCache(redis_client=None)
|
||||
cache._enabled = False
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
existing_state = CachedSlotState(
|
||||
filled_slots={
|
||||
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
|
||||
},
|
||||
)
|
||||
await cache.set(tenant_id, session_id, existing_state)
|
||||
|
||||
new_slots = {
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
||||
}
|
||||
|
||||
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
|
||||
|
||||
assert result.filled_slots["region"].value == "北京"
|
||||
assert result.filled_slots["region"].source == "user_confirmed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_and_set_lower_priority_ignored(self):
|
||||
"""测试低优先级值被忽略"""
|
||||
cache = SlotStateCache(redis_client=None)
|
||||
cache._enabled = False
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
existing_state = CachedSlotState(
|
||||
filled_slots={
|
||||
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
|
||||
},
|
||||
)
|
||||
await cache.set(tenant_id, session_id, existing_state)
|
||||
|
||||
new_slots = {
|
||||
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
|
||||
}
|
||||
|
||||
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
|
||||
|
||||
assert result.filled_slots["region"].value == "北京"
|
||||
assert result.filled_slots["region"].source == "user_confirmed"
|
||||
|
||||
|
||||
class TestGetSlotStateCache:
|
||||
"""get_slot_state_cache 单例测试"""
|
||||
|
||||
def test_singleton(self):
|
||||
"""测试单例模式"""
|
||||
cache1 = get_slot_state_cache()
|
||||
cache2 = get_slot_state_cache()
|
||||
assert cache1 is cache2
|
||||
|
||||
|
||||
class TestSlotStateCacheWithRedis:
|
||||
"""SlotStateCache Redis 集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_set_and_get(self):
|
||||
"""测试 Redis 存取"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock(return_value=True)
|
||||
|
||||
cache = SlotStateCache(redis_client=mock_redis)
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
state = CachedSlotState(
|
||||
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
|
||||
)
|
||||
|
||||
await cache.set(tenant_id, session_id, state)
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][0] == f"slot_state:{tenant_id}:{session_id}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_get_hit(self):
|
||||
"""测试 Redis 命中"""
|
||||
state_dict = {
|
||||
"filled_slots": {
|
||||
"region": {
|
||||
"value": "北京",
|
||||
"source": "user_confirmed",
|
||||
"confidence": 1.0,
|
||||
"updated_at": 12345.0,
|
||||
}
|
||||
},
|
||||
"slot_to_field_map": {"region": "region_field"},
|
||||
"created_at": 12340.0,
|
||||
"updated_at": 12345.0,
|
||||
}
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps(state_dict))
|
||||
|
||||
cache = SlotStateCache(redis_client=mock_redis)
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
result = await cache.get(tenant_id, session_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.filled_slots["region"].value == "北京"
|
||||
assert result.filled_slots["region"].source == "user_confirmed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_delete(self):
|
||||
"""测试 Redis 删除"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
cache = SlotStateCache(redis_client=mock_redis)
|
||||
|
||||
tenant_id = "tenant_1"
|
||||
session_id = "session_1"
|
||||
|
||||
await cache.delete(tenant_id, session_id)
|
||||
|
||||
mock_redis.delete.assert_called_once_with(f"slot_state:{tenant_id}:{session_id}")
|
||||
|
||||
|
||||
class TestCacheTTL:
|
||||
"""TTL 配置测试"""
|
||||
|
||||
def test_default_ttl(self):
|
||||
"""测试默认 TTL"""
|
||||
cache = SlotStateCache()
|
||||
assert cache._cache_ttl == 1800
|
||||
|
||||
def test_local_cache_ttl(self):
|
||||
"""测试本地缓存 TTL"""
|
||||
cache = SlotStateCache()
|
||||
assert cache._local_cache_ttl == 300
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
Tests for Slot Strategy Executor.
|
||||
[AC-MRS-07-UPGRADE] 提取策略链执行器测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.models.entities import ExtractFailureType
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
SlotStrategyExecutor,
|
||||
ExtractContext,
|
||||
StrategyChainResult,
|
||||
execute_extract_strategies,
|
||||
)
|
||||
|
||||
|
||||
class TestSlotStrategyExecutor:
|
||||
"""测试槽位策略执行器"""
|
||||
|
||||
@pytest.fixture
|
||||
def executor(self):
|
||||
"""创建执行器实例"""
|
||||
return SlotStrategyExecutor()
|
||||
|
||||
@pytest.fixture
|
||||
def context(self):
|
||||
"""创建测试上下文"""
|
||||
return ExtractContext(
|
||||
tenant_id="test-tenant",
|
||||
slot_key="grade",
|
||||
user_input="我想了解初一语文课程",
|
||||
slot_type="string",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_success_on_first_step(self, executor, context):
|
||||
"""测试第一步成功时停止"""
|
||||
# Mock rule extractor 成功
|
||||
mock_rule = AsyncMock(return_value="初一")
|
||||
executor._extractors["rule"] = mock_rule
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule", "llm", "user_input"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.final_value == "初一"
|
||||
assert result.final_strategy == "rule"
|
||||
assert len(result.steps) == 1
|
||||
assert result.steps[0].success is True
|
||||
mock_rule.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_fallback_to_second_step(self, executor, context):
|
||||
"""测试第一步失败,第二步成功"""
|
||||
# Mock rule extractor 失败(返回空)
|
||||
mock_rule = AsyncMock(return_value=None)
|
||||
# Mock llm extractor 成功
|
||||
mock_llm = AsyncMock(return_value="初一")
|
||||
|
||||
executor._extractors["rule"] = mock_rule
|
||||
executor._extractors["llm"] = mock_llm
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule", "llm", "user_input"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.final_value == "初一"
|
||||
assert result.final_strategy == "llm"
|
||||
assert len(result.steps) == 2
|
||||
assert result.steps[0].success is False
|
||||
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_EMPTY
|
||||
assert result.steps[1].success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_all_failed(self, executor, context):
|
||||
"""测试所有策略都失败"""
|
||||
# Mock 所有 extractor 都失败
|
||||
mock_rule = AsyncMock(return_value=None)
|
||||
mock_llm = AsyncMock(return_value=None)
|
||||
mock_user_input = AsyncMock(return_value=None)
|
||||
|
||||
executor._extractors["rule"] = mock_rule
|
||||
executor._extractors["llm"] = mock_llm
|
||||
executor._extractors["user_input"] = mock_user_input
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule", "llm", "user_input"],
|
||||
context=context,
|
||||
ask_back_prompt="请告诉我您的年级",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.final_value is None
|
||||
assert result.final_strategy is None
|
||||
assert len(result.steps) == 3
|
||||
assert result.ask_back_prompt == "请告诉我您的年级"
|
||||
|
||||
# 所有步骤都失败
|
||||
for step in result.steps:
|
||||
assert step.success is False
|
||||
assert step.failure_type == ExtractFailureType.EXTRACT_EMPTY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_validation_failure(self, executor, context):
|
||||
"""测试校验失败的情况"""
|
||||
context.validation_rule = r"^初[一二三]$" # 只允许初一/初二/初三
|
||||
|
||||
# Mock rule extractor 返回不符合校验的值
|
||||
mock_rule = AsyncMock(return_value="高一")
|
||||
executor._extractors["rule"] = mock_rule
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_VALIDATION_FAIL
|
||||
assert "Validation failed" in result.steps[0].failure_reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_runtime_error(self, executor, context):
|
||||
"""测试运行时错误"""
|
||||
# Mock rule extractor 抛出异常
|
||||
mock_rule = AsyncMock(side_effect=Exception("LLM service unavailable"))
|
||||
executor._extractors["rule"] = mock_rule
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR
|
||||
assert "Runtime error" in result.steps[0].failure_reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_empty_strategies(self, executor, context):
|
||||
"""测试空策略链"""
|
||||
result = await executor.execute_chain(
|
||||
strategies=[],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert len(result.steps) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_unknown_strategy(self, executor, context):
|
||||
"""测试未知策略"""
|
||||
result = await executor.execute_chain(
|
||||
strategies=["unknown_strategy"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR
|
||||
assert "Unknown strategy" in result.steps[0].failure_reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_chain_result_to_dict(self, executor, context):
|
||||
"""测试结果转换为字典"""
|
||||
mock_rule = AsyncMock(return_value="初一")
|
||||
executor._extractors["rule"] = mock_rule
|
||||
|
||||
result = await executor.execute_chain(
|
||||
strategies=["rule"],
|
||||
context=context,
|
||||
)
|
||||
|
||||
result_dict = result.to_dict()
|
||||
|
||||
assert result_dict["slot_key"] == "grade"
|
||||
assert result_dict["success"] is True
|
||||
assert result_dict["final_value"] == "初一"
|
||||
assert result_dict["final_strategy"] == "rule"
|
||||
assert "steps" in result_dict
|
||||
assert "total_execution_time_ms" in result_dict
|
||||
|
||||
|
||||
class TestExecuteExtractStrategies:
|
||||
"""测试便捷函数"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_function(self):
|
||||
"""测试便捷函数 execute_extract_strategies"""
|
||||
mock_rule = AsyncMock(return_value="初一")
|
||||
|
||||
result = await execute_extract_strategies(
|
||||
strategies=["rule"],
|
||||
tenant_id="test-tenant",
|
||||
slot_key="grade",
|
||||
user_input="我想了解初一语文课程",
|
||||
rule_extractor=mock_rule,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.final_value == "初一"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convenience_function_with_validation(self):
|
||||
"""测试带校验的便捷函数"""
|
||||
mock_rule = AsyncMock(return_value="初一")
|
||||
|
||||
result = await execute_extract_strategies(
|
||||
strategies=["rule"],
|
||||
tenant_id="test-tenant",
|
||||
slot_key="grade",
|
||||
user_input="我想了解初一语文课程",
|
||||
validation_rule=r"^初[一二三]$",
|
||||
rule_extractor=mock_rule,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.final_value == "初一"
|
||||
|
||||
|
||||
class TestExtractContext:
|
||||
"""测试提取上下文"""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""测试上下文创建"""
|
||||
context = ExtractContext(
|
||||
tenant_id="test-tenant",
|
||||
slot_key="grade",
|
||||
user_input="测试输入",
|
||||
slot_type="string",
|
||||
validation_rule=r"^初[一二三]$",
|
||||
history=[{"role": "user", "content": "你好"}],
|
||||
session_id="session-123",
|
||||
)
|
||||
|
||||
assert context.tenant_id == "test-tenant"
|
||||
assert context.slot_key == "grade"
|
||||
assert context.user_input == "测试输入"
|
||||
assert context.slot_type == "string"
|
||||
assert context.validation_rule == r"^初[一二三]$"
|
||||
assert len(context.history) == 1
|
||||
assert context.session_id == "session-123"
|
||||
|
|
@ -0,0 +1,541 @@
|
|||
"""
|
||||
Tests for Slot Validation Service.
|
||||
槽位校验服务单元测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mid.slot_validation_service import (
|
||||
SlotValidationService,
|
||||
SlotValidationErrorCode,
|
||||
ValidationResult,
|
||||
SlotValidationError,
|
||||
BatchValidationResult,
|
||||
)
|
||||
|
||||
|
||||
class TestSlotValidationService:
|
||||
"""槽位校验服务测试类"""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""创建校验服务实例"""
|
||||
return SlotValidationService()
|
||||
|
||||
@pytest.fixture
|
||||
def string_slot_def(self):
|
||||
"""字符串类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "name",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": "请输入您的姓名",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def required_string_slot_def(self):
|
||||
"""必填字符串类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "phone",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"validation_rule": r"^1[3-9]\d{9}$",
|
||||
"ask_back_prompt": "请输入正确的手机号码",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def number_slot_def(self):
|
||||
"""数字类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "age",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": "请输入年龄",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def boolean_slot_def(self):
|
||||
"""布尔类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "is_student",
|
||||
"type": "boolean",
|
||||
"required": False,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": "是否是学生?",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def enum_slot_def(self):
|
||||
"""枚举类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "grade",
|
||||
"type": "enum",
|
||||
"required": False,
|
||||
"options": ["初一", "初二", "初三", "高一", "高二", "高三"],
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": "请选择年级",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def array_enum_slot_def(self):
|
||||
"""数组枚举类型槽位定义"""
|
||||
return {
|
||||
"slot_key": "subjects",
|
||||
"type": "array_enum",
|
||||
"required": False,
|
||||
"options": ["语文", "数学", "英语", "物理", "化学"],
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": "请选择学科",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def json_schema_slot_def(self):
|
||||
"""JSON Schema 校验槽位定义"""
|
||||
return {
|
||||
"slot_key": "email",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"validation_rule": '{"type": "string", "format": "email"}',
|
||||
"ask_back_prompt": "请输入有效的邮箱地址",
|
||||
}
|
||||
|
||||
class TestBasicValidation:
|
||||
"""基础校验测试"""
|
||||
|
||||
def test_empty_validation_rule(self, service, string_slot_def):
|
||||
"""测试空校验规则(应通过)"""
|
||||
string_slot_def["validation_rule"] = None
|
||||
result = service.validate_slot_value(string_slot_def, "test")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "test"
|
||||
|
||||
def test_whitespace_validation_rule(self, service, string_slot_def):
|
||||
"""测试空白校验规则(应通过)"""
|
||||
string_slot_def["validation_rule"] = " "
|
||||
result = service.validate_slot_value(string_slot_def, "test")
|
||||
assert result.ok is True
|
||||
|
||||
def test_no_slot_definition(self, service):
|
||||
"""测试无槽位定义(动态槽位)"""
|
||||
# 使用最小定义
|
||||
minimal_def = {"slot_key": "dynamic_field"}
|
||||
result = service.validate_slot_value(minimal_def, "any_value")
|
||||
assert result.ok is True
|
||||
|
||||
class TestRegexValidation:
|
||||
"""正则表达式校验测试"""
|
||||
|
||||
def test_regex_match(self, service, required_string_slot_def):
|
||||
"""测试正则匹配成功"""
|
||||
result = service.validate_slot_value(
|
||||
required_string_slot_def, "13800138000"
|
||||
)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "13800138000"
|
||||
|
||||
def test_regex_mismatch(self, service, required_string_slot_def):
|
||||
"""测试正则匹配失败"""
|
||||
result = service.validate_slot_value(
|
||||
required_string_slot_def, "invalid_phone"
|
||||
)
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH
|
||||
assert result.ask_back_prompt == "请输入正确的手机号码"
|
||||
|
||||
def test_regex_invalid_pattern(self, service, string_slot_def):
|
||||
"""测试非法正则表达式"""
|
||||
string_slot_def["validation_rule"] = "[invalid("
|
||||
result = service.validate_slot_value(string_slot_def, "test")
|
||||
assert result.ok is False
|
||||
assert (
|
||||
result.error_code
|
||||
== SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
|
||||
)
|
||||
|
||||
def test_regex_with_chinese(self, service, string_slot_def):
|
||||
"""测试包含中文的正则"""
|
||||
string_slot_def["validation_rule"] = r"^[\u4e00-\u9fa5]{2,4}$"
|
||||
result = service.validate_slot_value(string_slot_def, "张三")
|
||||
assert result.ok is True
|
||||
|
||||
result = service.validate_slot_value(string_slot_def, "John")
|
||||
assert result.ok is False
|
||||
|
||||
class TestJsonSchemaValidation:
|
||||
"""JSON Schema 校验测试"""
|
||||
|
||||
def test_json_schema_match(self, service):
|
||||
"""测试 JSON Schema 匹配成功"""
|
||||
slot_def = {
|
||||
"slot_key": "config",
|
||||
"type": "object",
|
||||
"validation_rule": '{"type": "object", "properties": {"name": {"type": "string"}}}',
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, {"name": "test"})
|
||||
assert result.ok is True
|
||||
|
||||
def test_json_schema_mismatch(self, service):
|
||||
"""测试 JSON Schema 匹配失败"""
|
||||
slot_def = {
|
||||
"slot_key": "count",
|
||||
"type": "number",
|
||||
"validation_rule": '{"type": "integer", "minimum": 0, "maximum": 100}',
|
||||
"ask_back_prompt": "请输入0-100之间的整数",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, 150)
|
||||
assert result.ok is False
|
||||
assert (
|
||||
result.error_code == SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH
|
||||
)
|
||||
assert result.ask_back_prompt == "请输入0-100之间的整数"
|
||||
|
||||
def test_json_schema_invalid_json(self, service, string_slot_def):
|
||||
"""测试非法 JSON Schema"""
|
||||
string_slot_def["validation_rule"] = "{invalid json}"
|
||||
result = service.validate_slot_value(string_slot_def, "test")
|
||||
assert result.ok is False
|
||||
assert (
|
||||
result.error_code
|
||||
== SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
|
||||
)
|
||||
|
||||
def test_json_schema_array(self, service):
|
||||
"""测试数组类型的 JSON Schema"""
|
||||
slot_def = {
|
||||
"slot_key": "items",
|
||||
"type": "array",
|
||||
"validation_rule": '{"type": "array", "items": {"type": "string"}}',
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, ["a", "b", "c"])
|
||||
assert result.ok is True
|
||||
|
||||
result = service.validate_slot_value(slot_def, [1, 2, 3])
|
||||
assert result.ok is False
|
||||
|
||||
class TestRequiredValidation:
|
||||
"""必填校验测试"""
|
||||
|
||||
def test_required_missing_none(self, service, required_string_slot_def):
|
||||
"""测试必填字段为 None"""
|
||||
result = service.validate_slot_value(
|
||||
required_string_slot_def, None
|
||||
)
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
|
||||
|
||||
def test_required_missing_empty_string(self, service, required_string_slot_def):
|
||||
"""测试必填字段为空字符串"""
|
||||
result = service.validate_slot_value(required_string_slot_def, "")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
|
||||
|
||||
def test_required_missing_whitespace(self, service, required_string_slot_def):
|
||||
"""测试必填字段为空白字符"""
|
||||
result = service.validate_slot_value(required_string_slot_def, " ")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
|
||||
|
||||
def test_required_present(self, service, required_string_slot_def):
|
||||
"""测试必填字段有值"""
|
||||
result = service.validate_slot_value(
|
||||
required_string_slot_def, "13800138000"
|
||||
)
|
||||
assert result.ok is True
|
||||
|
||||
def test_not_required_empty(self, service, string_slot_def):
|
||||
"""测试非必填字段为空"""
|
||||
result = service.validate_slot_value(string_slot_def, "")
|
||||
assert result.ok is True
|
||||
|
||||
class TestTypeValidation:
|
||||
"""类型校验测试"""
|
||||
|
||||
def test_string_type(self, service, string_slot_def):
|
||||
"""测试字符串类型"""
|
||||
result = service.validate_slot_value(string_slot_def, "hello")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "hello"
|
||||
|
||||
def test_string_type_conversion(self, service, string_slot_def):
|
||||
"""测试字符串类型自动转换"""
|
||||
result = service.validate_slot_value(string_slot_def, 123)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "123"
|
||||
|
||||
def test_number_type_integer(self, service, number_slot_def):
|
||||
"""测试数字类型 - 整数"""
|
||||
result = service.validate_slot_value(number_slot_def, 25)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == 25
|
||||
|
||||
def test_number_type_float(self, service, number_slot_def):
|
||||
"""测试数字类型 - 浮点数"""
|
||||
result = service.validate_slot_value(number_slot_def, 25.5)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == 25.5
|
||||
|
||||
def test_number_type_string_conversion(self, service, number_slot_def):
|
||||
"""测试数字类型 - 字符串转换"""
|
||||
result = service.validate_slot_value(number_slot_def, "30")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == 30
|
||||
|
||||
def test_number_type_invalid(self, service, number_slot_def):
|
||||
"""测试数字类型 - 无效值"""
|
||||
result = service.validate_slot_value(number_slot_def, "not_a_number")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
def test_number_type_reject_boolean(self, service, number_slot_def):
|
||||
"""测试数字类型 - 拒绝布尔值"""
|
||||
result = service.validate_slot_value(number_slot_def, True)
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
def test_boolean_type_true(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - True"""
|
||||
result = service.validate_slot_value(boolean_slot_def, True)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value is True
|
||||
|
||||
def test_boolean_type_false(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - False"""
|
||||
result = service.validate_slot_value(boolean_slot_def, False)
|
||||
assert result.ok is True
|
||||
assert result.normalized_value is False
|
||||
|
||||
def test_boolean_type_string_true(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - 字符串 true"""
|
||||
result = service.validate_slot_value(boolean_slot_def, "true")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value is True
|
||||
|
||||
def test_boolean_type_string_yes(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - 字符串 yes/是"""
|
||||
result = service.validate_slot_value(boolean_slot_def, "是")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value is True
|
||||
|
||||
def test_boolean_type_string_false(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - 字符串 false"""
|
||||
result = service.validate_slot_value(boolean_slot_def, "false")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value is False
|
||||
|
||||
def test_boolean_type_invalid(self, service, boolean_slot_def):
|
||||
"""测试布尔类型 - 无效值"""
|
||||
result = service.validate_slot_value(boolean_slot_def, "maybe")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
def test_enum_type_valid(self, service, enum_slot_def):
|
||||
"""测试枚举类型 - 有效值"""
|
||||
result = service.validate_slot_value(enum_slot_def, "高一")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "高一"
|
||||
|
||||
def test_enum_type_invalid(self, service, enum_slot_def):
|
||||
"""测试枚举类型 - 无效值"""
|
||||
result = service.validate_slot_value(enum_slot_def, "大一")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_ENUM_INVALID
|
||||
|
||||
def test_enum_type_not_string(self, service, enum_slot_def):
|
||||
"""测试枚举类型 - 非字符串"""
|
||||
result = service.validate_slot_value(enum_slot_def, 123)
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
def test_array_enum_type_valid(self, service, array_enum_slot_def):
|
||||
"""测试数组枚举类型 - 有效值"""
|
||||
result = service.validate_slot_value(
|
||||
array_enum_slot_def, ["语文", "数学"]
|
||||
)
|
||||
assert result.ok is True
|
||||
|
||||
def test_array_enum_type_invalid_item(self, service, array_enum_slot_def):
|
||||
"""测试数组枚举类型 - 无效元素"""
|
||||
result = service.validate_slot_value(
|
||||
array_enum_slot_def, ["语文", "生物"]
|
||||
)
|
||||
assert result.ok is False
|
||||
assert (
|
||||
result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID
|
||||
)
|
||||
|
||||
def test_array_enum_type_not_array(self, service, array_enum_slot_def):
|
||||
"""测试数组枚举类型 - 非数组"""
|
||||
result = service.validate_slot_value(array_enum_slot_def, "语文")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
def test_array_enum_type_non_string_item(self, service, array_enum_slot_def):
|
||||
"""测试数组枚举类型 - 非字符串元素"""
|
||||
result = service.validate_slot_value(array_enum_slot_def, ["语文", 123])
|
||||
assert result.ok is False
|
||||
assert (
|
||||
result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID
|
||||
)
|
||||
|
||||
class TestBatchValidation:
|
||||
"""批量校验测试"""
|
||||
|
||||
def test_batch_all_valid(self, service, string_slot_def, number_slot_def):
|
||||
"""测试批量校验 - 全部通过"""
|
||||
slot_defs = [string_slot_def, number_slot_def]
|
||||
values = {"name": "张三", "age": 25}
|
||||
result = service.validate_slots(slot_defs, values)
|
||||
assert result.ok is True
|
||||
assert len(result.errors) == 0
|
||||
assert result.validated_values["name"] == "张三"
|
||||
assert result.validated_values["age"] == 25
|
||||
|
||||
def test_batch_some_invalid(self, service, string_slot_def, number_slot_def):
|
||||
"""测试批量校验 - 部分失败"""
|
||||
slot_defs = [string_slot_def, number_slot_def]
|
||||
values = {"name": "张三", "age": "not_a_number"}
|
||||
result = service.validate_slots(slot_defs, values)
|
||||
assert result.ok is False
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].slot_key == "age"
|
||||
|
||||
def test_batch_missing_required(
|
||||
self, service, required_string_slot_def, string_slot_def
|
||||
):
|
||||
"""测试批量校验 - 缺失必填字段"""
|
||||
slot_defs = [required_string_slot_def, string_slot_def]
|
||||
values = {"name": "张三"} # 缺少 phone
|
||||
result = service.validate_slots(slot_defs, values)
|
||||
assert result.ok is False
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].slot_key == "phone"
|
||||
assert (
|
||||
result.errors[0].error_code
|
||||
== SlotValidationErrorCode.SLOT_REQUIRED_MISSING
|
||||
)
|
||||
|
||||
def test_batch_undefined_slot(self, service, string_slot_def):
|
||||
"""测试批量校验 - 未定义槽位"""
|
||||
slot_defs = [string_slot_def]
|
||||
values = {"name": "张三", "undefined_field": "value"}
|
||||
result = service.validate_slots(slot_defs, values)
|
||||
assert result.ok is True
|
||||
# 未定义槽位应允许通过
|
||||
assert "undefined_field" in result.validated_values
|
||||
|
||||
class TestCombinedValidation:
|
||||
"""组合校验测试(类型 + 正则/JSON Schema)"""
|
||||
|
||||
def test_type_and_regex_both_pass(self, service):
|
||||
"""测试类型和正则都通过"""
|
||||
slot_def = {
|
||||
"slot_key": "code",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"validation_rule": r"^[A-Z]{2}\d{4}$",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, "AB1234")
|
||||
assert result.ok is True
|
||||
|
||||
def test_type_pass_regex_fail(self, service):
|
||||
"""测试类型通过但正则失败"""
|
||||
slot_def = {
|
||||
"slot_key": "code",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"validation_rule": r"^[A-Z]{2}\d{4}$",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, "ab1234")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH
|
||||
|
||||
def test_type_fail_no_regex_check(self, service):
|
||||
"""测试类型失败时不执行正则校验"""
|
||||
slot_def = {
|
||||
"slot_key": "code",
|
||||
"type": "number",
|
||||
"required": True,
|
||||
"validation_rule": r"^\d+$",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, "not_a_number")
|
||||
assert result.ok is False
|
||||
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
|
||||
|
||||
class TestAskBackPrompt:
|
||||
"""追问提示语测试"""
|
||||
|
||||
def test_ask_back_prompt_on_validation_fail(self, service):
|
||||
"""测试校验失败时返回 ask_back_prompt"""
|
||||
slot_def = {
|
||||
"slot_key": "email",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"validation_rule": r"^[\w\.-]+@[\w\.-]+\.\w+$",
|
||||
"ask_back_prompt": "请输入有效的邮箱地址,如 example@domain.com",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, "invalid_email")
|
||||
assert result.ok is False
|
||||
assert result.ask_back_prompt == "请输入有效的邮箱地址,如 example@domain.com"
|
||||
|
||||
def test_no_ask_back_prompt_on_success(self, service, string_slot_def):
|
||||
"""测试校验通过时不返回 ask_back_prompt"""
|
||||
result = service.validate_slot_value(string_slot_def, "valid")
|
||||
assert result.ok is True
|
||||
assert result.ask_back_prompt is None
|
||||
|
||||
def test_ask_back_prompt_on_required_missing(self, service):
|
||||
"""测试必填缺失时返回 ask_back_prompt"""
|
||||
slot_def = {
|
||||
"slot_key": "name",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"ask_back_prompt": "请告诉我们您的姓名",
|
||||
}
|
||||
result = service.validate_slot_value(slot_def, "")
|
||||
assert result.ok is False
|
||||
assert result.ask_back_prompt == "请告诉我们您的姓名"
|
||||
|
||||
|
||||
class TestSlotValidationErrorCode:
|
||||
"""错误码测试"""
|
||||
|
||||
def test_error_code_values(self):
|
||||
"""测试错误码值"""
|
||||
assert SlotValidationErrorCode.SLOT_REQUIRED_MISSING == "SLOT_REQUIRED_MISSING"
|
||||
assert SlotValidationErrorCode.SLOT_TYPE_INVALID == "SLOT_TYPE_INVALID"
|
||||
assert SlotValidationErrorCode.SLOT_REGEX_MISMATCH == "SLOT_REGEX_MISMATCH"
|
||||
assert (
|
||||
SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH
|
||||
== "SLOT_JSON_SCHEMA_MISMATCH"
|
||||
)
|
||||
assert (
|
||||
SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
|
||||
== "SLOT_VALIDATION_RULE_INVALID"
|
||||
)
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""ValidationResult 测试"""
|
||||
|
||||
def test_success_result(self):
|
||||
"""测试成功结果"""
|
||||
result = ValidationResult(ok=True, normalized_value="test")
|
||||
assert result.ok is True
|
||||
assert result.normalized_value == "test"
|
||||
assert result.error_code is None
|
||||
assert result.error_message is None
|
||||
|
||||
def test_failure_result(self):
|
||||
"""测试失败结果"""
|
||||
result = ValidationResult(
|
||||
ok=False,
|
||||
error_code="SLOT_REGEX_MISMATCH",
|
||||
error_message="格式不正确",
|
||||
ask_back_prompt="请重新输入",
|
||||
)
|
||||
assert result.ok is False
|
||||
assert result.error_code == "SLOT_REGEX_MISMATCH"
|
||||
assert result.error_message == "格式不正确"
|
||||
assert result.ask_back_prompt == "请重新输入"
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
Test cases for Step-KB Binding feature.
|
||||
[Step-KB-Binding] 步骤关联知识库功能的测试用例
|
||||
|
||||
测试覆盖:
|
||||
1. 步骤配置的增删改查与参数校验
|
||||
2. 配置步骤KB范围后,检索仅在范围内发生
|
||||
3. 未配置时回退原逻辑
|
||||
4. 多知识库同名内容场景下,步骤约束生效
|
||||
5. trace 字段完整性校验
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TestStepKbBindingModel:
|
||||
"""测试步骤KB绑定数据模型"""
|
||||
|
||||
def test_flow_step_with_kb_binding_fields(self):
|
||||
"""测试 FlowStep 包含 KB 绑定字段"""
|
||||
from app.models.entities import FlowStep
|
||||
|
||||
step = FlowStep(
|
||||
step_no=1,
|
||||
content="测试步骤",
|
||||
allowed_kb_ids=["kb-1", "kb-2"],
|
||||
preferred_kb_ids=["kb-1"],
|
||||
kb_query_hint="查找产品相关信息",
|
||||
max_kb_calls_per_step=2,
|
||||
)
|
||||
|
||||
assert step.allowed_kb_ids == ["kb-1", "kb-2"]
|
||||
assert step.preferred_kb_ids == ["kb-1"]
|
||||
assert step.kb_query_hint == "查找产品相关信息"
|
||||
assert step.max_kb_calls_per_step == 2
|
||||
|
||||
def test_flow_step_without_kb_binding(self):
|
||||
"""测试 FlowStep 不配置 KB 绑定时的默认值"""
|
||||
from app.models.entities import FlowStep
|
||||
|
||||
step = FlowStep(
|
||||
step_no=1,
|
||||
content="测试步骤",
|
||||
)
|
||||
|
||||
assert step.allowed_kb_ids is None
|
||||
assert step.preferred_kb_ids is None
|
||||
assert step.kb_query_hint is None
|
||||
assert step.max_kb_calls_per_step is None
|
||||
|
||||
def test_max_kb_calls_validation(self):
|
||||
"""测试 max_kb_calls_per_step 的范围校验"""
|
||||
from app.models.entities import FlowStep
|
||||
from pydantic import ValidationError
|
||||
|
||||
# 有效范围 1-5
|
||||
step = FlowStep(step_no=1, content="test", max_kb_calls_per_step=3)
|
||||
assert step.max_kb_calls_per_step == 3
|
||||
|
||||
# 超出上限
|
||||
with pytest.raises(Exception): # ValidationError
|
||||
FlowStep(step_no=1, content="test", max_kb_calls_per_step=10)
|
||||
|
||||
|
||||
class TestStepKbConfig:
|
||||
"""测试 StepKbConfig 数据类"""
|
||||
|
||||
def test_step_kb_config_creation(self):
|
||||
"""测试 StepKbConfig 创建"""
|
||||
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
|
||||
|
||||
config = StepKbConfig(
|
||||
allowed_kb_ids=["kb-1", "kb-2"],
|
||||
preferred_kb_ids=["kb-1"],
|
||||
kb_query_hint="查找产品信息",
|
||||
max_kb_calls=2,
|
||||
step_id="flow-1_step_1",
|
||||
)
|
||||
|
||||
assert config.allowed_kb_ids == ["kb-1", "kb-2"]
|
||||
assert config.preferred_kb_ids == ["kb-1"]
|
||||
assert config.kb_query_hint == "查找产品信息"
|
||||
assert config.max_kb_calls == 2
|
||||
assert config.step_id == "flow-1_step_1"
|
||||
|
||||
def test_step_kb_config_defaults(self):
|
||||
"""测试 StepKbConfig 默认值"""
|
||||
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
|
||||
|
||||
config = StepKbConfig()
|
||||
|
||||
assert config.allowed_kb_ids is None
|
||||
assert config.preferred_kb_ids is None
|
||||
assert config.kb_query_hint is None
|
||||
assert config.max_kb_calls == 1
|
||||
assert config.step_id is None
|
||||
|
||||
|
||||
class TestKbSearchDynamicToolWithStepConfig:
|
||||
"""测试 KbSearchDynamicTool 与步骤配置的集成"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_search_with_allowed_kb_ids(self):
|
||||
"""测试配置 allowed_kb_ids 后检索范围受限"""
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicTool,
|
||||
KbSearchDynamicConfig,
|
||||
StepKbConfig,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_timeout_governor = MagicMock()
|
||||
|
||||
tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
timeout_governor=mock_timeout_governor,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
step_config = StepKbConfig(
|
||||
allowed_kb_ids=["kb-allowed-1", "kb-allowed-2"],
|
||||
step_id="test_step",
|
||||
)
|
||||
|
||||
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
|
||||
mock_retrieve.return_value = [
|
||||
{"id": "1", "content": "test", "score": 0.8, "metadata": {"kb_id": "kb-allowed-1"}}
|
||||
]
|
||||
|
||||
result = await tool.execute(
|
||||
query="测试查询",
|
||||
tenant_id="tenant-1",
|
||||
step_kb_config=step_config,
|
||||
)
|
||||
|
||||
# 验证检索调用时传入了正确的 kb_ids
|
||||
call_args = mock_retrieve.call_args
|
||||
assert call_args[1]['step_kb_config'] == step_config
|
||||
|
||||
# 验证返回结果包含 step_kb_binding 信息
|
||||
assert result.step_kb_binding is not None
|
||||
assert result.step_kb_binding['allowed_kb_ids'] == ["kb-allowed-1", "kb-allowed-2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_search_without_step_config(self):
|
||||
"""测试未配置步骤KB时的回退行为"""
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicTool,
|
||||
KbSearchDynamicConfig,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_timeout_governor = MagicMock()
|
||||
|
||||
tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
timeout_governor=mock_timeout_governor,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
|
||||
mock_retrieve.return_value = [
|
||||
{"id": "1", "content": "test", "score": 0.8, "metadata": {}}
|
||||
]
|
||||
|
||||
result = await tool.execute(
|
||||
query="测试查询",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
# 验证检索调用时未传入 step_kb_config
|
||||
call_args = mock_retrieve.call_args
|
||||
assert call_args[1]['step_kb_config'] is None
|
||||
|
||||
# 验证返回结果不包含 step_kb_binding
|
||||
assert result.step_kb_binding is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_search_result_includes_used_kb_ids(self):
|
||||
"""测试检索结果包含实际使用的知识库ID"""
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicTool,
|
||||
KbSearchDynamicConfig,
|
||||
StepKbConfig,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_timeout_governor = MagicMock()
|
||||
|
||||
tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
timeout_governor=mock_timeout_governor,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
step_config = StepKbConfig(
|
||||
allowed_kb_ids=["kb-1", "kb-2"],
|
||||
step_id="test_step",
|
||||
)
|
||||
|
||||
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
|
||||
mock_retrieve.return_value = [
|
||||
{"id": "1", "content": "test1", "score": 0.9, "metadata": {"kb_id": "kb-1"}},
|
||||
{"id": "2", "content": "test2", "score": 0.8, "metadata": {"kb_id": "kb-1"}},
|
||||
{"id": "3", "content": "test3", "score": 0.7, "metadata": {"kb_id": "kb-2"}},
|
||||
]
|
||||
|
||||
result = await tool.execute(
|
||||
query="测试查询",
|
||||
tenant_id="tenant-1",
|
||||
step_kb_config=step_config,
|
||||
)
|
||||
|
||||
# 验证 used_kb_ids 包含所有命中的知识库
|
||||
assert result.step_kb_binding is not None
|
||||
assert set(result.step_kb_binding['used_kb_ids']) == {"kb-1", "kb-2"}
|
||||
assert result.step_kb_binding['kb_hit'] is True
|
||||
|
||||
|
||||
class TestTraceInfoStepKbBinding:
|
||||
"""测试 TraceInfo 中的 step_kb_binding 字段"""
|
||||
|
||||
def test_trace_info_with_step_kb_binding(self):
|
||||
"""测试 TraceInfo 包含 step_kb_binding 字段"""
|
||||
from app.models.mid.schemas import TraceInfo, ExecutionMode
|
||||
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
step_kb_binding={
|
||||
"step_id": "flow-1_step_2",
|
||||
"allowed_kb_ids": ["kb-1", "kb-2"],
|
||||
"used_kb_ids": ["kb-1"],
|
||||
"kb_hit": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert trace.step_kb_binding is not None
|
||||
assert trace.step_kb_binding['step_id'] == "flow-1_step_2"
|
||||
assert trace.step_kb_binding['allowed_kb_ids'] == ["kb-1", "kb-2"]
|
||||
assert trace.step_kb_binding['used_kb_ids'] == ["kb-1"]
|
||||
|
||||
def test_trace_info_without_step_kb_binding(self):
|
||||
"""测试 TraceInfo 默认不包含 step_kb_binding"""
|
||||
from app.models.mid.schemas import TraceInfo, ExecutionMode
|
||||
|
||||
trace = TraceInfo(mode=ExecutionMode.AGENT)
|
||||
|
||||
assert trace.step_kb_binding is None
|
||||
|
||||
|
||||
class TestFlowStepKbBindingIntegration:
|
||||
"""测试流程步骤与KB绑定的集成"""
|
||||
|
||||
def test_script_flow_steps_with_kb_binding(self):
|
||||
"""测试 ScriptFlow 的 steps 包含 KB 绑定配置"""
|
||||
from app.models.entities import ScriptFlowCreate
|
||||
|
||||
flow_create = ScriptFlowCreate(
|
||||
name="测试流程",
|
||||
steps=[
|
||||
{
|
||||
"step_no": 1,
|
||||
"content": "步骤1",
|
||||
"allowed_kb_ids": ["kb-1"],
|
||||
"preferred_kb_ids": None,
|
||||
"kb_query_hint": "查找产品信息",
|
||||
"max_kb_calls_per_step": 2,
|
||||
},
|
||||
{
|
||||
"step_no": 2,
|
||||
"content": "步骤2",
|
||||
# 不配置 KB 绑定
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert flow_create.steps[0]['allowed_kb_ids'] == ["kb-1"]
|
||||
assert flow_create.steps[0]['kb_query_hint'] == "查找产品信息"
|
||||
assert flow_create.steps[1].get('allowed_kb_ids') is None
|
||||
|
||||
|
||||
class TestKbBindingLogging:
|
||||
"""测试 KB 绑定的日志记录"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_kb_config_logging(self, caplog):
|
||||
"""测试步骤KB配置的日志记录"""
|
||||
import logging
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
KbSearchDynamicTool,
|
||||
KbSearchDynamicConfig,
|
||||
StepKbConfig,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_timeout_governor = MagicMock()
|
||||
|
||||
tool = KbSearchDynamicTool(
|
||||
session=mock_session,
|
||||
timeout_governor=mock_timeout_governor,
|
||||
config=KbSearchDynamicConfig(enabled=True),
|
||||
)
|
||||
|
||||
step_config = StepKbConfig(
|
||||
allowed_kb_ids=["kb-1"],
|
||||
step_id="flow-1_step_1",
|
||||
)
|
||||
|
||||
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
await tool.execute(
|
||||
query="测试",
|
||||
tenant_id="tenant-1",
|
||||
step_kb_config=step_config,
|
||||
)
|
||||
|
||||
# 验证日志包含 Step-KB-Binding 标记
|
||||
assert any("Step-KB-Binding" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
# 运行测试的入口
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Reference in New Issue