203 lines
6.4 KiB
Python
203 lines
6.4 KiB
Python
|
|
"""
|
|||
|
|
KB document metadata inference using LLM.
|
|||
|
|
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import re
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from app.services.llm.base import LLMConfig
|
|||
|
|
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
|
|||
|
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
|||
|
|
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
_MAX_CONTENT_CHARS = 2000
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
|||
|
|
candidates: list[str] = []
|
|||
|
|
|
|||
|
|
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
|
|||
|
|
if code_block_match:
|
|||
|
|
candidates.append(code_block_match.group(1).strip())
|
|||
|
|
|
|||
|
|
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
|
|||
|
|
if fence_match:
|
|||
|
|
candidates.append(fence_match.group(1).strip())
|
|||
|
|
|
|||
|
|
brace_match = re.search(r"\{[\s\S]*\}", text)
|
|||
|
|
if brace_match:
|
|||
|
|
candidates.append(brace_match.group(0).strip())
|
|||
|
|
|
|||
|
|
for candidate in candidates:
|
|||
|
|
if not candidate:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
obj = json.loads(candidate)
|
|||
|
|
if isinstance(obj, dict):
|
|||
|
|
return obj
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
fixed = candidate.replace("'", '"')
|
|||
|
|
try:
|
|||
|
|
obj = json.loads(fixed)
|
|||
|
|
if isinstance(obj, dict):
|
|||
|
|
return obj
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _truncate_content(text: str) -> str:
|
|||
|
|
if len(text) <= _MAX_CONTENT_CHARS * 2:
|
|||
|
|
return text
|
|||
|
|
head = text[:_MAX_CONTENT_CHARS]
|
|||
|
|
tail = text[-_MAX_CONTENT_CHARS:]
|
|||
|
|
return f"{head}\n\n...\n\n{tail}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KBMetadataInferenceService:
|
|||
|
|
"""Infer document metadata based on markdown content."""
|
|||
|
|
|
|||
|
|
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
|
|||
|
|
self._session = session
|
|||
|
|
self._max_tokens = max_tokens
|
|||
|
|
self._temperature = temperature
|
|||
|
|
|
|||
|
|
async def infer_metadata(
|
|||
|
|
self,
|
|||
|
|
tenant_id: str,
|
|||
|
|
content: str,
|
|||
|
|
filename: str | None = None,
|
|||
|
|
kb_id: str | None = None,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
field_def_service = MetadataFieldDefinitionService(self._session)
|
|||
|
|
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
|
|||
|
|
|
|||
|
|
if not field_defs:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
discovery_tool = MetadataDiscoveryTool(self._session)
|
|||
|
|
discovery_result = await discovery_tool.execute(
|
|||
|
|
tenant_id=tenant_id,
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
include_values=True,
|
|||
|
|
top_n=5,
|
|||
|
|
)
|
|||
|
|
common_values_map = {
|
|||
|
|
field.field_key: field.common_values
|
|||
|
|
for field in (discovery_result.fields if discovery_result.success else [])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fields_payload = []
|
|||
|
|
for field_def in field_defs:
|
|||
|
|
fields_payload.append({
|
|||
|
|
"field_key": field_def.field_key,
|
|||
|
|
"label": field_def.label,
|
|||
|
|
"type": field_def.type,
|
|||
|
|
"required": field_def.required,
|
|||
|
|
"options": field_def.options or [],
|
|||
|
|
"common_values": common_values_map.get(field_def.field_key, []),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
|
|||
|
|
|
|||
|
|
要求:
|
|||
|
|
1) 只能使用提供的字段;
|
|||
|
|
2) 如果字段有 options 或 common_values,优先从中选择最匹配的值;
|
|||
|
|
3) 不确定的字段不要填写;
|
|||
|
|
4) 输出必须是严格 JSON 对象,只包含推断出的字段;
|
|||
|
|
5) 不要输出多余说明。
|
|||
|
|
|
|||
|
|
可用字段定义(JSON 数组):
|
|||
|
|
{json.dumps(fields_payload, ensure_ascii=False)}
|
|||
|
|
|
|||
|
|
文件名:{filename or "unknown"}
|
|||
|
|
|
|||
|
|
Markdown 内容:
|
|||
|
|
{_truncate_content(content)}
|
|||
|
|
""".strip()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
llm_manager = get_llm_config_manager()
|
|||
|
|
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = await llm_client.generate(
|
|||
|
|
messages=[
|
|||
|
|
{"role": "system", "content": "你是严格的 JSON 生成器。"},
|
|||
|
|
{"role": "user", "content": prompt},
|
|||
|
|
],
|
|||
|
|
config=LLMConfig(
|
|||
|
|
max_tokens=self._max_tokens,
|
|||
|
|
temperature=self._temperature,
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
if not response.content:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
inferred = _extract_json_object(response.content)
|
|||
|
|
if not inferred:
|
|||
|
|
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
cleaned = self._clean_inferred_values(inferred, field_defs)
|
|||
|
|
if not cleaned:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
|||
|
|
tenant_id, cleaned, "kb_document"
|
|||
|
|
)
|
|||
|
|
if not is_valid:
|
|||
|
|
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
return cleaned
|
|||
|
|
|
|||
|
|
def _clean_inferred_values(
|
|||
|
|
self,
|
|||
|
|
inferred: dict[str, Any],
|
|||
|
|
field_defs: list[Any],
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
field_map = {f.field_key: f for f in field_defs}
|
|||
|
|
cleaned: dict[str, Any] = {}
|
|||
|
|
|
|||
|
|
for key, value in inferred.items():
|
|||
|
|
field_def = field_map.get(key)
|
|||
|
|
if not field_def:
|
|||
|
|
continue
|
|||
|
|
if value is None or value == "":
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if field_def.type in {"enum", "array_enum"} and field_def.options:
|
|||
|
|
if field_def.type == "enum":
|
|||
|
|
if value not in field_def.options:
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
if not isinstance(value, list):
|
|||
|
|
continue
|
|||
|
|
filtered = [v for v in value if v in field_def.options]
|
|||
|
|
if not filtered:
|
|||
|
|
continue
|
|||
|
|
value = filtered
|
|||
|
|
|
|||
|
|
cleaned[key] = value
|
|||
|
|
|
|||
|
|
return cleaned
|