test: add unit tests and utility scripts for intent routing, slot management, and KB search [AC-TEST]

This commit is contained in:
MerCry 2026-03-10 12:10:22 +08:00
parent fe883cfff0
commit f4ca25b0d8
57 changed files with 9663 additions and 58 deletions

View File

@ -0,0 +1,81 @@
"""
Database Migration: Scene Slot Bundle Tables.
[AC-SCENE-SLOT-01] 场景-槽位映射配置表迁移
创建时间: 2025-03-07
变更说明:
- 新增 scene_slot_bundles 表用于存储场景槽位包配置
执行方式:
- SQLModel 会自动创建表通过 init_db
- 此脚本用于手动迁移或回滚
SQL DDL:
```sql
CREATE TABLE scene_slot_bundles (
id UUID PRIMARY KEY,
tenant_id VARCHAR NOT NULL,
scene_key VARCHAR(100) NOT NULL,
scene_name VARCHAR(100) NOT NULL,
description TEXT,
required_slots JSON NOT NULL DEFAULT '[]',
optional_slots JSON NOT NULL DEFAULT '[]',
slot_priority JSON,
completion_threshold FLOAT NOT NULL DEFAULT 1.0,
ask_back_order VARCHAR NOT NULL DEFAULT 'priority',
status VARCHAR NOT NULL DEFAULT 'draft',
version INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
CREATE INDEX ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id);
CREATE UNIQUE INDEX ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key);
CREATE INDEX ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status);
```
回滚 SQL:
```sql
DROP TABLE IF EXISTS scene_slot_bundles;
```
"""
SCENE_SLOT_BUNDLES_DDL = """
CREATE TABLE IF NOT EXISTS scene_slot_bundles (
id UUID PRIMARY KEY,
tenant_id VARCHAR NOT NULL,
scene_key VARCHAR(100) NOT NULL,
scene_name VARCHAR(100) NOT NULL,
description TEXT,
required_slots JSON NOT NULL DEFAULT '[]',
optional_slots JSON NOT NULL DEFAULT '[]',
slot_priority JSON,
completion_threshold FLOAT NOT NULL DEFAULT 1.0,
ask_back_order VARCHAR NOT NULL DEFAULT 'priority',
status VARCHAR NOT NULL DEFAULT 'draft',
version INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
"""
SCENE_SLOT_BUNDLES_INDEXES = """
CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id);
CREATE UNIQUE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key);
CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status);
"""
SCENE_SLOT_BUNDLES_ROLLBACK = """
DROP TABLE IF EXISTS scene_slot_bundles;
"""
async def upgrade(conn):
"""执行迁移"""
await conn.execute(SCENE_SLOT_BUNDLES_DDL)
await conn.execute(SCENE_SLOT_BUNDLES_INDEXES)
async def downgrade(conn):
"""回滚迁移"""
await conn.execute(SCENE_SLOT_BUNDLES_ROLLBACK)

View File

@ -0,0 +1,49 @@
"""
Database Migration: Add display_name and description to slot_definitions.
添加槽位名称和槽位说明字段
创建时间: 2026-03-08
变更说明:
- 新增 display_name 字段槽位名称给运营/教研看的中文名
- 新增 description 字段槽位说明解释这个槽位采集什么用于哪里
执行方式:
- SQLModel 会自动处理新字段通过 init_db
- 此脚本用于手动迁移现有数据库
SQL DDL:
```sql
ALTER TABLE slot_definitions
ADD COLUMN IF NOT EXISTS display_name VARCHAR(100),
ADD COLUMN IF NOT EXISTS description VARCHAR(500);
```
回滚 SQL:
```sql
ALTER TABLE slot_definitions
DROP COLUMN IF EXISTS display_name,
DROP COLUMN IF EXISTS description;
```
"""
ALTER_SLOT_DEFINITIONS_DDL = """
ALTER TABLE slot_definitions
ADD COLUMN IF NOT EXISTS display_name VARCHAR(100),
ADD COLUMN IF NOT EXISTS description VARCHAR(500);
"""
ALTER_SLOT_DEFINITIONS_ROLLBACK = """
ALTER TABLE slot_definitions
DROP COLUMN IF EXISTS display_name,
DROP COLUMN IF EXISTS description;
"""
async def upgrade(conn):
"""执行迁移"""
await conn.execute(ALTER_SLOT_DEFINITIONS_DDL)
async def downgrade(conn):
"""回滚迁移"""
await conn.execute(ALTER_SLOT_DEFINITIONS_ROLLBACK)

View File

@ -0,0 +1,51 @@
"""
检查所有意图规则包括未启用的
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.models.entities import IntentRule
from app.core.config import get_settings
async def check_all_rules():
"""获取所有意图规则"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
# 查询所有规则(包括未启用的)
result = await session.execute(
select(IntentRule).where(
IntentRule.tenant_id == "szmp@ash@2026"
)
)
rules = result.scalars().all()
print("=" * 80)
print(f"数据库中的所有意图规则 (tenant=szmp@ash@2026):")
print(f"总计: {len(rules)}")
print("=" * 80)
for rule in rules:
print(f"\n规则: {rule.name}")
print(f" ID: {rule.id}")
print(f" 响应类型: {rule.response_type}")
print(f" 关键词: {rule.keywords}")
print(f" 目标知识库: {rule.target_kb_ids}")
print(f" 优先级: {rule.priority}")
print(f" 启用状态: {'✅ 启用' if rule.is_enabled else '❌ 禁用'}")
print("\n" + "=" * 80)
if __name__ == "__main__":
asyncio.run(check_all_rules())

View File

@ -0,0 +1,71 @@
"""
检查 Qdrant 中课程知识库的数据结构
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def check_course_kb():
"""检查课程知识库"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
safe_tenant_id = tenant_id.replace('@', '_')
prefix = settings.qdrant_collection_prefix
expected_collection = f"{prefix}{safe_tenant_id}_{course_kb_id}"
print(f"\n{'='*80}")
print(f"检查课程知识库 Collection")
print(f"{'='*80}")
print(f"租户 ID: {tenant_id}")
print(f"课程知识库 ID: {course_kb_id}")
print(f"预期 Collection 名称: {expected_collection}")
collections = await qdrant.get_collections()
collection_names = [c.name for c in collections.collections]
print(f"\n租户的所有 Collections:")
for name in collection_names:
if safe_tenant_id in name:
print(f" - {name}")
if expected_collection in collection_names:
print(f"\n✅ 课程知识库 Collection 存在: {expected_collection}")
points, _ = qdrant.scroll(
collection_name=expected_collection,
limit=3,
with_vectors=False,
)
print(f"\n课程知识库数据 (共 {len(points)} 条):")
for i, point in enumerate(points, 1):
payload = point.get('payload', {})
print(f"\n [{i}] id: {point.get('id')}")
print(f" payload keys: {list(payload.keys())}")
if 'metadata' in payload:
print(f" metadata: {payload['metadata']}")
else:
print(f"\n❌ 课程知识库 Collection 不存在!")
print(f" 可用的 Collections: {collection_names}")
if __name__ == "__main__":
asyncio.run(check_course_kb())

View File

@ -0,0 +1,50 @@
"""
检查课程知识库的 collection 是否存在
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def check_course_kb_collection():
"""检查课程知识库的 collection 是否存在"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"课程知识库 collection name: {collection_name}")
exists = await qdrant.collection_exists(collection_name)
print(f"Collection exists: {exists}")
if exists:
points = await qdrant.scroll(
collection_name=collection_name,
limit=5,
with_vectors=False,
)
print(f"\n课程知识库中有 {len(points)} 条数据:")
for i, point in enumerate(points, 1):
payload = point.get('payload', {})
print(f" [{i}] payload keys: {list(payload.keys())}")
for key, value in payload.items():
if key != 'text' and key != 'vector':
print(f" {key}: {value}")
else:
print(f"课程知识库 collection 不存在!")
if __name__ == "__main__":
asyncio.run(check_course_kb_collection())

View File

@ -0,0 +1,98 @@
"""
检查课程知识库的录入情况
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
from app.models.entities import Document
async def check_course_kb_status():
"""检查课程知识库的录入情况"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
print(f"\n{'='*80}")
print(f"检查课程知识库的录入情况")
print(f"{'='*80}")
print(f"租户 ID: {tenant_id}")
print(f"知识库 ID: {kb_id}")
async with async_session() as session:
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.kb_id == kb_id,
)
result = await session.execute(stmt)
documents = result.scalars().all()
print(f"\n数据库中的文档记录: {len(documents)}")
if documents:
for doc in documents[:5]:
print(f" - {doc.file_name} (status: {doc.status})")
if len(documents) > 5:
print(f" ... 还有 {len(documents) - 5} 个文档")
client = QdrantClient()
qdrant = await client.get_client()
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"\nQdrant Collection 名称: {collection_name}")
exists = await qdrant.collection_exists(collection_name)
if exists:
points_result = await qdrant.scroll(
collection_name=collection_name,
limit=5,
with_vectors=False,
)
points = points_result[0] if isinstance(points_result, tuple) else points_result
print(f"Qdrant Collection 存在,有 {len(points)} 条数据")
for i, point in enumerate(points, 1):
if hasattr(point, 'payload'):
payload = point.payload
point_id = point.id
else:
payload = point.get('payload', {})
point_id = point.get('id', 'unknown')
print(f" [{i}] id: {point_id}")
if 'text' in payload:
text = payload['text'][:50] + '...' if len(payload['text']) > 50 else payload['text']
print(f" text: {text}")
else:
print(f"Qdrant Collection 不存在!")
print(f"\n{'='*80}")
print(f"结论:")
if len(documents) > 0 and not exists:
print(" 数据库有文档记录,但 Qdrant Collection 不存在")
print(" 需要等待文档向量化任务完成")
elif len(documents) == 0 and exists:
print(" 数据库没有文档记录,但 Qdrant Collection 存在")
print(" 可能是旧数据")
elif len(documents) > 0 and exists:
print(f" 数据库有 {len(documents)} 个文档记录")
print(f" Qdrant Collection 存在")
print(" ✅ 知识库已录入完成")
else:
print(" 数据库没有文档记录")
print(" Qdrant Collection 不存在")
print(" ❌ 知识库未录入")
if __name__ == "__main__":
asyncio.run(check_course_kb_status())

View File

@ -0,0 +1,78 @@
"""
检查 Qdrant 中是否有 grade=五年级 的数据
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client.models import FieldCondition, Filter, MatchValue
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def check_grade_data():
"""检查 Qdrant 中是否有 grade=五年级 的数据"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"\n{'='*80}")
print(f"检查 Qdrant 中 grade 字段的分布")
print(f"{'='*80}")
print(f"Collection: {collection_name}")
# 获取所有数据
all_points = await qdrant.scroll(
collection_name=collection_name,
limit=100,
with_vectors=False,
)
print(f"\n总数据量: {len(all_points[0])}")
# 统计 grade 分布
grade_count = {}
for point in all_points[0]:
metadata = point.payload.get('metadata', {})
grade = metadata.get('grade', '')
grade_count[grade] = grade_count.get(grade, 0) + 1
print(f"\ngrade 字段分布:")
for grade, count in sorted(grade_count.items()):
print(f" {grade}: {count}")
# 检查是否有 五年级 的数据
print(f"\n--- 检查 grade=五年级 的数据 ---")
qdrant_filter = Filter(
must=[
FieldCondition(
key="metadata.grade",
match=MatchValue(value="五年级"),
)
]
)
results = await qdrant.scroll(
collection_name=collection_name,
limit=10,
with_vectors=False,
scroll_filter=qdrant_filter,
)
print(f"grade=五年级 的数据: {len(results[0])}")
for p in results[0]:
print(f" text: {p.payload.get('text', '')[:80]}...")
print(f" metadata: {p.payload.get('metadata', {})}")
if __name__ == "__main__":
asyncio.run(check_grade_data())

View File

@ -0,0 +1,88 @@
"""
查看指定知识库的内容
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
async def check_kb_content():
"""查看知识库内容"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
tenant_id = "szmp@ash@2026"
kb_id = "8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c"
collection_name = f"kb_szmp_ash_2026_8559ebc9"
print("=" * 80)
print(f"查看知识库: {kb_id}")
print(f"Collection: {collection_name}")
print("=" * 80)
try:
# 检查 collection 是否存在
exists = await client.collection_exists(collection_name)
print(f"\nCollection 存在: {exists}")
if not exists:
print("Collection 不存在!")
return
# 获取 collection 信息
info = await client.get_collection(collection_name)
print(f"\nCollection 信息:")
print(f" 向量数: {info.points_count}")
# 滚动查询所有点
print(f"\n文档内容:")
print("-" * 80)
offset = None
total = 0
while True:
result = await client.scroll(
collection_name=collection_name,
limit=10,
offset=offset,
with_payload=True,
)
points = result[0]
if not points:
break
for point in points:
total += 1
payload = point.payload or {}
text = payload.get('text', 'N/A')[:100]
metadata = payload.get('metadata', {})
filename = payload.get('filename', 'N/A')
print(f"\n [{total}] ID: {point.id}")
print(f" Filename: {filename}")
print(f" Text: {text}...")
print(f" Metadata: {metadata}")
offset = result[1]
if offset is None:
break
print(f"\n总计 {total} 条记录")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(check_kb_content())

View File

@ -0,0 +1,51 @@
"""
检查租户的所有知识库
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import get_settings
from app.models.entities import KnowledgeBase
async def check_knowledge_bases():
"""检查租户的所有知识库"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
tenant_id = "szmp@ash@2026"
print(f"\n{'='*80}")
print(f"检查租户 {tenant_id} 的所有知识库")
print(f"{'='*80}")
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
)
result = await session.execute(stmt)
kbs = result.scalars().all()
print(f"\n找到 {len(kbs)} 个知识库:")
for kb in kbs:
print(f"\n 知识库: {kb.name}")
print(f" id: {kb.id}")
print(f" kb_type: {kb.kb_type}")
print(f" description: {kb.description}")
print(f" is_enabled: {kb.is_enabled}")
print(f" doc_count: {kb.doc_count}")
print(f" created_at: {kb.created_at}")
if __name__ == "__main__":
asyncio.run(check_knowledge_bases())

View File

@ -0,0 +1,65 @@
"""
检查知识库的元数据字段定义
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import get_settings
from app.models.entities import (
MetadataFieldDefinition,
MetadataFieldStatus,
FieldRole,
)
async def check_metadata_fields():
"""检查元数据字段定义"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
tenant_id = "szmp@ash@2026"
print(f"\n{'='*80}")
print(f"检查租户 {tenant_id} 的元数据字段定义")
print(f"{'='*80}")
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.tenant_id == tenant_id,
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE,
)
result = await session.execute(stmt)
fields = result.scalars().all()
print(f"\n找到 {len(fields)} 个活跃字段定义:")
for f in fields:
print(f"\n 字段: {f.field_key}")
print(f" label: {f.label}")
print(f" type: {f.type}")
print(f" required: {f.required}")
print(f" field_roles: {f.field_roles}")
print(f" options: {f.options}")
print(f" default_value: {f.default_value}")
filterable_fields = [
f for f in fields
if f.field_roles and FieldRole.RESOURCE_FILTER.value in f.field_roles
]
print(f"\n{'='*80}")
print(f"可过滤字段 (field_roles 包含 resource_filter): {len(filterable_fields)}")
for f in filterable_fields:
print(f" - {f.field_key} (label: {f.label}, required: {f.required})")
if __name__ == "__main__":
asyncio.run(check_metadata_fields())

View File

@ -0,0 +1,68 @@
"""
检查 Qdrant 中数据的 metadata 存储结构
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def check_metadata_structure():
"""检查 Qdrant 中数据的 metadata 存储结构"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"\n{'='*80}")
print(f"检查 Qdrant 数据结构")
print(f"{'='*80}")
print(f"Collection: {collection_name}")
points = await qdrant.scroll(
collection_name=collection_name,
limit=3,
with_vectors=False,
)
print(f"\n找到 {len(points[0])} 条数据:")
for i, point in enumerate(points[0], 1):
print(f"\n--- Point {i} ---")
if hasattr(point, 'payload'):
payload = point.payload
point_id = point.id
else:
payload = point.get('payload', {})
point_id = point.get('id', 'unknown')
print(f"ID: {point_id}")
print(f"Payload keys: {list(payload.keys())}")
# 打印完整的 payload 结构
for key, value in payload.items():
if key == 'text':
print(f" {key}: {value[:50]}..." if len(str(value)) > 50 else f" {key}: {value}")
elif key == 'vector':
print(f" {key}: [向量数据]")
else:
print(f" {key}: {value}")
# 检查 metadata 字段
if 'metadata' in payload:
print(f"\n metadata 字段内容:")
for mk, mv in payload['metadata'].items():
print(f" {mk}: {mv}")
if __name__ == "__main__":
asyncio.run(check_metadata_structure())

View File

