feat: add metadata discovery tool for dynamic metadata extraction [AC-METADATA-DISCOVERY]
This commit is contained in:
parent
812af6c7a1
commit
3b354ba041
|
|
@ -0,0 +1,281 @@
|
||||||
|
"""
|
||||||
|
Metadata Discovery Tool for Mid Platform.
|
||||||
|
[AC-MARH-XX] 元数据发现工具,用于查询当前可用的元数据字段及其常见值。
|
||||||
|
|
||||||
|
核心特性:
|
||||||
|
- 列出当前知识库文档中使用的元数据字段
|
||||||
|
- 返回每个字段的常见取值(从现有文档中聚合)
|
||||||
|
- 支持按知识库过滤
|
||||||
|
- 返回字段定义信息(类型、用途说明等)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.entities import Document, MetadataFieldDefinition, MetadataFieldStatus
|
||||||
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mid.tool_registry import ToolRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT_MS = 2000
|
||||||
|
DEFAULT_TOP_VALUES = 10
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataFieldDiscoveryConfig:
|
||||||
|
"""Configuration for metadata field discovery tool."""
|
||||||
|
timeout_ms: int = DEFAULT_TIMEOUT_MS
|
||||||
|
top_values_count: int = DEFAULT_TOP_VALUES
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataFieldInfo:
|
||||||
|
"""Information about a metadata field."""
|
||||||
|
field_key: str
|
||||||
|
field_type: str = "string"
|
||||||
|
label: str = ""
|
||||||
|
description: str | None = None
|
||||||
|
common_values: list[str] = field(default_factory=list)
|
||||||
|
value_count: int = 0
|
||||||
|
is_filterable: bool = True
|
||||||
|
options: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataDiscoveryResult:
|
||||||
|
"""Result of metadata discovery."""
|
||||||
|
success: bool
|
||||||
|
fields: list[MetadataFieldInfo] = field(default_factory=list)
|
||||||
|
total_documents: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
duration_ms: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataDiscoveryTool:
|
||||||
|
"""
|
||||||
|
[AC-MARH-XX] 元数据发现工具。
|
||||||
|
|
||||||
|
用于查询当前知识库文档中使用的元数据字段及其常见值,
|
||||||
|
帮助 AI 了解可用的过滤字段,从而更好地构造搜索请求。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
timeout_governor: TimeoutGovernor | None = None,
|
||||||
|
config: MetadataFieldDiscoveryConfig | None = None,
|
||||||
|
):
|
||||||
|
self._session = session
|
||||||
|
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||||
|
self._config = config or MetadataFieldDiscoveryConfig()
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
kb_id: str | None = None,
|
||||||
|
include_values: bool = True,
|
||||||
|
top_n: int | None = None,
|
||||||
|
) -> MetadataDiscoveryResult:
|
||||||
|
"""
|
||||||
|
Execute metadata discovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
kb_id: Optional knowledge base ID to filter
|
||||||
|
include_values: Whether to include common values (default True)
|
||||||
|
top_n: Number of top values to return per field (default from config)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MetadataDiscoveryResult with field information
|
||||||
|
"""
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
top_n = top_n or self._config.top_values_count
|
||||||
|
|
||||||
|
field_definitions = await self._get_field_definitions(tenant_id)
|
||||||
|
|
||||||
|
document_metadata = await self._get_document_metadata(tenant_id, kb_id)
|
||||||
|
|
||||||
|
total_docs = len(document_metadata)
|
||||||
|
|
||||||
|
field_values: dict[str, Counter] = {}
|
||||||
|
for doc_meta in document_metadata:
|
||||||
|
if not doc_meta:
|
||||||
|
continue
|
||||||
|
for key, value in doc_meta.items():
|
||||||
|
if key not in field_values:
|
||||||
|
field_values[key] = Counter()
|
||||||
|
str_value = str(value) if value is not None else ""
|
||||||
|
field_values[key].update([str_value])
|
||||||
|
|
||||||
|
fields: list[MetadataFieldInfo] = []
|
||||||
|
for field_key, values_counter in field_values.items():
|
||||||
|
field_def = field_definitions.get(field_key)
|
||||||
|
|
||||||
|
common_values = []
|
||||||
|
if include_values:
|
||||||
|
most_common = values_counter.most_common(top_n)
|
||||||
|
common_values = [v for v, _ in most_common if v]
|
||||||
|
|
||||||
|
field_info = MetadataFieldInfo(
|
||||||
|
field_key=field_key,
|
||||||
|
field_type=field_def.type if field_def else "string",
|
||||||
|
label=field_def.label if field_def else field_key,
|
||||||
|
description=field_def.usage_description if field_def else None,
|
||||||
|
common_values=common_values,
|
||||||
|
value_count=len(values_counter),
|
||||||
|
is_filterable=field_def.is_filterable if field_def else True,
|
||||||
|
options=field_def.options if field_def else None,
|
||||||
|
)
|
||||||
|
fields.append(field_info)
|
||||||
|
|
||||||
|
fields.sort(key=lambda f: f.value_count, reverse=True)
|
||||||
|
|
||||||
|
duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MetadataDiscovery] Discovered {len(fields)} fields from {total_docs} documents, "
|
||||||
|
f"duration={duration_ms}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MetadataDiscoveryResult(
|
||||||
|
success=True,
|
||||||
|
fields=fields,
|
||||||
|
total_documents=total_docs,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[MetadataDiscovery] Discovery failed: {e}")
|
||||||
|
return MetadataDiscoveryResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_field_definitions(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> dict[str, MetadataFieldDefinition]:
|
||||||
|
"""Get field definitions for tenant."""
|
||||||
|
stmt = select(MetadataFieldDefinition).where(
|
||||||
|
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||||
|
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value,
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
definitions = result.scalars().all()
|
||||||
|
|
||||||
|
return {d.field_key: d for d in definitions}
|
||||||
|
|
||||||
|
async def _get_document_metadata(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
kb_id: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get all document metadata for tenant."""
|
||||||
|
stmt = select(Document.doc_metadata).where(
|
||||||
|
Document.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
if kb_id:
|
||||||
|
stmt = stmt.where(Document.kb_id == kb_id)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
return [row for row in rows if row]
|
||||||
|
|
||||||
|
|
||||||
|
def register_metadata_discovery_tool(
|
||||||
|
registry: "ToolRegistry",
|
||||||
|
session: AsyncSession,
|
||||||
|
timeout_governor: TimeoutGovernor | None = None,
|
||||||
|
config: MetadataFieldDiscoveryConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register metadata discovery tool to registry."""
|
||||||
|
from app.services.mid.tool_registry import ToolType
|
||||||
|
|
||||||
|
cfg = config or MetadataFieldDiscoveryConfig()
|
||||||
|
|
||||||
|
async def metadata_discovery_handler(
|
||||||
|
tenant_id: str = "",
|
||||||
|
kb_id: str | None = None,
|
||||||
|
include_values: bool = True,
|
||||||
|
top_n: int | None = None,
|
||||||
|
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Metadata discovery tool handler."""
|
||||||
|
tool = MetadataDiscoveryTool(
|
||||||
|
session=session,
|
||||||
|
timeout_governor=timeout_governor,
|
||||||
|
config=cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await tool.execute(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
kb_id=kb_id,
|
||||||
|
include_values=include_values,
|
||||||
|
top_n=top_n,
|
||||||
|
)
|
||||||
|
# 将 dataclass 转换为 dict
|
||||||
|
return {
|
||||||
|
"success": result.success,
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"field_key": f.field_key,
|
||||||
|
"field_type": f.field_type,
|
||||||
|
"label": f.label,
|
||||||
|
"description": f.description,
|
||||||
|
"common_values": f.common_values,
|
||||||
|
"value_count": f.value_count,
|
||||||
|
"is_filterable": f.is_filterable,
|
||||||
|
"options": f.options,
|
||||||
|
}
|
||||||
|
for f in result.fields
|
||||||
|
],
|
||||||
|
"total_documents": result.total_documents,
|
||||||
|
"error": result.error,
|
||||||
|
"duration_ms": result.duration_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
name="list_document_metadata_fields",
|
||||||
|
description="列出当前知识库文档中使用的元数据字段及其常见取值,用于后续的知识库搜索过滤",
|
||||||
|
handler=metadata_discovery_handler,
|
||||||
|
tool_type=ToolType.INTERNAL,
|
||||||
|
version="1.0.0",
|
||||||
|
auth_required=False,
|
||||||
|
timeout_ms=cfg.timeout_ms,
|
||||||
|
enabled=True,
|
||||||
|
metadata={
|
||||||
|
"when_to_use": "当需要了解知识库中有哪些可用的元数据过滤字段时使用。",
|
||||||
|
"when_not_to_use": "当已知可用的过滤字段,或不需要元数据过滤时不需要调用。",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tenant_id": {"type": "string", "description": "租户 ID"},
|
||||||
|
"kb_id": {"type": "string", "description": "知识库 ID(可选,用于限定范围)"},
|
||||||
|
"include_values": {"type": "boolean", "description": "是否包含常见值列表,默认 true"},
|
||||||
|
"top_n": {"type": "integer", "description": "每个字段返回的常见值数量,默认 10"},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
"example_action_input": {
|
||||||
|
"include_values": True,
|
||||||
|
"top_n": 5,
|
||||||
|
},
|
||||||
|
"result_interpretation": "fields 数组包含每个字段的详细信息;common_values 是该字段在文档中的常见取值;value_count 表示该字段在多少文档中出现。",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("[MetadataDiscovery] Tool registered: list_document_metadata_fields")
|
||||||
Loading…
Reference in New Issue