[AC-METADATA-INFERENCE] feat(metadata): 新增元数据自动推断服务

- 新增 metadata_auto_inference_service 实现元数据自动推断
- 新增 kb_metadata_inference 提供知识库元数据推断工具
- 支持从文档内容自动提取元数据字段
- 集成缓存机制提升推断效率
This commit is contained in:
MerCry 2026-03-11 19:06:21 +08:00
parent b3343f9e52
commit 9196247578
2 changed files with 698 additions and 0 deletions

View File

@ -0,0 +1,202 @@
"""
KB document metadata inference using LLM.
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.base import LLMConfig
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
logger = logging.getLogger(__name__)
_MAX_CONTENT_CHARS = 2000
def _extract_json_object(text: str) -> dict[str, Any] | None:
candidates: list[str] = []
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
if code_block_match:
candidates.append(code_block_match.group(1).strip())
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
if fence_match:
candidates.append(fence_match.group(1).strip())
brace_match = re.search(r"\{[\s\S]*\}", text)
if brace_match:
candidates.append(brace_match.group(0).strip())
for candidate in candidates:
if not candidate:
continue
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
fixed = candidate.replace("'", '"')
try:
obj = json.loads(fixed)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
continue
return None
def _truncate_content(text: str) -> str:
if len(text) <= _MAX_CONTENT_CHARS * 2:
return text
head = text[:_MAX_CONTENT_CHARS]
tail = text[-_MAX_CONTENT_CHARS:]
return f"{head}\n\n...\n\n{tail}"
class KBMetadataInferenceService:
"""Infer document metadata based on markdown content."""
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
self._session = session
self._max_tokens = max_tokens
self._temperature = temperature
async def infer_metadata(
self,
tenant_id: str,
content: str,
filename: str | None = None,
kb_id: str | None = None,
) -> dict[str, Any]:
field_def_service = MetadataFieldDefinitionService(self._session)
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
if not field_defs:
return {}
discovery_tool = MetadataDiscoveryTool(self._session)
discovery_result = await discovery_tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=True,
top_n=5,
)
common_values_map = {
field.field_key: field.common_values
for field in (discovery_result.fields if discovery_result.success else [])
}
fields_payload = []
for field_def in field_defs:
fields_payload.append({
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options or [],
"common_values": common_values_map.get(field_def.field_key, []),
})
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
要求
1) 只能使用提供的字段
2) 如果字段有 options common_values优先从中选择最匹配的值
3) 不确定的字段不要填写
4) 输出必须是严格 JSON 对象只包含推断出的字段
5) 不要输出多余说明
可用字段定义JSON 数组
{json.dumps(fields_payload, ensure_ascii=False)}
文件名{filename or "unknown"}
Markdown 内容
{_truncate_content(content)}
""".strip()
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
return {}
try:
response = await llm_client.generate(
messages=[
{"role": "system", "content": "你是严格的 JSON 生成器。"},
{"role": "user", "content": prompt},
],
config=LLMConfig(
max_tokens=self._max_tokens,
temperature=self._temperature,
),
)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
return {}
if not response.content:
return {}
inferred = _extract_json_object(response.content)
if not inferred:
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
return {}
cleaned = self._clean_inferred_values(inferred, field_defs)
if not cleaned:
return {}
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, cleaned, "kb_document"
)
if not is_valid:
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
return {}
return cleaned
def _clean_inferred_values(
self,
inferred: dict[str, Any],
field_defs: list[Any],
) -> dict[str, Any]:
field_map = {f.field_key: f for f in field_defs}
cleaned: dict[str, Any] = {}
for key, value in inferred.items():
field_def = field_map.get(key)
if not field_def:
continue
if value is None or value == "":
continue
if field_def.type in {"enum", "array_enum"} and field_def.options:
if field_def.type == "enum":
if value not in field_def.options:
continue
else:
if not isinstance(value, list):
continue
filtered = [v for v in value if v in field_def.options]
if not filtered:
continue
value = filtered
cleaned[key] = value
return cleaned