@ -1,79 +1,112 @@
"""
Check Qdrant vector database contents - detailed view.
检查 Qdrant 向量数据库状态和知识库内容
"""
import asyncio
import sys
sys.path.insert(0, ".")
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
from collections import defaultdict
settings = get_settings()
from app.core.qdrant_client import get_qdrant_client
async def check_qdrant():
"""Check Qdrant collections and vectors."""
client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False)
"""检查 Qdrant 状态"""
settings = get_settings()
tenant_id = "szmp@ash@2026"
print(f"\n{'='*60}")
print(f"Database URL: {settings.database_url}")
print(f"Qdrant URL: {settings.qdrant_url}")
print(f"{'='*60}\n")
print(f"Tenant ID: {tenant_id}")
print()
# List all collections
collections = await client.get_collections()
# Check kb_default collection
for c in collections.collections:
if c.name == "kb_default":
print(f"\n--- Collection: {c.name} ---")
try:
qdrant_manager = await get_qdrant_client()
client = await qdrant_manager.get_client()
# 检查集合是否存在
collections = (await client.get_collections()).collections
collection_names = [c.name for c in collections]
print(f"Available collections: {collection_names}")
print()
# 筛选该租户的 collections
tenant_collections = [name for name in collection_names if "szmp_ash_2026" in name]
print(f"Tenant collections: {tenant_collections}")
print()
# 检查每个集合
for collection_name in tenant_collections:
print(f"\n{'='*60}")
print(f"Collection: {collection_name}")
print(f"{'='*60}")
# Get collection info
info = await client.get_collection(c.name)
print(f" Total vectors: {info.points_count}")
# 获取集合信息
collection_info = await client.get_collection(collection_name)
print(f" Points count: {collection_info.points_count}")
print(f" Vectors count: {collection_info.vectors_count}")
print(f" Status: {collection_info.status}")
# Scroll through all points and group by source
all_points = []
offset = None
if collection_info.points_count == 0:
print(" ⚠️ Collection is empty!")
continue
while True:
points, offset = await client.scroll(
collection_name=c.name,
limit=100,
offset=offset,
# 滚动获取一些数据
print(f"\n 前 3 条数据:")
points, next_page = await client.scroll(
collection_name=collection_name,
limit=3,
with_payload=True,
with_vectors=False,
)
for i, point in enumerate(points, 1):
payload = point.payload or {}
text = payload.get("text", "")[:100] + "..." if payload.get("text") else "N/A"
kb_id = payload.get("kb_id", "N/A")
metadata = payload.get("metadata", {})
print(f"\n Point {i}:")
print(f" ID: {point.id}")
print(f" KB ID: {kb_id}")
print(f" Text: {text}")
print(f" Metadata: {metadata}")
# 尝试向量搜索
print(f"\n\n{'='*60}")
print(f"尝试向量搜索 (query='课程'):")
print(f"{'='*60}")
from app.services.embedding.factory import get_embedding_provider
embedding_provider = await get_embedding_provider()
query_vector = await embedding_provider.embed("课程")
print(f"Query vector dimension: {len(query_vector)}")
for collection_name in tenant_collections:
print(f"\n搜索 collection: {collection_name}")
try:
search_results = await client.query_points(
collection_name=collection_name,
query=query_vector,
using="full", # 使用 full 向量
limit=3,
with_payload=True,
with_vectors=False,
)
all_points.extend(points)
if offset is None:
break
# Group by source
by_source = defaultdict(list)
for p in all_points:
source = p.payload.get("source", "unknown") if p.payload else "unknown"
by_source[source].append(p)
print(f"\n Documents by source:")
for source, points in by_source.items():
print(f"\n Source: {source}")
print(f" Chunks: {len(points)}")
# Check first chunk content
first_point = points[0]
text = first_point.payload.get("text", "") if first_point.payload else ""
# Check if it's binary garbage or proper text
is_garbage = any(ord(c) > 0xFFFF or (ord(c) < 32 and c not in '\n\r\t') for c in text[:200])
if is_garbage:
print(f" Status: ❌ BINARY GARBAGE (parsing failed)")
else:
print(f" Status: ✅ PROPER TEXT (parsed correctly)")
print(f" Preview: {text[:150]}...")
await client.close()
print(f" Search results: {len(search_results.points)}")
for i, result in enumerate(search_results.points, 1):
payload = result.payload or {}
text = payload.get("text", "")[:80] + "..." if payload.get("text") else "N/A"
print(f" {i}. [score={result.score:.4f}] {text}")
except Exception as e:
print(f" ❌ Search error: {e}")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":

View File

@ -0,0 +1,54 @@
"""
检查 Qdrant 中实际存在的 collections
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
async def check_qdrant_collections():
"""检查 Qdrant 中实际存在的 collections"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
try:
collections = await client.get_collections()
print(f"\n{'='*80}")
print(f"Qdrant 中所有 collections:")
print(f"{'='*80}")
for coll in collections.collections:
print(f" - {coll.name}")
tenant_id = "szmp@ash@2026"
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [coll.name for coll in collections.collections if coll.name.startswith(prefix)]
print(f"\n租户 {tenant_id} 的 collections:")
print(f"{'='*80}")
for coll_name in tenant_collections:
print(f" - {coll_name}")
kb_id = None
if coll_name.startswith(prefix):
parts = coll_name.split('_')
if len(parts) > 2:
kb_id = parts[2]
print(f" kb_id: {kb_id}")
except Exception as e:
print(f"错误: {e}")
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(check_qdrant_collections())

View File

@ -0,0 +1,61 @@
"""
检查 Qdrant 中实际存储的数据结构
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def check_qdrant_data():
"""检查 Qdrant 中的数据"""
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
print(f"\n{'='*80}")
print(f"检查租户 {tenant_id} 的 Qdrant 数据")
print(f"{'='*80}")
collections = await client.list_collections(tenant_id)
print(f"\n找到 {len(collections)} 个集合:")
for coll in collections:
print(f" - {coll}")
for collection_name in collections[:3]:
print(f"\n{'='*80}")
print(f"检查集合: {collection_name}")
print(f"{'='*80}")
try:
points = await client.scroll_points(
collection_name=collection_name,
limit=5,
)
print(f"\n找到 {len(points)} 条数据:")
for i, point in enumerate(points, 1):
payload = point.get('payload', {})
print(f"\n [{i}] id: {point.get('id')}")
print(f" metadata 字段:")
for key, value in payload.items():
if key != 'text' and key != 'vector':
print(f" {key}: {value}")
text = payload.get('text', '')
if text:
print(f" text 预览: {text[:100]}...")
except Exception as e:
print(f" 错误: {e}")
if __name__ == "__main__":
asyncio.run(check_qdrant_data())

View File

@ -0,0 +1,120 @@
"""
清理 szmp@ash@2026 租户下不需要的 Qdrant collections
保留8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c, 30c19c84-8f69-4768-9d23-7f4a5bc3627a
删除其他所有 collections
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
async def cleanup_collections():
"""清理 collections"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
tenant_id = "szmp@ash@2026"
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
# 保留的 kb_id 前缀前8位
keep_kb_ids = [
"8559ebc9",
"30c19c84",
]
print(f"🔍 扫描租户 {tenant_id} 的 collections...")
print(f" 前缀: {prefix}")
print(f" 保留: {keep_kb_ids}")
print("-" * 80)
try:
collections = await client.get_collections()
# 找出该租户的所有 collections
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
print(f"\n📊 找到 {len(tenant_collections)} 个 collections:")
for name in sorted(tenant_collections):
# 检查是否需要保留
should_keep = any(kb_id in name for kb_id in keep_kb_ids)
status = "✅ 保留" if should_keep else "❌ 删除"
print(f" {status} {name}")
print("\n" + "=" * 80)
print("开始删除...")
print("=" * 80)
deleted = []
skipped = []
for collection_name in tenant_collections:
# 检查是否需要保留
should_keep = any(kb_id in collection_name for kb_id in keep_kb_ids)
if should_keep:
print(f"\n⏭️ 跳过 {collection_name} (保留)")
skipped.append(collection_name)
continue
print(f"\n🗑️ 删除 {collection_name}...")
try:
await client.delete_collection(collection_name)
print(f" ✅ 已删除")
deleted.append(collection_name)
except Exception as e:
print(f" ❌ 删除失败: {e}")
print("\n" + "=" * 80)
print("清理完成!")
print("=" * 80)
print(f"\n📈 统计:")
print(f" 保留: {len(skipped)}")
for name in skipped:
print(f" - {name}")
print(f"\n 删除: {len(deleted)}")
for name in deleted:
print(f" - {name}")
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
# 安全确认
print("=" * 80)
print("⚠️ 警告: 此操作将永久删除以下 collections:")
print(" - kb_szmp_ash_2026")
print(" - kb_szmp_ash_2026_fa4c1d61")
print(" - kb_szmp_ash_2026_3ddf0ce7")
print("\n 保留:")
print(" - kb_szmp_ash_2026_8559ebc9")
print(" - kb_szmp_ash_2026_30c19c84")
print("=" * 80)
print("\n确认删除? (yes/no): ", end="")
# 在非交互环境自动确认
import os
if os.environ.get('AUTO_CONFIRM') == 'true':
response = 'yes'
print('yes (auto)')
else:
response = input().strip().lower()
if response in ('yes', 'y'):
asyncio.run(cleanup_collections())
else:
print("\n❌ 已取消")

View File

@ -0,0 +1,42 @@
"""
删除课程知识库的 Qdrant Collection
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def delete_course_kb_collection():
"""删除课程知识库的 Qdrant Collection"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"\n{'='*80}")
print(f"删除课程知识库的 Qdrant Collection")
print(f"{'='*80}")
print(f"租户 ID: {tenant_id}")
print(f"知识库 ID: {kb_id}")
print(f"Collection 名称: {collection_name}")
exists = await qdrant.collection_exists(collection_name)
if exists:
await qdrant.delete_collection(collection_name)
print(f"\n✅ Collection {collection_name} 已删除!")
else:
print(f"\n⚠️ Collection {collection_name} 不存在,无需删除")
if __name__ == "__main__":
asyncio.run(delete_course_kb_collection())

View File

@ -0,0 +1,75 @@
"""
删除课程知识库的文档记录
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import get_settings
from app.models.entities import Document, IndexJob
async def delete_course_kb_documents():
"""删除课程知识库的文档记录"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
print(f"\n{'='*80}")
print(f"删除课程知识库的文档记录")
print(f"{'='*80}")
print(f"租户 ID: {tenant_id}")
print(f"知识库 ID: {kb_id}")
async with async_session() as session:
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.kb_id == kb_id,
)
result = await session.execute(stmt)
documents = result.scalars().all()
print(f"\n找到 {len(documents)} 个文档记录")
if not documents:
print("没有需要删除的文档记录")
return
for doc in documents[:5]:
print(f" - {doc.file_name} (id: {doc.id})")
if len(documents) > 5:
print(f" ... 还有 {len(documents) - 5} 个文档")
doc_ids = [doc.id for doc in documents]
index_job_stmt = delete(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id.in_(doc_ids),
)
index_job_result = await session.execute(index_job_stmt)
print(f"\n删除了 {index_job_result.rowcount} 个索引任务记录")
doc_stmt = delete(Document).where(
Document.tenant_id == tenant_id,
Document.kb_id == kb_id,
)
doc_result = await session.execute(doc_stmt)
print(f"删除了 {doc_result.rowcount} 个文档记录")
await session.commit()
print(f"\n✅ 删除完成!")
if __name__ == "__main__":
asyncio.run(delete_course_kb_documents())

View File

@ -0,0 +1,42 @@
"""
获取数据库中的 API key
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.models.entities import ApiKey
from app.core.config import get_settings
async def get_api_keys():
"""获取所有 API keys"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
result = await session.execute(
select(ApiKey).where(ApiKey.is_active == True)
)
keys = result.scalars().all()
print("=" * 80)
print("数据库中的 API Keys:")
print("=" * 80)
for key in keys:
print(f"\nKey: {key.key}")
print(f" Name: {key.name}")
print(f" Tenant: {key.tenant_id}")
print(f" Active: {key.is_active}")
print("\n" + "=" * 80)
if __name__ == "__main__":
asyncio.run(get_api_keys())

View File

@ -0,0 +1,54 @@
"""
获取数据库中的意图规则
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.models.entities import IntentRule
from app.core.config import get_settings
async def get_intent_rules():
"""获取所有意图规则"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
result = await session.execute(
select(IntentRule).where(
IntentRule.tenant_id == "szmp@ash@2026",
IntentRule.is_enabled == True
)
)
rules = result.scalars().all()
print("=" * 80)
print("数据库中的意图规则 (tenant=szmp@ash@2026):")
print("=" * 80)
if not rules:
print("\n没有找到任何启用的意图规则!")
else:
for rule in rules:
print(f"\n规则: {rule.name}")
print(f" ID: {rule.id}")
print(f" 响应类型: {rule.response_type}")
print(f" 关键词: {rule.keywords}")
print(f" 正则模式: {rule.patterns}")
print(f" 目标知识库: {rule.target_kb_ids}")
print(f" 优先级: {rule.priority}")
print(f" 启用: {rule.is_enabled}")
print("\n" + "=" * 80)
if __name__ == "__main__":
asyncio.run(get_intent_rules())

View File

@ -0,0 +1,9 @@
-- Migration: Add usage_description column to metadata_field_definitions table
-- [AC-IDSMETA-XX] Add usage description field for metadata field definitions
-- Add usage_description column
ALTER TABLE metadata_field_definitions
ADD COLUMN IF NOT EXISTS usage_description TEXT;
-- Add comment
COMMENT ON COLUMN metadata_field_definitions.usage_description IS '用途说明,描述该元数据字段的业务用途';

View File

@ -0,0 +1,33 @@
-- Migration: Add extract_strategies field to slot_definitions
-- Date: 2026-03-06
-- Issue: [AC-MRS-07-UPGRADE] 提取策略体系升级 - 支持策略链
-- 1. 为 slot_definitions 表新增 extract_strategies 字段JSONB 数组格式)
ALTER TABLE slot_definitions
ADD COLUMN IF NOT EXISTS extract_strategies JSONB DEFAULT NULL;
-- 2. 添加字段注释
COMMENT ON COLUMN slot_definitions.extract_strategies IS
'[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功';
-- 3. 数据迁移:将旧的 extract_strategy 转换为 extract_strategies 数组
-- 注意:保留旧字段用于兼容读取
UPDATE slot_definitions
SET extract_strategies = CASE
WHEN extract_strategy IS NOT NULL THEN jsonb_build_array(extract_strategy)
ELSE NULL
END
WHERE extract_strategies IS NULL;
-- 4. 创建 GIN 索引支持策略查询(可选,根据实际查询需求)
-- CREATE INDEX IF NOT EXISTS idx_slot_definitions_extract_strategies
-- ON slot_definitions USING GIN (extract_strategies);
-- 5. 验证迁移结果
-- SELECT
-- id,
-- slot_key,
-- extract_strategy as old_strategy,
-- extract_strategies as new_strategies
-- FROM slot_definitions
-- WHERE extract_strategy IS NOT NULL;

View File

@ -0,0 +1,5 @@
-- Add metadata field to documents table
-- Migration: 010_add_metadata_to_documents
-- Description: Add metadata JSON field to store document-level metadata
ALTER TABLE documents ADD COLUMN IF NOT EXISTS metadata JSONB;

View File

@ -0,0 +1,55 @@
"""
Migration script to add intent_vector and semantic_examples fields to intent_rules table.
[v0.8.0] Hybrid routing - Intent vector fields
Run this script with: python scripts/migrations/011_add_intent_vector_fields.py
"""
import asyncio
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from sqlalchemy import text
from app.core.database import async_session_maker
async def run_migration():
"""Run the migration to add new columns to intent_rules table."""
statements = [
"""
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
""",
"""
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
""",
"""
ALTER TABLE chat_messages
ADD COLUMN IF NOT EXISTS route_trace JSONB;
""",
]
async with async_session_maker() as session:
for i, statement in enumerate(statements, 1):
try:
await session.execute(text(statement))
print(f"[{i}] Executed successfully")
except Exception as e:
if "already exists" in str(e).lower() or "duplicate" in str(e).lower():
print(f"[{i}] Skipped (already exists): {str(e)[:50]}...")
else:
raise
await session.commit()
print("\nMigration completed successfully!")
if __name__ == "__main__":
asyncio.run(run_migration())
else:
print("Please run this script directly, not imported.")
sys.exit(1)

View File

@ -0,0 +1,20 @@
-- Add intent_vector and semantic_examples fields to intent_rules table
-- [v0.8.0] Hybrid routing support for semantic matching
-- Add intent_vector column (JSONB for storing pre-computed embedding vectors)
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
-- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation)
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
-- Add comments for documentation
COMMENT ON COLUMN intent_rules.intent_vector IS '[v0.8.0] Pre-computed intent vector for semantic matching';
COMMENT ON COLUMN intent_rules.semantic_examples IS '[v0.8.0] Semantic example sentences for dynamic vector computation';
-- Add route_trace column to chat_messages table if not exists
ALTER TABLE chat_messages
ADD COLUMN IF NOT EXISTS route_trace JSONB;
COMMENT ON COLUMN chat_messages.route_trace IS '[v0.8.0] Intent routing trace log for hybrid routing observability';

View File

@ -0,0 +1,204 @@
"""
详细性能分析 - 确认每个环节的耗时
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
async def profile_detailed():
"""详细分析每个环节的耗时"""
settings = get_settings()
print("=" * 80)
print("详细性能分析")
print("=" * 80)
query = "三年级语文学习"
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
# 1. Embedding 生成(应该已预初始化)
print("\n📊 1. Embedding 生成")
print("-" * 80)
start = time.time()
embedding_service = await get_embedding_provider()
init_time = (time.time() - start) * 1000
start = time.time()
embedding_result = await embedding_service.embed_query(query)
embed_time = (time.time() - start) * 1000
# 获取 embedding 向量
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
print(f" 获取服务实例: {init_time:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" 向量维度: {len(query_vector)}")
# 2. 获取 collections 列表(带缓存)
print("\n📊 2. 获取 collections 列表")
print("-" * 80)
client = await get_qdrant_client()
qdrant_client = await client.get_client()
start = time.time()
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cache_key = f"collections:{tenant_id}"
# 尝试从缓存获取
redis_client = await cache_service._get_redis()
cache_hit = False
if redis_client and cache_service._enabled:
cached = await redis_client.get(cache_key)
if cached:
import json
tenant_collections = json.loads(cached)
cache_hit = True
cache_time = (time.time() - start) * 1000
print(f" ✅ 缓存命中: {cache_time:.2f} ms")
print(f" Collections: {tenant_collections}")
if not cache_hit:
import json
# 从 Qdrant 查询
start = time.time()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
tenant_collections.sort()
db_time = (time.time() - start) * 1000
print(f" ❌ 缓存未命中,从 Qdrant 查询: {db_time:.2f} ms")
print(f" Collections: {tenant_collections}")
# 缓存结果
if redis_client and cache_service._enabled:
await redis_client.setex(cache_key, 300, json.dumps(tenant_collections))
print(f" 已缓存到 Redis")
# 3. Qdrant 搜索(每个 collection
print("\n📊 3. Qdrant 搜索")
print("-" * 80)
from qdrant_client.models import FieldCondition, Filter, MatchValue
# 构建 filter
start = time.time()
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
filter_time = (time.time() - start) * 1000
print(f" 构建 filter: {filter_time:.2f} ms")
# 逐个 collection 搜索
total_search_time = 0
for collection_name in tenant_collections:
print(f"\n Collection: {collection_name}")
# 检查是否存在
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = (time.time() - start) * 1000
print(f" 检查存在: {check_time:.2f} ms")
if not exists:
print(f" ❌ 不存在")
continue
# 搜索
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
search_time = (time.time() - start) * 1000
total_search_time += search_time
print(f" 搜索时间: {search_time:.2f} ms")
print(f" 结果数: {len(results.points)}")
except Exception as e:
print(f" ❌ 搜索失败: {e}")
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
# 4. 完整 KB Search 流程
print("\n📊 4. 完整 KB Search 流程")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000,
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
# 5. 总结
print("\n" + "=" * 80)
print("📈 耗时总结")
print("=" * 80)
print(f"\n各环节耗时:")
print(f" Embedding 获取服务: {init_time:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" Collections 获取: {cache_time if cache_hit else db_time:.2f} ms")
print(f" Filter 构建: {filter_time:.2f} ms")
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
print(f" 完整流程: {total_time:.2f} ms")
other_time = total_time - embed_time - (cache_time if cache_hit else db_time) - filter_time - total_search_time
print(f" 其他开销: {other_time:.2f} ms")
print("\n" + "=" * 80)
if __name__ == "__main__":
asyncio.run(profile_detailed())

View File

@ -0,0 +1,246 @@
"""
详细分析完整参数查询的耗时
对比带 metadata_filter 和不带的区别
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
from qdrant_client.models import FieldCondition, Filter, MatchValue
async def profile_step_by_step():
"""逐步分析完整参数查询的耗时"""
settings = get_settings()
print("=" * 80)
print("完整参数查询耗时分析")
print("=" * 80)
query = "三年级语文学习"
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
# 1. Embedding 生成
print("\n📊 1. Embedding 生成")
print("-" * 80)
from app.services.embedding import get_embedding_provider
start = time.time()
embedding_service = await get_embedding_provider()
init_time = time.time() - start
start = time.time()
embedding_result = await embedding_service.embed_query(query)
embed_time = (time.time() - start) * 1000
# 获取 embedding 向量
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
print(f" 初始化时间: {init_time * 1000:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" 向量维度: {len(query_vector)}")
# 2. 获取 collections 列表
print("\n📊 2. 获取 collections 列表")
print("-" * 80)
client = await get_qdrant_client()
qdrant_client = await client.get_client()
start = time.time()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
list_time = (time.time() - start) * 1000
print(f" 获取 collections: {list_time:.2f} ms")
print(f" Collections: {tenant_collections}")
# 3. 构建 metadata filter
print("\n📊 3. 构建 metadata filter")
print("-" * 80)
start = time.time()
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
filter_time = (time.time() - start) * 1000
print(f" 构建 filter: {filter_time:.2f} ms")
print(f" Filter: {qdrant_filter}")
# 4. 逐个 collection 搜索(带 filter
print("\n📊 4. Qdrant 搜索(带 metadata filter")
print("-" * 80)
total_search_time = 0
total_results = 0
for collection_name in tenant_collections:
print(f"\n Collection: {collection_name}")
# 检查是否存在
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = (time.time() - start) * 1000
print(f" 检查存在: {check_time:.2f} ms")
if not exists:
print(f" ❌ 不存在")
continue
# 搜索(带 filter
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
search_time = (time.time() - start) * 1000
total_search_time += search_time
total_results += len(results.points)
print(f" 搜索时间: {search_time:.2f} ms")
print(f" 结果数: {len(results.points)}")
except Exception as e:
print(f" ❌ 搜索失败: {e}")
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
print(f" 总结果数: {total_results}")
# 5. 对比:不带 filter 的搜索
print("\n📊 5. Qdrant 搜索(不带 metadata filter对比")
print("-" * 80)
total_search_time_no_filter = 0
total_results_no_filter = 0
for collection_name in tenant_collections:
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
# 不带 filter
)
search_time = (time.time() - start) * 1000
total_search_time_no_filter += search_time
total_results_no_filter += len(results.points)
except Exception as e:
print(f" {collection_name}: 失败 {e}")
print(f" 总搜索时间(无 filter: {total_search_time_no_filter:.2f} ms")
print(f" 总结果数(无 filter: {total_results_no_filter}")
# 6. 完整 KB Search 流程
print("\n📊 6. 完整 KB Search 流程(带 context")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000,
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
print(f" 应用的 filter: {result.applied_filter}")
# 7. 对比:不带 context 的完整流程
print("\n📊 7. 完整 KB Search 流程(不带 context")
print("-" * 80)
async with async_session() as session:
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
# 不带 context
)
total_time_no_context = (time.time() - start) * 1000
print(f" 总耗时: {total_time_no_context:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
# 8. 总结
print("\n" + "=" * 80)
print("📈 耗时分析总结")
print("=" * 80)
print(f"\n带 metadata filter:")
print(f" Embedding: {embed_time:.2f} ms")
print(f" 获取 collections: {list_time:.2f} ms")
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
print(f" 完整流程: {total_time:.2f} ms")
print(f"\n不带 metadata filter:")
print(f" Qdrant 搜索: {total_search_time_no_filter:.2f} ms")
print(f" 完整流程: {total_time_no_context:.2f} ms")
print(f"\nMetadata filter 额外开销:")
print(f" Qdrant 搜索: {total_search_time - total_search_time_no_filter:.2f} ms")
print(f" 完整流程: {total_time - total_time_no_context:.2f} ms")
if total_search_time > total_search_time_no_filter:
print(f"\n⚠️ 带 filter 的搜索更慢,可能原因:")
print(f" - Filter 增加了索引查找的复杂度")
print(f" - 需要匹配 metadata 字段")
print(f" - 建议: 检查 Qdrant 的 payload 索引配置")
if __name__ == "__main__":
asyncio.run(profile_step_by_step())

View File

@ -0,0 +1,259 @@
"""
知识库检索性能分析脚本
详细分析每个环节的耗时
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.services.retrieval.vector_retriever import VectorRetriever
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def profile_embedding_generation(query: str):
"""分析 embedding 生成耗时"""
from app.services.embedding import get_embedding_provider
start = time.time()
embedding_service = await get_embedding_provider()
init_time = time.time() - start
start = time.time()
embedding = await embedding_service.embed_query(query)
embed_time = time.time() - start
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding, 'embedding_full'):
vector = embedding.embedding_full
elif hasattr(embedding, 'embedding'):
vector = embedding.embedding
else:
vector = embedding
return {
"init_time_ms": init_time * 1000,
"embed_time_ms": embed_time * 1000,
"dimension": len(vector),
}
async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None):
"""分析 Qdrant 搜索耗时"""
from app.core.qdrant_client import get_qdrant_client
client = await get_qdrant_client()
# 获取 collections
start = time.time()
qdrant_client = await client.get_client()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
list_collections_time = time.time() - start
# 逐个 collection 搜索
collection_times = []
for collection_name in tenant_collections:
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = time.time() - start
if not exists:
collection_times.append({
"collection": collection_name,
"exists": False,
"time_ms": check_time * 1000,
})
continue
start = time.time()
# 构建 filter
qdrant_filter = None
if metadata_filter:
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full", # 使用 full 向量
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
except Exception as e:
if "vector name" in str(e).lower():
# 尝试不使用 vector name
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
else:
raise
search_time = time.time() - start
collection_times.append({
"collection": collection_name,
"exists": True,
"check_time_ms": check_time * 1000,
"search_time_ms": search_time * 1000,
"results_count": len(results.points),
})
return {
"list_collections_time_ms": list_collections_time * 1000,
"collections_count": len(tenant_collections),
"collection_times": collection_times,
}
async def profile_full_kb_search():
"""分析完整的知识库搜索流程"""
settings = get_settings()
print("=" * 80)
print("知识库检索性能分析")
print("=" * 80)
# 1. 分析 Embedding 生成
print("\n📊 1. Embedding 生成分析")
print("-" * 80)
query = "三年级语文学习"
embed_result = await profile_embedding_generation(query)
print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms")
print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms")
print(f" 向量维度: {embed_result['dimension']}")
# 2. 分析 Qdrant 搜索
print("\n📊 2. Qdrant 搜索分析")
print("-" * 80)
# 先生成 embedding
from app.services.embedding import get_embedding_provider
embedding_service = await get_embedding_provider()
embedding_result = await embedding_service.embed_query(query)
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter)
print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms")
print(f" Collections 数量: {qdrant_result['collections_count']}")
print(f"\n 各 Collection 搜索耗时:")
for ct in qdrant_result['collection_times']:
if ct['exists']:
print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)")
else:
print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)")
total_search_time = sum(
ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times']
)
print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms")
# 3. 分析完整流程
print("\n📊 3. 完整 KB Search 流程分析")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000, # 30秒
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
# 记录各阶段时间
stages = []
start_total = time.time()
# 执行搜索
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start_total) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 结果: success={result.success}, hits={len(result.hits)}")
print(f" 工具内部耗时: {result.duration_ms} ms")
# 计算时间差(工具内部 vs 外部测量)
overhead = total_time - result.duration_ms
print(f" 额外开销(初始化等): {overhead:.2f} ms")
# 4. 性能瓶颈分析
print("\n📊 4. 性能瓶颈分析")
print("-" * 80)
embedding_time = embed_result['embed_time_ms']
qdrant_time = total_search_time
total_measured = embedding_time + qdrant_time
print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)")
print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)")
print(f" 其他开销: {total_time - total_measured:.2f} ms")
print("\n" + "=" * 80)
print("优化建议:")
print("=" * 80)
if embedding_time > 1000:
print(" ⚠️ Embedding 生成较慢,考虑:")
print(" - 使用更快的 embedding 模型")
print(" - 增加 embedding 服务缓存")
if qdrant_time > 1000:
print(" ⚠️ Qdrant 搜索较慢,考虑:")
print(" - 并行查询多个 collections")
print(" - 优化 Qdrant 索引")
print(" - 减少 collections 数量")
if len(qdrant_result['collection_times']) > 3:
print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)")
print(" - 建议合并或归档空/少数据的 collections")
if __name__ == "__main__":
asyncio.run(profile_full_kb_search())

View File

@ -0,0 +1,120 @@
"""
查询 Qdrant Collection 中的所有内容
"""
import asyncio
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import ScrollRequest
from app.core.config import get_settings
async def query_all_points(collection_name: str):
"""查询 collection 中的所有 points"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False)
print(f"🔍 查询 Collection: {collection_name}")
print("=" * 80)
try:
# 获取 collection 信息
info = await client.get_collection(collection_name)
total_points = info.points_count
print(f"📊 总向量数: {total_points}\n")
# 分页获取所有 points
all_points = []
offset = None
batch_size = 100
while True:
scroll_result = await client.scroll(
collection_name=collection_name,
offset=offset,
limit=batch_size,
with_payload=True,
with_vectors=False
)
points, next_offset = scroll_result
all_points.extend(points)
if next_offset is None:
break
offset = next_offset
# 显示进度
if len(all_points) % 500 == 0:
print(f" 已获取 {len(all_points)} / {total_points} 条记录...")
print(f"✅ 成功获取全部 {len(all_points)} 条记录\n")
print("=" * 80)
# 显示所有内容
for i, point in enumerate(all_points, 1):
payload = point.payload or {}
print(f"\n📄 记录 {i}/{len(all_points)} (ID: {point.id})")
print("-" * 80)
# 显示主要字段
text = payload.get('text', '')
kb_id = payload.get('kb_id', 'N/A')
source = payload.get('source', 'N/A')
chunk_index = payload.get('chunk_index', 'N/A')
metadata = payload.get('metadata', {})
print(f" KB ID: {kb_id}")
print(f" Source: {source}")
print(f" Chunk Index: {chunk_index}")
if metadata:
print(f" Metadata: {json.dumps(metadata, ensure_ascii=False)}")
# 显示文本内容(格式化)
print(f"\n 文本内容:")
if text:
# 按行显示,保持格式
lines = text.split('\n')
for line in lines:
if line.strip():
print(f" {line}")
else:
print(" (无文本内容)")
print("\n" + "=" * 80)
print(f"✅ 查询完成,共 {len(all_points)} 条记录")
# 统计信息
print("\n📈 统计信息:")
kb_ids = {}
for point in all_points:
payload = point.payload or {}
kb_id = payload.get('kb_id', 'N/A')
kb_ids[kb_id] = kb_ids.get(kb_id, 0) + 1
print(f" KB ID 分布:")
for kb_id, count in sorted(kb_ids.items()):
print(f" - {kb_id}: {count}")
except Exception as e:
print(f"❌ 查询失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
async def main():
collection_name = "kb_szmp_ash_2026_30c19c84"
await query_all_points(collection_name)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,305 @@
"""
恢复处理中断的索引任务
用于服务重启后继续处理pending/processing状态的任务
"""
import asyncio
import logging
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from app.core.database import async_session_maker
from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus
from app.api.admin.kb import _index_document
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def resume_pending_jobs():
"""恢复所有pending和processing状态的任务"""
async with async_session_maker() as session:
# 查询所有未完成的任务
result = await session.execute(
select(IndexJob).where(
IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value])
)
)
pending_jobs = result.scalars().all()
if not pending_jobs:
logger.info("没有需要恢复的任务")
return
logger.info(f"发现 {len(pending_jobs)} 个未完成的任务")
for job in pending_jobs:
try:
# 获取关联的文档
doc_result = await session.execute(
select(Document).where(Document.id == job.doc_id)
)
doc = doc_result.scalar_one_or_none()
if not doc:
logger.error(f"找不到文档: {job.doc_id}")
continue
if not doc.file_path or not Path(doc.file_path).exists():
logger.error(f"文档文件不存在: {doc.file_path}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = "文档文件不存在"
doc.status = DocumentStatus.FAILED.value
doc.error_msg = "文档文件不存在"
await session.commit()
continue
logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}")
# 读取文件内容
with open(doc.file_path, 'rb') as f:
file_content = f.read()
# 重置任务状态为pending
job.status = IndexJobStatus.PENDING.value
job.progress = 0
job.error_msg = None
await session.commit()
# 启动后台任务处理
# 注意这里我们直接调用不使用background_tasks
await process_job(
tenant_id=job.tenant_id,
kb_id=doc.kb_id,
job_id=str(job.id),
doc_id=str(doc.id),
file_content=file_content,
filename=doc.file_name,
metadata=doc.doc_metadata or {}
)
logger.info(f"任务处理完成: job_id={job.id}")
except Exception as e:
logger.error(f"处理任务失败: job_id={job.id}, error={e}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = str(e)
if doc:
doc.status = DocumentStatus.FAILED.value
doc.error_msg = str(e)
await session.commit()
logger.info("所有任务处理完成")
async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str,
file_content: bytes, filename: str, metadata: dict):
"""
处理单个索引任务
复制自 _index_document 函数
"""
import tempfile
from pathlib import Path
from qdrant_client.models import PointStruct
from app.core.qdrant_client import get_qdrant_client
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
from app.services.embedding import get_embedding_provider
from app.services.kb import KBService
from app.api.admin.kb import chunk_text_by_lines, TextChunk
logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}")
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info("[RESUME] Treating as text file")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = file_content.decode(encoding)
logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = file_content.decode("utf-8", errors="replace")
else:
logger.info("[RESUME] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
logger.info(f"[RESUME] Temp file created: {tmp_path}")
try:
logger.info(f"[RESUME] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}"
)
except UnsupportedFormatError as e:
logger.error(f"[RESUME] UnsupportedFormatError: {e}")
text = file_content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[RESUME] DocumentParseException: {e}")
text = file_content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}")
text = file_content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[RESUME] Final text length: {len(text)} chars")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info("[RESUME] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
else:
logger.info("[RESUME] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[RESUME] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}")
import uuid
points = []
total_chunks = len(all_chunks)
doc_metadata = metadata or {}
for i, chunk in enumerate(all_chunks):
payload = {
"text": chunk.text,
"source": doc_id,
"kb_id": kb_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
"metadata": doc_metadata,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
if use_multi_vector:
embedding_result = await embedding_provider.embed_document(chunk.text)
points.append({
"id": str(uuid.uuid4()),
"vector": {
"full": embedding_result.embedding_full,
"dim_256": embedding_result.embedding_256,
"dim_512": embedding_result.embedding_512,
},
"payload": payload,
})
else:
embedding = await embedding_provider.embed(chunk.text)
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
else:
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}"
)
except Exception as e:
import traceback
logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
if __name__ == "__main__":
logger.info("开始恢复索引任务...")
asyncio.run(resume_pending_jobs())
logger.info("恢复脚本执行完成")

View File

@ -0,0 +1,47 @@
"""
Migration script to add intent_vector and semantic_examples fields.
Run: python scripts/run_migration_011.py
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sqlalchemy import text
from app.core.database import engine
async def run_migration():
"""Execute migration to add new fields."""
migration_sql = """
-- Add intent_vector column (JSONB for storing pre-computed embedding vectors)
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS intent_vector JSONB;
-- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation)
ALTER TABLE intent_rules
ADD COLUMN IF NOT EXISTS semantic_examples JSONB;
-- Add route_trace column to chat_messages table if not exists
ALTER TABLE chat_messages
ADD COLUMN IF NOT EXISTS route_trace JSONB;
"""
async with engine.begin() as conn:
for statement in migration_sql.strip().split(";"):
statement = statement.strip()
if statement and not statement.startswith("--"):
try:
await conn.execute(text(statement))
print(f"Executed: {statement[:80]}...")
except Exception as e:
print(f"Error executing: {statement[:80]}...")
print(f" Error: {e}")
print("\nMigration completed successfully!")
if __name__ == "__main__":
asyncio.run(run_migration())

