feat: add metadata discovery tool for dynamic metadata extraction [AC-METADATA-DISCOVERY]
This commit is contained in:
parent
812af6c7a1
commit
3b354ba041
|
|
@ -0,0 +1,281 @@
|
|||
"""
|
||||
Metadata Discovery Tool for Mid Platform.
|
||||
[AC-MARH-XX] 元数据发现工具,用于查询当前可用的元数据字段及其常见值。
|
||||
|
||||
核心特性:
|
||||
- 列出当前知识库文档中使用的元数据字段
|
||||
- 返回每个字段的常见取值(从现有文档中聚合)
|
||||
- 支持按知识库过滤
|
||||
- 返回字段定义信息(类型、用途说明等)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT_MS = 2000
|
||||
DEFAULT_TOP_VALUES = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFieldDiscoveryConfig:
|
||||
"""Configuration for metadata field discovery tool."""
|
||||
timeout_ms: int = DEFAULT_TIMEOUT_MS
|
||||
top_values_count: int = DEFAULT_TOP_VALUES
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFieldInfo:
|
||||
"""Information about a metadata field."""
|
||||
field_key: str
|
||||
field_type: str = "string"
|
||||
label: str = ""
|
||||
description: str | None = None
|
||||
common_values: list[str] = field(default_factory=list)
|
||||
value_count: int = 0
|
||||
is_filterable: bool = True
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataDiscoveryResult:
|
||||
"""Result of metadata discovery."""
|
||||
success: bool
|
||||
fields: list[MetadataFieldInfo] = field(default_factory=list)
|
||||
total_documents: int = 0
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
|
||||
|
||||
class MetadataDiscoveryTool:
|
||||
"""
|
||||
[AC-MARH-XX] 元数据发现工具。
|
||||
|
||||
用于查询当前知识库文档中使用的元数据字段及其常见值,
|
||||
帮助 AI 了解可用的过滤字段,从而更好地构造搜索请求。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MetadataFieldDiscoveryConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or MetadataFieldDiscoveryConfig()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
include_values: bool = True,
|
||||
top_n: int | None = None,
|
||||
) -> MetadataDiscoveryResult:
|
||||
"""
|
||||
Execute metadata discovery.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
kb_id: Optional knowledge base ID to filter
|
||||
include_values: Whether to include common values (default True)
|
||||
top_n: Number of top values to return per field (default from config)
|
||||
|
||||
Returns:
|
||||
MetadataDiscoveryResult with field information
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
top_n = top_n or self._config.top_values_count
|
||||
|
||||
field_definitions = await self._get_field_definitions(tenant_id)
|
||||
|
||||
document_metadata = await self._get_document_metadata(tenant_id, kb_id)
|
||||
|
||||
total_docs = len(document_metadata)
|
||||
|
||||
field_values: dict[str, Counter] = {}
|
||||
for doc_meta in document_metadata:
|
||||
if not doc_meta:
|
||||
continue
|
||||
for key, value in doc_meta.items():
|
||||
if key not in field_values:
|
||||
field_values[key] = Counter()
|
||||
str_value = str(value) if value is not None else ""
|
||||
field_values[key].update([str_value])
|
||||
|
||||
fields: list[MetadataFieldInfo] = []
|
||||
for field_key, values_counter in field_values.items():
|
||||
field_def = field_definitions.get(field_key)
|
||||
|
||||
common_values = []
|
||||
if include_values:
|
||||
most_common = values_counter.most_common(top_n)
|
||||
common_values = [v for v, _ in most_common if v]
|
||||
|
||||
field_info = MetadataFieldInfo(
|
||||
field_key=field_key,
|
||||
field_type=field_def.type if field_def else "string",
|
||||
label=field_def.label if field_def else field_key,
|
||||
description=field_def.usage_description if field_def else None,
|
||||
common_values=common_values,
|
||||
value_count=len(values_counter),
|
||||
is_filterable=field_def.is_filterable if field_def else True,
|
||||
options=field_def.options if field_def else None,
|
||||
)
|
||||
fields.append(field_info)
|
||||
|
||||
fields.sort(key=lambda f: f.value_count, reverse=True)
|
||||
|
||||
duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, "
|
||||
f"duration={duration_ms}ms"
|
||||
)
|
||||
|
||||
return MetadataDiscoveryResult(
|
||||
success=True,
|
||||
fields=fields,
|
||||
total_documents=total_docs,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataDiscovery] Discovery failed: {e}")
|
||||
return MetadataDiscoveryResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _get_field_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> dict[str, MetadataFieldDefinition]:
|
||||
"""Get field definitions for tenant."""
|
||||
stmt = select(MetadataFieldDefinition).where(
|
||||
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
definitions = result.scalars().all()
|
||||
|
||||
return {d.field_key: d for d in definitions}
|
||||
|
||||
async def _get_document_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all document metadata for tenant."""
|
||||
stmt = select(Document.doc_metadata).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
)
|
||||
if kb_id:
|
||||
stmt = stmt.where(Document.kb_id == kb_id)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [row for row in rows if row]
|
||||
|
||||
|
||||
def register_metadata_discovery_tool(
|
||||
registry: "ToolRegistry",
|
||||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: MetadataFieldDiscoveryConfig | None = None,
|
||||
) -> None:
|
||||
"""Register metadata discovery tool to registry."""
|
||||
from app.services.mid.tool_registry import ToolType
|
||||
|
||||
cfg = config or MetadataFieldDiscoveryConfig()
|
||||
|
||||
async def metadata_discovery_handler(
|
||||
tenant_id: str = "",
|
||||
kb_id: str | None = None,
|
||||
include_values: bool = True,
|
||||
top_n: int | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
"""Metadata discovery tool handler."""
|
||||
tool = MetadataDiscoveryTool(
|
||||
session=session,
|
||||
timeout_governor=timeout_governor,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
result = await tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
include_values=include_values,
|
||||
top_n=top_n,
|
||||
)
|
||||
# 将 dataclass 转换为 dict
|
||||
return {
|
||||
"success": result.success,
|
||||
"fields": [
|
||||
{
|
||||
"field_key": f.field_key,
|
||||
"field_type": f.field_type,
|
||||
"label": f.label,
|
||||
"description": f.description,
|
||||
"common_values": f.common_values,
|
||||
"value_count": f.value_count,
|
||||
"is_filterable": f.is_filterable,
|
||||
"options": f.options,
|
||||
}
|
||||
for f in result.fields
|
||||
],
|
||||
"total_documents": result.total_documents,
|
||||
"error": result.error,
|
||||
"duration_ms": result.duration_ms,
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name="list_document_metadata_fields",
|
||||
description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤",
|
||||
handler=metadata_discovery_handler,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
auth_required=False,
|
||||
timeout_ms=cfg.timeout_ms,
|
||||
enabled=True,
|
||||
metadata={
|
||||
"when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。",
|
||||
"when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tenant_id": {"type": "string", "description": "租户 ID"},
|
||||
"kb_id": {"type": "string", "description": "知识库 ID(可选,用于限定范围)"},
|
||||
"include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"},
|
||||
"top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"example_action_input": {
|
||||
"include_values": True,
|
||||
"top_n": 5,
|
||||
},
|
||||
"result_interpretation": "fields 数组包含每个字段的详细信息;common_values 是该字段在文档中的常见取值;value_count 表示该字段在多少文档中出现。",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")
|
||||
Loading…
Reference in New Issue