View File

@ -0,0 +1,496 @@
"""
Metadata Auto Inference Service.
自动推断文档元数据的服务支持图片和文本格式
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.factory import get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.metadata_cache_service import get_metadata_cache_service
logger = logging.getLogger(__name__)
_field_definitions_cache: dict[str, list[Any]] = {}
_cache_ttl_seconds = 300
_last_cache_refresh: dict[str, float] = {}
@dataclass
class InferenceFieldContext:
"""推断字段的上下文信息"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
description: str | None = None
@dataclass
class AutoInferenceResult:
"""自动推断结果"""
inferred_metadata: dict[str, Any]
confidence_scores: dict[str, float]
raw_response: str
success: bool
error_message: str | None = None
METADATA_INFERENCE_SYSTEM_PROMPT = """你是一个专业的文档元数据分析助手。你的任务是根据文档内容,自动推断并填写元数据字段。
## 输出要求
请严格按照以下 JSON 格式输出不要添加任何其他内容
```json
{
"inferred_metadata": {
"字段键名1": "推断的值1",
"字段键名2": "推断的值2"
},
"confidence_scores": {
"字段键名1": 0.95,
"字段键名2": 0.80
}
}
```
## 推断规则
1. **仔细分析文档内容**根据文档的主题关键词上下文来推断元数据
2. **遵循字段定义**
- 对于枚举类型(enum)必须从给定的选项中选择
- 对于数组枚举类型(array_enum)可以选择多个选项
- 对于数字类型(number)输出数字
- 对于布尔类型(boolean)输出 true false
- 对于文本类型(text)输出字符串
3. **置信度评分**
- 0.9-1.0: 非常确定
- 0.7-0.9: 比较确定
- 0.5-0.7: 有一定依据但不确定
- 0.0-0.5: 猜测或无法确定
4. **无法推断时**如果无法从文档内容中合理推断某个字段可以不填写该字段
## 注意事项
- 必须严格按照字段定义的类型和选项填写
- 不要编造不存在的选项值
- 保持客观基于文档内容推断"""
class MetadataAutoInferenceService:
"""
元数据自动推断服务
功能
1. 获取租户配置的元数据字段定义
2. 使用 LLM 根据文档内容自动推断元数据
3. 验证推断结果符合字段定义
使用场景
- 图片上传时自动推断元数据
- Markdown/文本上传时自动推断元数据
"""
def __init__(
self,
session: AsyncSession,
model: str | None = None,
max_tokens: int = 1024,
timeout_seconds: int = 60,
):
self._session = session
self._model = model
self._max_tokens = max_tokens
self._timeout_seconds = timeout_seconds
self._field_def_service = MetadataFieldDefinitionService(session)
async def infer_metadata(
self,
tenant_id: str,
content: str,
scope: str = "kb_document",
existing_metadata: dict[str, Any] | None = None,
image_base64: str | None = None,
mime_type: str | None = None,
) -> AutoInferenceResult:
"""
自动推断文档元数据
Args:
tenant_id: 租户 ID
content: 文档文本内容
scope: 元数据作用范围
existing_metadata: 已有的元数据用户手动填写的会覆盖推断结果
image_base64: 图片的 base64 编码如果是图片
mime_type: 图片的 MIME 类型
Returns:
AutoInferenceResult 包含推断的元数据
"""
logger.info(
f"[MetadataAutoInference] Starting inference: tenant={tenant_id}, "
f"content_length={len(content)}, scope={scope}"
)
field_definitions = await self._get_field_definitions_with_cache(tenant_id, scope)
if not field_definitions:
logger.info(f"[MetadataAutoInference] No field definitions found for tenant={tenant_id}")
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=True,
error_message="No field definitions configured",
)
field_contexts = self._build_field_contexts(field_definitions)
if not field_contexts:
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=True,
)
user_prompt = self._build_user_prompt(content, field_contexts, existing_metadata)
try:
if image_base64 and mime_type:
raw_response = await self._call_multimodal_llm(
user_prompt, image_base64, mime_type
)
else:
raw_response = await self._call_text_llm(user_prompt)
result = self._parse_llm_response(raw_response, field_contexts)
if existing_metadata:
result.inferred_metadata.update(existing_metadata)
logger.info(
f"[MetadataAutoInference] Inference completed: "
f"inferred_fields={list(result.inferred_metadata.keys())}, "
f"avg_confidence={sum(result.confidence_scores.values()) / len(result.confidence_scores) if result.confidence_scores else 0:.2f}"
)
return result
except Exception as e:
logger.error(f"[MetadataAutoInference] Inference failed: {e}")
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=False,
error_message=str(e),
)
def _build_field_contexts(
self,
field_definitions: list[Any],
) -> list[InferenceFieldContext]:
"""构建字段上下文列表"""
contexts = []
for f in field_definitions:
ctx = InferenceFieldContext(
field_key=f.field_key,
label=f.label,
type=f.type,
required=f.required,
options=f.options,
description=getattr(f, 'description', None),
)
contexts.append(ctx)
return contexts
async def _get_field_definitions_with_cache(
self,
tenant_id: str,
scope: str,
) -> list[Any]:
"""
获取字段定义带缓存
优先级
1. Redis 缓存
2. 本地内存缓存
3. 数据库查询
Args:
tenant_id: 租户 ID
scope: 作用范围
Returns:
字段定义列表
"""
import time
cache_key = f"{tenant_id}:{scope}"
try:
redis_cache = await get_metadata_cache_service()
cached_fields = await redis_cache.get_fields(tenant_id)
if cached_fields:
logger.info(f"[MetadataAutoInference] Redis cache hit for tenant={tenant_id}")
return [self._dict_to_field_def(f) for f in cached_fields]
except Exception as e:
logger.warning(f"[MetadataAutoInference] Redis cache error: {e}")
current_time = time.time()
if cache_key in _field_definitions_cache:
last_refresh = _last_cache_refresh.get(cache_key, 0)
if current_time - last_refresh < _cache_ttl_seconds:
logger.info(f"[MetadataAutoInference] Local cache hit for tenant={tenant_id}")
return _field_definitions_cache[cache_key]
logger.info(f"[MetadataAutoInference] Cache miss, querying database for tenant={tenant_id}")
field_definitions = await self._field_def_service.get_active_field_definitions(
tenant_id, scope
)
_field_definitions_cache[cache_key] = field_definitions
_last_cache_refresh[cache_key] = current_time
try:
redis_cache = await get_metadata_cache_service()
await redis_cache.set_fields(
tenant_id,
[self._field_def_to_dict(f) for f in field_definitions]
)
except Exception as e:
logger.warning(f"[MetadataAutoInference] Failed to update Redis cache: {e}")
return field_definitions
def _field_def_to_dict(self, field_def: Any) -> dict[str, Any]:
"""将字段定义转换为字典"""
return {
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options,
}
def _dict_to_field_def(self, data: dict[str, Any]) -> Any:
"""将字典转换为字段定义对象"""
from dataclasses import dataclass
@dataclass
class CachedFieldDefinition:
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
return CachedFieldDefinition(
field_key=data["field_key"],
label=data["label"],
type=data["type"],
required=data["required"],
options=data.get("options"),
)
def _build_user_prompt(
self,
content: str,
field_contexts: list[InferenceFieldContext],
existing_metadata: dict[str, Any] | None = None,
) -> str:
"""构建用户提示词"""
field_descriptions = []
for ctx in field_contexts:
desc = f"- **{ctx.label}** ({ctx.field_key})"
desc += f"\n - 类型: {ctx.type}"
desc += f"\n - 必填: {'' if ctx.required else ''}"
if ctx.options:
desc += f"\n - 可选值: {', '.join(ctx.options)}"
if existing_metadata and ctx.field_key in existing_metadata:
desc += f"\n - 已有值: {existing_metadata[ctx.field_key]}"
field_descriptions.append(desc)
fields_text = "\n".join(field_descriptions)
prompt = f"""请分析以下文档内容,并推断相应的元数据字段。
## 待推断的字段定义
{fields_text}
## 文档内容
{content[:4000]}
请根据文档内容推断上述字段的值并输出 JSON 格式的结果"""
return prompt
async def _call_text_llm(self, prompt: str) -> str:
"""调用文本 LLM"""
manager = get_llm_config_manager()
client = manager.get_kb_processing_client()
config = manager.kb_processing_config
model = self._model or config.get("model", "gpt-4o-mini")
from app.services.llm.base import LLMConfig
llm_config = LLMConfig(
model=model,
max_tokens=self._max_tokens,
temperature=0.3,
timeout_seconds=self._timeout_seconds,
)
messages = [
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
response = await client.generate(messages=messages, config=llm_config)
return response.content or ""
async def _call_multimodal_llm(
self,
prompt: str,
image_base64: str,
mime_type: str,
) -> str:
"""调用多模态 LLM"""
manager = get_llm_config_manager()
client = manager.get_kb_processing_client()
config = manager.kb_processing_config
model = self._model or config.get("model", "gpt-4o-mini")
from app.services.llm.base import LLMConfig
llm_config = LLMConfig(
model=model,
max_tokens=self._max_tokens,
temperature=0.3,
timeout_seconds=self._timeout_seconds,
)
messages = [
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_base64}",
},
},
],
},
]
response = await client.generate(messages=messages, config=llm_config)
return response.content or ""
def _parse_llm_response(
self,
response: str,
field_contexts: list[InferenceFieldContext],
) -> AutoInferenceResult:
"""解析 LLM 响应"""
try:
json_str = self._extract_json(response)
data = json.loads(json_str)
inferred_metadata = data.get("inferred_metadata", {})
confidence_scores = data.get("confidence_scores", {})
field_map = {ctx.field_key: ctx for ctx in field_contexts}
validated_metadata = {}
validated_scores = {}
for field_key, value in inferred_metadata.items():
if field_key not in field_map:
continue
ctx = field_map[field_key]
validated_value = self._validate_field_value(ctx, value)
if validated_value is not None:
validated_metadata[field_key] = validated_value
validated_scores[field_key] = confidence_scores.get(field_key, 0.5)
return AutoInferenceResult(
inferred_metadata=validated_metadata,
confidence_scores=validated_scores,
raw_response=response,
success=True,
)
except json.JSONDecodeError as e:
logger.warning(f"[MetadataAutoInference] Failed to parse JSON: {e}")
return AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response=response,
success=False,
error_message=f"JSON parse error: {e}",
)
def _validate_field_value(
self,
ctx: InferenceFieldContext,
value: Any,
) -> Any:
"""验证并转换字段值"""
if value is None:
return None
from app.models.entities import MetadataFieldType
if ctx.type == MetadataFieldType.NUMBER.value:
try:
return float(value) if isinstance(value, str) else value
except (ValueError, TypeError):
return None
elif ctx.type == MetadataFieldType.BOOLEAN.value:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes")
return bool(value)
elif ctx.type == MetadataFieldType.ENUM.value:
if ctx.options and value in ctx.options:
return value
return None
elif ctx.type == MetadataFieldType.ARRAY_ENUM.value:
if not isinstance(value, list):
value = [value] if value else []
if ctx.options:
return [v for v in value if v in ctx.options]
return value
else:
return str(value) if value is not None else None
def _extract_json(self, content: str) -> str:
"""从响应中提取 JSON"""
content = content.strip()
if content.startswith("{") and content.endswith("}"):
return content
json_start = content.find("{")
json_end = content.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
return content[json_start:json_end + 1]
return content