View File

@ -0,0 +1,82 @@
"""
通过 API 测试知识库检索性能
"""
import requests
import json
import time
API_BASE = "http://localhost:8000"
API_KEY = "oQfkSAbL8iafzyHxqb--G7zRWSOYJHvlzQxia2KpYms"
TENANT_ID = "szmp@ash@2026"
def test_kb_search():
"""测试知识库搜索 API"""
print("=" * 80)
print("测试知识库检索 API")
print("=" * 80)
headers = {
"Content-Type": "application/json",
"X-API-Key": API_KEY,
"X-Tenant-Id": TENANT_ID,
}
# 测试数据
test_cases = [
{
"name": "完整参数含context过滤",
"data": {
"query": "三年级语文学习",
"scene": "学习方案",
"top_k": 5,
"context": {"grade": "三年级", "subject": "语文"},
}
},
{
"name": "简化参数无context",
"data": {
"query": "三年级语文学习",
"scene": "学习方案",
"top_k": 5,
}
},
]
for test_case in test_cases:
print(f"\n{'='*80}")
print(f"测试: {test_case['name']}")
print(f"{'='*80}")
print(f"请求数据: {json.dumps(test_case['data'], ensure_ascii=False)}")
try:
start = time.time()
response = requests.post(
f"{API_BASE}/api/v1/mid/kb-search-dynamic",
headers=headers,
json=test_case['data'],
timeout=30,
)
elapsed = (time.time() - start) * 1000
print(f"\n响应状态: {response.status_code}")
print(f"总耗时: {elapsed:.2f} ms")
if response.status_code == 200:
result = response.json()
print(f"API 结果:")
print(f" success: {result.get('success')}")
print(f" hits count: {len(result.get('hits', []))}")
print(f" duration_ms: {result.get('duration_ms')}")
print(f" applied_filter: {result.get('applied_filter')}")
else:
print(f"错误: {response.text}")
except Exception as e:
print(f"请求失败: {e}")
print("\n" + "=" * 80)
if __name__ == "__main__":
test_kb_search()

View File

@ -0,0 +1,73 @@
"""
测试 Redis 缓存是否正常工作
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.services.metadata_cache_service import get_metadata_cache_service
async def test_cache():
"""测试缓存服务"""
print("=" * 80)
print("测试 Redis 缓存")
print("=" * 80)
cache_service = await get_metadata_cache_service()
tenant_id = "szmp@ash@2026"
# 1. 检查缓存是否存在
print("\n📊 1. 检查缓存是否存在")
cached = await cache_service.get_fields(tenant_id)
if cached:
print(f" ✅ 缓存存在,包含 {len(cached)} 个字段")
for field in cached[:3]:
print(f" - {field['field_key']}: {field['label']}")
else:
print(" ❌ 缓存不存在")
# 2. 手动设置缓存
print("\n📊 2. 手动设置测试缓存")
test_fields = [
{
"field_key": "grade",
"label": "年级",
"field_type": "enum",
"required": False,
"options": ["三年级", "四年级", "五年级"],
"default_value": None,
"is_filterable": True,
},
{
"field_key": "subject",
"label": "学科",
"field_type": "enum",
"required": False,
"options": ["语文", "数学", "英语"],
"default_value": None,
"is_filterable": True,
},
]
result = await cache_service.set_fields(tenant_id, test_fields, ttl=3600)
print(f" 设置缓存结果: {result}")
# 3. 再次获取缓存
print("\n📊 3. 再次获取缓存")
cached = await cache_service.get_fields(tenant_id)
if cached:
print(f" ✅ 缓存存在,包含 {len(cached)} 个字段")
else:
print(" ❌ 缓存不存在")
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
if __name__ == "__main__":
asyncio.run(test_cache())

View File

@ -0,0 +1,60 @@
"""
测试动态生成的工具 Schema
"""
import asyncio
import sys
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from app.core.config import get_settings
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
async def test_dynamic_tool_schema():
"""测试动态生成的工具 Schema"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
tenant_id = "szmp@ash@2026"
print(f"\n{'='*80}")
print(f"测试动态生成的工具 Schema")
print(f"{'='*80}")
print(f"租户 ID: {tenant_id}")
async with async_session() as session:
tool = KbSearchDynamicTool(session)
# 获取静态 Schema
static_schema = tool.get_tool_schema()
print(f"\n--- 静态 Schema ---")
print(json.dumps(static_schema, indent=2, ensure_ascii=False))
# 获取动态 Schema
dynamic_schema = await tool.get_dynamic_tool_schema(tenant_id)
print(f"\n--- 动态 Schema ---")
print(json.dumps(dynamic_schema, indent=2, ensure_ascii=False))
# 再次获取,测试缓存
print(f"\n--- 测试缓存 ---")
dynamic_schema2 = await tool.get_dynamic_tool_schema(tenant_id)
print(f"缓存命中: {dynamic_schema == dynamic_schema2}")
# 打印 context 字段的详细结构
print(f"\n--- context 字段详情 ---")
context_props = dynamic_schema["parameters"]["properties"].get("context", {}).get("properties", {})
print(f"过滤字段数量: {len(context_props)}")
for key, value in context_props.items():
print(f" {key}: {value}")
if __name__ == "__main__":
asyncio.run(test_dynamic_tool_schema())

View File

@ -0,0 +1,93 @@
"""
测试 kb_search_dynamic 工具是否能用给定的参数查出数据
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
from app.core.config import get_settings
async def test_kb_search():
"""测试知识库搜索"""
settings = get_settings()
# 创建数据库会话
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
tool = KbSearchDynamicTool(session=session)
# 测试参数
test_cases = [
{
"name": "完整参数含context过滤",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
"context": {"grade": "三年级", "subject": "语文"},
}
},
{
"name": "简化参数无context",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
}
},
{
"name": "仅query和tenant_id",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"top_k": 5,
}
},
]
for test_case in test_cases:
print(f"\n{'='*80}")
print(f"测试: {test_case['name']}")
print(f"{'='*80}")
print(f"参数: {test_case['params']}")
try:
result = await tool.execute(**test_case['params'])
print(f"\n结果:")
print(f" success: {result.success}")
print(f" hits count: {len(result.hits)}")
print(f" applied_filter: {result.applied_filter}")
print(f" fallback_reason_code: {result.fallback_reason_code}")
print(f" duration_ms: {result.duration_ms}")
if result.hits:
print(f"\n 前3条结果:")
for i, hit in enumerate(result.hits[:3], 1):
text = hit.get('text', '')[:80] + '...' if hit.get('text') else 'N/A'
score = hit.get('score', 0)
metadata = hit.get('metadata', {})
print(f" {i}. [score={score:.4f}] {text}")
print(f" metadata: {metadata}")
else:
print(f"\n ⚠️ 没有命中任何结果")
except Exception as e:
print(f"\n ❌ 错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_kb_search())

View File

@ -0,0 +1,120 @@
"""
测试 kb_search_dynamic 工具 - 课程咨询场景
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
from app.core.config import get_settings
async def test_kb_search():
"""测试知识库搜索 - 课程咨询场景"""
settings = get_settings()
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=10,
timeout_ms=15000,
min_score_threshold=0.3,
)
tool = KbSearchDynamicTool(session=session, config=config)
course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
step_kb_config = StepKbConfig(
allowed_kb_ids=[course_kb_id],
preferred_kb_ids=[course_kb_id],
step_id="test_course_query",
)
test_params = {
"query": "课程介绍",
"tenant_id": "szmp@ash@2026",
"top_k": 10,
"context": {
"grade": "五年级",
},
"step_kb_config": step_kb_config,
}
print(f"\n{'='*80}")
print(f"测试: kb_search_dynamic - 课程知识库")
print(f"{'='*80}")
print(f"参数: query={test_params['query']}")
print(f" tenant_id={test_params['tenant_id']}")
print(f" context={test_params['context']}")
print(f" step_kb_config.allowed_kb_ids={step_kb_config.allowed_kb_ids}")
print(f"超时设置: {config.timeout_ms}ms")
print(f"最低分数阈值: {config.min_score_threshold}")
try:
result = await tool.execute(**test_params)
print(f"\n结果:")
print(f" success: {result.success}")
print(f" hits count: {len(result.hits)}")
print(f" applied_filter: {result.applied_filter}")
print(f" fallback_reason_code: {result.fallback_reason_code}")
print(f" duration_ms: {result.duration_ms}")
if result.filter_debug:
print(f" filter_debug: {result.filter_debug}")
if result.step_kb_binding:
print(f" step_kb_binding: {result.step_kb_binding}")
if result.tool_trace:
print(f"\n Tool Trace:")
print(f" tool_name: {result.tool_trace.tool_name}")
print(f" status: {result.tool_trace.status}")
print(f" duration_ms: {result.tool_trace.duration_ms}")
print(f" args_digest: {result.tool_trace.args_digest}")
print(f" result_digest: {result.tool_trace.result_digest}")
if hasattr(result.tool_trace, 'arguments') and result.tool_trace.arguments:
print(f" arguments: {result.tool_trace.arguments}")
if result.hits:
print(f"\n 检索结果 (共 {len(result.hits)} 条):")
for i, hit in enumerate(result.hits, 1):
text = hit.get('text', '')
text_preview = text[:200] + '...' if len(text) > 200 else text
score = hit.get('score', 0)
metadata = hit.get('metadata', {})
collection = hit.get('collection', 'unknown')
kb_id = hit.get('kb_id', 'unknown')
print(f"\n [{i}] score={score:.4f}")
print(f" collection: {collection}")
print(f" kb_id: {kb_id}")
print(f" metadata: {metadata}")
print(f" text: {text_preview}")
else:
print(f"\n ⚠️ 没有命中任何结果")
print(f" 请检查:")
print(f" 1. 知识库是否有数据")
print(f" 2. 向量是否正确生成")
print(f" 3. 过滤条件是否过于严格")
except Exception as e:
print(f"\n ❌ 错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_kb_search())

View File

@ -0,0 +1,96 @@
"""
测试 kb_search_dynamic 工具 - 增加超时时间
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.core.config import get_settings
async def test_kb_search():
"""测试知识库搜索 - 增加超时时间"""
settings = get_settings()
# 创建数据库会话
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
# 使用更长的超时时间
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=10000, # 10秒超时
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
# 测试参数
test_cases = [
{
"name": "完整参数含context过滤",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
"context": {"grade": "三年级", "subject": "语文"},
}
},
{
"name": "简化参数无context",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
}
},
]
for test_case in test_cases:
print(f"\n{'='*80}")
print(f"测试: {test_case['name']}")
print(f"{'='*80}")
print(f"参数: {test_case['params']}")
print(f"超时设置: {config.timeout_ms}ms")
try:
result = await tool.execute(**test_case['params'])
print(f"\n结果:")
print(f" success: {result.success}")
print(f" hits count: {len(result.hits)}")
print(f" applied_filter: {result.applied_filter}")
print(f" fallback_reason_code: {result.fallback_reason_code}")
print(f" duration_ms: {result.duration_ms}")
if result.hits:
print(f"\n 所有结果:")
for i, hit in enumerate(result.hits, 1):
text = hit.get('text', '')[:100] + '...' if hit.get('text') else 'N/A'
score = hit.get('score', 0)
metadata = hit.get('metadata', {})
collection = hit.get('collection', 'unknown')
print(f" {i}. [score={score:.4f}] [collection={collection}]")
print(f" text: {text}")
print(f" metadata: {metadata}")
else:
print(f"\n ⚠️ 没有命中任何结果")
except Exception as e:
print(f"\n ❌ 错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_kb_search())

View File

@ -0,0 +1,89 @@
"""
直接测试 Qdrant 过滤功能
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client.models import FieldCondition, Filter, MatchValue
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def test_qdrant_filter():
"""直接测试 Qdrant 过滤功能"""
settings = get_settings()
client = QdrantClient()
qdrant = await client.get_client()
tenant_id = "szmp@ash@2026"
kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f"\n{'='*80}")
print(f"测试 Qdrant 过滤功能")
print(f"{'='*80}")
print(f"Collection: {collection_name}")
# 测试 1: 无过滤
print(f"\n--- 测试 1: 无过滤 ---")
results = await qdrant.scroll(
collection_name=collection_name,
limit=5,
with_vectors=False,
)
print(f"无过滤结果数: {len(results[0])}")
for p in results[0][:3]:
print(f" grade: {p.payload.get('metadata', {}).get('grade')}")
# 测试 2: 使用 Filter 对象过滤
print(f"\n--- 测试 2: 使用 Filter 对象过滤 (grade=五年级) ---")
qdrant_filter = Filter(
must=[
FieldCondition(
key="metadata.grade",
match=MatchValue(value="五年级"),
)
]
)
print(f"Filter: {qdrant_filter}")
results = await qdrant.scroll(
collection_name=collection_name,
limit=10,
with_vectors=False,
scroll_filter=qdrant_filter,
)
print(f"过滤后结果数: {len(results[0])}")
for p in results[0]:
print(f" grade: {p.payload.get('metadata', {}).get('grade')}, text: {p.payload.get('text', '')[:50]}...")
# 测试 3: 使用 query_points 过滤
print(f"\n--- 测试 3: 使用 query_points 过滤 ---")
# 先获取一个向量用于测试
all_points = await qdrant.scroll(
collection_name=collection_name,
limit=1,
with_vectors=True,
)
if all_points[0]:
query_vector = all_points[0][0].vector
results = await qdrant.query_points(
collection_name=collection_name,
query=query_vector,
limit=10,
query_filter=qdrant_filter,
)
print(f"query_points 过滤后结果数: {len(results.points)}")
for p in results.points:
print(f" grade: {p.payload.get('metadata', {}).get('grade')}, score: {p.score:.4f}")
if __name__ == "__main__":
asyncio.run(test_qdrant_filter())

View File

@ -0,0 +1,193 @@
"""
验证 Qdrant 向量数据库中的 collections 情况
用于检查 szmp@ash@2026 租户下的知识库 collections
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from qdrant_client import AsyncQdrantClient
from app.core.config import get_settings
async def list_collections():
"""列出所有 collections"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
print(f"🔗 Qdrant URL: {settings.qdrant_url}")
print(f"📦 Collection Prefix: {settings.qdrant_collection_prefix}")
print("-" * 60)
try:
collections = await client.get_collections()
if not collections.collections:
print("⚠️ 没有找到任何 collections")
return
print(f"✅ 找到 {len(collections.collections)} 个 collections:\n")
# 过滤出 szmp 相关的 collections
szmp_collections = []
other_collections = []
for collection in collections.collections:
name = collection.name
if "szmp" in name.lower():
szmp_collections.append(name)
else:
other_collections.append(name)
# 显示 szmp 相关的 collections
if szmp_collections:
print(f"🎯 szmp@ash@2026 租户相关的 collections ({len(szmp_collections)} 个):")
print("-" * 60)
for name in sorted(szmp_collections):
try:
info = await client.get_collection(name)
points_count = info.points_count if hasattr(info, 'points_count') else 'N/A'
print(f" 📁 {name}")
print(f" └─ 向量数量: {points_count}")
# 获取 collection 信息
if hasattr(info, 'config') and hasattr(info.config, 'params'):
params = info.config.params
if hasattr(params, 'vectors'):
vector_params = params.vectors
if hasattr(vector_params, 'size'):
print(f" └─ 向量维度: {vector_params.size}")
if hasattr(vector_params, 'distance'):
print(f" └─ 距离函数: {vector_params.distance}")
print()
except Exception as e:
print(f" 📁 {name}")
print(f" └─ 获取信息失败: {e}\n")
else:
print("⚠️ 没有找到 szmp@ash@2026 租户相关的 collections\n")
# 显示其他 collections
if other_collections:
print(f"📂 其他 collections ({len(other_collections)} 个):")
print("-" * 60)
for name in sorted(other_collections):
try:
info = await client.get_collection(name)
points_count = info.points_count if hasattr(info, 'points_count') else 'N/A'
print(f" 📁 {name} (向量数: {points_count})")
except Exception as e:
print(f" 📁 {name} (获取信息失败: {e})")
print("\n" + "=" * 60)
print("📊 总结:")
print(f" - Collections 总数: {len(collections.collections)}")
print(f" - szmp 相关: {len(szmp_collections)}")
print(f" - 其他: {len(other_collections)}")
# 验证预期
print("\n✅ 验证:")
if len(szmp_collections) == 2:
print(" ✓ szmp 租户的 collection 数量符合预期 (2个)")
else:
print(f" ⚠️ szmp 租户的 collection 数量不符合预期 (实际: {len(szmp_collections)} 个, 预期: 2个)")
except Exception as e:
print(f"❌ 连接 Qdrant 失败: {e}")
print(f" 请检查 Qdrant 是否运行在 {settings.qdrant_url}")
finally:
await client.close()
async def check_collection_details(collection_name: str):
"""查看特定 collection 的详细信息"""
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
try:
print(f"\n📋 Collection '{collection_name}' 详细信息:")
print("-" * 60)
info = await client.get_collection(collection_name)
print(f" 名称: {collection_name}")
print(f" 向量数量: {info.points_count}")
if hasattr(info, 'config') and hasattr(info.config, 'params'):
params = info.config.params
if hasattr(params, 'vectors'):
vector_params = params.vectors
print(f" 向量配置:")
if hasattr(vector_params, 'size'):
print(f" - 维度: {vector_params.size}")
if hasattr(vector_params, 'distance'):
print(f" - 距离函数: {vector_params.distance}")
if hasattr(vector_params, 'on_disk'):
print(f" - 磁盘存储: {vector_params.on_disk}")
if hasattr(params, 'shard_number'):
print(f" 分片数: {params.shard_number}")
if hasattr(params, 'replication_factor'):
print(f" 副本数: {params.replication_factor}")
# 获取一些样本数据
try:
from qdrant_client.models import ScrollRequest
scroll_result = await client.scroll(
collection_name=collection_name,
limit=3,
with_payload=True,
with_vectors=False
)
if scroll_result[0]:
print(f"\n 样本数据 (前3条):")
for i, point in enumerate(scroll_result[0], 1):
payload = point.payload or {}
text = payload.get('text', '')[:50] + '...' if payload.get('text') else 'N/A'
kb_id = payload.get('kb_id', 'N/A')
print(f" {i}. ID: {point.id}")
print(f" KB ID: {kb_id}")
print(f" 文本: {text}")
except Exception as e:
print(f" 获取样本数据失败: {e}")
except Exception as e:
print(f"❌ 获取 collection 信息失败: {e}")
finally:
await client.close()
async def main():
"""主函数"""
print("=" * 60)
print("🔍 Qdrant 向量数据库 Collections 验证工具")
print("=" * 60)
print()
# 列出所有 collections
await list_collections()
# 检查 szmp 相关的 collections 详情
settings = get_settings()
client = AsyncQdrantClient(url=settings.qdrant_url)
try:
collections = await client.get_collections()
szmp_collections = [c.name for c in collections.collections if "szmp" in c.name.lower()]
for name in sorted(szmp_collections):
await check_collection_details(name)
except Exception as e:
print(f"❌ 错误: {e}")
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,333 @@
"""
Tests for Batch Ask-Back Service.
[AC-MRS-SLOT-ASKBACK-01] 批量追问测试
"""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.mid.batch_ask_back_service import (
AskBackSlot,
BatchAskBackConfig,
BatchAskBackResult,
BatchAskBackService,
create_batch_ask_back_service,
)
class TestAskBackSlot:
"""AskBackSlot 测试"""
def test_init(self):
"""测试初始化"""
slot = AskBackSlot(
slot_key="region",
label="地区",
ask_back_prompt="请告诉我您的地区",
priority=100,
is_required=True,
)
assert slot.slot_key == "region"
assert slot.label == "地区"
assert slot.priority == 100
assert slot.is_required is True
class TestBatchAskBackConfig:
"""BatchAskBackConfig 测试"""
def test_default_config(self):
"""测试默认配置"""
config = BatchAskBackConfig()
assert config.max_ask_back_slots_per_turn == 2
assert config.prefer_required is True
assert config.prefer_scene_relevant is True
assert config.avoid_recent_asked is True
assert config.recent_asked_threshold_seconds == 60.0
assert config.merge_prompts is True
def test_custom_config(self):
"""测试自定义配置"""
config = BatchAskBackConfig(
max_ask_back_slots_per_turn=3,
prefer_required=False,
merge_prompts=False,
)
assert config.max_ask_back_slots_per_turn == 3
assert config.prefer_required is False
assert config.merge_prompts is False
class TestBatchAskBackResult:
"""BatchAskBackResult 测试"""
def test_has_ask_back(self):
"""测试是否有追问"""
result = BatchAskBackResult(ask_back_count=2)
assert result.has_ask_back() is True
result = BatchAskBackResult(ask_back_count=0)
assert result.has_ask_back() is False
def test_get_prompt_with_merged(self):
"""测试获取合并后的提示"""
result = BatchAskBackResult(
merged_prompt="请告诉我您的地区和产品",
prompts=["请告诉我您的地区", "请告诉我您的产品"],
ask_back_count=2,
)
assert result.get_prompt() == "请告诉我您的地区和产品"
def test_get_prompt_without_merged(self):
"""测试获取未合并的提示"""
result = BatchAskBackResult(
prompts=["请告诉我您的地区"],
ask_back_count=1,
)
assert result.get_prompt() == "请告诉我您的地区"
def test_get_prompt_empty(self):
"""测试空结果的提示"""
result = BatchAskBackResult()
assert "请提供更多信息" in result.get_prompt()
class TestBatchAskBackService:
"""BatchAskBackService 测试"""
@pytest.fixture
def mock_session(self):
"""创建 mock session"""
return AsyncMock()
@pytest.fixture
def config(self):
"""创建配置"""
return BatchAskBackConfig(
max_ask_back_slots_per_turn=2,
prefer_required=True,
merge_prompts=True,
)
@pytest.fixture
def service(self, mock_session, config):
"""创建服务实例"""
return BatchAskBackService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
config=config,
)
def test_calculate_priority_required(self, service):
"""测试必填槽位优先级"""
priority = service._calculate_priority(is_required=True, scene_relevance=0.0)
assert priority == 100
def test_calculate_priority_scene_relevant(self, service):
"""测试场景相关优先级"""
priority = service._calculate_priority(is_required=False, scene_relevance=1.0)
assert priority == 50
def test_calculate_priority_both(self, service):
"""测试必填且场景相关优先级"""
priority = service._calculate_priority(is_required=True, scene_relevance=1.0)
assert priority == 150
def test_select_slots_for_ask_back(self, service):
"""测试选择追问槽位"""
slots = [
AskBackSlot(slot_key="a", label="A", priority=50),
AskBackSlot(slot_key="b", label="B", priority=100),
AskBackSlot(slot_key="c", label="C", priority=75),
]
selected = service._select_slots_for_ask_back(slots)
assert len(selected) == 2
assert selected[0].slot_key == "b"
assert selected[1].slot_key == "c"
def test_filter_recently_asked(self, service):
"""测试过滤最近追问过的槽位"""
current_time = time.time()
asked_history = {
"recently_asked": current_time - 30,
"old_asked": current_time - 120,
}
slots = [
AskBackSlot(slot_key="recently_asked", label="最近追问过"),
AskBackSlot(slot_key="old_asked", label="很久前追问过"),
AskBackSlot(slot_key="never_asked", label="从未追问过"),
]
filtered = service._filter_recently_asked(slots, asked_history)
assert len(filtered) == 2
slot_keys = [s.slot_key for s in filtered]
assert "old_asked" in slot_keys
assert "never_asked" in slot_keys
assert "recently_asked" not in slot_keys
def test_generate_prompts(self, service):
"""测试生成追问提示"""
slots = [
AskBackSlot(slot_key="region", label="地区", ask_back_prompt="请告诉我您的地区"),
AskBackSlot(slot_key="product", label="产品", ask_back_prompt=None),
]
prompts = service._generate_prompts(slots)
assert len(prompts) == 2
assert prompts[0] == "请告诉我您的地区"
assert "产品" in prompts[1]
def test_merge_prompts_single(self, service):
"""测试合并单个提示"""
prompts = ["请告诉我您的地区"]
merged = service._merge_prompts(prompts)
assert merged == "请告诉我您的地区"
def test_merge_prompts_two(self, service):
"""测试合并两个提示"""
prompts = ["请告诉我您的地区", "请告诉我您的产品"]
merged = service._merge_prompts(prompts)
assert "地区" in merged
assert "产品" in merged
assert "以及" in merged
def test_merge_prompts_multiple(self, service):
"""测试合并多个提示"""
prompts = ["您的地区", "您的产品", "您的等级"]
merged = service._merge_prompts(prompts)
assert "地区" in merged
assert "产品" in merged
assert "等级" in merged
assert "" in merged
assert "以及" in merged
@pytest.mark.asyncio
async def test_generate_batch_ask_back_empty(self, service):
"""测试空缺失槽位"""
result = await service.generate_batch_ask_back(missing_slots=[])
assert result.has_ask_back() is False
@pytest.mark.asyncio
async def test_generate_batch_ask_back_single(self, service):
"""测试单个缺失槽位"""
missing_slots = [
{
"slot_key": "region",
"label": "地区",
"ask_back_prompt": "请告诉我您的地区",
}
]
with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get:
mock_get.return_value = MagicMock(required=True)
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
assert result.has_ask_back() is True
assert result.ask_back_count == 1
assert "地区" in result.get_prompt()
@pytest.mark.asyncio
async def test_generate_batch_ask_back_multiple(self, service):
"""测试多个缺失槽位"""
missing_slots = [
{"slot_key": "region", "label": "地区", "ask_back_prompt": "您的地区"},
{"slot_key": "product", "label": "产品", "ask_back_prompt": "您的产品"},
{"slot_key": "grade", "label": "等级", "ask_back_prompt": "您的等级"},
]
with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get:
mock_get.return_value = MagicMock(required=True)
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
assert result.has_ask_back() is True
assert result.ask_back_count == 2
@pytest.mark.asyncio
async def test_generate_batch_ask_back_prioritize_required(self, service):
"""测试优先追问必填槽位"""
missing_slots = [
{"slot_key": "optional", "label": "可选", "ask_back_prompt": "可选信息"},
{"slot_key": "required", "label": "必填", "ask_back_prompt": "必填信息"},
]
def mock_get_slot(tenant_id, slot_key):
if slot_key == "required":
return MagicMock(required=True)
return MagicMock(required=False)
with patch.object(service._slot_def_service, 'get_slot_definition_by_key', side_effect=mock_get_slot):
result = await service.generate_batch_ask_back(missing_slots=missing_slots)
assert result.has_ask_back() is True
assert result.selected_slots[0].slot_key == "required"
class TestCreateBatchAskBackService:
"""create_batch_ask_back_service 工厂函数测试"""
def test_create(self):
"""测试创建服务实例"""
mock_session = AsyncMock()
config = BatchAskBackConfig(max_ask_back_slots_per_turn=3)
service = create_batch_ask_back_service(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
config=config,
)
assert isinstance(service, BatchAskBackService)
assert service._tenant_id == "tenant_1"
assert service._session_id == "session_1"
assert service._config.max_ask_back_slots_per_turn == 3
class TestAskBackHistory:
"""追问历史测试"""
@pytest.fixture
def service(self):
"""创建服务实例"""
mock_session = AsyncMock()
return BatchAskBackService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
@pytest.mark.asyncio
async def test_get_asked_history_empty(self, service):
"""测试获取空历史"""
with patch.object(service._cache, '_get_client') as mock_client:
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_client.return_value = mock_redis
history = await service._get_asked_history()
assert history == {}
@pytest.mark.asyncio
async def test_get_asked_history_with_data(self, service):
"""测试获取有数据的历史"""
import json
history_data = {"region": 12345.0, "product": 12346.0}
with patch.object(service._cache, '_get_client') as mock_client:
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=json.dumps(history_data))
mock_client.return_value = mock_redis
history = await service._get_asked_history()
assert history == history_data

