ai-robot-core/ai-service/app/services/kb_metadata_inference.py

203 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
KB document metadata inference using LLM.
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.base import LLMConfig
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
logger = logging.getLogger(__name__)
_MAX_CONTENT_CHARS = 2000
def _extract_json_object(text: str) -> dict[str, Any] | None:
candidates: list[str] = []
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
if code_block_match:
candidates.append(code_block_match.group(1).strip())
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
if fence_match:
candidates.append(fence_match.group(1).strip())
brace_match = re.search(r"\{[\s\S]*\}", text)
if brace_match:
candidates.append(brace_match.group(0).strip())
for candidate in candidates:
if not candidate:
continue
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
fixed = candidate.replace("'", '"')
try:
obj = json.loads(fixed)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
continue
return None
def _truncate_content(text: str) -> str:
if len(text) <= _MAX_CONTENT_CHARS * 2:
return text
head = text[:_MAX_CONTENT_CHARS]
tail = text[-_MAX_CONTENT_CHARS:]
return f"{head}\n\n...\n\n{tail}"
class KBMetadataInferenceService:
"""Infer document metadata based on markdown content."""
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
self._session = session
self._max_tokens = max_tokens
self._temperature = temperature
async def infer_metadata(
self,
tenant_id: str,
content: str,
filename: str | None = None,
kb_id: str | None = None,
) -> dict[str, Any]:
field_def_service = MetadataFieldDefinitionService(self._session)
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
if not field_defs:
return {}
discovery_tool = MetadataDiscoveryTool(self._session)
discovery_result = await discovery_tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=True,
top_n=5,
)
common_values_map = {
field.field_key: field.common_values
for field in (discovery_result.fields if discovery_result.success else [])
}
fields_payload = []
for field_def in field_defs:
fields_payload.append({
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options or [],
"common_values": common_values_map.get(field_def.field_key, []),
})
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
要求:
1) 只能使用提供的字段;
2) 如果字段有 options 或 common_values优先从中选择最匹配的值
3) 不确定的字段不要填写;
4) 输出必须是严格 JSON 对象,只包含推断出的字段;
5) 不要输出多余说明。
可用字段定义JSON 数组):
{json.dumps(fields_payload, ensure_ascii=False)}
文件名:{filename or "unknown"}
Markdown 内容:
{_truncate_content(content)}
""".strip()
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
return {}
try:
response = await llm_client.generate(
messages=[
{"role": "system", "content": "你是严格的 JSON 生成器。"},
{"role": "user", "content": prompt},
],
config=LLMConfig(
max_tokens=self._max_tokens,
temperature=self._temperature,
),
)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
return {}
if not response.content:
return {}
inferred = _extract_json_object(response.content)
if not inferred:
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
return {}
cleaned = self._clean_inferred_values(inferred, field_defs)
if not cleaned:
return {}
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, cleaned, "kb_document"
)
if not is_valid:
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
return {}
return cleaned
def _clean_inferred_values(
self,
inferred: dict[str, Any],
field_defs: list[Any],
) -> dict[str, Any]:
field_map = {f.field_key: f for f in field_defs}
cleaned: dict[str, Any] = {}
for key, value in inferred.items():
field_def = field_map.get(key)
if not field_def:
continue
if value is None or value == "":
continue
if field_def.type in {"enum", "array_enum"} and field_def.options:
if field_def.type == "enum":
if value not in field_def.options:
continue
else:
if not isinstance(value, list):
continue
filtered = [v for v in value if v in field_def.options]
if not filtered:
continue
value = filtered
cleaned[key] = value
return cleaned