From 9739aa20162c08e1c8ccb12b289b7acaf4c11611 Mon Sep 17 00:00:00 2001 From: MerCry Date: Mon, 2 Mar 2026 22:17:23 +0800 Subject: [PATCH] test: add metadata governance contract and integration tests [AC-IDSMETA-13~22] --- ai-service/app/api/admin/metadata_schema.py | 310 ++++++++ .../app/services/metadata_schema_service.py | 325 ++++++++ .../test_metadata_governance_contract.py | 732 ++++++++++++++++++ .../test_metadata_governance_integration.py | 675 ++++++++++++++++ 4 files changed, 2042 insertions(+) create mode 100644 ai-service/app/api/admin/metadata_schema.py create mode 100644 ai-service/app/services/metadata_schema_service.py create mode 100644 ai-service/tests/test_metadata_governance_contract.py create mode 100644 ai-service/tests/test_metadata_governance_integration.py diff --git a/ai-service/app/api/admin/metadata_schema.py b/ai-service/app/api/admin/metadata_schema.py new file mode 100644 index 0000000..a1e39ff --- /dev/null +++ b/ai-service/app/api/admin/metadata_schema.py @@ -0,0 +1,310 @@ +""" +Metadata Schema API. +动态元数据模式管理接口。 +""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.models.entities import ( + MetadataField, + MetadataSchema, + MetadataSchemaCreate, + MetadataSchemaUpdate, +) +from app.services.metadata_schema_service import MetadataSchemaService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/admin/metadata-schemas", tags=["Metadata Schemas"]) + + +def get_current_tenant_id() -> str: + """Get current tenant ID from context.""" + from app.core.tenant import get_tenant_id + tenant_id = get_tenant_id() + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + return tenant_id + + +@router.get( + "", + operation_id="listMetadataSchemas", + summary="List metadata schemas", + description="获取租户所有元数据模式配置", +) +async def list_schemas( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + include_disabled: bool = False, +) -> JSONResponse: + """ + 列出租户所有元数据模式 + """ + service = MetadataSchemaService(session) + schemas = await service.list_schemas(tenant_id, include_disabled) + + return JSONResponse( + content={ + "schemas": [ + { + "id": str(s.id), + "name": s.name, + "description": s.description, + "fields": s.fields, + "isDefault": s.is_default, + "isEnabled": s.is_enabled, + "createdAt": s.created_at.isoformat(), + "updatedAt": s.updated_at.isoformat(), + } + for s in schemas + ] + } + ) + + +@router.get( + "/default", + operation_id="getDefaultMetadataSchema", + summary="Get default metadata schema", + description="获取租户默认的元数据模式配置", +) +async def get_default_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], +) -> JSONResponse: + """ + 获取租户默认的元数据模式 + """ + service = MetadataSchemaService(session) + schema = await service.get_schema(tenant_id) + + if not schema: + return JSONResponse( + content={ + "schema": None, + "message": "No default schema configured", + } + ) + + return JSONResponse( + content={ + "schema": { + "id": str(schema.id), + "name": schema.name, + "description": schema.description, + "fields": schema.fields, + "isDefault": schema.is_default, + "isEnabled": schema.is_enabled, + "createdAt": schema.created_at.isoformat(), + "updatedAt": schema.updated_at.isoformat(), + } + } + ) + + +@router.get( + "/{schema_id}", + operation_id="getMetadataSchema", + summary="Get metadata schema by ID", + description="根据 ID 获取元数据模式配置", +) +async def get_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + schema_id: str, +) -> JSONResponse: + """ + 根据 ID 获取元数据模式 + """ + service = MetadataSchemaService(session) + schema = await service.get_schema(tenant_id, schema_id) + + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + return JSONResponse( + content={ + "schema": { + "id": str(schema.id), + "name": schema.name, + "description": schema.description, + "fields": schema.fields, + "isDefault": schema.is_default, + "isEnabled": schema.is_enabled, + "createdAt": schema.created_at.isoformat(), + "updatedAt": schema.updated_at.isoformat(), + } + } + ) + + +@router.post( + "", + operation_id="createMetadataSchema", + summary="Create metadata schema", + description="创建新的元数据模式配置", +) +async def create_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + schema_create: MetadataSchemaCreate, +) -> JSONResponse: + """ + 创建元数据模式 + """ + service = MetadataSchemaService(session) + + for field in schema_create.fields: + if isinstance(field, MetadataField): + field_dict = field.model_dump() + else: + field_dict = field + + field_type = field_dict.get("field_type", "string") + if field_type in ["select", "multi_select"]: + if not field_dict.get("options"): + raise HTTPException( + status_code=400, + detail=f"Field '{field_dict.get('name')}' is {field_type} type but has no options" + ) + + schema = await service.create_schema(tenant_id, schema_create) + await session.commit() + + return JSONResponse( + status_code=201, + content={ + "id": str(schema.id), + "name": schema.name, + "description": schema.description, + "fields": schema.fields, + "isDefault": schema.is_default, + "isEnabled": schema.is_enabled, + } + ) + + +@router.put( + "/{schema_id}", + operation_id="updateMetadataSchema", + summary="Update metadata schema", + description="更新元数据模式配置", +) +async def update_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + schema_id: str, + schema_update: MetadataSchemaUpdate, +) -> JSONResponse: + """ + 更新元数据模式 + """ + service = MetadataSchemaService(session) + schema = await service.update_schema(tenant_id, schema_id, schema_update) + + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + await session.commit() + + return JSONResponse( + content={ + "id": str(schema.id), + "name": schema.name, + "description": schema.description, + "fields": schema.fields, + "isDefault": schema.is_default, + "isEnabled": schema.is_enabled, + } + ) + + +@router.delete( + "/{schema_id}", + operation_id="deleteMetadataSchema", + summary="Delete metadata schema", + description="删除元数据模式配置", +) +async def delete_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + schema_id: str, +) -> JSONResponse: + """ + 删除元数据模式 + """ + service = MetadataSchemaService(session) + success = await service.delete_schema(tenant_id, schema_id) + + if not success: + raise HTTPException( + status_code=400, + detail="Cannot delete schema (not found or is default)" + ) + + await session.commit() + + return JSONResponse( + content={ + "success": True, + "message": "Schema deleted" + } + ) + + +@router.get( + "/default/field-definitions", + operation_id="getFieldDefinitions", + summary="Get field definitions", + description="获取字段定义映射,用于前端动态渲染表单", +) +async def get_field_definitions( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + schema_id: str | None = None, +) -> JSONResponse: + """ + 获取字段定义映射 + """ + service = MetadataSchemaService(session) + field_defs = await service.get_field_definitions(tenant_id, schema_id) + + return JSONResponse( + content={ + "fieldDefinitions": field_defs + } + ) + + +@router.post( + "/default/validate", + operation_id="validateMetadata", + summary="Validate metadata", + description="验证元数据是否符合模式定义", +) +async def validate_metadata( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + metadata: dict[str, Any], + schema_id: str | None = None, +) -> JSONResponse: + """ + 验证元数据 + """ + service = MetadataSchemaService(session) + is_valid, errors = await service.validate_metadata(tenant_id, metadata, schema_id) + + return JSONResponse( + content={ + "isValid": is_valid, + "errors": errors + } + ) diff --git a/ai-service/app/services/metadata_schema_service.py b/ai-service/app/services/metadata_schema_service.py new file mode 100644 index 0000000..ccbd801 --- /dev/null +++ b/ai-service/app/services/metadata_schema_service.py @@ -0,0 +1,325 @@ +""" +Metadata Schema Service. +动态元数据模式管理服务,支持租户自定义元数据字段配置。 +""" + +import logging +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col + +from app.models.entities import ( + MetadataField, + MetadataFieldType, + MetadataSchema, + MetadataSchemaCreate, + MetadataSchemaUpdate, +) + +logger = logging.getLogger(__name__) + + +class MetadataSchemaService: + """ + 元数据模式服务 + 管理租户的动态元数据字段配置 + """ + + def __init__(self, session: AsyncSession): + self._session = session + + async def get_schema( + self, + tenant_id: str, + schema_id: str | None = None, + ) -> MetadataSchema | None: + """ + 获取元数据模式 + + Args: + tenant_id: 租户 ID + schema_id: 模式 ID(可选,不传则获取默认模式) + + Returns: + MetadataSchema 或 None + """ + if schema_id: + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + MetadataSchema.id == uuid.UUID(schema_id), + ) + else: + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + MetadataSchema.is_default == True, + MetadataSchema.is_enabled == True, + ).order_by(col(MetadataSchema.created_at).desc()) + + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def list_schemas( + self, + tenant_id: str, + include_disabled: bool = False, + ) -> list[MetadataSchema]: + """ + 列出租户所有元数据模式 + + Args: + tenant_id: 租户 ID + include_disabled: 是否包含禁用的模式 + + Returns: + MetadataSchema 列表 + """ + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + ) + if not include_disabled: + stmt = stmt.where(MetadataSchema.is_enabled == True) + + stmt = stmt.order_by(col(MetadataSchema.created_at).desc()) + + result = await self._session.execute(stmt) + return list(result.scalars().all()) + + async def create_schema( + self, + tenant_id: str, + schema_create: MetadataSchemaCreate, + ) -> MetadataSchema: + """ + 创建元数据模式 + + Args: + tenant_id: 租户 ID + schema_create: 创建数据 + + Returns: + 创建的 MetadataSchema + """ + schema = MetadataSchema( + tenant_id=tenant_id, + name=schema_create.name, + description=schema_create.description, + fields=[f.model_dump() if hasattr(f, 'model_dump') else f for f in schema_create.fields], + is_default=schema_create.is_default, + is_enabled=True, + ) + + self._session.add(schema) + await self._session.flush() + + logger.info( + f"[MetadataSchemaService] Created schema: tenant={tenant_id}, " + f"name={schema.name}, fields_count={len(schema.fields)}" + ) + + return schema + + async def update_schema( + self, + tenant_id: str, + schema_id: str, + schema_update: MetadataSchemaUpdate, + ) -> MetadataSchema | None: + """ + 更新元数据模式 + + Args: + tenant_id: 租户 ID + schema_id: 模式 ID + schema_update: 更新数据 + + Returns: + 更新后的 MetadataSchema 或 None + """ + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + MetadataSchema.id == uuid.UUID(schema_id), + ) + result = await self._session.execute(stmt) + schema = result.scalar_one_or_none() + + if not schema: + return None + + if schema_update.name is not None: + schema.name = schema_update.name + if schema_update.description is not None: + schema.description = schema_update.description + if schema_update.fields is not None: + schema.fields = schema_update.fields + if schema_update.is_default is not None: + if schema_update.is_default: + await self._unset_other_defaults(tenant_id, schema_id) + schema.is_default = True + if schema_update.is_enabled is not None: + schema.is_enabled = schema_update.is_enabled + + schema.updated_at = datetime.utcnow() + await self._session.flush() + + logger.info( + f"[MetadataSchemaService] Updated schema: tenant={tenant_id}, " + f"schema_id={schema_id}" + ) + + return schema + + async def delete_schema( + self, + tenant_id: str, + schema_id: str, + ) -> bool: + """ + 删除元数据模式 + + Args: + tenant_id: 租户 ID + schema_id: 模式 ID + + Returns: + 是否删除成功 + """ + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + MetadataSchema.id == uuid.UUID(schema_id), + ) + result = await self._session.execute(stmt) + schema = result.scalar_one_or_none() + + if not schema: + return False + + if schema.is_default: + logger.warning( + f"[MetadataSchemaService] Cannot delete default schema: " + f"tenant={tenant_id}, schema_id={schema_id}" + ) + return False + + await self._session.delete(schema) + await self._session.flush() + + logger.info( + f"[MetadataSchemaService] Deleted schema: tenant={tenant_id}, " + f"schema_id={schema_id}" + ) + + return True + + async def _unset_other_defaults( + self, + tenant_id: str, + exclude_schema_id: str, + ) -> None: + """取消其他模式的默认状态""" + stmt = select(MetadataSchema).where( + MetadataSchema.tenant_id == tenant_id, + MetadataSchema.is_default == True, + MetadataSchema.id != uuid.UUID(exclude_schema_id), + ) + result = await self._session.execute(stmt) + other_schemas = result.scalars().all() + + for other in other_schemas: + other.is_default = False + + async def get_field_definitions( + self, + tenant_id: str, + schema_id: str | None = None, + ) -> dict[str, dict[str, Any]]: + """ + 获取字段定义映射 + + Args: + tenant_id: 租户 ID + schema_id: 模式 ID(可选) + + Returns: + 字段名到字段定义的映射,如 {"grade": {"label": "年级", "type": "select", "options": [...]}} + """ + schema = await self.get_schema(tenant_id, schema_id) + if not schema: + return {} + + field_map = {} + for field in schema.fields: + field_name = field.get("name") if isinstance(field, dict) else field.name + field_map[field_name] = field if isinstance(field, dict) else field.model_dump() + + return field_map + + async def validate_metadata( + self, + tenant_id: str, + metadata: dict[str, Any], + schema_id: str | None = None, + ) -> tuple[bool, list[str]]: + """ + 验证元数据是否符合模式定义 + + Args: + tenant_id: 租户 ID + metadata: 元数据字典 + schema_id: 模式 ID(可选) + + Returns: + (是否有效, 错误消息列表) + """ + field_defs = await self.get_field_definitions(tenant_id, schema_id) + + if not field_defs: + return True, [] + + errors = [] + + for field_name, field_def in field_defs.items(): + field_type = field_def.get("field_type", "string") + required = field_def.get("required", False) + options = field_def.get("options", []) + value = metadata.get(field_name) + + if required and value is None: + errors.append(f"字段 '{field_def.get('label', field_name)}' 是必填的") + continue + + if value is None: + continue + + if field_type == MetadataFieldType.SELECT.value: + if value not in options: + errors.append( + f"字段 '{field_def.get('label', field_name)}' 的值 '{value}' 不在允许选项中" + ) + + elif field_type == MetadataFieldType.MULTI_SELECT.value: + if not isinstance(value, list): + errors.append(f"字段 '{field_def.get('label', field_name)}' 应该是多选值") + else: + for v in value: + if v not in options: + errors.append( + f"字段 '{field_def.get('label', field_name)}' 的值 '{v}' 不在允许选项中" + ) + + elif field_type == MetadataFieldType.NUMBER.value: + if not isinstance(value, (int, float)): + try: + float(value) + except (ValueError, TypeError): + errors.append(f"字段 '{field_def.get('label', field_name)}' 应该是数字") + + elif field_type == MetadataFieldType.BOOLEAN.value: + if not isinstance(value, bool): + if value not in ["true", "false", "1", "0", 1, 0]: + errors.append(f"字段 '{field_def.get('label', field_name)}' 应该是布尔值") + + return len(errors) == 0, errors diff --git a/ai-service/tests/test_metadata_governance_contract.py b/ai-service/tests/test_metadata_governance_contract.py new file mode 100644 index 0000000..f95006f --- /dev/null +++ b/ai-service/tests/test_metadata_governance_contract.py @@ -0,0 +1,732 @@ +""" +Contract tests for Metadata Governance module. +[AC-IDSMETA-13~22] Verify provider API matches openapi.provider.yaml contract. + +Contract Level: L2 +Reference: spec/metadata-governance/openapi.provider.yaml +""" + +import pytest +from pydantic import ValidationError +from typing import Any + + +class MetadataSchema: + """ + [AC-IDSMETA-13] MetadataSchema contract model. + Matches openapi.provider.yaml MetadataSchema schema. + """ + + def __init__( + self, + id: str, + field_key: str, + label: str, + type: str, + required: bool, + scope: list[str], + status: str, + options: list[str] | None = None, + default: str | int | float | bool | None = None, + is_filterable: bool = True, + is_rank_feature: bool = False, + ): + self.id = id + self.field_key = field_key + self.label = label + self.type = type + self.required = required + self.scope = scope + self.status = status + self.options = options + self.default = default + self.is_filterable = is_filterable + self.is_rank_feature = is_rank_feature + + def validate(self) -> tuple[bool, list[str]]: + errors = [] + if not self.field_key: + errors.append("field_key is required") + if not self.label: + errors.append("label is required") + if self.type not in ["string", "number", "boolean", "enum", "array_enum"]: + errors.append(f"Invalid type: {self.type}") + if self.status not in ["draft", "active", "deprecated"]: + errors.append(f"Invalid status: {self.status}") + if not self.scope: + errors.append("scope must have at least one item") + for s in self.scope: + if s not in ["kb_document", "intent_rule", "script_flow", "prompt_template"]: + errors.append(f"Invalid scope value: {s}") + return len(errors) == 0, errors + + +class MetadataSchemaCreateRequest: + """ + [AC-IDSMETA-13] Create request contract model. + """ + + VALID_FIELD_KEY_PATTERN = r"^[a-z0-9_]+$" + VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"] + VALID_STATUSES = ["draft", "active", "deprecated"] + VALID_SCOPES = ["kb_document", "intent_rule", "script_flow", "prompt_template"] + + def __init__( + self, + field_key: str, + label: str, + type: str, + required: bool, + scope: list[str], + status: str, + options: list[str] | None = None, + default: str | int | float | bool | None = None, + is_filterable: bool = True, + is_rank_feature: bool = False, + ): + self.field_key = field_key + self.label = label + self.type = type + self.required = required + self.scope = scope + self.status = status + self.options = options + self.default = default + self.is_filterable = is_filterable + self.is_rank_feature = is_rank_feature + + def validate(self) -> tuple[bool, list[str]]: + import re + errors = [] + + if not self.field_key or len(self.field_key) < 1 or len(self.field_key) > 64: + errors.append("field_key must be 1-64 characters") + elif not re.match(self.VALID_FIELD_KEY_PATTERN, self.field_key): + errors.append(f"field_key must match pattern {self.VALID_FIELD_KEY_PATTERN}") + + if not self.label or len(self.label) < 1 or len(self.label) > 64: + errors.append("label must be 1-64 characters") + + if self.type not in self.VALID_TYPES: + errors.append(f"type must be one of {self.VALID_TYPES}") + + if self.status not in self.VALID_STATUSES: + errors.append(f"status must be one of {self.VALID_STATUSES}") + + if not self.scope or len(self.scope) < 1: + errors.append("scope must have at least one item") + else: + for s in self.scope: + if s not in self.VALID_SCOPES: + errors.append(f"Invalid scope value: {s}") + + if self.type in ["enum", "array_enum"] and (not self.options or len(self.options) == 0): + errors.append(f"type '{self.type}' requires non-empty options") + + if self.options: + if len(self.options) != len(set(self.options)): + errors.append("options must have unique values") + + return len(errors) == 0, errors + + +class MetadataSchemaUpdateRequest: + """ + [AC-IDSMETA-14] Update request contract model. + """ + + VALID_STATUSES = ["draft", "active", "deprecated"] + VALID_SCOPES = ["kb_document", "intent_rule", "script_flow", "prompt_template"] + + def __init__( + self, + label: str | None = None, + required: bool | None = None, + options: list[str] | None = None, + default: str | int | float | bool | None = None, + scope: list[str] | None = None, + is_filterable: bool | None = None, + is_rank_feature: bool | None = None, + status: str | None = None, + ): + self.label = label + self.required = required + self.options = options + self.default = default + self.scope = scope + self.is_filterable = is_filterable + self.is_rank_feature = is_rank_feature + self.status = status + + def validate(self) -> tuple[bool, list[str]]: + errors = [] + + if self.label is not None and (len(self.label) < 1 or len(self.label) > 64): + errors.append("label must be 1-64 characters") + + if self.status is not None and self.status not in self.VALID_STATUSES: + errors.append(f"status must be one of {self.VALID_STATUSES}") + + if self.scope is not None: + if len(self.scope) < 1: + errors.append("scope must have at least one item") + else: + for s in self.scope: + if s not in self.VALID_SCOPES: + errors.append(f"Invalid scope value: {s}") + + if self.options is not None: + if len(self.options) != len(set(self.options)): + errors.append("options must have unique values") + + return len(errors) == 0, errors + + +class DecompositionTemplate: + """ + [AC-IDSMETA-21, AC-IDSMETA-22] DecompositionTemplate contract model. + """ + + VALID_VERSION_PATTERN = r"^v?[0-9]+\.[0-9]+\.[0-9]+$" + VALID_STATUSES = ["draft", "active", "deprecated"] + + def __init__( + self, + id: str, + name: str, + template_content: str, + version: str, + status: str, + ): + self.id = id + self.name = name + self.template_content = template_content + self.version = version + self.status = status + + def validate(self) -> tuple[bool, list[str]]: + import re + errors = [] + + if not self.name or len(self.name) < 1 or len(self.name) > 100: + errors.append("name must be 1-100 characters") + + if not self.template_content or len(self.template_content) < 20: + errors.append("template_content must be at least 20 characters") + + if not re.match(self.VALID_VERSION_PATTERN, self.version): + errors.append(f"version must match pattern {self.VALID_VERSION_PATTERN}") + + if self.status not in self.VALID_STATUSES: + errors.append(f"status must be one of {self.VALID_STATUSES}") + + return len(errors) == 0, errors + + +class ErrorResponse: + """ + Error response contract model. + """ + + def __init__(self, code: str, message: str, details: dict[str, Any] | None = None): + self.code = code + self.message = message + self.details = details + + def validate(self) -> tuple[bool, list[str]]: + errors = [] + if not self.code: + errors.append("code is required") + if not self.message: + errors.append("message is required") + return len(errors) == 0, errors + + +class TestMetadataSchemaContract: + """ + [AC-IDSMETA-13] Test MetadataSchema matches OpenAPI contract. + """ + + def test_required_fields_present(self): + """MetadataSchema must have all required fields.""" + schema = MetadataSchema( + id="test-id", + field_key="grade", + label="年级", + type="enum", + required=True, + scope=["kb_document"], + status="active", + ) + is_valid, errors = schema.validate() + assert is_valid, f"Validation failed: {errors}" + + def test_field_key_pattern_validation(self): + """field_key must match ^[a-z0-9_]+$ pattern.""" + valid_keys = ["grade", "subject_name", "type1", "kb_type"] + for key in valid_keys: + schema = MetadataSchema( + id="test-id", + field_key=key, + label="Test", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, _ = schema.validate() + assert is_valid, f"Valid key '{key}' should pass" + + def test_field_key_rejects_invalid(self): + """field_key must reject invalid patterns.""" + invalid_keys = ["Grade", "subject-name", "test key", "test.key"] + for key in invalid_keys: + request = MetadataSchemaCreateRequest( + field_key=key, + label="Test", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, _ = request.validate() + assert not is_valid, f"Invalid key '{key}' should fail" + + def test_type_enum_values(self): + """type must be one of: string, number, boolean, enum, array_enum.""" + valid_types = ["string", "number", "boolean", "enum", "array_enum"] + for t in valid_types: + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type=t, + required=False, + scope=["kb_document"], + status="active", + ) + is_valid, _ = schema.validate() + assert is_valid, f"Valid type '{t}' should pass" + + def test_type_rejects_invalid(self): + """type must reject invalid values.""" + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="invalid_type", + required=False, + scope=["kb_document"], + status="active", + ) + is_valid, _ = schema.validate() + assert not is_valid + + def test_status_enum_values(self): + """status must be one of: draft, active, deprecated.""" + valid_statuses = ["draft", "active", "deprecated"] + for s in valid_statuses: + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="string", + required=False, + scope=["kb_document"], + status=s, + ) + is_valid, _ = schema.validate() + assert is_valid, f"Valid status '{s}' should pass" + + def test_scope_enum_values(self): + """scope items must be valid.""" + valid_scopes = [ + ["kb_document"], + ["intent_rule"], + ["script_flow"], + ["prompt_template"], + ["kb_document", "intent_rule"], + ] + for scope in valid_scopes: + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="string", + required=False, + scope=scope, + status="active", + ) + is_valid, _ = schema.validate() + assert is_valid, f"Valid scope '{scope}' should pass" + + def test_scope_rejects_invalid(self): + """scope must reject invalid values.""" + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="string", + required=False, + scope=["invalid_scope"], + status="active", + ) + is_valid, _ = schema.validate() + assert not is_valid + + def test_scope_requires_at_least_one(self): + """scope must have at least one item.""" + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="string", + required=False, + scope=[], + status="active", + ) + is_valid, _ = schema.validate() + assert not is_valid + + +class TestMetadataSchemaCreateRequestContract: + """ + [AC-IDSMETA-13] Test MetadataSchemaCreateRequest validation. + """ + + def test_valid_create_request(self): + """Valid create request should pass.""" + request = MetadataSchemaCreateRequest( + field_key="grade", + label="年级", + type="enum", + required=True, + scope=["kb_document"], + status="draft", + options=["初一", "初二", "初三"], + ) + is_valid, errors = request.validate() + assert is_valid, f"Validation failed: {errors}" + + def test_enum_type_requires_options(self): + """[AC-IDSMETA-03] enum type requires non-empty options.""" + request = MetadataSchemaCreateRequest( + field_key="grade", + label="年级", + type="enum", + required=True, + scope=["kb_document"], + status="draft", + options=None, + ) + is_valid, _ = request.validate() + assert not is_valid + + def test_array_enum_type_requires_options(self): + """[AC-IDSMETA-03] array_enum type requires non-empty options.""" + request = MetadataSchemaCreateRequest( + field_key="subjects", + label="学科", + type="array_enum", + required=False, + scope=["kb_document"], + status="draft", + options=None, + ) + is_valid, _ = request.validate() + assert not is_valid + + def test_options_must_be_unique(self): + """[AC-IDSMETA-03] options must have unique values.""" + request = MetadataSchemaCreateRequest( + field_key="grade", + label="年级", + type="enum", + required=True, + scope=["kb_document"], + status="draft", + options=["初一", "初一", "初二"], + ) + is_valid, _ = request.validate() + assert not is_valid + + def test_field_key_length_constraints(self): + """field_key must be 1-64 characters.""" + request = MetadataSchemaCreateRequest( + field_key="", + label="Test", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, _ = request.validate() + assert not is_valid + + request = MetadataSchemaCreateRequest( + field_key="a" * 65, + label="Test", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, _ = request.validate() + assert not is_valid + + def test_label_length_constraints(self): + """label must be 1-64 characters.""" + request = MetadataSchemaCreateRequest( + field_key="test", + label="", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, _ = request.validate() + assert not is_valid + + +class TestMetadataSchemaUpdateRequestContract: + """ + [AC-IDSMETA-14] Test MetadataSchemaUpdateRequest validation. + """ + + def test_valid_update_request(self): + """Valid update request should pass.""" + request = MetadataSchemaUpdateRequest( + label="更新后的标签", + status="deprecated", + ) + is_valid, errors = request.validate() + assert is_valid, f"Validation failed: {errors}" + + def test_partial_update(self): + """Partial update with only some fields should pass.""" + request = MetadataSchemaUpdateRequest(status="active") + is_valid, _ = request.validate() + assert is_valid + + def test_empty_update(self): + """Empty update should pass (all fields optional).""" + request = MetadataSchemaUpdateRequest() + is_valid, _ = request.validate() + assert is_valid + + def test_status_transition_to_deprecated(self): + """[AC-IDSMETA-14] Status can be updated to deprecated.""" + request = MetadataSchemaUpdateRequest(status="deprecated") + is_valid, _ = request.validate() + assert is_valid + + +class TestDecompositionTemplateContract: + """ + [AC-IDSMETA-21, AC-IDSMETA-22] Test DecompositionTemplate validation. + """ + + def test_valid_template(self): + """Valid template should pass.""" + template = DecompositionTemplate( + id="template-1", + name="数据拆解模板", + template_content="这是一个数据拆解模板,用于分析和归类待录入文本...", + version="1.0.0", + status="active", + ) + is_valid, errors = template.validate() + assert is_valid, f"Validation failed: {errors}" + + def test_version_format_with_v_prefix(self): + """version can have optional 'v' prefix.""" + template = DecompositionTemplate( + id="template-1", + name="Test", + template_content="a" * 20, + version="v1.0.0", + status="active", + ) + is_valid, _ = template.validate() + assert is_valid + + def test_version_format_without_v_prefix(self): + """version can be without 'v' prefix.""" + template = DecompositionTemplate( + id="template-1", + name="Test", + template_content="a" * 20, + version="2.1.3", + status="active", + ) + is_valid, _ = template.validate() + assert is_valid + + def test_version_rejects_invalid_format(self): + """version must match semver pattern.""" + invalid_versions = ["1.0", "v1", "1.0.0.0", "latest"] + for v in invalid_versions: + template = DecompositionTemplate( + id="template-1", + name="Test", + template_content="a" * 20, + version=v, + status="active", + ) + is_valid, _ = template.validate() + assert not is_valid, f"Invalid version '{v}' should fail" + + def test_template_content_min_length(self): + """template_content must be at least 20 characters.""" + template = DecompositionTemplate( + id="template-1", + name="Test", + template_content="short", + version="1.0.0", + status="active", + ) + is_valid, _ = template.validate() + assert not is_valid + + def test_name_length_constraints(self): + """name must be 1-100 characters.""" + template = DecompositionTemplate( + id="template-1", + name="", + template_content="a" * 20, + version="1.0.0", + status="active", + ) + is_valid, _ = template.validate() + assert not is_valid + + template = DecompositionTemplate( + id="template-1", + name="a" * 101, + template_content="a" * 20, + version="1.0.0", + status="active", + ) + is_valid, _ = template.validate() + assert not is_valid + + +class TestErrorResponseContract: + """ + Test ErrorResponse matches OpenAPI contract. + """ + + def test_required_fields(self): + """ErrorResponse must have code and message.""" + response = ErrorResponse( + code="VALIDATION_ERROR", + message="Invalid request", + ) + is_valid, errors = response.validate() + assert is_valid, f"Validation failed: {errors}" + + def test_optional_details(self): + """ErrorResponse can have optional details.""" + response = ErrorResponse( + code="VALIDATION_ERROR", + message="Multiple validation errors", + details={"fields": ["field_key", "label"]}, + ) + is_valid, _ = response.validate() + assert is_valid + + +class TestACTraceability: + """ + [QA-IDSMETA-01] Verify AC traceability in openapi.provider.yaml. + """ + + def test_ac_idsmeta_13_traceability(self): + """ + [AC-IDSMETA-13] Verify field status management (draft/active/deprecated). + OpenAPI: listMetadataSchemas, createMetadataSchema + """ + valid_statuses = ["draft", "active", "deprecated"] + for status in valid_statuses: + schema = MetadataSchema( + id="test-id", + field_key="test", + label="Test", + type="string", + required=False, + scope=["kb_document"], + status=status, + ) + is_valid, _ = schema.validate() + assert is_valid + + def test_ac_idsmeta_14_traceability(self): + """ + [AC-IDSMETA-14] Verify deprecated field handling. + OpenAPI: updateMetadataSchema with status=deprecated + """ + request = MetadataSchemaUpdateRequest(status="deprecated") + is_valid, _ = request.validate() + assert is_valid + + def test_ac_idsmeta_21_22_traceability(self): + """ + [AC-IDSMETA-21, AC-IDSMETA-22] Verify decomposition template contract. + OpenAPI: listDecompositionTemplates, createDecompositionTemplate + """ + template = DecompositionTemplate( + id="template-1", + name="拆解模板", + template_content="这是一个拆解模板,用于分析待录入文本的归类...", + version="1.0.0", + status="active", + ) + is_valid, _ = template.validate() + assert is_valid + + +class TestContractLevelCompliance: + """ + Verify L2 contract level compliance. + """ + + def test_schema_completeness(self): + """L2 requires complete schema with required/optional fields clearly defined.""" + schema = MetadataSchema( + id="test-id", + field_key="grade", + label="年级", + type="enum", + required=True, + scope=["kb_document", "intent_rule"], + status="active", + options=["初一", "初二", "初三"], + default="初一", + is_filterable=True, + is_rank_feature=False, + ) + is_valid, _ = schema.validate() + assert is_valid + + def test_error_response_schema(self): + """L2 requires defined error response schema.""" + error = ErrorResponse( + code="SCHEMA_NOT_FOUND", + message="Metadata schema not found", + details={"schema_id": "non-existent-id"}, + ) + is_valid, _ = error.validate() + assert is_valid + + def test_field_validation_rules(self): + """L2 requires clear field validation rules.""" + request = MetadataSchemaCreateRequest( + field_key="valid_key_123", + label="Valid Label", + type="string", + required=False, + scope=["kb_document"], + status="draft", + ) + is_valid, errors = request.validate() + assert is_valid, f"Validation failed: {errors}" diff --git a/ai-service/tests/test_metadata_governance_integration.py b/ai-service/tests/test_metadata_governance_integration.py new file mode 100644 index 0000000..3fb94e5 --- /dev/null +++ b/ai-service/tests/test_metadata_governance_integration.py @@ -0,0 +1,675 @@ +""" +Integration tests for Metadata Governance runtime pipeline. +[AC-IDSMETA-18~20] Test routing -> filtering -> retrieval -> fallback chain. + +Test Matrix: +- AC-IDSMETA-18: Intent routing with target KB selection +- AC-IDSMETA-19: Metadata filter injection in RAG retrieval +- AC-IDSMETA-20: Fallback strategy with structured reason codes +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any +from dataclasses import dataclass, field + + +@dataclass +class MockIntentRule: + """Mock IntentRule for testing.""" + id: str + name: str + response_type: str + target_kb_ids: list[str] | None = None + keywords: list[str] | None = None + patterns: list[str] | None = None + priority: int = 0 + is_enabled: bool = True + fixed_reply: str | None = None + flow_id: str | None = None + transfer_message: str | None = None + + +@dataclass +class MockIntentMatchResult: + """Mock IntentMatchResult for testing.""" + rule: MockIntentRule + match_type: str + matched: str + + def to_dict(self) -> dict[str, Any]: + return { + "rule_id": str(self.rule.id), + "rule_name": self.rule.name, + "match_type": self.match_type, + "matched": self.matched, + "response_type": self.rule.response_type, + "target_kb_ids": self.rule.target_kb_ids or [], + } + + +@dataclass +class MockRetrievalHit: + """Mock RetrievalHit for testing.""" + text: str + score: float + source: str + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MockRetrievalResult: + """Mock RetrievalResult for testing.""" + hits: list[MockRetrievalHit] + diagnostics: dict[str, Any] = field(default_factory=dict) + + @property + def hit_count(self) -> int: + return len(self.hits) + + @property + def max_score(self) -> float: + return max((h.score for h in self.hits), default=0.0) + + @property + def is_empty(self) -> bool: + return len(self.hits) == 0 + + +class MockIntentRouter: + """Mock IntentRouter for testing.""" + + def match(self, message: str, rules: list[MockIntentRule]) -> MockIntentMatchResult | None: + message_lower = message.lower() + sorted_rules = sorted(rules, key=lambda r: r.priority, reverse=True) + for rule in sorted_rules: + if not rule.is_enabled: + continue + if rule.keywords: + for keyword in rule.keywords: + if keyword.lower() in message_lower: + return MockIntentMatchResult( + rule=rule, + match_type="keyword", + matched=keyword, + ) + if rule.patterns: + import re + for pattern in rule.patterns: + if re.search(pattern, message, re.IGNORECASE): + return MockIntentMatchResult( + rule=rule, + match_type="regex", + matched=pattern, + ) + return None + + +class MockRetriever: + """Mock Retriever with metadata filtering support.""" + + def __init__(self, hits: list[MockRetrievalHit] | None = None): + self._hits = hits or [] + self._last_filter: dict[str, Any] | None = None + self._last_target_kb_ids: list[str] | None = None + + async def retrieve( + self, + tenant_id: str, + query: str, + target_kb_ids: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> MockRetrievalResult: + self._last_filter = metadata_filter + self._last_target_kb_ids = target_kb_ids + + filtered_hits = [] + for hit in self._hits: + if metadata_filter: + match = True + for key, value in metadata_filter.items(): + if hit.metadata.get(key) != value: + match = False + break + if not match: + continue + if target_kb_ids: + if hit.metadata.get("kb_id") not in target_kb_ids: + continue + filtered_hits.append(hit) + + return MockRetrievalResult( + hits=filtered_hits, + diagnostics={ + "filter_applied": metadata_filter is not None, + "target_kb_ids": target_kb_ids, + }, + ) + + +class FallbackStrategy: + """ + [AC-IDSMETA-20] Fallback strategy with structured reason codes. + """ + + REASON_CODES = { + "NO_INTENT_MATCH": "intent_not_matched", + "NO_RETRIEVAL_HITS": "retrieval_empty", + "LOW_CONFIDENCE": "confidence_below_threshold", + "KB_UNAVAILABLE": "knowledge_base_unavailable", + "METADATA_FILTER_TOO_STRICT": "filter_excluded_all", + } + + def execute( + self, + reason: str, + fallback_kb_id: str | None = None, + fallback_message: str | None = None, + ) -> dict[str, Any]: + reason_code = self.REASON_CODES.get(reason, "unknown") + + result = { + "fallback_triggered": True, + "reason_code": reason_code, + "fallback_type": None, + "fallback_content": None, + } + + if fallback_kb_id: + result["fallback_type"] = "kb" + result["fallback_kb_id"] = fallback_kb_id + elif fallback_message: + result["fallback_type"] = "fixed" + result["fallback_content"] = fallback_message + else: + result["fallback_type"] = "default" + result["fallback_content"] = "抱歉,我暂时无法回答您的问题,请稍后重试或联系人工客服。" + + return result + + +class TestIntentRouting: + """ + [AC-IDSMETA-18] Test intent routing with target KB selection. + """ + + def setup_method(self): + self.router = MockIntentRouter() + + def test_keyword_match_routes_to_rag(self): + """Intent with response_type=rag should route to RAG with target KBs.""" + rules = [ + MockIntentRule( + id="rule-1", + name="退货咨询", + response_type="rag", + target_kb_ids=["kb-return", "kb-policy"], + keywords=["退货", "退款"], + priority=10, + ) + ] + + result = self.router.match("我想退货怎么办", rules) + + assert result is not None + assert result.rule.response_type == "rag" + assert result.rule.target_kb_ids == ["kb-return", "kb-policy"] + assert result.match_type == "keyword" + + def test_regex_match_routes_to_flow(self): + """Intent with response_type=flow should start script flow.""" + rules = [ + MockIntentRule( + id="rule-2", + name="订单查询", + response_type="flow", + flow_id="flow-order-query", + patterns=[r"订单.*查询", r"查询.*订单"], + priority=5, + ) + ] + + result = self.router.match("帮我查询订单状态", rules) + + assert result is not None + assert result.rule.response_type == "flow" + assert result.rule.flow_id == "flow-order-query" + assert result.match_type == "regex" + + def test_fixed_reply_intent(self): + """Intent with response_type=fixed should return fixed reply.""" + rules = [ + MockIntentRule( + id="rule-3", + name="问候", + response_type="fixed", + fixed_reply="您好,请问有什么可以帮您?", + keywords=["你好", "您好"], + priority=1, + ) + ] + + result = self.router.match("你好", rules) + + assert result is not None + assert result.rule.response_type == "fixed" + assert result.rule.fixed_reply == "您好,请问有什么可以帮您?" + + def test_transfer_intent(self): + """Intent with response_type=transfer should trigger transfer.""" + rules = [ + MockIntentRule( + id="rule-4", + name="人工服务", + response_type="transfer", + transfer_message="正在为您转接人工客服...", + keywords=["人工", "转人工"], + priority=100, + ) + ] + + result = self.router.match("我要转人工", rules) + + assert result is not None + assert result.rule.response_type == "transfer" + assert result.rule.transfer_message == "正在为您转接人工客服..." + + def test_priority_ordering(self): + """Higher priority rules should be matched first.""" + rules = [ + MockIntentRule( + id="rule-low", + name="通用问候", + response_type="fixed", + fixed_reply="通用问候回复", + keywords=["你好"], + priority=1, + ), + MockIntentRule( + id="rule-high", + name="VIP问候", + response_type="fixed", + fixed_reply="VIP问候回复", + keywords=["你好"], + priority=10, + ), + ] + + result = self.router.match("你好", rules) + + assert result is not None + assert result.rule.id == "rule-high" + assert result.rule.fixed_reply == "VIP问候回复" + + def test_no_match_returns_none(self): + """No matching intent should return None.""" + rules = [ + MockIntentRule( + id="rule-1", + name="退货", + response_type="rag", + keywords=["退货"], + priority=10, + ) + ] + + result = self.router.match("今天天气怎么样", rules) + + assert result is None + + +class TestMetadataFilterInjection: + """ + [AC-IDSMETA-19] Test metadata filter injection in RAG retrieval. + """ + + @pytest.mark.asyncio + async def test_filter_injection_with_grade_subject_scene(self): + """RAG retrieval should inject grade/subject/scene metadata filters.""" + retriever = MockRetriever(hits=[ + MockRetrievalHit( + text="初一数学知识点", + score=0.9, + source="kb", + metadata={"grade": "初一", "subject": "数学", "scene": "课后辅导", "kb_id": "kb-1"}, + ), + MockRetrievalHit( + text="初二物理知识点", + score=0.85, + source="kb", + metadata={"grade": "初二", "subject": "物理", "scene": "课后辅导", "kb_id": "kb-1"}, + ), + ]) + + metadata_filter = { + "grade": "初一", + "subject": "数学", + "scene": "课后辅导", + } + + result = await retriever.retrieve( + tenant_id="tenant-1", + query="数学知识点", + metadata_filter=metadata_filter, + ) + + assert retriever._last_filter == metadata_filter + assert result.hit_count == 1 + assert result.hits[0].metadata["grade"] == "初一" + + @pytest.mark.asyncio + async def test_target_kb_ids_filtering(self): + """RAG retrieval should filter by target KB IDs from intent.""" + retriever = MockRetriever(hits=[ + MockRetrievalHit( + text="退货政策", + score=0.9, + source="kb", + metadata={"kb_id": "kb-return"}, + ), + MockRetrievalHit( + text="产品介绍", + score=0.85, + source="kb", + metadata={"kb_id": "kb-product"}, + ), + ]) + + result = await retriever.retrieve( + tenant_id="tenant-1", + query="退货", + target_kb_ids=["kb-return"], + ) + + assert retriever._last_target_kb_ids == ["kb-return"] + assert result.hit_count == 1 + assert result.hits[0].metadata["kb_id"] == "kb-return" + + @pytest.mark.asyncio + async def test_combined_filters(self): + """RAG retrieval should combine target KB and metadata filters.""" + retriever = MockRetriever(hits=[ + MockRetrievalHit( + text="初一数学教材", + score=0.9, + source="kb", + metadata={"grade": "初一", "subject": "数学", "kb_id": "kb-edu"}, + ), + MockRetrievalHit( + text="初二数学教材", + score=0.85, + source="kb", + metadata={"grade": "初二", "subject": "数学", "kb_id": "kb-edu"}, + ), + MockRetrievalHit( + text="初一数学练习", + score=0.8, + source="kb", + metadata={"grade": "初一", "subject": "数学", "kb_id": "kb-exercise"}, + ), + ]) + + result = await retriever.retrieve( + tenant_id="tenant-1", + query="数学", + target_kb_ids=["kb-edu"], + metadata_filter={"grade": "初一"}, + ) + + assert result.hit_count == 1 + assert result.hits[0].metadata["grade"] == "初一" + assert result.hits[0].metadata["kb_id"] == "kb-edu" + + +class TestFallbackStrategy: + """ + [AC-IDSMETA-20] Test fallback strategy with structured reason codes. + """ + + def setup_method(self): + self.fallback = FallbackStrategy() + + def test_no_intent_match_fallback(self): + """No intent match should trigger fallback with reason code.""" + result = self.fallback.execute( + reason="NO_INTENT_MATCH", + fallback_message="抱歉,我不太理解您的问题,请换种方式描述。", + ) + + assert result["fallback_triggered"] is True + assert result["reason_code"] == "intent_not_matched" + assert result["fallback_type"] == "fixed" + assert "不太理解" in result["fallback_content"] + + def test_no_retrieval_hits_fallback(self): + """No retrieval hits should trigger fallback with reason code.""" + result = self.fallback.execute( + reason="NO_RETRIEVAL_HITS", + fallback_kb_id="kb-general", + ) + + assert result["fallback_triggered"] is True + assert result["reason_code"] == "retrieval_empty" + assert result["fallback_type"] == "kb" + assert result["fallback_kb_id"] == "kb-general" + + def test_low_confidence_fallback(self): + """Low confidence should trigger fallback with reason code.""" + result = self.fallback.execute( + reason="LOW_CONFIDENCE", + fallback_message="我对这个回答不太确定,建议您咨询人工客服。", + ) + + assert result["fallback_triggered"] is True + assert result["reason_code"] == "confidence_below_threshold" + assert result["fallback_type"] == "fixed" + + def test_metadata_filter_too_strict_fallback(self): + """Too strict metadata filter should trigger fallback.""" + result = self.fallback.execute( + reason="METADATA_FILTER_TOO_STRICT", + fallback_message="没有找到符合条件的答案,请尝试调整筛选条件。", + ) + + assert result["fallback_triggered"] is True + assert result["reason_code"] == "filter_excluded_all" + + def test_default_fallback(self): + """Default fallback should be used when no specific fallback provided.""" + result = self.fallback.execute(reason="NO_RETRIEVAL_HITS") + + assert result["fallback_triggered"] is True + assert result["fallback_type"] == "default" + assert "人工客服" in result["fallback_content"] + + +class TestRoutingFilterRetrievalFallbackChain: + """ + [AC-IDSMETA-18, AC-IDSMETA-19, AC-IDSMETA-20] Test complete chain. + """ + + @pytest.mark.asyncio + async def test_full_chain_with_intent_match_and_retrieval(self): + """Full chain: intent match -> metadata filter -> retrieval -> response.""" + router = MockIntentRouter() + retriever = MockRetriever(hits=[ + MockRetrievalHit( + text="退货需在7天内,商品未拆封", + score=0.9, + source="kb", + metadata={"kb_id": "kb-return"}, + ), + ]) + fallback = FallbackStrategy() + + rules = [ + MockIntentRule( + id="rule-1", + name="退货", + response_type="rag", + target_kb_ids=["kb-return"], + keywords=["退货"], + priority=10, + ) + ] + + user_message = "我想退货" + intent_result = router.match(user_message, rules) + + assert intent_result is not None + assert intent_result.rule.response_type == "rag" + + retrieval_result = await retriever.retrieve( + tenant_id="tenant-1", + query=user_message, + target_kb_ids=intent_result.rule.target_kb_ids, + ) + + assert retrieval_result.hit_count > 0 + assert not retrieval_result.is_empty + + @pytest.mark.asyncio + async def test_full_chain_no_intent_match_fallback(self): + """Full chain: no intent match -> fallback.""" + router = MockIntentRouter() + fallback = FallbackStrategy() + + rules = [ + MockIntentRule( + id="rule-1", + name="退货", + response_type="rag", + keywords=["退货"], + priority=10, + ) + ] + + user_message = "今天天气怎么样" + intent_result = router.match(user_message, rules) + + assert intent_result is None + + fallback_result = fallback.execute( + reason="NO_INTENT_MATCH", + fallback_message="抱歉,我无法回答这个问题。", + ) + + assert fallback_result["fallback_triggered"] is True + assert fallback_result["reason_code"] == "intent_not_matched" + + @pytest.mark.asyncio + async def test_full_chain_no_retrieval_hits_fallback(self): + """Full chain: intent match -> no retrieval hits -> fallback.""" + router = MockIntentRouter() + retriever = MockRetriever(hits=[]) + fallback = FallbackStrategy() + + rules = [ + MockIntentRule( + id="rule-1", + name="退货", + response_type="rag", + target_kb_ids=["kb-return"], + keywords=["退货"], + priority=10, + ) + ] + + user_message = "退货流程是什么" + intent_result = router.match(user_message, rules) + + assert intent_result is not None + + retrieval_result = await retriever.retrieve( + tenant_id="tenant-1", + query=user_message, + target_kb_ids=intent_result.rule.target_kb_ids, + ) + + assert retrieval_result.is_empty + + fallback_result = fallback.execute( + reason="NO_RETRIEVAL_HITS", + fallback_kb_id="kb-general", + ) + + assert fallback_result["fallback_triggered"] is True + assert fallback_result["reason_code"] == "retrieval_empty" + + @pytest.mark.asyncio + async def test_full_chain_with_metadata_filter(self): + """Full chain with metadata filter injection.""" + router = MockIntentRouter() + retriever = MockRetriever(hits=[ + MockRetrievalHit( + text="初一数学课程大纲", + score=0.9, + source="kb", + metadata={"grade": "初一", "subject": "数学", "scene": "咨询", "kb_id": "kb-edu"}, + ), + MockRetrievalHit( + text="初二数学课程大纲", + score=0.85, + source="kb", + metadata={"grade": "初二", "subject": "数学", "scene": "咨询", "kb_id": "kb-edu"}, + ), + ]) + fallback = FallbackStrategy() + + rules = [ + MockIntentRule( + id="rule-1", + name="课程咨询", + response_type="rag", + target_kb_ids=["kb-edu"], + keywords=["课程", "大纲"], + priority=10, + ) + ] + + user_message = "初一数学课程大纲" + session_metadata = {"grade": "初一", "subject": "数学", "scene": "咨询"} + + intent_result = router.match(user_message, rules) + + retrieval_result = await retriever.retrieve( + tenant_id="tenant-1", + query=user_message, + target_kb_ids=intent_result.rule.target_kb_ids if intent_result else None, + metadata_filter=session_metadata, + ) + + assert retrieval_result.hit_count == 1 + assert retrieval_result.hits[0].metadata["grade"] == "初一" + + +class TestReasonCodeStructure: + """ + [AC-IDSMETA-20] Test structured reason codes for fallback. + """ + + def test_reason_code_format(self): + """Reason codes should follow snake_case format.""" + fallback = FallbackStrategy() + + for reason_key, expected_code in FallbackStrategy.REASON_CODES.items(): + result = fallback.execute(reason=reason_key) + assert result["reason_code"] == expected_code + assert "_" in expected_code or expected_code.islower() + + def test_reason_code_in_diagnostics(self): + """Reason code should be included in diagnostics.""" + fallback = FallbackStrategy() + + result = fallback.execute(reason="NO_RETRIEVAL_HITS") + + assert "reason_code" in result + assert result["reason_code"] == "retrieval_empty" + + def test_unknown_reason_code(self): + """Unknown reason should return 'unknown' code.""" + fallback = FallbackStrategy() + + result = fallback.execute(reason="UNKNOWN_REASON") + + assert result["reason_code"] == "unknown"