View File

@ -0,0 +1,543 @@
"""
Tests for clarification mechanism.
[AC-CLARIFY] 澄清机制测试
"""
import pytest
from unittest.mock import MagicMock, patch
from app.services.intent.clarification import (
ClarificationEngine,
ClarifyMetrics,
ClarifyReason,
ClarifySessionManager,
ClarifyState,
HybridIntentResult,
IntentCandidate,
T_HIGH,
T_LOW,
MAX_CLARIFY_RETRY,
get_clarify_metrics,
)
class TestClarifyMetrics:
def test_singleton_pattern(self):
m1 = ClarifyMetrics()
m2 = ClarifyMetrics()
assert m1 is m2
def test_record_clarify_trigger(self):
metrics = ClarifyMetrics()
metrics.reset()
metrics.record_clarify_trigger()
metrics.record_clarify_trigger()
metrics.record_clarify_trigger()
counts = metrics.get_metrics()
assert counts["clarify_trigger_rate"] == 3
def test_record_clarify_converge(self):
metrics = ClarifyMetrics()
metrics.reset()
metrics.record_clarify_converge()
metrics.record_clarify_converge()
counts = metrics.get_metrics()
assert counts["clarify_converge_rate"] == 2
def test_record_misroute(self):
metrics = ClarifyMetrics()
metrics.reset()
metrics.record_misroute()
counts = metrics.get_metrics()
assert counts["misroute_rate"] == 1
def test_get_rates(self):
metrics = ClarifyMetrics()
metrics.reset()
metrics.record_clarify_trigger()
metrics.record_clarify_converge()
metrics.record_misroute()
rates = metrics.get_rates(100)
assert rates["clarify_trigger_rate"] == 0.01
assert rates["clarify_converge_rate"] == 1.0
assert rates["misroute_rate"] == 0.01
def test_get_rates_zero_requests(self):
metrics = ClarifyMetrics()
metrics.reset()
rates = metrics.get_rates(0)
assert rates["clarify_trigger_rate"] == 0.0
assert rates["clarify_converge_rate"] == 0.0
assert rates["misroute_rate"] == 0.0
def test_reset(self):
metrics = ClarifyMetrics()
metrics.record_clarify_trigger()
metrics.record_clarify_converge()
metrics.record_misroute()
metrics.reset()
counts = metrics.get_metrics()
assert counts["clarify_trigger_rate"] == 0
assert counts["clarify_converge_rate"] == 0
assert counts["misroute_rate"] == 0
class TestIntentCandidate:
def test_to_dict(self):
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.85,
response_type="flow",
target_kb_ids=["kb-1"],
flow_id="flow-1",
fixed_reply=None,
transfer_message=None,
)
result = candidate.to_dict()
assert result["intent_id"] == "intent-1"
assert result["intent_name"] == "退货意图"
assert result["confidence"] == 0.85
assert result["response_type"] == "flow"
assert result["target_kb_ids"] == ["kb-1"]
assert result["flow_id"] == "flow-1"
class TestHybridIntentResult:
def test_to_dict(self):
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.85,
)
result = HybridIntentResult(
intent=candidate,
confidence=0.85,
candidates=[candidate],
need_clarify=False,
clarify_reason=None,
missing_slots=[],
)
d = result.to_dict()
assert d["intent"]["intent_id"] == "intent-1"
assert d["confidence"] == 0.85
assert len(d["candidates"]) == 1
assert d["need_clarify"] is False
def test_from_fusion_result(self):
mock_fusion = MagicMock()
mock_fusion.final_intent = MagicMock()
mock_fusion.final_intent.id = "intent-1"
mock_fusion.final_intent.name = "退货意图"
mock_fusion.final_intent.response_type = "flow"
mock_fusion.final_intent.target_kb_ids = ["kb-1"]
mock_fusion.final_intent.flow_id = None
mock_fusion.final_intent.fixed_reply = None
mock_fusion.final_intent.transfer_message = None
mock_fusion.final_confidence = 0.85
mock_fusion.need_clarify = False
mock_fusion.decision_reason = "rule_high_confidence"
mock_fusion.clarify_candidates = []
result = HybridIntentResult.from_fusion_result(mock_fusion)
assert result.intent is not None
assert result.intent.intent_id == "intent-1"
assert result.confidence == 0.85
assert result.need_clarify is False
def test_from_fusion_result_with_clarify(self):
mock_fusion = MagicMock()
mock_fusion.final_intent = None
mock_fusion.final_confidence = 0.5
mock_fusion.need_clarify = True
mock_fusion.decision_reason = "multi_intent"
candidate1 = MagicMock()
candidate1.id = "intent-1"
candidate1.name = "退货意图"
candidate1.response_type = "flow"
candidate1.target_kb_ids = None
candidate1.flow_id = None
candidate1.fixed_reply = None
candidate1.transfer_message = None
candidate2 = MagicMock()
candidate2.id = "intent-2"
candidate2.name = "换货意图"
candidate2.response_type = "flow"
candidate2.target_kb_ids = None
candidate2.flow_id = None
candidate2.fixed_reply = None
candidate2.transfer_message = None
mock_fusion.clarify_candidates = [candidate1, candidate2]
result = HybridIntentResult.from_fusion_result(mock_fusion)
assert result.need_clarify is True
assert result.clarify_reason == ClarifyReason.MULTI_INTENT
assert len(result.candidates) == 2
class TestClarifyState:
def test_to_dict(self):
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.5,
)
state = ClarifyState(
reason=ClarifyReason.INTENT_AMBIGUITY,
asked_slot=None,
retry_count=1,
candidates=[candidate],
asked_intent_ids=["intent-1"],
)
d = state.to_dict()
assert d["reason"] == "intent_ambiguity"
assert d["retry_count"] == 1
assert len(d["candidates"]) == 1
def test_increment_retry(self):
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
state.increment_retry()
assert state.retry_count == 1
state.increment_retry()
assert state.retry_count == 2
def test_is_max_retry(self):
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
assert not state.is_max_retry()
state.retry_count = MAX_CLARIFY_RETRY
assert state.is_max_retry()
class TestClarificationEngine:
def test_compute_confidence_rule_only(self):
engine = ClarificationEngine()
confidence = engine.compute_confidence(
rule_score=1.0,
semantic_score=0.0,
llm_score=0.0,
w_rule=1.0,
w_semantic=0.0,
w_llm=0.0,
)
assert confidence == 1.0
def test_compute_confidence_semantic_only(self):
engine = ClarificationEngine()
confidence = engine.compute_confidence(
rule_score=0.0,
semantic_score=0.8,
llm_score=0.0,
w_rule=0.3,
w_semantic=0.5,
w_llm=0.2,
)
# With weights w_rule=0.3, w_semantic=0.5, w_llm=0.2 and scores
# rule=0.0, semantic=0.8, llm=0.0:
# confidence = (0.0*0.3 + 0.8*0.5 + 0.0*0.2) / (0.3+0.5+0.2) = 0.4/1.0 = 0.4
assert confidence == 0.4
def test_compute_confidence_weighted(self):
engine = ClarificationEngine()
confidence = engine.compute_confidence(
rule_score=1.0,
semantic_score=0.8,
llm_score=0.9,
w_rule=0.5,
w_semantic=0.3,
w_llm=0.2,
)
expected = (1.0 * 0.5 + 0.8 * 0.3 + 0.9 * 0.2) / 1.0
assert abs(confidence - expected) < 0.001
def test_check_hard_block_low_confidence(self):
engine = ClarificationEngine()
result = HybridIntentResult(
intent=None,
confidence=0.5,
candidates=[],
)
is_blocked, reason = engine.check_hard_block(result)
assert is_blocked is True
assert reason == ClarifyReason.LOW_CONFIDENCE
def test_check_hard_block_high_confidence(self):
engine = ClarificationEngine()
result = HybridIntentResult(
intent=IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.85,
),
confidence=0.85,
candidates=[],
)
is_blocked, reason = engine.check_hard_block(result)
assert is_blocked is False
assert reason is None
def test_check_hard_block_missing_slots(self):
engine = ClarificationEngine()
result = HybridIntentResult(
intent=IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.85,
),
confidence=0.85,
candidates=[],
)
is_blocked, reason = engine.check_hard_block(
result,
required_slots=["order_id", "product_id"],
filled_slots={"order_id": "123"},
)
assert is_blocked is True
assert reason == ClarifyReason.MISSING_SLOT
def test_should_trigger_clarify_below_t_low(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
result = HybridIntentResult(
intent=None,
confidence=0.3,
candidates=[],
)
should_clarify, state = engine.should_trigger_clarify(result)
assert should_clarify is True
assert state is not None
assert state.reason == ClarifyReason.LOW_CONFIDENCE
def test_should_trigger_clarify_gray_zone(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.5,
)
result = HybridIntentResult(
intent=candidate,
confidence=0.5,
candidates=[candidate],
need_clarify=True,
clarify_reason=ClarifyReason.INTENT_AMBIGUITY,
)
should_clarify, state = engine.should_trigger_clarify(result)
assert should_clarify is True
assert state is not None
assert state.reason == ClarifyReason.INTENT_AMBIGUITY
def test_should_trigger_clarify_above_t_high(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.85,
)
result = HybridIntentResult(
intent=candidate,
confidence=0.85,
candidates=[candidate],
)
should_clarify, state = engine.should_trigger_clarify(result)
assert should_clarify is False
assert state is None
def test_generate_clarify_prompt_missing_slot(self):
engine = ClarificationEngine()
state = ClarifyState(
reason=ClarifyReason.MISSING_SLOT,
asked_slot="order_id",
)
prompt = engine.generate_clarify_prompt(state)
assert "order_id" in prompt or "相关信息" in prompt
def test_generate_clarify_prompt_low_confidence(self):
engine = ClarificationEngine()
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
prompt = engine.generate_clarify_prompt(state)
assert "理解" in prompt or "详细" in prompt
def test_generate_clarify_prompt_multi_intent(self):
engine = ClarificationEngine()
candidates = [
IntentCandidate(intent_id="1", intent_name="退货", confidence=0.5),
IntentCandidate(intent_id="2", intent_name="换货", confidence=0.4),
]
state = ClarifyState(
reason=ClarifyReason.MULTI_INTENT,
candidates=candidates,
)
prompt = engine.generate_clarify_prompt(state)
assert "退货" in prompt
assert "换货" in prompt
def test_process_clarify_response_max_retry(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
state = ClarifyState(
reason=ClarifyReason.LOW_CONFIDENCE,
retry_count=MAX_CLARIFY_RETRY,
)
result = engine.process_clarify_response("用户回复", state)
assert result.intent is None
assert result.confidence == 0.0
assert result.need_clarify is False
def test_process_clarify_response_missing_slot(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
candidate = IntentCandidate(
intent_id="intent-1",
intent_name="退货意图",
confidence=0.8,
)
state = ClarifyState(
reason=ClarifyReason.MISSING_SLOT,
asked_slot="order_id",
candidates=[candidate],
)
result = engine.process_clarify_response("订单号是123", state)
assert result.intent is not None
assert result.need_clarify is False
def test_get_metrics(self):
engine = ClarificationEngine()
get_clarify_metrics().reset()
engine._metrics.record_clarify_trigger()
engine._metrics.record_clarify_converge()
metrics = engine.get_metrics()
assert metrics["clarify_trigger_rate"] == 1
assert metrics["clarify_converge_rate"] == 1
class TestClarifySessionManager:
def test_set_and_get_session(self):
ClarifySessionManager.clear_session("test-session")
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
ClarifySessionManager.set_session("test-session", state)
retrieved = ClarifySessionManager.get_session("test-session")
assert retrieved is not None
assert retrieved.reason == ClarifyReason.LOW_CONFIDENCE
def test_clear_session(self):
ClarifySessionManager.set_session(
"test-session",
ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE),
)
ClarifySessionManager.clear_session("test-session")
retrieved = ClarifySessionManager.get_session("test-session")
assert retrieved is None
def test_has_active_clarify(self):
ClarifySessionManager.clear_session("test-session")
assert not ClarifySessionManager.has_active_clarify("test-session")
state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE)
ClarifySessionManager.set_session("test-session", state)
assert ClarifySessionManager.has_active_clarify("test-session")
state.retry_count = MAX_CLARIFY_RETRY
assert not ClarifySessionManager.has_active_clarify("test-session")
class TestThresholds:
def test_t_high_value(self):
assert T_HIGH == 0.75
def test_t_low_value(self):
assert T_LOW == 0.45
def test_t_high_greater_than_t_low(self):
assert T_HIGH > T_LOW
def test_max_retry_value(self):
assert MAX_CLARIFY_RETRY == 3

View File

@ -0,0 +1,308 @@
"""
Tests for Dialogue API with Slot State Integration.
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
from app.services.mid.slot_state_aggregator import SlotState
class TestGenerateAskBackResponse:
"""测试生成追问响应"""
@pytest.mark.asyncio
async def test_generate_ask_back_with_prompt(self):
"""测试使用配置的 ask_back_prompt 生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "region",
"label": "地区",
"ask_back_prompt": "请问您在哪个地区?",
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert response == "请问您在哪个地区?"
@pytest.mark.asyncio
async def test_generate_ask_back_generic(self):
"""测试使用通用模板生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "product_line",
"label": "产品线",
# 没有 ask_back_prompt
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "产品线" in response
@pytest.mark.asyncio
async def test_generate_ask_back_empty_slots(self):
"""测试空缺失槽位列表"""
slot_state = SlotState()
missing_slots = []
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "更多信息" in response
class TestDialogueAskBackResponse:
"""测试对话追问响应"""
def test_dialogue_response_with_ask_back(self):
"""测试追问响应的结构"""
from app.models.mid.schemas import DialogueResponse
response = DialogueResponse(
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id="test_request_id",
generation_id="test_generation_id",
fallback_reason_code="missing_required_slots",
kb_tool_called=True,
kb_hit=False,
),
)
assert len(response.segments) == 1
assert "哪个产品线" in response.segments[0].text
assert response.trace.fallback_reason_code == "missing_required_slots"
assert response.trace.kb_tool_called is True
assert response.trace.kb_hit is False
class TestSlotStateAggregationFlow:
"""测试槽位状态聚合流程"""
@pytest.mark.asyncio
async def test_memory_slots_included_in_state(self):
"""测试 memory_recall 的槽位被包含在状态中"""
from app.models.mid.schemas import MemorySlot, SlotSource
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
memory_slots = {
"product_line": MemorySlot(
key="product_line",
value="vip_course",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
}
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=memory_slots,
current_input_slots=None,
context=None,
)
assert "product_line" in state.filled_slots
assert state.filled_slots["product_line"] == "vip_course"
assert state.slot_sources["product_line"] == "user_confirmed"
@pytest.mark.asyncio
async def test_missing_slots_identified(self):
"""测试缺失的必填槽位被正确识别"""
from unittest.mock import MagicMock
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟一个 required 的槽位定义
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "region"
mock_slot_def.required = True
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
mock_slot_def.linked_field_id = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
assert len(state.missing_required_slots) == 1
assert state.missing_required_slots[0]["slot_key"] == "region"
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
class TestSlotMetadataLinkage:
"""测试槽位与元数据关联"""
@pytest.mark.asyncio
async def test_slot_to_field_mapping(self):
"""测试槽位到元数据字段的映射"""
from unittest.mock import MagicMock, patch
from app.services.mid.slot_state_aggregator import SlotStateAggregator
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟槽位定义(带 linked_field_id
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "product"
mock_slot_def.linked_field_id = "field-uuid-123"
mock_slot_def.required = False
mock_slot_def.type = "string"
mock_slot_def.options = None
# 模拟关联的元数据字段
mock_field = MagicMock()
mock_field.field_key = "product_line"
mock_field.label = "产品线"
mock_field.type = "string"
mock_field.required = False
mock_field.options = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
with patch.object(
MetadataFieldDefinitionService,
"get_field_definition",
return_value=mock_field,
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
# 验证映射已建立
assert state.slot_to_field_map.get("product") == "product_line"
class TestBackwardCompatibility:
"""测试向后兼容性"""
@pytest.mark.asyncio
async def test_kb_search_without_slot_state(self):
"""测试不使用 slot_state 时 KB 检索仍然工作"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 模拟 filter_builder 返回空结果
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=None, # 不提供 slot_state
)
# 应该成功执行
assert result.success is True
assert result.fallback_reason_code is None
@pytest.mark.asyncio
async def test_legacy_context_filter(self):
"""测试使用传统 context 构建过滤器"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 使用简单 context
context = {"product_line": "vip_course", "region": "beijing"}
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context=context,
slot_state=None,
)
# 应该成功执行
assert result.success is True
# 简单 context 应该直接使用作为 filter
assert result.applied_filter.get("product_line") == "vip_course"

View File

@ -0,0 +1,149 @@
"""
Tests for field_roles update functionality.
[AC-MRS-01] 验证字段角色更新功能
"""
import pytest
import uuid
from unittest.mock import AsyncMock, MagicMock
from app.models.entities import (
MetadataFieldDefinition,
MetadataFieldDefinitionUpdate,
MetadataFieldStatus,
)
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
class TestFieldRolesUpdate:
"""测试字段角色更新功能"""
@pytest.fixture
def mock_session(self):
"""Create mock session"""
session = MagicMock()
session.execute = AsyncMock()
session.flush = AsyncMock()
session.commit = AsyncMock()
return session
@pytest.fixture
def service(self, mock_session):
"""Create service instance"""
return MetadataFieldDefinitionService(mock_session)
@pytest.fixture
def existing_field(self):
"""Create existing field with field_roles"""
field = MagicMock(spec=MetadataFieldDefinition)
field.id = uuid.uuid4()
field.tenant_id = "test-tenant"
field.field_key = "grade"
field.label = "年级"
field.type = "string"
field.required = True
field.options = None
field.default_value = None
field.scope = ["kb_document"]
field.is_filterable = True
field.is_rank_feature = False
field.field_roles = ["slot"] # 初始角色
field.status = MetadataFieldStatus.ACTIVE.value
field.version = 1
return field
@pytest.mark.asyncio
async def test_update_field_roles_success(self, service, mock_session, existing_field):
"""[AC-MRS-01] 测试成功更新字段角色"""
# Mock get_field_definition to return existing field
service.get_field_definition = AsyncMock(return_value=existing_field)
# Create update request with new field_roles
field_update = MetadataFieldDefinitionUpdate(
field_roles=["slot", "resource_filter"]
)
# Execute update
result = await service.update_field_definition(
"test-tenant",
str(existing_field.id),
field_update
)
# Verify result
assert result is not None
assert result.field_roles == ["slot", "resource_filter"]
assert result.version == 2 # Version should increment
@pytest.mark.asyncio
async def test_update_field_roles_to_empty(self, service, mock_session, existing_field):
"""[AC-MRS-01] 测试将字段角色更新为空列表"""
service.get_field_definition = AsyncMock(return_value=existing_field)
field_update = MetadataFieldDefinitionUpdate(
field_roles=[]
)
result = await service.update_field_definition(
"test-tenant",
str(existing_field.id),
field_update
)
assert result is not None
assert result.field_roles == []
@pytest.mark.asyncio
async def test_update_field_roles_invalid_role(self, service, mock_session, existing_field):
"""[AC-MRS-01] 测试更新无效的字段角色"""
service.get_field_definition = AsyncMock(return_value=existing_field)
field_update = MetadataFieldDefinitionUpdate(
field_roles=["invalid_role"]
)
with pytest.raises(ValueError) as exc_info:
await service.update_field_definition(
"test-tenant",
str(existing_field.id),
field_update
)
assert "无效的字段角色" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_without_field_roles_unchanged(self, service, mock_session, existing_field):
"""[AC-MRS-01] 测试不更新 field_roles 时保持原值"""
service.get_field_definition = AsyncMock(return_value=existing_field)
# Update only label, not field_roles
field_update = MetadataFieldDefinitionUpdate(
label="新年级标签"
)
result = await service.update_field_definition(
"test-tenant",
str(existing_field.id),
field_update
)
assert result is not None
assert result.label == "新年级标签"
assert result.field_roles == ["slot"] # Should remain unchanged
@pytest.mark.asyncio
async def test_update_field_roles_not_found(self, service, mock_session):
"""[AC-MRS-01] 测试更新不存在的字段"""
service.get_field_definition = AsyncMock(return_value=None)
field_update = MetadataFieldDefinitionUpdate(
field_roles=["slot"]
)
result = await service.update_field_definition(
"test-tenant",
str(uuid.uuid4()),
field_update
)
assert result is None

View File

@ -0,0 +1,408 @@
"""
Unit tests for FusionPolicy.
[AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy.
"""
import pytest
from unittest.mock import MagicMock
import uuid
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeResult,
RuleMatchResult,
SemanticCandidate,
SemanticMatchResult,
RouteTrace,
)
class FusionPolicy:
"""[AC-AISVC-115] Fusion decision policy."""
DECISION_PRIORITY = [
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False),
("semantic_fallback", lambda r, s, l: s.top_score > 0.5),
("rule_fallback", lambda r, s, l: r.score > 0),
("no_match", lambda r, s, l: True),
]
def __init__(self, config: FusionConfig):
self._config = config
def fuse(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> FusionResult:
trace = RouteTrace(
rule_match={
"rule_id": str(rule_result.rule_id) if rule_result.rule_id else None,
"match_type": rule_result.match_type,
"matched_text": rule_result.matched_text,
"score": rule_result.score,
"duration_ms": rule_result.duration_ms,
},
semantic_match={
"top_candidates": [
{"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score}
for c in semantic_result.candidates
],
"top_score": semantic_result.top_score,
"duration_ms": semantic_result.duration_ms,
"skipped": semantic_result.skipped,
"skip_reason": semantic_result.skip_reason,
},
llm_judge={
"triggered": llm_result.triggered if llm_result else False,
"intent_id": llm_result.intent_id if llm_result else None,
"score": llm_result.score if llm_result else 0.0,
"duration_ms": llm_result.duration_ms if llm_result else 0,
"tokens_used": llm_result.tokens_used if llm_result else 0,
},
fusion={},
)
final_intent = None
final_confidence = 0.0
decision_reason = "no_match"
for reason, condition in self.DECISION_PRIORITY:
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
decision_reason = reason
break
if decision_reason == "rule_high_confidence":
final_intent = rule_result.rule
final_confidence = 1.0
elif decision_reason == "llm_judge" and llm_result:
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
final_confidence = llm_result.score
elif decision_reason == "semantic_override":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_semantic_agree":
final_intent = rule_result.rule
final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result)
elif decision_reason == "semantic_fallback":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
elif decision_reason == "rule_fallback":
final_intent = rule_result.rule
final_confidence = rule_result.score
need_clarify = final_confidence < self._config.clarify_threshold
clarify_candidates = None
if need_clarify and len(semantic_result.candidates) > 1:
clarify_candidates = [c.rule for c in semantic_result.candidates[:3]]
trace.fusion = {
"weights": {
"w_rule": self._config.w_rule,
"w_semantic": self._config.w_semantic,
"w_llm": self._config.w_llm,
},
"final_confidence": final_confidence,
"decision_reason": decision_reason,
}
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=need_clarify,
clarify_candidates=clarify_candidates,
trace=trace,
)
def _calculate_weighted_confidence(
self,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
llm_result: LlmJudgeResult | None,
) -> float:
rule_score = rule_result.score
semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0
llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0
total_weight = self._config.w_rule + self._config.w_semantic
if llm_result and llm_result.triggered:
total_weight += self._config.w_llm
confidence = (
self._config.w_rule * rule_score +
self._config.w_semantic * semantic_score +
self._config.w_llm * llm_score
) / total_weight
return min(1.0, max(0.0, confidence))
def _find_rule_by_id(
self,
intent_id: str | None,
rule_result: RuleMatchResult,
semantic_result: SemanticMatchResult,
):
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for candidate in semantic_result.candidates:
if str(candidate.rule.id) == intent_id:
return candidate.rule
return None
@pytest.fixture
def config():
return FusionConfig()
@pytest.fixture
def mock_rule():
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Test Intent"
rule.response_type = "rag"
return rule
class TestFusionPolicy:
"""Tests for FusionPolicy class."""
def test_init(self, config):
"""Test FusionPolicy initialization."""
policy = FusionPolicy(config)
assert policy._config == config
def test_fuse_rule_high_confidence(self, config, mock_rule):
"""Test fusion with rule high confidence."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=50,
skipped=True,
skip_reason="no_semantic_config",
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rule
assert result.final_confidence == 1.0
assert result.need_clarify is False
def test_fuse_llm_judge(self, config, mock_rule):
"""Test fusion with LLM judge result."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
top_score=0.5,
duration_ms=50,
skipped=False,
skip_reason=None,
)
llm_result = LlmJudgeResult(
intent_id=str(mock_rule.id),
intent_name="Test Intent",
score=0.85,
reasoning="Test reasoning",
duration_ms=500,
tokens_used=100,
triggered=True,
)
result = policy.fuse(rule_result, semantic_result, llm_result)
assert result.decision_reason == "llm_judge"
assert result.final_intent == mock_rule
assert result.final_confidence == 0.85
def test_fuse_semantic_override(self, config, mock_rule):
"""Test fusion with semantic override."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.85)],
top_score=0.85,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "semantic_override"
assert result.final_intent == mock_rule
assert result.final_confidence == 0.85
def test_fuse_rule_semantic_agree(self, config, mock_rule):
"""Test fusion when rule and semantic agree."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rule
def test_fuse_no_match(self, config):
"""Test fusion with no match."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=50,
skipped=True,
skip_reason="no_semantic_config",
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.decision_reason == "no_match"
assert result.final_intent is None
assert result.final_confidence == 0.0
def test_fuse_need_clarify(self, config, mock_rule):
"""Test fusion with clarify needed."""
policy = FusionPolicy(config)
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Other Intent"
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[
SemanticCandidate(rule=mock_rule, score=0.35),
SemanticCandidate(rule=other_rule, score=0.30),
],
top_score=0.35,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.need_clarify is True
assert result.clarify_candidates is not None
assert len(result.clarify_candidates) == 2
def test_calculate_weighted_confidence(self, config, mock_rule):
"""Test weighted confidence calculation."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None)
expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3)
assert abs(confidence - expected) < 0.01
def test_trace_generation(self, config, mock_rule):
"""Test that trace is properly generated."""
policy = FusionPolicy(config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.8)],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
result = policy.fuse(rule_result, semantic_result, None)
assert result.trace is not None
assert result.trace.rule_match["rule_id"] == str(mock_rule.id)
assert result.trace.semantic_match["top_score"] == 0.8
assert result.trace.fusion["decision_reason"] == "rule_high_confidence"

View File

@ -0,0 +1,142 @@
"""
Tests for intent router.
"""
import uuid
import pytest
from app.services.intent.router import IntentRouter, RuleMatcher
from app.services.intent.models import (
FusionConfig,
RuleMatchResult,
SemanticMatchResult,
LlmJudgeResult,
FusionResult,
)
from app.models.entities import IntentRule
class TestRuleMatcher:
"""Test RuleMatcher basic functionality."""
def test_match_empty_message(self):
matcher = RuleMatcher()
result = matcher.match("", [])
assert result.score == 0.0
assert result.rule is None
assert result.duration_ms >= 0
def test_match_empty_rules(self):
matcher = RuleMatcher()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Test Rule",
keywords=["test", "demo"],
is_enabled=True,
)
result = matcher.match("test message", [rule])
assert result.score == 1.0
assert result.rule == rule
assert result.match_type == "keyword"
assert result.matched_text == "test"
def test_match_regex(self):
matcher = RuleMatcher()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Test Regex Rule",
patterns=[r"test.*pattern"],
is_enabled=True,
)
result = matcher.match("this is a test regex pattern", [rule])
assert result.score == 1.0
assert result.rule == rule
assert result.match_type == "regex"
assert "pattern" in result.matched_text
def test_no_match(self):
matcher = RuleMatcher()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Test Rule",
keywords=["specific", "keyword"],
is_enabled=True,
)
result = matcher.match("no match here", [rule])
assert result.score == 0.0
assert result.rule is None
def test_priority_order(self):
matcher = RuleMatcher()
rule1 = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="High Priority",
keywords=["high"],
priority=10,
is_enabled=True,
)
rule2 = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Low Priority",
keywords=["low"],
priority=1,
is_enabled=True,
)
result = matcher.match("high priority message", [rule1, rule2])
assert result.rule == rule1
assert result.rule.name == "High Priority"
def test_disabled_rule(self):
matcher = RuleMatcher()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Disabled Rule",
keywords=["disabled"],
is_enabled=False,
)
result = matcher.match("disabled message", [rule])
assert result.score == 0.0
assert result.rule is None
class TestIntentRouterBackwardCompatibility:
"""Test IntentRouter backward compatibility."""
def test_match_backward_compatible(self):
router = IntentRouter()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Test Rule",
keywords=["hello", "hi"],
is_enabled=True,
)
result = router.match("hello world", [rule])
assert result is not None
assert result.rule.name == "Test Rule"
assert result.match_type == "keyword"
assert result.matched == "hello"
def test_match_with_stats(self):
router = IntentRouter()
rule = IntentRule(
id=uuid.uuid4(),
tenant_id="test_tenant",
name="Test Rule",
keywords=["test"],
is_enabled=True,
)
result, rule_id = router.match_with_stats("test message", [rule])
assert result is not None
assert rule_id == str(rule.id)

View File

@ -0,0 +1,468 @@
"""
Integration tests for IntentRouter.match_hybrid().
[AC-AISVC-111] Tests for hybrid routing integration.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import asyncio
from app.services.intent.models import (
FusionConfig,
FusionResult,
LlmJudgeInput,
LlmJudgeResult,
RuleMatchResult,
SemanticCandidate,
SemanticMatchResult,
RouteTrace,
)
@pytest.fixture
def mock_embedding_provider():
"""Create a mock embedding provider."""
provider = AsyncMock()
provider.embed = AsyncMock(return_value=[0.1] * 768)
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768])
return provider
@pytest.fixture
def mock_llm_client():
"""Create a mock LLM client."""
client = AsyncMock()
return client
@pytest.fixture
def config():
"""Create a fusion config."""
return FusionConfig()
@pytest.fixture
def mock_rule():
"""Create a mock intent rule."""
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Return Intent"
rule.response_type = "rag"
rule.keywords = ["退货", "退款"]
rule.patterns = []
rule.intent_vector = [0.1] * 768
rule.semantic_examples = None
rule.is_enabled = True
rule.priority = 10
return rule
@pytest.fixture
def mock_rules(mock_rule):
"""Create a list of mock intent rules."""
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Order Query"
other_rule.response_type = "rag"
other_rule.keywords = ["订单", "查询"]
other_rule.patterns = []
other_rule.intent_vector = [0.5] * 768
other_rule.semantic_examples = None
other_rule.is_enabled = True
other_rule.priority = 5
return [mock_rule, other_rule]
class MockRuleMatcher:
"""Mock RuleMatcher for testing."""
def match(self, message: str, rules: list) -> RuleMatchResult:
import time
start_time = time.time()
message_lower = message.lower()
for rule in rules:
if not rule.is_enabled:
continue
for keyword in (rule.keywords or []):
if keyword.lower() in message_lower:
return RuleMatchResult(
rule_id=rule.id,
rule=rule,
match_type="keyword",
matched_text=keyword,
score=1.0,
duration_ms=int((time.time() - start_time) * 1000),
)
return RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=int((time.time() - start_time) * 1000),
)
class MockSemanticMatcher:
"""Mock SemanticMatcher for testing."""
def __init__(self, config):
self._config = config
async def match(self, message: str, rules: list, tenant_id: str, top_k: int = 3) -> SemanticMatchResult:
import time
start_time = time.time()
if not self._config.semantic_matcher_enabled:
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=0,
skipped=True,
skip_reason="disabled",
)
candidates = []
for rule in rules:
if rule.intent_vector:
candidates.append(SemanticCandidate(rule=rule, score=0.85))
break
return SemanticMatchResult(
candidates=candidates[:top_k],
top_score=candidates[0].score if candidates else 0.0,
duration_ms=int((time.time() - start_time) * 1000),
skipped=False,
skip_reason=None,
)
class MockLlmJudge:
"""Mock LlmJudge for testing."""
def __init__(self, config):
self._config = config
def should_trigger(self, rule_result, semantic_result, config=None) -> tuple:
effective_config = config or self._config
if not effective_config.llm_judge_enabled:
return False, "disabled"
if rule_result.score > 0 and semantic_result.top_score > 0:
if semantic_result.candidates:
if rule_result.rule_id != semantic_result.candidates[0].rule.id:
if abs(rule_result.score - semantic_result.top_score) < effective_config.conflict_threshold:
return True, "rule_semantic_conflict"
max_score = max(rule_result.score, semantic_result.top_score)
if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold:
return True, "gray_zone"
return False, ""
async def judge(self, input_data: LlmJudgeInput, tenant_id: str) -> LlmJudgeResult:
return LlmJudgeResult(
intent_id=input_data.candidates[0]["id"] if input_data.candidates else None,
intent_name=input_data.candidates[0]["name"] if input_data.candidates else None,
score=0.9,
reasoning="Test arbitration",
duration_ms=500,
tokens_used=100,
triggered=True,
)
class MockFusionPolicy:
"""Mock FusionPolicy for testing."""
DECISION_PRIORITY = [
("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None),
("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None),
("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7),
("no_match", lambda r, s, l: True),
]
def __init__(self, config):
self._config = config
def fuse(self, rule_result, semantic_result, llm_result) -> FusionResult:
decision_reason = "no_match"
for reason, condition in self.DECISION_PRIORITY:
if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()):
decision_reason = reason
break
final_intent = None
final_confidence = 0.0
if decision_reason == "rule_high_confidence":
final_intent = rule_result.rule
final_confidence = 1.0
elif decision_reason == "llm_judge" and llm_result:
final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result)
final_confidence = llm_result.score
elif decision_reason == "semantic_override":
final_intent = semantic_result.candidates[0].rule
final_confidence = semantic_result.top_score
return FusionResult(
final_intent=final_intent,
final_confidence=final_confidence,
decision_reason=decision_reason,
need_clarify=final_confidence < 0.4,
clarify_candidates=None,
trace=RouteTrace(),
)
def _find_rule_by_id(self, intent_id, rule_result, semantic_result):
if not intent_id:
return None
if rule_result.rule_id and str(rule_result.rule_id) == intent_id:
return rule_result.rule
for c in semantic_result.candidates:
if str(c.rule.id) == intent_id:
return c.rule
return None
class MockIntentRouter:
"""Mock IntentRouter for testing match_hybrid."""
def __init__(self, rule_matcher, semantic_matcher, llm_judge, fusion_policy, config=None):
self._rule_matcher = rule_matcher
self._semantic_matcher = semantic_matcher
self._llm_judge = llm_judge
self._fusion_policy = fusion_policy
self._config = config or FusionConfig()
async def match_hybrid(
self,
message: str,
rules: list,
tenant_id: str,
config: FusionConfig | None = None,
) -> FusionResult:
effective_config = config or self._config
rule_result, semantic_result = await asyncio.gather(
asyncio.to_thread(self._rule_matcher.match, message, rules),
self._semantic_matcher.match(message, rules, tenant_id),
)
llm_result = None
should_trigger, trigger_reason = self._llm_judge.should_trigger(
rule_result, semantic_result, effective_config
)
if should_trigger:
candidates = self._build_llm_candidates(rule_result, semantic_result)
llm_result = await self._llm_judge.judge(
LlmJudgeInput(
message=message,
candidates=candidates,
conflict_type=trigger_reason,
),
tenant_id,
)
fusion_result = self._fusion_policy.fuse(
rule_result, semantic_result, llm_result
)
return fusion_result
def _build_llm_candidates(self, rule_result, semantic_result) -> list:
candidates = []
if rule_result.rule:
candidates.append({
"id": str(rule_result.rule_id),
"name": rule_result.rule.name,
"description": f"匹配方式: {rule_result.match_type}",
})
for candidate in semantic_result.candidates[:3]:
if not any(c["id"] == str(candidate.rule.id) for c in candidates):
candidates.append({
"id": str(candidate.rule.id),
"name": candidate.rule.name,
"description": f"语义相似度: {candidate.score:.2f}",
})
return candidates
class TestIntentRouterHybrid:
"""Tests for IntentRouter.match_hybrid() integration."""
@pytest.mark.asyncio
async def test_match_hybrid_rule_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with rule match."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rules[0]
assert result.final_confidence == 1.0
@pytest.mark.asyncio
async def test_match_hybrid_semantic_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with semantic match only."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("商品有问题", mock_rules, "tenant-1")
assert result.decision_reason == "semantic_override"
assert result.final_intent is not None
assert result.final_confidence > 0.7
@pytest.mark.asyncio
async def test_match_hybrid_parallel_execution(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test that rule and semantic matching run in parallel."""
import time
class SlowSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
await asyncio.sleep(0.1)
return await super().match(message, rules, tenant_id, top_k)
rule_matcher = MockRuleMatcher()
semantic_matcher = SlowSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
start_time = time.time()
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
elapsed = time.time() - start_time
assert elapsed < 0.2
assert result is not None
@pytest.mark.asyncio
async def test_match_hybrid_llm_judge_triggered(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with LLM judge triggered."""
config = FusionConfig(conflict_threshold=0.3)
class ConflictSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
result = await super().match(message, rules, tenant_id, top_k)
if result.candidates:
result.candidates[0] = SemanticCandidate(rule=rules[1], score=0.9)
result.top_score = 0.9
return result
rule_matcher = MockRuleMatcher()
semantic_matcher = ConflictSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason in ["rule_high_confidence", "llm_judge"]
@pytest.mark.asyncio
async def test_match_hybrid_no_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test hybrid routing with no match."""
class NoMatchSemanticMatcher(MockSemanticMatcher):
async def match(self, message, rules, tenant_id, top_k=3):
return SemanticMatchResult(
candidates=[],
top_score=0.0,
duration_ms=10,
skipped=True,
skip_reason="no_semantic_config",
)
rule_matcher = MockRuleMatcher()
semantic_matcher = NoMatchSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("随便说说", mock_rules, "tenant-1")
assert result.decision_reason == "no_match"
assert result.final_intent is None
assert result.final_confidence == 0.0
@pytest.mark.asyncio
async def test_match_hybrid_semantic_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
"""Test hybrid routing with semantic matcher disabled."""
config = FusionConfig(semantic_matcher_enabled=False)
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
assert result.final_intent == mock_rules[0]
@pytest.mark.asyncio
async def test_match_hybrid_llm_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules):
"""Test hybrid routing with LLM judge disabled."""
config = FusionConfig(llm_judge_enabled=False)
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.decision_reason == "rule_high_confidence"
@pytest.mark.asyncio
async def test_match_hybrid_trace_generated(self, mock_embedding_provider, mock_llm_client, config, mock_rules):
"""Test that route trace is generated."""
rule_matcher = MockRuleMatcher()
semantic_matcher = MockSemanticMatcher(config)
llm_judge = MockLlmJudge(config)
fusion_policy = MockFusionPolicy(config)
router = MockIntentRouter(
rule_matcher, semantic_matcher, llm_judge, fusion_policy, config
)
result = await router.match_hybrid("我想退货", mock_rules, "tenant-1")
assert result.trace is not None

View File

@ -0,0 +1,308 @@
"""
Tests for KB Search Dynamic Tool with Slot State Integration.
[AC-MRS-SLOT-META-02] KB 检索与槽位状态集成测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.models.mid.schemas import MemorySlot, SlotSource, ToolCallStatus
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.slot_state_aggregator import SlotState
class TestKbSearchDynamicWithSlotState:
"""测试 KB Search Dynamic Tool 与槽位状态的集成"""
@pytest.fixture
def mock_session(self):
"""模拟数据库会话"""
return AsyncMock()
@pytest.fixture
def kb_tool(self, mock_session):
"""创建 KB 工具实例"""
return KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=10000,
min_score_threshold=0.5,
),
)
@pytest.mark.asyncio
async def test_execute_with_missing_required_slots(self, kb_tool, mock_session):
"""测试当存在缺失必填槽位时返回追问响应"""
slot_state = SlotState(
filled_slots={},
missing_required_slots=[
{
"slot_key": "product_line",
"label": "产品线",
"reason": "required_slot_missing",
"ask_back_prompt": "请问您咨询的是哪个产品线?",
}
],
)
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=slot_state,
)
assert result.success is False
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
assert len(result.missing_required_slots) == 1
assert result.missing_required_slots[0]["slot_key"] == "product_line"
assert result.tool_trace is not None
assert result.tool_trace.error_code == "MISSING_REQUIRED_SLOTS"
@pytest.mark.asyncio
async def test_execute_with_filled_slots(self, kb_tool, mock_session):
"""测试使用已填充槽位构建过滤器"""
slot_state = SlotState(
filled_slots={"product_line": "vip_course"},
missing_required_slots=[],
slot_to_field_map={"product_line": "product_line"},
)
# 模拟 filter_builder
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[])
kb_tool._filter_builder = mock_filter_builder
# 模拟检索结果
with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=slot_state,
)
# 应该成功执行(虽然没有命中结果)
assert result.success is True
@pytest.mark.asyncio
async def test_build_filter_from_slot_state_priority(self, kb_tool, mock_session):
"""测试过滤值来源优先级slot > context > default"""
from app.services.mid.metadata_filter_builder import FilterFieldInfo
slot_state = SlotState(
filled_slots={"product_line": "from_slot"},
missing_required_slots=[],
slot_to_field_map={"product_line": "product_line"},
)
context = {"product_line": "from_context"}
# 模拟可过滤字段
mock_field = FilterFieldInfo(
field_key="product_line",
label="产品线",
field_type="string",
required=True,
options=None,
default_value="from_default",
is_filterable=True,
)
# 模拟 filter_builder
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_slot"})
kb_tool._filter_builder = mock_filter_builder
filter_result = await kb_tool._build_filter_from_slot_state(
tenant_id="test_tenant",
slot_state=slot_state,
context=context,
)
# 应该使用 slot 的值(优先级最高)
assert "product_line" in filter_result
mock_filter_builder._build_field_filter.assert_called_once()
call_args = mock_filter_builder._build_field_filter.call_args
assert call_args[0][1] == "from_slot"
@pytest.mark.asyncio
async def test_build_filter_uses_context_when_slot_empty(self, kb_tool, mock_session):
"""测试当 slot 为空时使用 context 值"""
from app.services.mid.metadata_filter_builder import FilterFieldInfo
slot_state = SlotState(
filled_slots={}, # 空槽位
missing_required_slots=[],
slot_to_field_map={},
)
context = {"product_line": "from_context"}
# 模拟可过滤字段
mock_field = FilterFieldInfo(
field_key="product_line",
label="产品线",
field_type="string",
required=True,
options=None,
default_value="from_default",
is_filterable=True,
)
# 模拟 filter_builder
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_context"})
kb_tool._filter_builder = mock_filter_builder
filter_result = await kb_tool._build_filter_from_slot_state(
tenant_id="test_tenant",
slot_state=slot_state,
context=context,
)
# 应该使用 context 的值
assert "product_line" in filter_result
call_args = mock_filter_builder._build_field_filter.call_args
assert call_args[0][1] == "from_context"
@pytest.mark.asyncio
async def test_build_filter_uses_default_when_no_other(self, kb_tool, mock_session):
"""测试当 slot 和 context 都为空时使用默认值"""
from app.services.mid.metadata_filter_builder import FilterFieldInfo
slot_state = SlotState(
filled_slots={},
missing_required_slots=[],
slot_to_field_map={},
)
context = {}
# 模拟可过滤字段(带默认值)
mock_field = FilterFieldInfo(
field_key="product_line",
label="产品线",
field_type="string",
required=False, # 非必填
options=None,
default_value="from_default",
is_filterable=True,
)
# 模拟 filter_builder
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_default"})
kb_tool._filter_builder = mock_filter_builder
filter_result = await kb_tool._build_filter_from_slot_state(
tenant_id="test_tenant",
slot_state=slot_state,
context=context,
)
# 应该使用默认值
assert "product_line" in filter_result
call_args = mock_filter_builder._build_field_filter.call_args
assert call_args[0][1] == "from_default"
class TestKbSearchDynamicSlotMapping:
"""测试槽位与字段映射"""
@pytest.mark.asyncio
async def test_slot_to_field_mapping_in_filter(self):
"""测试通过 slot_to_field_map 映射槽位值到字段"""
from app.services.mid.metadata_filter_builder import FilterFieldInfo
from unittest.mock import AsyncMock
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(session=mock_session)
# slot_key 是 "product",但映射到 field_key "product_line"
slot_state = SlotState(
filled_slots={"product": "vip_course"},
missing_required_slots=[],
slot_to_field_map={"product": "product_line"},
)
# 模拟可过滤字段(使用 field_key
mock_field = FilterFieldInfo(
field_key="product_line",
label="产品线",
field_type="string",
required=True,
options=None,
default_value=None,
is_filterable=True,
)
# 模拟 filter_builder
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"})
kb_tool._filter_builder = mock_filter_builder
filter_result = await kb_tool._build_filter_from_slot_state(
tenant_id="test_tenant",
slot_state=slot_state,
context={},
)
# 应该通过映射找到值
assert "product_line" in filter_result
call_args = mock_filter_builder._build_field_filter.call_args
assert call_args[0][1] == "vip_course"
class TestKbSearchDynamicDebugInfo:
"""测试调试信息输出"""
@pytest.mark.asyncio
async def test_filter_debug_includes_sources(self):
"""测试过滤器调试信息包含来源标识"""
from app.services.mid.metadata_filter_builder import FilterFieldInfo
from unittest.mock import AsyncMock
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(session=mock_session)
slot_state = SlotState(
filled_slots={"product_line": "vip_course"},
missing_required_slots=[],
slot_to_field_map={"product_line": "product_line"},
)
mock_field = FilterFieldInfo(
field_key="product_line",
label="产品线",
field_type="string",
required=True,
options=None,
default_value=None,
is_filterable=True,
)
mock_filter_builder = MagicMock()
mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field])
mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"})
kb_tool._filter_builder = mock_filter_builder
with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=slot_state,
)
# 调试信息应该包含来源
assert result.filter_debug is not None

