diff --git a/ai-service/app/services/mid/role_based_field_provider.py b/ai-service/app/services/mid/role_based_field_provider.py new file mode 100644 index 0000000..4d34275 --- /dev/null +++ b/ai-service/app/services/mid/role_based_field_provider.py @@ -0,0 +1,299 @@ +""" +Role Based Field Provider Service. +[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者服务 +""" + +import logging +from typing import Any + +from sqlalchemy import select, cast +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ( + FieldRole, + MetadataFieldDefinition, + MetadataFieldStatus, + SlotDefinition, +) +from app.schemas.metadata import VALID_FIELD_ROLES + +logger = logging.getLogger(__name__) + + +class InvalidRoleError(Exception): + """[AC-MRS-05] 无效角色异常""" + + def __init__(self, role: str): + self.role = role + self.valid_roles = VALID_FIELD_ROLES + super().__init__( + f"Invalid role '{role}'. Valid roles are: {', '.join(self.valid_roles)}" + ) + + +class RoleBasedFieldProvider: + """ + [AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者 + + 提供按角色查询字段定义的能力,供工具链按需消费 + """ + + CACHE_KEY_PREFIX = "field_roles" + CACHE_TTL = 300 + + def __init__(self, session: AsyncSession): + self._session = session + + def _validate_role(self, role: str) -> str: + """ + [AC-MRS-05] 验证角色值是否有效 + + Args: + role: 角色字符串 + + Returns: + 验证通过的角色字符串 + + Raises: + InvalidRoleError: 如果角色无效 + """ + if role not in VALID_FIELD_ROLES: + raise InvalidRoleError(role) + return role + + async def get_fields_by_role( + self, + tenant_id: str, + role: str, + include_deprecated: bool = False, + ) -> list[MetadataFieldDefinition]: + """ + [AC-MRS-04] 按角色获取字段定义 + + Args: + tenant_id: 租户 ID + role: 字段角色 (resource_filter/slot/prompt_var/routing_signal) + include_deprecated: 是否包含已废弃字段 + + Returns: + 字段定义列表 + + Raises: + InvalidRoleError: 如果角色无效 + """ + self._validate_role(role) + + logger.info( + f"[AC-MRS-04] Getting fields by role: tenant={tenant_id}, " + f"role={role}, include_deprecated={include_deprecated}" + ) + + stmt = select(MetadataFieldDefinition).where( + MetadataFieldDefinition.tenant_id == tenant_id, + cast(MetadataFieldDefinition.field_roles, JSONB).op('?')(role), + ) + + if not include_deprecated: + stmt = stmt.where( + MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value + ) + else: + stmt = stmt.where( + MetadataFieldDefinition.status.in_([ + MetadataFieldStatus.ACTIVE.value, + MetadataFieldStatus.DEPRECATED.value, + ]) + ) + + result = await self._session.execute(stmt) + fields = list(result.scalars().all()) + + logger.info( + f"[AC-MRS-04] Found {len(fields)} fields for role={role}" + ) + + return fields + + async def get_field_keys_by_role( + self, + tenant_id: str, + role: str, + include_deprecated: bool = False, + ) -> list[str]: + """ + 按角色获取字段键名列表 + + Args: + tenant_id: 租户 ID + role: 字段角色 + include_deprecated: 是否包含已废弃字段 + + Returns: + 字段键名列表 + """ + fields = await self.get_fields_by_role(tenant_id, role, include_deprecated) + return [f.field_key for f in fields] + + async def get_slot_definitions_by_role( + self, + tenant_id: str, + role: str = "slot", + ) -> list[dict[str, Any]]: + """ + [AC-MRS-10] 按角色获取槽位定义及关联字段信息 + + Args: + tenant_id: 租户 ID + role: 字段角色,默认为 slot + + Returns: + 槽位定义列表,包含关联字段信息 + """ + self._validate_role(role) + + logger.info( + f"[AC-MRS-10] Getting slot definitions by role: tenant={tenant_id}, role={role}" + ) + + fields = await self.get_fields_by_role(tenant_id, role) + field_ids = [f.id for f in fields] + field_map = {str(f.id): f for f in fields} + + stmt = select(SlotDefinition).where( + SlotDefinition.tenant_id == tenant_id, + ) + + result = await self._session.execute(stmt) + all_slots = list(result.scalars().all()) + + slot_with_fields = [] + for slot in all_slots: + slot_data = { + "id": str(slot.id), + "tenant_id": slot.tenant_id, + "slot_key": slot.slot_key, + "type": slot.type, + "required": slot.required, + "extract_strategy": slot.extract_strategy, + "validation_rule": slot.validation_rule, + "ask_back_prompt": slot.ask_back_prompt, + "default_value": slot.default_value, + "linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None, + "created_at": slot.created_at.isoformat() if slot.created_at else None, + "updated_at": slot.updated_at.isoformat() if slot.updated_at else None, + "linked_field": None, + } + + if slot.linked_field_id and str(slot.linked_field_id) in field_map: + linked_field = field_map[str(slot.linked_field_id)] + slot_data["linked_field"] = { + "id": str(linked_field.id), + "field_key": linked_field.field_key, + "label": linked_field.label, + "type": linked_field.type, + "required": linked_field.required, + "options": linked_field.options, + "default_value": linked_field.default_value, + "scope": linked_field.scope, + "is_filterable": linked_field.is_filterable, + "is_rank_feature": linked_field.is_rank_feature, + "field_roles": linked_field.field_roles, + "status": linked_field.status, + } + + slot_with_fields.append(slot_data) + + for field in fields: + field_id_str = str(field.id) + has_slot = any( + s.get("linked_field_id") == field_id_str + for s in slot_with_fields + ) + if not has_slot and role == "slot": + slot_with_fields.append({ + "id": None, + "tenant_id": field.tenant_id, + "slot_key": field.field_key, + "type": field.type, + "required": field.required, + "extract_strategy": None, + "validation_rule": None, + "ask_back_prompt": None, + "default_value": field.default_value, + "linked_field_id": str(field.id), + "created_at": None, + "updated_at": None, + "linked_field": { + "id": str(field.id), + "field_key": field.field_key, + "label": field.label, + "type": field.type, + "required": field.required, + "options": field.options, + "default_value": field.default_value, + "scope": field.scope, + "is_filterable": field.is_filterable, + "is_rank_feature": field.is_rank_feature, + "field_roles": field.field_roles, + "status": field.status, + }, + }) + + logger.info( + f"[AC-MRS-10] Found {len(slot_with_fields)} slot definitions for role={role}" + ) + + return slot_with_fields + + async def get_resource_filter_fields( + self, + tenant_id: str, + ) -> list[MetadataFieldDefinition]: + """ + [AC-MRS-11] 获取资源过滤角色字段 + 供 kb_search_dynamic 工具使用 + """ + return await self.get_fields_by_role( + tenant_id, + FieldRole.RESOURCE_FILTER.value + ) + + async def get_slot_fields( + self, + tenant_id: str, + ) -> list[MetadataFieldDefinition]: + """ + [AC-MRS-12] 获取槽位角色字段 + 供 memory_recall 工具使用 + """ + return await self.get_fields_by_role( + tenant_id, + FieldRole.SLOT.value + ) + + async def get_routing_signal_fields( + self, + tenant_id: str, + ) -> list[MetadataFieldDefinition]: + """ + [AC-MRS-13] 获取路由信号角色字段 + 供 intent_hint/high_risk_check 工具使用 + """ + return await self.get_fields_by_role( + tenant_id, + FieldRole.ROUTING_SIGNAL.value + ) + + async def get_prompt_var_fields( + self, + tenant_id: str, + ) -> list[MetadataFieldDefinition]: + """ + [AC-MRS-14] 获取提示词变量角色字段 + 供 template_engine 使用 + """ + return await self.get_fields_by_role( + tenant_id, + FieldRole.PROMPT_VAR.value + )