diff --git a/ai-service/migrations/001_scene_slot_bundles.py b/ai-service/migrations/001_scene_slot_bundles.py new file mode 100644 index 0000000..7b61d18 --- /dev/null +++ b/ai-service/migrations/001_scene_slot_bundles.py @@ -0,0 +1,81 @@ +""" +Database Migration: Scene Slot Bundle Tables. +[AC-SCENE-SLOT-01] 场景-槽位映射配置表迁移 + +创建时间: 2025-03-07 +变更说明: +- 新增 scene_slot_bundles 表用于存储场景槽位包配置 + +执行方式: +- SQLModel 会自动创建表(通过 init_db) +- 此脚本用于手动迁移或回滚 + +SQL DDL: +```sql +CREATE TABLE scene_slot_bundles ( + id UUID PRIMARY KEY, + tenant_id VARCHAR NOT NULL, + scene_key VARCHAR(100) NOT NULL, + scene_name VARCHAR(100) NOT NULL, + description TEXT, + required_slots JSON NOT NULL DEFAULT '[]', + optional_slots JSON NOT NULL DEFAULT '[]', + slot_priority JSON, + completion_threshold FLOAT NOT NULL DEFAULT 1.0, + ask_back_order VARCHAR NOT NULL DEFAULT 'priority', + status VARCHAR NOT NULL DEFAULT 'draft', + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id); +CREATE UNIQUE INDEX ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key); +CREATE INDEX ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status); +``` + +回滚 SQL: +```sql +DROP TABLE IF EXISTS scene_slot_bundles; +``` +""" + +SCENE_SLOT_BUNDLES_DDL = """ +CREATE TABLE IF NOT EXISTS scene_slot_bundles ( + id UUID PRIMARY KEY, + tenant_id VARCHAR NOT NULL, + scene_key VARCHAR(100) NOT NULL, + scene_name VARCHAR(100) NOT NULL, + description TEXT, + required_slots JSON NOT NULL DEFAULT '[]', + optional_slots JSON NOT NULL DEFAULT '[]', + slot_priority JSON, + completion_threshold FLOAT NOT NULL DEFAULT 1.0, + ask_back_order VARCHAR NOT NULL DEFAULT 'priority', + status VARCHAR NOT NULL DEFAULT 'draft', + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); +""" + +SCENE_SLOT_BUNDLES_INDEXES = """ +CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant ON scene_slot_bundles(tenant_id); +CREATE UNIQUE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_scene ON scene_slot_bundles(tenant_id, scene_key); +CREATE INDEX IF NOT EXISTS ix_scene_slot_bundles_tenant_status ON scene_slot_bundles(tenant_id, status); +""" + +SCENE_SLOT_BUNDLES_ROLLBACK = """ +DROP TABLE IF EXISTS scene_slot_bundles; +""" + + +async def upgrade(conn): + """执行迁移""" + await conn.execute(SCENE_SLOT_BUNDLES_DDL) + await conn.execute(SCENE_SLOT_BUNDLES_INDEXES) + + +async def downgrade(conn): + """回滚迁移""" + await conn.execute(SCENE_SLOT_BUNDLES_ROLLBACK) diff --git a/ai-service/migrations/002_slot_definitions_add_display_fields.py b/ai-service/migrations/002_slot_definitions_add_display_fields.py new file mode 100644 index 0000000..20634df --- /dev/null +++ b/ai-service/migrations/002_slot_definitions_add_display_fields.py @@ -0,0 +1,49 @@ +""" +Database Migration: Add display_name and description to slot_definitions. +添加槽位名称和槽位说明字段 + +创建时间: 2026-03-08 +变更说明: +- 新增 display_name 字段:槽位名称,给运营/教研看的中文名 +- 新增 description 字段:槽位说明,解释这个槽位采集什么、用于哪里 + +执行方式: +- SQLModel 会自动处理新字段(通过 init_db) +- 此脚本用于手动迁移现有数据库 + +SQL DDL: +```sql +ALTER TABLE slot_definitions +ADD COLUMN IF NOT EXISTS display_name VARCHAR(100), +ADD COLUMN IF NOT EXISTS description VARCHAR(500); +``` + +回滚 SQL: +```sql +ALTER TABLE slot_definitions +DROP COLUMN IF EXISTS display_name, +DROP COLUMN IF EXISTS description; +``` +""" + +ALTER_SLOT_DEFINITIONS_DDL = """ +ALTER TABLE slot_definitions +ADD COLUMN IF NOT EXISTS display_name VARCHAR(100), +ADD COLUMN IF NOT EXISTS description VARCHAR(500); +""" + +ALTER_SLOT_DEFINITIONS_ROLLBACK = """ +ALTER TABLE slot_definitions +DROP COLUMN IF EXISTS display_name, +DROP COLUMN IF EXISTS description; +""" + + +async def upgrade(conn): + """执行迁移""" + await conn.execute(ALTER_SLOT_DEFINITIONS_DDL) + + +async def downgrade(conn): + """回滚迁移""" + await conn.execute(ALTER_SLOT_DEFINITIONS_ROLLBACK) diff --git a/ai-service/scripts/check_all_rules.py b/ai-service/scripts/check_all_rules.py new file mode 100644 index 0000000..7fd6394 --- /dev/null +++ b/ai-service/scripts/check_all_rules.py @@ -0,0 +1,51 @@ +""" +检查所有意图规则(包括未启用的) +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.models.entities import IntentRule +from app.core.config import get_settings + + +async def check_all_rules(): + """获取所有意图规则""" + settings = get_settings() + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + # 查询所有规则(包括未启用的) + result = await session.execute( + select(IntentRule).where( + IntentRule.tenant_id == "szmp@ash@2026" + ) + ) + rules = result.scalars().all() + + print("=" * 80) + print(f"数据库中的所有意图规则 (tenant=szmp@ash@2026):") + print(f"总计: {len(rules)} 条") + print("=" * 80) + + for rule in rules: + print(f"\n规则: {rule.name}") + print(f" ID: {rule.id}") + print(f" 响应类型: {rule.response_type}") + print(f" 关键词: {rule.keywords}") + print(f" 目标知识库: {rule.target_kb_ids}") + print(f" 优先级: {rule.priority}") + print(f" 启用状态: {'✅ 启用' if rule.is_enabled else '❌ 禁用'}") + + print("\n" + "=" * 80) + + +if __name__ == "__main__": + asyncio.run(check_all_rules()) diff --git a/ai-service/scripts/check_course_kb.py b/ai-service/scripts/check_course_kb.py new file mode 100644 index 0000000..c2999b8 --- /dev/null +++ b/ai-service/scripts/check_course_kb.py @@ -0,0 +1,71 @@ +""" +检查 Qdrant 中课程知识库的数据结构 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient + +from app.core.config import get_settings + + +from app.core.qdrant_client import QdrantClient + + +async def check_course_kb(): + """检查课程知识库""" + settings = get_settings() + + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + safe_tenant_id = tenant_id.replace('@', '_') + prefix = settings.qdrant_collection_prefix + + expected_collection = f"{prefix}{safe_tenant_id}_{course_kb_id}" + + print(f"\n{'='*80}") + print(f"检查课程知识库 Collection") + print(f"{'='*80}") + print(f"租户 ID: {tenant_id}") + print(f"课程知识库 ID: {course_kb_id}") + print(f"预期 Collection 名称: {expected_collection}") + + collections = await qdrant.get_collections() + collection_names = [c.name for c in collections.collections] + + print(f"\n租户的所有 Collections:") + for name in collection_names: + if safe_tenant_id in name: + print(f" - {name}") + + if expected_collection in collection_names: + print(f"\n✅ 课程知识库 Collection 存在: {expected_collection}") + + points, _ = qdrant.scroll( + collection_name=expected_collection, + limit=3, + with_vectors=False, + ) + + print(f"\n课程知识库数据 (共 {len(points)} 条):") + for i, point in enumerate(points, 1): + payload = point.get('payload', {}) + print(f"\n [{i}] id: {point.get('id')}") + print(f" payload keys: {list(payload.keys())}") + if 'metadata' in payload: + print(f" metadata: {payload['metadata']}") + else: + print(f"\n❌ 课程知识库 Collection 不存在!") + print(f" 可用的 Collections: {collection_names}") + + +if __name__ == "__main__": + asyncio.run(check_course_kb()) diff --git a/ai-service/scripts/check_course_kb_collection.py b/ai-service/scripts/check_course_kb_collection.py new file mode 100644 index 0000000..56c8f24 --- /dev/null +++ b/ai-service/scripts/check_course_kb_collection.py @@ -0,0 +1,50 @@ +""" +检查课程知识库的 collection 是否存在 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def check_course_kb_collection(): + """检查课程知识库的 collection 是否存在""" + settings = get_settings() + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + print(f"课程知识库 collection name: {collection_name}") + + exists = await qdrant.collection_exists(collection_name) + print(f"Collection exists: {exists}") + + if exists: + points = await qdrant.scroll( + collection_name=collection_name, + limit=5, + with_vectors=False, + ) + print(f"\n课程知识库中有 {len(points)} 条数据:") + for i, point in enumerate(points, 1): + payload = point.get('payload', {}) + print(f" [{i}] payload keys: {list(payload.keys())}") + for key, value in payload.items(): + if key != 'text' and key != 'vector': + print(f" {key}: {value}") + else: + print(f"课程知识库 collection 不存在!") + + +if __name__ == "__main__": + asyncio.run(check_course_kb_collection()) diff --git a/ai-service/scripts/check_course_kb_status.py b/ai-service/scripts/check_course_kb_status.py new file mode 100644 index 0000000..bc32bac --- /dev/null +++ b/ai-service/scripts/check_course_kb_status.py @@ -0,0 +1,98 @@ +""" +检查课程知识库的录入情况 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient +from app.models.entities import Document + + +async def check_course_kb_status(): + """检查课程知识库的录入情况""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + print(f"\n{'='*80}") + print(f"检查课程知识库的录入情况") + print(f"{'='*80}") + print(f"租户 ID: {tenant_id}") + print(f"知识库 ID: {kb_id}") + + async with async_session() as session: + stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.kb_id == kb_id, + ) + result = await session.execute(stmt) + documents = result.scalars().all() + + print(f"\n数据库中的文档记录: {len(documents)} 个") + if documents: + for doc in documents[:5]: + print(f" - {doc.file_name} (status: {doc.status})") + if len(documents) > 5: + print(f" ... 还有 {len(documents) - 5} 个文档") + + client = QdrantClient() + qdrant = await client.get_client() + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + print(f"\nQdrant Collection 名称: {collection_name}") + + exists = await qdrant.collection_exists(collection_name) + if exists: + points_result = await qdrant.scroll( + collection_name=collection_name, + limit=5, + with_vectors=False, + ) + points = points_result[0] if isinstance(points_result, tuple) else points_result + print(f"Qdrant Collection 存在,有 {len(points)} 条数据") + for i, point in enumerate(points, 1): + if hasattr(point, 'payload'): + payload = point.payload + point_id = point.id + else: + payload = point.get('payload', {}) + point_id = point.get('id', 'unknown') + print(f" [{i}] id: {point_id}") + if 'text' in payload: + text = payload['text'][:50] + '...' if len(payload['text']) > 50 else payload['text'] + print(f" text: {text}") + else: + print(f"Qdrant Collection 不存在!") + + print(f"\n{'='*80}") + print(f"结论:") + if len(documents) > 0 and not exists: + print(" 数据库有文档记录,但 Qdrant Collection 不存在") + print(" 需要等待文档向量化任务完成") + elif len(documents) == 0 and exists: + print(" 数据库没有文档记录,但 Qdrant Collection 存在") + print(" 可能是旧数据") + elif len(documents) > 0 and exists: + print(f" 数据库有 {len(documents)} 个文档记录") + print(f" Qdrant Collection 存在") + print(" ✅ 知识库已录入完成") + else: + print(" 数据库没有文档记录") + print(" Qdrant Collection 不存在") + print(" ❌ 知识库未录入") + + +if __name__ == "__main__": + asyncio.run(check_course_kb_status()) diff --git a/ai-service/scripts/check_grade_data.py b/ai-service/scripts/check_grade_data.py new file mode 100644 index 0000000..ea6f9dc --- /dev/null +++ b/ai-service/scripts/check_grade_data.py @@ -0,0 +1,78 @@ +""" +检查 Qdrant 中是否有 grade=五年级 的数据 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client.models import FieldCondition, Filter, MatchValue + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def check_grade_data(): + """检查 Qdrant 中是否有 grade=五年级 的数据""" + settings = get_settings() + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + + print(f"\n{'='*80}") + print(f"检查 Qdrant 中 grade 字段的分布") + print(f"{'='*80}") + print(f"Collection: {collection_name}") + + # 获取所有数据 + all_points = await qdrant.scroll( + collection_name=collection_name, + limit=100, + with_vectors=False, + ) + + print(f"\n总数据量: {len(all_points[0])} 条") + + # 统计 grade 分布 + grade_count = {} + for point in all_points[0]: + metadata = point.payload.get('metadata', {}) + grade = metadata.get('grade', '无') + grade_count[grade] = grade_count.get(grade, 0) + 1 + + print(f"\ngrade 字段分布:") + for grade, count in sorted(grade_count.items()): + print(f" {grade}: {count} 条") + + # 检查是否有 五年级 的数据 + print(f"\n--- 检查 grade=五年级 的数据 ---") + qdrant_filter = Filter( + must=[ + FieldCondition( + key="metadata.grade", + match=MatchValue(value="五年级"), + ) + ] + ) + + results = await qdrant.scroll( + collection_name=collection_name, + limit=10, + with_vectors=False, + scroll_filter=qdrant_filter, + ) + + print(f"grade=五年级 的数据: {len(results[0])} 条") + for p in results[0]: + print(f" text: {p.payload.get('text', '')[:80]}...") + print(f" metadata: {p.payload.get('metadata', {})}") + + +if __name__ == "__main__": + asyncio.run(check_grade_data()) diff --git a/ai-service/scripts/check_kb_content.py b/ai-service/scripts/check_kb_content.py new file mode 100644 index 0000000..747923e --- /dev/null +++ b/ai-service/scripts/check_kb_content.py @@ -0,0 +1,88 @@ +""" +查看指定知识库的内容 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient +from app.core.config import get_settings + + +async def check_kb_content(): + """查看知识库内容""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + tenant_id = "szmp@ash@2026" + kb_id = "8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c" + collection_name = f"kb_szmp_ash_2026_8559ebc9" + + print("=" * 80) + print(f"查看知识库: {kb_id}") + print(f"Collection: {collection_name}") + print("=" * 80) + + try: + # 检查 collection 是否存在 + exists = await client.collection_exists(collection_name) + print(f"\nCollection 存在: {exists}") + + if not exists: + print("Collection 不存在!") + return + + # 获取 collection 信息 + info = await client.get_collection(collection_name) + print(f"\nCollection 信息:") + print(f" 向量数: {info.points_count}") + + # 滚动查询所有点 + print(f"\n文档内容:") + print("-" * 80) + + offset = None + total = 0 + while True: + result = await client.scroll( + collection_name=collection_name, + limit=10, + offset=offset, + with_payload=True, + ) + + points = result[0] + if not points: + break + + for point in points: + total += 1 + payload = point.payload or {} + text = payload.get('text', 'N/A')[:100] + metadata = payload.get('metadata', {}) + filename = payload.get('filename', 'N/A') + + print(f"\n [{total}] ID: {point.id}") + print(f" Filename: {filename}") + print(f" Text: {text}...") + print(f" Metadata: {metadata}") + + offset = result[1] + if offset is None: + break + + print(f"\n总计 {total} 条记录") + + except Exception as e: + print(f"\n错误: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(check_kb_content()) diff --git a/ai-service/scripts/check_knowledge_bases.py b/ai-service/scripts/check_knowledge_bases.py new file mode 100644 index 0000000..282f24e --- /dev/null +++ b/ai-service/scripts/check_knowledge_bases.py @@ -0,0 +1,51 @@ +""" +检查租户的所有知识库 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings +from app.models.entities import KnowledgeBase + + +async def check_knowledge_bases(): + """检查租户的所有知识库""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + tenant_id = "szmp@ash@2026" + + print(f"\n{'='*80}") + print(f"检查租户 {tenant_id} 的所有知识库") + print(f"{'='*80}") + + stmt = select(KnowledgeBase).where( + KnowledgeBase.tenant_id == tenant_id, + ) + result = await session.execute(stmt) + kbs = result.scalars().all() + + print(f"\n找到 {len(kbs)} 个知识库:") + + for kb in kbs: + print(f"\n 知识库: {kb.name}") + print(f" id: {kb.id}") + print(f" kb_type: {kb.kb_type}") + print(f" description: {kb.description}") + print(f" is_enabled: {kb.is_enabled}") + print(f" doc_count: {kb.doc_count}") + print(f" created_at: {kb.created_at}") + + +if __name__ == "__main__": + asyncio.run(check_knowledge_bases()) diff --git a/ai-service/scripts/check_metadata_fields.py b/ai-service/scripts/check_metadata_fields.py new file mode 100644 index 0000000..673ec9e --- /dev/null +++ b/ai-service/scripts/check_metadata_fields.py @@ -0,0 +1,65 @@ +""" +检查知识库的元数据字段定义 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings +from app.models.entities import ( + MetadataFieldDefinition, + MetadataFieldStatus, + FieldRole, +) + + +async def check_metadata_fields(): + """检查元数据字段定义""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + tenant_id = "szmp@ash@2026" + + print(f"\n{'='*80}") + print(f"检查租户 {tenant_id} 的元数据字段定义") + print(f"{'='*80}") + + stmt = select(MetadataFieldDefinition).where( + MetadataFieldDefinition.tenant_id == tenant_id, + MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE, + ) + result = await session.execute(stmt) + fields = result.scalars().all() + + print(f"\n找到 {len(fields)} 个活跃字段定义:") + + for f in fields: + print(f"\n 字段: {f.field_key}") + print(f" label: {f.label}") + print(f" type: {f.type}") + print(f" required: {f.required}") + print(f" field_roles: {f.field_roles}") + print(f" options: {f.options}") + print(f" default_value: {f.default_value}") + + filterable_fields = [ + f for f in fields + if f.field_roles and FieldRole.RESOURCE_FILTER.value in f.field_roles + ] + print(f"\n{'='*80}") + print(f"可过滤字段 (field_roles 包含 resource_filter): {len(filterable_fields)} 个") + for f in filterable_fields: + print(f" - {f.field_key} (label: {f.label}, required: {f.required})") + + +if __name__ == "__main__": + asyncio.run(check_metadata_fields()) diff --git a/ai-service/scripts/check_metadata_structure.py b/ai-service/scripts/check_metadata_structure.py new file mode 100644 index 0000000..50725e5 --- /dev/null +++ b/ai-service/scripts/check_metadata_structure.py @@ -0,0 +1,68 @@ +""" +检查 Qdrant 中数据的 metadata 存储结构 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def check_metadata_structure(): + """检查 Qdrant 中数据的 metadata 存储结构""" + settings = get_settings() + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + + print(f"\n{'='*80}") + print(f"检查 Qdrant 数据结构") + print(f"{'='*80}") + print(f"Collection: {collection_name}") + + points = await qdrant.scroll( + collection_name=collection_name, + limit=3, + with_vectors=False, + ) + + print(f"\n找到 {len(points[0])} 条数据:") + + for i, point in enumerate(points[0], 1): + print(f"\n--- Point {i} ---") + if hasattr(point, 'payload'): + payload = point.payload + point_id = point.id + else: + payload = point.get('payload', {}) + point_id = point.get('id', 'unknown') + + print(f"ID: {point_id}") + print(f"Payload keys: {list(payload.keys())}") + + # 打印完整的 payload 结构 + for key, value in payload.items(): + if key == 'text': + print(f" {key}: {value[:50]}..." if len(str(value)) > 50 else f" {key}: {value}") + elif key == 'vector': + print(f" {key}: [向量数据]") + else: + print(f" {key}: {value}") + + # 检查 metadata 字段 + if 'metadata' in payload: + print(f"\n metadata 字段内容:") + for mk, mv in payload['metadata'].items(): + print(f" {mk}: {mv}") + + +if __name__ == "__main__": + asyncio.run(check_metadata_structure()) diff --git a/ai-service/scripts/check_qdrant.py b/ai-service/scripts/check_qdrant.py index 1612bf3..2ef35bf 100644 --- a/ai-service/scripts/check_qdrant.py +++ b/ai-service/scripts/check_qdrant.py @@ -1,79 +1,112 @@ """ -Check Qdrant vector database contents - detailed view. +检查 Qdrant 向量数据库状态和知识库内容 """ + import asyncio import sys -sys.path.insert(0, ".") +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) -from qdrant_client import AsyncQdrantClient from app.core.config import get_settings -from collections import defaultdict - -settings = get_settings() +from app.core.qdrant_client import get_qdrant_client async def check_qdrant(): - """Check Qdrant collections and vectors.""" - client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False) + """检查 Qdrant 状态""" + settings = get_settings() + tenant_id = "szmp@ash@2026" - print(f"\n{'='*60}") + print(f"Database URL: {settings.database_url}") print(f"Qdrant URL: {settings.qdrant_url}") - print(f"{'='*60}\n") + print(f"Tenant ID: {tenant_id}") + print() - # List all collections - collections = await client.get_collections() - - # Check kb_default collection - for c in collections.collections: - if c.name == "kb_default": - print(f"\n--- Collection: {c.name} ---") + try: + qdrant_manager = await get_qdrant_client() + client = await qdrant_manager.get_client() + + # 检查集合是否存在 + collections = (await client.get_collections()).collections + collection_names = [c.name for c in collections] + print(f"Available collections: {collection_names}") + print() + + # 筛选该租户的 collections + tenant_collections = [name for name in collection_names if "szmp_ash_2026" in name] + print(f"Tenant collections: {tenant_collections}") + print() + + # 检查每个集合 + for collection_name in tenant_collections: + print(f"\n{'='*60}") + print(f"Collection: {collection_name}") + print(f"{'='*60}") - # Get collection info - info = await client.get_collection(c.name) - print(f" Total vectors: {info.points_count}") + # 获取集合信息 + collection_info = await client.get_collection(collection_name) + print(f" Points count: {collection_info.points_count}") + print(f" Vectors count: {collection_info.vectors_count}") + print(f" Status: {collection_info.status}") - # Scroll through all points and group by source - all_points = [] - offset = None + if collection_info.points_count == 0: + print(" ⚠️ Collection is empty!") + continue - while True: - points, offset = await client.scroll( - collection_name=c.name, - limit=100, - offset=offset, + # 滚动获取一些数据 + print(f"\n 前 3 条数据:") + points, next_page = await client.scroll( + collection_name=collection_name, + limit=3, + with_payload=True, + with_vectors=False, + ) + + for i, point in enumerate(points, 1): + payload = point.payload or {} + text = payload.get("text", "")[:100] + "..." if payload.get("text") else "N/A" + kb_id = payload.get("kb_id", "N/A") + metadata = payload.get("metadata", {}) + print(f"\n Point {i}:") + print(f" ID: {point.id}") + print(f" KB ID: {kb_id}") + print(f" Text: {text}") + print(f" Metadata: {metadata}") + + # 尝试向量搜索 + print(f"\n\n{'='*60}") + print(f"尝试向量搜索 (query='课程'):") + print(f"{'='*60}") + + from app.services.embedding.factory import get_embedding_provider + + embedding_provider = await get_embedding_provider() + query_vector = await embedding_provider.embed("课程") + print(f"Query vector dimension: {len(query_vector)}") + + for collection_name in tenant_collections: + print(f"\n搜索 collection: {collection_name}") + try: + search_results = await client.query_points( + collection_name=collection_name, + query=query_vector, + using="full", # 使用 full 向量 + limit=3, with_payload=True, - with_vectors=False, ) - all_points.extend(points) - if offset is None: - break - - # Group by source - by_source = defaultdict(list) - for p in all_points: - source = p.payload.get("source", "unknown") if p.payload else "unknown" - by_source[source].append(p) - - print(f"\n Documents by source:") - for source, points in by_source.items(): - print(f"\n Source: {source}") - print(f" Chunks: {len(points)}") - # Check first chunk content - first_point = points[0] - text = first_point.payload.get("text", "") if first_point.payload else "" - - # Check if it's binary garbage or proper text - is_garbage = any(ord(c) > 0xFFFF or (ord(c) < 32 and c not in '\n\r\t') for c in text[:200]) - - if is_garbage: - print(f" Status: ❌ BINARY GARBAGE (parsing failed)") - else: - print(f" Status: ✅ PROPER TEXT (parsed correctly)") - - print(f" Preview: {text[:150]}...") - - await client.close() + print(f" Search results: {len(search_results.points)}") + for i, result in enumerate(search_results.points, 1): + payload = result.payload or {} + text = payload.get("text", "")[:80] + "..." if payload.get("text") else "N/A" + print(f" {i}. [score={result.score:.4f}] {text}") + except Exception as e: + print(f" ❌ Search error: {e}") + + except Exception as e: + print(f"❌ Error: {e}") + import traceback + traceback.print_exc() if __name__ == "__main__": diff --git a/ai-service/scripts/check_qdrant_collections.py b/ai-service/scripts/check_qdrant_collections.py new file mode 100644 index 0000000..b25b04d --- /dev/null +++ b/ai-service/scripts/check_qdrant_collections.py @@ -0,0 +1,54 @@ +""" +检查 Qdrant 中实际存在的 collections +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient + +from app.core.config import get_settings + + +async def check_qdrant_collections(): + """检查 Qdrant 中实际存在的 collections""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + try: + collections = await client.get_collections() + print(f"\n{'='*80}") + print(f"Qdrant 中所有 collections:") + print(f"{'='*80}") + + for coll in collections.collections: + print(f" - {coll.name}") + + tenant_id = "szmp@ash@2026" + safe_tenant_id = tenant_id.replace('@', '_') + prefix = f"kb_{safe_tenant_id}" + + tenant_collections = [coll.name for coll in collections.collections if coll.name.startswith(prefix)] + print(f"\n租户 {tenant_id} 的 collections:") + print(f"{'='*80}") + for coll_name in tenant_collections: + print(f" - {coll_name}") + + kb_id = None + if coll_name.startswith(prefix): + parts = coll_name.split('_') + if len(parts) > 2: + kb_id = parts[2] + print(f" kb_id: {kb_id}") + + except Exception as e: + print(f"错误: {e}") + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(check_qdrant_collections()) diff --git a/ai-service/scripts/check_qdrant_data.py b/ai-service/scripts/check_qdrant_data.py new file mode 100644 index 0000000..b5d3be8 --- /dev/null +++ b/ai-service/scripts/check_qdrant_data.py @@ -0,0 +1,61 @@ +""" +检查 Qdrant 中实际存储的数据结构 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def check_qdrant_data(): + """检查 Qdrant 中的数据""" + + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + + print(f"\n{'='*80}") + print(f"检查租户 {tenant_id} 的 Qdrant 数据") + print(f"{'='*80}") + + collections = await client.list_collections(tenant_id) + print(f"\n找到 {len(collections)} 个集合:") + for coll in collections: + print(f" - {coll}") + + for collection_name in collections[:3]: + print(f"\n{'='*80}") + print(f"检查集合: {collection_name}") + print(f"{'='*80}") + + try: + points = await client.scroll_points( + collection_name=collection_name, + limit=5, + ) + + print(f"\n找到 {len(points)} 条数据:") + for i, point in enumerate(points, 1): + payload = point.get('payload', {}) + print(f"\n [{i}] id: {point.get('id')}") + print(f" metadata 字段:") + for key, value in payload.items(): + if key != 'text' and key != 'vector': + print(f" {key}: {value}") + + text = payload.get('text', '') + if text: + print(f" text 预览: {text[:100]}...") + + except Exception as e: + print(f" 错误: {e}") + + +if __name__ == "__main__": + asyncio.run(check_qdrant_data()) diff --git a/ai-service/scripts/cleanup_collections.py b/ai-service/scripts/cleanup_collections.py new file mode 100644 index 0000000..3c24f8f --- /dev/null +++ b/ai-service/scripts/cleanup_collections.py @@ -0,0 +1,120 @@ +""" +清理 szmp@ash@2026 租户下不需要的 Qdrant collections +保留:8559ebc9-bfaf-4211-8fe3-ee2b22a5e29c, 30c19c84-8f69-4768-9d23-7f4a5bc3627a +删除:其他所有 collections +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient +from app.core.config import get_settings + + +async def cleanup_collections(): + """清理 collections""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + tenant_id = "szmp@ash@2026" + safe_tenant_id = tenant_id.replace('@', '_') + prefix = f"kb_{safe_tenant_id}" + + # 保留的 kb_id 前缀(前8位) + keep_kb_ids = [ + "8559ebc9", + "30c19c84", + ] + + print(f"🔍 扫描租户 {tenant_id} 的 collections...") + print(f" 前缀: {prefix}") + print(f" 保留: {keep_kb_ids}") + print("-" * 80) + + try: + collections = await client.get_collections() + + # 找出该租户的所有 collections + tenant_collections = [ + c.name for c in collections.collections + if c.name.startswith(prefix) + ] + + print(f"\n📊 找到 {len(tenant_collections)} 个 collections:") + for name in sorted(tenant_collections): + # 检查是否需要保留 + should_keep = any(kb_id in name for kb_id in keep_kb_ids) + status = "✅ 保留" if should_keep else "❌ 删除" + print(f" {status} {name}") + + print("\n" + "=" * 80) + print("开始删除...") + print("=" * 80) + + deleted = [] + skipped = [] + + for collection_name in tenant_collections: + # 检查是否需要保留 + should_keep = any(kb_id in collection_name for kb_id in keep_kb_ids) + + if should_keep: + print(f"\n⏭️ 跳过 {collection_name} (保留)") + skipped.append(collection_name) + continue + + print(f"\n🗑️ 删除 {collection_name}...") + try: + await client.delete_collection(collection_name) + print(f" ✅ 已删除") + deleted.append(collection_name) + except Exception as e: + print(f" ❌ 删除失败: {e}") + + print("\n" + "=" * 80) + print("清理完成!") + print("=" * 80) + print(f"\n📈 统计:") + print(f" 保留: {len(skipped)} 个") + for name in skipped: + print(f" - {name}") + print(f"\n 删除: {len(deleted)} 个") + for name in deleted: + print(f" - {name}") + + except Exception as e: + print(f"\n❌ 错误: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + # 安全确认 + print("=" * 80) + print("⚠️ 警告: 此操作将永久删除以下 collections:") + print(" - kb_szmp_ash_2026") + print(" - kb_szmp_ash_2026_fa4c1d61") + print(" - kb_szmp_ash_2026_3ddf0ce7") + print("\n 保留:") + print(" - kb_szmp_ash_2026_8559ebc9") + print(" - kb_szmp_ash_2026_30c19c84") + print("=" * 80) + print("\n确认删除? (yes/no): ", end="") + + # 在非交互环境自动确认 + import os + if os.environ.get('AUTO_CONFIRM') == 'true': + response = 'yes' + print('yes (auto)') + else: + response = input().strip().lower() + + if response in ('yes', 'y'): + asyncio.run(cleanup_collections()) + else: + print("\n❌ 已取消") diff --git a/ai-service/scripts/delete_course_kb_collection.py b/ai-service/scripts/delete_course_kb_collection.py new file mode 100644 index 0000000..1757080 --- /dev/null +++ b/ai-service/scripts/delete_course_kb_collection.py @@ -0,0 +1,42 @@ +""" +删除课程知识库的 Qdrant Collection +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def delete_course_kb_collection(): + """删除课程知识库的 Qdrant Collection""" + settings = get_settings() + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + + print(f"\n{'='*80}") + print(f"删除课程知识库的 Qdrant Collection") + print(f"{'='*80}") + print(f"租户 ID: {tenant_id}") + print(f"知识库 ID: {kb_id}") + print(f"Collection 名称: {collection_name}") + + exists = await qdrant.collection_exists(collection_name) + if exists: + await qdrant.delete_collection(collection_name) + print(f"\n✅ Collection {collection_name} 已删除!") + else: + print(f"\n⚠️ Collection {collection_name} 不存在,无需删除") + + +if __name__ == "__main__": + asyncio.run(delete_course_kb_collection()) diff --git a/ai-service/scripts/delete_course_kb_documents.py b/ai-service/scripts/delete_course_kb_documents.py new file mode 100644 index 0000000..3a28bcb --- /dev/null +++ b/ai-service/scripts/delete_course_kb_documents.py @@ -0,0 +1,75 @@ +""" +删除课程知识库的文档记录 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings +from app.models.entities import Document, IndexJob + + +async def delete_course_kb_documents(): + """删除课程知识库的文档记录""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + print(f"\n{'='*80}") + print(f"删除课程知识库的文档记录") + print(f"{'='*80}") + print(f"租户 ID: {tenant_id}") + print(f"知识库 ID: {kb_id}") + + async with async_session() as session: + stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.kb_id == kb_id, + ) + result = await session.execute(stmt) + documents = result.scalars().all() + + print(f"\n找到 {len(documents)} 个文档记录") + + if not documents: + print("没有需要删除的文档记录") + return + + for doc in documents[:5]: + print(f" - {doc.file_name} (id: {doc.id})") + if len(documents) > 5: + print(f" ... 还有 {len(documents) - 5} 个文档") + + doc_ids = [doc.id for doc in documents] + + index_job_stmt = delete(IndexJob).where( + IndexJob.tenant_id == tenant_id, + IndexJob.doc_id.in_(doc_ids), + ) + index_job_result = await session.execute(index_job_stmt) + print(f"\n删除了 {index_job_result.rowcount} 个索引任务记录") + + doc_stmt = delete(Document).where( + Document.tenant_id == tenant_id, + Document.kb_id == kb_id, + ) + doc_result = await session.execute(doc_stmt) + print(f"删除了 {doc_result.rowcount} 个文档记录") + + await session.commit() + + print(f"\n✅ 删除完成!") + + +if __name__ == "__main__": + asyncio.run(delete_course_kb_documents()) diff --git a/ai-service/scripts/get_api_key.py b/ai-service/scripts/get_api_key.py new file mode 100644 index 0000000..315f953 --- /dev/null +++ b/ai-service/scripts/get_api_key.py @@ -0,0 +1,42 @@ +""" +获取数据库中的 API key +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.models.entities import ApiKey +from app.core.config import get_settings + + +async def get_api_keys(): + """获取所有 API keys""" + settings = get_settings() + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + result = await session.execute( + select(ApiKey).where(ApiKey.is_active == True) + ) + keys = result.scalars().all() + + print("=" * 80) + print("数据库中的 API Keys:") + print("=" * 80) + for key in keys: + print(f"\nKey: {key.key}") + print(f" Name: {key.name}") + print(f" Tenant: {key.tenant_id}") + print(f" Active: {key.is_active}") + print("\n" + "=" * 80) + + +if __name__ == "__main__": + asyncio.run(get_api_keys()) diff --git a/ai-service/scripts/get_intent_rules.py b/ai-service/scripts/get_intent_rules.py new file mode 100644 index 0000000..edc323f --- /dev/null +++ b/ai-service/scripts/get_intent_rules.py @@ -0,0 +1,54 @@ +""" +获取数据库中的意图规则 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.models.entities import IntentRule +from app.core.config import get_settings + + +async def get_intent_rules(): + """获取所有意图规则""" + settings = get_settings() + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + result = await session.execute( + select(IntentRule).where( + IntentRule.tenant_id == "szmp@ash@2026", + IntentRule.is_enabled == True + ) + ) + rules = result.scalars().all() + + print("=" * 80) + print("数据库中的意图规则 (tenant=szmp@ash@2026):") + print("=" * 80) + + if not rules: + print("\n没有找到任何启用的意图规则!") + else: + for rule in rules: + print(f"\n规则: {rule.name}") + print(f" ID: {rule.id}") + print(f" 响应类型: {rule.response_type}") + print(f" 关键词: {rule.keywords}") + print(f" 正则模式: {rule.patterns}") + print(f" 目标知识库: {rule.target_kb_ids}") + print(f" 优先级: {rule.priority}") + print(f" 启用: {rule.is_enabled}") + + print("\n" + "=" * 80) + + +if __name__ == "__main__": + asyncio.run(get_intent_rules()) diff --git a/ai-service/scripts/migrations/008_add_usage_description_to_metadata_fields.sql b/ai-service/scripts/migrations/008_add_usage_description_to_metadata_fields.sql new file mode 100644 index 0000000..73cdfee --- /dev/null +++ b/ai-service/scripts/migrations/008_add_usage_description_to_metadata_fields.sql @@ -0,0 +1,9 @@ +-- Migration: Add usage_description column to metadata_field_definitions table +-- [AC-IDSMETA-XX] Add usage description field for metadata field definitions + +-- Add usage_description column +ALTER TABLE metadata_field_definitions +ADD COLUMN IF NOT EXISTS usage_description TEXT; + +-- Add comment +COMMENT ON COLUMN metadata_field_definitions.usage_description IS '用途说明,描述该元数据字段的业务用途'; diff --git a/ai-service/scripts/migrations/009_add_extract_strategies.sql b/ai-service/scripts/migrations/009_add_extract_strategies.sql new file mode 100644 index 0000000..1f671d1 --- /dev/null +++ b/ai-service/scripts/migrations/009_add_extract_strategies.sql @@ -0,0 +1,33 @@ +-- Migration: Add extract_strategies field to slot_definitions +-- Date: 2026-03-06 +-- Issue: [AC-MRS-07-UPGRADE] 提取策略体系升级 - 支持策略链 + +-- 1. 为 slot_definitions 表新增 extract_strategies 字段(JSONB 数组格式) +ALTER TABLE slot_definitions +ADD COLUMN IF NOT EXISTS extract_strategies JSONB DEFAULT NULL; + +-- 2. 添加字段注释 +COMMENT ON COLUMN slot_definitions.extract_strategies IS +'[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功'; + +-- 3. 数据迁移:将旧的 extract_strategy 转换为 extract_strategies 数组 +-- 注意:保留旧字段用于兼容读取 +UPDATE slot_definitions +SET extract_strategies = CASE + WHEN extract_strategy IS NOT NULL THEN jsonb_build_array(extract_strategy) + ELSE NULL +END +WHERE extract_strategies IS NULL; + +-- 4. 创建 GIN 索引支持策略查询(可选,根据实际查询需求) +-- CREATE INDEX IF NOT EXISTS idx_slot_definitions_extract_strategies +-- ON slot_definitions USING GIN (extract_strategies); + +-- 5. 验证迁移结果 +-- SELECT +-- id, +-- slot_key, +-- extract_strategy as old_strategy, +-- extract_strategies as new_strategies +-- FROM slot_definitions +-- WHERE extract_strategy IS NOT NULL; diff --git a/ai-service/scripts/migrations/010_add_metadata_to_documents.sql b/ai-service/scripts/migrations/010_add_metadata_to_documents.sql new file mode 100644 index 0000000..2f933f1 --- /dev/null +++ b/ai-service/scripts/migrations/010_add_metadata_to_documents.sql @@ -0,0 +1,5 @@ +-- Add metadata field to documents table +-- Migration: 010_add_metadata_to_documents +-- Description: Add metadata JSON field to store document-level metadata + +ALTER TABLE documents ADD COLUMN IF NOT EXISTS metadata JSONB; diff --git a/ai-service/scripts/migrations/011_add_intent_vector_fields.py b/ai-service/scripts/migrations/011_add_intent_vector_fields.py new file mode 100644 index 0000000..3ec4fd5 --- /dev/null +++ b/ai-service/scripts/migrations/011_add_intent_vector_fields.py @@ -0,0 +1,55 @@ +""" +Migration script to add intent_vector and semantic_examples fields to intent_rules table. +[v0.8.0] Hybrid routing - Intent vector fields + +Run this script with: python scripts/migrations/011_add_intent_vector_fields.py +""" + +import asyncio +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from sqlalchemy import text +from app.core.database import async_session_maker + + +async def run_migration(): + """Run the migration to add new columns to intent_rules table.""" + + statements = [ + """ + ALTER TABLE intent_rules + ADD COLUMN IF NOT EXISTS intent_vector JSONB; + """, + """ + ALTER TABLE intent_rules + ADD COLUMN IF NOT EXISTS semantic_examples JSONB; + """, + """ + ALTER TABLE chat_messages + ADD COLUMN IF NOT EXISTS route_trace JSONB; + """, + ] + + async with async_session_maker() as session: + for i, statement in enumerate(statements, 1): + try: + await session.execute(text(statement)) + print(f"[{i}] Executed successfully") + except Exception as e: + if "already exists" in str(e).lower() or "duplicate" in str(e).lower(): + print(f"[{i}] Skipped (already exists): {str(e)[:50]}...") + else: + raise + + await session.commit() + print("\nMigration completed successfully!") + + +if __name__ == "__main__": + asyncio.run(run_migration()) +else: + print("Please run this script directly, not imported.") + sys.exit(1) diff --git a/ai-service/scripts/migrations/011_add_intent_vector_fields.sql b/ai-service/scripts/migrations/011_add_intent_vector_fields.sql new file mode 100644 index 0000000..8b25d84 --- /dev/null +++ b/ai-service/scripts/migrations/011_add_intent_vector_fields.sql @@ -0,0 +1,20 @@ +-- Add intent_vector and semantic_examples fields to intent_rules table +-- [v0.8.0] Hybrid routing support for semantic matching + +-- Add intent_vector column (JSONB for storing pre-computed embedding vectors) +ALTER TABLE intent_rules +ADD COLUMN IF NOT EXISTS intent_vector JSONB; + +-- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation) +ALTER TABLE intent_rules +ADD COLUMN IF NOT EXISTS semantic_examples JSONB; + +-- Add comments for documentation +COMMENT ON COLUMN intent_rules.intent_vector IS '[v0.8.0] Pre-computed intent vector for semantic matching'; +COMMENT ON COLUMN intent_rules.semantic_examples IS '[v0.8.0] Semantic example sentences for dynamic vector computation'; + +-- Add route_trace column to chat_messages table if not exists +ALTER TABLE chat_messages +ADD COLUMN IF NOT EXISTS route_trace JSONB; + +COMMENT ON COLUMN chat_messages.route_trace IS '[v0.8.0] Intent routing trace log for hybrid routing observability'; diff --git a/ai-service/scripts/profile_detailed.py b/ai-service/scripts/profile_detailed.py new file mode 100644 index 0000000..aa8b44c --- /dev/null +++ b/ai-service/scripts/profile_detailed.py @@ -0,0 +1,204 @@ +""" +详细性能分析 - 确认每个环节的耗时 +""" + +import asyncio +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig +from app.core.config import get_settings +from app.core.qdrant_client import get_qdrant_client +from app.services.embedding import get_embedding_provider + + +async def profile_detailed(): + """详细分析每个环节的耗时""" + settings = get_settings() + + print("=" * 80) + print("详细性能分析") + print("=" * 80) + + query = "三年级语文学习" + tenant_id = "szmp@ash@2026" + metadata_filter = {"grade": "三年级", "subject": "语文"} + + # 1. Embedding 生成(应该已预初始化) + print("\n📊 1. Embedding 生成") + print("-" * 80) + start = time.time() + embedding_service = await get_embedding_provider() + init_time = (time.time() - start) * 1000 + + start = time.time() + embedding_result = await embedding_service.embed_query(query) + embed_time = (time.time() - start) * 1000 + + # 获取 embedding 向量 + if hasattr(embedding_result, 'embedding_full'): + query_vector = embedding_result.embedding_full + elif hasattr(embedding_result, 'embedding'): + query_vector = embedding_result.embedding + else: + query_vector = embedding_result + + print(f" 获取服务实例: {init_time:.2f} ms") + print(f" Embedding 生成: {embed_time:.2f} ms") + print(f" 向量维度: {len(query_vector)}") + + # 2. 获取 collections 列表(带缓存) + print("\n📊 2. 获取 collections 列表") + print("-" * 80) + client = await get_qdrant_client() + qdrant_client = await client.get_client() + + start = time.time() + from app.services.metadata_cache_service import get_metadata_cache_service + cache_service = await get_metadata_cache_service() + cache_key = f"collections:{tenant_id}" + + # 尝试从缓存获取 + redis_client = await cache_service._get_redis() + cache_hit = False + if redis_client and cache_service._enabled: + cached = await redis_client.get(cache_key) + if cached: + import json + tenant_collections = json.loads(cached) + cache_hit = True + cache_time = (time.time() - start) * 1000 + print(f" ✅ 缓存命中: {cache_time:.2f} ms") + print(f" Collections: {tenant_collections}") + + if not cache_hit: + import json + # 从 Qdrant 查询 + start = time.time() + collections = await qdrant_client.get_collections() + safe_tenant_id = tenant_id.replace('@', '_') + prefix = f"kb_{safe_tenant_id}" + tenant_collections = [ + c.name for c in collections.collections + if c.name.startswith(prefix) + ] + tenant_collections.sort() + db_time = (time.time() - start) * 1000 + print(f" ❌ 缓存未命中,从 Qdrant 查询: {db_time:.2f} ms") + print(f" Collections: {tenant_collections}") + + # 缓存结果 + if redis_client and cache_service._enabled: + await redis_client.setex(cache_key, 300, json.dumps(tenant_collections)) + print(f" 已缓存到 Redis") + + # 3. Qdrant 搜索(每个 collection) + print("\n📊 3. Qdrant 搜索") + print("-" * 80) + from qdrant_client.models import FieldCondition, Filter, MatchValue + + # 构建 filter + start = time.time() + must_conditions = [] + for key, value in metadata_filter.items(): + field_path = f"metadata.{key}" + condition = FieldCondition( + key=field_path, + match=MatchValue(value=value), + ) + must_conditions.append(condition) + qdrant_filter = Filter(must=must_conditions) if must_conditions else None + filter_time = (time.time() - start) * 1000 + print(f" 构建 filter: {filter_time:.2f} ms") + + # 逐个 collection 搜索 + total_search_time = 0 + for collection_name in tenant_collections: + print(f"\n Collection: {collection_name}") + + # 检查是否存在 + start = time.time() + exists = await qdrant_client.collection_exists(collection_name) + check_time = (time.time() - start) * 1000 + print(f" 检查存在: {check_time:.2f} ms") + + if not exists: + print(f" ❌ 不存在") + continue + + # 搜索 + start = time.time() + try: + results = await qdrant_client.query_points( + collection_name=collection_name, + query=query_vector, + using="full", + limit=5, + score_threshold=0.5, + query_filter=qdrant_filter, + ) + search_time = (time.time() - start) * 1000 + total_search_time += search_time + print(f" 搜索时间: {search_time:.2f} ms") + print(f" 结果数: {len(results.points)}") + except Exception as e: + print(f" ❌ 搜索失败: {e}") + + print(f"\n 总搜索时间: {total_search_time:.2f} ms") + + # 4. 完整 KB Search 流程 + print("\n📊 4. 完整 KB Search 流程") + print("-" * 80) + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + config = KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=30000, + min_score_threshold=0.5, + ) + + tool = KbSearchDynamicTool(session=session, config=config) + + start = time.time() + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene="学习方案", + top_k=5, + context=metadata_filter, + ) + total_time = (time.time() - start) * 1000 + + print(f" 总耗时: {total_time:.2f} ms") + print(f" 工具内部耗时: {result.duration_ms} ms") + print(f" 结果数: {len(result.hits)}") + + # 5. 总结 + print("\n" + "=" * 80) + print("📈 耗时总结") + print("=" * 80) + print(f"\n各环节耗时:") + print(f" Embedding 获取服务: {init_time:.2f} ms") + print(f" Embedding 生成: {embed_time:.2f} ms") + print(f" Collections 获取: {cache_time if cache_hit else db_time:.2f} ms") + print(f" Filter 构建: {filter_time:.2f} ms") + print(f" Qdrant 搜索: {total_search_time:.2f} ms") + print(f" 完整流程: {total_time:.2f} ms") + + other_time = total_time - embed_time - (cache_time if cache_hit else db_time) - filter_time - total_search_time + print(f" 其他开销: {other_time:.2f} ms") + + print("\n" + "=" * 80) + + +if __name__ == "__main__": + asyncio.run(profile_detailed()) diff --git a/ai-service/scripts/profile_full_params.py b/ai-service/scripts/profile_full_params.py new file mode 100644 index 0000000..67ae58b --- /dev/null +++ b/ai-service/scripts/profile_full_params.py @@ -0,0 +1,246 @@ +""" +详细分析完整参数查询的耗时 +对比带 metadata_filter 和不带的区别 +""" + +import asyncio +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig +from app.core.config import get_settings +from app.core.qdrant_client import get_qdrant_client +from qdrant_client.models import FieldCondition, Filter, MatchValue + + +async def profile_step_by_step(): + """逐步分析完整参数查询的耗时""" + settings = get_settings() + + print("=" * 80) + print("完整参数查询耗时分析") + print("=" * 80) + + query = "三年级语文学习" + tenant_id = "szmp@ash@2026" + metadata_filter = {"grade": "三年级", "subject": "语文"} + + # 1. Embedding 生成 + print("\n📊 1. Embedding 生成") + print("-" * 80) + from app.services.embedding import get_embedding_provider + + start = time.time() + embedding_service = await get_embedding_provider() + init_time = time.time() - start + + start = time.time() + embedding_result = await embedding_service.embed_query(query) + embed_time = (time.time() - start) * 1000 + + # 获取 embedding 向量 + if hasattr(embedding_result, 'embedding_full'): + query_vector = embedding_result.embedding_full + elif hasattr(embedding_result, 'embedding'): + query_vector = embedding_result.embedding + else: + query_vector = embedding_result + + print(f" 初始化时间: {init_time * 1000:.2f} ms") + print(f" Embedding 生成: {embed_time:.2f} ms") + print(f" 向量维度: {len(query_vector)}") + + # 2. 获取 collections 列表 + print("\n📊 2. 获取 collections 列表") + print("-" * 80) + client = await get_qdrant_client() + qdrant_client = await client.get_client() + + start = time.time() + collections = await qdrant_client.get_collections() + safe_tenant_id = tenant_id.replace('@', '_') + prefix = f"kb_{safe_tenant_id}" + tenant_collections = [ + c.name for c in collections.collections + if c.name.startswith(prefix) + ] + list_time = (time.time() - start) * 1000 + + print(f" 获取 collections: {list_time:.2f} ms") + print(f" Collections: {tenant_collections}") + + # 3. 构建 metadata filter + print("\n📊 3. 构建 metadata filter") + print("-" * 80) + start = time.time() + + must_conditions = [] + for key, value in metadata_filter.items(): + field_path = f"metadata.{key}" + condition = FieldCondition( + key=field_path, + match=MatchValue(value=value), + ) + must_conditions.append(condition) + qdrant_filter = Filter(must=must_conditions) if must_conditions else None + + filter_time = (time.time() - start) * 1000 + print(f" 构建 filter: {filter_time:.2f} ms") + print(f" Filter: {qdrant_filter}") + + # 4. 逐个 collection 搜索(带 filter) + print("\n📊 4. Qdrant 搜索(带 metadata filter)") + print("-" * 80) + + total_search_time = 0 + total_results = 0 + + for collection_name in tenant_collections: + print(f"\n Collection: {collection_name}") + + # 检查是否存在 + start = time.time() + exists = await qdrant_client.collection_exists(collection_name) + check_time = (time.time() - start) * 1000 + print(f" 检查存在: {check_time:.2f} ms") + + if not exists: + print(f" ❌ 不存在") + continue + + # 搜索(带 filter) + start = time.time() + try: + results = await qdrant_client.query_points( + collection_name=collection_name, + query=query_vector, + using="full", + limit=5, + score_threshold=0.5, + query_filter=qdrant_filter, + ) + search_time = (time.time() - start) * 1000 + total_search_time += search_time + total_results += len(results.points) + + print(f" 搜索时间: {search_time:.2f} ms") + print(f" 结果数: {len(results.points)}") + except Exception as e: + print(f" ❌ 搜索失败: {e}") + + print(f"\n 总搜索时间: {total_search_time:.2f} ms") + print(f" 总结果数: {total_results}") + + # 5. 对比:不带 filter 的搜索 + print("\n📊 5. Qdrant 搜索(不带 metadata filter)对比") + print("-" * 80) + + total_search_time_no_filter = 0 + total_results_no_filter = 0 + + for collection_name in tenant_collections: + start = time.time() + try: + results = await qdrant_client.query_points( + collection_name=collection_name, + query=query_vector, + using="full", + limit=5, + score_threshold=0.5, + # 不带 filter + ) + search_time = (time.time() - start) * 1000 + total_search_time_no_filter += search_time + total_results_no_filter += len(results.points) + except Exception as e: + print(f" {collection_name}: 失败 {e}") + + print(f" 总搜索时间(无 filter): {total_search_time_no_filter:.2f} ms") + print(f" 总结果数(无 filter): {total_results_no_filter}") + + # 6. 完整 KB Search 流程 + print("\n📊 6. 完整 KB Search 流程(带 context)") + print("-" * 80) + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + config = KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=30000, + min_score_threshold=0.5, + ) + + tool = KbSearchDynamicTool(session=session, config=config) + + start = time.time() + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene="学习方案", + top_k=5, + context=metadata_filter, + ) + total_time = (time.time() - start) * 1000 + + print(f" 总耗时: {total_time:.2f} ms") + print(f" 工具内部耗时: {result.duration_ms} ms") + print(f" 结果数: {len(result.hits)}") + print(f" 应用的 filter: {result.applied_filter}") + + # 7. 对比:不带 context 的完整流程 + print("\n📊 7. 完整 KB Search 流程(不带 context)") + print("-" * 80) + + async with async_session() as session: + tool = KbSearchDynamicTool(session=session, config=config) + + start = time.time() + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene="学习方案", + top_k=5, + # 不带 context + ) + total_time_no_context = (time.time() - start) * 1000 + + print(f" 总耗时: {total_time_no_context:.2f} ms") + print(f" 工具内部耗时: {result.duration_ms} ms") + print(f" 结果数: {len(result.hits)}") + + # 8. 总结 + print("\n" + "=" * 80) + print("📈 耗时分析总结") + print("=" * 80) + + print(f"\n带 metadata filter:") + print(f" Embedding: {embed_time:.2f} ms") + print(f" 获取 collections: {list_time:.2f} ms") + print(f" Qdrant 搜索: {total_search_time:.2f} ms") + print(f" 完整流程: {total_time:.2f} ms") + + print(f"\n不带 metadata filter:") + print(f" Qdrant 搜索: {total_search_time_no_filter:.2f} ms") + print(f" 完整流程: {total_time_no_context:.2f} ms") + + print(f"\nMetadata filter 额外开销:") + print(f" Qdrant 搜索: {total_search_time - total_search_time_no_filter:.2f} ms") + print(f" 完整流程: {total_time - total_time_no_context:.2f} ms") + + if total_search_time > total_search_time_no_filter: + print(f"\n⚠️ 带 filter 的搜索更慢,可能原因:") + print(f" - Filter 增加了索引查找的复杂度") + print(f" - 需要匹配 metadata 字段") + print(f" - 建议: 检查 Qdrant 的 payload 索引配置") + + +if __name__ == "__main__": + asyncio.run(profile_step_by_step()) diff --git a/ai-service/scripts/profile_kb_search.py b/ai-service/scripts/profile_kb_search.py new file mode 100644 index 0000000..f4b166e --- /dev/null +++ b/ai-service/scripts/profile_kb_search.py @@ -0,0 +1,259 @@ +""" +知识库检索性能分析脚本 +详细分析每个环节的耗时 +""" + +import asyncio +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig +from app.services.retrieval.vector_retriever import VectorRetriever +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def profile_embedding_generation(query: str): + """分析 embedding 生成耗时""" + from app.services.embedding import get_embedding_provider + + start = time.time() + embedding_service = await get_embedding_provider() + init_time = time.time() - start + + start = time.time() + embedding = await embedding_service.embed_query(query) + embed_time = time.time() - start + + # 获取 embedding 向量(兼容不同 provider) + if hasattr(embedding, 'embedding_full'): + vector = embedding.embedding_full + elif hasattr(embedding, 'embedding'): + vector = embedding.embedding + else: + vector = embedding + + return { + "init_time_ms": init_time * 1000, + "embed_time_ms": embed_time * 1000, + "dimension": len(vector), + } + + +async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None): + """分析 Qdrant 搜索耗时""" + from app.core.qdrant_client import get_qdrant_client + + client = await get_qdrant_client() + + # 获取 collections + start = time.time() + qdrant_client = await client.get_client() + collections = await qdrant_client.get_collections() + safe_tenant_id = tenant_id.replace('@', '_') + prefix = f"kb_{safe_tenant_id}" + tenant_collections = [ + c.name for c in collections.collections + if c.name.startswith(prefix) + ] + list_collections_time = time.time() - start + + # 逐个 collection 搜索 + collection_times = [] + for collection_name in tenant_collections: + start = time.time() + exists = await qdrant_client.collection_exists(collection_name) + check_time = time.time() - start + + if not exists: + collection_times.append({ + "collection": collection_name, + "exists": False, + "time_ms": check_time * 1000, + }) + continue + + start = time.time() + # 构建 filter + qdrant_filter = None + if metadata_filter: + from qdrant_client.models import FieldCondition, Filter, MatchValue + must_conditions = [] + for key, value in metadata_filter.items(): + field_path = f"metadata.{key}" + condition = FieldCondition( + key=field_path, + match=MatchValue(value=value), + ) + must_conditions.append(condition) + qdrant_filter = Filter(must=must_conditions) if must_conditions else None + + try: + results = await qdrant_client.query_points( + collection_name=collection_name, + query=query_vector, + using="full", # 使用 full 向量 + limit=5, + score_threshold=0.5, + query_filter=qdrant_filter, + ) + except Exception as e: + if "vector name" in str(e).lower(): + # 尝试不使用 vector name + results = await qdrant_client.query_points( + collection_name=collection_name, + query=query_vector, + limit=5, + score_threshold=0.5, + query_filter=qdrant_filter, + ) + else: + raise + search_time = time.time() - start + + collection_times.append({ + "collection": collection_name, + "exists": True, + "check_time_ms": check_time * 1000, + "search_time_ms": search_time * 1000, + "results_count": len(results.points), + }) + + return { + "list_collections_time_ms": list_collections_time * 1000, + "collections_count": len(tenant_collections), + "collection_times": collection_times, + } + + +async def profile_full_kb_search(): + """分析完整的知识库搜索流程""" + settings = get_settings() + + print("=" * 80) + print("知识库检索性能分析") + print("=" * 80) + + # 1. 分析 Embedding 生成 + print("\n📊 1. Embedding 生成分析") + print("-" * 80) + query = "三年级语文学习" + embed_result = await profile_embedding_generation(query) + print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms") + print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms") + print(f" 向量维度: {embed_result['dimension']}") + + # 2. 分析 Qdrant 搜索 + print("\n📊 2. Qdrant 搜索分析") + print("-" * 80) + + # 先生成 embedding + from app.services.embedding import get_embedding_provider + embedding_service = await get_embedding_provider() + embedding_result = await embedding_service.embed_query(query) + # 获取 embedding 向量(兼容不同 provider) + if hasattr(embedding_result, 'embedding_full'): + query_vector = embedding_result.embedding_full + elif hasattr(embedding_result, 'embedding'): + query_vector = embedding_result.embedding + else: + query_vector = embedding_result + + tenant_id = "szmp@ash@2026" + metadata_filter = {"grade": "三年级", "subject": "语文"} + + qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter) + print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms") + print(f" Collections 数量: {qdrant_result['collections_count']}") + print(f"\n 各 Collection 搜索耗时:") + for ct in qdrant_result['collection_times']: + if ct['exists']: + print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)") + else: + print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)") + + total_search_time = sum( + ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times'] + ) + print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms") + + # 3. 分析完整流程 + print("\n📊 3. 完整 KB Search 流程分析") + print("-" * 80) + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + config = KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=30000, # 30秒 + min_score_threshold=0.5, + ) + + tool = KbSearchDynamicTool(session=session, config=config) + + # 记录各阶段时间 + stages = [] + + start_total = time.time() + + # 执行搜索 + start = time.time() + result = await tool.execute( + query=query, + tenant_id=tenant_id, + scene="学习方案", + top_k=5, + context=metadata_filter, + ) + total_time = (time.time() - start_total) * 1000 + + print(f" 总耗时: {total_time:.2f} ms") + print(f" 结果: success={result.success}, hits={len(result.hits)}") + print(f" 工具内部耗时: {result.duration_ms} ms") + + # 计算时间差(工具内部 vs 外部测量) + overhead = total_time - result.duration_ms + print(f" 额外开销(初始化等): {overhead:.2f} ms") + + # 4. 性能瓶颈分析 + print("\n📊 4. 性能瓶颈分析") + print("-" * 80) + + embedding_time = embed_result['embed_time_ms'] + qdrant_time = total_search_time + total_measured = embedding_time + qdrant_time + + print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)") + print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)") + print(f" 其他开销: {total_time - total_measured:.2f} ms") + + print("\n" + "=" * 80) + print("优化建议:") + print("=" * 80) + + if embedding_time > 1000: + print(" ⚠️ Embedding 生成较慢,考虑:") + print(" - 使用更快的 embedding 模型") + print(" - 增加 embedding 服务缓存") + + if qdrant_time > 1000: + print(" ⚠️ Qdrant 搜索较慢,考虑:") + print(" - 并行查询多个 collections") + print(" - 优化 Qdrant 索引") + print(" - 减少 collections 数量") + + if len(qdrant_result['collection_times']) > 3: + print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)") + print(" - 建议合并或归档空/少数据的 collections") + + +if __name__ == "__main__": + asyncio.run(profile_full_kb_search()) diff --git a/ai-service/scripts/query_collection_points.py b/ai-service/scripts/query_collection_points.py new file mode 100644 index 0000000..ee6bdb4 --- /dev/null +++ b/ai-service/scripts/query_collection_points.py @@ -0,0 +1,120 @@ +""" +查询 Qdrant Collection 中的所有内容 +""" + +import asyncio +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient +from qdrant_client.models import ScrollRequest +from app.core.config import get_settings + + +async def query_all_points(collection_name: str): + """查询 collection 中的所有 points""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url, check_compatibility=False) + + print(f"🔍 查询 Collection: {collection_name}") + print("=" * 80) + + try: + # 获取 collection 信息 + info = await client.get_collection(collection_name) + total_points = info.points_count + print(f"📊 总向量数: {total_points}\n") + + # 分页获取所有 points + all_points = [] + offset = None + batch_size = 100 + + while True: + scroll_result = await client.scroll( + collection_name=collection_name, + offset=offset, + limit=batch_size, + with_payload=True, + with_vectors=False + ) + + points, next_offset = scroll_result + all_points.extend(points) + + if next_offset is None: + break + offset = next_offset + + # 显示进度 + if len(all_points) % 500 == 0: + print(f" 已获取 {len(all_points)} / {total_points} 条记录...") + + print(f"✅ 成功获取全部 {len(all_points)} 条记录\n") + print("=" * 80) + + # 显示所有内容 + for i, point in enumerate(all_points, 1): + payload = point.payload or {} + + print(f"\n📄 记录 {i}/{len(all_points)} (ID: {point.id})") + print("-" * 80) + + # 显示主要字段 + text = payload.get('text', '') + kb_id = payload.get('kb_id', 'N/A') + source = payload.get('source', 'N/A') + chunk_index = payload.get('chunk_index', 'N/A') + metadata = payload.get('metadata', {}) + + print(f" KB ID: {kb_id}") + print(f" Source: {source}") + print(f" Chunk Index: {chunk_index}") + + if metadata: + print(f" Metadata: {json.dumps(metadata, ensure_ascii=False)}") + + # 显示文本内容(格式化) + print(f"\n 文本内容:") + if text: + # 按行显示,保持格式 + lines = text.split('\n') + for line in lines: + if line.strip(): + print(f" {line}") + else: + print(" (无文本内容)") + + print("\n" + "=" * 80) + print(f"✅ 查询完成,共 {len(all_points)} 条记录") + + # 统计信息 + print("\n📈 统计信息:") + kb_ids = {} + for point in all_points: + payload = point.payload or {} + kb_id = payload.get('kb_id', 'N/A') + kb_ids[kb_id] = kb_ids.get(kb_id, 0) + 1 + + print(f" KB ID 分布:") + for kb_id, count in sorted(kb_ids.items()): + print(f" - {kb_id}: {count} 条") + + except Exception as e: + print(f"❌ 查询失败: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +async def main(): + collection_name = "kb_szmp_ash_2026_30c19c84" + await query_all_points(collection_name) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ai-service/scripts/resume_index_jobs.py b/ai-service/scripts/resume_index_jobs.py new file mode 100644 index 0000000..eaefa6a --- /dev/null +++ b/ai-service/scripts/resume_index_jobs.py @@ -0,0 +1,305 @@ +""" +恢复处理中断的索引任务 +用于服务重启后继续处理pending/processing状态的任务 +""" + +import asyncio +import logging +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import select +from app.core.database import async_session_maker +from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus +from app.api.admin.kb import _index_document + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +async def resume_pending_jobs(): + """恢复所有pending和processing状态的任务""" + async with async_session_maker() as session: + # 查询所有未完成的任务 + result = await session.execute( + select(IndexJob).where( + IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value]) + ) + ) + pending_jobs = result.scalars().all() + + if not pending_jobs: + logger.info("没有需要恢复的任务") + return + + logger.info(f"发现 {len(pending_jobs)} 个未完成的任务") + + for job in pending_jobs: + try: + # 获取关联的文档 + doc_result = await session.execute( + select(Document).where(Document.id == job.doc_id) + ) + doc = doc_result.scalar_one_or_none() + + if not doc: + logger.error(f"找不到文档: {job.doc_id}") + continue + + if not doc.file_path or not Path(doc.file_path).exists(): + logger.error(f"文档文件不存在: {doc.file_path}") + # 标记为失败 + job.status = IndexJobStatus.FAILED.value + job.error_msg = "文档文件不存在" + doc.status = DocumentStatus.FAILED.value + doc.error_msg = "文档文件不存在" + await session.commit() + continue + + logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}") + + # 读取文件内容 + with open(doc.file_path, 'rb') as f: + file_content = f.read() + + # 重置任务状态为pending + job.status = IndexJobStatus.PENDING.value + job.progress = 0 + job.error_msg = None + await session.commit() + + # 启动后台任务处理 + # 注意:这里我们直接调用,不使用background_tasks + await process_job( + tenant_id=job.tenant_id, + kb_id=doc.kb_id, + job_id=str(job.id), + doc_id=str(doc.id), + file_content=file_content, + filename=doc.file_name, + metadata=doc.doc_metadata or {} + ) + + logger.info(f"任务处理完成: job_id={job.id}") + + except Exception as e: + logger.error(f"处理任务失败: job_id={job.id}, error={e}") + # 标记为失败 + job.status = IndexJobStatus.FAILED.value + job.error_msg = str(e) + if doc: + doc.status = DocumentStatus.FAILED.value + doc.error_msg = str(e) + await session.commit() + + logger.info("所有任务处理完成") + + +async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str, + file_content: bytes, filename: str, metadata: dict): + """ + 处理单个索引任务 + 复制自 _index_document 函数 + """ + import tempfile + from pathlib import Path + + from qdrant_client.models import PointStruct + + from app.core.qdrant_client import get_qdrant_client + from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document + from app.services.embedding import get_embedding_provider + from app.services.kb import KBService + from app.api.admin.kb import chunk_text_by_lines, TextChunk + + logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}") + + async with async_session_maker() as session: + kb_service = KBService(session) + try: + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10 + ) + await session.commit() + + parse_result = None + text = None + file_ext = Path(filename or "").suffix.lower() + logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes") + + text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"} + + if file_ext in text_extensions or not file_ext: + logger.info("[RESUME] Treating as text file") + text = None + for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]: + try: + text = file_content.decode(encoding) + logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}") + break + except (UnicodeDecodeError, LookupError): + continue + + if text is None: + text = file_content.decode("utf-8", errors="replace") + else: + logger.info("[RESUME] Binary file detected, will parse with document parser") + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15 + ) + await session.commit() + + with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: + tmp_file.write(file_content) + tmp_path = tmp_file.name + + logger.info(f"[RESUME] Temp file created: {tmp_path}") + + try: + logger.info(f"[RESUME] Starting document parsing for {file_ext}...") + parse_result = parse_document(tmp_path) + text = parse_result.text + logger.info( + f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}" + ) + except UnsupportedFormatError as e: + logger.error(f"[RESUME] UnsupportedFormatError: {e}") + text = file_content.decode("utf-8", errors="ignore") + except DocumentParseException as e: + logger.error(f"[RESUME] DocumentParseException: {e}") + text = file_content.decode("utf-8", errors="ignore") + except Exception as e: + logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}") + text = file_content.decode("utf-8", errors="ignore") + finally: + Path(tmp_path).unlink(missing_ok=True) + + logger.info(f"[RESUME] Final text length: {len(text)} chars") + + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20 + ) + await session.commit() + + logger.info("[RESUME] Getting embedding provider...") + embedding_provider = await get_embedding_provider() + logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}") + + all_chunks: list[TextChunk] = [] + + if parse_result and parse_result.pages: + logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages") + for page in parse_result.pages: + page_chunks = chunk_text_by_lines( + page.text, + min_line_length=10, + source=filename, + ) + for pc in page_chunks: + pc.page = page.page + all_chunks.extend(page_chunks) + else: + logger.info("[RESUME] Using line-based chunking") + all_chunks = chunk_text_by_lines( + text, + min_line_length=10, + source=filename, + ) + + logger.info(f"[RESUME] Total chunks: {len(all_chunks)}") + + qdrant = await get_qdrant_client() + await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True) + + from app.services.embedding.nomic_provider import NomicEmbeddingProvider + use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider) + logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}") + + import uuid + points = [] + total_chunks = len(all_chunks) + doc_metadata = metadata or {} + + for i, chunk in enumerate(all_chunks): + payload = { + "text": chunk.text, + "source": doc_id, + "kb_id": kb_id, + "chunk_index": i, + "start_token": chunk.start_token, + "end_token": chunk.end_token, + "metadata": doc_metadata, + } + if chunk.page is not None: + payload["page"] = chunk.page + if chunk.source: + payload["filename"] = chunk.source + + if use_multi_vector: + embedding_result = await embedding_provider.embed_document(chunk.text) + points.append({ + "id": str(uuid.uuid4()), + "vector": { + "full": embedding_result.embedding_full, + "dim_256": embedding_result.embedding_256, + "dim_512": embedding_result.embedding_512, + }, + "payload": payload, + }) + else: + embedding = await embedding_provider.embed(chunk.text) + points.append( + PointStruct( + id=str(uuid.uuid4()), + vector=embedding, + payload=payload, + ) + ) + + progress = 20 + int((i + 1) / total_chunks * 70) + if i % 10 == 0 or i == total_chunks - 1: + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress + ) + await session.commit() + + if points: + logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...") + if use_multi_vector: + await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id) + else: + await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id) + + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100 + ) + await session.commit() + + logger.info( + f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, " + f"job_id={job_id}, chunks={len(all_chunks)}" + ) + + except Exception as e: + import traceback + logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}") + await session.rollback() + async with async_session_maker() as error_session: + kb_service = KBService(error_session) + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.FAILED.value, + progress=0, error_msg=str(e) + ) + await error_session.commit() + + +if __name__ == "__main__": + logger.info("开始恢复索引任务...") + asyncio.run(resume_pending_jobs()) + logger.info("恢复脚本执行完成") diff --git a/ai-service/scripts/run_migration_011.py b/ai-service/scripts/run_migration_011.py new file mode 100644 index 0000000..348dbc4 --- /dev/null +++ b/ai-service/scripts/run_migration_011.py @@ -0,0 +1,47 @@ +""" +Migration script to add intent_vector and semantic_examples fields. +Run: python scripts/run_migration_011.py +""" + +import asyncio +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sqlalchemy import text +from app.core.database import engine + + +async def run_migration(): + """Execute migration to add new fields.""" + migration_sql = """ + -- Add intent_vector column (JSONB for storing pre-computed embedding vectors) + ALTER TABLE intent_rules + ADD COLUMN IF NOT EXISTS intent_vector JSONB; + + -- Add semantic_examples column (JSONB for storing example sentences for dynamic vector computation) + ALTER TABLE intent_rules + ADD COLUMN IF NOT EXISTS semantic_examples JSONB; + + -- Add route_trace column to chat_messages table if not exists + ALTER TABLE chat_messages + ADD COLUMN IF NOT EXISTS route_trace JSONB; + """ + + async with engine.begin() as conn: + for statement in migration_sql.strip().split(";"): + statement = statement.strip() + if statement and not statement.startswith("--"): + try: + await conn.execute(text(statement)) + print(f"Executed: {statement[:80]}...") + except Exception as e: + print(f"Error executing: {statement[:80]}...") + print(f" Error: {e}") + + print("\nMigration completed successfully!") + + +if __name__ == "__main__": + asyncio.run(run_migration()) diff --git a/ai-service/scripts/test_api_search.py b/ai-service/scripts/test_api_search.py new file mode 100644 index 0000000..f605ebc --- /dev/null +++ b/ai-service/scripts/test_api_search.py @@ -0,0 +1,82 @@ +""" +通过 API 测试知识库检索性能 +""" + +import requests +import json +import time + +API_BASE = "http://localhost:8000" +API_KEY = "oQfkSAbL8iafzyHxqb--G7zRWSOYJHvlzQxia2KpYms" +TENANT_ID = "szmp@ash@2026" + +def test_kb_search(): + """测试知识库搜索 API""" + print("=" * 80) + print("测试知识库检索 API") + print("=" * 80) + + headers = { + "Content-Type": "application/json", + "X-API-Key": API_KEY, + "X-Tenant-Id": TENANT_ID, + } + + # 测试数据 + test_cases = [ + { + "name": "完整参数(含context过滤)", + "data": { + "query": "三年级语文学习", + "scene": "学习方案", + "top_k": 5, + "context": {"grade": "三年级", "subject": "语文"}, + } + }, + { + "name": "简化参数(无context)", + "data": { + "query": "三年级语文学习", + "scene": "学习方案", + "top_k": 5, + } + }, + ] + + for test_case in test_cases: + print(f"\n{'='*80}") + print(f"测试: {test_case['name']}") + print(f"{'='*80}") + print(f"请求数据: {json.dumps(test_case['data'], ensure_ascii=False)}") + + try: + start = time.time() + response = requests.post( + f"{API_BASE}/api/v1/mid/kb-search-dynamic", + headers=headers, + json=test_case['data'], + timeout=30, + ) + elapsed = (time.time() - start) * 1000 + + print(f"\n响应状态: {response.status_code}") + print(f"总耗时: {elapsed:.2f} ms") + + if response.status_code == 200: + result = response.json() + print(f"API 结果:") + print(f" success: {result.get('success')}") + print(f" hits count: {len(result.get('hits', []))}") + print(f" duration_ms: {result.get('duration_ms')}") + print(f" applied_filter: {result.get('applied_filter')}") + else: + print(f"错误: {response.text}") + + except Exception as e: + print(f"请求失败: {e}") + + print("\n" + "=" * 80) + + +if __name__ == "__main__": + test_kb_search() diff --git a/ai-service/scripts/test_cache.py b/ai-service/scripts/test_cache.py new file mode 100644 index 0000000..b150774 --- /dev/null +++ b/ai-service/scripts/test_cache.py @@ -0,0 +1,73 @@ +""" +测试 Redis 缓存是否正常工作 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.services.metadata_cache_service import get_metadata_cache_service + + +async def test_cache(): + """测试缓存服务""" + print("=" * 80) + print("测试 Redis 缓存") + print("=" * 80) + + cache_service = await get_metadata_cache_service() + tenant_id = "szmp@ash@2026" + + # 1. 检查缓存是否存在 + print("\n📊 1. 检查缓存是否存在") + cached = await cache_service.get_fields(tenant_id) + if cached: + print(f" ✅ 缓存存在,包含 {len(cached)} 个字段") + for field in cached[:3]: + print(f" - {field['field_key']}: {field['label']}") + else: + print(" ❌ 缓存不存在") + + # 2. 手动设置缓存 + print("\n📊 2. 手动设置测试缓存") + test_fields = [ + { + "field_key": "grade", + "label": "年级", + "field_type": "enum", + "required": False, + "options": ["三年级", "四年级", "五年级"], + "default_value": None, + "is_filterable": True, + }, + { + "field_key": "subject", + "label": "学科", + "field_type": "enum", + "required": False, + "options": ["语文", "数学", "英语"], + "default_value": None, + "is_filterable": True, + }, + ] + + result = await cache_service.set_fields(tenant_id, test_fields, ttl=3600) + print(f" 设置缓存结果: {result}") + + # 3. 再次获取缓存 + print("\n📊 3. 再次获取缓存") + cached = await cache_service.get_fields(tenant_id) + if cached: + print(f" ✅ 缓存存在,包含 {len(cached)} 个字段") + else: + print(" ❌ 缓存不存在") + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + +if __name__ == "__main__": + asyncio.run(test_cache()) diff --git a/ai-service/scripts/test_dynamic_tool_schema.py b/ai-service/scripts/test_dynamic_tool_schema.py new file mode 100644 index 0000000..57fb4c4 --- /dev/null +++ b/ai-service/scripts/test_dynamic_tool_schema.py @@ -0,0 +1,60 @@ +""" +测试动态生成的工具 Schema +""" + +import asyncio +import sys +import json +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from app.core.config import get_settings +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool + + +async def test_dynamic_tool_schema(): + """测试动态生成的工具 Schema""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + tenant_id = "szmp@ash@2026" + + print(f"\n{'='*80}") + print(f"测试动态生成的工具 Schema") + print(f"{'='*80}") + print(f"租户 ID: {tenant_id}") + + async with async_session() as session: + tool = KbSearchDynamicTool(session) + + # 获取静态 Schema + static_schema = tool.get_tool_schema() + print(f"\n--- 静态 Schema ---") + print(json.dumps(static_schema, indent=2, ensure_ascii=False)) + + # 获取动态 Schema + dynamic_schema = await tool.get_dynamic_tool_schema(tenant_id) + print(f"\n--- 动态 Schema ---") + print(json.dumps(dynamic_schema, indent=2, ensure_ascii=False)) + + # 再次获取,测试缓存 + print(f"\n--- 测试缓存 ---") + dynamic_schema2 = await tool.get_dynamic_tool_schema(tenant_id) + print(f"缓存命中: {dynamic_schema == dynamic_schema2}") + + # 打印 context 字段的详细结构 + print(f"\n--- context 字段详情 ---") + context_props = dynamic_schema["parameters"]["properties"].get("context", {}).get("properties", {}) + print(f"过滤字段数量: {len(context_props)}") + for key, value in context_props.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + asyncio.run(test_dynamic_tool_schema()) diff --git a/ai-service/scripts/test_kb_search.py b/ai-service/scripts/test_kb_search.py new file mode 100644 index 0000000..26ab619 --- /dev/null +++ b/ai-service/scripts/test_kb_search.py @@ -0,0 +1,93 @@ +""" +测试 kb_search_dynamic 工具是否能用给定的参数查出数据 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool +from app.core.config import get_settings + + +async def test_kb_search(): + """测试知识库搜索""" + settings = get_settings() + + # 创建数据库会话 + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + tool = KbSearchDynamicTool(session=session) + + # 测试参数 + test_cases = [ + { + "name": "完整参数(含context过滤)", + "params": { + "query": "三年级语文学习", + "tenant_id": "szmp@ash@2026", + "scene": "学习方案", + "top_k": 5, + "context": {"grade": "三年级", "subject": "语文"}, + } + }, + { + "name": "简化参数(无context)", + "params": { + "query": "三年级语文学习", + "tenant_id": "szmp@ash@2026", + "scene": "学习方案", + "top_k": 5, + } + }, + { + "name": "仅query和tenant_id", + "params": { + "query": "三年级语文学习", + "tenant_id": "szmp@ash@2026", + "top_k": 5, + } + }, + ] + + for test_case in test_cases: + print(f"\n{'='*80}") + print(f"测试: {test_case['name']}") + print(f"{'='*80}") + print(f"参数: {test_case['params']}") + + try: + result = await tool.execute(**test_case['params']) + + print(f"\n结果:") + print(f" success: {result.success}") + print(f" hits count: {len(result.hits)}") + print(f" applied_filter: {result.applied_filter}") + print(f" fallback_reason_code: {result.fallback_reason_code}") + print(f" duration_ms: {result.duration_ms}") + + if result.hits: + print(f"\n 前3条结果:") + for i, hit in enumerate(result.hits[:3], 1): + text = hit.get('text', '')[:80] + '...' if hit.get('text') else 'N/A' + score = hit.get('score', 0) + metadata = hit.get('metadata', {}) + print(f" {i}. [score={score:.4f}] {text}") + print(f" metadata: {metadata}") + else: + print(f"\n ⚠️ 没有命中任何结果") + + except Exception as e: + print(f"\n ❌ 错误: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(test_kb_search()) diff --git a/ai-service/scripts/test_kb_search_course.py b/ai-service/scripts/test_kb_search_course.py new file mode 100644 index 0000000..2d451e3 --- /dev/null +++ b/ai-service/scripts/test_kb_search_course.py @@ -0,0 +1,120 @@ +""" +测试 kb_search_dynamic 工具 - 课程咨询场景 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicTool, + KbSearchDynamicConfig, + StepKbConfig, +) +from app.core.config import get_settings + + +async def test_kb_search(): + """测试知识库搜索 - 课程咨询场景""" + settings = get_settings() + + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + config = KbSearchDynamicConfig( + enabled=True, + top_k=10, + timeout_ms=15000, + min_score_threshold=0.3, + ) + + tool = KbSearchDynamicTool(session=session, config=config) + + course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + step_kb_config = StepKbConfig( + allowed_kb_ids=[course_kb_id], + preferred_kb_ids=[course_kb_id], + step_id="test_course_query", + ) + + test_params = { + "query": "课程介绍", + "tenant_id": "szmp@ash@2026", + "top_k": 10, + "context": { + "grade": "五年级", + }, + "step_kb_config": step_kb_config, + } + + print(f"\n{'='*80}") + print(f"测试: kb_search_dynamic - 课程知识库") + print(f"{'='*80}") + print(f"参数: query={test_params['query']}") + print(f" tenant_id={test_params['tenant_id']}") + print(f" context={test_params['context']}") + print(f" step_kb_config.allowed_kb_ids={step_kb_config.allowed_kb_ids}") + print(f"超时设置: {config.timeout_ms}ms") + print(f"最低分数阈值: {config.min_score_threshold}") + + try: + result = await tool.execute(**test_params) + + print(f"\n结果:") + print(f" success: {result.success}") + print(f" hits count: {len(result.hits)}") + print(f" applied_filter: {result.applied_filter}") + print(f" fallback_reason_code: {result.fallback_reason_code}") + print(f" duration_ms: {result.duration_ms}") + + if result.filter_debug: + print(f" filter_debug: {result.filter_debug}") + + if result.step_kb_binding: + print(f" step_kb_binding: {result.step_kb_binding}") + + if result.tool_trace: + print(f"\n Tool Trace:") + print(f" tool_name: {result.tool_trace.tool_name}") + print(f" status: {result.tool_trace.status}") + print(f" duration_ms: {result.tool_trace.duration_ms}") + print(f" args_digest: {result.tool_trace.args_digest}") + print(f" result_digest: {result.tool_trace.result_digest}") + if hasattr(result.tool_trace, 'arguments') and result.tool_trace.arguments: + print(f" arguments: {result.tool_trace.arguments}") + + if result.hits: + print(f"\n 检索结果 (共 {len(result.hits)} 条):") + for i, hit in enumerate(result.hits, 1): + text = hit.get('text', '') + text_preview = text[:200] + '...' if len(text) > 200 else text + score = hit.get('score', 0) + metadata = hit.get('metadata', {}) + collection = hit.get('collection', 'unknown') + kb_id = hit.get('kb_id', 'unknown') + print(f"\n [{i}] score={score:.4f}") + print(f" collection: {collection}") + print(f" kb_id: {kb_id}") + print(f" metadata: {metadata}") + print(f" text: {text_preview}") + else: + print(f"\n ⚠️ 没有命中任何结果") + print(f" 请检查:") + print(f" 1. 知识库是否有数据") + print(f" 2. 向量是否正确生成") + print(f" 3. 过滤条件是否过于严格") + + except Exception as e: + print(f"\n ❌ 错误: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(test_kb_search()) diff --git a/ai-service/scripts/test_kb_search_v2.py b/ai-service/scripts/test_kb_search_v2.py new file mode 100644 index 0000000..3d88b0c --- /dev/null +++ b/ai-service/scripts/test_kb_search_v2.py @@ -0,0 +1,96 @@ +""" +测试 kb_search_dynamic 工具 - 增加超时时间 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig +from app.core.config import get_settings + + +async def test_kb_search(): + """测试知识库搜索 - 增加超时时间""" + settings = get_settings() + + # 创建数据库会话 + engine = create_async_engine(settings.database_url) + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + # 使用更长的超时时间 + config = KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=10000, # 10秒超时 + min_score_threshold=0.5, + ) + + tool = KbSearchDynamicTool(session=session, config=config) + + # 测试参数 + test_cases = [ + { + "name": "完整参数(含context过滤)", + "params": { + "query": "三年级语文学习", + "tenant_id": "szmp@ash@2026", + "scene": "学习方案", + "top_k": 5, + "context": {"grade": "三年级", "subject": "语文"}, + } + }, + { + "name": "简化参数(无context)", + "params": { + "query": "三年级语文学习", + "tenant_id": "szmp@ash@2026", + "scene": "学习方案", + "top_k": 5, + } + }, + ] + + for test_case in test_cases: + print(f"\n{'='*80}") + print(f"测试: {test_case['name']}") + print(f"{'='*80}") + print(f"参数: {test_case['params']}") + print(f"超时设置: {config.timeout_ms}ms") + + try: + result = await tool.execute(**test_case['params']) + + print(f"\n结果:") + print(f" success: {result.success}") + print(f" hits count: {len(result.hits)}") + print(f" applied_filter: {result.applied_filter}") + print(f" fallback_reason_code: {result.fallback_reason_code}") + print(f" duration_ms: {result.duration_ms}") + + if result.hits: + print(f"\n 所有结果:") + for i, hit in enumerate(result.hits, 1): + text = hit.get('text', '')[:100] + '...' if hit.get('text') else 'N/A' + score = hit.get('score', 0) + metadata = hit.get('metadata', {}) + collection = hit.get('collection', 'unknown') + print(f" {i}. [score={score:.4f}] [collection={collection}]") + print(f" text: {text}") + print(f" metadata: {metadata}") + else: + print(f"\n ⚠️ 没有命中任何结果") + + except Exception as e: + print(f"\n ❌ 错误: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(test_kb_search()) diff --git a/ai-service/scripts/test_qdrant_filter.py b/ai-service/scripts/test_qdrant_filter.py new file mode 100644 index 0000000..3d4578c --- /dev/null +++ b/ai-service/scripts/test_qdrant_filter.py @@ -0,0 +1,89 @@ +""" +直接测试 Qdrant 过滤功能 +""" + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client.models import FieldCondition, Filter, MatchValue + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient + + +async def test_qdrant_filter(): + """直接测试 Qdrant 过滤功能""" + settings = get_settings() + client = QdrantClient() + qdrant = await client.get_client() + + tenant_id = "szmp@ash@2026" + kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b" + + collection_name = client.get_kb_collection_name(tenant_id, kb_id) + + print(f"\n{'='*80}") + print(f"测试 Qdrant 过滤功能") + print(f"{'='*80}") + print(f"Collection: {collection_name}") + + # 测试 1: 无过滤 + print(f"\n--- 测试 1: 无过滤 ---") + results = await qdrant.scroll( + collection_name=collection_name, + limit=5, + with_vectors=False, + ) + print(f"无过滤结果数: {len(results[0])}") + for p in results[0][:3]: + print(f" grade: {p.payload.get('metadata', {}).get('grade')}") + + # 测试 2: 使用 Filter 对象过滤 + print(f"\n--- 测试 2: 使用 Filter 对象过滤 (grade=五年级) ---") + qdrant_filter = Filter( + must=[ + FieldCondition( + key="metadata.grade", + match=MatchValue(value="五年级"), + ) + ] + ) + print(f"Filter: {qdrant_filter}") + + results = await qdrant.scroll( + collection_name=collection_name, + limit=10, + with_vectors=False, + scroll_filter=qdrant_filter, + ) + print(f"过滤后结果数: {len(results[0])}") + for p in results[0]: + print(f" grade: {p.payload.get('metadata', {}).get('grade')}, text: {p.payload.get('text', '')[:50]}...") + + # 测试 3: 使用 query_points 过滤 + print(f"\n--- 测试 3: 使用 query_points 过滤 ---") + # 先获取一个向量用于测试 + all_points = await qdrant.scroll( + collection_name=collection_name, + limit=1, + with_vectors=True, + ) + if all_points[0]: + query_vector = all_points[0][0].vector + + results = await qdrant.query_points( + collection_name=collection_name, + query=query_vector, + limit=10, + query_filter=qdrant_filter, + ) + print(f"query_points 过滤后结果数: {len(results.points)}") + for p in results.points: + print(f" grade: {p.payload.get('metadata', {}).get('grade')}, score: {p.score:.4f}") + + +if __name__ == "__main__": + asyncio.run(test_qdrant_filter()) diff --git a/ai-service/scripts/verify_qdrant_collections.py b/ai-service/scripts/verify_qdrant_collections.py new file mode 100644 index 0000000..caedfff --- /dev/null +++ b/ai-service/scripts/verify_qdrant_collections.py @@ -0,0 +1,193 @@ +""" +验证 Qdrant 向量数据库中的 collections 情况 +用于检查 szmp@ash@2026 租户下的知识库 collections +""" + +import asyncio +import sys +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from qdrant_client import AsyncQdrantClient +from app.core.config import get_settings + + +async def list_collections(): + """列出所有 collections""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + print(f"🔗 Qdrant URL: {settings.qdrant_url}") + print(f"📦 Collection Prefix: {settings.qdrant_collection_prefix}") + print("-" * 60) + + try: + collections = await client.get_collections() + + if not collections.collections: + print("⚠️ 没有找到任何 collections") + return + + print(f"✅ 找到 {len(collections.collections)} 个 collections:\n") + + # 过滤出 szmp 相关的 collections + szmp_collections = [] + other_collections = [] + + for collection in collections.collections: + name = collection.name + if "szmp" in name.lower(): + szmp_collections.append(name) + else: + other_collections.append(name) + + # 显示 szmp 相关的 collections + if szmp_collections: + print(f"🎯 szmp@ash@2026 租户相关的 collections ({len(szmp_collections)} 个):") + print("-" * 60) + for name in sorted(szmp_collections): + try: + info = await client.get_collection(name) + points_count = info.points_count if hasattr(info, 'points_count') else 'N/A' + print(f" 📁 {name}") + print(f" └─ 向量数量: {points_count}") + + # 获取 collection 信息 + if hasattr(info, 'config') and hasattr(info.config, 'params'): + params = info.config.params + if hasattr(params, 'vectors'): + vector_params = params.vectors + if hasattr(vector_params, 'size'): + print(f" └─ 向量维度: {vector_params.size}") + if hasattr(vector_params, 'distance'): + print(f" └─ 距离函数: {vector_params.distance}") + print() + except Exception as e: + print(f" 📁 {name}") + print(f" └─ 获取信息失败: {e}\n") + else: + print("⚠️ 没有找到 szmp@ash@2026 租户相关的 collections\n") + + # 显示其他 collections + if other_collections: + print(f"📂 其他 collections ({len(other_collections)} 个):") + print("-" * 60) + for name in sorted(other_collections): + try: + info = await client.get_collection(name) + points_count = info.points_count if hasattr(info, 'points_count') else 'N/A' + print(f" 📁 {name} (向量数: {points_count})") + except Exception as e: + print(f" 📁 {name} (获取信息失败: {e})") + + print("\n" + "=" * 60) + print("📊 总结:") + print(f" - Collections 总数: {len(collections.collections)}") + print(f" - szmp 相关: {len(szmp_collections)} 个") + print(f" - 其他: {len(other_collections)} 个") + + # 验证预期 + print("\n✅ 验证:") + if len(szmp_collections) == 2: + print(" ✓ szmp 租户的 collection 数量符合预期 (2个)") + else: + print(f" ⚠️ szmp 租户的 collection 数量不符合预期 (实际: {len(szmp_collections)} 个, 预期: 2个)") + + except Exception as e: + print(f"❌ 连接 Qdrant 失败: {e}") + print(f" 请检查 Qdrant 是否运行在 {settings.qdrant_url}") + finally: + await client.close() + + +async def check_collection_details(collection_name: str): + """查看特定 collection 的详细信息""" + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + try: + print(f"\n📋 Collection '{collection_name}' 详细信息:") + print("-" * 60) + + info = await client.get_collection(collection_name) + print(f" 名称: {collection_name}") + print(f" 向量数量: {info.points_count}") + + if hasattr(info, 'config') and hasattr(info.config, 'params'): + params = info.config.params + + if hasattr(params, 'vectors'): + vector_params = params.vectors + print(f" 向量配置:") + if hasattr(vector_params, 'size'): + print(f" - 维度: {vector_params.size}") + if hasattr(vector_params, 'distance'): + print(f" - 距离函数: {vector_params.distance}") + if hasattr(vector_params, 'on_disk'): + print(f" - 磁盘存储: {vector_params.on_disk}") + + if hasattr(params, 'shard_number'): + print(f" 分片数: {params.shard_number}") + if hasattr(params, 'replication_factor'): + print(f" 副本数: {params.replication_factor}") + + # 获取一些样本数据 + try: + from qdrant_client.models import ScrollRequest + + scroll_result = await client.scroll( + collection_name=collection_name, + limit=3, + with_payload=True, + with_vectors=False + ) + + if scroll_result[0]: + print(f"\n 样本数据 (前3条):") + for i, point in enumerate(scroll_result[0], 1): + payload = point.payload or {} + text = payload.get('text', '')[:50] + '...' if payload.get('text') else 'N/A' + kb_id = payload.get('kb_id', 'N/A') + print(f" {i}. ID: {point.id}") + print(f" KB ID: {kb_id}") + print(f" 文本: {text}") + except Exception as e: + print(f" 获取样本数据失败: {e}") + + except Exception as e: + print(f"❌ 获取 collection 信息失败: {e}") + finally: + await client.close() + + +async def main(): + """主函数""" + print("=" * 60) + print("🔍 Qdrant 向量数据库 Collections 验证工具") + print("=" * 60) + print() + + # 列出所有 collections + await list_collections() + + # 检查 szmp 相关的 collections 详情 + settings = get_settings() + client = AsyncQdrantClient(url=settings.qdrant_url) + + try: + collections = await client.get_collections() + szmp_collections = [c.name for c in collections.collections if "szmp" in c.name.lower()] + + for name in sorted(szmp_collections): + await check_collection_details(name) + + except Exception as e: + print(f"❌ 错误: {e}") + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ai-service/tests/test_batch_ask_back_service.py b/ai-service/tests/test_batch_ask_back_service.py new file mode 100644 index 0000000..0f37e72 --- /dev/null +++ b/ai-service/tests/test_batch_ask_back_service.py @@ -0,0 +1,333 @@ +""" +Tests for Batch Ask-Back Service. +[AC-MRS-SLOT-ASKBACK-01] 批量追问测试 +""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.mid.batch_ask_back_service import ( + AskBackSlot, + BatchAskBackConfig, + BatchAskBackResult, + BatchAskBackService, + create_batch_ask_back_service, +) + + +class TestAskBackSlot: + """AskBackSlot 测试""" + + def test_init(self): + """测试初始化""" + slot = AskBackSlot( + slot_key="region", + label="地区", + ask_back_prompt="请告诉我您的地区", + priority=100, + is_required=True, + ) + assert slot.slot_key == "region" + assert slot.label == "地区" + assert slot.priority == 100 + assert slot.is_required is True + + +class TestBatchAskBackConfig: + """BatchAskBackConfig 测试""" + + def test_default_config(self): + """测试默认配置""" + config = BatchAskBackConfig() + assert config.max_ask_back_slots_per_turn == 2 + assert config.prefer_required is True + assert config.prefer_scene_relevant is True + assert config.avoid_recent_asked is True + assert config.recent_asked_threshold_seconds == 60.0 + assert config.merge_prompts is True + + def test_custom_config(self): + """测试自定义配置""" + config = BatchAskBackConfig( + max_ask_back_slots_per_turn=3, + prefer_required=False, + merge_prompts=False, + ) + assert config.max_ask_back_slots_per_turn == 3 + assert config.prefer_required is False + assert config.merge_prompts is False + + +class TestBatchAskBackResult: + """BatchAskBackResult 测试""" + + def test_has_ask_back(self): + """测试是否有追问""" + result = BatchAskBackResult(ask_back_count=2) + assert result.has_ask_back() is True + + result = BatchAskBackResult(ask_back_count=0) + assert result.has_ask_back() is False + + def test_get_prompt_with_merged(self): + """测试获取合并后的提示""" + result = BatchAskBackResult( + merged_prompt="请告诉我您的地区和产品", + prompts=["请告诉我您的地区", "请告诉我您的产品"], + ask_back_count=2, + ) + assert result.get_prompt() == "请告诉我您的地区和产品" + + def test_get_prompt_without_merged(self): + """测试获取未合并的提示""" + result = BatchAskBackResult( + prompts=["请告诉我您的地区"], + ask_back_count=1, + ) + assert result.get_prompt() == "请告诉我您的地区" + + def test_get_prompt_empty(self): + """测试空结果的提示""" + result = BatchAskBackResult() + assert "请提供更多信息" in result.get_prompt() + + +class TestBatchAskBackService: + """BatchAskBackService 测试""" + + @pytest.fixture + def mock_session(self): + """创建 mock session""" + return AsyncMock() + + @pytest.fixture + def config(self): + """创建配置""" + return BatchAskBackConfig( + max_ask_back_slots_per_turn=2, + prefer_required=True, + merge_prompts=True, + ) + + @pytest.fixture + def service(self, mock_session, config): + """创建服务实例""" + return BatchAskBackService( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + config=config, + ) + + def test_calculate_priority_required(self, service): + """测试必填槽位优先级""" + priority = service._calculate_priority(is_required=True, scene_relevance=0.0) + assert priority == 100 + + def test_calculate_priority_scene_relevant(self, service): + """测试场景相关优先级""" + priority = service._calculate_priority(is_required=False, scene_relevance=1.0) + assert priority == 50 + + def test_calculate_priority_both(self, service): + """测试必填且场景相关优先级""" + priority = service._calculate_priority(is_required=True, scene_relevance=1.0) + assert priority == 150 + + def test_select_slots_for_ask_back(self, service): + """测试选择追问槽位""" + slots = [ + AskBackSlot(slot_key="a", label="A", priority=50), + AskBackSlot(slot_key="b", label="B", priority=100), + AskBackSlot(slot_key="c", label="C", priority=75), + ] + + selected = service._select_slots_for_ask_back(slots) + + assert len(selected) == 2 + assert selected[0].slot_key == "b" + assert selected[1].slot_key == "c" + + def test_filter_recently_asked(self, service): + """测试过滤最近追问过的槽位""" + current_time = time.time() + asked_history = { + "recently_asked": current_time - 30, + "old_asked": current_time - 120, + } + + slots = [ + AskBackSlot(slot_key="recently_asked", label="最近追问过"), + AskBackSlot(slot_key="old_asked", label="很久前追问过"), + AskBackSlot(slot_key="never_asked", label="从未追问过"), + ] + + filtered = service._filter_recently_asked(slots, asked_history) + + assert len(filtered) == 2 + slot_keys = [s.slot_key for s in filtered] + assert "old_asked" in slot_keys + assert "never_asked" in slot_keys + assert "recently_asked" not in slot_keys + + def test_generate_prompts(self, service): + """测试生成追问提示""" + slots = [ + AskBackSlot(slot_key="region", label="地区", ask_back_prompt="请告诉我您的地区"), + AskBackSlot(slot_key="product", label="产品", ask_back_prompt=None), + ] + + prompts = service._generate_prompts(slots) + + assert len(prompts) == 2 + assert prompts[0] == "请告诉我您的地区" + assert "产品" in prompts[1] + + def test_merge_prompts_single(self, service): + """测试合并单个提示""" + prompts = ["请告诉我您的地区"] + merged = service._merge_prompts(prompts) + assert merged == "请告诉我您的地区" + + def test_merge_prompts_two(self, service): + """测试合并两个提示""" + prompts = ["请告诉我您的地区", "请告诉我您的产品"] + merged = service._merge_prompts(prompts) + assert "地区" in merged + assert "产品" in merged + assert "以及" in merged + + def test_merge_prompts_multiple(self, service): + """测试合并多个提示""" + prompts = ["您的地区", "您的产品", "您的等级"] + merged = service._merge_prompts(prompts) + assert "地区" in merged + assert "产品" in merged + assert "等级" in merged + assert "、" in merged + assert "以及" in merged + + @pytest.mark.asyncio + async def test_generate_batch_ask_back_empty(self, service): + """测试空缺失槽位""" + result = await service.generate_batch_ask_back(missing_slots=[]) + assert result.has_ask_back() is False + + @pytest.mark.asyncio + async def test_generate_batch_ask_back_single(self, service): + """测试单个缺失槽位""" + missing_slots = [ + { + "slot_key": "region", + "label": "地区", + "ask_back_prompt": "请告诉我您的地区", + } + ] + + with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get: + mock_get.return_value = MagicMock(required=True) + + result = await service.generate_batch_ask_back(missing_slots=missing_slots) + + assert result.has_ask_back() is True + assert result.ask_back_count == 1 + assert "地区" in result.get_prompt() + + @pytest.mark.asyncio + async def test_generate_batch_ask_back_multiple(self, service): + """测试多个缺失槽位""" + missing_slots = [ + {"slot_key": "region", "label": "地区", "ask_back_prompt": "您的地区"}, + {"slot_key": "product", "label": "产品", "ask_back_prompt": "您的产品"}, + {"slot_key": "grade", "label": "等级", "ask_back_prompt": "您的等级"}, + ] + + with patch.object(service._slot_def_service, 'get_slot_definition_by_key') as mock_get: + mock_get.return_value = MagicMock(required=True) + + result = await service.generate_batch_ask_back(missing_slots=missing_slots) + + assert result.has_ask_back() is True + assert result.ask_back_count == 2 + + @pytest.mark.asyncio + async def test_generate_batch_ask_back_prioritize_required(self, service): + """测试优先追问必填槽位""" + missing_slots = [ + {"slot_key": "optional", "label": "可选", "ask_back_prompt": "可选信息"}, + {"slot_key": "required", "label": "必填", "ask_back_prompt": "必填信息"}, + ] + + def mock_get_slot(tenant_id, slot_key): + if slot_key == "required": + return MagicMock(required=True) + return MagicMock(required=False) + + with patch.object(service._slot_def_service, 'get_slot_definition_by_key', side_effect=mock_get_slot): + result = await service.generate_batch_ask_back(missing_slots=missing_slots) + + assert result.has_ask_back() is True + assert result.selected_slots[0].slot_key == "required" + + +class TestCreateBatchAskBackService: + """create_batch_ask_back_service 工厂函数测试""" + + def test_create(self): + """测试创建服务实例""" + mock_session = AsyncMock() + config = BatchAskBackConfig(max_ask_back_slots_per_turn=3) + + service = create_batch_ask_back_service( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + config=config, + ) + + assert isinstance(service, BatchAskBackService) + assert service._tenant_id == "tenant_1" + assert service._session_id == "session_1" + assert service._config.max_ask_back_slots_per_turn == 3 + + +class TestAskBackHistory: + """追问历史测试""" + + @pytest.fixture + def service(self): + """创建服务实例""" + mock_session = AsyncMock() + return BatchAskBackService( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + ) + + @pytest.mark.asyncio + async def test_get_asked_history_empty(self, service): + """测试获取空历史""" + with patch.object(service._cache, '_get_client') as mock_client: + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + mock_client.return_value = mock_redis + + history = await service._get_asked_history() + assert history == {} + + @pytest.mark.asyncio + async def test_get_asked_history_with_data(self, service): + """测试获取有数据的历史""" + import json + + history_data = {"region": 12345.0, "product": 12346.0} + + with patch.object(service._cache, '_get_client') as mock_client: + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=json.dumps(history_data)) + mock_client.return_value = mock_redis + + history = await service._get_asked_history() + assert history == history_data diff --git a/ai-service/tests/test_clarification.py b/ai-service/tests/test_clarification.py new file mode 100644 index 0000000..3164db4 --- /dev/null +++ b/ai-service/tests/test_clarification.py @@ -0,0 +1,543 @@ +""" +Tests for clarification mechanism. +[AC-CLARIFY] 澄清机制测试 +""" + +import pytest +from unittest.mock import MagicMock, patch + +from app.services.intent.clarification import ( + ClarificationEngine, + ClarifyMetrics, + ClarifyReason, + ClarifySessionManager, + ClarifyState, + HybridIntentResult, + IntentCandidate, + T_HIGH, + T_LOW, + MAX_CLARIFY_RETRY, + get_clarify_metrics, +) + + +class TestClarifyMetrics: + def test_singleton_pattern(self): + m1 = ClarifyMetrics() + m2 = ClarifyMetrics() + assert m1 is m2 + + def test_record_clarify_trigger(self): + metrics = ClarifyMetrics() + metrics.reset() + + metrics.record_clarify_trigger() + metrics.record_clarify_trigger() + metrics.record_clarify_trigger() + + counts = metrics.get_metrics() + assert counts["clarify_trigger_rate"] == 3 + + def test_record_clarify_converge(self): + metrics = ClarifyMetrics() + metrics.reset() + + metrics.record_clarify_converge() + metrics.record_clarify_converge() + + counts = metrics.get_metrics() + assert counts["clarify_converge_rate"] == 2 + + def test_record_misroute(self): + metrics = ClarifyMetrics() + metrics.reset() + + metrics.record_misroute() + + counts = metrics.get_metrics() + assert counts["misroute_rate"] == 1 + + def test_get_rates(self): + metrics = ClarifyMetrics() + metrics.reset() + + metrics.record_clarify_trigger() + metrics.record_clarify_converge() + metrics.record_misroute() + + rates = metrics.get_rates(100) + assert rates["clarify_trigger_rate"] == 0.01 + assert rates["clarify_converge_rate"] == 1.0 + assert rates["misroute_rate"] == 0.01 + + def test_get_rates_zero_requests(self): + metrics = ClarifyMetrics() + metrics.reset() + + rates = metrics.get_rates(0) + assert rates["clarify_trigger_rate"] == 0.0 + assert rates["clarify_converge_rate"] == 0.0 + assert rates["misroute_rate"] == 0.0 + + def test_reset(self): + metrics = ClarifyMetrics() + metrics.record_clarify_trigger() + metrics.record_clarify_converge() + metrics.record_misroute() + + metrics.reset() + + counts = metrics.get_metrics() + assert counts["clarify_trigger_rate"] == 0 + assert counts["clarify_converge_rate"] == 0 + assert counts["misroute_rate"] == 0 + + +class TestIntentCandidate: + def test_to_dict(self): + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.85, + response_type="flow", + target_kb_ids=["kb-1"], + flow_id="flow-1", + fixed_reply=None, + transfer_message=None, + ) + + result = candidate.to_dict() + + assert result["intent_id"] == "intent-1" + assert result["intent_name"] == "退货意图" + assert result["confidence"] == 0.85 + assert result["response_type"] == "flow" + assert result["target_kb_ids"] == ["kb-1"] + assert result["flow_id"] == "flow-1" + + +class TestHybridIntentResult: + def test_to_dict(self): + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.85, + ) + + result = HybridIntentResult( + intent=candidate, + confidence=0.85, + candidates=[candidate], + need_clarify=False, + clarify_reason=None, + missing_slots=[], + ) + + d = result.to_dict() + + assert d["intent"]["intent_id"] == "intent-1" + assert d["confidence"] == 0.85 + assert len(d["candidates"]) == 1 + assert d["need_clarify"] is False + + def test_from_fusion_result(self): + mock_fusion = MagicMock() + mock_fusion.final_intent = MagicMock() + mock_fusion.final_intent.id = "intent-1" + mock_fusion.final_intent.name = "退货意图" + mock_fusion.final_intent.response_type = "flow" + mock_fusion.final_intent.target_kb_ids = ["kb-1"] + mock_fusion.final_intent.flow_id = None + mock_fusion.final_intent.fixed_reply = None + mock_fusion.final_intent.transfer_message = None + mock_fusion.final_confidence = 0.85 + mock_fusion.need_clarify = False + mock_fusion.decision_reason = "rule_high_confidence" + mock_fusion.clarify_candidates = [] + + result = HybridIntentResult.from_fusion_result(mock_fusion) + + assert result.intent is not None + assert result.intent.intent_id == "intent-1" + assert result.confidence == 0.85 + assert result.need_clarify is False + + def test_from_fusion_result_with_clarify(self): + mock_fusion = MagicMock() + mock_fusion.final_intent = None + mock_fusion.final_confidence = 0.5 + mock_fusion.need_clarify = True + mock_fusion.decision_reason = "multi_intent" + + candidate1 = MagicMock() + candidate1.id = "intent-1" + candidate1.name = "退货意图" + candidate1.response_type = "flow" + candidate1.target_kb_ids = None + candidate1.flow_id = None + candidate1.fixed_reply = None + candidate1.transfer_message = None + + candidate2 = MagicMock() + candidate2.id = "intent-2" + candidate2.name = "换货意图" + candidate2.response_type = "flow" + candidate2.target_kb_ids = None + candidate2.flow_id = None + candidate2.fixed_reply = None + candidate2.transfer_message = None + + mock_fusion.clarify_candidates = [candidate1, candidate2] + + result = HybridIntentResult.from_fusion_result(mock_fusion) + + assert result.need_clarify is True + assert result.clarify_reason == ClarifyReason.MULTI_INTENT + assert len(result.candidates) == 2 + + +class TestClarifyState: + def test_to_dict(self): + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.5, + ) + + state = ClarifyState( + reason=ClarifyReason.INTENT_AMBIGUITY, + asked_slot=None, + retry_count=1, + candidates=[candidate], + asked_intent_ids=["intent-1"], + ) + + d = state.to_dict() + + assert d["reason"] == "intent_ambiguity" + assert d["retry_count"] == 1 + assert len(d["candidates"]) == 1 + + def test_increment_retry(self): + state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE) + + state.increment_retry() + + assert state.retry_count == 1 + + state.increment_retry() + + assert state.retry_count == 2 + + def test_is_max_retry(self): + state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE) + + assert not state.is_max_retry() + + state.retry_count = MAX_CLARIFY_RETRY + + assert state.is_max_retry() + + +class TestClarificationEngine: + def test_compute_confidence_rule_only(self): + engine = ClarificationEngine() + + confidence = engine.compute_confidence( + rule_score=1.0, + semantic_score=0.0, + llm_score=0.0, + w_rule=1.0, + w_semantic=0.0, + w_llm=0.0, + ) + + assert confidence == 1.0 + + def test_compute_confidence_semantic_only(self): + engine = ClarificationEngine() + + confidence = engine.compute_confidence( + rule_score=0.0, + semantic_score=0.8, + llm_score=0.0, + w_rule=0.3, + w_semantic=0.5, + w_llm=0.2, + ) + + # With weights w_rule=0.3, w_semantic=0.5, w_llm=0.2 and scores + # rule=0.0, semantic=0.8, llm=0.0: + # confidence = (0.0*0.3 + 0.8*0.5 + 0.0*0.2) / (0.3+0.5+0.2) = 0.4/1.0 = 0.4 + assert confidence == 0.4 + + def test_compute_confidence_weighted(self): + engine = ClarificationEngine() + + confidence = engine.compute_confidence( + rule_score=1.0, + semantic_score=0.8, + llm_score=0.9, + w_rule=0.5, + w_semantic=0.3, + w_llm=0.2, + ) + + expected = (1.0 * 0.5 + 0.8 * 0.3 + 0.9 * 0.2) / 1.0 + assert abs(confidence - expected) < 0.001 + + def test_check_hard_block_low_confidence(self): + engine = ClarificationEngine() + + result = HybridIntentResult( + intent=None, + confidence=0.5, + candidates=[], + ) + + is_blocked, reason = engine.check_hard_block(result) + + assert is_blocked is True + assert reason == ClarifyReason.LOW_CONFIDENCE + + def test_check_hard_block_high_confidence(self): + engine = ClarificationEngine() + + result = HybridIntentResult( + intent=IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.85, + ), + confidence=0.85, + candidates=[], + ) + + is_blocked, reason = engine.check_hard_block(result) + + assert is_blocked is False + assert reason is None + + def test_check_hard_block_missing_slots(self): + engine = ClarificationEngine() + + result = HybridIntentResult( + intent=IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.85, + ), + confidence=0.85, + candidates=[], + ) + + is_blocked, reason = engine.check_hard_block( + result, + required_slots=["order_id", "product_id"], + filled_slots={"order_id": "123"}, + ) + + assert is_blocked is True + assert reason == ClarifyReason.MISSING_SLOT + + def test_should_trigger_clarify_below_t_low(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + result = HybridIntentResult( + intent=None, + confidence=0.3, + candidates=[], + ) + + should_clarify, state = engine.should_trigger_clarify(result) + + assert should_clarify is True + assert state is not None + assert state.reason == ClarifyReason.LOW_CONFIDENCE + + def test_should_trigger_clarify_gray_zone(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.5, + ) + + result = HybridIntentResult( + intent=candidate, + confidence=0.5, + candidates=[candidate], + need_clarify=True, + clarify_reason=ClarifyReason.INTENT_AMBIGUITY, + ) + + should_clarify, state = engine.should_trigger_clarify(result) + + assert should_clarify is True + assert state is not None + assert state.reason == ClarifyReason.INTENT_AMBIGUITY + + def test_should_trigger_clarify_above_t_high(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.85, + ) + + result = HybridIntentResult( + intent=candidate, + confidence=0.85, + candidates=[candidate], + ) + + should_clarify, state = engine.should_trigger_clarify(result) + + assert should_clarify is False + assert state is None + + def test_generate_clarify_prompt_missing_slot(self): + engine = ClarificationEngine() + + state = ClarifyState( + reason=ClarifyReason.MISSING_SLOT, + asked_slot="order_id", + ) + + prompt = engine.generate_clarify_prompt(state) + + assert "order_id" in prompt or "相关信息" in prompt + + def test_generate_clarify_prompt_low_confidence(self): + engine = ClarificationEngine() + + state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE) + + prompt = engine.generate_clarify_prompt(state) + + assert "理解" in prompt or "详细" in prompt + + def test_generate_clarify_prompt_multi_intent(self): + engine = ClarificationEngine() + + candidates = [ + IntentCandidate(intent_id="1", intent_name="退货", confidence=0.5), + IntentCandidate(intent_id="2", intent_name="换货", confidence=0.4), + ] + + state = ClarifyState( + reason=ClarifyReason.MULTI_INTENT, + candidates=candidates, + ) + + prompt = engine.generate_clarify_prompt(state) + + assert "退货" in prompt + assert "换货" in prompt + + def test_process_clarify_response_max_retry(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + state = ClarifyState( + reason=ClarifyReason.LOW_CONFIDENCE, + retry_count=MAX_CLARIFY_RETRY, + ) + + result = engine.process_clarify_response("用户回复", state) + + assert result.intent is None + assert result.confidence == 0.0 + assert result.need_clarify is False + + def test_process_clarify_response_missing_slot(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + candidate = IntentCandidate( + intent_id="intent-1", + intent_name="退货意图", + confidence=0.8, + ) + + state = ClarifyState( + reason=ClarifyReason.MISSING_SLOT, + asked_slot="order_id", + candidates=[candidate], + ) + + result = engine.process_clarify_response("订单号是123", state) + + assert result.intent is not None + assert result.need_clarify is False + + def test_get_metrics(self): + engine = ClarificationEngine() + get_clarify_metrics().reset() + + engine._metrics.record_clarify_trigger() + engine._metrics.record_clarify_converge() + + metrics = engine.get_metrics() + + assert metrics["clarify_trigger_rate"] == 1 + assert metrics["clarify_converge_rate"] == 1 + + +class TestClarifySessionManager: + def test_set_and_get_session(self): + ClarifySessionManager.clear_session("test-session") + + state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE) + + ClarifySessionManager.set_session("test-session", state) + + retrieved = ClarifySessionManager.get_session("test-session") + + assert retrieved is not None + assert retrieved.reason == ClarifyReason.LOW_CONFIDENCE + + def test_clear_session(self): + ClarifySessionManager.set_session( + "test-session", + ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE), + ) + + ClarifySessionManager.clear_session("test-session") + + retrieved = ClarifySessionManager.get_session("test-session") + + assert retrieved is None + + def test_has_active_clarify(self): + ClarifySessionManager.clear_session("test-session") + + assert not ClarifySessionManager.has_active_clarify("test-session") + + state = ClarifyState(reason=ClarifyReason.LOW_CONFIDENCE) + ClarifySessionManager.set_session("test-session", state) + + assert ClarifySessionManager.has_active_clarify("test-session") + + state.retry_count = MAX_CLARIFY_RETRY + + assert not ClarifySessionManager.has_active_clarify("test-session") + + +class TestThresholds: + def test_t_high_value(self): + assert T_HIGH == 0.75 + + def test_t_low_value(self): + assert T_LOW == 0.45 + + def test_t_high_greater_than_t_low(self): + assert T_HIGH > T_LOW + + def test_max_retry_value(self): + assert MAX_CLARIFY_RETRY == 3 diff --git a/ai-service/tests/test_dialogue_slot_integration.py b/ai-service/tests/test_dialogue_slot_integration.py new file mode 100644 index 0000000..559ce47 --- /dev/null +++ b/ai-service/tests/test_dialogue_slot_integration.py @@ -0,0 +1,308 @@ +""" +Tests for Dialogue API with Slot State Integration. +[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.api.mid.dialogue import _generate_ask_back_for_missing_slots +from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo +from app.services.mid.slot_state_aggregator import SlotState + + +class TestGenerateAskBackResponse: + """测试生成追问响应""" + + @pytest.mark.asyncio + async def test_generate_ask_back_with_prompt(self): + """测试使用配置的 ask_back_prompt 生成追问""" + slot_state = SlotState() + missing_slots = [ + { + "slot_key": "region", + "label": "地区", + "ask_back_prompt": "请问您在哪个地区?", + } + ] + mock_session = AsyncMock() + + response = await _generate_ask_back_for_missing_slots( + slot_state=slot_state, + missing_slots=missing_slots, + session=mock_session, + tenant_id="test_tenant", + ) + + assert response == "请问您在哪个地区?" + + @pytest.mark.asyncio + async def test_generate_ask_back_generic(self): + """测试使用通用模板生成追问""" + slot_state = SlotState() + missing_slots = [ + { + "slot_key": "product_line", + "label": "产品线", + # 没有 ask_back_prompt + } + ] + mock_session = AsyncMock() + + response = await _generate_ask_back_for_missing_slots( + slot_state=slot_state, + missing_slots=missing_slots, + session=mock_session, + tenant_id="test_tenant", + ) + + assert "产品线" in response + + @pytest.mark.asyncio + async def test_generate_ask_back_empty_slots(self): + """测试空缺失槽位列表""" + slot_state = SlotState() + missing_slots = [] + mock_session = AsyncMock() + + response = await _generate_ask_back_for_missing_slots( + slot_state=slot_state, + missing_slots=missing_slots, + session=mock_session, + tenant_id="test_tenant", + ) + + assert "更多信息" in response + + +class TestDialogueAskBackResponse: + """测试对话追问响应""" + + def test_dialogue_response_with_ask_back(self): + """测试追问响应的结构""" + from app.models.mid.schemas import DialogueResponse + + response = DialogueResponse( + segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)], + trace=TraceInfo( + mode=ExecutionMode.AGENT, + request_id="test_request_id", + generation_id="test_generation_id", + fallback_reason_code="missing_required_slots", + kb_tool_called=True, + kb_hit=False, + ), + ) + + assert len(response.segments) == 1 + assert "哪个产品线" in response.segments[0].text + assert response.trace.fallback_reason_code == "missing_required_slots" + assert response.trace.kb_tool_called is True + assert response.trace.kb_hit is False + + +class TestSlotStateAggregationFlow: + """测试槽位状态聚合流程""" + + @pytest.mark.asyncio + async def test_memory_slots_included_in_state(self): + """测试 memory_recall 的槽位被包含在状态中""" + from app.models.mid.schemas import MemorySlot, SlotSource + from app.services.mid.slot_state_aggregator import SlotStateAggregator + + mock_session = AsyncMock() + aggregator = SlotStateAggregator( + session=mock_session, + tenant_id="test_tenant", + ) + + memory_slots = { + "product_line": MemorySlot( + key="product_line", + value="vip_course", + source=SlotSource.USER_CONFIRMED, + confidence=1.0, + ) + } + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[], + ): + state = await aggregator.aggregate( + memory_slots=memory_slots, + current_input_slots=None, + context=None, + ) + + assert "product_line" in state.filled_slots + assert state.filled_slots["product_line"] == "vip_course" + assert state.slot_sources["product_line"] == "user_confirmed" + + @pytest.mark.asyncio + async def test_missing_slots_identified(self): + """测试缺失的必填槽位被正确识别""" + from unittest.mock import MagicMock + from app.services.mid.slot_state_aggregator import SlotStateAggregator + + mock_session = AsyncMock() + aggregator = SlotStateAggregator( + session=mock_session, + tenant_id="test_tenant", + ) + + # 模拟一个 required 的槽位定义 + mock_slot_def = MagicMock() + mock_slot_def.slot_key = "region" + mock_slot_def.required = True + mock_slot_def.ask_back_prompt = "请问您在哪个地区?" + mock_slot_def.linked_field_id = None + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[mock_slot_def], + ): + state = await aggregator.aggregate( + memory_slots={}, + current_input_slots=None, + context=None, + ) + + assert len(state.missing_required_slots) == 1 + assert state.missing_required_slots[0]["slot_key"] == "region" + assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?" + + +class TestSlotMetadataLinkage: + """测试槽位与元数据关联""" + + @pytest.mark.asyncio + async def test_slot_to_field_mapping(self): + """测试槽位到元数据字段的映射""" + from unittest.mock import MagicMock, patch + from app.services.mid.slot_state_aggregator import SlotStateAggregator + from app.services.metadata_field_definition_service import MetadataFieldDefinitionService + + mock_session = AsyncMock() + aggregator = SlotStateAggregator( + session=mock_session, + tenant_id="test_tenant", + ) + + # 模拟槽位定义(带 linked_field_id) + mock_slot_def = MagicMock() + mock_slot_def.slot_key = "product" + mock_slot_def.linked_field_id = "field-uuid-123" + mock_slot_def.required = False + mock_slot_def.type = "string" + mock_slot_def.options = None + + # 模拟关联的元数据字段 + mock_field = MagicMock() + mock_field.field_key = "product_line" + mock_field.label = "产品线" + mock_field.type = "string" + mock_field.required = False + mock_field.options = None + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[mock_slot_def], + ): + with patch.object( + MetadataFieldDefinitionService, + "get_field_definition", + return_value=mock_field, + ): + state = await aggregator.aggregate( + memory_slots={}, + current_input_slots=None, + context=None, + ) + + # 验证映射已建立 + assert state.slot_to_field_map.get("product") == "product_line" + + +class TestBackwardCompatibility: + """测试向后兼容性""" + + @pytest.mark.asyncio + async def test_kb_search_without_slot_state(self): + """测试不使用 slot_state 时 KB 检索仍然工作""" + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicConfig, + KbSearchDynamicTool, + ) + from app.services.mid.metadata_filter_builder import MetadataFilterBuilder + + mock_session = AsyncMock() + kb_tool = KbSearchDynamicTool( + session=mock_session, + config=KbSearchDynamicConfig(enabled=True), + ) + + # 模拟 filter_builder 返回空结果 + with patch.object( + MetadataFilterBuilder, + "_get_filterable_fields", + return_value=[], + ): + with patch.object( + kb_tool, + "_retrieve_with_timeout", + return_value=[], + ): + result = await kb_tool.execute( + query="退款政策", + tenant_id="test_tenant", + context={}, + slot_state=None, # 不提供 slot_state + ) + + # 应该成功执行 + assert result.success is True + assert result.fallback_reason_code is None + + @pytest.mark.asyncio + async def test_legacy_context_filter(self): + """测试使用传统 context 构建过滤器""" + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicConfig, + KbSearchDynamicTool, + ) + from app.services.mid.metadata_filter_builder import MetadataFilterBuilder + + mock_session = AsyncMock() + kb_tool = KbSearchDynamicTool( + session=mock_session, + config=KbSearchDynamicConfig(enabled=True), + ) + + # 使用简单 context + context = {"product_line": "vip_course", "region": "beijing"} + + with patch.object( + MetadataFilterBuilder, + "_get_filterable_fields", + return_value=[], + ): + with patch.object( + kb_tool, + "_retrieve_with_timeout", + return_value=[], + ): + result = await kb_tool.execute( + query="退款政策", + tenant_id="test_tenant", + context=context, + slot_state=None, + ) + + # 应该成功执行 + assert result.success is True + # 简单 context 应该直接使用作为 filter + assert result.applied_filter.get("product_line") == "vip_course" diff --git a/ai-service/tests/test_field_roles_update.py b/ai-service/tests/test_field_roles_update.py new file mode 100644 index 0000000..82b5276 --- /dev/null +++ b/ai-service/tests/test_field_roles_update.py @@ -0,0 +1,149 @@ +""" +Tests for field_roles update functionality. +[AC-MRS-01] 验证字段角色更新功能 +""" + +import pytest +import uuid +from unittest.mock import AsyncMock, MagicMock + +from app.models.entities import ( + MetadataFieldDefinition, + MetadataFieldDefinitionUpdate, + MetadataFieldStatus, +) +from app.services.metadata_field_definition_service import MetadataFieldDefinitionService + + +class TestFieldRolesUpdate: + """测试字段角色更新功能""" + + @pytest.fixture + def mock_session(self): + """Create mock session""" + session = MagicMock() + session.execute = AsyncMock() + session.flush = AsyncMock() + session.commit = AsyncMock() + return session + + @pytest.fixture + def service(self, mock_session): + """Create service instance""" + return MetadataFieldDefinitionService(mock_session) + + @pytest.fixture + def existing_field(self): + """Create existing field with field_roles""" + field = MagicMock(spec=MetadataFieldDefinition) + field.id = uuid.uuid4() + field.tenant_id = "test-tenant" + field.field_key = "grade" + field.label = "年级" + field.type = "string" + field.required = True + field.options = None + field.default_value = None + field.scope = ["kb_document"] + field.is_filterable = True + field.is_rank_feature = False + field.field_roles = ["slot"] # 初始角色 + field.status = MetadataFieldStatus.ACTIVE.value + field.version = 1 + return field + + @pytest.mark.asyncio + async def test_update_field_roles_success(self, service, mock_session, existing_field): + """[AC-MRS-01] 测试成功更新字段角色""" + # Mock get_field_definition to return existing field + service.get_field_definition = AsyncMock(return_value=existing_field) + + # Create update request with new field_roles + field_update = MetadataFieldDefinitionUpdate( + field_roles=["slot", "resource_filter"] + ) + + # Execute update + result = await service.update_field_definition( + "test-tenant", + str(existing_field.id), + field_update + ) + + # Verify result + assert result is not None + assert result.field_roles == ["slot", "resource_filter"] + assert result.version == 2 # Version should increment + + @pytest.mark.asyncio + async def test_update_field_roles_to_empty(self, service, mock_session, existing_field): + """[AC-MRS-01] 测试将字段角色更新为空列表""" + service.get_field_definition = AsyncMock(return_value=existing_field) + + field_update = MetadataFieldDefinitionUpdate( + field_roles=[] + ) + + result = await service.update_field_definition( + "test-tenant", + str(existing_field.id), + field_update + ) + + assert result is not None + assert result.field_roles == [] + + @pytest.mark.asyncio + async def test_update_field_roles_invalid_role(self, service, mock_session, existing_field): + """[AC-MRS-01] 测试更新无效的字段角色""" + service.get_field_definition = AsyncMock(return_value=existing_field) + + field_update = MetadataFieldDefinitionUpdate( + field_roles=["invalid_role"] + ) + + with pytest.raises(ValueError) as exc_info: + await service.update_field_definition( + "test-tenant", + str(existing_field.id), + field_update + ) + + assert "无效的字段角色" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_update_without_field_roles_unchanged(self, service, mock_session, existing_field): + """[AC-MRS-01] 测试不更新 field_roles 时保持原值""" + service.get_field_definition = AsyncMock(return_value=existing_field) + + # Update only label, not field_roles + field_update = MetadataFieldDefinitionUpdate( + label="新年级标签" + ) + + result = await service.update_field_definition( + "test-tenant", + str(existing_field.id), + field_update + ) + + assert result is not None + assert result.label == "新年级标签" + assert result.field_roles == ["slot"] # Should remain unchanged + + @pytest.mark.asyncio + async def test_update_field_roles_not_found(self, service, mock_session): + """[AC-MRS-01] 测试更新不存在的字段""" + service.get_field_definition = AsyncMock(return_value=None) + + field_update = MetadataFieldDefinitionUpdate( + field_roles=["slot"] + ) + + result = await service.update_field_definition( + "test-tenant", + str(uuid.uuid4()), + field_update + ) + + assert result is None diff --git a/ai-service/tests/test_fusion_policy.py b/ai-service/tests/test_fusion_policy.py new file mode 100644 index 0000000..17187c2 --- /dev/null +++ b/ai-service/tests/test_fusion_policy.py @@ -0,0 +1,408 @@ +""" +Unit tests for FusionPolicy. +[AC-AISVC-115~AC-AISVC-117] Tests for fusion decision policy. +""" + +import pytest +from unittest.mock import MagicMock +import uuid + +from app.services.intent.models import ( + FusionConfig, + FusionResult, + LlmJudgeResult, + RuleMatchResult, + SemanticCandidate, + SemanticMatchResult, + RouteTrace, +) + + +class FusionPolicy: + """[AC-AISVC-115] Fusion decision policy.""" + + DECISION_PRIORITY = [ + ("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None), + ("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None), + ("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7), + ("rule_semantic_agree", lambda r, s, l: r.score > 0 and s.top_score > 0.5 and r.rule_id == s.candidates[0].rule.id if s.candidates else False), + ("semantic_fallback", lambda r, s, l: s.top_score > 0.5), + ("rule_fallback", lambda r, s, l: r.score > 0), + ("no_match", lambda r, s, l: True), + ] + + def __init__(self, config: FusionConfig): + self._config = config + + def fuse( + self, + rule_result: RuleMatchResult, + semantic_result: SemanticMatchResult, + llm_result: LlmJudgeResult | None, + ) -> FusionResult: + trace = RouteTrace( + rule_match={ + "rule_id": str(rule_result.rule_id) if rule_result.rule_id else None, + "match_type": rule_result.match_type, + "matched_text": rule_result.matched_text, + "score": rule_result.score, + "duration_ms": rule_result.duration_ms, + }, + semantic_match={ + "top_candidates": [ + {"rule_id": str(c.rule.id), "name": c.rule.name, "score": c.score} + for c in semantic_result.candidates + ], + "top_score": semantic_result.top_score, + "duration_ms": semantic_result.duration_ms, + "skipped": semantic_result.skipped, + "skip_reason": semantic_result.skip_reason, + }, + llm_judge={ + "triggered": llm_result.triggered if llm_result else False, + "intent_id": llm_result.intent_id if llm_result else None, + "score": llm_result.score if llm_result else 0.0, + "duration_ms": llm_result.duration_ms if llm_result else 0, + "tokens_used": llm_result.tokens_used if llm_result else 0, + }, + fusion={}, + ) + + final_intent = None + final_confidence = 0.0 + decision_reason = "no_match" + + for reason, condition in self.DECISION_PRIORITY: + if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()): + decision_reason = reason + break + + if decision_reason == "rule_high_confidence": + final_intent = rule_result.rule + final_confidence = 1.0 + elif decision_reason == "llm_judge" and llm_result: + final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result) + final_confidence = llm_result.score + elif decision_reason == "semantic_override": + final_intent = semantic_result.candidates[0].rule + final_confidence = semantic_result.top_score + elif decision_reason == "rule_semantic_agree": + final_intent = rule_result.rule + final_confidence = self._calculate_weighted_confidence(rule_result, semantic_result, llm_result) + elif decision_reason == "semantic_fallback": + final_intent = semantic_result.candidates[0].rule + final_confidence = semantic_result.top_score + elif decision_reason == "rule_fallback": + final_intent = rule_result.rule + final_confidence = rule_result.score + + need_clarify = final_confidence < self._config.clarify_threshold + clarify_candidates = None + if need_clarify and len(semantic_result.candidates) > 1: + clarify_candidates = [c.rule for c in semantic_result.candidates[:3]] + + trace.fusion = { + "weights": { + "w_rule": self._config.w_rule, + "w_semantic": self._config.w_semantic, + "w_llm": self._config.w_llm, + }, + "final_confidence": final_confidence, + "decision_reason": decision_reason, + } + + return FusionResult( + final_intent=final_intent, + final_confidence=final_confidence, + decision_reason=decision_reason, + need_clarify=need_clarify, + clarify_candidates=clarify_candidates, + trace=trace, + ) + + def _calculate_weighted_confidence( + self, + rule_result: RuleMatchResult, + semantic_result: SemanticMatchResult, + llm_result: LlmJudgeResult | None, + ) -> float: + rule_score = rule_result.score + semantic_score = semantic_result.top_score if not semantic_result.skipped else 0.0 + llm_score = llm_result.score if llm_result and llm_result.triggered else 0.0 + + total_weight = self._config.w_rule + self._config.w_semantic + if llm_result and llm_result.triggered: + total_weight += self._config.w_llm + + confidence = ( + self._config.w_rule * rule_score + + self._config.w_semantic * semantic_score + + self._config.w_llm * llm_score + ) / total_weight + + return min(1.0, max(0.0, confidence)) + + def _find_rule_by_id( + self, + intent_id: str | None, + rule_result: RuleMatchResult, + semantic_result: SemanticMatchResult, + ): + if not intent_id: + return None + + if rule_result.rule_id and str(rule_result.rule_id) == intent_id: + return rule_result.rule + + for candidate in semantic_result.candidates: + if str(candidate.rule.id) == intent_id: + return candidate.rule + + return None + + +@pytest.fixture +def config(): + return FusionConfig() + + +@pytest.fixture +def mock_rule(): + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = "Test Intent" + rule.response_type = "rag" + return rule + + +class TestFusionPolicy: + """Tests for FusionPolicy class.""" + + def test_init(self, config): + """Test FusionPolicy initialization.""" + policy = FusionPolicy(config) + assert policy._config == config + + def test_fuse_rule_high_confidence(self, config, mock_rule): + """Test fusion with rule high confidence.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=mock_rule.id, + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[], + top_score=0.0, + duration_ms=50, + skipped=True, + skip_reason="no_semantic_config", + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.decision_reason == "rule_high_confidence" + assert result.final_intent == mock_rule + assert result.final_confidence == 1.0 + assert result.need_clarify is False + + def test_fuse_llm_judge(self, config, mock_rule): + """Test fusion with LLM judge result.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.5)], + top_score=0.5, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + llm_result = LlmJudgeResult( + intent_id=str(mock_rule.id), + intent_name="Test Intent", + score=0.85, + reasoning="Test reasoning", + duration_ms=500, + tokens_used=100, + triggered=True, + ) + + result = policy.fuse(rule_result, semantic_result, llm_result) + + assert result.decision_reason == "llm_judge" + assert result.final_intent == mock_rule + assert result.final_confidence == 0.85 + + def test_fuse_semantic_override(self, config, mock_rule): + """Test fusion with semantic override.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.85)], + top_score=0.85, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.decision_reason == "semantic_override" + assert result.final_intent == mock_rule + assert result.final_confidence == 0.85 + + def test_fuse_rule_semantic_agree(self, config, mock_rule): + """Test fusion when rule and semantic agree.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=mock_rule.id, + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], + top_score=0.8, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.decision_reason == "rule_high_confidence" + assert result.final_intent == mock_rule + + def test_fuse_no_match(self, config): + """Test fusion with no match.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[], + top_score=0.0, + duration_ms=50, + skipped=True, + skip_reason="no_semantic_config", + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.decision_reason == "no_match" + assert result.final_intent is None + assert result.final_confidence == 0.0 + + def test_fuse_need_clarify(self, config, mock_rule): + """Test fusion with clarify needed.""" + policy = FusionPolicy(config) + + other_rule = MagicMock() + other_rule.id = uuid.uuid4() + other_rule.name = "Other Intent" + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[ + SemanticCandidate(rule=mock_rule, score=0.35), + SemanticCandidate(rule=other_rule, score=0.30), + ], + top_score=0.35, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.need_clarify is True + assert result.clarify_candidates is not None + assert len(result.clarify_candidates) == 2 + + def test_calculate_weighted_confidence(self, config, mock_rule): + """Test weighted confidence calculation.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=mock_rule.id, + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], + top_score=0.8, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + confidence = policy._calculate_weighted_confidence(rule_result, semantic_result, None) + + expected = (0.5 * 1.0 + 0.3 * 0.8) / (0.5 + 0.3) + assert abs(confidence - expected) < 0.01 + + def test_trace_generation(self, config, mock_rule): + """Test that trace is properly generated.""" + policy = FusionPolicy(config) + + rule_result = RuleMatchResult( + rule_id=mock_rule.id, + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.8)], + top_score=0.8, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + result = policy.fuse(rule_result, semantic_result, None) + + assert result.trace is not None + assert result.trace.rule_match["rule_id"] == str(mock_rule.id) + assert result.trace.semantic_match["top_score"] == 0.8 + assert result.trace.fusion["decision_reason"] == "rule_high_confidence" diff --git a/ai-service/tests/test_intent.py b/ai-service/tests/test_intent.py new file mode 100644 index 0000000..50e0aca --- /dev/null +++ b/ai-service/tests/test_intent.py @@ -0,0 +1,142 @@ +""" +Tests for intent router. +""" + +import uuid + +import pytest + +from app.services.intent.router import IntentRouter, RuleMatcher +from app.services.intent.models import ( + FusionConfig, + RuleMatchResult, + SemanticMatchResult, + LlmJudgeResult, + FusionResult, +) +from app.models.entities import IntentRule + + +class TestRuleMatcher: + """Test RuleMatcher basic functionality.""" + + def test_match_empty_message(self): + matcher = RuleMatcher() + result = matcher.match("", []) + assert result.score == 0.0 + assert result.rule is None + assert result.duration_ms >= 0 + + def test_match_empty_rules(self): + matcher = RuleMatcher() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Test Rule", + keywords=["test", "demo"], + is_enabled=True, + ) + result = matcher.match("test message", [rule]) + assert result.score == 1.0 + assert result.rule == rule + assert result.match_type == "keyword" + assert result.matched_text == "test" + + def test_match_regex(self): + matcher = RuleMatcher() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Test Regex Rule", + patterns=[r"test.*pattern"], + is_enabled=True, + ) + result = matcher.match("this is a test regex pattern", [rule]) + assert result.score == 1.0 + assert result.rule == rule + assert result.match_type == "regex" + assert "pattern" in result.matched_text + + def test_no_match(self): + matcher = RuleMatcher() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Test Rule", + keywords=["specific", "keyword"], + is_enabled=True, + ) + result = matcher.match("no match here", [rule]) + assert result.score == 0.0 + assert result.rule is None + + + def test_priority_order(self): + matcher = RuleMatcher() + rule1 = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="High Priority", + keywords=["high"], + priority=10, + is_enabled=True, + ) + rule2 = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Low Priority", + keywords=["low"], + priority=1, + is_enabled=True, + ) + result = matcher.match("high priority message", [rule1, rule2]) + assert result.rule == rule1 + assert result.rule.name == "High Priority" + + + def test_disabled_rule(self): + matcher = RuleMatcher() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Disabled Rule", + keywords=["disabled"], + is_enabled=False, + ) + result = matcher.match("disabled message", [rule]) + assert result.score == 0.0 + assert result.rule is None + + +class TestIntentRouterBackwardCompatibility: + """Test IntentRouter backward compatibility.""" + + def test_match_backward_compatible(self): + router = IntentRouter() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Test Rule", + keywords=["hello", "hi"], + is_enabled=True, + ) + result = router.match("hello world", [rule]) + assert result is not None + assert result.rule.name == "Test Rule" + assert result.match_type == "keyword" + assert result.matched == "hello" + + def test_match_with_stats(self): + router = IntentRouter() + rule = IntentRule( + id=uuid.uuid4(), + tenant_id="test_tenant", + name="Test Rule", + keywords=["test"], + is_enabled=True, + ) + result, rule_id = router.match_with_stats("test message", [rule]) + assert result is not None + assert rule_id == str(rule.id) + + diff --git a/ai-service/tests/test_intent_router_hybrid.py b/ai-service/tests/test_intent_router_hybrid.py new file mode 100644 index 0000000..814eca2 --- /dev/null +++ b/ai-service/tests/test_intent_router_hybrid.py @@ -0,0 +1,468 @@ +""" +Integration tests for IntentRouter.match_hybrid(). +[AC-AISVC-111] Tests for hybrid routing integration. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import uuid +import asyncio + +from app.services.intent.models import ( + FusionConfig, + FusionResult, + LlmJudgeInput, + LlmJudgeResult, + RuleMatchResult, + SemanticCandidate, + SemanticMatchResult, + RouteTrace, +) + + +@pytest.fixture +def mock_embedding_provider(): + """Create a mock embedding provider.""" + provider = AsyncMock() + provider.embed = AsyncMock(return_value=[0.1] * 768) + provider.embed_batch = AsyncMock(return_value=[[0.1] * 768]) + return provider + + +@pytest.fixture +def mock_llm_client(): + """Create a mock LLM client.""" + client = AsyncMock() + return client + + +@pytest.fixture +def config(): + """Create a fusion config.""" + return FusionConfig() + + +@pytest.fixture +def mock_rule(): + """Create a mock intent rule.""" + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = "Return Intent" + rule.response_type = "rag" + rule.keywords = ["退货", "退款"] + rule.patterns = [] + rule.intent_vector = [0.1] * 768 + rule.semantic_examples = None + rule.is_enabled = True + rule.priority = 10 + return rule + + +@pytest.fixture +def mock_rules(mock_rule): + """Create a list of mock intent rules.""" + other_rule = MagicMock() + other_rule.id = uuid.uuid4() + other_rule.name = "Order Query" + other_rule.response_type = "rag" + other_rule.keywords = ["订单", "查询"] + other_rule.patterns = [] + other_rule.intent_vector = [0.5] * 768 + other_rule.semantic_examples = None + other_rule.is_enabled = True + other_rule.priority = 5 + + return [mock_rule, other_rule] + + +class MockRuleMatcher: + """Mock RuleMatcher for testing.""" + + def match(self, message: str, rules: list) -> RuleMatchResult: + import time + start_time = time.time() + message_lower = message.lower() + + for rule in rules: + if not rule.is_enabled: + continue + for keyword in (rule.keywords or []): + if keyword.lower() in message_lower: + return RuleMatchResult( + rule_id=rule.id, + rule=rule, + match_type="keyword", + matched_text=keyword, + score=1.0, + duration_ms=int((time.time() - start_time) * 1000), + ) + return RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=int((time.time() - start_time) * 1000), + ) + + +class MockSemanticMatcher: + """Mock SemanticMatcher for testing.""" + + def __init__(self, config): + self._config = config + + async def match(self, message: str, rules: list, tenant_id: str, top_k: int = 3) -> SemanticMatchResult: + import time + start_time = time.time() + + if not self._config.semantic_matcher_enabled: + return SemanticMatchResult( + candidates=[], + top_score=0.0, + duration_ms=0, + skipped=True, + skip_reason="disabled", + ) + + candidates = [] + for rule in rules: + if rule.intent_vector: + candidates.append(SemanticCandidate(rule=rule, score=0.85)) + break + + return SemanticMatchResult( + candidates=candidates[:top_k], + top_score=candidates[0].score if candidates else 0.0, + duration_ms=int((time.time() - start_time) * 1000), + skipped=False, + skip_reason=None, + ) + + +class MockLlmJudge: + """Mock LlmJudge for testing.""" + + def __init__(self, config): + self._config = config + + def should_trigger(self, rule_result, semantic_result, config=None) -> tuple: + effective_config = config or self._config + if not effective_config.llm_judge_enabled: + return False, "disabled" + + if rule_result.score > 0 and semantic_result.top_score > 0: + if semantic_result.candidates: + if rule_result.rule_id != semantic_result.candidates[0].rule.id: + if abs(rule_result.score - semantic_result.top_score) < effective_config.conflict_threshold: + return True, "rule_semantic_conflict" + + max_score = max(rule_result.score, semantic_result.top_score) + if effective_config.min_trigger_threshold < max_score < effective_config.gray_zone_threshold: + return True, "gray_zone" + + return False, "" + + async def judge(self, input_data: LlmJudgeInput, tenant_id: str) -> LlmJudgeResult: + return LlmJudgeResult( + intent_id=input_data.candidates[0]["id"] if input_data.candidates else None, + intent_name=input_data.candidates[0]["name"] if input_data.candidates else None, + score=0.9, + reasoning="Test arbitration", + duration_ms=500, + tokens_used=100, + triggered=True, + ) + + +class MockFusionPolicy: + """Mock FusionPolicy for testing.""" + + DECISION_PRIORITY = [ + ("rule_high_confidence", lambda r, s, l: r.score == 1.0 and r.rule is not None), + ("llm_judge", lambda r, s, l: l.triggered and l.intent_id is not None), + ("semantic_override", lambda r, s, l: r.score == 0 and s.top_score > 0.7), + ("no_match", lambda r, s, l: True), + ] + + def __init__(self, config): + self._config = config + + def fuse(self, rule_result, semantic_result, llm_result) -> FusionResult: + decision_reason = "no_match" + for reason, condition in self.DECISION_PRIORITY: + if condition(rule_result, semantic_result, llm_result or LlmJudgeResult.empty()): + decision_reason = reason + break + + final_intent = None + final_confidence = 0.0 + + if decision_reason == "rule_high_confidence": + final_intent = rule_result.rule + final_confidence = 1.0 + elif decision_reason == "llm_judge" and llm_result: + final_intent = self._find_rule_by_id(llm_result.intent_id, rule_result, semantic_result) + final_confidence = llm_result.score + elif decision_reason == "semantic_override": + final_intent = semantic_result.candidates[0].rule + final_confidence = semantic_result.top_score + + return FusionResult( + final_intent=final_intent, + final_confidence=final_confidence, + decision_reason=decision_reason, + need_clarify=final_confidence < 0.4, + clarify_candidates=None, + trace=RouteTrace(), + ) + + def _find_rule_by_id(self, intent_id, rule_result, semantic_result): + if not intent_id: + return None + if rule_result.rule_id and str(rule_result.rule_id) == intent_id: + return rule_result.rule + for c in semantic_result.candidates: + if str(c.rule.id) == intent_id: + return c.rule + return None + + +class MockIntentRouter: + """Mock IntentRouter for testing match_hybrid.""" + + def __init__(self, rule_matcher, semantic_matcher, llm_judge, fusion_policy, config=None): + self._rule_matcher = rule_matcher + self._semantic_matcher = semantic_matcher + self._llm_judge = llm_judge + self._fusion_policy = fusion_policy + self._config = config or FusionConfig() + + async def match_hybrid( + self, + message: str, + rules: list, + tenant_id: str, + config: FusionConfig | None = None, + ) -> FusionResult: + effective_config = config or self._config + + rule_result, semantic_result = await asyncio.gather( + asyncio.to_thread(self._rule_matcher.match, message, rules), + self._semantic_matcher.match(message, rules, tenant_id), + ) + + llm_result = None + should_trigger, trigger_reason = self._llm_judge.should_trigger( + rule_result, semantic_result, effective_config + ) + + if should_trigger: + candidates = self._build_llm_candidates(rule_result, semantic_result) + llm_result = await self._llm_judge.judge( + LlmJudgeInput( + message=message, + candidates=candidates, + conflict_type=trigger_reason, + ), + tenant_id, + ) + + fusion_result = self._fusion_policy.fuse( + rule_result, semantic_result, llm_result + ) + + return fusion_result + + def _build_llm_candidates(self, rule_result, semantic_result) -> list: + candidates = [] + + if rule_result.rule: + candidates.append({ + "id": str(rule_result.rule_id), + "name": rule_result.rule.name, + "description": f"匹配方式: {rule_result.match_type}", + }) + + for candidate in semantic_result.candidates[:3]: + if not any(c["id"] == str(candidate.rule.id) for c in candidates): + candidates.append({ + "id": str(candidate.rule.id), + "name": candidate.rule.name, + "description": f"语义相似度: {candidate.score:.2f}", + }) + + return candidates + + +class TestIntentRouterHybrid: + """Tests for IntentRouter.match_hybrid() integration.""" + + @pytest.mark.asyncio + async def test_match_hybrid_rule_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test hybrid routing with rule match.""" + rule_matcher = MockRuleMatcher() + semantic_matcher = MockSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + + assert result.decision_reason == "rule_high_confidence" + assert result.final_intent == mock_rules[0] + assert result.final_confidence == 1.0 + + @pytest.mark.asyncio + async def test_match_hybrid_semantic_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test hybrid routing with semantic match only.""" + rule_matcher = MockRuleMatcher() + semantic_matcher = MockSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("商品有问题", mock_rules, "tenant-1") + + assert result.decision_reason == "semantic_override" + assert result.final_intent is not None + assert result.final_confidence > 0.7 + + @pytest.mark.asyncio + async def test_match_hybrid_parallel_execution(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test that rule and semantic matching run in parallel.""" + import time + + class SlowSemanticMatcher(MockSemanticMatcher): + async def match(self, message, rules, tenant_id, top_k=3): + await asyncio.sleep(0.1) + return await super().match(message, rules, tenant_id, top_k) + + rule_matcher = MockRuleMatcher() + semantic_matcher = SlowSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + start_time = time.time() + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + elapsed = time.time() - start_time + + assert elapsed < 0.2 + assert result is not None + + @pytest.mark.asyncio + async def test_match_hybrid_llm_judge_triggered(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test hybrid routing with LLM judge triggered.""" + config = FusionConfig(conflict_threshold=0.3) + + class ConflictSemanticMatcher(MockSemanticMatcher): + async def match(self, message, rules, tenant_id, top_k=3): + result = await super().match(message, rules, tenant_id, top_k) + if result.candidates: + result.candidates[0] = SemanticCandidate(rule=rules[1], score=0.9) + result.top_score = 0.9 + return result + + rule_matcher = MockRuleMatcher() + semantic_matcher = ConflictSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + + assert result.decision_reason in ["rule_high_confidence", "llm_judge"] + + @pytest.mark.asyncio + async def test_match_hybrid_no_match(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test hybrid routing with no match.""" + class NoMatchSemanticMatcher(MockSemanticMatcher): + async def match(self, message, rules, tenant_id, top_k=3): + return SemanticMatchResult( + candidates=[], + top_score=0.0, + duration_ms=10, + skipped=True, + skip_reason="no_semantic_config", + ) + + rule_matcher = MockRuleMatcher() + semantic_matcher = NoMatchSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("随便说说", mock_rules, "tenant-1") + + assert result.decision_reason == "no_match" + assert result.final_intent is None + assert result.final_confidence == 0.0 + + @pytest.mark.asyncio + async def test_match_hybrid_semantic_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules): + """Test hybrid routing with semantic matcher disabled.""" + config = FusionConfig(semantic_matcher_enabled=False) + + rule_matcher = MockRuleMatcher() + semantic_matcher = MockSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + + assert result.decision_reason == "rule_high_confidence" + assert result.final_intent == mock_rules[0] + + @pytest.mark.asyncio + async def test_match_hybrid_llm_disabled(self, mock_embedding_provider, mock_llm_client, mock_rules): + """Test hybrid routing with LLM judge disabled.""" + config = FusionConfig(llm_judge_enabled=False) + + rule_matcher = MockRuleMatcher() + semantic_matcher = MockSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + + assert result.decision_reason == "rule_high_confidence" + + @pytest.mark.asyncio + async def test_match_hybrid_trace_generated(self, mock_embedding_provider, mock_llm_client, config, mock_rules): + """Test that route trace is generated.""" + rule_matcher = MockRuleMatcher() + semantic_matcher = MockSemanticMatcher(config) + llm_judge = MockLlmJudge(config) + fusion_policy = MockFusionPolicy(config) + + router = MockIntentRouter( + rule_matcher, semantic_matcher, llm_judge, fusion_policy, config + ) + + result = await router.match_hybrid("我想退货", mock_rules, "tenant-1") + + assert result.trace is not None diff --git a/ai-service/tests/test_kb_search_dynamic_slot_integration.py b/ai-service/tests/test_kb_search_dynamic_slot_integration.py new file mode 100644 index 0000000..62d25df --- /dev/null +++ b/ai-service/tests/test_kb_search_dynamic_slot_integration.py @@ -0,0 +1,308 @@ +""" +Tests for KB Search Dynamic Tool with Slot State Integration. +[AC-MRS-SLOT-META-02] KB 检索与槽位状态集成测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.models.mid.schemas import MemorySlot, SlotSource, ToolCallStatus +from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicConfig, + KbSearchDynamicTool, +) +from app.services.mid.slot_state_aggregator import SlotState + + +class TestKbSearchDynamicWithSlotState: + """测试 KB Search Dynamic Tool 与槽位状态的集成""" + + @pytest.fixture + def mock_session(self): + """模拟数据库会话""" + return AsyncMock() + + @pytest.fixture + def kb_tool(self, mock_session): + """创建 KB 工具实例""" + return KbSearchDynamicTool( + session=mock_session, + config=KbSearchDynamicConfig( + enabled=True, + top_k=5, + timeout_ms=10000, + min_score_threshold=0.5, + ), + ) + + @pytest.mark.asyncio + async def test_execute_with_missing_required_slots(self, kb_tool, mock_session): + """测试当存在缺失必填槽位时返回追问响应""" + slot_state = SlotState( + filled_slots={}, + missing_required_slots=[ + { + "slot_key": "product_line", + "label": "产品线", + "reason": "required_slot_missing", + "ask_back_prompt": "请问您咨询的是哪个产品线?", + } + ], + ) + + result = await kb_tool.execute( + query="退款政策", + tenant_id="test_tenant", + context={}, + slot_state=slot_state, + ) + + assert result.success is False + assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS" + assert len(result.missing_required_slots) == 1 + assert result.missing_required_slots[0]["slot_key"] == "product_line" + assert result.tool_trace is not None + assert result.tool_trace.error_code == "MISSING_REQUIRED_SLOTS" + + @pytest.mark.asyncio + async def test_execute_with_filled_slots(self, kb_tool, mock_session): + """测试使用已填充槽位构建过滤器""" + slot_state = SlotState( + filled_slots={"product_line": "vip_course"}, + missing_required_slots=[], + slot_to_field_map={"product_line": "product_line"}, + ) + + # 模拟 filter_builder + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[]) + kb_tool._filter_builder = mock_filter_builder + + # 模拟检索结果 + with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]): + result = await kb_tool.execute( + query="退款政策", + tenant_id="test_tenant", + context={}, + slot_state=slot_state, + ) + + # 应该成功执行(虽然没有命中结果) + assert result.success is True + + @pytest.mark.asyncio + async def test_build_filter_from_slot_state_priority(self, kb_tool, mock_session): + """测试过滤值来源优先级:slot > context > default""" + from app.services.mid.metadata_filter_builder import FilterFieldInfo + + slot_state = SlotState( + filled_slots={"product_line": "from_slot"}, + missing_required_slots=[], + slot_to_field_map={"product_line": "product_line"}, + ) + + context = {"product_line": "from_context"} + + # 模拟可过滤字段 + mock_field = FilterFieldInfo( + field_key="product_line", + label="产品线", + field_type="string", + required=True, + options=None, + default_value="from_default", + is_filterable=True, + ) + + # 模拟 filter_builder + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field]) + mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_slot"}) + kb_tool._filter_builder = mock_filter_builder + + filter_result = await kb_tool._build_filter_from_slot_state( + tenant_id="test_tenant", + slot_state=slot_state, + context=context, + ) + + # 应该使用 slot 的值(优先级最高) + assert "product_line" in filter_result + mock_filter_builder._build_field_filter.assert_called_once() + call_args = mock_filter_builder._build_field_filter.call_args + assert call_args[0][1] == "from_slot" + + @pytest.mark.asyncio + async def test_build_filter_uses_context_when_slot_empty(self, kb_tool, mock_session): + """测试当 slot 为空时使用 context 值""" + from app.services.mid.metadata_filter_builder import FilterFieldInfo + + slot_state = SlotState( + filled_slots={}, # 空槽位 + missing_required_slots=[], + slot_to_field_map={}, + ) + + context = {"product_line": "from_context"} + + # 模拟可过滤字段 + mock_field = FilterFieldInfo( + field_key="product_line", + label="产品线", + field_type="string", + required=True, + options=None, + default_value="from_default", + is_filterable=True, + ) + + # 模拟 filter_builder + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field]) + mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_context"}) + kb_tool._filter_builder = mock_filter_builder + + filter_result = await kb_tool._build_filter_from_slot_state( + tenant_id="test_tenant", + slot_state=slot_state, + context=context, + ) + + # 应该使用 context 的值 + assert "product_line" in filter_result + call_args = mock_filter_builder._build_field_filter.call_args + assert call_args[0][1] == "from_context" + + @pytest.mark.asyncio + async def test_build_filter_uses_default_when_no_other(self, kb_tool, mock_session): + """测试当 slot 和 context 都为空时使用默认值""" + from app.services.mid.metadata_filter_builder import FilterFieldInfo + + slot_state = SlotState( + filled_slots={}, + missing_required_slots=[], + slot_to_field_map={}, + ) + + context = {} + + # 模拟可过滤字段(带默认值) + mock_field = FilterFieldInfo( + field_key="product_line", + label="产品线", + field_type="string", + required=False, # 非必填 + options=None, + default_value="from_default", + is_filterable=True, + ) + + # 模拟 filter_builder + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field]) + mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "from_default"}) + kb_tool._filter_builder = mock_filter_builder + + filter_result = await kb_tool._build_filter_from_slot_state( + tenant_id="test_tenant", + slot_state=slot_state, + context=context, + ) + + # 应该使用默认值 + assert "product_line" in filter_result + call_args = mock_filter_builder._build_field_filter.call_args + assert call_args[0][1] == "from_default" + + +class TestKbSearchDynamicSlotMapping: + """测试槽位与字段映射""" + + @pytest.mark.asyncio + async def test_slot_to_field_mapping_in_filter(self): + """测试通过 slot_to_field_map 映射槽位值到字段""" + from app.services.mid.metadata_filter_builder import FilterFieldInfo + from unittest.mock import AsyncMock + + mock_session = AsyncMock() + kb_tool = KbSearchDynamicTool(session=mock_session) + + # slot_key 是 "product",但映射到 field_key "product_line" + slot_state = SlotState( + filled_slots={"product": "vip_course"}, + missing_required_slots=[], + slot_to_field_map={"product": "product_line"}, + ) + + # 模拟可过滤字段(使用 field_key) + mock_field = FilterFieldInfo( + field_key="product_line", + label="产品线", + field_type="string", + required=True, + options=None, + default_value=None, + is_filterable=True, + ) + + # 模拟 filter_builder + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field]) + mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"}) + kb_tool._filter_builder = mock_filter_builder + + filter_result = await kb_tool._build_filter_from_slot_state( + tenant_id="test_tenant", + slot_state=slot_state, + context={}, + ) + + # 应该通过映射找到值 + assert "product_line" in filter_result + call_args = mock_filter_builder._build_field_filter.call_args + assert call_args[0][1] == "vip_course" + + +class TestKbSearchDynamicDebugInfo: + """测试调试信息输出""" + + @pytest.mark.asyncio + async def test_filter_debug_includes_sources(self): + """测试过滤器调试信息包含来源标识""" + from app.services.mid.metadata_filter_builder import FilterFieldInfo + from unittest.mock import AsyncMock + + mock_session = AsyncMock() + kb_tool = KbSearchDynamicTool(session=mock_session) + + slot_state = SlotState( + filled_slots={"product_line": "vip_course"}, + missing_required_slots=[], + slot_to_field_map={"product_line": "product_line"}, + ) + + mock_field = FilterFieldInfo( + field_key="product_line", + label="产品线", + field_type="string", + required=True, + options=None, + default_value=None, + is_filterable=True, + ) + + mock_filter_builder = MagicMock() + mock_filter_builder._get_filterable_fields = AsyncMock(return_value=[mock_field]) + mock_filter_builder._build_field_filter = MagicMock(return_value={"$eq": "vip_course"}) + kb_tool._filter_builder = mock_filter_builder + + with patch.object(kb_tool, "_retrieve_with_timeout", return_value=[]): + result = await kb_tool.execute( + query="退款政策", + tenant_id="test_tenant", + context={}, + slot_state=slot_state, + ) + + # 调试信息应该包含来源 + assert result.filter_debug is not None diff --git a/ai-service/tests/test_llm_judge.py b/ai-service/tests/test_llm_judge.py new file mode 100644 index 0000000..37f6d00 --- /dev/null +++ b/ai-service/tests/test_llm_judge.py @@ -0,0 +1,291 @@ +""" +Unit tests for LlmJudge. +[AC-AISVC-118, AC-AISVC-119] Tests for LLM-based intent arbitration. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import uuid + +from app.services.intent.llm_judge import LlmJudge +from app.services.intent.models import ( + FusionConfig, + LlmJudgeInput, + LlmJudgeResult, + RuleMatchResult, + SemanticCandidate, + SemanticMatchResult, +) + + +@pytest.fixture +def mock_llm_client(): + """Create a mock LLM client.""" + client = AsyncMock() + return client + + +@pytest.fixture +def config(): + """Create a fusion config.""" + return FusionConfig() + + +@pytest.fixture +def mock_rule(): + """Create a mock intent rule.""" + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = "Test Intent" + return rule + + +class TestLlmJudge: + """Tests for LlmJudge class.""" + + def test_init(self, mock_llm_client, config): + """Test LlmJudge initialization.""" + judge = LlmJudge(mock_llm_client, config) + assert judge._llm_client == mock_llm_client + assert judge._config == config + + def test_should_trigger_disabled(self, mock_llm_client): + """Test should_trigger when LLM judge is disabled.""" + config = FusionConfig(llm_judge_enabled=False) + judge = LlmJudge(mock_llm_client, config) + + rule_result = RuleMatchResult( + rule_id=uuid.uuid4(), + rule=MagicMock(), + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + semantic_result = SemanticMatchResult( + candidates=[], + top_score=0.8, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + triggered, reason = judge.should_trigger(rule_result, semantic_result) + assert triggered is False + assert reason == "disabled" + + def test_should_trigger_rule_semantic_conflict(self, mock_llm_client, config, mock_rule): + """Test should_trigger for rule vs semantic conflict.""" + judge = LlmJudge(mock_llm_client, config) + + rule_result = RuleMatchResult( + rule_id=uuid.uuid4(), + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + + other_rule = MagicMock() + other_rule.id = uuid.uuid4() + other_rule.name = "Other Intent" + + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=other_rule, score=0.95)], + top_score=0.95, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + triggered, reason = judge.should_trigger(rule_result, semantic_result) + assert triggered is True + assert reason == "rule_semantic_conflict" + + def test_should_trigger_gray_zone(self, mock_llm_client, config, mock_rule): + """Test should_trigger for gray zone scenario.""" + judge = LlmJudge(mock_llm_client, config) + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.5)], + top_score=0.5, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + triggered, reason = judge.should_trigger(rule_result, semantic_result) + assert triggered is True + assert reason == "gray_zone" + + def test_should_trigger_multi_intent(self, mock_llm_client, config, mock_rule): + """Test should_trigger for multi-intent scenario.""" + judge = LlmJudge(mock_llm_client, config) + + rule_result = RuleMatchResult( + rule_id=None, + rule=None, + match_type=None, + matched_text=None, + score=0.0, + duration_ms=10, + ) + + other_rule = MagicMock() + other_rule.id = uuid.uuid4() + other_rule.name = "Other Intent" + + semantic_result = SemanticMatchResult( + candidates=[ + SemanticCandidate(rule=mock_rule, score=0.8), + SemanticCandidate(rule=other_rule, score=0.75), + ], + top_score=0.8, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + triggered, reason = judge.should_trigger(rule_result, semantic_result) + assert triggered is True + assert reason == "multi_intent" + + def test_should_not_trigger_high_confidence(self, mock_llm_client, config, mock_rule): + """Test should_trigger returns False for high confidence match.""" + judge = LlmJudge(mock_llm_client, config) + + rule_result = RuleMatchResult( + rule_id=mock_rule.id, + rule=mock_rule, + match_type="keyword", + matched_text="test", + score=1.0, + duration_ms=10, + ) + + semantic_result = SemanticMatchResult( + candidates=[SemanticCandidate(rule=mock_rule, score=0.9)], + top_score=0.9, + duration_ms=50, + skipped=False, + skip_reason=None, + ) + + triggered, reason = judge.should_trigger(rule_result, semantic_result) + assert triggered is False + + @pytest.mark.asyncio + async def test_judge_success(self, mock_llm_client, config): + """Test successful LLM judge.""" + from app.services.llm.base import LLMResponse + + mock_response = LLMResponse( + content='{"intent_id": "test-id", "intent_name": "Test", "confidence": 0.85, "reasoning": "Test reasoning"}', + model="gpt-4", + usage={"total_tokens": 100}, + ) + mock_llm_client.generate = AsyncMock(return_value=mock_response) + + judge = LlmJudge(mock_llm_client, config) + input_data = LlmJudgeInput( + message="test message", + candidates=[{"id": "test-id", "name": "Test", "description": "Test intent"}], + conflict_type="gray_zone", + ) + + result = await judge.judge(input_data, "tenant-1") + + assert result.triggered is True + assert result.intent_id == "test-id" + assert result.intent_name == "Test" + assert result.score == 0.85 + assert result.reasoning == "Test reasoning" + assert result.tokens_used == 100 + + @pytest.mark.asyncio + async def test_judge_timeout(self, mock_llm_client, config): + """Test LLM judge timeout.""" + import asyncio + mock_llm_client.generate = AsyncMock(side_effect=asyncio.TimeoutError()) + + judge = LlmJudge(mock_llm_client, config) + input_data = LlmJudgeInput( + message="test message", + candidates=[{"id": "test-id", "name": "Test"}], + conflict_type="gray_zone", + ) + + result = await judge.judge(input_data, "tenant-1") + + assert result.triggered is True + assert result.intent_id is None + assert "timeout" in result.reasoning.lower() + + @pytest.mark.asyncio + async def test_judge_error(self, mock_llm_client, config): + """Test LLM judge error handling.""" + mock_llm_client.generate = AsyncMock(side_effect=Exception("LLM error")) + + judge = LlmJudge(mock_llm_client, config) + input_data = LlmJudgeInput( + message="test message", + candidates=[{"id": "test-id", "name": "Test"}], + conflict_type="gray_zone", + ) + + result = await judge.judge(input_data, "tenant-1") + + assert result.triggered is True + assert result.intent_id is None + assert "error" in result.reasoning.lower() + + def test_parse_response_valid_json(self, mock_llm_client, config): + """Test parsing valid JSON response.""" + judge = LlmJudge(mock_llm_client, config) + + content = '{"intent_id": "test", "confidence": 0.9}' + result = judge._parse_response(content) + + assert result["intent_id"] == "test" + assert result["confidence"] == 0.9 + + def test_parse_response_with_markdown(self, mock_llm_client, config): + """Test parsing JSON response with markdown code block.""" + judge = LlmJudge(mock_llm_client, config) + + content = '```json\n{"intent_id": "test", "confidence": 0.9}\n```' + result = judge._parse_response(content) + + assert result["intent_id"] == "test" + assert result["confidence"] == 0.9 + + def test_parse_response_invalid_json(self, mock_llm_client, config): + """Test parsing invalid JSON response.""" + judge = LlmJudge(mock_llm_client, config) + + content = "This is not valid JSON" + result = judge._parse_response(content) + + assert result == {} + + def test_llm_judge_result_empty(self): + """Test LlmJudgeResult.empty() class method.""" + result = LlmJudgeResult.empty() + + assert result.intent_id is None + assert result.intent_name is None + assert result.score == 0.0 + assert result.reasoning is None + assert result.duration_ms == 0 + assert result.tokens_used == 0 + assert result.triggered is False diff --git a/ai-service/tests/test_scene_slot_bundle_loader.py b/ai-service/tests/test_scene_slot_bundle_loader.py new file mode 100644 index 0000000..fc537e4 --- /dev/null +++ b/ai-service/tests/test_scene_slot_bundle_loader.py @@ -0,0 +1,299 @@ +""" +Tests for Scene Slot Bundle Loader. +[AC-SCENE-SLOT-02] 场景槽位包加载器测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +from app.services.mid.scene_slot_bundle_loader import ( + SceneSlotBundleLoader, + SceneSlotContext, + SlotInfo, +) +from app.models.entities import ( + SceneSlotBundle, + SceneSlotBundleStatus, + SlotDefinition, +) + + +@pytest.fixture +def mock_session(): + """Mock database session.""" + session = AsyncMock() + session.execute = AsyncMock() + return session + + +@pytest.fixture +def sample_slot_definitions(): + """Sample slot definitions for testing.""" + return [ + SlotDefinition( + id=uuid4(), + tenant_id="test_tenant", + slot_key="course_type", + type="string", + required=True, + ask_back_prompt="请问您想咨询哪种类型的课程?", + ), + SlotDefinition( + id=uuid4(), + tenant_id="test_tenant", + slot_key="grade", + type="string", + required=True, + ask_back_prompt="请问您是几年级?", + ), + SlotDefinition( + id=uuid4(), + tenant_id="test_tenant", + slot_key="region", + type="string", + required=False, + ask_back_prompt="请问您在哪个地区?", + ), + ] + + +@pytest.fixture +def sample_bundle(sample_slot_definitions): + """Sample scene slot bundle for testing.""" + return SceneSlotBundle( + id=uuid4(), + tenant_id="test_tenant", + scene_key="open_consult", + scene_name="开放咨询", + description="开放咨询场景的槽位配置", + required_slots=["course_type", "grade"], + optional_slots=["region"], + slot_priority=["course_type", "grade", "region"], + completion_threshold=1.0, + ask_back_order="priority", + status=SceneSlotBundleStatus.ACTIVE.value, + version=1, + ) + + +class TestSceneSlotContext: + """Test cases for SceneSlotContext.""" + + def test_get_all_slot_keys(self): + """Test getting all slot keys.""" + context = SceneSlotContext( + scene_key="test_scene", + scene_name="测试场景", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True), + ], + optional_slots=[ + SlotInfo(slot_key="region", type="string", required=False), + ], + ) + + all_keys = context.get_all_slot_keys() + + assert "course_type" in all_keys + assert "region" in all_keys + assert len(all_keys) == 2 + + def test_get_missing_slots(self): + """Test getting missing slots.""" + context = SceneSlotContext( + scene_key="test_scene", + scene_name="测试场景", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"), + SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"), + ], + optional_slots=[], + ) + + filled_slots = {"course_type": "数学"} + missing = context.get_missing_slots(filled_slots) + + assert len(missing) == 1 + assert missing[0]["slot_key"] == "grade" + + def test_get_ordered_missing_slots_priority(self): + """Test getting ordered missing slots with priority order.""" + context = SceneSlotContext( + scene_key="test_scene", + scene_name="测试场景", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True), + SlotInfo(slot_key="grade", type="string", required=True), + ], + optional_slots=[], + slot_priority=["grade", "course_type"], + ask_back_order="priority", + ) + + filled_slots = {} + missing = context.get_ordered_missing_slots(filled_slots) + + assert len(missing) == 2 + assert missing[0]["slot_key"] == "grade" + assert missing[1]["slot_key"] == "course_type" + + def test_get_completion_ratio(self): + """Test calculating completion ratio.""" + context = SceneSlotContext( + scene_key="test_scene", + scene_name="测试场景", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True), + SlotInfo(slot_key="grade", type="string", required=True), + ], + optional_slots=[], + completion_threshold=0.5, + ) + + filled_slots = {"course_type": "数学"} + ratio = context.get_completion_ratio(filled_slots) + + assert ratio == 0.5 + + def test_is_complete(self): + """Test checking if complete.""" + context = SceneSlotContext( + scene_key="test_scene", + scene_name="测试场景", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True), + SlotInfo(slot_key="grade", type="string", required=True), + ], + optional_slots=[], + completion_threshold=1.0, + ) + + filled_slots = {"course_type": "数学", "grade": "高一"} + is_complete = context.is_complete(filled_slots) + + assert is_complete is True + + filled_slots_partial = {"course_type": "数学"} + is_complete_partial = context.is_complete(filled_slots_partial) + + assert is_complete_partial is False + + +class TestSceneSlotBundleLoader: + """Test cases for SceneSlotBundleLoader.""" + + @pytest.mark.asyncio + async def test_load_scene_context(self, mock_session, sample_bundle, sample_slot_definitions): + """Test loading scene context.""" + loader = SceneSlotBundleLoader(mock_session) + + with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get_bundle: + mock_get_bundle.return_value = sample_bundle + + with patch.object(loader._slot_service, 'list_slot_definitions', new_callable=AsyncMock) as mock_get_slots: + mock_get_slots.return_value = sample_slot_definitions + + context = await loader.load_scene_context("test_tenant", "open_consult") + + assert context is not None + assert context.scene_key == "open_consult" + assert len(context.required_slots) == 2 + assert len(context.optional_slots) == 1 + + @pytest.mark.asyncio + async def test_load_scene_context_not_found(self, mock_session): + """Test loading scene context when bundle not found.""" + loader = SceneSlotBundleLoader(mock_session) + + with patch.object(loader._bundle_service, 'get_active_bundle_by_scene', new_callable=AsyncMock) as mock_get: + mock_get.return_value = None + + context = await loader.load_scene_context("test_tenant", "unknown_scene") + + assert context is None + + @pytest.mark.asyncio + async def test_get_missing_slots_for_scene(self, mock_session, sample_bundle, sample_slot_definitions): + """Test getting missing slots for a scene.""" + loader = SceneSlotBundleLoader(mock_session) + + with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load: + mock_context = SceneSlotContext( + scene_key="open_consult", + scene_name="开放咨询", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问课程类型?"), + SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="请问年级?"), + ], + optional_slots=[], + slot_priority=["course_type", "grade"], + ) + mock_load.return_value = mock_context + + filled_slots = {"course_type": "数学"} + missing = await loader.get_missing_slots_for_scene( + "test_tenant", "open_consult", filled_slots + ) + + assert len(missing) == 1 + assert missing[0]["slot_key"] == "grade" + + @pytest.mark.asyncio + async def test_generate_ask_back_prompt_single(self, mock_session): + """Test generating ask-back prompt for single missing slot.""" + loader = SceneSlotBundleLoader(mock_session) + + with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load: + mock_context = SceneSlotContext( + scene_key="open_consult", + scene_name="开放咨询", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="请问您想咨询哪种课程?"), + ], + optional_slots=[], + ask_back_order="priority", + ) + mock_load.return_value = mock_context + + with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing: + mock_missing.return_value = [ + {"slot_key": "course_type", "ask_back_prompt": "请问您想咨询哪种课程?"} + ] + + prompt = await loader.generate_ask_back_prompt( + "test_tenant", "open_consult", {} + ) + + assert prompt == "请问您想咨询哪种课程?" + + @pytest.mark.asyncio + async def test_generate_ask_back_prompt_parallel(self, mock_session): + """Test generating ask-back prompt with parallel strategy.""" + loader = SceneSlotBundleLoader(mock_session) + + with patch.object(loader, 'load_scene_context', new_callable=AsyncMock) as mock_load: + mock_context = SceneSlotContext( + scene_key="open_consult", + scene_name="开放咨询", + required_slots=[ + SlotInfo(slot_key="course_type", type="string", required=True, ask_back_prompt="课程类型"), + SlotInfo(slot_key="grade", type="string", required=True, ask_back_prompt="年级"), + ], + optional_slots=[], + ask_back_order="parallel", + ) + mock_load.return_value = mock_context + + with patch.object(loader, 'get_missing_slots_for_scene', new_callable=AsyncMock) as mock_missing: + mock_missing.return_value = [ + {"slot_key": "course_type", "ask_back_prompt": "课程类型"}, + {"slot_key": "grade", "ask_back_prompt": "年级"}, + ] + + prompt = await loader.generate_ask_back_prompt( + "test_tenant", "open_consult", {} + ) + + assert "课程类型" in prompt + assert "年级" in prompt diff --git a/ai-service/tests/test_scene_slot_bundle_service.py b/ai-service/tests/test_scene_slot_bundle_service.py new file mode 100644 index 0000000..61077aa --- /dev/null +++ b/ai-service/tests/test_scene_slot_bundle_service.py @@ -0,0 +1,284 @@ +""" +Tests for Scene Slot Bundle Service. +[AC-SCENE-SLOT-01] 场景-槽位映射配置服务测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +from app.models.entities import ( + SceneSlotBundle, + SceneSlotBundleCreate, + SceneSlotBundleUpdate, + SceneSlotBundleStatus, + SlotDefinition, +) +from app.services.scene_slot_bundle_service import SceneSlotBundleService + + +@pytest.fixture +def mock_session(): + """Mock database session.""" + session = AsyncMock() + session.execute = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + session.delete = AsyncMock() + return session + + +@pytest.fixture +def sample_slot_definition(): + """Sample slot definition for testing.""" + return SlotDefinition( + id=uuid4(), + tenant_id="test_tenant", + slot_key="course_type", + type="string", + required=True, + ask_back_prompt="请问您想咨询哪种类型的课程?", + ) + + +@pytest.fixture +def sample_bundle(): + """Sample scene slot bundle for testing.""" + return SceneSlotBundle( + id=uuid4(), + tenant_id="test_tenant", + scene_key="open_consult", + scene_name="开放咨询", + description="开放咨询场景的槽位配置", + required_slots=["course_type", "grade"], + optional_slots=["region"], + slot_priority=["course_type", "grade", "region"], + completion_threshold=1.0, + ask_back_order="priority", + status=SceneSlotBundleStatus.ACTIVE.value, + version=1, + ) + + +class TestSceneSlotBundleService: + """Test cases for SceneSlotBundleService.""" + + @pytest.mark.asyncio + async def test_list_bundles(self, mock_session, sample_bundle): + """Test listing scene slot bundles.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [sample_bundle] + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + bundles = await service.list_bundles("test_tenant") + + assert len(bundles) == 1 + assert bundles[0].scene_key == "open_consult" + + @pytest.mark.asyncio + async def test_list_bundles_with_status_filter(self, mock_session, sample_bundle): + """Test listing scene slot bundles with status filter.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [sample_bundle] + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + bundles = await service.list_bundles("test_tenant", status="active") + + assert len(bundles) == 1 + + @pytest.mark.asyncio + async def test_get_bundle_by_id(self, mock_session, sample_bundle): + """Test getting a bundle by ID.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_bundle + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + bundle = await service.get_bundle("test_tenant", str(sample_bundle.id)) + + assert bundle is not None + assert bundle.scene_key == "open_consult" + + @pytest.mark.asyncio + async def test_get_bundle_by_scene_key(self, mock_session, sample_bundle): + """Test getting a bundle by scene key.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_bundle + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + bundle = await service.get_bundle_by_scene_key("test_tenant", "open_consult") + + assert bundle is not None + assert bundle.scene_key == "open_consult" + + @pytest.mark.asyncio + async def test_get_active_bundle_by_scene(self, mock_session, sample_bundle): + """Test getting an active bundle by scene key.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_bundle + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + bundle = await service.get_active_bundle_by_scene("test_tenant", "open_consult") + + assert bundle is not None + assert bundle.status == SceneSlotBundleStatus.ACTIVE.value + + @pytest.mark.asyncio + async def test_create_bundle_success(self, mock_session, sample_slot_definition): + """Test creating a scene slot bundle successfully.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [sample_slot_definition] + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + with patch.object(service, '_validate_slot_keys', new_callable=AsyncMock) as mock_validate: + mock_validate.return_value = {"course_type"} + + with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get: + mock_get.return_value = None + + bundle_create = SceneSlotBundleCreate( + scene_key="new_scene", + scene_name="新场景", + required_slots=["course_type"], + optional_slots=[], + ) + + bundle = await service.create_bundle("test_tenant", bundle_create) + + assert bundle is not None + assert bundle.scene_key == "new_scene" + mock_session.add.assert_called_once() + + @pytest.mark.asyncio + async def test_create_bundle_duplicate_scene_key(self, mock_session, sample_bundle): + """Test creating a bundle with duplicate scene key.""" + service = SceneSlotBundleService(mock_session) + + with patch.object(service, 'get_bundle_by_scene_key', new_callable=AsyncMock) as mock_get: + mock_get.return_value = sample_bundle + + with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate: + mock_validate.return_value = [] + + bundle_create = SceneSlotBundleCreate( + scene_key="open_consult", + scene_name="开放咨询", + ) + + with pytest.raises(ValueError, match="已存在"): + await service.create_bundle("test_tenant", bundle_create) + + @pytest.mark.asyncio + async def test_create_bundle_validation_error(self, mock_session): + """Test creating a bundle with validation error.""" + service = SceneSlotBundleService(mock_session) + + with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate: + mock_validate.return_value = ["必填和可选槽位存在交叉"] + + bundle_create = SceneSlotBundleCreate( + scene_key="new_scene", + scene_name="新场景", + required_slots=["course_type"], + optional_slots=["course_type"], + ) + + with pytest.raises(ValueError, match="交叉"): + await service.create_bundle("test_tenant", bundle_create) + + @pytest.mark.asyncio + async def test_update_bundle_success(self, mock_session, sample_bundle): + """Test updating a scene slot bundle successfully.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_bundle + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + with patch.object(service, '_validate_bundle_data', new_callable=AsyncMock) as mock_validate: + mock_validate.return_value = [] + + bundle_update = SceneSlotBundleUpdate( + scene_name="更新后的场景名称", + ) + + bundle = await service.update_bundle("test_tenant", str(sample_bundle.id), bundle_update) + + assert bundle is not None + assert bundle.scene_name == "更新后的场景名称" + assert bundle.version == 2 + + @pytest.mark.asyncio + async def test_delete_bundle_success(self, mock_session, sample_bundle): + """Test deleting a scene slot bundle successfully.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_bundle + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + success = await service.delete_bundle("test_tenant", str(sample_bundle.id)) + + assert success is True + mock_session.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_bundle_not_found(self, mock_session): + """Test deleting a non-existent bundle.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + success = await service.delete_bundle("test_tenant", str(uuid4())) + + assert success is False + + +class TestSceneSlotBundleValidation: + """Test cases for bundle validation.""" + + @pytest.mark.asyncio + async def test_validate_required_optional_overlap(self, mock_session, sample_slot_definition): + """Test validation for required and optional slots overlap.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [sample_slot_definition] + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + errors = await service._validate_bundle_data( + tenant_id="test_tenant", + required_slots=["course_type"], + optional_slots=["course_type"], + slot_priority=None, + ) + + assert len(errors) > 0 + assert any("交叉" in e for e in errors) + + @pytest.mark.asyncio + async def test_validate_priority_unknown_slots(self, mock_session, sample_slot_definition): + """Test validation for unknown slots in priority list.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [sample_slot_definition] + mock_session.execute.return_value = mock_result + + service = SceneSlotBundleService(mock_session) + + errors = await service._validate_bundle_data( + tenant_id="test_tenant", + required_slots=["course_type"], + optional_slots=[], + slot_priority=["course_type", "unknown_slot"], + ) + + assert len(errors) > 0 + assert any("未定义" in e for e in errors) diff --git a/ai-service/tests/test_semantic_matcher.py b/ai-service/tests/test_semantic_matcher.py new file mode 100644 index 0000000..241ddea --- /dev/null +++ b/ai-service/tests/test_semantic_matcher.py @@ -0,0 +1,210 @@ +""" +Unit tests for SemanticMatcher. +[AC-AISVC-113, AC-AISVC-114] Tests for semantic matching. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import uuid + +from app.services.intent.semantic_matcher import SemanticMatcher +from app.services.intent.models import ( + FusionConfig, + SemanticCandidate, + SemanticMatchResult, +) + + +@pytest.fixture +def mock_embedding_provider(): + """Create a mock embedding provider.""" + provider = AsyncMock() + provider.embed = AsyncMock(return_value=[0.1] * 768) + provider.embed_batch = AsyncMock(return_value=[[0.1] * 768, [0.2] * 768]) + return provider + + +@pytest.fixture +def mock_rule(): + """Create a mock intent rule with semantic config.""" + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = "Test Intent" + rule.intent_vector = [0.1] * 768 + rule.semantic_examples = None + rule.is_enabled = True + return rule + + +@pytest.fixture +def mock_rule_with_examples(): + """Create a mock intent rule with semantic examples.""" + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = "Test Intent with Examples" + rule.intent_vector = None + rule.semantic_examples = ["我想退货", "如何退款"] + rule.is_enabled = True + return rule + + +@pytest.fixture +def config(): + """Create a fusion config.""" + return FusionConfig() + + +class TestSemanticMatcher: + """Tests for SemanticMatcher class.""" + + @pytest.mark.asyncio + async def test_init(self, mock_embedding_provider, config): + """Test SemanticMatcher initialization.""" + matcher = SemanticMatcher(mock_embedding_provider, config) + assert matcher._embedding_provider == mock_embedding_provider + assert matcher._config == config + + @pytest.mark.asyncio + async def test_match_disabled(self, mock_embedding_provider): + """Test match when semantic matcher is disabled.""" + config = FusionConfig(semantic_matcher_enabled=False) + matcher = SemanticMatcher(mock_embedding_provider, config) + + result = await matcher.match("test message", [], "tenant-1") + + assert result.skipped is True + assert result.skip_reason == "disabled" + assert result.candidates == [] + + @pytest.mark.asyncio + async def test_match_no_semantic_config( + self, mock_embedding_provider, config, mock_rule + ): + """Test match when no rules have semantic config.""" + mock_rule.intent_vector = None + mock_rule.semantic_examples = None + + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("test message", [mock_rule], "tenant-1") + + assert result.skipped is True + assert result.skip_reason == "no_semantic_config" + + @pytest.mark.asyncio + async def test_match_mode_a_with_intent_vector( + self, mock_embedding_provider, config, mock_rule + ): + """Test match with pre-computed intent vector (Mode A).""" + mock_embedding_provider.embed.return_value = [0.1] * 768 + + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("我想退货", [mock_rule], "tenant-1") + + assert result.skipped is False + assert result.skip_reason is None + assert len(result.candidates) == 1 + assert result.top_score > 0.9 + assert result.duration_ms >= 0 + + @pytest.mark.asyncio + async def test_match_mode_b_with_examples( + self, mock_embedding_provider, config, mock_rule_with_examples + ): + """Test match with semantic examples (Mode B).""" + mock_embedding_provider.embed.return_value = [0.1] * 768 + mock_embedding_provider.embed_batch.return_value = [[0.1] * 768, [0.1] * 768] + + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("我想退货", [mock_rule_with_examples], "tenant-1") + + assert result.skipped is False + assert len(result.candidates) == 1 + assert result.top_score > 0.9 + + @pytest.mark.asyncio + async def test_match_embedding_timeout(self, mock_embedding_provider, config, mock_rule): + """Test match when embedding times out.""" + import asyncio + mock_embedding_provider.embed.side_effect = asyncio.TimeoutError() + + config = FusionConfig(semantic_matcher_timeout_ms=100) + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("test message", [mock_rule], "tenant-1") + + assert result.skipped is True + assert "embedding_timeout" in result.skip_reason + + @pytest.mark.asyncio + async def test_match_embedding_error(self, mock_embedding_provider, config, mock_rule): + """Test match when embedding fails with error.""" + mock_embedding_provider.embed.side_effect = Exception("Embedding failed") + + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("test message", [mock_rule], "tenant-1") + + assert result.skipped is True + assert "embedding_error" in result.skip_reason + + @pytest.mark.asyncio + async def test_match_top_k_limit(self, mock_embedding_provider, config): + """Test that match returns only top_k candidates.""" + rules = [] + for i in range(5): + rule = MagicMock() + rule.id = uuid.uuid4() + rule.name = f"Intent {i}" + rule.intent_vector = [0.1 + i * 0.01] * 768 + rule.semantic_examples = None + rule.is_enabled = True + rules.append(rule) + + mock_embedding_provider.embed.return_value = [0.1] * 768 + + config = FusionConfig(semantic_top_k=3) + matcher = SemanticMatcher(mock_embedding_provider, config) + result = await matcher.match("test message", rules, "tenant-1") + + assert len(result.candidates) <= 3 + + def test_cosine_similarity(self, mock_embedding_provider, config): + """Test cosine similarity calculation.""" + matcher = SemanticMatcher(mock_embedding_provider, config) + + v1 = [1.0, 0.0, 0.0] + v2 = [1.0, 0.0, 0.0] + similarity = matcher._cosine_similarity(v1, v2) + assert similarity == 1.0 + + v1 = [1.0, 0.0, 0.0] + v2 = [0.0, 1.0, 0.0] + similarity = matcher._cosine_similarity(v1, v2) + assert similarity == 0.0 + + v1 = [1.0, 1.0, 0.0] + v2 = [1.0, 0.0, 0.0] + similarity = matcher._cosine_similarity(v1, v2) + assert 0.0 < similarity < 1.0 + + def test_cosine_similarity_empty_vectors(self, mock_embedding_provider, config): + """Test cosine similarity with empty vectors.""" + matcher = SemanticMatcher(mock_embedding_provider, config) + + assert matcher._cosine_similarity([], [1.0]) == 0.0 + assert matcher._cosine_similarity([1.0], []) == 0.0 + assert matcher._cosine_similarity([], []) == 0.0 + + def test_has_semantic_config(self, mock_embedding_provider, config, mock_rule): + """Test checking if rule has semantic config.""" + matcher = SemanticMatcher(mock_embedding_provider, config) + + mock_rule.intent_vector = [0.1] * 768 + mock_rule.semantic_examples = None + assert matcher._has_semantic_config(mock_rule) is True + + mock_rule.intent_vector = None + mock_rule.semantic_examples = ["example"] + assert matcher._has_semantic_config(mock_rule) is True + + mock_rule.intent_vector = None + mock_rule.semantic_examples = None + assert matcher._has_semantic_config(mock_rule) is False diff --git a/ai-service/tests/test_slot_backfill_service.py b/ai-service/tests/test_slot_backfill_service.py new file mode 100644 index 0000000..b445779 --- /dev/null +++ b/ai-service/tests/test_slot_backfill_service.py @@ -0,0 +1,419 @@ +""" +Tests for Slot Backfill Service. +[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试 +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.models.mid.schemas import SlotSource +from app.services.mid.slot_backfill_service import ( + BackfillResult, + BackfillStatus, + BatchBackfillResult, + SlotBackfillService, + create_slot_backfill_service, +) +from app.services.mid.slot_manager import SlotWriteResult +from app.services.mid.slot_strategy_executor import ( + StrategyChainResult, + StrategyStepResult, +) + + +class TestBackfillResult: + """BackfillResult 测试""" + + def test_is_success(self): + """测试成功判断""" + result = BackfillResult(status=BackfillStatus.SUCCESS) + assert result.is_success() is True + + result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED) + assert result.is_success() is False + + def test_needs_ask_back(self): + """测试需要追问判断""" + result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED) + assert result.needs_ask_back() is True + + result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED) + assert result.needs_ask_back() is True + + result = BackfillResult(status=BackfillStatus.SUCCESS) + assert result.needs_ask_back() is False + + def test_needs_confirmation(self): + """测试需要确认判断""" + result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION) + assert result.needs_confirmation() is True + + result = BackfillResult(status=BackfillStatus.SUCCESS) + assert result.needs_confirmation() is False + + def test_to_dict(self): + """测试转换为字典""" + result = BackfillResult( + status=BackfillStatus.SUCCESS, + slot_key="region", + value="北京", + normalized_value="北京", + source="user_confirmed", + confidence=1.0, + ) + d = result.to_dict() + assert d["status"] == "success" + assert d["slot_key"] == "region" + assert d["value"] == "北京" + assert d["source"] == "user_confirmed" + + +class TestBatchBackfillResult: + """BatchBackfillResult 测试""" + + def test_add_result(self): + """测试添加结果""" + batch = BatchBackfillResult() + + batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region")) + batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product")) + batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade")) + + assert batch.success_count == 1 + assert batch.failed_count == 1 + assert batch.confirmation_needed_count == 1 + + def test_get_ask_back_prompts(self): + """测试获取追问提示""" + batch = BatchBackfillResult() + + batch.add_result(BackfillResult( + status=BackfillStatus.VALIDATION_FAILED, + ask_back_prompt="请重新输入", + )) + batch.add_result(BackfillResult( + status=BackfillStatus.SUCCESS, + )) + batch.add_result(BackfillResult( + status=BackfillStatus.EXTRACTION_FAILED, + ask_back_prompt="无法识别,请重试", + )) + + prompts = batch.get_ask_back_prompts() + assert len(prompts) == 2 + assert "请重新输入" in prompts + assert "无法识别,请重试" in prompts + + def test_get_confirmation_prompts(self): + """测试获取确认提示""" + batch = BatchBackfillResult() + + batch.add_result(BackfillResult( + status=BackfillStatus.NEEDS_CONFIRMATION, + confirmation_prompt="我理解您说的是「北京」,对吗?", + )) + batch.add_result(BackfillResult( + status=BackfillStatus.SUCCESS, + )) + + prompts = batch.get_confirmation_prompts() + assert len(prompts) == 1 + assert "北京" in prompts[0] + + +class TestSlotBackfillService: + """SlotBackfillService 测试""" + + @pytest.fixture + def mock_session(self): + """创建 mock session""" + return AsyncMock() + + @pytest.fixture + def mock_slot_manager(self): + """创建 mock slot manager""" + manager = MagicMock() + manager.write_slot = AsyncMock() + manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息") + return manager + + @pytest.fixture + def service(self, mock_session, mock_slot_manager): + """创建服务实例""" + return SlotBackfillService( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + slot_manager=mock_slot_manager, + ) + + def test_confidence_thresholds(self, service): + """测试置信度阈值""" + assert service.CONFIDENCE_THRESHOLD_LOW == 0.5 + assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8 + + def test_get_source_for_strategy(self, service): + """测试策略到来源的映射""" + assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value + assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value + assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value + assert service._get_source_for_strategy("unknown") == "unknown" + + def test_get_confidence_for_strategy(self, service): + """测试来源到置信度的映射""" + assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0 + assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9 + assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7 + assert service._get_confidence_for_strategy("context") == 0.5 + assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3 + + def test_generate_confirmation_prompt(self, service): + """测试生成确认提示""" + prompt = service._generate_confirmation_prompt("region", "北京") + assert "北京" in prompt + assert "对吗" in prompt + + @pytest.mark.asyncio + async def test_backfill_single_slot_success(self, service, mock_slot_manager): + """测试单个槽位回填成功""" + mock_slot_manager.write_slot.return_value = SlotWriteResult( + success=True, + slot_key="region", + value="北京", + ) + + with patch.object(service, '_get_state_aggregator') as mock_agg: + mock_aggregator = AsyncMock() + mock_aggregator.update_slot = AsyncMock() + mock_agg.return_value = mock_aggregator + + result = await service.backfill_single_slot( + slot_key="region", + candidate_value="北京", + source="user_confirmed", + confidence=1.0, + ) + + assert result.status == BackfillStatus.SUCCESS + assert result.slot_key == "region" + assert result.normalized_value == "北京" + + @pytest.mark.asyncio + async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager): + """测试单个槽位回填校验失败""" + from app.services.mid.slot_validation_service import SlotValidationError + + mock_slot_manager.write_slot.return_value = SlotWriteResult( + success=False, + slot_key="region", + error=SlotValidationError( + slot_key="region", + error_code="INVALID_VALUE", + error_message="无效的地区", + ), + ask_back_prompt="请提供有效的地区", + ) + + result = await service.backfill_single_slot( + slot_key="region", + candidate_value="无效地区", + source="user_confirmed", + confidence=1.0, + ) + + assert result.status == BackfillStatus.VALIDATION_FAILED + assert result.ask_back_prompt == "请提供有效的地区" + + @pytest.mark.asyncio + async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager): + """测试低置信度槽位需要确认""" + mock_slot_manager.write_slot.return_value = SlotWriteResult( + success=True, + slot_key="region", + value="北京", + ) + + with patch.object(service, '_get_state_aggregator') as mock_agg: + mock_aggregator = AsyncMock() + mock_aggregator.update_slot = AsyncMock() + mock_agg.return_value = mock_aggregator + + result = await service.backfill_single_slot( + slot_key="region", + candidate_value="北京", + source="llm_inferred", + confidence=0.4, + ) + + assert result.status == BackfillStatus.NEEDS_CONFIRMATION + assert result.confirmation_prompt is not None + assert "北京" in result.confirmation_prompt + + @pytest.mark.asyncio + async def test_backfill_multiple_slots(self, service, mock_slot_manager): + """测试批量回填槽位""" + mock_slot_manager.write_slot.side_effect = [ + SlotWriteResult(success=True, slot_key="region", value="北京"), + SlotWriteResult(success=True, slot_key="product", value="手机"), + SlotWriteResult(success=False, slot_key="grade", error=MagicMock()), + ] + + with patch.object(service, '_get_state_aggregator') as mock_agg: + mock_aggregator = AsyncMock() + mock_aggregator.update_slot = AsyncMock() + mock_agg.return_value = mock_aggregator + + result = await service.backfill_multiple_slots( + candidates={ + "region": "北京", + "product": "手机", + "grade": "无效等级", + }, + source="user_confirmed", + ) + + assert result.success_count == 2 + assert result.failed_count == 1 + + @pytest.mark.asyncio + async def test_confirm_low_confidence_slot_confirmed(self, service): + """测试确认低置信度槽位 - 用户确认""" + with patch.object(service, '_get_state_aggregator') as mock_agg: + mock_aggregator = AsyncMock() + mock_aggregator.update_slot = AsyncMock() + mock_agg.return_value = mock_aggregator + + result = await service.confirm_low_confidence_slot( + slot_key="region", + confirmed=True, + ) + + assert result.status == BackfillStatus.SUCCESS + assert result.source == SlotSource.USER_CONFIRMED.value + assert result.confidence == 1.0 + + @pytest.mark.asyncio + async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager): + """测试确认低置信度槽位 - 用户拒绝""" + with patch.object(service, '_get_state_aggregator') as mock_agg: + mock_aggregator = AsyncMock() + mock_aggregator.clear_slot = AsyncMock() + mock_agg.return_value = mock_aggregator + + result = await service.confirm_low_confidence_slot( + slot_key="region", + confirmed=False, + ) + + assert result.status == BackfillStatus.VALIDATION_FAILED + assert result.ask_back_prompt is not None + + +class TestCreateSlotBackfillService: + """create_slot_backfill_service 工厂函数测试""" + + def test_create(self): + """测试创建服务实例""" + mock_session = AsyncMock() + service = create_slot_backfill_service( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + ) + assert isinstance(service, SlotBackfillService) + assert service._tenant_id == "tenant_1" + assert service._session_id == "session_1" + + +class TestBackfillFromUserResponse: + """从用户回复回填测试""" + + @pytest.fixture + def service(self): + """创建服务实例""" + mock_session = AsyncMock() + mock_slot_def_service = AsyncMock() + + service = SlotBackfillService( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + ) + service._slot_def_service = mock_slot_def_service + return service + + @pytest.mark.asyncio + async def test_backfill_from_user_response_success(self, service): + """测试从用户回复成功提取并回填""" + mock_slot_def = MagicMock() + mock_slot_def.type = "string" + mock_slot_def.validation_rule = None + mock_slot_def.ask_back_prompt = "请提供地区" + + service._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=mock_slot_def + ) + + with patch.object(service, '_extract_value') as mock_extract: + mock_extract.return_value = StrategyChainResult( + slot_key="region", + success=True, + final_value="北京", + final_strategy="rule", + ) + + with patch.object(service, 'backfill_single_slot') as mock_backfill: + mock_backfill.return_value = BackfillResult( + status=BackfillStatus.SUCCESS, + slot_key="region", + value="北京", + ) + + result = await service.backfill_from_user_response( + user_response="我想查询北京的产品", + expected_slots=["region"], + ) + + assert result.success_count == 1 + + @pytest.mark.asyncio + async def test_backfill_from_user_response_no_definition(self, service): + """测试槽位定义不存在""" + service._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=None + ) + + result = await service.backfill_from_user_response( + user_response="我想查询北京的产品", + expected_slots=["unknown_slot"], + ) + + assert result.success_count == 0 + assert result.failed_count == 0 + + @pytest.mark.asyncio + async def test_backfill_from_user_response_extraction_failed(self, service): + """测试提取失败""" + mock_slot_def = MagicMock() + mock_slot_def.type = "string" + mock_slot_def.validation_rule = None + mock_slot_def.ask_back_prompt = "请提供地区" + + service._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=mock_slot_def + ) + + with patch.object(service, '_extract_value') as mock_extract: + mock_extract.return_value = StrategyChainResult( + slot_key="region", + success=False, + ) + + result = await service.backfill_from_user_response( + user_response="我想查询产品", + expected_slots=["region"], + ) + + assert result.failed_count == 1 + assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED diff --git a/ai-service/tests/test_slot_extraction_integration.py b/ai-service/tests/test_slot_extraction_integration.py new file mode 100644 index 0000000..2ac17c1 --- /dev/null +++ b/ai-service/tests/test_slot_extraction_integration.py @@ -0,0 +1,335 @@ +""" +Tests for Slot Extraction Integration. +[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试 +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.models.mid.schemas import SlotSource +from app.services.mid.slot_extraction_integration import ( + ExtractionResult, + ExtractionTrace, + SlotExtractionIntegration, + integrate_slot_extraction, +) +from app.services.mid.slot_strategy_executor import ( + StrategyChainResult, + StrategyStepResult, +) + + +class TestExtractionTrace: + """ExtractionTrace 测试""" + + def test_init(self): + """测试初始化""" + trace = ExtractionTrace(slot_key="region") + assert trace.slot_key == "region" + assert trace.strategy is None + assert trace.validation_passed is False + + def test_to_dict(self): + """测试转换为字典""" + trace = ExtractionTrace( + slot_key="region", + strategy="rule", + extracted_value="北京", + validation_passed=True, + final_value="北京", + execution_time_ms=10.5, + ) + d = trace.to_dict() + assert d["slot_key"] == "region" + assert d["strategy"] == "rule" + assert d["extracted_value"] == "北京" + assert d["validation_passed"] is True + + +class TestExtractionResult: + """ExtractionResult 测试""" + + def test_init(self): + """测试初始化""" + result = ExtractionResult() + assert result.success is False + assert result.extracted_slots == {} + assert result.failed_slots == [] + + def test_to_dict(self): + """测试转换为字典""" + result = ExtractionResult( + success=True, + extracted_slots={"region": "北京"}, + failed_slots=["product"], + traces=[ExtractionTrace(slot_key="region")], + total_execution_time_ms=50.0, + ask_back_triggered=True, + ask_back_prompts=["请提供产品信息"], + ) + d = result.to_dict() + assert d["success"] is True + assert d["extracted_slots"] == {"region": "北京"} + assert d["failed_slots"] == ["product"] + assert d["ask_back_triggered"] is True + + +class TestSlotExtractionIntegration: + """SlotExtractionIntegration 测试""" + + @pytest.fixture + def mock_session(self): + """创建 mock session""" + return AsyncMock() + + @pytest.fixture + def integration(self, mock_session): + """创建集成实例""" + return SlotExtractionIntegration( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + ) + + def test_default_strategies(self, integration): + """测试默认策略""" + assert integration.DEFAULT_STRATEGIES == ["rule", "llm"] + + def test_get_source_for_strategy(self, integration): + """测试策略到来源的映射""" + assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value + assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value + assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value + + def test_get_confidence_for_source(self, integration): + """测试来源到置信度的映射""" + assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0 + assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9 + assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7 + + @pytest.mark.asyncio + async def test_extract_and_fill_no_target_slots(self, integration): + """测试没有目标槽位""" + result = await integration.extract_and_fill( + user_input="测试输入", + target_slots=[], + ) + + assert result.success is True + assert result.extracted_slots == {} + + @pytest.mark.asyncio + async def test_extract_and_fill_slot_not_found(self, integration): + """测试槽位定义不存在""" + integration._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=None + ) + integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None) + + result = await integration.extract_and_fill( + user_input="测试输入", + target_slots=["unknown_slot"], + ) + + assert result.success is False + assert "unknown_slot" in result.failed_slots + assert result.traces[0].failure_reason == "Slot definition not found" + + @pytest.mark.asyncio + async def test_extract_and_fill_extraction_success(self, integration): + """测试提取成功""" + mock_slot_def = MagicMock() + mock_slot_def.type = "string" + mock_slot_def.validation_rule = None + mock_slot_def.ask_back_prompt = "请提供地区" + + integration._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=mock_slot_def + ) + + with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: + mock_chain.return_value = StrategyChainResult( + slot_key="region", + success=True, + final_value="北京", + final_strategy="rule", + ) + + with patch.object(integration, '_get_backfill_service') as mock_backfill_svc: + mock_backfill = AsyncMock() + mock_backfill.backfill_single_slot = AsyncMock( + return_value=MagicMock( + is_success=lambda: True, + normalized_value="北京", + error_message=None, + ) + ) + mock_backfill_svc.return_value = mock_backfill + + with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock): + result = await integration.extract_and_fill( + user_input="我想查询北京的产品", + target_slots=["region"], + ) + + assert result.success is True + assert "region" in result.extracted_slots + assert result.extracted_slots["region"] == "北京" + + @pytest.mark.asyncio + async def test_extract_and_fill_extraction_failed(self, integration): + """测试提取失败""" + mock_slot_def = MagicMock() + mock_slot_def.type = "string" + mock_slot_def.validation_rule = None + mock_slot_def.ask_back_prompt = "请提供地区" + + integration._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=mock_slot_def + ) + + with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: + mock_chain.return_value = StrategyChainResult( + slot_key="region", + success=False, + steps=[ + StrategyStepResult( + strategy="rule", + success=False, + failure_reason="无法提取", + ) + ], + ) + + integration._slot_manager.get_ask_back_prompt = AsyncMock( + return_value="请提供地区" + ) + + result = await integration.extract_and_fill( + user_input="测试输入", + target_slots=["region"], + ) + + assert result.success is False + assert "region" in result.failed_slots + assert result.ask_back_triggered is True + assert "请提供地区" in result.ask_back_prompts + + @pytest.mark.asyncio + async def test_get_missing_required_slots_from_state(self, integration): + """测试从状态获取缺失槽位""" + from app.services.mid.slot_state_aggregator import SlotState + + slot_state = SlotState() + slot_state.missing_required_slots = [ + {"slot_key": "region"}, + {"slot_key": "product"}, + ] + + result = await integration._get_missing_required_slots(slot_state) + + assert "region" in result + assert "product" in result + + @pytest.mark.asyncio + async def test_get_missing_required_slots_from_db(self, integration): + """测试从数据库获取缺失槽位""" + mock_defs = [ + MagicMock(slot_key="region"), + MagicMock(slot_key="product"), + ] + + integration._slot_def_service.list_slot_definitions = AsyncMock( + return_value=mock_defs + ) + + result = await integration._get_missing_required_slots(None) + + assert "region" in result + assert "product" in result + + +class TestIntegrateSlotExtraction: + """integrate_slot_extraction 便捷函数测试""" + + @pytest.mark.asyncio + async def test_integrate(self): + """测试便捷函数""" + mock_session = AsyncMock() + + with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls: + mock_instance = AsyncMock() + mock_instance.extract_and_fill = AsyncMock( + return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"}) + ) + mock_cls.return_value = mock_instance + + result = await integrate_slot_extraction( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + user_input="我想查询北京的产品", + ) + + assert result.success is True + assert "region" in result.extracted_slots + + +class TestExtractionTraceFlow: + """提取追踪流程测试""" + + @pytest.fixture + def integration(self): + """创建集成实例""" + mock_session = AsyncMock() + return SlotExtractionIntegration( + session=mock_session, + tenant_id="tenant_1", + session_id="session_1", + ) + + @pytest.mark.asyncio + async def test_full_extraction_flow(self, integration): + """测试完整提取流程""" + mock_slot_def = MagicMock() + mock_slot_def.type = "string" + mock_slot_def.validation_rule = None + mock_slot_def.ask_back_prompt = None + + integration._slot_def_service.get_slot_definition_by_key = AsyncMock( + return_value=mock_slot_def + ) + + with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: + mock_chain.return_value = StrategyChainResult( + slot_key="region", + success=True, + final_value="北京", + final_strategy="rule", + ) + + with patch.object(integration, '_get_backfill_service') as mock_backfill_svc: + mock_backfill = AsyncMock() + mock_backfill.backfill_single_slot = AsyncMock( + return_value=MagicMock( + is_success=lambda: True, + normalized_value="北京", + error_message=None, + ) + ) + mock_backfill_svc.return_value = mock_backfill + + with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock): + result = await integration.extract_and_fill( + user_input="北京", + target_slots=["region"], + ) + + assert len(result.traces) == 1 + trace = result.traces[0] + assert trace.slot_key == "region" + assert trace.strategy == "rule" + assert trace.extracted_value == "北京" + assert trace.validation_passed is True + assert trace.final_value == "北京" diff --git a/ai-service/tests/test_slot_state_aggregator.py b/ai-service/tests/test_slot_state_aggregator.py new file mode 100644 index 0000000..dad35dc --- /dev/null +++ b/ai-service/tests/test_slot_state_aggregator.py @@ -0,0 +1,256 @@ +""" +Tests for Slot State Aggregator. +[AC-MRS-SLOT-META-01] 槽位状态聚合服务测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.models.mid.schemas import MemorySlot, SlotSource +from app.services.mid.slot_state_aggregator import ( + SlotState, + SlotStateAggregator, + create_slot_state_aggregator, +) + + +class TestSlotState: + """测试 SlotState 数据类""" + + def test_slot_state_initialization(self): + """测试 SlotState 初始化""" + state = SlotState() + assert state.filled_slots == {} + assert state.missing_required_slots == [] + assert state.slot_sources == {} + assert state.slot_confidence == {} + assert state.slot_to_field_map == {} + + def test_get_value_for_filter_direct_match(self): + """测试直接匹配获取过滤值""" + state = SlotState( + filled_slots={"product_line": "vip_course"}, + slot_to_field_map={}, + ) + value = state.get_value_for_filter("product_line") + assert value == "vip_course" + + def test_get_value_for_filter_via_mapping(self): + """测试通过映射获取过滤值""" + state = SlotState( + filled_slots={"product": "vip_course"}, + slot_to_field_map={"product": "product_line"}, + ) + value = state.get_value_for_filter("product_line") + assert value == "vip_course" + + def test_get_value_for_filter_not_found(self): + """测试获取不存在的过滤值""" + state = SlotState(filled_slots={}) + value = state.get_value_for_filter("non_existent") + assert value is None + + def test_to_debug_info(self): + """测试转换为调试信息""" + state = SlotState( + filled_slots={"key": "value"}, + missing_required_slots=[{"slot_key": "missing"}], + ) + debug_info = state.to_debug_info() + assert debug_info["filled_slots"] == {"key": "value"} + assert len(debug_info["missing_required_slots"]) == 1 + + +class TestSlotStateAggregator: + """测试 SlotStateAggregator""" + + @pytest.fixture + def mock_session(self): + """模拟数据库会话""" + return AsyncMock() + + @pytest.fixture + def aggregator(self, mock_session): + """创建聚合器实例""" + return SlotStateAggregator( + session=mock_session, + tenant_id="test_tenant", + ) + + @pytest.mark.asyncio + async def test_aggregate_from_memory_slots(self, aggregator, mock_session): + """测试从 memory_slots 初始化""" + memory_slots = { + "product_line": MemorySlot( + key="product_line", + value="vip_course", + source=SlotSource.USER_CONFIRMED, + confidence=1.0, + ) + } + + # 模拟槽位定义服务返回空列表(没有 required 槽位) + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[], + ): + state = await aggregator.aggregate( + memory_slots=memory_slots, + current_input_slots=None, + context=None, + ) + + assert state.filled_slots["product_line"] == "vip_course" + assert state.slot_sources["product_line"] == "user_confirmed" + assert state.slot_confidence["product_line"] == 1.0 + + @pytest.mark.asyncio + async def test_aggregate_current_input_priority(self, aggregator, mock_session): + """测试当前输入优先级高于 memory""" + memory_slots = { + "product_line": MemorySlot( + key="product_line", + value="old_value", + source=SlotSource.USER_CONFIRMED, + confidence=1.0, + ) + } + current_input = {"product_line": "new_value"} + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[], + ): + state = await aggregator.aggregate( + memory_slots=memory_slots, + current_input_slots=current_input, + context=None, + ) + + # 当前输入应该覆盖 memory 的值 + assert state.filled_slots["product_line"] == "new_value" + assert state.slot_sources["product_line"] == "user_confirmed" + + @pytest.mark.asyncio + async def test_aggregate_extract_from_context(self, aggregator, mock_session): + """测试从 context 提取槽位值""" + context = { + "scene": "open_consult", + "product_line": "vip_course", + "other_key": "other_value", + } + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[], + ): + state = await aggregator.aggregate( + memory_slots=None, + current_input_slots=None, + context=context, + ) + + # 应该提取 scene 和 product_line + assert state.filled_slots.get("scene") == "open_consult" + assert state.filled_slots.get("product_line") == "vip_course" + assert state.slot_sources.get("scene") == "context" + + @pytest.mark.asyncio + async def test_generate_ask_back_response_with_prompt(self, aggregator): + """测试生成追问响应 - 使用配置的 ask_back_prompt""" + state = SlotState( + missing_required_slots=[ + { + "slot_key": "region", + "label": "地区", + "reason": "required_slot_missing", + "ask_back_prompt": "请问您在哪个地区?", + } + ] + ) + + response = await aggregator.generate_ask_back_response(state) + assert response == "请问您在哪个地区?" + + @pytest.mark.asyncio + async def test_generate_ask_back_response_generic(self, aggregator): + """测试生成追问响应 - 使用通用模板""" + state = SlotState( + missing_required_slots=[ + { + "slot_key": "region", + "label": "地区", + "reason": "required_slot_missing", + # 没有 ask_back_prompt + } + ] + ) + + response = await aggregator.generate_ask_back_response(state) + assert "地区" in response + + @pytest.mark.asyncio + async def test_generate_ask_back_response_no_missing(self, aggregator): + """测试没有缺失槽位时返回 None""" + state = SlotState(missing_required_slots=[]) + response = await aggregator.generate_ask_back_response(state) + assert response is None + + +class TestCreateSlotStateAggregator: + """测试工厂函数""" + + def test_create_aggregator(self): + """测试创建聚合器实例""" + mock_session = MagicMock() + aggregator = create_slot_state_aggregator( + session=mock_session, + tenant_id="test_tenant", + ) + assert isinstance(aggregator, SlotStateAggregator) + assert aggregator._tenant_id == "test_tenant" + + +class TestSlotStateFilterPriority: + """测试过滤值来源优先级""" + + @pytest.mark.asyncio + async def test_filter_priority_slot_first(self): + """测试优先级:slot > context > default""" + from unittest.mock import AsyncMock, MagicMock + + mock_session = AsyncMock() + aggregator = SlotStateAggregator( + session=mock_session, + tenant_id="test_tenant", + ) + + # 模拟槽位定义 + mock_slot_def = MagicMock() + mock_slot_def.slot_key = "product" + mock_slot_def.linked_field_id = None + mock_slot_def.required = False + + with patch.object( + aggregator._slot_def_service, + "list_slot_definitions", + return_value=[mock_slot_def], + ): + state = await aggregator.aggregate( + memory_slots={ + "product": MemorySlot( + key="product", + value="from_memory", + source=SlotSource.USER_CONFIRMED, + confidence=1.0, + ) + }, + current_input_slots={"product": "from_input"}, + context={"product": "from_context"}, + ) + + # 当前输入应该优先级最高 + assert state.filled_slots["product"] == "from_input" diff --git a/ai-service/tests/test_slot_state_cache.py b/ai-service/tests/test_slot_state_cache.py new file mode 100644 index 0000000..0fa625d --- /dev/null +++ b/ai-service/tests/test_slot_state_cache.py @@ -0,0 +1,399 @@ +""" +Tests for Slot State Cache. +[AC-MRS-SLOT-CACHE-01] 多轮状态持久化测试 +""" + +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.cache.slot_state_cache import ( + CachedSlotState, + CachedSlotValue, + SlotStateCache, + get_slot_state_cache, +) + + +class TestCachedSlotValue: + """CachedSlotValue 测试""" + + def test_init(self): + """测试初始化""" + value = CachedSlotValue( + value="test_value", + source="user_confirmed", + confidence=0.9, + ) + assert value.value == "test_value" + assert value.source == "user_confirmed" + assert value.confidence == 0.9 + assert value.updated_at > 0 + + def test_to_dict(self): + """测试转换为字典""" + value = CachedSlotValue( + value="test_value", + source="rule_extracted", + confidence=0.8, + ) + d = value.to_dict() + assert d["value"] == "test_value" + assert d["source"] == "rule_extracted" + assert d["confidence"] == 0.8 + assert "updated_at" in d + + def test_from_dict(self): + """测试从字典创建""" + d = { + "value": "test_value", + "source": "llm_inferred", + "confidence": 0.7, + "updated_at": 12345.0, + } + value = CachedSlotValue.from_dict(d) + assert value.value == "test_value" + assert value.source == "llm_inferred" + assert value.confidence == 0.7 + assert value.updated_at == 12345.0 + + +class TestCachedSlotState: + """CachedSlotState 测试""" + + def test_init(self): + """测试初始化""" + state = CachedSlotState() + assert state.filled_slots == {} + assert state.slot_to_field_map == {} + assert state.created_at > 0 + assert state.updated_at > 0 + + def test_with_slots(self): + """测试带槽位初始化""" + slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed"), + "product": CachedSlotValue(value="手机", source="rule_extracted"), + } + state = CachedSlotState( + filled_slots=slots, + slot_to_field_map={"region": "region_field"}, + ) + assert len(state.filled_slots) == 2 + assert state.slot_to_field_map["region"] == "region_field" + + def test_to_dict_and_from_dict(self): + """测试序列化和反序列化""" + slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed"), + } + original = CachedSlotState( + filled_slots=slots, + slot_to_field_map={"region": "region_field"}, + ) + + d = original.to_dict() + restored = CachedSlotState.from_dict(d) + + assert len(restored.filled_slots) == 1 + assert restored.filled_slots["region"].value == "北京" + assert restored.filled_slots["region"].source == "user_confirmed" + assert restored.slot_to_field_map["region"] == "region_field" + + def test_get_simple_filled_slots(self): + """测试获取简化槽位字典""" + slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed"), + "product": CachedSlotValue(value="手机", source="rule_extracted"), + } + state = CachedSlotState(filled_slots=slots) + simple = state.get_simple_filled_slots() + assert simple == {"region": "北京", "product": "手机"} + + def test_get_slot_sources(self): + """测试获取槽位来源""" + slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed"), + "product": CachedSlotValue(value="手机", source="rule_extracted"), + } + state = CachedSlotState(filled_slots=slots) + sources = state.get_slot_sources() + assert sources == {"region": "user_confirmed", "product": "rule_extracted"} + + def test_get_slot_confidence(self): + """测试获取槽位置信度""" + slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), + "product": CachedSlotValue(value="手机", source="rule_extracted", confidence=0.8), + } + state = CachedSlotState(filled_slots=slots) + confidence = state.get_slot_confidence() + assert confidence == {"region": 1.0, "product": 0.8} + + +class TestSlotStateCache: + """SlotStateCache 测试""" + + def test_source_priority(self): + """测试来源优先级""" + cache = SlotStateCache() + assert cache._get_source_priority("user_confirmed") == 100 + assert cache._get_source_priority("rule_extracted") == 80 + assert cache._get_source_priority("llm_inferred") == 60 + assert cache._get_source_priority("context") == 40 + assert cache._get_source_priority("default") == 20 + assert cache._get_source_priority("unknown") == 0 + + def test_make_key(self): + """测试 key 生成""" + cache = SlotStateCache() + key = cache._make_key("tenant_123", "session_456") + assert key == "slot_state:tenant_123:session_456" + + @pytest.mark.asyncio + async def test_l1_cache_hit(self): + """测试 L1 缓存命中""" + cache = SlotStateCache() + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, + ) + + cache._local_cache[f"{tenant_id}:{session_id}"] = (state, time.time()) + + result = await cache.get(tenant_id, session_id) + assert result is not None + assert result.filled_slots["region"].value == "北京" + + @pytest.mark.asyncio + async def test_l1_cache_expired(self): + """测试 L1 缓存过期""" + cache = SlotStateCache() + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, + ) + + old_time = time.time() - 400 + cache._local_cache[f"{tenant_id}:{session_id}"] = (state, old_time) + + result = await cache.get(tenant_id, session_id) + assert result is None + assert f"{tenant_id}:{session_id}" not in cache._local_cache + + @pytest.mark.asyncio + async def test_set_and_get_l1(self): + """测试设置和获取 L1 缓存""" + cache = SlotStateCache(redis_client=None) + cache._enabled = False + + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, + ) + + await cache.set(tenant_id, session_id, state) + + local_key = f"{tenant_id}:{session_id}" + assert local_key in cache._local_cache + + result = await cache.get(tenant_id, session_id) + assert result is not None + assert result.filled_slots["region"].value == "北京" + + @pytest.mark.asyncio + async def test_delete(self): + """测试删除缓存""" + cache = SlotStateCache(redis_client=None) + cache._enabled = False + + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, + ) + + await cache.set(tenant_id, session_id, state) + await cache.delete(tenant_id, session_id) + + result = await cache.get(tenant_id, session_id) + assert result is None + + @pytest.mark.asyncio + async def test_clear_slot(self): + """测试清除单个槽位""" + cache = SlotStateCache(redis_client=None) + cache._enabled = False + + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={ + "region": CachedSlotValue(value="北京", source="user_confirmed"), + "product": CachedSlotValue(value="手机", source="rule_extracted"), + }, + ) + + await cache.set(tenant_id, session_id, state) + await cache.clear_slot(tenant_id, session_id, "region") + + result = await cache.get(tenant_id, session_id) + assert result is not None + assert "region" not in result.filled_slots + assert "product" in result.filled_slots + + @pytest.mark.asyncio + async def test_merge_and_set_priority(self): + """测试合并时优先级处理""" + cache = SlotStateCache(redis_client=None) + cache._enabled = False + + tenant_id = "tenant_1" + session_id = "session_1" + + existing_state = CachedSlotState( + filled_slots={ + "region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6), + }, + ) + await cache.set(tenant_id, session_id, existing_state) + + new_slots = { + "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), + } + + result = await cache.merge_and_set(tenant_id, session_id, new_slots) + + assert result.filled_slots["region"].value == "北京" + assert result.filled_slots["region"].source == "user_confirmed" + + @pytest.mark.asyncio + async def test_merge_and_set_lower_priority_ignored(self): + """测试低优先级值被忽略""" + cache = SlotStateCache(redis_client=None) + cache._enabled = False + + tenant_id = "tenant_1" + session_id = "session_1" + + existing_state = CachedSlotState( + filled_slots={ + "region": CachedSlotValue(value="北京", source="user_confirmed", confidence=1.0), + }, + ) + await cache.set(tenant_id, session_id, existing_state) + + new_slots = { + "region": CachedSlotValue(value="上海", source="llm_inferred", confidence=0.6), + } + + result = await cache.merge_and_set(tenant_id, session_id, new_slots) + + assert result.filled_slots["region"].value == "北京" + assert result.filled_slots["region"].source == "user_confirmed" + + +class TestGetSlotStateCache: + """get_slot_state_cache 单例测试""" + + def test_singleton(self): + """测试单例模式""" + cache1 = get_slot_state_cache() + cache2 = get_slot_state_cache() + assert cache1 is cache2 + + +class TestSlotStateCacheWithRedis: + """SlotStateCache Redis 集成测试""" + + @pytest.mark.asyncio + async def test_redis_set_and_get(self): + """测试 Redis 存取""" + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + mock_redis.setex = AsyncMock(return_value=True) + + cache = SlotStateCache(redis_client=mock_redis) + + tenant_id = "tenant_1" + session_id = "session_1" + + state = CachedSlotState( + filled_slots={"region": CachedSlotValue(value="北京", source="user_confirmed")}, + ) + + await cache.set(tenant_id, session_id, state) + + mock_redis.setex.assert_called_once() + call_args = mock_redis.setex.call_args + assert call_args[0][0] == f"slot_state:{tenant_id}:{session_id}" + + @pytest.mark.asyncio + async def test_redis_get_hit(self): + """测试 Redis 命中""" + state_dict = { + "filled_slots": { + "region": { + "value": "北京", + "source": "user_confirmed", + "confidence": 1.0, + "updated_at": 12345.0, + } + }, + "slot_to_field_map": {"region": "region_field"}, + "created_at": 12340.0, + "updated_at": 12345.0, + } + + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(return_value=json.dumps(state_dict)) + + cache = SlotStateCache(redis_client=mock_redis) + + tenant_id = "tenant_1" + session_id = "session_1" + + result = await cache.get(tenant_id, session_id) + + assert result is not None + assert result.filled_slots["region"].value == "北京" + assert result.filled_slots["region"].source == "user_confirmed" + + @pytest.mark.asyncio + async def test_redis_delete(self): + """测试 Redis 删除""" + mock_redis = AsyncMock() + mock_redis.delete = AsyncMock(return_value=1) + + cache = SlotStateCache(redis_client=mock_redis) + + tenant_id = "tenant_1" + session_id = "session_1" + + await cache.delete(tenant_id, session_id) + + mock_redis.delete.assert_called_once_with(f"slot_state:{tenant_id}:{session_id}") + + +class TestCacheTTL: + """TTL 配置测试""" + + def test_default_ttl(self): + """测试默认 TTL""" + cache = SlotStateCache() + assert cache._cache_ttl == 1800 + + def test_local_cache_ttl(self): + """测试本地缓存 TTL""" + cache = SlotStateCache() + assert cache._local_cache_ttl == 300 diff --git a/ai-service/tests/test_slot_strategy_executor.py b/ai-service/tests/test_slot_strategy_executor.py new file mode 100644 index 0000000..663e167 --- /dev/null +++ b/ai-service/tests/test_slot_strategy_executor.py @@ -0,0 +1,244 @@ +""" +Tests for Slot Strategy Executor. +[AC-MRS-07-UPGRADE] 提取策略链执行器测试 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from app.models.entities import ExtractFailureType +from app.services.mid.slot_strategy_executor import ( + SlotStrategyExecutor, + ExtractContext, + StrategyChainResult, + execute_extract_strategies, +) + + +class TestSlotStrategyExecutor: + """测试槽位策略执行器""" + + @pytest.fixture + def executor(self): + """创建执行器实例""" + return SlotStrategyExecutor() + + @pytest.fixture + def context(self): + """创建测试上下文""" + return ExtractContext( + tenant_id="test-tenant", + slot_key="grade", + user_input="我想了解初一语文课程", + slot_type="string", + ) + + @pytest.mark.asyncio + async def test_execute_chain_success_on_first_step(self, executor, context): + """测试第一步成功时停止""" + # Mock rule extractor 成功 + mock_rule = AsyncMock(return_value="初一") + executor._extractors["rule"] = mock_rule + + result = await executor.execute_chain( + strategies=["rule", "llm", "user_input"], + context=context, + ) + + assert result.success is True + assert result.final_value == "初一" + assert result.final_strategy == "rule" + assert len(result.steps) == 1 + assert result.steps[0].success is True + mock_rule.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_chain_fallback_to_second_step(self, executor, context): + """测试第一步失败,第二步成功""" + # Mock rule extractor 失败(返回空) + mock_rule = AsyncMock(return_value=None) + # Mock llm extractor 成功 + mock_llm = AsyncMock(return_value="初一") + + executor._extractors["rule"] = mock_rule + executor._extractors["llm"] = mock_llm + + result = await executor.execute_chain( + strategies=["rule", "llm", "user_input"], + context=context, + ) + + assert result.success is True + assert result.final_value == "初一" + assert result.final_strategy == "llm" + assert len(result.steps) == 2 + assert result.steps[0].success is False + assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_EMPTY + assert result.steps[1].success is True + + @pytest.mark.asyncio + async def test_execute_chain_all_failed(self, executor, context): + """测试所有策略都失败""" + # Mock 所有 extractor 都失败 + mock_rule = AsyncMock(return_value=None) + mock_llm = AsyncMock(return_value=None) + mock_user_input = AsyncMock(return_value=None) + + executor._extractors["rule"] = mock_rule + executor._extractors["llm"] = mock_llm + executor._extractors["user_input"] = mock_user_input + + result = await executor.execute_chain( + strategies=["rule", "llm", "user_input"], + context=context, + ask_back_prompt="请告诉我您的年级", + ) + + assert result.success is False + assert result.final_value is None + assert result.final_strategy is None + assert len(result.steps) == 3 + assert result.ask_back_prompt == "请告诉我您的年级" + + # 所有步骤都失败 + for step in result.steps: + assert step.success is False + assert step.failure_type == ExtractFailureType.EXTRACT_EMPTY + + @pytest.mark.asyncio + async def test_execute_chain_validation_failure(self, executor, context): + """测试校验失败的情况""" + context.validation_rule = r"^初[一二三]$" # 只允许初一/初二/初三 + + # Mock rule extractor 返回不符合校验的值 + mock_rule = AsyncMock(return_value="高一") + executor._extractors["rule"] = mock_rule + + result = await executor.execute_chain( + strategies=["rule"], + context=context, + ) + + assert result.success is False + assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_VALIDATION_FAIL + assert "Validation failed" in result.steps[0].failure_reason + + @pytest.mark.asyncio + async def test_execute_chain_runtime_error(self, executor, context): + """测试运行时错误""" + # Mock rule extractor 抛出异常 + mock_rule = AsyncMock(side_effect=Exception("LLM service unavailable")) + executor._extractors["rule"] = mock_rule + + result = await executor.execute_chain( + strategies=["rule"], + context=context, + ) + + assert result.success is False + assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR + assert "Runtime error" in result.steps[0].failure_reason + + @pytest.mark.asyncio + async def test_execute_chain_empty_strategies(self, executor, context): + """测试空策略链""" + result = await executor.execute_chain( + strategies=[], + context=context, + ) + + assert result.success is False + assert len(result.steps) == 0 + + @pytest.mark.asyncio + async def test_execute_chain_unknown_strategy(self, executor, context): + """测试未知策略""" + result = await executor.execute_chain( + strategies=["unknown_strategy"], + context=context, + ) + + assert result.success is False + assert result.steps[0].failure_type == ExtractFailureType.EXTRACT_RUNTIME_ERROR + assert "Unknown strategy" in result.steps[0].failure_reason + + @pytest.mark.asyncio + async def test_execute_chain_result_to_dict(self, executor, context): + """测试结果转换为字典""" + mock_rule = AsyncMock(return_value="初一") + executor._extractors["rule"] = mock_rule + + result = await executor.execute_chain( + strategies=["rule"], + context=context, + ) + + result_dict = result.to_dict() + + assert result_dict["slot_key"] == "grade" + assert result_dict["success"] is True + assert result_dict["final_value"] == "初一" + assert result_dict["final_strategy"] == "rule" + assert "steps" in result_dict + assert "total_execution_time_ms" in result_dict + + +class TestExecuteExtractStrategies: + """测试便捷函数""" + + @pytest.mark.asyncio + async def test_convenience_function(self): + """测试便捷函数 execute_extract_strategies""" + mock_rule = AsyncMock(return_value="初一") + + result = await execute_extract_strategies( + strategies=["rule"], + tenant_id="test-tenant", + slot_key="grade", + user_input="我想了解初一语文课程", + rule_extractor=mock_rule, + ) + + assert result.success is True + assert result.final_value == "初一" + + @pytest.mark.asyncio + async def test_convenience_function_with_validation(self): + """测试带校验的便捷函数""" + mock_rule = AsyncMock(return_value="初一") + + result = await execute_extract_strategies( + strategies=["rule"], + tenant_id="test-tenant", + slot_key="grade", + user_input="我想了解初一语文课程", + validation_rule=r"^初[一二三]$", + rule_extractor=mock_rule, + ) + + assert result.success is True + assert result.final_value == "初一" + + +class TestExtractContext: + """测试提取上下文""" + + def test_context_creation(self): + """测试上下文创建""" + context = ExtractContext( + tenant_id="test-tenant", + slot_key="grade", + user_input="测试输入", + slot_type="string", + validation_rule=r"^初[一二三]$", + history=[{"role": "user", "content": "你好"}], + session_id="session-123", + ) + + assert context.tenant_id == "test-tenant" + assert context.slot_key == "grade" + assert context.user_input == "测试输入" + assert context.slot_type == "string" + assert context.validation_rule == r"^初[一二三]$" + assert len(context.history) == 1 + assert context.session_id == "session-123" diff --git a/ai-service/tests/test_slot_validation_service.py b/ai-service/tests/test_slot_validation_service.py new file mode 100644 index 0000000..d5e040c --- /dev/null +++ b/ai-service/tests/test_slot_validation_service.py @@ -0,0 +1,541 @@ +""" +Tests for Slot Validation Service. +槽位校验服务单元测试 +""" + +import pytest + +from app.services.mid.slot_validation_service import ( + SlotValidationService, + SlotValidationErrorCode, + ValidationResult, + SlotValidationError, + BatchValidationResult, +) + + +class TestSlotValidationService: + """槽位校验服务测试类""" + + @pytest.fixture + def service(self): + """创建校验服务实例""" + return SlotValidationService() + + @pytest.fixture + def string_slot_def(self): + """字符串类型槽位定义""" + return { + "slot_key": "name", + "type": "string", + "required": False, + "validation_rule": None, + "ask_back_prompt": "请输入您的姓名", + } + + @pytest.fixture + def required_string_slot_def(self): + """必填字符串类型槽位定义""" + return { + "slot_key": "phone", + "type": "string", + "required": True, + "validation_rule": r"^1[3-9]\d{9}$", + "ask_back_prompt": "请输入正确的手机号码", + } + + @pytest.fixture + def number_slot_def(self): + """数字类型槽位定义""" + return { + "slot_key": "age", + "type": "number", + "required": False, + "validation_rule": None, + "ask_back_prompt": "请输入年龄", + } + + @pytest.fixture + def boolean_slot_def(self): + """布尔类型槽位定义""" + return { + "slot_key": "is_student", + "type": "boolean", + "required": False, + "validation_rule": None, + "ask_back_prompt": "是否是学生?", + } + + @pytest.fixture + def enum_slot_def(self): + """枚举类型槽位定义""" + return { + "slot_key": "grade", + "type": "enum", + "required": False, + "options": ["初一", "初二", "初三", "高一", "高二", "高三"], + "validation_rule": None, + "ask_back_prompt": "请选择年级", + } + + @pytest.fixture + def array_enum_slot_def(self): + """数组枚举类型槽位定义""" + return { + "slot_key": "subjects", + "type": "array_enum", + "required": False, + "options": ["语文", "数学", "英语", "物理", "化学"], + "validation_rule": None, + "ask_back_prompt": "请选择学科", + } + + @pytest.fixture + def json_schema_slot_def(self): + """JSON Schema 校验槽位定义""" + return { + "slot_key": "email", + "type": "string", + "required": True, + "validation_rule": '{"type": "string", "format": "email"}', + "ask_back_prompt": "请输入有效的邮箱地址", + } + + class TestBasicValidation: + """基础校验测试""" + + def test_empty_validation_rule(self, service, string_slot_def): + """测试空校验规则(应通过)""" + string_slot_def["validation_rule"] = None + result = service.validate_slot_value(string_slot_def, "test") + assert result.ok is True + assert result.normalized_value == "test" + + def test_whitespace_validation_rule(self, service, string_slot_def): + """测试空白校验规则(应通过)""" + string_slot_def["validation_rule"] = " " + result = service.validate_slot_value(string_slot_def, "test") + assert result.ok is True + + def test_no_slot_definition(self, service): + """测试无槽位定义(动态槽位)""" + # 使用最小定义 + minimal_def = {"slot_key": "dynamic_field"} + result = service.validate_slot_value(minimal_def, "any_value") + assert result.ok is True + + class TestRegexValidation: + """正则表达式校验测试""" + + def test_regex_match(self, service, required_string_slot_def): + """测试正则匹配成功""" + result = service.validate_slot_value( + required_string_slot_def, "13800138000" + ) + assert result.ok is True + assert result.normalized_value == "13800138000" + + def test_regex_mismatch(self, service, required_string_slot_def): + """测试正则匹配失败""" + result = service.validate_slot_value( + required_string_slot_def, "invalid_phone" + ) + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH + assert result.ask_back_prompt == "请输入正确的手机号码" + + def test_regex_invalid_pattern(self, service, string_slot_def): + """测试非法正则表达式""" + string_slot_def["validation_rule"] = "[invalid(" + result = service.validate_slot_value(string_slot_def, "test") + assert result.ok is False + assert ( + result.error_code + == SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID + ) + + def test_regex_with_chinese(self, service, string_slot_def): + """测试包含中文的正则""" + string_slot_def["validation_rule"] = r"^[\u4e00-\u9fa5]{2,4}$" + result = service.validate_slot_value(string_slot_def, "张三") + assert result.ok is True + + result = service.validate_slot_value(string_slot_def, "John") + assert result.ok is False + + class TestJsonSchemaValidation: + """JSON Schema 校验测试""" + + def test_json_schema_match(self, service): + """测试 JSON Schema 匹配成功""" + slot_def = { + "slot_key": "config", + "type": "object", + "validation_rule": '{"type": "object", "properties": {"name": {"type": "string"}}}', + } + result = service.validate_slot_value(slot_def, {"name": "test"}) + assert result.ok is True + + def test_json_schema_mismatch(self, service): + """测试 JSON Schema 匹配失败""" + slot_def = { + "slot_key": "count", + "type": "number", + "validation_rule": '{"type": "integer", "minimum": 0, "maximum": 100}', + "ask_back_prompt": "请输入0-100之间的整数", + } + result = service.validate_slot_value(slot_def, 150) + assert result.ok is False + assert ( + result.error_code == SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH + ) + assert result.ask_back_prompt == "请输入0-100之间的整数" + + def test_json_schema_invalid_json(self, service, string_slot_def): + """测试非法 JSON Schema""" + string_slot_def["validation_rule"] = "{invalid json}" + result = service.validate_slot_value(string_slot_def, "test") + assert result.ok is False + assert ( + result.error_code + == SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID + ) + + def test_json_schema_array(self, service): + """测试数组类型的 JSON Schema""" + slot_def = { + "slot_key": "items", + "type": "array", + "validation_rule": '{"type": "array", "items": {"type": "string"}}', + } + result = service.validate_slot_value(slot_def, ["a", "b", "c"]) + assert result.ok is True + + result = service.validate_slot_value(slot_def, [1, 2, 3]) + assert result.ok is False + + class TestRequiredValidation: + """必填校验测试""" + + def test_required_missing_none(self, service, required_string_slot_def): + """测试必填字段为 None""" + result = service.validate_slot_value( + required_string_slot_def, None + ) + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING + + def test_required_missing_empty_string(self, service, required_string_slot_def): + """测试必填字段为空字符串""" + result = service.validate_slot_value(required_string_slot_def, "") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING + + def test_required_missing_whitespace(self, service, required_string_slot_def): + """测试必填字段为空白字符""" + result = service.validate_slot_value(required_string_slot_def, " ") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_REQUIRED_MISSING + + def test_required_present(self, service, required_string_slot_def): + """测试必填字段有值""" + result = service.validate_slot_value( + required_string_slot_def, "13800138000" + ) + assert result.ok is True + + def test_not_required_empty(self, service, string_slot_def): + """测试非必填字段为空""" + result = service.validate_slot_value(string_slot_def, "") + assert result.ok is True + + class TestTypeValidation: + """类型校验测试""" + + def test_string_type(self, service, string_slot_def): + """测试字符串类型""" + result = service.validate_slot_value(string_slot_def, "hello") + assert result.ok is True + assert result.normalized_value == "hello" + + def test_string_type_conversion(self, service, string_slot_def): + """测试字符串类型自动转换""" + result = service.validate_slot_value(string_slot_def, 123) + assert result.ok is True + assert result.normalized_value == "123" + + def test_number_type_integer(self, service, number_slot_def): + """测试数字类型 - 整数""" + result = service.validate_slot_value(number_slot_def, 25) + assert result.ok is True + assert result.normalized_value == 25 + + def test_number_type_float(self, service, number_slot_def): + """测试数字类型 - 浮点数""" + result = service.validate_slot_value(number_slot_def, 25.5) + assert result.ok is True + assert result.normalized_value == 25.5 + + def test_number_type_string_conversion(self, service, number_slot_def): + """测试数字类型 - 字符串转换""" + result = service.validate_slot_value(number_slot_def, "30") + assert result.ok is True + assert result.normalized_value == 30 + + def test_number_type_invalid(self, service, number_slot_def): + """测试数字类型 - 无效值""" + result = service.validate_slot_value(number_slot_def, "not_a_number") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + def test_number_type_reject_boolean(self, service, number_slot_def): + """测试数字类型 - 拒绝布尔值""" + result = service.validate_slot_value(number_slot_def, True) + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + def test_boolean_type_true(self, service, boolean_slot_def): + """测试布尔类型 - True""" + result = service.validate_slot_value(boolean_slot_def, True) + assert result.ok is True + assert result.normalized_value is True + + def test_boolean_type_false(self, service, boolean_slot_def): + """测试布尔类型 - False""" + result = service.validate_slot_value(boolean_slot_def, False) + assert result.ok is True + assert result.normalized_value is False + + def test_boolean_type_string_true(self, service, boolean_slot_def): + """测试布尔类型 - 字符串 true""" + result = service.validate_slot_value(boolean_slot_def, "true") + assert result.ok is True + assert result.normalized_value is True + + def test_boolean_type_string_yes(self, service, boolean_slot_def): + """测试布尔类型 - 字符串 yes/是""" + result = service.validate_slot_value(boolean_slot_def, "是") + assert result.ok is True + assert result.normalized_value is True + + def test_boolean_type_string_false(self, service, boolean_slot_def): + """测试布尔类型 - 字符串 false""" + result = service.validate_slot_value(boolean_slot_def, "false") + assert result.ok is True + assert result.normalized_value is False + + def test_boolean_type_invalid(self, service, boolean_slot_def): + """测试布尔类型 - 无效值""" + result = service.validate_slot_value(boolean_slot_def, "maybe") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + def test_enum_type_valid(self, service, enum_slot_def): + """测试枚举类型 - 有效值""" + result = service.validate_slot_value(enum_slot_def, "高一") + assert result.ok is True + assert result.normalized_value == "高一" + + def test_enum_type_invalid(self, service, enum_slot_def): + """测试枚举类型 - 无效值""" + result = service.validate_slot_value(enum_slot_def, "大一") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_ENUM_INVALID + + def test_enum_type_not_string(self, service, enum_slot_def): + """测试枚举类型 - 非字符串""" + result = service.validate_slot_value(enum_slot_def, 123) + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + def test_array_enum_type_valid(self, service, array_enum_slot_def): + """测试数组枚举类型 - 有效值""" + result = service.validate_slot_value( + array_enum_slot_def, ["语文", "数学"] + ) + assert result.ok is True + + def test_array_enum_type_invalid_item(self, service, array_enum_slot_def): + """测试数组枚举类型 - 无效元素""" + result = service.validate_slot_value( + array_enum_slot_def, ["语文", "生物"] + ) + assert result.ok is False + assert ( + result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID + ) + + def test_array_enum_type_not_array(self, service, array_enum_slot_def): + """测试数组枚举类型 - 非数组""" + result = service.validate_slot_value(array_enum_slot_def, "语文") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + def test_array_enum_type_non_string_item(self, service, array_enum_slot_def): + """测试数组枚举类型 - 非字符串元素""" + result = service.validate_slot_value(array_enum_slot_def, ["语文", 123]) + assert result.ok is False + assert ( + result.error_code == SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID + ) + + class TestBatchValidation: + """批量校验测试""" + + def test_batch_all_valid(self, service, string_slot_def, number_slot_def): + """测试批量校验 - 全部通过""" + slot_defs = [string_slot_def, number_slot_def] + values = {"name": "张三", "age": 25} + result = service.validate_slots(slot_defs, values) + assert result.ok is True + assert len(result.errors) == 0 + assert result.validated_values["name"] == "张三" + assert result.validated_values["age"] == 25 + + def test_batch_some_invalid(self, service, string_slot_def, number_slot_def): + """测试批量校验 - 部分失败""" + slot_defs = [string_slot_def, number_slot_def] + values = {"name": "张三", "age": "not_a_number"} + result = service.validate_slots(slot_defs, values) + assert result.ok is False + assert len(result.errors) == 1 + assert result.errors[0].slot_key == "age" + + def test_batch_missing_required( + self, service, required_string_slot_def, string_slot_def + ): + """测试批量校验 - 缺失必填字段""" + slot_defs = [required_string_slot_def, string_slot_def] + values = {"name": "张三"} # 缺少 phone + result = service.validate_slots(slot_defs, values) + assert result.ok is False + assert len(result.errors) == 1 + assert result.errors[0].slot_key == "phone" + assert ( + result.errors[0].error_code + == SlotValidationErrorCode.SLOT_REQUIRED_MISSING + ) + + def test_batch_undefined_slot(self, service, string_slot_def): + """测试批量校验 - 未定义槽位""" + slot_defs = [string_slot_def] + values = {"name": "张三", "undefined_field": "value"} + result = service.validate_slots(slot_defs, values) + assert result.ok is True + # 未定义槽位应允许通过 + assert "undefined_field" in result.validated_values + + class TestCombinedValidation: + """组合校验测试(类型 + 正则/JSON Schema)""" + + def test_type_and_regex_both_pass(self, service): + """测试类型和正则都通过""" + slot_def = { + "slot_key": "code", + "type": "string", + "required": True, + "validation_rule": r"^[A-Z]{2}\d{4}$", + } + result = service.validate_slot_value(slot_def, "AB1234") + assert result.ok is True + + def test_type_pass_regex_fail(self, service): + """测试类型通过但正则失败""" + slot_def = { + "slot_key": "code", + "type": "string", + "required": True, + "validation_rule": r"^[A-Z]{2}\d{4}$", + } + result = service.validate_slot_value(slot_def, "ab1234") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_REGEX_MISMATCH + + def test_type_fail_no_regex_check(self, service): + """测试类型失败时不执行正则校验""" + slot_def = { + "slot_key": "code", + "type": "number", + "required": True, + "validation_rule": r"^\d+$", + } + result = service.validate_slot_value(slot_def, "not_a_number") + assert result.ok is False + assert result.error_code == SlotValidationErrorCode.SLOT_TYPE_INVALID + + class TestAskBackPrompt: + """追问提示语测试""" + + def test_ask_back_prompt_on_validation_fail(self, service): + """测试校验失败时返回 ask_back_prompt""" + slot_def = { + "slot_key": "email", + "type": "string", + "required": True, + "validation_rule": r"^[\w\.-]+@[\w\.-]+\.\w+$", + "ask_back_prompt": "请输入有效的邮箱地址,如 example@domain.com", + } + result = service.validate_slot_value(slot_def, "invalid_email") + assert result.ok is False + assert result.ask_back_prompt == "请输入有效的邮箱地址,如 example@domain.com" + + def test_no_ask_back_prompt_on_success(self, service, string_slot_def): + """测试校验通过时不返回 ask_back_prompt""" + result = service.validate_slot_value(string_slot_def, "valid") + assert result.ok is True + assert result.ask_back_prompt is None + + def test_ask_back_prompt_on_required_missing(self, service): + """测试必填缺失时返回 ask_back_prompt""" + slot_def = { + "slot_key": "name", + "type": "string", + "required": True, + "ask_back_prompt": "请告诉我们您的姓名", + } + result = service.validate_slot_value(slot_def, "") + assert result.ok is False + assert result.ask_back_prompt == "请告诉我们您的姓名" + + +class TestSlotValidationErrorCode: + """错误码测试""" + + def test_error_code_values(self): + """测试错误码值""" + assert SlotValidationErrorCode.SLOT_REQUIRED_MISSING == "SLOT_REQUIRED_MISSING" + assert SlotValidationErrorCode.SLOT_TYPE_INVALID == "SLOT_TYPE_INVALID" + assert SlotValidationErrorCode.SLOT_REGEX_MISMATCH == "SLOT_REGEX_MISMATCH" + assert ( + SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH + == "SLOT_JSON_SCHEMA_MISMATCH" + ) + assert ( + SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID + == "SLOT_VALIDATION_RULE_INVALID" + ) + + +class TestValidationResult: + """ValidationResult 测试""" + + def test_success_result(self): + """测试成功结果""" + result = ValidationResult(ok=True, normalized_value="test") + assert result.ok is True + assert result.normalized_value == "test" + assert result.error_code is None + assert result.error_message is None + + def test_failure_result(self): + """测试失败结果""" + result = ValidationResult( + ok=False, + error_code="SLOT_REGEX_MISMATCH", + error_message="格式不正确", + ask_back_prompt="请重新输入", + ) + assert result.ok is False + assert result.error_code == "SLOT_REGEX_MISMATCH" + assert result.error_message == "格式不正确" + assert result.ask_back_prompt == "请重新输入" diff --git a/ai-service/tests/test_step_kb_binding.py b/ai-service/tests/test_step_kb_binding.py new file mode 100644 index 0000000..1032034 --- /dev/null +++ b/ai-service/tests/test_step_kb_binding.py @@ -0,0 +1,328 @@ +""" +Test cases for Step-KB Binding feature. +[Step-KB-Binding] 步骤关联知识库功能的测试用例 + +测试覆盖: +1. 步骤配置的增删改查与参数校验 +2. 配置步骤KB范围后,检索仅在范围内发生 +3. 未配置时回退原逻辑 +4. 多知识库同名内容场景下,步骤约束生效 +5. trace 字段完整性校验 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass +from typing import Any + + +class TestStepKbBindingModel: + """测试步骤KB绑定数据模型""" + + def test_flow_step_with_kb_binding_fields(self): + """测试 FlowStep 包含 KB 绑定字段""" + from app.models.entities import FlowStep + + step = FlowStep( + step_no=1, + content="测试步骤", + allowed_kb_ids=["kb-1", "kb-2"], + preferred_kb_ids=["kb-1"], + kb_query_hint="查找产品相关信息", + max_kb_calls_per_step=2, + ) + + assert step.allowed_kb_ids == ["kb-1", "kb-2"] + assert step.preferred_kb_ids == ["kb-1"] + assert step.kb_query_hint == "查找产品相关信息" + assert step.max_kb_calls_per_step == 2 + + def test_flow_step_without_kb_binding(self): + """测试 FlowStep 不配置 KB 绑定时的默认值""" + from app.models.entities import FlowStep + + step = FlowStep( + step_no=1, + content="测试步骤", + ) + + assert step.allowed_kb_ids is None + assert step.preferred_kb_ids is None + assert step.kb_query_hint is None + assert step.max_kb_calls_per_step is None + + def test_max_kb_calls_validation(self): + """测试 max_kb_calls_per_step 的范围校验""" + from app.models.entities import FlowStep + from pydantic import ValidationError + + # 有效范围 1-5 + step = FlowStep(step_no=1, content="test", max_kb_calls_per_step=3) + assert step.max_kb_calls_per_step == 3 + + # 超出上限 + with pytest.raises(Exception): # ValidationError + FlowStep(step_no=1, content="test", max_kb_calls_per_step=10) + + +class TestStepKbConfig: + """测试 StepKbConfig 数据类""" + + def test_step_kb_config_creation(self): + """测试 StepKbConfig 创建""" + from app.services.mid.kb_search_dynamic_tool import StepKbConfig + + config = StepKbConfig( + allowed_kb_ids=["kb-1", "kb-2"], + preferred_kb_ids=["kb-1"], + kb_query_hint="查找产品信息", + max_kb_calls=2, + step_id="flow-1_step_1", + ) + + assert config.allowed_kb_ids == ["kb-1", "kb-2"] + assert config.preferred_kb_ids == ["kb-1"] + assert config.kb_query_hint == "查找产品信息" + assert config.max_kb_calls == 2 + assert config.step_id == "flow-1_step_1" + + def test_step_kb_config_defaults(self): + """测试 StepKbConfig 默认值""" + from app.services.mid.kb_search_dynamic_tool import StepKbConfig + + config = StepKbConfig() + + assert config.allowed_kb_ids is None + assert config.preferred_kb_ids is None + assert config.kb_query_hint is None + assert config.max_kb_calls == 1 + assert config.step_id is None + + +class TestKbSearchDynamicToolWithStepConfig: + """测试 KbSearchDynamicTool 与步骤配置的集成""" + + @pytest.mark.asyncio + async def test_kb_search_with_allowed_kb_ids(self): + """测试配置 allowed_kb_ids 后检索范围受限""" + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicTool, + KbSearchDynamicConfig, + StepKbConfig, + ) + + mock_session = MagicMock() + mock_timeout_governor = MagicMock() + + tool = KbSearchDynamicTool( + session=mock_session, + timeout_governor=mock_timeout_governor, + config=KbSearchDynamicConfig(enabled=True), + ) + + step_config = StepKbConfig( + allowed_kb_ids=["kb-allowed-1", "kb-allowed-2"], + step_id="test_step", + ) + + with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve: + mock_retrieve.return_value = [ + {"id": "1", "content": "test", "score": 0.8, "metadata": {"kb_id": "kb-allowed-1"}} + ] + + result = await tool.execute( + query="测试查询", + tenant_id="tenant-1", + step_kb_config=step_config, + ) + + # 验证检索调用时传入了正确的 kb_ids + call_args = mock_retrieve.call_args + assert call_args[1]['step_kb_config'] == step_config + + # 验证返回结果包含 step_kb_binding 信息 + assert result.step_kb_binding is not None + assert result.step_kb_binding['allowed_kb_ids'] == ["kb-allowed-1", "kb-allowed-2"] + + @pytest.mark.asyncio + async def test_kb_search_without_step_config(self): + """测试未配置步骤KB时的回退行为""" + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicTool, + KbSearchDynamicConfig, + ) + + mock_session = MagicMock() + mock_timeout_governor = MagicMock() + + tool = KbSearchDynamicTool( + session=mock_session, + timeout_governor=mock_timeout_governor, + config=KbSearchDynamicConfig(enabled=True), + ) + + with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve: + mock_retrieve.return_value = [ + {"id": "1", "content": "test", "score": 0.8, "metadata": {}} + ] + + result = await tool.execute( + query="测试查询", + tenant_id="tenant-1", + ) + + # 验证检索调用时未传入 step_kb_config + call_args = mock_retrieve.call_args + assert call_args[1]['step_kb_config'] is None + + # 验证返回结果不包含 step_kb_binding + assert result.step_kb_binding is None + + @pytest.mark.asyncio + async def test_kb_search_result_includes_used_kb_ids(self): + """测试检索结果包含实际使用的知识库ID""" + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicTool, + KbSearchDynamicConfig, + StepKbConfig, + ) + + mock_session = MagicMock() + mock_timeout_governor = MagicMock() + + tool = KbSearchDynamicTool( + session=mock_session, + timeout_governor=mock_timeout_governor, + config=KbSearchDynamicConfig(enabled=True), + ) + + step_config = StepKbConfig( + allowed_kb_ids=["kb-1", "kb-2"], + step_id="test_step", + ) + + with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve: + mock_retrieve.return_value = [ + {"id": "1", "content": "test1", "score": 0.9, "metadata": {"kb_id": "kb-1"}}, + {"id": "2", "content": "test2", "score": 0.8, "metadata": {"kb_id": "kb-1"}}, + {"id": "3", "content": "test3", "score": 0.7, "metadata": {"kb_id": "kb-2"}}, + ] + + result = await tool.execute( + query="测试查询", + tenant_id="tenant-1", + step_kb_config=step_config, + ) + + # 验证 used_kb_ids 包含所有命中的知识库 + assert result.step_kb_binding is not None + assert set(result.step_kb_binding['used_kb_ids']) == {"kb-1", "kb-2"} + assert result.step_kb_binding['kb_hit'] is True + + +class TestTraceInfoStepKbBinding: + """测试 TraceInfo 中的 step_kb_binding 字段""" + + def test_trace_info_with_step_kb_binding(self): + """测试 TraceInfo 包含 step_kb_binding 字段""" + from app.models.mid.schemas import TraceInfo, ExecutionMode + + trace = TraceInfo( + mode=ExecutionMode.AGENT, + step_kb_binding={ + "step_id": "flow-1_step_2", + "allowed_kb_ids": ["kb-1", "kb-2"], + "used_kb_ids": ["kb-1"], + "kb_hit": True, + }, + ) + + assert trace.step_kb_binding is not None + assert trace.step_kb_binding['step_id'] == "flow-1_step_2" + assert trace.step_kb_binding['allowed_kb_ids'] == ["kb-1", "kb-2"] + assert trace.step_kb_binding['used_kb_ids'] == ["kb-1"] + + def test_trace_info_without_step_kb_binding(self): + """测试 TraceInfo 默认不包含 step_kb_binding""" + from app.models.mid.schemas import TraceInfo, ExecutionMode + + trace = TraceInfo(mode=ExecutionMode.AGENT) + + assert trace.step_kb_binding is None + + +class TestFlowStepKbBindingIntegration: + """测试流程步骤与KB绑定的集成""" + + def test_script_flow_steps_with_kb_binding(self): + """测试 ScriptFlow 的 steps 包含 KB 绑定配置""" + from app.models.entities import ScriptFlowCreate + + flow_create = ScriptFlowCreate( + name="测试流程", + steps=[ + { + "step_no": 1, + "content": "步骤1", + "allowed_kb_ids": ["kb-1"], + "preferred_kb_ids": None, + "kb_query_hint": "查找产品信息", + "max_kb_calls_per_step": 2, + }, + { + "step_no": 2, + "content": "步骤2", + # 不配置 KB 绑定 + }, + ], + ) + + assert flow_create.steps[0]['allowed_kb_ids'] == ["kb-1"] + assert flow_create.steps[0]['kb_query_hint'] == "查找产品信息" + assert flow_create.steps[1].get('allowed_kb_ids') is None + + +class TestKbBindingLogging: + """测试 KB 绑定的日志记录""" + + @pytest.mark.asyncio + async def test_step_kb_config_logging(self, caplog): + """测试步骤KB配置的日志记录""" + import logging + from app.services.mid.kb_search_dynamic_tool import ( + KbSearchDynamicTool, + KbSearchDynamicConfig, + StepKbConfig, + ) + + mock_session = MagicMock() + mock_timeout_governor = MagicMock() + + tool = KbSearchDynamicTool( + session=mock_session, + timeout_governor=mock_timeout_governor, + config=KbSearchDynamicConfig(enabled=True), + ) + + step_config = StepKbConfig( + allowed_kb_ids=["kb-1"], + step_id="flow-1_step_1", + ) + + with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve: + mock_retrieve.return_value = [] + + with caplog.at_level(logging.INFO): + await tool.execute( + query="测试", + tenant_id="tenant-1", + step_kb_config=step_config, + ) + + # 验证日志包含 Step-KB-Binding 标记 + assert any("Step-KB-Binding" in record.message for record in caplog.records) + + +# 运行测试的入口 +if __name__ == "__main__": + pytest.main([__file__, "-v"])