feat: implement RoleBasedFieldProvider service for role-based field queries [AC-MRS-04,05,10]
This commit is contained in:
parent
68e5adaa28
commit
0db2971c73
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Role Based Field Provider Service.
|
||||
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者服务
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import (
|
||||
FieldRole,
|
||||
MetadataFieldDefinition,
|
||||
MetadataFieldStatus,
|
||||
SlotDefinition,
|
||||
)
|
||||
from app.schemas.metadata import VALID_FIELD_ROLES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidRoleError(Exception):
|
||||
"""[AC-MRS-05] 无效角色异常"""
|
||||
|
||||
def __init__(self, role: str):
|
||||
self.role = role
|
||||
self.valid_roles = VALID_FIELD_ROLES
|
||||
super().__init__(
|
||||
f"Invalid role '{role}'. Valid roles are: {', '.join(self.valid_roles)}"
|
||||
)
|
||||
|
||||
|
||||
class RoleBasedFieldProvider:
|
||||
"""
|
||||
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者
|
||||
|
||||
提供按角色查询字段定义的能力,供工具链按需消费
|
||||
"""
|
||||
|
||||
CACHE_KEY_PREFIX = "field_roles"
|
||||
CACHE_TTL = 300
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
def _validate_role(self, role: str) -> str:
|
||||
"""
|
||||
[AC-MRS-05] 验证角色值是否有效
|
||||
|
||||
Args:
|
||||
role: 角色字符串
|
||||
|
||||
Returns:
|
||||
验证通过的角色字符串
|
||||
|
||||
Raises:
|
||||
InvalidRoleError: 如果角色无效
|
||||
"""
|
||||
if role not in VALID_FIELD_ROLES:
|
||||
raise InvalidRoleError(role)
|
||||
return role
|
||||
|
||||
async def get_fields_by_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
role: str,
|
||||
include_deprecated: bool = False,
|
||||
) -> list[MetadataFieldDefinition]:
|
||||
"""
|
||||
[AC-MRS-04] 按角色获取字段定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
role: 字段角色 (resource_filter/slot/prompt_var/routing_signal)
|
||||
include_deprecated: 是否包含已废弃字段
|
||||
|
||||
Returns:
|
||||
字段定义列表
|
||||
|
||||
Raises:
|
||||
InvalidRoleError: 如果角色无效
|
||||
"""
|
||||
self._validate_role(role)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-04] Getting fields by role: tenant={tenant_id}, "
|
||||
f"role={role}, include_deprecated={include_deprecated}"
|
||||
)
|
||||
|
||||
stmt = select(MetadataFieldDefinition).where(
|
||||
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||
cast(MetadataFieldDefinition.field_roles, JSONB).op('?')(role),
|
||||
)
|
||||
|
||||
if not include_deprecated:
|
||||
stmt = stmt.where(
|
||||
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(
|
||||
MetadataFieldDefinition.status.in_([
|
||||
MetadataFieldStatus.ACTIVE.value,
|
||||
MetadataFieldStatus.DEPRECATED.value,
|
||||
])
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
fields = list(result.scalars().all())
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-04] Found {len(fields)} fields for role={role}"
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
async def get_field_keys_by_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
role: str,
|
||||
include_deprecated: bool = False,
|
||||
) -> list[str]:
|
||||
"""
|
||||
按角色获取字段键名列表
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
role: 字段角色
|
||||
include_deprecated: 是否包含已废弃字段
|
||||
|
||||
Returns:
|
||||
字段键名列表
|
||||
"""
|
||||
fields = await self.get_fields_by_role(tenant_id, role, include_deprecated)
|
||||
return [f.field_key for f in fields]
|
||||
|
||||
async def get_slot_definitions_by_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
role: str = "slot",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-MRS-10] 按角色获取槽位定义及关联字段信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
role: 字段角色,默认为 slot
|
||||
|
||||
Returns:
|
||||
槽位定义列表,包含关联字段信息
|
||||
"""
|
||||
self._validate_role(role)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-10] Getting slot definitions by role: tenant={tenant_id}, role={role}"
|
||||
)
|
||||
|
||||
fields = await self.get_fields_by_role(tenant_id, role)
|
||||
field_ids = [f.id for f in fields]
|
||||
field_map = {str(f.id): f for f in fields}
|
||||
|
||||
stmt = select(SlotDefinition).where(
|
||||
SlotDefinition.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
all_slots = list(result.scalars().all())
|
||||
|
||||
slot_with_fields = []
|
||||
for slot in all_slots:
|
||||
slot_data = {
|
||||
"id": str(slot.id),
|
||||
"tenant_id": slot.tenant_id,
|
||||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||||
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||||
"linked_field": None,
|
||||
}
|
||||
|
||||
if slot.linked_field_id and str(slot.linked_field_id) in field_map:
|
||||
linked_field = field_map[str(slot.linked_field_id)]
|
||||
slot_data["linked_field"] = {
|
||||
"id": str(linked_field.id),
|
||||
"field_key": linked_field.field_key,
|
||||
"label": linked_field.label,
|
||||
"type": linked_field.type,
|
||||
"required": linked_field.required,
|
||||
"options": linked_field.options,
|
||||
"default_value": linked_field.default_value,
|
||||
"scope": linked_field.scope,
|
||||
"is_filterable": linked_field.is_filterable,
|
||||
"is_rank_feature": linked_field.is_rank_feature,
|
||||
"field_roles": linked_field.field_roles,
|
||||
"status": linked_field.status,
|
||||
}
|
||||
|
||||
slot_with_fields.append(slot_data)
|
||||
|
||||
for field in fields:
|
||||
field_id_str = str(field.id)
|
||||
has_slot = any(
|
||||
s.get("linked_field_id") == field_id_str
|
||||
for s in slot_with_fields
|
||||
)
|
||||
if not has_slot and role == "slot":
|
||||
slot_with_fields.append({
|
||||
"id": None,
|
||||
"tenant_id": field.tenant_id,
|
||||
"slot_key": field.field_key,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
"extract_strategy": None,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": None,
|
||||
"default_value": field.default_value,
|
||||
"linked_field_id": str(field.id),
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
"linked_field": {
|
||||
"id": str(field.id),
|
||||
"field_key": field.field_key,
|
||||
"label": field.label,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
"options": field.options,
|
||||
"default_value": field.default_value,
|
||||
"scope": field.scope,
|
||||
"is_filterable": field.is_filterable,
|
||||
"is_rank_feature": field.is_rank_feature,
|
||||
"field_roles": field.field_roles,
|
||||
"status": field.status,
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-10] Found {len(slot_with_fields)} slot definitions for role={role}"
|
||||
)
|
||||
|
||||
return slot_with_fields
|
||||
|
||||
async def get_resource_filter_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> list[MetadataFieldDefinition]:
|
||||
"""
|
||||
[AC-MRS-11] 获取资源过滤角色字段
|
||||
供 kb_search_dynamic 工具使用
|
||||
"""
|
||||
return await self.get_fields_by_role(
|
||||
tenant_id,
|
||||
FieldRole.RESOURCE_FILTER.value
|
||||
)
|
||||
|
||||
async def get_slot_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> list[MetadataFieldDefinition]:
|
||||
"""
|
||||
[AC-MRS-12] 获取槽位角色字段
|
||||
供 memory_recall 工具使用
|
||||
"""
|
||||
return await self.get_fields_by_role(
|
||||
tenant_id,
|
||||
FieldRole.SLOT.value
|
||||
)
|
||||
|
||||
async def get_routing_signal_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> list[MetadataFieldDefinition]:
|
||||
"""
|
||||
[AC-MRS-13] 获取路由信号角色字段
|
||||
供 intent_hint/high_risk_check 工具使用
|
||||
"""
|
||||
return await self.get_fields_by_role(
|
||||
tenant_id,
|
||||
FieldRole.ROUTING_SIGNAL.value
|
||||
)
|
||||
|
||||
async def get_prompt_var_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> list[MetadataFieldDefinition]:
|
||||
"""
|
||||
[AC-MRS-14] 获取提示词变量角色字段
|
||||
供 template_engine 使用
|
||||
"""
|
||||
return await self.get_fields_by_role(
|
||||
tenant_id,
|
||||
FieldRole.PROMPT_VAR.value
|
||||
)
|
||||
Loading…
Reference in New Issue