[AC-METADATA-INFERENCE] feat(metadata): 新增元数据自动推断服务
- 新增 metadata_auto_inference_service 实现元数据自动推断 - 新增 kb_metadata_inference 提供知识库元数据推断工具 - 支持从文档内容自动提取元数据字段 - 集成缓存机制提升推断效率
This commit is contained in:
parent
b3343f9e52
commit
9196247578
|
|
@ -0,0 +1,202 @@
|
||||||
|
"""
|
||||||
|
KB document metadata inference using LLM.
|
||||||
|
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.llm.base import LLMConfig
|
||||||
|
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
|
||||||
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||||
|
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_CONTENT_CHARS = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
candidates: list[str] = []
|
||||||
|
|
||||||
|
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
|
||||||
|
if code_block_match:
|
||||||
|
candidates.append(code_block_match.group(1).strip())
|
||||||
|
|
||||||
|
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
|
||||||
|
if fence_match:
|
||||||
|
candidates.append(fence_match.group(1).strip())
|
||||||
|
|
||||||
|
brace_match = re.search(r"\{[\s\S]*\}", text)
|
||||||
|
if brace_match:
|
||||||
|
candidates.append(brace_match.group(0).strip())
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(candidate)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
fixed = candidate.replace("'", '"')
|
||||||
|
try:
|
||||||
|
obj = json.loads(fixed)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_content(text: str) -> str:
|
||||||
|
if len(text) <= _MAX_CONTENT_CHARS * 2:
|
||||||
|
return text
|
||||||
|
head = text[:_MAX_CONTENT_CHARS]
|
||||||
|
tail = text[-_MAX_CONTENT_CHARS:]
|
||||||
|
return f"{head}\n\n...\n\n{tail}"
|
||||||
|
|
||||||
|
|
||||||
|
class KBMetadataInferenceService:
|
||||||
|
"""Infer document metadata based on markdown content."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
|
||||||
|
self._session = session
|
||||||
|
self._max_tokens = max_tokens
|
||||||
|
self._temperature = temperature
|
||||||
|
|
||||||
|
async def infer_metadata(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
content: str,
|
||||||
|
filename: str | None = None,
|
||||||
|
kb_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
field_def_service = MetadataFieldDefinitionService(self._session)
|
||||||
|
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
|
||||||
|
|
||||||
|
if not field_defs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
discovery_tool = MetadataDiscoveryTool(self._session)
|
||||||
|
discovery_result = await discovery_tool.execute(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
kb_id=kb_id,
|
||||||
|
include_values=True,
|
||||||
|
top_n=5,
|
||||||
|
)
|
||||||
|
common_values_map = {
|
||||||
|
field.field_key: field.common_values
|
||||||
|
for field in (discovery_result.fields if discovery_result.success else [])
|
||||||
|
}
|
||||||
|
|
||||||
|
fields_payload = []
|
||||||
|
for field_def in field_defs:
|
||||||
|
fields_payload.append({
|
||||||
|
"field_key": field_def.field_key,
|
||||||
|
"label": field_def.label,
|
||||||
|
"type": field_def.type,
|
||||||
|
"required": field_def.required,
|
||||||
|
"options": field_def.options or [],
|
||||||
|
"common_values": common_values_map.get(field_def.field_key, []),
|
||||||
|
})
|
||||||
|
|
||||||
|
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1) 只能使用提供的字段;
|
||||||
|
2) 如果字段有 options 或 common_values,优先从中选择最匹配的值;
|
||||||
|
3) 不确定的字段不要填写;
|
||||||
|
4) 输出必须是严格 JSON 对象,只包含推断出的字段;
|
||||||
|
5) 不要输出多余说明。
|
||||||
|
|
||||||
|
可用字段定义(JSON 数组):
|
||||||
|
{json.dumps(fields_payload, ensure_ascii=False)}
|
||||||
|
|
||||||
|
文件名:{filename or "unknown"}
|
||||||
|
|
||||||
|
Markdown 内容:
|
||||||
|
{_truncate_content(content)}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_manager = get_llm_config_manager()
|
||||||
|
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm_client.generate(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是严格的 JSON 生成器。"},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
config=LLMConfig(
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
temperature=self._temperature,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not response.content:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
inferred = _extract_json_object(response.content)
|
||||||
|
if not inferred:
|
||||||
|
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
cleaned = self._clean_inferred_values(inferred, field_defs)
|
||||||
|
if not cleaned:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||||||
|
tenant_id, cleaned, "kb_document"
|
||||||
|
)
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _clean_inferred_values(
|
||||||
|
self,
|
||||||
|
inferred: dict[str, Any],
|
||||||
|
field_defs: list[Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
field_map = {f.field_key: f for f in field_defs}
|
||||||
|
cleaned: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, value in inferred.items():
|
||||||
|
field_def = field_map.get(key)
|
||||||
|
if not field_def:
|
||||||
|
continue
|
||||||
|
if value is None or value == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if field_def.type in {"enum", "array_enum"} and field_def.options:
|
||||||
|
if field_def.type == "enum":
|
||||||
|
if value not in field_def.options:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
continue
|
||||||
|
filtered = [v for v in value if v in field_def.options]
|
||||||
|
if not filtered:
|
||||||
|
continue
|
||||||
|
value = filtered
|
||||||
|
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
@ -0,0 +1,496 @@
|
||||||
|
"""
|
||||||
|
Metadata Auto Inference Service.
|
||||||
|
自动推断文档元数据的服务,支持图片和文本格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.llm.factory import get_llm_config_manager
|
||||||
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||||
|
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_field_definitions_cache: dict[str, list[Any]] = {}
|
||||||
|
_cache_ttl_seconds = 300
|
||||||
|
_last_cache_refresh: dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceFieldContext:
|
||||||
|
"""推断字段的上下文信息"""
|
||||||
|
field_key: str
|
||||||
|
label: str
|
||||||
|
type: str
|
||||||
|
required: bool
|
||||||
|
options: list[str] | None = None
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoInferenceResult:
|
||||||
|
"""自动推断结果"""
|
||||||
|
inferred_metadata: dict[str, Any]
|
||||||
|
confidence_scores: dict[str, float]
|
||||||
|
raw_response: str
|
||||||
|
success: bool
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
METADATA_INFERENCE_SYSTEM_PROMPT = """你是一个专业的文档元数据分析助手。你的任务是根据文档内容,自动推断并填写元数据字段。
|
||||||
|
|
||||||
|
## 输出要求
|
||||||
|
请严格按照以下 JSON 格式输出,不要添加任何其他内容:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"inferred_metadata": {
|
||||||
|
"字段键名1": "推断的值1",
|
||||||
|
"字段键名2": "推断的值2"
|
||||||
|
},
|
||||||
|
"confidence_scores": {
|
||||||
|
"字段键名1": 0.95,
|
||||||
|
"字段键名2": 0.80
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 推断规则
|
||||||
|
1. **仔细分析文档内容**:根据文档的主题、关键词、上下文来推断元数据
|
||||||
|
2. **遵循字段定义**:
|
||||||
|
- 对于枚举类型(enum),必须从给定的选项中选择
|
||||||
|
- 对于数组枚举类型(array_enum),可以选择多个选项
|
||||||
|
- 对于数字类型(number),输出数字
|
||||||
|
- 对于布尔类型(boolean),输出 true 或 false
|
||||||
|
- 对于文本类型(text),输出字符串
|
||||||
|
3. **置信度评分**:
|
||||||
|
- 0.9-1.0: 非常确定
|
||||||
|
- 0.7-0.9: 比较确定
|
||||||
|
- 0.5-0.7: 有一定依据但不确定
|
||||||
|
- 0.0-0.5: 猜测或无法确定
|
||||||
|
4. **无法推断时**:如果无法从文档内容中合理推断某个字段,可以不填写该字段
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
- 必须严格按照字段定义的类型和选项填写
|
||||||
|
- 不要编造不存在的选项值
|
||||||
|
- 保持客观,基于文档内容推断"""
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataAutoInferenceService:
|
||||||
|
"""
|
||||||
|
元数据自动推断服务。
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 获取租户配置的元数据字段定义
|
||||||
|
2. 使用 LLM 根据文档内容自动推断元数据
|
||||||
|
3. 验证推断结果符合字段定义
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
- 图片上传时自动推断元数据
|
||||||
|
- Markdown/文本上传时自动推断元数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
timeout_seconds: int = 60,
|
||||||
|
):
|
||||||
|
self._session = session
|
||||||
|
self._model = model
|
||||||
|
self._max_tokens = max_tokens
|
||||||
|
self._timeout_seconds = timeout_seconds
|
||||||
|
self._field_def_service = MetadataFieldDefinitionService(session)
|
||||||
|
|
||||||
|
async def infer_metadata(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
content: str,
|
||||||
|
scope: str = "kb_document",
|
||||||
|
existing_metadata: dict[str, Any] | None = None,
|
||||||
|
image_base64: str | None = None,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> AutoInferenceResult:
|
||||||
|
"""
|
||||||
|
自动推断文档元数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
content: 文档文本内容
|
||||||
|
scope: 元数据作用范围
|
||||||
|
existing_metadata: 已有的元数据(用户手动填写的,会覆盖推断结果)
|
||||||
|
image_base64: 图片的 base64 编码(如果是图片)
|
||||||
|
mime_type: 图片的 MIME 类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AutoInferenceResult 包含推断的元数据
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[MetadataAutoInference] Starting inference: tenant={tenant_id}, "
|
||||||
|
f"content_length={len(content)}, scope={scope}"
|
||||||
|
)
|
||||||
|
|
||||||
|
field_definitions = await self._get_field_definitions_with_cache(tenant_id, scope)
|
||||||
|
|
||||||
|
if not field_definitions:
|
||||||
|
logger.info(f"[MetadataAutoInference] No field definitions found for tenant={tenant_id}")
|
||||||
|
return AutoInferenceResult(
|
||||||
|
inferred_metadata=existing_metadata or {},
|
||||||
|
confidence_scores={},
|
||||||
|
raw_response="",
|
||||||
|
success=True,
|
||||||
|
error_message="No field definitions configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
field_contexts = self._build_field_contexts(field_definitions)
|
||||||
|
|
||||||
|
if not field_contexts:
|
||||||
|
return AutoInferenceResult(
|
||||||
|
inferred_metadata=existing_metadata or {},
|
||||||
|
confidence_scores={},
|
||||||
|
raw_response="",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = self._build_user_prompt(content, field_contexts, existing_metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if image_base64 and mime_type:
|
||||||
|
raw_response = await self._call_multimodal_llm(
|
||||||
|
user_prompt, image_base64, mime_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_response = await self._call_text_llm(user_prompt)
|
||||||
|
|
||||||
|
result = self._parse_llm_response(raw_response, field_contexts)
|
||||||
|
|
||||||
|
if existing_metadata:
|
||||||
|
result.inferred_metadata.update(existing_metadata)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MetadataAutoInference] Inference completed: "
|
||||||
|
f"inferred_fields={list(result.inferred_metadata.keys())}, "
|
||||||
|
f"avg_confidence={sum(result.confidence_scores.values()) / len(result.confidence_scores) if result.confidence_scores else 0:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[MetadataAutoInference] Inference failed: {e}")
|
||||||
|
return AutoInferenceResult(
|
||||||
|
inferred_metadata=existing_metadata or {},
|
||||||
|
confidence_scores={},
|
||||||
|
raw_response="",
|
||||||
|
success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_field_contexts(
|
||||||
|
self,
|
||||||
|
field_definitions: list[Any],
|
||||||
|
) -> list[InferenceFieldContext]:
|
||||||
|
"""构建字段上下文列表"""
|
||||||
|
contexts = []
|
||||||
|
for f in field_definitions:
|
||||||
|
ctx = InferenceFieldContext(
|
||||||
|
field_key=f.field_key,
|
||||||
|
label=f.label,
|
||||||
|
type=f.type,
|
||||||
|
required=f.required,
|
||||||
|
options=f.options,
|
||||||
|
description=getattr(f, 'description', None),
|
||||||
|
)
|
||||||
|
contexts.append(ctx)
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
async def _get_field_definitions_with_cache(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
scope: str,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""
|
||||||
|
获取字段定义(带缓存)。
|
||||||
|
|
||||||
|
优先级:
|
||||||
|
1. Redis 缓存
|
||||||
|
2. 本地内存缓存
|
||||||
|
3. 数据库查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
scope: 作用范围
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段定义列表
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
cache_key = f"{tenant_id}:{scope}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis_cache = await get_metadata_cache_service()
|
||||||
|
cached_fields = await redis_cache.get_fields(tenant_id)
|
||||||
|
|
||||||
|
if cached_fields:
|
||||||
|
logger.info(f"[MetadataAutoInference] Redis cache hit for tenant={tenant_id}")
|
||||||
|
return [self._dict_to_field_def(f) for f in cached_fields]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[MetadataAutoInference] Redis cache error: {e}")
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
if cache_key in _field_definitions_cache:
|
||||||
|
last_refresh = _last_cache_refresh.get(cache_key, 0)
|
||||||
|
if current_time - last_refresh < _cache_ttl_seconds:
|
||||||
|
logger.info(f"[MetadataAutoInference] Local cache hit for tenant={tenant_id}")
|
||||||
|
return _field_definitions_cache[cache_key]
|
||||||
|
|
||||||
|
logger.info(f"[MetadataAutoInference] Cache miss, querying database for tenant={tenant_id}")
|
||||||
|
field_definitions = await self._field_def_service.get_active_field_definitions(
|
||||||
|
tenant_id, scope
|
||||||
|
)
|
||||||
|
|
||||||
|
_field_definitions_cache[cache_key] = field_definitions
|
||||||
|
_last_cache_refresh[cache_key] = current_time
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis_cache = await get_metadata_cache_service()
|
||||||
|
await redis_cache.set_fields(
|
||||||
|
tenant_id,
|
||||||
|
[self._field_def_to_dict(f) for f in field_definitions]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[MetadataAutoInference] Failed to update Redis cache: {e}")
|
||||||
|
|
||||||
|
return field_definitions
|
||||||
|
|
||||||
|
def _field_def_to_dict(self, field_def: Any) -> dict[str, Any]:
|
||||||
|
"""将字段定义转换为字典"""
|
||||||
|
return {
|
||||||
|
"field_key": field_def.field_key,
|
||||||
|
"label": field_def.label,
|
||||||
|
"type": field_def.type,
|
||||||
|
"required": field_def.required,
|
||||||
|
"options": field_def.options,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _dict_to_field_def(self, data: dict[str, Any]) -> Any:
|
||||||
|
"""将字典转换为字段定义对象"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedFieldDefinition:
|
||||||
|
field_key: str
|
||||||
|
label: str
|
||||||
|
type: str
|
||||||
|
required: bool
|
||||||
|
options: list[str] | None = None
|
||||||
|
|
||||||
|
return CachedFieldDefinition(
|
||||||
|
field_key=data["field_key"],
|
||||||
|
label=data["label"],
|
||||||
|
type=data["type"],
|
||||||
|
required=data["required"],
|
||||||
|
options=data.get("options"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_user_prompt(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
field_contexts: list[InferenceFieldContext],
|
||||||
|
existing_metadata: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""构建用户提示词"""
|
||||||
|
field_descriptions = []
|
||||||
|
for ctx in field_contexts:
|
||||||
|
desc = f"- **{ctx.label}** ({ctx.field_key})"
|
||||||
|
desc += f"\n - 类型: {ctx.type}"
|
||||||
|
desc += f"\n - 必填: {'是' if ctx.required else '否'}"
|
||||||
|
if ctx.options:
|
||||||
|
desc += f"\n - 可选值: {', '.join(ctx.options)}"
|
||||||
|
if existing_metadata and ctx.field_key in existing_metadata:
|
||||||
|
desc += f"\n - 已有值: {existing_metadata[ctx.field_key]}"
|
||||||
|
field_descriptions.append(desc)
|
||||||
|
|
||||||
|
fields_text = "\n".join(field_descriptions)
|
||||||
|
|
||||||
|
prompt = f"""请分析以下文档内容,并推断相应的元数据字段。
|
||||||
|
|
||||||
|
## 待推断的字段定义
|
||||||
|
{fields_text}
|
||||||
|
|
||||||
|
## 文档内容
|
||||||
|
{content[:4000]}
|
||||||
|
|
||||||
|
请根据文档内容推断上述字段的值,并输出 JSON 格式的结果。"""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
async def _call_text_llm(self, prompt: str) -> str:
|
||||||
|
"""调用文本 LLM"""
|
||||||
|
manager = get_llm_config_manager()
|
||||||
|
client = manager.get_kb_processing_client()
|
||||||
|
|
||||||
|
config = manager.kb_processing_config
|
||||||
|
model = self._model or config.get("model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
from app.services.llm.base import LLMConfig
|
||||||
|
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model=model,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
temperature=0.3,
|
||||||
|
timeout_seconds=self._timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await client.generate(messages=messages, config=llm_config)
|
||||||
|
return response.content or ""
|
||||||
|
|
||||||
|
async def _call_multimodal_llm(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_base64: str,
|
||||||
|
mime_type: str,
|
||||||
|
) -> str:
|
||||||
|
"""调用多模态 LLM"""
|
||||||
|
manager = get_llm_config_manager()
|
||||||
|
client = manager.get_kb_processing_client()
|
||||||
|
|
||||||
|
config = manager.kb_processing_config
|
||||||
|
model = self._model or config.get("model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
from app.services.llm.base import LLMConfig
|
||||||
|
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model=model,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
temperature=0.3,
|
||||||
|
timeout_seconds=self._timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:{mime_type};base64,{image_base64}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await client.generate(messages=messages, config=llm_config)
|
||||||
|
return response.content or ""
|
||||||
|
|
||||||
|
def _parse_llm_response(
|
||||||
|
self,
|
||||||
|
response: str,
|
||||||
|
field_contexts: list[InferenceFieldContext],
|
||||||
|
) -> AutoInferenceResult:
|
||||||
|
"""解析 LLM 响应"""
|
||||||
|
try:
|
||||||
|
json_str = self._extract_json(response)
|
||||||
|
data = json.loads(json_str)
|
||||||
|
|
||||||
|
inferred_metadata = data.get("inferred_metadata", {})
|
||||||
|
confidence_scores = data.get("confidence_scores", {})
|
||||||
|
|
||||||
|
field_map = {ctx.field_key: ctx for ctx in field_contexts}
|
||||||
|
validated_metadata = {}
|
||||||
|
validated_scores = {}
|
||||||
|
|
||||||
|
for field_key, value in inferred_metadata.items():
|
||||||
|
if field_key not in field_map:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ctx = field_map[field_key]
|
||||||
|
validated_value = self._validate_field_value(ctx, value)
|
||||||
|
|
||||||
|
if validated_value is not None:
|
||||||
|
validated_metadata[field_key] = validated_value
|
||||||
|
validated_scores[field_key] = confidence_scores.get(field_key, 0.5)
|
||||||
|
|
||||||
|
return AutoInferenceResult(
|
||||||
|
inferred_metadata=validated_metadata,
|
||||||
|
confidence_scores=validated_scores,
|
||||||
|
raw_response=response,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"[MetadataAutoInference] Failed to parse JSON: {e}")
|
||||||
|
return AutoInferenceResult(
|
||||||
|
inferred_metadata={},
|
||||||
|
confidence_scores={},
|
||||||
|
raw_response=response,
|
||||||
|
success=False,
|
||||||
|
error_message=f"JSON parse error: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_field_value(
|
||||||
|
self,
|
||||||
|
ctx: InferenceFieldContext,
|
||||||
|
value: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""验证并转换字段值"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from app.models.entities import MetadataFieldType
|
||||||
|
|
||||||
|
if ctx.type == MetadataFieldType.NUMBER.value:
|
||||||
|
try:
|
||||||
|
return float(value) if isinstance(value, str) else value
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif ctx.type == MetadataFieldType.BOOLEAN.value:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower() in ("true", "1", "yes")
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
elif ctx.type == MetadataFieldType.ENUM.value:
|
||||||
|
if ctx.options and value in ctx.options:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
elif ctx.type == MetadataFieldType.ARRAY_ENUM.value:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
value = [value] if value else []
|
||||||
|
if ctx.options:
|
||||||
|
return [v for v in value if v in ctx.options]
|
||||||
|
return value
|
||||||
|
|
||||||
|
else:
|
||||||
|
return str(value) if value is not None else None
|
||||||
|
|
||||||
|
def _extract_json(self, content: str) -> str:
|
||||||
|
"""从响应中提取 JSON"""
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
|
if content.startswith("{") and content.endswith("}"):
|
||||||
|
return content
|
||||||
|
|
||||||
|
json_start = content.find("{")
|
||||||
|
json_end = content.rfind("}")
|
||||||
|
|
||||||
|
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||||
|
return content[json_start:json_end + 1]
|
||||||
|
|
||||||
|
return content
|
||||||
Loading…
Reference in New Issue