View File

@ -0,0 +1,291 @@
"""
Unit tests for LlmJudge.
[AC-AISVC-118, AC-AISVC-119] Tests for LLM-based intent arbitration.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
from app.services.intent.llm_judge import LlmJudge
from app.services.intent.models import (
FusionConfig,
LlmJudgeInput,
LlmJudgeResult,
RuleMatchResult,
SemanticCandidate,
SemanticMatchResult,
)
@pytest.fixture
def mock_llm_client():
"""Create a mock LLM client."""
client = AsyncMock()
return client
@pytest.fixture
def config():
"""Create a fusion config."""
return FusionConfig()
@pytest.fixture
def mock_rule():
"""Create a mock intent rule."""
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Test Intent"
return rule
class TestLlmJudge:
"""Tests for LlmJudge class."""
def test_init(self, mock_llm_client, config):
"""Test LlmJudge initialization."""
judge = LlmJudge(mock_llm_client, config)
assert judge._llm_client == mock_llm_client
assert judge._config == config
def test_should_trigger_disabled(self, mock_llm_client):
"""Test should_trigger when LLM judge is disabled."""
config = FusionConfig(llm_judge_enabled=False)
judge = LlmJudge(mock_llm_client, config)
rule_result = RuleMatchResult(
rule_id=uuid.uuid4(),
rule=MagicMock(),
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
triggered, reason = judge.should_trigger(rule_result, semantic_result)
assert triggered is False
assert reason == "disabled"
def test_should_trigger_rule_semantic_conflict(self, mock_llm_client, config, mock_rule):
"""Test should_trigger for rule vs semantic conflict."""
judge = LlmJudge(mock_llm_client, config)
rule_result = RuleMatchResult(
rule_id=uuid.uuid4(),
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Other Intent"
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=other_rule, score=0.95)],
top_score=0.95,
duration_ms=50,
skipped=False,
skip_reason=None,
)
triggered, reason = judge.should_trigger(rule_result, semantic_result)
assert triggered is True
assert reason == "rule_semantic_conflict"
def test_should_trigger_gray_zone(self, mock_llm_client, config, mock_rule):
"""Test should_trigger for gray zone scenario."""
judge = LlmJudge(mock_llm_client, config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.5)],
top_score=0.5,
duration_ms=50,
skipped=False,
skip_reason=None,
)
triggered, reason = judge.should_trigger(rule_result, semantic_result)
assert triggered is True
assert reason == "gray_zone"
def test_should_trigger_multi_intent(self, mock_llm_client, config, mock_rule):
"""Test should_trigger for multi-intent scenario."""
judge = LlmJudge(mock_llm_client, config)
rule_result = RuleMatchResult(
rule_id=None,
rule=None,
match_type=None,
matched_text=None,
score=0.0,
duration_ms=10,
)
other_rule = MagicMock()
other_rule.id = uuid.uuid4()
other_rule.name = "Other Intent"
semantic_result = SemanticMatchResult(
candidates=[
SemanticCandidate(rule=mock_rule, score=0.8),
SemanticCandidate(rule=other_rule, score=0.75),
],
top_score=0.8,
duration_ms=50,
skipped=False,
skip_reason=None,
)
triggered, reason = judge.should_trigger(rule_result, semantic_result)
assert triggered is True
assert reason == "multi_intent"
def test_should_not_trigger_high_confidence(self, mock_llm_client, config, mock_rule):
"""Test should_trigger returns False for high confidence match."""
judge = LlmJudge(mock_llm_client, config)
rule_result = RuleMatchResult(
rule_id=mock_rule.id,
rule=mock_rule,
match_type="keyword",
matched_text="test",
score=1.0,
duration_ms=10,
)
semantic_result = SemanticMatchResult(
candidates=[SemanticCandidate(rule=mock_rule, score=0.9)],
top_score=0.9,
duration_ms=50,
skipped=False,
skip_reason=None,
)
triggered, reason = judge.should_trigger(rule_result, semantic_result)
assert triggered is False
@pytest.mark.asyncio
async def test_judge_success(self, mock_llm_client, config):
"""Test successful LLM judge."""
from app.services.llm.base import LLMResponse
mock_response = LLMResponse(
content='{"intent_id": "test-id", "intent_name": "Test", "confidence": 0.85, "reasoning": "Test reasoning"}',
model="gpt-4",
usage={"total_tokens": 100},
)
mock_llm_client.generate = AsyncMock(return_value=mock_response)
judge = LlmJudge(mock_llm_client, config)
input_data = LlmJudgeInput(
message="test message",
candidates=[{"id": "test-id", "name": "Test", "description": "Test intent"}],
conflict_type="gray_zone",
)
result = await judge.judge(input_data, "tenant-1")
assert result.triggered is True
assert result.intent_id == "test-id"
assert result.intent_name == "Test"
assert result.score == 0.85
assert result.reasoning == "Test reasoning"
assert result.tokens_used == 100
@pytest.mark.asyncio
async def test_judge_timeout(self, mock_llm_client, config):
"""Test LLM judge timeout."""
import asyncio
mock_llm_client.generate = AsyncMock(side_effect=asyncio.TimeoutError())
judge = LlmJudge(mock_llm_client, config)
input_data = LlmJudgeInput(
message="test message",
candidates=[{"id": "test-id", "name": "Test"}],
conflict_type="gray_zone",
)
result = await judge.judge(input_data, "tenant-1")
assert result.triggered is True
assert result.intent_id is None
assert "timeout" in result.reasoning.lower()
@pytest.mark.asyncio
async def test_judge_error(self, mock_llm_client, config):
"""Test LLM judge error handling."""
mock_llm_client.generate = AsyncMock(side_effect=Exception("LLM error"))
judge = LlmJudge(mock_llm_client, config)
input_data = LlmJudgeInput(
message="test message",
candidates=[{"id": "test-id", "name": "Test"}],
conflict_type="gray_zone",
)
result = await judge.judge(input_data, "tenant-1")
assert result.triggered is True
assert result.intent_id is None
assert "error" in result.reasoning.lower()
def test_parse_response_valid_json(self, mock_llm_client, config):
"""Test parsing valid JSON response."""
judge = LlmJudge(mock_llm_client, config)
content = '{"intent_id": "test", "confidence": 0.9}'
result = judge._parse_response(content)
assert result["intent_id"] == "test"
assert result["confidence"] == 0.9
def test_parse_response_with_markdown(self, mock_llm_client, config):
"""Test parsing JSON response with markdown code block."""
judge = LlmJudge(mock_llm_client, config)
content = '```json\n{"intent_id": "test", "confidence": 0.9}\n```'
result = judge._parse_response(content)
assert result["intent_id"] == "test"
assert result["confidence"] == 0.9
def test_parse_response_invalid_json(self, mock_llm_client, config):
"""Test parsing invalid JSON response."""
judge = LlmJudge(mock_llm_client, config)
content = "This is not valid JSON"
result = judge._parse_response(content)
assert result == {}
def test_llm_judge_result_empty(self):
"""Test LlmJudgeResult.empty() class method."""
result = LlmJudgeResult.empty()
assert result.intent_id is None
assert result.intent_name is None
assert result.score == 0.0
assert result.reasoning is None
assert result.duration_ms == 0
assert result.tokens_used == 0
assert result.triggered is False

View File

@ -0,0 +1,299 @@
"""
Tests for Scene Slot Bundle Loader.
[AC-SCENE-SLOT-02] 场景槽位包加载器测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from app.services.mid.scene_slot_bundle_loader import (
SceneSlotBundleLoader,
SceneSlotContext,
SlotInfo,
)
from app.models.entities import (
SceneSlotBundle,
SceneSlotBundleStatus,
SlotDefinition,
)
@pytest.fixture
def mock_session():
"""Mock database session."""
session = AsyncMock()
session.execute = AsyncMock()
return session
@pytest.fixture
def sample_slot_definitions():
"""Sample slot definitions for testing."""
return [
SlotDefinition(
id=uuid4(),
tenant_id="test_tenant",
slot_key="course_type",
type="string",
required=True,
ask_back_prompt="请问您想咨询哪种类型的课程?",
),
SlotDefinition(
id=uuid4(),
tenant_id="test_tenant",
slot_key="grade",
type="string",
required=True,
ask_back_prompt="请问您是几年级?",
),
SlotDefinition(
id=uuid4(),
tenant_id="test_tenant",
slot_key="region",
type="string",
required=False,
ask_back_prompt="请问您在哪个地区?",
),
]
@pytest.fixture
def sample_bundle(sample_slot_definitions):
"""Sample scene slot bundle for testing."""
return SceneSlotBundle(
id=uuid4(),
tenant_id="test_tenant",
scene_key="open_consult",
scene_name="开放咨询",
description="开放咨询场景的槽位配置",
required_slots=["course_type", "grade"],
optional_slots=["region"],
slot_priority=["course_type", "grade", "region"],
completion_threshold=1.0,
ask_back_order="priority",
status=SceneSlotBundleStatus.ACTIVE.value,
version=1,
)
class TestSceneSlotContext:
"""Test cases for SceneSlotContext."""
def test_get_all_slot_keys(self):
"""Test getting all slot keys."""
context = SceneSlotContext(
scene_key="test_scene",
scene_name="测试场景",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True),
],
optional_slots=[
SlotInfo(slot_key="region", type="string", required=False),
],
)
all_keys = context.get_all_slot_keys()
assert "course_type" in all_keys
assert "region" in all_keys
assert len(all_keys) == 2
def test_get_missing_slots(self):
"""Test getting missing slots."""
context = SceneSlotContext(
scene_key="test_scene",
scene_name="测试场景",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"),
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"),
],
optional_slots=[],
)
filled_slots = {"course_type": "数学"}
missing = context.get_missing_slots(filled_slots)
assert len(missing) == 1
assert missing[0]["slot_key"] == "grade"
def test_get_ordered_missing_slots_priority(self):
"""Test getting ordered missing slots with priority order."""
context = SceneSlotContext(
scene_key="test_scene",
scene_name="测试场景",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True),
SlotInfo(slot_key="grade", type="string", required=True),
],
optional_slots=[],
slot_priority=["grade", "course_type"],
ask_back_order="priority",
)
filled_slots = {}
missing = context.get_ordered_missing_slots(filled_slots)
assert len(missing) == 2
assert missing[0]["slot_key"] == "grade"
assert missing[1]["slot_key"] == "course_type"
def test_get_completion_ratio(self):
"""Test calculating completion ratio."""
context = SceneSlotContext(
scene_key="test_scene",
scene_name="测试场景",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True),
SlotInfo(slot_key="grade", type="string", required=True),
],
optional_slots=[],
completion_threshold=0.5,
)
filled_slots = {"course_type": "数学"}
ratio = context.get_completion_ratio(filled_slots)
assert ratio == 0.5
def test_is_complete(self):
"""Test checking if complete."""
context = SceneSlotContext(
scene_key="test_scene",
scene_name="测试场景",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True),
SlotInfo(slot_key="grade", type="string", required=True),
],
optional_slots=[],
completion_threshold=1.0,
)
filled_slots = {"course_type": "数学", "grade": "高一"}
is_complete = context.is_complete(filled_slots)
assert is_complete is True
filled_slots_partial = {"course_type": "数学"}
is_complete_partial = context.is_complete(filled_slots_partial)
assert is_complete_partial is False
class TestSceneSlotBundleLoader:
"""Test cases for SceneSlotBundleLoader."""
@pytest.mark.asyncio
async def test_load_scene_context(self, mock_session, sample_bundle, sample_slot_definitions):
"""Test loading scene context."""
loader = SceneSlotBundleLoader(mock_session)
with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get_bundle:
mock_get_bundle.return_value = sample_bundle
with patch.object(loader._slot_service, 'list_slot_definitions', new_callable=AsyncMock) as mock_get_slots:
mock_get_slots.return_value = sample_slot_definitions
context = await loader.load_scene_context("test_tenant", "open_consult")
assert context is not None
assert context.scene_key == "open_consult"
assert len(context.required_slots) == 2
assert len(context.optional_slots) == 1
@pytest.mark.asyncio
async def test_load_scene_context_not_found(self, mock_session):
"""Test loading scene context when bundle not found."""
loader = SceneSlotBundleLoader(mock_session)
with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get:
mock_get.return_value = None
context = await loader.load_scene_context("test_tenant", "unknown_scene")
assert context is None
@pytest.mark.asyncio
async def test_get_missing_slots_for_scene(self, mock_session, sample_bundle, sample_slot_definitions):
"""Test getting missing slots for a scene."""
loader = SceneSlotBundleLoader(mock_session)
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
mock_context = SceneSlotContext(
scene_key="open_consult",
scene_name="开放咨询",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"),
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"),
],
optional_slots=[],
slot_priority=["course_type", "grade"],
)
mock_load.return_value = mock_context
filled_slots = {"course_type": "数学"}
missing = await loader.get_missing_slots_for_scene(
"test_tenant", "open_consult", filled_slots
)
assert len(missing) == 1
assert missing[0]["slot_key"] == "grade"
@pytest.mark.asyncio
async def test_generate_ask_back_prompt_single(self, mock_session):
"""Test generating ask-back prompt for single missing slot."""
loader = SceneSlotBundleLoader(mock_session)
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
mock_context = SceneSlotContext(
scene_key="open_consult",
scene_name="开放咨询",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问您想咨询哪种课程?"),
],
optional_slots=[],
ask_back_order="priority",
)
mock_load.return_value = mock_context
with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing:
mock_missing.return_value = [
{"slot_key": "course_type", "ask_back_prompt": "请问您想咨询哪种课程?"}
]
prompt = await loader.generate_ask_back_prompt(
"test_tenant", "open_consult", {}
)
assert prompt == "请问您想咨询哪种课程?"
@pytest.mark.asyncio
async def test_generate_ask_back_prompt_parallel(self, mock_session):
"""Test generating ask-back prompt with parallel strategy."""
loader = SceneSlotBundleLoader(mock_session)
with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load:
mock_context = SceneSlotContext(
scene_key="open_consult",
scene_name="开放咨询",
required_slots=[
SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="课程类型"),
SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="年级"),
],
optional_slots=[],
ask_back_order="parallel",
)
mock_load.return_value = mock_context
with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing:
mock_missing.return_value = [
{"slot_key": "course_type", "ask_back_prompt": "课程类型"},
{"slot_key": "grade", "ask_back_prompt": "年级"},
]
prompt = await loader.generate_ask_back_prompt(
"test_tenant", "open_consult", {}
)
assert "课程类型" in prompt
assert "年级" in prompt

