ai-robot-core/ai-service/app/services/kb_metadata_inference.py

203 lines
6.4 KiB
Python
Raw Normal View History

"""
KB document metadata inference using LLM.
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.base import LLMConfig
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
logger = logging.getLogger(__name__)
_MAX_CONTENT_CHARS = 2000
def _extract_json_object(text: str) -> dict[str, Any] | None:
candidates: list[str] = []
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
if code_block_match:
candidates.append(code_block_match.group(1).strip())
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
if fence_match:
candidates.append(fence_match.group(1).strip())
brace_match = re.search(r"\{[\s\S]*\}", text)
if brace_match:
candidates.append(brace_match.group(0).strip())
for candidate in candidates:
if not candidate:
continue
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
fixed = candidate.replace("'", '"')
try:
obj = json.loads(fixed)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
continue
return None
def _truncate_content(text: str) -> str:
if len(text) <= _MAX_CONTENT_CHARS * 2:
return text
head = text[:_MAX_CONTENT_CHARS]
tail = text[-_MAX_CONTENT_CHARS:]
return f"{head}\n\n...\n\n{tail}"
class KBMetadataInferenceService:
"""Infer document metadata based on markdown content."""
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
self._session = session
self._max_tokens = max_tokens
self._temperature = temperature
async def infer_metadata(
self,
tenant_id: str,
content: str,
filename: str | None = None,
kb_id: str | None = None,
) -> dict[str, Any]:
field_def_service = MetadataFieldDefinitionService(self._session)
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
if not field_defs:
return {}
discovery_tool = MetadataDiscoveryTool(self._session)
discovery_result = await discovery_tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=True,
top_n=5,
)
common_values_map = {
field.field_key: field.common_values
for field in (discovery_result.fields if discovery_result.success else [])
}
fields_payload = []
for field_def in field_defs:
fields_payload.append({
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options or [],
"common_values": common_values_map.get(field_def.field_key, []),
})
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
要求
1) 只能使用提供的字段
2) 如果字段有 options common_values优先从中选择最匹配的值
3) 不确定的字段不要填写
4) 输出必须是严格 JSON 对象只包含推断出的字段
5) 不要输出多余说明
可用字段定义JSON 数组
{json.dumps(fields_payload, ensure_ascii=False)}
文件名{filename or "unknown"}
Markdown 内容
{_truncate_content(content)}
""".strip()
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
return {}
try:
response = await llm_client.generate(
messages=[
{"role": "system", "content": "你是严格的 JSON 生成器。"},
{"role": "user", "content": prompt},
],
config=LLMConfig(
max_tokens=self._max_tokens,
temperature=self._temperature,
),
)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
return {}
if not response.content:
return {}
inferred = _extract_json_object(response.content)
if not inferred:
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
return {}
cleaned = self._clean_inferred_values(inferred, field_defs)
if not cleaned:
return {}
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, cleaned, "kb_document"
)
if not is_valid:
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
return {}
return cleaned
def _clean_inferred_values(
self,
inferred: dict[str, Any],
field_defs: list[Any],
) -> dict[str, Any]:
field_map = {f.field_key: f for f in field_defs}
cleaned: dict[str, Any] = {}
for key, value in inferred.items():
field_def = field_map.get(key)
if not field_def:
continue
if value is None or value == "":
continue
if field_def.type in {"enum", "array_enum"} and field_def.options:
if field_def.type == "enum":
if value not in field_def.options:
continue
else:
if not isinstance(value, list):
continue
filtered = [v for v in value if v in field_def.options]
if not filtered:
continue
value = filtered
cleaned[key] = value
return cleaned