From 3b354ba041602d58db3cb97748d808951fdf962a Mon Sep 17 00:00:00 2001 From: MerCry Date: Tue, 10 Mar 2026 12:11:31 +0800 Subject: [PATCH] feat: add metadata discovery tool for dynamic metadata extraction [AC-METADATA-DISCOVERY] --- .../services/mid/metadata_discovery_tool.py | 281 ++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 ai-service/app/services/mid/metadata_discovery_tool.py diff --git a/ai-service/app/services/mid/metadata_discovery_tool.py b/ai-service/app/services/mid/metadata_discovery_tool.py new file mode 100644 index 0000000..51c525f --- /dev/null +++ b/ai-service/app/services/mid/metadata_discovery_tool.py @@ -0,0 +1,281 @@ +""" +Metadata Discovery Tool for Mid Platform. +[AC-MARH-XX] 元数据发现工具,用于查询当前可用的元数据字段及其常见值。 + +核心特性: +- 列出当前知识库文档中使用的元数据字段 +- 返回每个字段的常见取值(从现有文档中聚合) +- 支持按知识库过滤 +- 返回字段定义信息(类型、用途说明等) +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import Counter +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus +from app.services.mid.timeout_governor import TimeoutGovernor + +if TYPE_CHECKING: + from app.services.mid.tool_registry import ToolRegistry + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT_MS = 2000 +DEFAULT_TOP_VALUES = 10 + + +@dataclass +class MetadataFieldDiscoveryConfig: + """Configuration for metadata field discovery tool.""" + timeout_ms: int = DEFAULT_TIMEOUT_MS + top_values_count: int = DEFAULT_TOP_VALUES + + +@dataclass +class MetadataFieldInfo: + """Information about a metadata field.""" + field_key: str + field_type: str = "string" + label: str = "" + description: str | None = None + common_values: list[str] = field(default_factory=list) + value_count: int = 0 + is_filterable: bool = True + options: list[str] | None = None + + +@dataclass +class MetadataDiscoveryResult: + """Result of metadata discovery.""" + success: bool + fields: list[MetadataFieldInfo] = field(default_factory=list) + total_documents: int = 0 + error: str | None = None + duration_ms: int = 0 + + +class MetadataDiscoveryTool: + """ + [AC-MARH-XX] 元数据发现工具。 + + 用于查询当前知识库文档中使用的元数据字段及其常见值, + 帮助 AI 了解可用的过滤字段,从而更好地构造搜索请求。 + """ + + def __init__( + self, + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: MetadataFieldDiscoveryConfig | None = None, + ): + self._session = session + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._config = config or MetadataFieldDiscoveryConfig() + + async def execute( + self, + tenant_id: str, + kb_id: str | None = None, + include_values: bool = True, + top_n: int | None = None, + ) -> MetadataDiscoveryResult: + """ + Execute metadata discovery. + + Args: + tenant_id: Tenant ID + kb_id: Optional knowledge base ID to filter + include_values: Whether to include common values (default True) + top_n: Number of top values to return per field (default from config) + + Returns: + MetadataDiscoveryResult with field information + """ + start_time = asyncio.get_event_loop().time() + + try: + top_n = top_n or self._config.top_values_count + + field_definitions = await self._get_field_definitions(tenant_id) + + document_metadata = await self._get_document_metadata(tenant_id, kb_id) + + total_docs = len(document_metadata) + + field_values: dict[str, Counter] = {} + for doc_meta in document_metadata: + if not doc_meta: + continue + for key, value in doc_meta.items(): + if key not in field_values: + field_values[key] = Counter() + str_value = str(value) if value is not None else "" + field_values[key].update([str_value]) + + fields: list[MetadataFieldInfo] = [] + for field_key, values_counter in field_values.items(): + field_def = field_definitions.get(field_key) + + common_values = [] + if include_values: + most_common = values_counter.most_common(top_n) + common_values = [v for v, _ in most_common if v] + + field_info = MetadataFieldInfo( + field_key=field_key, + field_type=field_def.type if field_def else "string", + label=field_def.label if field_def else field_key, + description=field_def.usage_description if field_def else None, + common_values=common_values, + value_count=len(values_counter), + is_filterable=field_def.is_filterable if field_def else True, + options=field_def.options if field_def else None, + ) + fields.append(field_info) + + fields.sort(key=lambda f: f.value_count, reverse=True) + + duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000) + + logger.info( + f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, " + f"duration={duration_ms}ms" + ) + + return MetadataDiscoveryResult( + success=True, + fields=fields, + total_documents=total_docs, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"[MetadataDiscovery] Discovery failed: {e}") + return MetadataDiscoveryResult( + success=False, + error=str(e), + ) + + async def _get_field_definitions( + self, + tenant_id: str, + ) -> dict[str, MetadataFieldDefinition]: + """Get field definitions for tenant.""" + stmt = select(MetadataFieldDefinition).where( + MetadataFieldDefinition.tenant_id == tenant_id, + MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value, + ) + result = await self._session.execute(stmt) + definitions = result.scalars().all() + + return {d.field_key: d for d in definitions} + + async def _get_document_metadata( + self, + tenant_id: str, + kb_id: str | None = None, + ) -> list[dict[str, Any]]: + """Get all document metadata for tenant.""" + stmt = select(Document.doc_metadata).where( + Document.tenant_id == tenant_id, + ) + if kb_id: + stmt = stmt.where(Document.kb_id == kb_id) + + result = await self._session.execute(stmt) + rows = result.scalars().all() + + return [row for row in rows if row] + + +def register_metadata_discovery_tool( + registry: "ToolRegistry", + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: MetadataFieldDiscoveryConfig | None = None, +) -> None: + """Register metadata discovery tool to registry.""" + from app.services.mid.tool_registry import ToolType + + cfg = config or MetadataFieldDiscoveryConfig() + + async def metadata_discovery_handler( + tenant_id: str = "", + kb_id: str | None = None, + include_values: bool = True, + top_n: int | None = None, + **kwargs, # 接受系统注入的额外参数(user_id, session_id 等) + ) -> dict[str, Any]: + """Metadata discovery tool handler.""" + tool = MetadataDiscoveryTool( + session=session, + timeout_governor=timeout_governor, + config=cfg, + ) + + result = await tool.execute( + tenant_id=tenant_id, + kb_id=kb_id, + include_values=include_values, + top_n=top_n, + ) + # 将 dataclass 转换为 dict + return { + "success": result.success, + "fields": [ + { + "field_key": f.field_key, + "field_type": f.field_type, + "label": f.label, + "description": f.description, + "common_values": f.common_values, + "value_count": f.value_count, + "is_filterable": f.is_filterable, + "options": f.options, + } + for f in result.fields + ], + "total_documents": result.total_documents, + "error": result.error, + "duration_ms": result.duration_ms, + } + + registry.register( + name="list_document_metadata_fields", + description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤", + handler=metadata_discovery_handler, + tool_type=ToolType.INTERNAL, + version="1.0.0", + auth_required=False, + timeout_ms=cfg.timeout_ms, + enabled=True, + metadata={ + "when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。", + "when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。", + "parameters": { + "type": "object", + "properties": { + "tenant_id": {"type": "string", "description": "租户 ID"}, + "kb_id": {"type": "string", "description": "知识库 ID(可选,用于限定范围)"}, + "include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"}, + "top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"}, + }, + "required": [], + }, + "example_action_input": { + "include_values": True, + "top_n": 5, + }, + "result_interpretation": "fields 数组包含每个字段的详细信息;common_values 是该字段在文档中的常见取值;value_count 表示该字段在多少文档中出现。", + }, + ) + + logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")