View File

@ -0,0 +1,284 @@
"""
Tests for Scene Slot Bundle Service.
[AC-SCENE-SLOT-01] 场景-槽位映射配置服务测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from app.models.entities import (
SceneSlotBundle,
SceneSlotBundleCreate,
SceneSlotBundleUpdate,
SceneSlotBundleStatus,
SlotDefinition,
)
from app.services.scene_slot_bundle_service import SceneSlotBundleService
@pytest.fixture
def mock_session():
"""Mock database session."""
session = AsyncMock()
session.execute = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def sample_slot_definition():
"""Sample slot definition for testing."""
return SlotDefinition(
id=uuid4(),
tenant_id="test_tenant",
slot_key="course_type",
type="string",
required=True,
ask_back_prompt="请问您想咨询哪种类型的课程?",
)
@pytest.fixture
def sample_bundle():
"""Sample scene slot bundle for testing."""
return SceneSlotBundle(
id=uuid4(),
tenant_id="test_tenant",
scene_key="open_consult",
scene_name="开放咨询",
description="开放咨询场景的槽位配置",
required_slots=["course_type", "grade"],
optional_slots=["region"],
slot_priority=["course_type", "grade", "region"],
completion_threshold=1.0,
ask_back_order="priority",
status=SceneSlotBundleStatus.ACTIVE.value,
version=1,
)
class TestSceneSlotBundleService:
"""Test cases for SceneSlotBundleService."""
@pytest.mark.asyncio
async def test_list_bundles(self, mock_session, sample_bundle):
"""Test listing scene slot bundles."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [sample_bundle]
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
bundles = await service.list_bundles("test_tenant")
assert len(bundles) == 1
assert bundles[0].scene_key == "open_consult"
@pytest.mark.asyncio
async def test_list_bundles_with_status_filter(self, mock_session, sample_bundle):
"""Test listing scene slot bundles with status filter."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [sample_bundle]
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
bundles = await service.list_bundles("test_tenant", status="active")
assert len(bundles) == 1
@pytest.mark.asyncio
async def test_get_bundle_by_id(self, mock_session, sample_bundle):
"""Test getting a bundle by ID."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = sample_bundle
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
bundle = await service.get_bundle("test_tenant", str(sample_bundle.id))
assert bundle is not None
assert bundle.scene_key == "open_consult"
@pytest.mark.asyncio
async def test_get_bundle_by_scene_key(self, mock_session, sample_bundle):
"""Test getting a bundle by scene key."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = sample_bundle
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
bundle = await service.get_bundle_by_scene_key("test_tenant", "open_consult")
assert bundle is not None
assert bundle.scene_key == "open_consult"
@pytest.mark.asyncio
async def test_get_active_bundle_by_scene(self, mock_session, sample_bundle):
"""Test getting an active bundle by scene key."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = sample_bundle
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
bundle = await service.get_active_bundle_by_scene("test_tenant", "open_consult")
assert bundle is not None
assert bundle.status == SceneSlotBundleStatus.ACTIVE.value
@pytest.mark.asyncio
async def test_create_bundle_success(self, mock_session, sample_slot_definition):
"""Test creating a scene slot bundle successfully."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
with patch.object(service, '_validate_slot_keys', new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = {"course_type"}
with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get:
mock_get.return_value = None
bundle_create = SceneSlotBundleCreate(
scene_key="new_scene",
scene_name="新场景",
required_slots=["course_type"],
optional_slots=[],
)
bundle = await service.create_bundle("test_tenant", bundle_create)
assert bundle is not None
assert bundle.scene_key == "new_scene"
mock_session.add.assert_called_once()
@pytest.mark.asyncio
async def test_create_bundle_duplicate_scene_key(self, mock_session, sample_bundle):
"""Test creating a bundle with duplicate scene key."""
service = SceneSlotBundleService(mock_session)
with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get:
mock_get.return_value = sample_bundle
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = []
bundle_create = SceneSlotBundleCreate(
scene_key="open_consult",
scene_name="开放咨询",
)
with pytest.raises(ValueError, match="已存在"):
await service.create_bundle("test_tenant", bundle_create)
@pytest.mark.asyncio
async def test_create_bundle_validation_error(self, mock_session):
"""Test creating a bundle with validation error."""
service = SceneSlotBundleService(mock_session)
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = ["必填和可选槽位存在交叉"]
bundle_create = SceneSlotBundleCreate(
scene_key="new_scene",
scene_name="新场景",
required_slots=["course_type"],
optional_slots=["course_type"],
)
with pytest.raises(ValueError, match="交叉"):
await service.create_bundle("test_tenant", bundle_create)
@pytest.mark.asyncio
async def test_update_bundle_success(self, mock_session, sample_bundle):
"""Test updating a scene slot bundle successfully."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = sample_bundle
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = []
bundle_update = SceneSlotBundleUpdate(
scene_name="更新后的场景名称",
)
bundle = await service.update_bundle("test_tenant", str(sample_bundle.id), bundle_update)
assert bundle is not None
assert bundle.scene_name == "更新后的场景名称"
assert bundle.version == 2
@pytest.mark.asyncio
async def test_delete_bundle_success(self, mock_session, sample_bundle):
"""Test deleting a scene slot bundle successfully."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = sample_bundle
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
success = await service.delete_bundle("test_tenant", str(sample_bundle.id))
assert success is True
mock_session.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_bundle_not_found(self, mock_session):
"""Test deleting a non-existent bundle."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
success = await service.delete_bundle("test_tenant", str(uuid4()))
assert success is False
class TestSceneSlotBundleValidation:
"""Test cases for bundle validation."""
@pytest.mark.asyncio
async def test_validate_required_optional_overlap(self, mock_session, sample_slot_definition):
"""Test validation for required and optional slots overlap."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
errors = await service._validate_bundle_data(
tenant_id="test_tenant",
required_slots=["course_type"],
optional_slots=["course_type"],
slot_priority=None,
)
assert len(errors) > 0
assert any("交叉" in e for e in errors)
@pytest.mark.asyncio
async def test_validate_priority_unknown_slots(self, mock_session, sample_slot_definition):
"""Test validation for unknown slots in priority list."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [sample_slot_definition]
mock_session.execute.return_value = mock_result
service = SceneSlotBundleService(mock_session)
errors = await service._validate_bundle_data(
tenant_id="test_tenant",
required_slots=["course_type"],
optional_slots=[],
slot_priority=["course_type", "unknown_slot"],
)
assert len(errors) > 0
assert any("未定义" in e for e in errors)

View File

