From 9196247578c15f5b11e5b1211e063a4dad993e93 Mon Sep 17 00:00:00 2001 From: MerCry Date: Wed, 11 Mar 2026 19:06:21 +0800 Subject: [PATCH] =?UTF-8?q?[AC-METADATA-INFERENCE]=20feat(metadata):=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=85=83=E6=95=B0=E6=8D=AE=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=8E=A8=E6=96=AD=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 metadata_auto_inference_service 实现元数据自动推断 - 新增 kb_metadata_inference 提供知识库元数据推断工具 - 支持从文档内容自动提取元数据字段 - 集成缓存机制提升推断效率 --- .../app/services/kb_metadata_inference.py | 202 +++++++ .../metadata_auto_inference_service.py | 496 ++++++++++++++++++ 2 files changed, 698 insertions(+) create mode 100644 ai-service/app/services/kb_metadata_inference.py create mode 100644 ai-service/app/services/metadata_auto_inference_service.py diff --git a/ai-service/app/services/kb_metadata_inference.py b/ai-service/app/services/kb_metadata_inference.py new file mode 100644 index 0000000..b1335c7 --- /dev/null +++ b/ai-service/app/services/kb_metadata_inference.py @@ -0,0 +1,202 @@ +""" +KB document metadata inference using LLM. +[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.llm.base import LLMConfig +from app.services.llm.factory import LLMUsageType, get_llm_config_manager +from app.services.metadata_field_definition_service import MetadataFieldDefinitionService +from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool + +logger = logging.getLogger(__name__) + + +_MAX_CONTENT_CHARS = 2000 + + +def _extract_json_object(text: str) -> dict[str, Any] | None: + candidates: list[str] = [] + + code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE) + if code_block_match: + candidates.append(code_block_match.group(1).strip()) + + fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text) + if fence_match: + candidates.append(fence_match.group(1).strip()) + + brace_match = re.search(r"\{[\s\S]*\}", text) + if brace_match: + candidates.append(brace_match.group(0).strip()) + + for candidate in candidates: + if not candidate: + continue + try: + obj = json.loads(candidate) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + fixed = candidate.replace("'", '"') + try: + obj = json.loads(fixed) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + continue + + return None + + +def _truncate_content(text: str) -> str: + if len(text) <= _MAX_CONTENT_CHARS * 2: + return text + head = text[:_MAX_CONTENT_CHARS] + tail = text[-_MAX_CONTENT_CHARS:] + return f"{head}\n\n...\n\n{tail}" + + +class KBMetadataInferenceService: + """Infer document metadata based on markdown content.""" + + def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2): + self._session = session + self._max_tokens = max_tokens + self._temperature = temperature + + async def infer_metadata( + self, + tenant_id: str, + content: str, + filename: str | None = None, + kb_id: str | None = None, + ) -> dict[str, Any]: + field_def_service = MetadataFieldDefinitionService(self._session) + field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document") + + if not field_defs: + return {} + + discovery_tool = MetadataDiscoveryTool(self._session) + discovery_result = await discovery_tool.execute( + tenant_id=tenant_id, + kb_id=kb_id, + include_values=True, + top_n=5, + ) + common_values_map = { + field.field_key: field.common_values + for field in (discovery_result.fields if discovery_result.success else []) + } + + fields_payload = [] + for field_def in field_defs: + fields_payload.append({ + "field_key": field_def.field_key, + "label": field_def.label, + "type": field_def.type, + "required": field_def.required, + "options": field_def.options or [], + "common_values": common_values_map.get(field_def.field_key, []), + }) + + prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。 + +要求: +1) 只能使用提供的字段; +2) 如果字段有 options 或 common_values,优先从中选择最匹配的值; +3) 不确定的字段不要填写; +4) 输出必须是严格 JSON 对象,只包含推断出的字段; +5) 不要输出多余说明。 + +可用字段定义(JSON 数组): +{json.dumps(fields_payload, ensure_ascii=False)} + +文件名:{filename or "unknown"} + +Markdown 内容: +{_truncate_content(content)} +""".strip() + + try: + llm_manager = get_llm_config_manager() + llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING) + except Exception as e: + logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}") + return {} + + try: + response = await llm_client.generate( + messages=[ + {"role": "system", "content": "你是严格的 JSON 生成器。"}, + {"role": "user", "content": prompt}, + ], + config=LLMConfig( + max_tokens=self._max_tokens, + temperature=self._temperature, + ), + ) + except Exception as e: + logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}") + return {} + + if not response.content: + return {} + + inferred = _extract_json_object(response.content) + if not inferred: + logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON") + return {} + + cleaned = self._clean_inferred_values(inferred, field_defs) + if not cleaned: + return {} + + is_valid, validation_errors = await field_def_service.validate_metadata_for_create( + tenant_id, cleaned, "kb_document" + ) + if not is_valid: + logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}") + return {} + + return cleaned + + def _clean_inferred_values( + self, + inferred: dict[str, Any], + field_defs: list[Any], + ) -> dict[str, Any]: + field_map = {f.field_key: f for f in field_defs} + cleaned: dict[str, Any] = {} + + for key, value in inferred.items(): + field_def = field_map.get(key) + if not field_def: + continue + if value is None or value == "": + continue + + if field_def.type in {"enum", "array_enum"} and field_def.options: + if field_def.type == "enum": + if value not in field_def.options: + continue + else: + if not isinstance(value, list): + continue + filtered = [v for v in value if v in field_def.options] + if not filtered: + continue + value = filtered + + cleaned[key] = value + + return cleaned diff --git a/ai-service/app/services/metadata_auto_inference_service.py b/ai-service/app/services/metadata_auto_inference_service.py new file mode 100644 index 0000000..cbd0778 --- /dev/null +++ b/ai-service/app/services/metadata_auto_inference_service.py @@ -0,0 +1,496 @@ +""" +Metadata Auto Inference Service. +自动推断文档元数据的服务,支持图片和文本格式。 +""" + +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.llm.factory import get_llm_config_manager +from app.services.metadata_field_definition_service import MetadataFieldDefinitionService +from app.services.metadata_cache_service import get_metadata_cache_service + +logger = logging.getLogger(__name__) + +_field_definitions_cache: dict[str, list[Any]] = {} +_cache_ttl_seconds = 300 +_last_cache_refresh: dict[str, float] = {} + + +@dataclass +class InferenceFieldContext: + """推断字段的上下文信息""" + field_key: str + label: str + type: str + required: bool + options: list[str] | None = None + description: str | None = None + + +@dataclass +class AutoInferenceResult: + """自动推断结果""" + inferred_metadata: dict[str, Any] + confidence_scores: dict[str, float] + raw_response: str + success: bool + error_message: str | None = None + + +METADATA_INFERENCE_SYSTEM_PROMPT = """你是一个专业的文档元数据分析助手。你的任务是根据文档内容,自动推断并填写元数据字段。 + +## 输出要求 +请严格按照以下 JSON 格式输出,不要添加任何其他内容: + +```json +{ + "inferred_metadata": { + "字段键名1": "推断的值1", + "字段键名2": "推断的值2" + }, + "confidence_scores": { + "字段键名1": 0.95, + "字段键名2": 0.80 + } +} +``` + +## 推断规则 +1. **仔细分析文档内容**:根据文档的主题、关键词、上下文来推断元数据 +2. **遵循字段定义**: + - 对于枚举类型(enum),必须从给定的选项中选择 + - 对于数组枚举类型(array_enum),可以选择多个选项 + - 对于数字类型(number),输出数字 + - 对于布尔类型(boolean),输出 true 或 false + - 对于文本类型(text),输出字符串 +3. **置信度评分**: + - 0.9-1.0: 非常确定 + - 0.7-0.9: 比较确定 + - 0.5-0.7: 有一定依据但不确定 + - 0.0-0.5: 猜测或无法确定 +4. **无法推断时**:如果无法从文档内容中合理推断某个字段,可以不填写该字段 + +## 注意事项 +- 必须严格按照字段定义的类型和选项填写 +- 不要编造不存在的选项值 +- 保持客观,基于文档内容推断""" + + +class MetadataAutoInferenceService: + """ + 元数据自动推断服务。 + + 功能: + 1. 获取租户配置的元数据字段定义 + 2. 使用 LLM 根据文档内容自动推断元数据 + 3. 验证推断结果符合字段定义 + + 使用场景: + - 图片上传时自动推断元数据 + - Markdown/文本上传时自动推断元数据 + """ + + def __init__( + self, + session: AsyncSession, + model: str | None = None, + max_tokens: int = 1024, + timeout_seconds: int = 60, + ): + self._session = session + self._model = model + self._max_tokens = max_tokens + self._timeout_seconds = timeout_seconds + self._field_def_service = MetadataFieldDefinitionService(session) + + async def infer_metadata( + self, + tenant_id: str, + content: str, + scope: str = "kb_document", + existing_metadata: dict[str, Any] | None = None, + image_base64: str | None = None, + mime_type: str | None = None, + ) -> AutoInferenceResult: + """ + 自动推断文档元数据。 + + Args: + tenant_id: 租户 ID + content: 文档文本内容 + scope: 元数据作用范围 + existing_metadata: 已有的元数据(用户手动填写的,会覆盖推断结果) + image_base64: 图片的 base64 编码(如果是图片) + mime_type: 图片的 MIME 类型 + + Returns: + AutoInferenceResult 包含推断的元数据 + """ + logger.info( + f"[MetadataAutoInference] Starting inference: tenant={tenant_id}, " + f"content_length={len(content)}, scope={scope}" + ) + + field_definitions = await self._get_field_definitions_with_cache(tenant_id, scope) + + if not field_definitions: + logger.info(f"[MetadataAutoInference] No field definitions found for tenant={tenant_id}") + return AutoInferenceResult( + inferred_metadata=existing_metadata or {}, + confidence_scores={}, + raw_response="", + success=True, + error_message="No field definitions configured", + ) + + field_contexts = self._build_field_contexts(field_definitions) + + if not field_contexts: + return AutoInferenceResult( + inferred_metadata=existing_metadata or {}, + confidence_scores={}, + raw_response="", + success=True, + ) + + user_prompt = self._build_user_prompt(content, field_contexts, existing_metadata) + + try: + if image_base64 and mime_type: + raw_response = await self._call_multimodal_llm( + user_prompt, image_base64, mime_type + ) + else: + raw_response = await self._call_text_llm(user_prompt) + + result = self._parse_llm_response(raw_response, field_contexts) + + if existing_metadata: + result.inferred_metadata.update(existing_metadata) + + logger.info( + f"[MetadataAutoInference] Inference completed: " + f"inferred_fields={list(result.inferred_metadata.keys())}, " + f"avg_confidence={sum(result.confidence_scores.values()) / len(result.confidence_scores) if result.confidence_scores else 0:.2f}" + ) + + return result + + except Exception as e: + logger.error(f"[MetadataAutoInference] Inference failed: {e}") + return AutoInferenceResult( + inferred_metadata=existing_metadata or {}, + confidence_scores={}, + raw_response="", + success=False, + error_message=str(e), + ) + + def _build_field_contexts( + self, + field_definitions: list[Any], + ) -> list[InferenceFieldContext]: + """构建字段上下文列表""" + contexts = [] + for f in field_definitions: + ctx = InferenceFieldContext( + field_key=f.field_key, + label=f.label, + type=f.type, + required=f.required, + options=f.options, + description=getattr(f, 'description', None), + ) + contexts.append(ctx) + return contexts + + async def _get_field_definitions_with_cache( + self, + tenant_id: str, + scope: str, + ) -> list[Any]: + """ + 获取字段定义(带缓存)。 + + 优先级: + 1. Redis 缓存 + 2. 本地内存缓存 + 3. 数据库查询 + + Args: + tenant_id: 租户 ID + scope: 作用范围 + + Returns: + 字段定义列表 + """ + import time + cache_key = f"{tenant_id}:{scope}" + + try: + redis_cache = await get_metadata_cache_service() + cached_fields = await redis_cache.get_fields(tenant_id) + + if cached_fields: + logger.info(f"[MetadataAutoInference] Redis cache hit for tenant={tenant_id}") + return [self._dict_to_field_def(f) for f in cached_fields] + except Exception as e: + logger.warning(f"[MetadataAutoInference] Redis cache error: {e}") + + current_time = time.time() + if cache_key in _field_definitions_cache: + last_refresh = _last_cache_refresh.get(cache_key, 0) + if current_time - last_refresh < _cache_ttl_seconds: + logger.info(f"[MetadataAutoInference] Local cache hit for tenant={tenant_id}") + return _field_definitions_cache[cache_key] + + logger.info(f"[MetadataAutoInference] Cache miss, querying database for tenant={tenant_id}") + field_definitions = await self._field_def_service.get_active_field_definitions( + tenant_id, scope + ) + + _field_definitions_cache[cache_key] = field_definitions + _last_cache_refresh[cache_key] = current_time + + try: + redis_cache = await get_metadata_cache_service() + await redis_cache.set_fields( + tenant_id, + [self._field_def_to_dict(f) for f in field_definitions] + ) + except Exception as e: + logger.warning(f"[MetadataAutoInference] Failed to update Redis cache: {e}") + + return field_definitions + + def _field_def_to_dict(self, field_def: Any) -> dict[str, Any]: + """将字段定义转换为字典""" + return { + "field_key": field_def.field_key, + "label": field_def.label, + "type": field_def.type, + "required": field_def.required, + "options": field_def.options, + } + + def _dict_to_field_def(self, data: dict[str, Any]) -> Any: + """将字典转换为字段定义对象""" + from dataclasses import dataclass + + @dataclass + class CachedFieldDefinition: + field_key: str + label: str + type: str + required: bool + options: list[str] | None = None + + return CachedFieldDefinition( + field_key=data["field_key"], + label=data["label"], + type=data["type"], + required=data["required"], + options=data.get("options"), + ) + + def _build_user_prompt( + self, + content: str, + field_contexts: list[InferenceFieldContext], + existing_metadata: dict[str, Any] | None = None, + ) -> str: + """构建用户提示词""" + field_descriptions = [] + for ctx in field_contexts: + desc = f"- **{ctx.label}** ({ctx.field_key})" + desc += f"\n - 类型: {ctx.type}" + desc += f"\n - 必填: {'是' if ctx.required else '否'}" + if ctx.options: + desc += f"\n - 可选值: {', '.join(ctx.options)}" + if existing_metadata and ctx.field_key in existing_metadata: + desc += f"\n - 已有值: {existing_metadata[ctx.field_key]}" + field_descriptions.append(desc) + + fields_text = "\n".join(field_descriptions) + + prompt = f"""请分析以下文档内容,并推断相应的元数据字段。 + +## 待推断的字段定义 +{fields_text} + +## 文档内容 +{content[:4000]} + +请根据文档内容推断上述字段的值,并输出 JSON 格式的结果。""" + + return prompt + + async def _call_text_llm(self, prompt: str) -> str: + """调用文本 LLM""" + manager = get_llm_config_manager() + client = manager.get_kb_processing_client() + + config = manager.kb_processing_config + model = self._model or config.get("model", "gpt-4o-mini") + + from app.services.llm.base import LLMConfig + + llm_config = LLMConfig( + model=model, + max_tokens=self._max_tokens, + temperature=0.3, + timeout_seconds=self._timeout_seconds, + ) + + messages = [ + {"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] + + response = await client.generate(messages=messages, config=llm_config) + return response.content or "" + + async def _call_multimodal_llm( + self, + prompt: str, + image_base64: str, + mime_type: str, + ) -> str: + """调用多模态 LLM""" + manager = get_llm_config_manager() + client = manager.get_kb_processing_client() + + config = manager.kb_processing_config + model = self._model or config.get("model", "gpt-4o-mini") + + from app.services.llm.base import LLMConfig + + llm_config = LLMConfig( + model=model, + max_tokens=self._max_tokens, + temperature=0.3, + timeout_seconds=self._timeout_seconds, + ) + + messages = [ + {"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{image_base64}", + }, + }, + ], + }, + ] + + response = await client.generate(messages=messages, config=llm_config) + return response.content or "" + + def _parse_llm_response( + self, + response: str, + field_contexts: list[InferenceFieldContext], + ) -> AutoInferenceResult: + """解析 LLM 响应""" + try: + json_str = self._extract_json(response) + data = json.loads(json_str) + + inferred_metadata = data.get("inferred_metadata", {}) + confidence_scores = data.get("confidence_scores", {}) + + field_map = {ctx.field_key: ctx for ctx in field_contexts} + validated_metadata = {} + validated_scores = {} + + for field_key, value in inferred_metadata.items(): + if field_key not in field_map: + continue + + ctx = field_map[field_key] + validated_value = self._validate_field_value(ctx, value) + + if validated_value is not None: + validated_metadata[field_key] = validated_value + validated_scores[field_key] = confidence_scores.get(field_key, 0.5) + + return AutoInferenceResult( + inferred_metadata=validated_metadata, + confidence_scores=validated_scores, + raw_response=response, + success=True, + ) + + except json.JSONDecodeError as e: + logger.warning(f"[MetadataAutoInference] Failed to parse JSON: {e}") + return AutoInferenceResult( + inferred_metadata={}, + confidence_scores={}, + raw_response=response, + success=False, + error_message=f"JSON parse error: {e}", + ) + + def _validate_field_value( + self, + ctx: InferenceFieldContext, + value: Any, + ) -> Any: + """验证并转换字段值""" + if value is None: + return None + + from app.models.entities import MetadataFieldType + + if ctx.type == MetadataFieldType.NUMBER.value: + try: + return float(value) if isinstance(value, str) else value + except (ValueError, TypeError): + return None + + elif ctx.type == MetadataFieldType.BOOLEAN.value: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in ("true", "1", "yes") + return bool(value) + + elif ctx.type == MetadataFieldType.ENUM.value: + if ctx.options and value in ctx.options: + return value + return None + + elif ctx.type == MetadataFieldType.ARRAY_ENUM.value: + if not isinstance(value, list): + value = [value] if value else [] + if ctx.options: + return [v for v in value if v in ctx.options] + return value + + else: + return str(value) if value is not None else None + + def _extract_json(self, content: str) -> str: + """从响应中提取 JSON""" + content = content.strip() + + if content.startswith("{") and content.endswith("}"): + return content + + json_start = content.find("{") + json_end = content.rfind("}") + + if json_start != -1 and json_end != -1 and json_end > json_start: + return content[json_start:json_end + 1] + + return content