203 lines
6.4 KiB
Python
203 lines
6.4 KiB
Python
"""
|
||
KB document metadata inference using LLM.
|
||
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Any
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.services.llm.base import LLMConfig
|
||
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
|
||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
_MAX_CONTENT_CHARS = 2000
|
||
|
||
|
||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||
candidates: list[str] = []
|
||
|
||
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
|
||
if code_block_match:
|
||
candidates.append(code_block_match.group(1).strip())
|
||
|
||
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
|
||
if fence_match:
|
||
candidates.append(fence_match.group(1).strip())
|
||
|
||
brace_match = re.search(r"\{[\s\S]*\}", text)
|
||
if brace_match:
|
||
candidates.append(brace_match.group(0).strip())
|
||
|
||
for candidate in candidates:
|
||
if not candidate:
|
||
continue
|
||
try:
|
||
obj = json.loads(candidate)
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
except json.JSONDecodeError:
|
||
fixed = candidate.replace("'", '"')
|
||
try:
|
||
obj = json.loads(fixed)
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
return None
|
||
|
||
|
||
def _truncate_content(text: str) -> str:
|
||
if len(text) <= _MAX_CONTENT_CHARS * 2:
|
||
return text
|
||
head = text[:_MAX_CONTENT_CHARS]
|
||
tail = text[-_MAX_CONTENT_CHARS:]
|
||
return f"{head}\n\n...\n\n{tail}"
|
||
|
||
|
||
class KBMetadataInferenceService:
|
||
"""Infer document metadata based on markdown content."""
|
||
|
||
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
|
||
self._session = session
|
||
self._max_tokens = max_tokens
|
||
self._temperature = temperature
|
||
|
||
async def infer_metadata(
|
||
self,
|
||
tenant_id: str,
|
||
content: str,
|
||
filename: str | None = None,
|
||
kb_id: str | None = None,
|
||
) -> dict[str, Any]:
|
||
field_def_service = MetadataFieldDefinitionService(self._session)
|
||
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
|
||
|
||
if not field_defs:
|
||
return {}
|
||
|
||
discovery_tool = MetadataDiscoveryTool(self._session)
|
||
discovery_result = await discovery_tool.execute(
|
||
tenant_id=tenant_id,
|
||
kb_id=kb_id,
|
||
include_values=True,
|
||
top_n=5,
|
||
)
|
||
common_values_map = {
|
||
field.field_key: field.common_values
|
||
for field in (discovery_result.fields if discovery_result.success else [])
|
||
}
|
||
|
||
fields_payload = []
|
||
for field_def in field_defs:
|
||
fields_payload.append({
|
||
"field_key": field_def.field_key,
|
||
"label": field_def.label,
|
||
"type": field_def.type,
|
||
"required": field_def.required,
|
||
"options": field_def.options or [],
|
||
"common_values": common_values_map.get(field_def.field_key, []),
|
||
})
|
||
|
||
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
|
||
|
||
要求:
|
||
1) 只能使用提供的字段;
|
||
2) 如果字段有 options 或 common_values,优先从中选择最匹配的值;
|
||
3) 不确定的字段不要填写;
|
||
4) 输出必须是严格 JSON 对象,只包含推断出的字段;
|
||
5) 不要输出多余说明。
|
||
|
||
可用字段定义(JSON 数组):
|
||
{json.dumps(fields_payload, ensure_ascii=False)}
|
||
|
||
文件名:{filename or "unknown"}
|
||
|
||
Markdown 内容:
|
||
{_truncate_content(content)}
|
||
""".strip()
|
||
|
||
try:
|
||
llm_manager = get_llm_config_manager()
|
||
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
|
||
except Exception as e:
|
||
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
|
||
return {}
|
||
|
||
try:
|
||
response = await llm_client.generate(
|
||
messages=[
|
||
{"role": "system", "content": "你是严格的 JSON 生成器。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
config=LLMConfig(
|
||
max_tokens=self._max_tokens,
|
||
temperature=self._temperature,
|
||
),
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
|
||
return {}
|
||
|
||
if not response.content:
|
||
return {}
|
||
|
||
inferred = _extract_json_object(response.content)
|
||
if not inferred:
|
||
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
|
||
return {}
|
||
|
||
cleaned = self._clean_inferred_values(inferred, field_defs)
|
||
if not cleaned:
|
||
return {}
|
||
|
||
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||
tenant_id, cleaned, "kb_document"
|
||
)
|
||
if not is_valid:
|
||
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
|
||
return {}
|
||
|
||
return cleaned
|
||
|
||
def _clean_inferred_values(
|
||
self,
|
||
inferred: dict[str, Any],
|
||
field_defs: list[Any],
|
||
) -> dict[str, Any]:
|
||
field_map = {f.field_key: f for f in field_defs}
|
||
cleaned: dict[str, Any] = {}
|
||
|
||
for key, value in inferred.items():
|
||
field_def = field_map.get(key)
|
||
if not field_def:
|
||
continue
|
||
if value is None or value == "":
|
||
continue
|
||
|
||
if field_def.type in {"enum", "array_enum"} and field_def.options:
|
||
if field_def.type == "enum":
|
||
if value not in field_def.options:
|
||
continue
|
||
else:
|
||
if not isinstance(value, list):
|
||
continue
|
||
filtered = [v for v in value if v in field_def.options]
|
||
if not filtered:
|
||
continue
|
||
value = filtered
|
||
|
||
cleaned[key] = value
|
||
|
||
return cleaned
|