@ -0,0 +1,210 @@
"""
Unit tests for SemanticMatcher.
[AC-AISVC-113, AC-AISVC-114] Tests for semantic matching.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
from app.services.intent.semantic_matcher import SemanticMatcher
from app.services.intent.models import (
FusionConfig,
SemanticCandidate,
SemanticMatchResult,
)
@pytest.fixture
def mock_embedding_provider():
"""Create a mock embedding provider."""
provider = AsyncMock()
provider.embed = AsyncMock(return_value=[0.1] * 768)
provider.embed_batch = AsyncMock(return_value=[[0.1] * 768, [0.2] * 768])
return provider
@pytest.fixture
def mock_rule():
"""Create a mock intent rule with semantic config."""
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Test Intent"
rule.intent_vector = [0.1] * 768
rule.semantic_examples = None
rule.is_enabled = True
return rule
@pytest.fixture
def mock_rule_with_examples():
"""Create a mock intent rule with semantic examples."""
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = "Test Intent with Examples"
rule.intent_vector = None
rule.semantic_examples = ["我想退货", "如何退款"]
rule.is_enabled = True
return rule
@pytest.fixture
def config():
"""Create a fusion config."""
return FusionConfig()
class TestSemanticMatcher:
"""Tests for SemanticMatcher class."""
@pytest.mark.asyncio
async def test_init(self, mock_embedding_provider, config):
"""Test SemanticMatcher initialization."""
matcher = SemanticMatcher(mock_embedding_provider, config)
assert matcher._embedding_provider == mock_embedding_provider
assert matcher._config == config
@pytest.mark.asyncio
async def test_match_disabled(self, mock_embedding_provider):
"""Test match when semantic matcher is disabled."""
config = FusionConfig(semantic_matcher_enabled=False)
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("test message", [], "tenant-1")
assert result.skipped is True
assert result.skip_reason == "disabled"
assert result.candidates == []
@pytest.mark.asyncio
async def test_match_no_semantic_config(
self, mock_embedding_provider, config, mock_rule
):
"""Test match when no rules have semantic config."""
mock_rule.intent_vector = None
mock_rule.semantic_examples = None
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("test message", [mock_rule], "tenant-1")
assert result.skipped is True
assert result.skip_reason == "no_semantic_config"
@pytest.mark.asyncio
async def test_match_mode_a_with_intent_vector(
self, mock_embedding_provider, config, mock_rule
):
"""Test match with pre-computed intent vector (Mode A)."""
mock_embedding_provider.embed.return_value = [0.1] * 768
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("我想退货", [mock_rule], "tenant-1")
assert result.skipped is False
assert result.skip_reason is None
assert len(result.candidates) == 1
assert result.top_score > 0.9
assert result.duration_ms >= 0
@pytest.mark.asyncio
async def test_match_mode_b_with_examples(
self, mock_embedding_provider, config, mock_rule_with_examples
):
"""Test match with semantic examples (Mode B)."""
mock_embedding_provider.embed.return_value = [0.1] * 768
mock_embedding_provider.embed_batch.return_value = [[0.1] * 768, [0.1] * 768]
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("我想退货", [mock_rule_with_examples], "tenant-1")
assert result.skipped is False
assert len(result.candidates) == 1
assert result.top_score > 0.9
@pytest.mark.asyncio
async def test_match_embedding_timeout(self, mock_embedding_provider, config, mock_rule):
"""Test match when embedding times out."""
import asyncio
mock_embedding_provider.embed.side_effect = asyncio.TimeoutError()
config = FusionConfig(semantic_matcher_timeout_ms=100)
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("test message", [mock_rule], "tenant-1")
assert result.skipped is True
assert "embedding_timeout" in result.skip_reason
@pytest.mark.asyncio
async def test_match_embedding_error(self, mock_embedding_provider, config, mock_rule):
"""Test match when embedding fails with error."""
mock_embedding_provider.embed.side_effect = Exception("Embedding failed")
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("test message", [mock_rule], "tenant-1")
assert result.skipped is True
assert "embedding_error" in result.skip_reason
@pytest.mark.asyncio
async def test_match_top_k_limit(self, mock_embedding_provider, config):
"""Test that match returns only top_k candidates."""
rules = []
for i in range(5):
rule = MagicMock()
rule.id = uuid.uuid4()
rule.name = f"Intent {i}"
rule.intent_vector = [0.1 + i * 0.01] * 768
rule.semantic_examples = None
rule.is_enabled = True
rules.append(rule)
mock_embedding_provider.embed.return_value = [0.1] * 768
config = FusionConfig(semantic_top_k=3)
matcher = SemanticMatcher(mock_embedding_provider, config)
result = await matcher.match("test message", rules, "tenant-1")
assert len(result.candidates) <= 3
def test_cosine_similarity(self, mock_embedding_provider, config):
"""Test cosine similarity calculation."""
matcher = SemanticMatcher(mock_embedding_provider, config)
v1 = [1.0, 0.0, 0.0]
v2 = [1.0, 0.0, 0.0]
similarity = matcher._cosine_similarity(v1, v2)
assert similarity == 1.0
v1 = [1.0, 0.0, 0.0]
v2 = [0.0, 1.0, 0.0]
similarity = matcher._cosine_similarity(v1, v2)
assert similarity == 0.0
v1 = [1.0, 1.0, 0.0]
v2 = [1.0, 0.0, 0.0]
similarity = matcher._cosine_similarity(v1, v2)
assert 0.0 < similarity < 1.0
def test_cosine_similarity_empty_vectors(self, mock_embedding_provider, config):
"""Test cosine similarity with empty vectors."""
matcher = SemanticMatcher(mock_embedding_provider, config)
assert matcher._cosine_similarity([], [1.0]) == 0.0
assert matcher._cosine_similarity([1.0], []) == 0.0
assert matcher._cosine_similarity([], []) == 0.0
def test_has_semantic_config(self, mock_embedding_provider, config, mock_rule):
"""Test checking if rule has semantic config."""
matcher = SemanticMatcher(mock_embedding_provider, config)
mock_rule.intent_vector = [0.1] * 768
mock_rule.semantic_examples = None
assert matcher._has_semantic_config(mock_rule) is True
mock_rule.intent_vector = None
mock_rule.semantic_examples = ["example"]
assert matcher._has_semantic_config(mock_rule) is True
mock_rule.intent_vector = None
mock_rule.semantic_examples = None
assert matcher._has_semantic_config(mock_rule) is False

View File

@ -0,0 +1,419 @@
"""
Tests for Slot Backfill Service.
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_backfill_service import (
BackfillResult,
BackfillStatus,
BatchBackfillResult,
SlotBackfillService,
create_slot_backfill_service,
)
from app.services.mid.slot_manager import SlotWriteResult
from app.services.mid.slot_strategy_executor import (
StrategyChainResult,
StrategyStepResult,
)
class TestBackfillResult:
"""BackfillResult 测试"""
def test_is_success(self):
"""测试成功判断"""
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.is_success() is True
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
assert result.is_success() is False
def test_needs_ask_back(self):
"""测试需要追问判断"""
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
assert result.needs_ask_back() is True
result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED)
assert result.needs_ask_back() is True
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.needs_ask_back() is False
def test_needs_confirmation(self):
"""测试需要确认判断"""
result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION)
assert result.needs_confirmation() is True
result = BackfillResult(status=BackfillStatus.SUCCESS)
assert result.needs_confirmation() is False
def test_to_dict(self):
"""测试转换为字典"""
result = BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key="region",
value="北京",
normalized_value="北京",
source="user_confirmed",
confidence=1.0,
)
d = result.to_dict()
assert d["status"] == "success"
assert d["slot_key"] == "region"
assert d["value"] == "北京"
assert d["source"] == "user_confirmed"
class TestBatchBackfillResult:
"""BatchBackfillResult 测试"""
def test_add_result(self):
"""测试添加结果"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region"))
batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product"))
batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade"))
assert batch.success_count == 1
assert batch.failed_count == 1
assert batch.confirmation_needed_count == 1
def test_get_ask_back_prompts(self):
"""测试获取追问提示"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(
status=BackfillStatus.VALIDATION_FAILED,
ask_back_prompt="请重新输入",
))
batch.add_result(BackfillResult(
status=BackfillStatus.SUCCESS,
))
batch.add_result(BackfillResult(
status=BackfillStatus.EXTRACTION_FAILED,
ask_back_prompt="无法识别,请重试",
))
prompts = batch.get_ask_back_prompts()
assert len(prompts) == 2
assert "请重新输入" in prompts
assert "无法识别,请重试" in prompts
def test_get_confirmation_prompts(self):
"""测试获取确认提示"""
batch = BatchBackfillResult()
batch.add_result(BackfillResult(
status=BackfillStatus.NEEDS_CONFIRMATION,
confirmation_prompt="我理解您说的是「北京」,对吗?",
))
batch.add_result(BackfillResult(
status=BackfillStatus.SUCCESS,
))
prompts = batch.get_confirmation_prompts()
assert len(prompts) == 1
assert "北京" in prompts[0]
class TestSlotBackfillService:
"""SlotBackfillService 测试"""
@pytest.fixture
def mock_session(self):
"""创建 mock session"""
return AsyncMock()
@pytest.fixture
def mock_slot_manager(self):
"""创建 mock slot manager"""
manager = MagicMock()
manager.write_slot = AsyncMock()
manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息")
return manager
@pytest.fixture
def service(self, mock_session, mock_slot_manager):
"""创建服务实例"""
return SlotBackfillService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
slot_manager=mock_slot_manager,
)
def test_confidence_thresholds(self, service):
"""测试置信度阈值"""
assert service.CONFIDENCE_THRESHOLD_LOW == 0.5
assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8
def test_get_source_for_strategy(self, service):
"""测试策略到来源的映射"""
assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
assert service._get_source_for_strategy("unknown") == "unknown"
def test_get_confidence_for_strategy(self, service):
"""测试来源到置信度的映射"""
assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0
assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9
assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7
assert service._get_confidence_for_strategy("context") == 0.5
assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3
def test_generate_confirmation_prompt(self, service):
"""测试生成确认提示"""
prompt = service._generate_confirmation_prompt("region", "北京")
assert "北京" in prompt
assert "对吗" in prompt
@pytest.mark.asyncio
async def test_backfill_single_slot_success(self, service, mock_slot_manager):
"""测试单个槽位回填成功"""
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=True,
slot_key="region",
value="北京",
)
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="北京",
source="user_confirmed",
confidence=1.0,
)
assert result.status == BackfillStatus.SUCCESS
assert result.slot_key == "region"
assert result.normalized_value == "北京"
@pytest.mark.asyncio
async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager):
"""测试单个槽位回填校验失败"""
from app.services.mid.slot_validation_service import SlotValidationError
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=False,
slot_key="region",
error=SlotValidationError(
slot_key="region",
error_code="INVALID_VALUE",
error_message="无效的地区",
),
ask_back_prompt="请提供有效的地区",
)
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="无效地区",
source="user_confirmed",
confidence=1.0,
)
assert result.status == BackfillStatus.VALIDATION_FAILED
assert result.ask_back_prompt == "请提供有效的地区"
@pytest.mark.asyncio
async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager):
"""测试低置信度槽位需要确认"""
mock_slot_manager.write_slot.return_value = SlotWriteResult(
success=True,
slot_key="region",
value="北京",
)
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_single_slot(
slot_key="region",
candidate_value="北京",
source="llm_inferred",
confidence=0.4,
)
assert result.status == BackfillStatus.NEEDS_CONFIRMATION
assert result.confirmation_prompt is not None
assert "北京" in result.confirmation_prompt
@pytest.mark.asyncio
async def test_backfill_multiple_slots(self, service, mock_slot_manager):
"""测试批量回填槽位"""
mock_slot_manager.write_slot.side_effect = [
SlotWriteResult(success=True, slot_key="region", value="北京"),
SlotWriteResult(success=True, slot_key="product", value="手机"),
SlotWriteResult(success=False, slot_key="grade", error=MagicMock()),
]
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.backfill_multiple_slots(
candidates={
"region": "北京",
"product": "手机",
"grade": "无效等级",
},
source="user_confirmed",
)
assert result.success_count == 2
assert result.failed_count == 1
@pytest.mark.asyncio
async def test_confirm_low_confidence_slot_confirmed(self, service):
"""测试确认低置信度槽位 - 用户确认"""
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.update_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.confirm_low_confidence_slot(
slot_key="region",
confirmed=True,
)
assert result.status == BackfillStatus.SUCCESS
assert result.source == SlotSource.USER_CONFIRMED.value
assert result.confidence == 1.0
@pytest.mark.asyncio
async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager):
"""测试确认低置信度槽位 - 用户拒绝"""
with patch.object(service, '_get_state_aggregator') as mock_agg:
mock_aggregator = AsyncMock()
mock_aggregator.clear_slot = AsyncMock()
mock_agg.return_value = mock_aggregator
result = await service.confirm_low_confidence_slot(
slot_key="region",
confirmed=False,
)
assert result.status == BackfillStatus.VALIDATION_FAILED
assert result.ask_back_prompt is not None
class TestCreateSlotBackfillService:
"""create_slot_backfill_service 工厂函数测试"""
def test_create(self):
"""测试创建服务实例"""
mock_session = AsyncMock()
service = create_slot_backfill_service(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
assert isinstance(service, SlotBackfillService)
assert service._tenant_id == "tenant_1"
assert service._session_id == "session_1"
class TestBackfillFromUserResponse:
"""从用户回复回填测试"""
@pytest.fixture
def service(self):
"""创建服务实例"""
mock_session = AsyncMock()
mock_slot_def_service = AsyncMock()
service = SlotBackfillService(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
service._slot_def_service = mock_slot_def_service
return service
@pytest.mark.asyncio
async def test_backfill_from_user_response_success(self, service):
"""测试从用户回复成功提取并回填"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(service, '_extract_value') as mock_extract:
mock_extract.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(service, 'backfill_single_slot') as mock_backfill:
mock_backfill.return_value = BackfillResult(
status=BackfillStatus.SUCCESS,
slot_key="region",
value="北京",
)
result = await service.backfill_from_user_response(
user_response="我想查询北京的产品",
expected_slots=["region"],
)
assert result.success_count == 1
@pytest.mark.asyncio
async def test_backfill_from_user_response_no_definition(self, service):
"""测试槽位定义不存在"""
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=None
)
result = await service.backfill_from_user_response(
user_response="我想查询北京的产品",
expected_slots=["unknown_slot"],
)
assert result.success_count == 0
assert result.failed_count == 0
@pytest.mark.asyncio
async def test_backfill_from_user_response_extraction_failed(self, service):
"""测试提取失败"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(service, '_extract_value') as mock_extract:
mock_extract.return_value = StrategyChainResult(
slot_key="region",
success=False,
)
result = await service.backfill_from_user_response(
user_response="我想查询产品",
expected_slots=["region"],
)
assert result.failed_count == 1
assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED

View File

@ -0,0 +1,335 @@
"""
Tests for Slot Extraction Integration.
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_extraction_integration import (
ExtractionResult,
ExtractionTrace,
SlotExtractionIntegration,
integrate_slot_extraction,
)
from app.services.mid.slot_strategy_executor import (
StrategyChainResult,
StrategyStepResult,
)
class TestExtractionTrace:
"""ExtractionTrace 测试"""
def test_init(self):
"""测试初始化"""
trace = ExtractionTrace(slot_key="region")
assert trace.slot_key == "region"
assert trace.strategy is None
assert trace.validation_passed is False
def test_to_dict(self):
"""测试转换为字典"""
trace = ExtractionTrace(
slot_key="region",
strategy="rule",
extracted_value="北京",
validation_passed=True,
final_value="北京",
execution_time_ms=10.5,
)
d = trace.to_dict()
assert d["slot_key"] == "region"
assert d["strategy"] == "rule"
assert d["extracted_value"] == "北京"
assert d["validation_passed"] is True
class TestExtractionResult:
"""ExtractionResult 测试"""
def test_init(self):
"""测试初始化"""
result = ExtractionResult()
assert result.success is False
assert result.extracted_slots == {}
assert result.failed_slots == []
def test_to_dict(self):
"""测试转换为字典"""
result = ExtractionResult(
success=True,
extracted_slots={"region": "北京"},
failed_slots=["product"],
traces=[ExtractionTrace(slot_key="region")],
total_execution_time_ms=50.0,
ask_back_triggered=True,
ask_back_prompts=["请提供产品信息"],
)
d = result.to_dict()
assert d["success"] is True
assert d["extracted_slots"] == {"region": "北京"}
assert d["failed_slots"] == ["product"]
assert d["ask_back_triggered"] is True
class TestSlotExtractionIntegration:
"""SlotExtractionIntegration 测试"""
@pytest.fixture
def mock_session(self):
"""创建 mock session"""
return AsyncMock()
@pytest.fixture
def integration(self, mock_session):
"""创建集成实例"""
return SlotExtractionIntegration(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
def test_default_strategies(self, integration):
"""测试默认策略"""
assert integration.DEFAULT_STRATEGIES == ["rule", "llm"]
def test_get_source_for_strategy(self, integration):
"""测试策略到来源的映射"""
assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
def test_get_confidence_for_source(self, integration):
"""测试来源到置信度的映射"""
assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0
assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9
assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7
@pytest.mark.asyncio
async def test_extract_and_fill_no_target_slots(self, integration):
"""测试没有目标槽位"""
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=[],
)
assert result.success is True
assert result.extracted_slots == {}
@pytest.mark.asyncio
async def test_extract_and_fill_slot_not_found(self, integration):
"""测试槽位定义不存在"""
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=None
)
integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None)
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=["unknown_slot"],
)
assert result.success is False
assert "unknown_slot" in result.failed_slots
assert result.traces[0].failure_reason == "Slot definition not found"
@pytest.mark.asyncio
async def test_extract_and_fill_extraction_success(self, integration):
"""测试提取成功"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
mock_backfill = AsyncMock()
mock_backfill.backfill_single_slot = AsyncMock(
return_value=MagicMock(
is_success=lambda: True,
normalized_value="北京",
error_message=None,
)
)
mock_backfill_svc.return_value = mock_backfill
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
result = await integration.extract_and_fill(
user_input="我想查询北京的产品",
target_slots=["region"],
)
assert result.success is True
assert "region" in result.extracted_slots
assert result.extracted_slots["region"] == "北京"
@pytest.mark.asyncio
async def test_extract_and_fill_extraction_failed(self, integration):
"""测试提取失败"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=False,
steps=[
StrategyStepResult(
strategy="rule",
success=False,
failure_reason="无法提取",
)
],
)
integration._slot_manager.get_ask_back_prompt = AsyncMock(
return_value="请提供地区"
)
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=["region"],
)
assert result.success is False
assert "region" in result.failed_slots
assert result.ask_back_triggered is True
assert "请提供地区" in result.ask_back_prompts
@pytest.mark.asyncio
async def test_get_missing_required_slots_from_state(self, integration):
"""测试从状态获取缺失槽位"""
from app.services.mid.slot_state_aggregator import SlotState
slot_state = SlotState()
slot_state.missing_required_slots = [
{"slot_key": "region"},
{"slot_key": "product"},
]
result = await integration._get_missing_required_slots(slot_state)
assert "region" in result
assert "product" in result
@pytest.mark.asyncio
async def test_get_missing_required_slots_from_db(self, integration):
"""测试从数据库获取缺失槽位"""
mock_defs = [
MagicMock(slot_key="region"),
MagicMock(slot_key="product"),
]
integration._slot_def_service.list_slot_definitions = AsyncMock(
return_value=mock_defs
)
result = await integration._get_missing_required_slots(None)
assert "region" in result
assert "product" in result
class TestIntegrateSlotExtraction:
"""integrate_slot_extraction 便捷函数测试"""
@pytest.mark.asyncio
async def test_integrate(self):
"""测试便捷函数"""
mock_session = AsyncMock()
with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls:
mock_instance = AsyncMock()
mock_instance.extract_and_fill = AsyncMock(
return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"})
)
mock_cls.return_value = mock_instance
result = await integrate_slot_extraction(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
user_input="我想查询北京的产品",
)
assert result.success is True
assert "region" in result.extracted_slots
class TestExtractionTraceFlow:
"""提取追踪流程测试"""
@pytest.fixture
def integration(self):
"""创建集成实例"""
mock_session = AsyncMock()
return SlotExtractionIntegration(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
@pytest.mark.asyncio
async def test_full_extraction_flow(self, integration):
"""测试完整提取流程"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = None
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
mock_backfill = AsyncMock()
mock_backfill.backfill_single_slot = AsyncMock(
return_value=MagicMock(
is_success=lambda: True,
normalized_value="北京",
error_message=None,
)
)
mock_backfill_svc.return_value = mock_backfill
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
result = await integration.extract_and_fill(
user_input="北京",
target_slots=["region"],
)
assert len(result.traces) == 1
trace = result.traces[0]
assert trace.slot_key == "region"
assert trace.strategy == "rule"
assert trace.extracted_value == "北京"
assert trace.validation_passed is True
assert trace.final_value == "北京"

View File

@ -0,0 +1,256 @@
"""
Tests for Slot State Aggregator.
[AC-MRS-SLOT-META-01] 槽位状态聚合服务测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.models.mid.schemas import MemorySlot, SlotSource
from app.services.mid.slot_state_aggregator import (
SlotState,
SlotStateAggregator,
create_slot_state_aggregator,
)
class TestSlotState:
"""测试 SlotState 数据类"""
def test_slot_state_initialization(self):
"""测试 SlotState 初始化"""
state = SlotState()
assert state.filled_slots == {}
assert state.missing_required_slots == []
assert state.slot_sources == {}
assert state.slot_confidence == {}
assert state.slot_to_field_map == {}
def test_get_value_for_filter_direct_match(self):
"""测试直接匹配获取过滤值"""
state = SlotState(
filled_slots={"product_line": "vip_course"},
slot_to_field_map={},
)
value = state.get_value_for_filter("product_line")
assert value == "vip_course"
def test_get_value_for_filter_via_mapping(self):
"""测试通过映射获取过滤值"""
state = SlotState(
filled_slots={"product": "vip_course"},
slot_to_field_map={"product": "product_line"},
)
value = state.get_value_for_filter("product_line")
assert value == "vip_course"
def test_get_value_for_filter_not_found(self):
"""测试获取不存在的过滤值"""
state = SlotState(filled_slots={})
value = state.get_value_for_filter("non_existent")
assert value is None
def test_to_debug_info(self):
"""测试转换为调试信息"""
state = SlotState(
filled_slots={"key": "value"},
missing_required_slots=[{"slot_key": "missing"}],
)
debug_info = state.to_debug_info()
assert debug_info["filled_slots"] == {"key": "value"}
assert len(debug_info["missing_required_slots"]) == 1
class TestSlotStateAggregator:
"""测试 SlotStateAggregator"""
@pytest.fixture
def mock_session(self):
"""模拟数据库会话"""
return AsyncMock()
@pytest.fixture
def aggregator(self, mock_session):
"""创建聚合器实例"""
return SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
@pytest.mark.asyncio
async def test_aggregate_from_memory_slots(self, aggregator, mock_session):
"""测试从 memory_slots 初始化"""
memory_slots = {
"product_line": MemorySlot(
key="product_line",
value="vip_course",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
}
# 模拟槽位定义服务返回空列表(没有 required 槽位)
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=memory_slots,
current_input_slots=None,
context=None,
)
assert state.filled_slots["product_line"] == "vip_course"
assert state.slot_sources["product_line"] == "user_confirmed"
assert state.slot_confidence["product_line"] == 1.0
@pytest.mark.asyncio
async def test_aggregate_current_input_priority(self, aggregator, mock_session):
"""测试当前输入优先级高于 memory"""
memory_slots = {
"product_line": MemorySlot(
key="product_line",
value="old_value",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
}
current_input = {"product_line": "new_value"}
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=memory_slots,
current_input_slots=current_input,
context=None,
)
# 当前输入应该覆盖 memory 的值
assert state.filled_slots["product_line"] == "new_value"
assert state.slot_sources["product_line"] == "user_confirmed"
@pytest.mark.asyncio
async def test_aggregate_extract_from_context(self, aggregator, mock_session):
"""测试从 context 提取槽位值"""
context = {
"scene": "open_consult",
"product_line": "vip_course",
"other_key": "other_value",
}
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=None,
current_input_slots=None,
context=context,
)
# 应该提取 scene 和 product_line
assert state.filled_slots.get("scene") == "open_consult"
assert state.filled_slots.get("product_line") == "vip_course"
assert state.slot_sources.get("scene") == "context"
@pytest.mark.asyncio
async def test_generate_ask_back_response_with_prompt(self, aggregator):
"""测试生成追问响应 - 使用配置的 ask_back_prompt"""
state = SlotState(
missing_required_slots=[
{
"slot_key": "region",
"label": "地区",
"reason": "required_slot_missing",
"ask_back_prompt": "请问您在哪个地区?",
}
]
)
response = await aggregator.generate_ask_back_response(state)
assert response == "请问您在哪个地区?"
@pytest.mark.asyncio
async def test_generate_ask_back_response_generic(self, aggregator):
"""测试生成追问响应 - 使用通用模板"""
state = SlotState(
missing_required_slots=[
{
"slot_key": "region",
"label": "地区",
"reason": "required_slot_missing",
# 没有 ask_back_prompt
}
]
)
response = await aggregator.generate_ask_back_response(state)
assert "地区" in response
@pytest.mark.asyncio
async def test_generate_ask_back_response_no_missing(self, aggregator):
"""测试没有缺失槽位时返回 None"""
state = SlotState(missing_required_slots=[])
response = await aggregator.generate_ask_back_response(state)
assert response is None
class TestCreateSlotStateAggregator:
"""测试工厂函数"""
def test_create_aggregator(self):
"""测试创建聚合器实例"""
mock_session = MagicMock()
aggregator = create_slot_state_aggregator(
session=mock_session,
tenant_id="test_tenant",
)
assert isinstance(aggregator, SlotStateAggregator)
assert aggregator._tenant_id == "test_tenant"
class TestSlotStateFilterPriority:
"""测试过滤值来源优先级"""
@pytest.mark.asyncio
async def test_filter_priority_slot_first(self):
"""测试优先级slot > context > default"""
from unittest.mock import AsyncMock, MagicMock
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟槽位定义
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "product"
mock_slot_def.linked_field_id = None
mock_slot_def.required = False
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
state = await aggregator.aggregate(
memory_slots={
"product": MemorySlot(
key="product",
value="from_memory",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
},
current_input_slots={"product": "from_input"},
context={"product": "from_context"},
)
# 当前输入应该优先级最高
assert state.filled_slots["product"] == "from_input"

View File

@ -0,0 +1,399 @@
"""
Tests for Slot State Cache.
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化测试
"""
import json
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.cache.slot_state_cache import (
CachedSlotState,
CachedSlotValue,
SlotStateCache,
get_slot_state_cache,
)
class TestCachedSlotValue:
"""CachedSlotValue 测试"""
def test_init(self):
"""测试初始化"""
value = CachedSlotValue(
value="test_value",
source="user_confirmed",
confidence=0.9,
)
assert value.value == "test_value"
assert value.source == "user_confirmed"
assert value.confidence == 0.9
assert value.updated_at > 0
def test_to_dict(self):
"""测试转换为字典"""
value = CachedSlotValue(
value="test_value",
source="rule_extracted",
confidence=0.8,
)
d = value.to_dict()
assert d["value"] == "test_value"
assert d["source"] == "rule_extracted"
assert d["confidence"] == 0.8
assert "updated_at" in d
def test_from_dict(self):
"""测试从字典创建"""
d = {
"value": "test_value",
"source": "llm_inferred",
"confidence": 0.7,
"updated_at": 12345.0,
}
value = CachedSlotValue.from_dict(d)
assert value.value == "test_value"
assert value.source == "llm_inferred"
assert value.confidence == 0.7
assert value.updated_at == 12345.0
class TestCachedSlotState:
"""CachedSlotState 测试"""
def test_init(self):
"""测试初始化"""
state = CachedSlotState()
assert state.filled_slots == {}
assert state.slot_to_field_map == {}
assert state.created_at > 0
assert state.updated_at > 0
def test_with_slots(self):
"""测试带槽位初始化"""
slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed"),
"product": CachedSlotValue(value="手机", source="rule_extracted"),
}
state = CachedSlotState(
filled_slots=slots,
slot_to_field_map={"region": "region_field"},
)
assert len(state.filled_slots) == 2
assert state.slot_to_field_map["region"] == "region_field"
def test_to_dict_and_from_dict(self):
"""测试序列化和反序列化"""
slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed"),
}
original = CachedSlotState(
filled_slots=slots,
slot_to_field_map={"region": "region_field"},
)
d = original.to_dict()
restored = CachedSlotState.from_dict(d)
assert len(restored.filled_slots) == 1
assert restored.filled_slots["region"].value == "北京"
assert restored.filled_slots["region"].source == "user_confirmed"
assert restored.slot_to_field_map["region"] == "region_field"
def test_get_simple_filled_slots(self):
"""测试获取简化槽位字典"""
slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed"),
"product": CachedSlotValue(value="手机", source="rule_extracted"),
}
state = CachedSlotState(filled_slots=slots)
simple = state.get_simple_filled_slots()
assert simple == {"region": "北京", "product": "手机"}
def test_get_slot_sources(self):
"""测试获取槽位来源"""
slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed"),
"product": CachedSlotValue(value="手机", source="rule_extracted"),
}
state = CachedSlotState(filled_slots=slots)
sources = state.get_slot_sources()
assert sources == {"region": "user_confirmed", "product": "rule_extracted"}
def test_get_slot_confidence(self):
"""测试获取槽位置信度"""
slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
"product": CachedSlotValue(value="手机", source="rule_extracted", confidence=0.8),
}
state = CachedSlotState(filled_slots=slots)
confidence = state.get_slot_confidence()
assert confidence == {"region": 1.0, "product": 0.8}
class TestSlotStateCache:
"""SlotStateCache 测试"""
def test_source_priority(self):
"""测试来源优先级"""
cache = SlotStateCache()
assert cache._get_source_priority("user_confirmed") == 100
assert cache._get_source_priority("rule_extracted") == 80
assert cache._get_source_priority("llm_inferred") == 60
assert cache._get_source_priority("context") == 40
assert cache._get_source_priority("default") == 20
assert cache._get_source_priority("unknown") == 0
def test_make_key(self):
"""测试 key 生成"""
cache = SlotStateCache()
key = cache._make_key("tenant_123", "session_456")
assert key == "slot_state:tenant_123:session_456"
@pytest.mark.asyncio
async def test_l1_cache_hit(self):
"""测试 L1 缓存命中"""
cache = SlotStateCache()
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
)
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, time.time())
result = await cache.get(tenant_id, session_id)
assert result is not None
assert result.filled_slots["region"].value == "北京"
@pytest.mark.asyncio
async def test_l1_cache_expired(self):
"""测试 L1 缓存过期"""
cache = SlotStateCache()
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
)
old_time = time.time() - 400
cache._local_cache[f"{tenant_id}:{session_id}"] = (state, old_time)
result = await cache.get(tenant_id, session_id)
assert result is None
assert f"{tenant_id}:{session_id}" not in cache._local_cache
@pytest.mark.asyncio
async def test_set_and_get_l1(self):
"""测试设置和获取 L1 缓存"""
cache = SlotStateCache(redis_client=None)
cache._enabled = False
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
)
await cache.set(tenant_id, session_id, state)
local_key = f"{tenant_id}:{session_id}"
assert local_key in cache._local_cache
result = await cache.get(tenant_id, session_id)
assert result is not None
assert result.filled_slots["region"].value == "北京"
@pytest.mark.asyncio
async def test_delete(self):
"""测试删除缓存"""
cache = SlotStateCache(redis_client=None)
cache._enabled = False
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
)
await cache.set(tenant_id, session_id, state)
await cache.delete(tenant_id, session_id)
result = await cache.get(tenant_id, session_id)
assert result is None
@pytest.mark.asyncio
async def test_clear_slot(self):
"""测试清除单个槽位"""
cache = SlotStateCache(redis_client=None)
cache._enabled = False
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={
"region": CachedSlotValue(value="北京", source="user_confirmed"),
"product": CachedSlotValue(value="手机", source="rule_extracted"),
},
)
await cache.set(tenant_id, session_id, state)
await cache.clear_slot(tenant_id, session_id, "region")
result = await cache.get(tenant_id, session_id)
assert result is not None
assert "region" not in result.filled_slots
assert "product" in result.filled_slots
@pytest.mark.asyncio
async def test_merge_and_set_priority(self):
"""测试合并时优先级处理"""
cache = SlotStateCache(redis_client=None)
cache._enabled = False
tenant_id = "tenant_1"
session_id = "session_1"
existing_state = CachedSlotState(
filled_slots={
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
},
)
await cache.set(tenant_id, session_id, existing_state)
new_slots = {
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
}
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
assert result.filled_slots["region"].value == "北京"
assert result.filled_slots["region"].source == "user_confirmed"
@pytest.mark.asyncio
async def test_merge_and_set_lower_priority_ignored(self):
"""测试低优先级值被忽略"""
cache = SlotStateCache(redis_client=None)
cache._enabled = False
tenant_id = "tenant_1"
session_id = "session_1"
existing_state = CachedSlotState(
filled_slots={
"region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0),
},
)
await cache.set(tenant_id, session_id, existing_state)
new_slots = {
"region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6),
}
result = await cache.merge_and_set(tenant_id, session_id, new_slots)
assert result.filled_slots["region"].value == "北京"
assert result.filled_slots["region"].source == "user_confirmed"
class TestGetSlotStateCache:
"""get_slot_state_cache 单例测试"""
def test_singleton(self):
"""测试单例模式"""
cache1 = get_slot_state_cache()
cache2 = get_slot_state_cache()
assert cache1 is cache2
class TestSlotStateCacheWithRedis:
"""SlotStateCache Redis 集成测试"""
@pytest.mark.asyncio
async def test_redis_set_and_get(self):
"""测试 Redis 存取"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis.setex = AsyncMock(return_value=True)
cache = SlotStateCache(redis_client=mock_redis)
tenant_id = "tenant_1"
session_id = "session_1"
state = CachedSlotState(
filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")},
)
await cache.set(tenant_id, session_id, state)
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0] == f"slot_state:{tenant_id}:{session_id}"
@pytest.mark.asyncio
async def test_redis_get_hit(self):
"""测试 Redis 命中"""
state_dict = {
"filled_slots": {
"region": {
"value": "北京",
"source": "user_confirmed",
"confidence": 1.0,
"updated_at": 12345.0,
}
},
"slot_to_field_map": {"region": "region_field"},
"created_at": 12340.0,
"updated_at": 12345.0,
}
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=json.dumps(state_dict))
cache = SlotStateCache(redis_client=mock_redis)
tenant_id = "tenant_1"
session_id = "session_1"
result = await cache.get(tenant_id, session_id)
assert result is not None
assert result.filled_slots["region"].value == "北京"
assert result.filled_slots["region"].source == "user_confirmed"
@pytest.mark.asyncio
async def test_redis_delete(self):
"""测试 Redis 删除"""
mock_redis = AsyncMock()
mock_redis.delete = AsyncMock(return_value=1)
cache = SlotStateCache(redis_client=mock_redis)
tenant_id = "tenant_1"
session_id = "session_1"
await cache.delete(tenant_id, session_id)
mock_redis.delete.assert_called_once_with(f"slot_state:{tenant_id}:{session_id}")
class TestCacheTTL:
"""TTL 配置测试"""
def test_default_ttl(self):
"""测试默认 TTL"""
cache = SlotStateCache()
assert cache._cache_ttl == 1800
def test_local_cache_ttl(self):
"""测试本地缓存 TTL"""
cache = SlotStateCache()
assert cache._local_cache_ttl == 300

View File

@ -0,0 +1,244 @@
"""
Tests for Slot Strategy Executor.
[AC-MRS-07-UPGRADE] 提取策略链执行器测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from app.models.entities import ExtractFailureType
from app.services.mid.slot_strategy_executor import (
SlotStrategyExecutor,
ExtractContext,
StrategyChainResult,
execute_extract_strategies,
)
class TestSlotStrategyExecutor:
"""测试槽位策略执行器"""
@pytest.fixture
def executor(self):
"""创建执行器实例"""
return SlotStrategyExecutor()
@pytest.fixture
def context(self):
"""创建测试上下文"""
return ExtractContext(
tenant_id="test-tenant",
slot_key="grade",
user_input="我想了解初一语文课程",
slot_type="string",
)
@pytest.mark.asyncio
async def test_execute_chain_success_on_first_step(self, executor, context):
"""测试第一步成功时停止"""
# Mock rule extractor 成功
mock_rule = AsyncMock(return_value="初一")
executor._extractors["rule"] = mock_rule
result = await executor.execute_chain(
strategies=["rule", "llm", "user_input"],
context=context,
)
assert result.success is True
assert result.final_value == "初一"
assert result.final_strategy == "rule"
assert len(result.steps) == 1
assert result.steps[0].success is True
mock_rule.assert_called_once()
@pytest.mark.asyncio
async def test_execute_chain_fallback_to_second_step(self, executor, context):
"""测试第一步失败,第二步成功"""
# Mock rule extractor 失败(返回空)
mock_rule = AsyncMock(return_value=None)
# Mock llm extractor 成功
mock_llm = AsyncMock(return_value="初一")
executor._extractors["rule"] = mock_rule
executor._extractors["llm"] = mock_llm
result = await executor.execute_chain(
strategies=["rule", "llm", "user_input"],
context=context,
)
assert result.success is True
assert result.final_value == "初一"
assert result.final_strategy == "llm"
assert len(result.steps) == 2
assert result.steps[0].success is False
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_EMPTY
assert result.steps[1].success is True
@pytest.mark.asyncio
async def test_execute_chain_all_failed(self, executor, context):
"""测试所有策略都失败"""
# Mock 所有 extractor 都失败
mock_rule = AsyncMock(return_value=None)
mock_llm = AsyncMock(return_value=None)
mock_user_input = AsyncMock(return_value=None)
executor._extractors["rule"] = mock_rule
executor._extractors["llm"] = mock_llm
executor._extractors["user_input"] = mock_user_input
result = await executor.execute_chain(
strategies=["rule", "llm", "user_input"],
context=context,
ask_back_prompt="请告诉我您的年级",
)
assert result.success is False
assert result.final_value is None
assert result.final_strategy is None
assert len(result.steps) == 3
assert result.ask_back_prompt == "请告诉我您的年级"
# 所有步骤都失败
for step in result.steps:
assert step.success is False
assert step.failure_type == ExtractFailureType.EXTRACT_EMPTY
@pytest.mark.asyncio
async def test_execute_chain_validation_failure(self, executor, context):
"""测试校验失败的情况"""
context.validation_rule = r"^初[一二三]$" # 只允许初一/初二/初三
# Mock rule extractor 返回不符合校验的值
mock_rule = AsyncMock(return_value="高一")
executor._extractors["rule"] = mock_rule
result = await executor.execute_chain(
strategies=["rule"],
context=context,
)
assert result.success is False
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_VALIDATION_FAIL
assert "Validation failed" in result.steps[0].failure_reason
@pytest.mark.asyncio
async def test_execute_chain_runtime_error(self, executor, context):
"""测试运行时错误"""
# Mock rule extractor 抛出异常
mock_rule = AsyncMock(side_effect=Exception("LLM service unavailable"))
executor._extractors["rule"] = mock_rule
result = await executor.execute_chain(
strategies=["rule"],
context=context,
)
assert result.success is False
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR
assert "Runtime error" in result.steps[0].failure_reason
@pytest.mark.asyncio
async def test_execute_chain_empty_strategies(self, executor, context):
"""测试空策略链"""
result = await executor.execute_chain(
strategies=[],
context=context,
)
assert result.success is False
assert len(result.steps) == 0
@pytest.mark.asyncio
async def test_execute_chain_unknown_strategy(self, executor, context):
"""测试未知策略"""
result = await executor.execute_chain(
strategies=["unknown_strategy"],
context=context,
)
assert result.success is False
assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR
assert "Unknown strategy" in result.steps[0].failure_reason
@pytest.mark.asyncio
async def test_execute_chain_result_to_dict(self, executor, context):
"""测试结果转换为字典"""
mock_rule = AsyncMock(return_value="初一")
executor._extractors["rule"] = mock_rule
result = await executor.execute_chain(
strategies=["rule"],
context=context,
)
result_dict = result.to_dict()
assert result_dict["slot_key"] == "grade"
assert result_dict["success"] is True
assert result_dict["final_value"] == "初一"
assert result_dict["final_strategy"] == "rule"
assert "steps" in result_dict
assert "total_execution_time_ms" in result_dict
class TestExecuteExtractStrategies:
"""测试便捷函数"""
@pytest.mark.asyncio
async def test_convenience_function(self):
"""测试便捷函数 execute_extract_strategies"""
mock_rule = AsyncMock(return_value="初一")
result = await execute_extract_strategies(
strategies=["rule"],
tenant_id="test-tenant",
slot_key="grade",
user_input="我想了解初一语文课程",
rule_extractor=mock_rule,
)
assert result.success is True
assert result.final_value == "初一"
@pytest.mark.asyncio
async def test_convenience_function_with_validation(self):
"""测试带校验的便捷函数"""
mock_rule = AsyncMock(return_value="初一")
result = await execute_extract_strategies(
strategies=["rule"],
tenant_id="test-tenant",
slot_key="grade",
user_input="我想了解初一语文课程",
validation_rule=r"^初[一二三]$",
rule_extractor=mock_rule,
)
assert result.success is True
assert result.final_value == "初一"
class TestExtractContext:
"""测试提取上下文"""
def test_context_creation(self):
"""测试上下文创建"""
context = ExtractContext(
tenant_id="test-tenant",
slot_key="grade",
user_input="测试输入",
slot_type="string",
validation_rule=r"^初[一二三]$",
history=[{"role": "user", "content": "你好"}],
session_id="session-123",
)
assert context.tenant_id == "test-tenant"
assert context.slot_key == "grade"
assert context.user_input == "测试输入"
assert context.slot_type == "string"
assert context.validation_rule == r"^初[一二三]$"
assert len(context.history) == 1
assert context.session_id == "session-123"

View File

@ -0,0 +1,541 @@
"""
Tests for Slot Validation Service.
槽位校验服务单元测试
"""
import pytest
from app.services.mid.slot_validation_service import (
SlotValidationService,
SlotValidationErrorCode,
ValidationResult,
SlotValidationError,
BatchValidationResult,
)
class TestSlotValidationService:
"""槽位校验服务测试类"""
@pytest.fixture
def service(self):
"""创建校验服务实例"""
return SlotValidationService()
@pytest.fixture
def string_slot_def(self):
"""字符串类型槽位定义"""
return {
"slot_key": "name",
"type": "string",
"required": False,
"validation_rule": None,
"ask_back_prompt": "请输入您的姓名",
}
@pytest.fixture
def required_string_slot_def(self):
"""必填字符串类型槽位定义"""
return {
"slot_key": "phone",
"type": "string",
"required": True,
"validation_rule": r"^1[3-9]\d{9}$",
"ask_back_prompt": "请输入正确的手机号码",
}
@pytest.fixture
def number_slot_def(self):
"""数字类型槽位定义"""
return {
"slot_key": "age",
"type": "number",
"required": False,
"validation_rule": None,
"ask_back_prompt": "请输入年龄",
}
@pytest.fixture
def boolean_slot_def(self):
"""布尔类型槽位定义"""
return {
"slot_key": "is_student",
"type": "boolean",
"required": False,
"validation_rule": None,
"ask_back_prompt": "是否是学生?",
}
@pytest.fixture
def enum_slot_def(self):
"""枚举类型槽位定义"""
return {
"slot_key": "grade",
"type": "enum",
"required": False,
"options": ["初一", "初二", "初三", "高一", "高二", "高三"],
"validation_rule": None,
"ask_back_prompt": "请选择年级",
}
@pytest.fixture
def array_enum_slot_def(self):
"""数组枚举类型槽位定义"""
return {
"slot_key": "subjects",
"type": "array_enum",
"required": False,
"options": ["语文", "数学", "英语", "物理", "化学"],
"validation_rule": None,
"ask_back_prompt": "请选择学科",
}
@pytest.fixture
def json_schema_slot_def(self):
"""JSON Schema 校验槽位定义"""
return {
"slot_key": "email",
"type": "string",
"required": True,
"validation_rule": '{"type": "string", "format": "email"}',
"ask_back_prompt": "请输入有效的邮箱地址",
}
class TestBasicValidation:
"""基础校验测试"""
def test_empty_validation_rule(self, service, string_slot_def):
"""测试空校验规则(应通过)"""
string_slot_def["validation_rule"] = None
result = service.validate_slot_value(string_slot_def, "test")
assert result.ok is True
assert result.normalized_value == "test"
def test_whitespace_validation_rule(self, service, string_slot_def):
"""测试空白校验规则(应通过)"""
string_slot_def["validation_rule"] = " "
result = service.validate_slot_value(string_slot_def, "test")
assert result.ok is True
def test_no_slot_definition(self, service):
"""测试无槽位定义(动态槽位)"""
# 使用最小定义
minimal_def = {"slot_key": "dynamic_field"}
result = service.validate_slot_value(minimal_def, "any_value")
assert result.ok is True
class TestRegexValidation:
"""正则表达式校验测试"""
def test_regex_match(self, service, required_string_slot_def):
"""测试正则匹配成功"""
result = service.validate_slot_value(
required_string_slot_def, "13800138000"
)
assert result.ok is True
assert result.normalized_value == "13800138000"
def test_regex_mismatch(self, service, required_string_slot_def):
"""测试正则匹配失败"""
result = service.validate_slot_value(
required_string_slot_def, "invalid_phone"
)
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH
assert result.ask_back_prompt == "请输入正确的手机号码"
def test_regex_invalid_pattern(self, service, string_slot_def):
"""测试非法正则表达式"""
string_slot_def["validation_rule"] = "[invalid("
result = service.validate_slot_value(string_slot_def, "test")
assert result.ok is False
assert (
result.error_code
== SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
)
def test_regex_with_chinese(self, service, string_slot_def):
"""测试包含中文的正则"""
string_slot_def["validation_rule"] = r"^[\u4e00-\u9fa5]{2,4}$"
result = service.validate_slot_value(string_slot_def, "张三")
assert result.ok is True
result = service.validate_slot_value(string_slot_def, "John")
assert result.ok is False
class TestJsonSchemaValidation:
"""JSON Schema 校验测试"""
def test_json_schema_match(self, service):
"""测试 JSON Schema 匹配成功"""
slot_def = {
"slot_key": "config",
"type": "object",
"validation_rule": '{"type": "object", "properties": {"name": {"type": "string"}}}',
}
result = service.validate_slot_value(slot_def, {"name": "test"})
assert result.ok is True
def test_json_schema_mismatch(self, service):
"""测试 JSON Schema 匹配失败"""
slot_def = {
"slot_key": "count",
"type": "number",
"validation_rule": '{"type": "integer", "minimum": 0, "maximum": 100}',
"ask_back_prompt": "请输入0-100之间的整数",
}
result = service.validate_slot_value(slot_def, 150)
assert result.ok is False
assert (
result.error_code == SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH
)
assert result.ask_back_prompt == "请输入0-100之间的整数"
def test_json_schema_invalid_json(self, service, string_slot_def):
"""测试非法 JSON Schema"""
string_slot_def["validation_rule"] = "{invalid json}"
result = service.validate_slot_value(string_slot_def, "test")
assert result.ok is False
assert (
result.error_code
== SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
)
def test_json_schema_array(self, service):
"""测试数组类型的 JSON Schema"""
slot_def = {
"slot_key": "items",
"type": "array",
"validation_rule": '{"type": "array", "items": {"type": "string"}}',
}
result = service.validate_slot_value(slot_def, ["a", "b", "c"])
assert result.ok is True
result = service.validate_slot_value(slot_def, [1, 2, 3])
assert result.ok is False
class TestRequiredValidation:
"""必填校验测试"""
def test_required_missing_none(self, service, required_string_slot_def):
"""测试必填字段为 None"""
result = service.validate_slot_value(
required_string_slot_def, None
)
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
def test_required_missing_empty_string(self, service, required_string_slot_def):
"""测试必填字段为空字符串"""
result = service.validate_slot_value(required_string_slot_def, "")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
def test_required_missing_whitespace(self, service, required_string_slot_def):
"""测试必填字段为空白字符"""
result = service.validate_slot_value(required_string_slot_def, " ")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING
def test_required_present(self, service, required_string_slot_def):
"""测试必填字段有值"""
result = service.validate_slot_value(
required_string_slot_def, "13800138000"
)
assert result.ok is True
def test_not_required_empty(self, service, string_slot_def):
"""测试非必填字段为空"""
result = service.validate_slot_value(string_slot_def, "")
assert result.ok is True
class TestTypeValidation:
"""类型校验测试"""
def test_string_type(self, service, string_slot_def):
"""测试字符串类型"""
result = service.validate_slot_value(string_slot_def, "hello")
assert result.ok is True
assert result.normalized_value == "hello"
def test_string_type_conversion(self, service, string_slot_def):
"""测试字符串类型自动转换"""
result = service.validate_slot_value(string_slot_def, 123)
assert result.ok is True
assert result.normalized_value == "123"
def test_number_type_integer(self, service, number_slot_def):
"""测试数字类型 - 整数"""
result = service.validate_slot_value(number_slot_def, 25)
assert result.ok is True
assert result.normalized_value == 25
def test_number_type_float(self, service, number_slot_def):
"""测试数字类型 - 浮点数"""
result = service.validate_slot_value(number_slot_def, 25.5)
assert result.ok is True
assert result.normalized_value == 25.5
def test_number_type_string_conversion(self, service, number_slot_def):
"""测试数字类型 - 字符串转换"""
result = service.validate_slot_value(number_slot_def, "30")
assert result.ok is True
assert result.normalized_value == 30
def test_number_type_invalid(self, service, number_slot_def):
"""测试数字类型 - 无效值"""
result = service.validate_slot_value(number_slot_def, "not_a_number")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
def test_number_type_reject_boolean(self, service, number_slot_def):
"""测试数字类型 - 拒绝布尔值"""
result = service.validate_slot_value(number_slot_def, True)
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
def test_boolean_type_true(self, service, boolean_slot_def):
"""测试布尔类型 - True"""
result = service.validate_slot_value(boolean_slot_def, True)
assert result.ok is True
assert result.normalized_value is True
def test_boolean_type_false(self, service, boolean_slot_def):
"""测试布尔类型 - False"""
result = service.validate_slot_value(boolean_slot_def, False)
assert result.ok is True
assert result.normalized_value is False
def test_boolean_type_string_true(self, service, boolean_slot_def):
"""测试布尔类型 - 字符串 true"""
result = service.validate_slot_value(boolean_slot_def, "true")
assert result.ok is True
assert result.normalized_value is True
def test_boolean_type_string_yes(self, service, boolean_slot_def):
"""测试布尔类型 - 字符串 yes/是"""
result = service.validate_slot_value(boolean_slot_def, "")
assert result.ok is True
assert result.normalized_value is True
def test_boolean_type_string_false(self, service, boolean_slot_def):
"""测试布尔类型 - 字符串 false"""
result = service.validate_slot_value(boolean_slot_def, "false")
assert result.ok is True
assert result.normalized_value is False
def test_boolean_type_invalid(self, service, boolean_slot_def):
"""测试布尔类型 - 无效值"""
result = service.validate_slot_value(boolean_slot_def, "maybe")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
def test_enum_type_valid(self, service, enum_slot_def):
"""测试枚举类型 - 有效值"""
result = service.validate_slot_value(enum_slot_def, "高一")
assert result.ok is True
assert result.normalized_value == "高一"
def test_enum_type_invalid(self, service, enum_slot_def):
"""测试枚举类型 - 无效值"""
result = service.validate_slot_value(enum_slot_def, "大一")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_ENUM_INVALID
def test_enum_type_not_string(self, service, enum_slot_def):
"""测试枚举类型 - 非字符串"""
result = service.validate_slot_value(enum_slot_def, 123)
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
def test_array_enum_type_valid(self, service, array_enum_slot_def):
"""测试数组枚举类型 - 有效值"""
result = service.validate_slot_value(
array_enum_slot_def, ["语文", "数学"]
)
assert result.ok is True
def test_array_enum_type_invalid_item(self, service, array_enum_slot_def):
"""测试数组枚举类型 - 无效元素"""
result = service.validate_slot_value(
array_enum_slot_def, ["语文", "生物"]
)
assert result.ok is False
assert (
result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID
)
def test_array_enum_type_not_array(self, service, array_enum_slot_def):
"""测试数组枚举类型 - 非数组"""
result = service.validate_slot_value(array_enum_slot_def, "语文")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
def test_array_enum_type_non_string_item(self, service, array_enum_slot_def):
"""测试数组枚举类型 - 非字符串元素"""
result = service.validate_slot_value(array_enum_slot_def, ["语文", 123])
assert result.ok is False
assert (
result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID
)
class TestBatchValidation:
"""批量校验测试"""
def test_batch_all_valid(self, service, string_slot_def, number_slot_def):
"""测试批量校验 - 全部通过"""
slot_defs = [string_slot_def, number_slot_def]
values = {"name": "张三", "age": 25}
result = service.validate_slots(slot_defs, values)
assert result.ok is True
assert len(result.errors) == 0
assert result.validated_values["name"] == "张三"
assert result.validated_values["age"] == 25
def test_batch_some_invalid(self, service, string_slot_def, number_slot_def):
"""测试批量校验 - 部分失败"""
slot_defs = [string_slot_def, number_slot_def]
values = {"name": "张三", "age": "not_a_number"}
result = service.validate_slots(slot_defs, values)
assert result.ok is False
assert len(result.errors) == 1
assert result.errors[0].slot_key == "age"
def test_batch_missing_required(
self, service, required_string_slot_def, string_slot_def
):
"""测试批量校验 - 缺失必填字段"""
slot_defs = [required_string_slot_def, string_slot_def]
values = {"name": "张三"} # 缺少 phone
result = service.validate_slots(slot_defs, values)
assert result.ok is False
assert len(result.errors) == 1
assert result.errors[0].slot_key == "phone"
assert (
result.errors[0].error_code
== SlotValidationErrorCode.SLOT_REQUIRED_MISSING
)
def test_batch_undefined_slot(self, service, string_slot_def):
"""测试批量校验 - 未定义槽位"""
slot_defs = [string_slot_def]
values = {"name": "张三", "undefined_field": "value"}
result = service.validate_slots(slot_defs, values)
assert result.ok is True
# 未定义槽位应允许通过
assert "undefined_field" in result.validated_values
class TestCombinedValidation:
"""组合校验测试(类型 + 正则/JSON Schema"""
def test_type_and_regex_both_pass(self, service):
"""测试类型和正则都通过"""
slot_def = {
"slot_key": "code",
"type": "string",
"required": True,
"validation_rule": r"^[A-Z]{2}\d{4}$",
}
result = service.validate_slot_value(slot_def, "AB1234")
assert result.ok is True
def test_type_pass_regex_fail(self, service):
"""测试类型通过但正则失败"""
slot_def = {
"slot_key": "code",
"type": "string",
"required": True,
"validation_rule": r"^[A-Z]{2}\d{4}$",
}
result = service.validate_slot_value(slot_def, "ab1234")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH
def test_type_fail_no_regex_check(self, service):
"""测试类型失败时不执行正则校验"""
slot_def = {
"slot_key": "code",
"type": "number",
"required": True,
"validation_rule": r"^\d+$",
}
result = service.validate_slot_value(slot_def, "not_a_number")
assert result.ok is False
assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID
class TestAskBackPrompt:
"""追问提示语测试"""
def test_ask_back_prompt_on_validation_fail(self, service):
"""测试校验失败时返回 ask_back_prompt"""
slot_def = {
"slot_key": "email",
"type": "string",
"required": True,
"validation_rule": r"^[\w\.-]+@[\w\.-]+\.\w+$",
"ask_back_prompt": "请输入有效的邮箱地址,如 example@domain.com",
}
result = service.validate_slot_value(slot_def, "invalid_email")
assert result.ok is False
assert result.ask_back_prompt == "请输入有效的邮箱地址,如 example@domain.com"
def test_no_ask_back_prompt_on_success(self, service, string_slot_def):
"""测试校验通过时不返回 ask_back_prompt"""
result = service.validate_slot_value(string_slot_def, "valid")
assert result.ok is True
assert result.ask_back_prompt is None
def test_ask_back_prompt_on_required_missing(self, service):
"""测试必填缺失时返回 ask_back_prompt"""
slot_def = {
"slot_key": "name",
"type": "string",
"required": True,
"ask_back_prompt": "请告诉我们您的姓名",
}
result = service.validate_slot_value(slot_def, "")
assert result.ok is False
assert result.ask_back_prompt == "请告诉我们您的姓名"
class TestSlotValidationErrorCode:
"""错误码测试"""
def test_error_code_values(self):
"""测试错误码值"""
assert SlotValidationErrorCode.SLOT_REQUIRED_MISSING == "SLOT_REQUIRED_MISSING"
assert SlotValidationErrorCode.SLOT_TYPE_INVALID == "SLOT_TYPE_INVALID"
assert SlotValidationErrorCode.SLOT_REGEX_MISMATCH == "SLOT_REGEX_MISMATCH"
assert (
SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH
== "SLOT_JSON_SCHEMA_MISMATCH"
)
assert (
SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID
== "SLOT_VALIDATION_RULE_INVALID"
)
class TestValidationResult:
"""ValidationResult 测试"""
def test_success_result(self):
"""测试成功结果"""
result = ValidationResult(ok=True, normalized_value="test")
assert result.ok is True
assert result.normalized_value == "test"
assert result.error_code is None
assert result.error_message is None
def test_failure_result(self):
"""测试失败结果"""
result = ValidationResult(
ok=False,
error_code="SLOT_REGEX_MISMATCH",
error_message="格式不正确",
ask_back_prompt="请重新输入",
)
assert result.ok is False
assert result.error_code == "SLOT_REGEX_MISMATCH"
assert result.error_message == "格式不正确"
assert result.ask_back_prompt == "请重新输入"

View File

@ -0,0 +1,328 @@
"""
Test cases for Step-KB Binding feature.
[Step-KB-Binding] 步骤关联知识库功能的测试用例
测试覆盖
1. 步骤配置的增删改查与参数校验
2. 配置步骤KB范围后检索仅在范围内发生
3. 未配置时回退原逻辑
4. 多知识库同名内容场景下步骤约束生效
5. trace 字段完整性校验
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from typing import Any
class TestStepKbBindingModel:
"""测试步骤KB绑定数据模型"""
def test_flow_step_with_kb_binding_fields(self):
"""测试 FlowStep 包含 KB 绑定字段"""
from app.models.entities import FlowStep
step = FlowStep(
step_no=1,
content="测试步骤",
allowed_kb_ids=["kb-1", "kb-2"],
preferred_kb_ids=["kb-1"],
kb_query_hint="查找产品相关信息",
max_kb_calls_per_step=2,
)
assert step.allowed_kb_ids == ["kb-1", "kb-2"]
assert step.preferred_kb_ids == ["kb-1"]
assert step.kb_query_hint == "查找产品相关信息"
assert step.max_kb_calls_per_step == 2
def test_flow_step_without_kb_binding(self):
"""测试 FlowStep 不配置 KB 绑定时的默认值"""
from app.models.entities import FlowStep
step = FlowStep(
step_no=1,
content="测试步骤",
)
assert step.allowed_kb_ids is None
assert step.preferred_kb_ids is None
assert step.kb_query_hint is None
assert step.max_kb_calls_per_step is None
def test_max_kb_calls_validation(self):
"""测试 max_kb_calls_per_step 的范围校验"""
from app.models.entities import FlowStep
from pydantic import ValidationError
# 有效范围 1-5
step = FlowStep(step_no=1, content="test", max_kb_calls_per_step=3)
assert step.max_kb_calls_per_step == 3
# 超出上限
with pytest.raises(Exception): # ValidationError
FlowStep(step_no=1, content="test", max_kb_calls_per_step=10)
class TestStepKbConfig:
"""测试 StepKbConfig 数据类"""
def test_step_kb_config_creation(self):
"""测试 StepKbConfig 创建"""
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
config = StepKbConfig(
allowed_kb_ids=["kb-1", "kb-2"],
preferred_kb_ids=["kb-1"],
kb_query_hint="查找产品信息",
max_kb_calls=2,
step_id="flow-1_step_1",
)
assert config.allowed_kb_ids == ["kb-1", "kb-2"]
assert config.preferred_kb_ids == ["kb-1"]
assert config.kb_query_hint == "查找产品信息"
assert config.max_kb_calls == 2
assert config.step_id == "flow-1_step_1"
def test_step_kb_config_defaults(self):
"""测试 StepKbConfig 默认值"""
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
config = StepKbConfig()
assert config.allowed_kb_ids is None
assert config.preferred_kb_ids is None
assert config.kb_query_hint is None
assert config.max_kb_calls == 1
assert config.step_id is None
class TestKbSearchDynamicToolWithStepConfig:
"""测试 KbSearchDynamicTool 与步骤配置的集成"""
@pytest.mark.asyncio
async def test_kb_search_with_allowed_kb_ids(self):
"""测试配置 allowed_kb_ids 后检索范围受限"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-allowed-1", "kb-allowed-2"],
step_id="test_step",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test", "score": 0.8, "metadata": {"kb_id": "kb-allowed-1"}}
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证检索调用时传入了正确的 kb_ids
call_args = mock_retrieve.call_args
assert call_args[1]['step_kb_config'] == step_config
# 验证返回结果包含 step_kb_binding 信息
assert result.step_kb_binding is not None
assert result.step_kb_binding['allowed_kb_ids'] == ["kb-allowed-1", "kb-allowed-2"]
@pytest.mark.asyncio
async def test_kb_search_without_step_config(self):
"""测试未配置步骤KB时的回退行为"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test", "score": 0.8, "metadata": {}}
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
)
# 验证检索调用时未传入 step_kb_config
call_args = mock_retrieve.call_args
assert call_args[1]['step_kb_config'] is None
# 验证返回结果不包含 step_kb_binding
assert result.step_kb_binding is None
@pytest.mark.asyncio
async def test_kb_search_result_includes_used_kb_ids(self):
"""测试检索结果包含实际使用的知识库ID"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-1", "kb-2"],
step_id="test_step",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test1", "score": 0.9, "metadata": {"kb_id": "kb-1"}},
{"id": "2", "content": "test2", "score": 0.8, "metadata": {"kb_id": "kb-1"}},
{"id": "3", "content": "test3", "score": 0.7, "metadata": {"kb_id": "kb-2"}},
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证 used_kb_ids 包含所有命中的知识库
assert result.step_kb_binding is not None
assert set(result.step_kb_binding['used_kb_ids']) == {"kb-1", "kb-2"}
assert result.step_kb_binding['kb_hit'] is True
class TestTraceInfoStepKbBinding:
"""测试 TraceInfo 中的 step_kb_binding 字段"""
def test_trace_info_with_step_kb_binding(self):
"""测试 TraceInfo 包含 step_kb_binding 字段"""
from app.models.mid.schemas import TraceInfo, ExecutionMode
trace = TraceInfo(
mode=ExecutionMode.AGENT,
step_kb_binding={
"step_id": "flow-1_step_2",
"allowed_kb_ids": ["kb-1", "kb-2"],
"used_kb_ids": ["kb-1"],
"kb_hit": True,
},
)
assert trace.step_kb_binding is not None
assert trace.step_kb_binding['step_id'] == "flow-1_step_2"
assert trace.step_kb_binding['allowed_kb_ids'] == ["kb-1", "kb-2"]
assert trace.step_kb_binding['used_kb_ids'] == ["kb-1"]
def test_trace_info_without_step_kb_binding(self):
"""测试 TraceInfo 默认不包含 step_kb_binding"""
from app.models.mid.schemas import TraceInfo, ExecutionMode
trace = TraceInfo(mode=ExecutionMode.AGENT)
assert trace.step_kb_binding is None
class TestFlowStepKbBindingIntegration:
"""测试流程步骤与KB绑定的集成"""
def test_script_flow_steps_with_kb_binding(self):
"""测试 ScriptFlow 的 steps 包含 KB 绑定配置"""
from app.models.entities import ScriptFlowCreate
flow_create = ScriptFlowCreate(
name="测试流程",
steps=[
{
"step_no": 1,
"content": "步骤1",
"allowed_kb_ids": ["kb-1"],
"preferred_kb_ids": None,
"kb_query_hint": "查找产品信息",
"max_kb_calls_per_step": 2,
},
{
"step_no": 2,
"content": "步骤2",
# 不配置 KB 绑定
},
],
)
assert flow_create.steps[0]['allowed_kb_ids'] == ["kb-1"]
assert flow_create.steps[0]['kb_query_hint'] == "查找产品信息"
assert flow_create.steps[1].get('allowed_kb_ids') is None
class TestKbBindingLogging:
"""测试 KB 绑定的日志记录"""
@pytest.mark.asyncio
async def test_step_kb_config_logging(self, caplog):
"""测试步骤KB配置的日志记录"""
import logging
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-1"],
step_id="flow-1_step_1",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = []
with caplog.at_level(logging.INFO):
await tool.execute(
query="测试",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证日志包含 Step-KB-Binding 标记
assert any("Step-KB-Binding" in record.message for record in caplog.records)
# 运行测试的入口
if __name__ == "__main__":
pytest.main([__file__, "-v"])