feat: update core backend services including LLM, embedding, KB, orchestrator and admin APIs [AC-AISVC-CORE]
This commit is contained in:
parent
759eafb490
commit
fe883cfff0
|
|
@ -2,6 +2,7 @@
|
|||
Admin API routes for AI Service management.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
||||
[AC-MRS-07,08,16] Slot definition management endpoints.
|
||||
[AC-SCENE-SLOT-01] Scene slot bundle management endpoints.
|
||||
"""
|
||||
|
||||
from app.api.admin.api_key import router as api_key_router
|
||||
|
|
@ -18,6 +19,7 @@ from app.api.admin.metadata_schema import router as metadata_schema_router
|
|||
from app.api.admin.monitoring import router as monitoring_router
|
||||
from app.api.admin.prompt_templates import router as prompt_templates_router
|
||||
from app.api.admin.rag import router as rag_router
|
||||
from app.api.admin.scene_slot_bundle import router as scene_slot_bundle_router
|
||||
from app.api.admin.script_flows import router as script_flows_router
|
||||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.slot_definition import router as slot_definition_router
|
||||
|
|
@ -38,6 +40,7 @@ __all__ = [
|
|||
"monitoring_router",
|
||||
"prompt_templates_router",
|
||||
"rag_router",
|
||||
"scene_slot_bundle_router",
|
||||
"script_flows_router",
|
||||
"sessions_router",
|
||||
"slot_definition_router",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Intent Rule Management API.
|
||||
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
|
||||
[AC-AISVC-96] Intent rule testing endpoint.
|
||||
[AC-AISVC-116] Fusion config management endpoints.
|
||||
[AC-AISVC-114] Intent vector generation endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -14,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.core.database import get_session
|
||||
from app.models.entities import IntentRuleCreate, IntentRuleUpdate
|
||||
from app.services.intent.models import DEFAULT_FUSION_CONFIG, FusionConfig
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
from app.services.intent.tester import IntentRuleTester
|
||||
|
||||
|
|
@ -21,6 +24,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
|
||||
|
||||
_fusion_config = FusionConfig()
|
||||
|
||||
|
||||
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||
"""Extract tenant ID from header."""
|
||||
|
|
@ -204,3 +209,109 @@ async def test_rule(
|
|||
result = await tester.test_rule(rule, [body.message], all_rules)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
class FusionConfigUpdate(BaseModel):
|
||||
"""Request body for updating fusion config."""
|
||||
|
||||
w_rule: float | None = None
|
||||
w_semantic: float | None = None
|
||||
w_llm: float | None = None
|
||||
semantic_threshold: float | None = None
|
||||
conflict_threshold: float | None = None
|
||||
gray_zone_threshold: float | None = None
|
||||
min_trigger_threshold: float | None = None
|
||||
clarify_threshold: float | None = None
|
||||
multi_intent_threshold: float | None = None
|
||||
llm_judge_enabled: bool | None = None
|
||||
semantic_matcher_enabled: bool | None = None
|
||||
semantic_matcher_timeout_ms: int | None = None
|
||||
llm_judge_timeout_ms: int | None = None
|
||||
semantic_top_k: int | None = None
|
||||
|
||||
|
||||
@router.get("/fusion-config")
|
||||
async def get_fusion_config() -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-116] Get current fusion configuration.
|
||||
"""
|
||||
logger.info("[AC-AISVC-116] Getting fusion config")
|
||||
return _fusion_config.to_dict()
|
||||
|
||||
|
||||
@router.put("/fusion-config")
|
||||
async def update_fusion_config(
|
||||
body: FusionConfigUpdate,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-116] Update fusion configuration.
|
||||
"""
|
||||
global _fusion_config
|
||||
|
||||
logger.info(f"[AC-AISVC-116] Updating fusion config: {body.model_dump()}")
|
||||
|
||||
current_dict = _fusion_config.to_dict()
|
||||
update_dict = body.model_dump(exclude_none=True)
|
||||
current_dict.update(update_dict)
|
||||
_fusion_config = FusionConfig.from_dict(current_dict)
|
||||
|
||||
return _fusion_config.to_dict()
|
||||
|
||||
|
||||
@router.post("/{rule_id}/generate-vector")
|
||||
async def generate_intent_vector(
|
||||
rule_id: uuid.UUID,
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-114] Generate intent vector for a rule.
|
||||
|
||||
Uses the rule's semantic_examples to generate an average vector.
|
||||
If no semantic_examples exist, returns an error.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-114] Generating intent vector for tenant={tenant_id}, rule_id={rule_id}"
|
||||
)
|
||||
|
||||
service = IntentRuleService(session)
|
||||
rule = await service.get_rule(tenant_id, rule_id)
|
||||
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Intent rule not found")
|
||||
|
||||
if not rule.semantic_examples:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rule has no semantic_examples. Please add semantic_examples first."
|
||||
)
|
||||
|
||||
try:
|
||||
from app.core.dependencies import get_embedding_provider
|
||||
embedding_provider = get_embedding_provider()
|
||||
|
||||
vectors = await embedding_provider.embed_batch(rule.semantic_examples)
|
||||
|
||||
import numpy as np
|
||||
avg_vector = np.mean(vectors, axis=0).tolist()
|
||||
|
||||
update_data = IntentRuleUpdate(intent_vector=avg_vector)
|
||||
updated_rule = await service.update_rule(tenant_id, rule_id, update_data)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-114] Generated intent vector for rule={rule_id}, "
|
||||
f"dimension={len(avg_vector)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(updated_rule.id),
|
||||
"intent_vector": updated_rule.intent_vector,
|
||||
"semantic_examples": updated_rule.semantic_examples,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-114] Failed to generate intent vector: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate intent vector: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,28 +6,35 @@ Knowledge Base management endpoints.
|
|||
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, UploadFile
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
IndexJob,
|
||||
IndexJobStatus,
|
||||
KBType,
|
||||
KnowledgeBase,
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeBaseUpdate,
|
||||
)
|
||||
from app.services.kb import KBService
|
||||
from app.services.knowledge_base_service import KnowledgeBaseService
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -457,6 +464,7 @@ async def list_documents(
|
|||
"kbId": doc.kb_id,
|
||||
"fileName": doc.file_name,
|
||||
"status": doc.status,
|
||||
"metadata": doc.doc_metadata,
|
||||
"jobId": str(latest_job.id) if latest_job else None,
|
||||
"createdAt": doc.created_at.isoformat() + "Z",
|
||||
"updatedAt": doc.updated_at.isoformat() + "Z",
|
||||
|
|
@ -585,6 +593,7 @@ async def upload_document(
|
|||
file_name=file.filename or "unknown",
|
||||
file_content=file_content,
|
||||
file_type=file.content_type,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||
|
|
@ -915,3 +924,488 @@ async def delete_document(
|
|||
"message": "Document deleted",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/documents/{doc_id}/metadata",
|
||||
operation_id="updateDocumentMetadata",
|
||||
summary="Update document metadata",
|
||||
description="[AC-ASA-08] Update metadata for a specific document.",
|
||||
responses={
|
||||
200: {"description": "Metadata updated"},
|
||||
404: {"description": "Document not found"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def update_document_metadata(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
doc_id: str,
|
||||
body: dict,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-08] Update document metadata.
|
||||
"""
|
||||
import json
|
||||
|
||||
metadata = body.get("metadata")
|
||||
|
||||
if metadata is not None and not isinstance(metadata, dict):
|
||||
try:
|
||||
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_METADATA",
|
||||
"message": "Invalid JSON format for metadata",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
|
||||
)
|
||||
|
||||
from sqlalchemy import select
|
||||
from app.models.entities import Document
|
||||
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.id == doc_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "DOCUMENT_NOT_FOUND",
|
||||
"message": f"Document {doc_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
document.doc_metadata = metadata
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Metadata updated",
|
||||
"metadata": document.doc_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/documents/batch-upload",
|
||||
operation_id="batchUploadDocuments",
|
||||
summary="Batch upload documents from zip",
|
||||
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
|
||||
responses={
|
||||
200: {"description": "Batch upload result"},
|
||||
400: {"description": "Bad Request - invalid zip or missing files"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def batch_upload_documents(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
kb_id: str = Form(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Batch upload documents from a zip file.
|
||||
|
||||
Zip structure:
|
||||
- Each folder contains one .md file and one metadata.json
|
||||
- metadata.json uses field_key from MetadataFieldDefinition as keys
|
||||
|
||||
Example metadata.json:
|
||||
{
|
||||
"grade": "高一",
|
||||
"subject": "数学",
|
||||
"type": "痛点"
|
||||
}
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
|
||||
f"kb_id={kb_id}, filename={file.filename}"
|
||||
)
|
||||
|
||||
if not file.filename or not file.filename.lower().endswith('.zip'):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_FORMAT",
|
||||
"message": "Only .zip files are supported",
|
||||
},
|
||||
)
|
||||
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
|
||||
if not kb:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "KB_NOT_FOUND",
|
||||
"message": f"Knowledge base {kb_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
file_content = await file.read()
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = Path(temp_dir) / "upload.zip"
|
||||
with open(zip_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
zf.extractall(temp_dir)
|
||||
except zipfile.BadZipFile as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "INVALID_ZIP",
|
||||
"message": f"Invalid zip file: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
extracted_path = Path(temp_dir)
|
||||
|
||||
# 列出解压后的所有内容,用于调试
|
||||
all_items = list(extracted_path.iterdir())
|
||||
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
|
||||
|
||||
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
|
||||
def find_document_folders(path: Path) -> list[Path]:
|
||||
"""递归查找所有包含文档文件的文件夹"""
|
||||
doc_folders = []
|
||||
|
||||
# 检查当前文件夹是否包含文档文件
|
||||
content_files = (
|
||||
list(path.glob("*.md")) +
|
||||
list(path.glob("*.markdown")) +
|
||||
list(path.glob("*.txt"))
|
||||
)
|
||||
|
||||
if content_files:
|
||||
# 这个文件夹包含文档文件,是一个文档文件夹
|
||||
doc_folders.append(path)
|
||||
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
|
||||
|
||||
# 递归检查子文件夹
|
||||
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
|
||||
doc_folders.extend(find_document_folders(subfolder))
|
||||
|
||||
return doc_folders
|
||||
|
||||
folders = find_document_folders(extracted_path)
|
||||
|
||||
if not folders:
|
||||
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "NO_DOCUMENTS_FOUND",
|
||||
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
|
||||
"details": {
|
||||
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
|
||||
"found_items": [item.name for item in all_items],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
|
||||
|
||||
for folder in folders:
|
||||
folder_name = folder.name if folder != extracted_path else "root"
|
||||
|
||||
content_files = (
|
||||
list(folder.glob("*.md")) +
|
||||
list(folder.glob("*.markdown")) +
|
||||
list(folder.glob("*.txt"))
|
||||
)
|
||||
|
||||
if not content_files:
|
||||
# 这种情况不应该发生,因为我们已经过滤过了
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": "No content file found",
|
||||
})
|
||||
continue
|
||||
|
||||
content_file = content_files[0]
|
||||
metadata_file = folder / "metadata.json"
|
||||
|
||||
metadata_dict = {}
|
||||
if metadata_file.exists():
|
||||
try:
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata_dict = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": f"Invalid metadata.json: {str(e)}",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
|
||||
|
||||
field_def_service = MetadataFieldDefinitionService(session)
|
||||
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||||
tenant_id, metadata_dict, "kb_document"
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": f"Metadata validation failed: {validation_errors}",
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(content_file, 'rb') as f:
|
||||
doc_content = f.read()
|
||||
|
||||
file_ext = content_file.suffix.lower()
|
||||
if file_ext == '.txt':
|
||||
file_type = "text/plain"
|
||||
else:
|
||||
file_type = "text/markdown"
|
||||
|
||||
doc_kb_service = KBService(session)
|
||||
document, job = await doc_kb_service.upload_document(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
file_name=content_file.name,
|
||||
file_content=doc_content,
|
||||
file_type=file_type,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||
await session.commit()
|
||||
|
||||
background_tasks.add_task(
|
||||
_index_document,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
str(job.id),
|
||||
str(document.id),
|
||||
doc_content,
|
||||
content_file.name,
|
||||
metadata_dict,
|
||||
)
|
||||
|
||||
succeeded += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"docId": str(document.id),
|
||||
"jobId": str(job.id),
|
||||
"status": "created",
|
||||
"fileName": content_file.name,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
|
||||
f"doc_id={document.id}, job_id={job.id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"folder": folder_name,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
})
|
||||
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
|
||||
|
||||
logger.info(
|
||||
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"total": len(results),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{kb_id}/documents/json-batch",
|
||||
summary="[AC-KB-03] JSON批量上传文档",
|
||||
description="上传JSONL格式文件,每行一个JSON对象,包含text和元数据字段",
|
||||
)
|
||||
async def upload_json_batch(
|
||||
kb_id: str,
|
||||
tenant_id: str = Query(..., description="租户ID"),
|
||||
file: UploadFile = File(..., description="JSONL格式文件,每行一个JSON对象"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
JSON批量上传文档
|
||||
|
||||
文件格式:JSONL (每行一个JSON对象)
|
||||
必填字段:text - 需要录入知识库的文本内容
|
||||
可选字段:元数据字段(如grade, subject, kb_scene等)
|
||||
|
||||
示例:
|
||||
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
|
||||
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
|
||||
"""
|
||||
kb = await session.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
if kb.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此知识库")
|
||||
|
||||
valid_field_keys = set()
|
||||
try:
|
||||
field_defs = await MetadataFieldDefinitionService(session).get_fields(
|
||||
tenant_id=tenant_id,
|
||||
include_inactive=False,
|
||||
)
|
||||
valid_field_keys = {f.field_key for f in field_defs}
|
||||
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
text_content = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
text_content = content.decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="文件编码不支持,请使用UTF-8编码")
|
||||
|
||||
lines = text_content.strip().split("\n")
|
||||
if not lines:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
kb_service = KBService(session)
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": f"JSON解析失败: {e}",
|
||||
})
|
||||
continue
|
||||
|
||||
text = json_obj.get("text")
|
||||
if not text:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": "缺少必填字段: text",
|
||||
})
|
||||
continue
|
||||
|
||||
metadata = {}
|
||||
for key, value in json_obj.items():
|
||||
if key == "text":
|
||||
continue
|
||||
if valid_field_keys and key not in valid_field_keys:
|
||||
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
|
||||
continue
|
||||
if value is not None:
|
||||
metadata[key] = value
|
||||
|
||||
try:
|
||||
file_name = f"json_batch_line_{line_num}.txt"
|
||||
file_content = text.encode("utf-8")
|
||||
|
||||
document, job = await kb_service.upload_document(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
file_content=file_content,
|
||||
file_type="text/plain",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
_index_document,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
str(job.id),
|
||||
str(document.id),
|
||||
file_content,
|
||||
file_name,
|
||||
metadata,
|
||||
)
|
||||
|
||||
succeeded += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": True,
|
||||
"doc_id": str(document.id),
|
||||
"job_id": str(job.id),
|
||||
"metadata": metadata,
|
||||
})
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
results.append({
|
||||
"line": line_num,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
})
|
||||
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"total": len(lines),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
|
|||
"scope": f.scope,
|
||||
"is_filterable": f.is_filterable,
|
||||
"is_rank_feature": f.is_rank_feature,
|
||||
"usage_description": f.usage_description,
|
||||
"field_roles": f.field_roles or [],
|
||||
"status": f.status,
|
||||
"version": f.version,
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ async def get_conversation_detail(
|
|||
"guardrailTriggered": user_msg.guardrail_triggered,
|
||||
"guardrailWords": user_msg.guardrail_words,
|
||||
"executionSteps": execution_steps,
|
||||
"routeTrace": user_msg.route_trace,
|
||||
"createdAt": user_msg.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
|
@ -659,8 +660,56 @@ async def _process_export(
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
|
||||
|
||||
task = await session.get(ExportTask, task_id)
|
||||
task = task_status.get(ExportTask, task_id)
|
||||
if task:
|
||||
task.status = ExportTaskStatus.FAILED.value
|
||||
task.error_message = str(e)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.get("/clarification-metrics")
|
||||
async def get_clarification_metrics(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
total_requests: int = Query(100, ge=1, description="Total requests for rate calculation"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-CLARIFY] Get clarification metrics.
|
||||
|
||||
Returns:
|
||||
- clarify_trigger_rate: 澄清触发率
|
||||
- clarify_converge_rate: 澄清后收敛率
|
||||
- misroute_rate: 误入流程率
|
||||
"""
|
||||
from app.services.intent.clarification import get_clarify_metrics
|
||||
|
||||
metrics = get_clarify_metrics()
|
||||
counts = metrics.get_metrics()
|
||||
rates = metrics.get_rates(total_requests)
|
||||
|
||||
return {
|
||||
"counts": counts,
|
||||
"rates": rates,
|
||||
"thresholds": {
|
||||
"t_high": 0.75,
|
||||
"t_low": 0.45,
|
||||
"max_retry": 3,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/clarification-metrics/reset")
|
||||
async def reset_clarification_metrics(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-CLARIFY] Reset clarification metrics.
|
||||
"""
|
||||
from app.services.intent.clarification import get_clarify_metrics
|
||||
|
||||
metrics = get_clarify_metrics()
|
||||
metrics.reset()
|
||||
|
||||
return {
|
||||
"status": "reset",
|
||||
"message": "Clarification metrics have been reset.",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class Settings(BaseSettings):
|
|||
redis_enabled: bool = True
|
||||
dashboard_cache_ttl: int = 60
|
||||
stats_counter_ttl: int = 7776000
|
||||
slot_state_cache_ttl: int = 1800
|
||||
|
||||
frontend_base_url: str = "http://localhost:3000"
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ engine = create_async_engine(
|
|||
settings.database_url,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
echo=settings.debug,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,12 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
|
|||
from app.core.database import async_session_maker
|
||||
async with async_session_maker() as session:
|
||||
await service.initialize(session)
|
||||
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
||||
if service._initialized and len(service._keys_cache) > 0:
|
||||
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
||||
elif service._initialized and len(service._keys_cache) == 0:
|
||||
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
|
||||
else:
|
||||
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -272,20 +272,24 @@ class QdrantClient:
|
|||
score_threshold: float | None = None,
|
||||
vector_name: str = "full",
|
||||
with_vectors: bool = False,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
kb_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
[AC-AISVC-10] Search vectors in tenant's collections.
|
||||
Returns results with score >= score_threshold if specified.
|
||||
Searches both old format (with @) and new format (with _) for backward compatibility.
|
||||
Searches all collections for the tenant (multi-KB support).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
query_vector: Query vector for similarity search
|
||||
limit: Maximum number of results
|
||||
limit: Maximum number of results per collection
|
||||
score_threshold: Minimum score threshold for results
|
||||
vector_name: Name of the vector to search (for multi-vector collections)
|
||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
||||
metadata_filter: Optional metadata filter to apply during search
|
||||
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
|
||||
"""
|
||||
client = await self.get_client()
|
||||
|
||||
|
|
@ -293,21 +297,36 @@ class QdrantClient:
|
|||
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||
)
|
||||
if metadata_filter:
|
||||
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
|
||||
|
||||
collection_names = [self.get_collection_name(tenant_id)]
|
||||
if '@' in tenant_id:
|
||||
old_format = f"{self._collection_prefix}{tenant_id}"
|
||||
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
||||
collection_names = [new_format, old_format]
|
||||
# 构建 Qdrant filter
|
||||
qdrant_filter = None
|
||||
if metadata_filter:
|
||||
qdrant_filter = self._build_qdrant_filter(metadata_filter)
|
||||
logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
||||
# 获取该租户的所有 collections
|
||||
collection_names = await self._get_tenant_collections(client, tenant_id)
|
||||
|
||||
# 如果指定了 kb_ids,则只搜索指定的知识库 collections
|
||||
if kb_ids:
|
||||
target_collections = []
|
||||
for kb_id in kb_ids:
|
||||
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
|
||||
if kb_collection_name in collection_names:
|
||||
target_collections.append(kb_collection_name)
|
||||
else:
|
||||
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
|
||||
collection_names = target_collections
|
||||
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
|
||||
else:
|
||||
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
|
||||
|
||||
all_hits = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||
|
||||
exists = await client.collection_exists(collection_name)
|
||||
if not exists:
|
||||
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
|
||||
|
|
@ -321,6 +340,7 @@ class QdrantClient:
|
|||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
|
||||
|
|
@ -334,6 +354,7 @@ class QdrantClient:
|
|||
limit=limit,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
query_filter=qdrant_filter,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -348,6 +369,7 @@ class QdrantClient:
|
|||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
"collection": collection_name, # 添加 collection 信息
|
||||
}
|
||||
if with_vectors and result.vector:
|
||||
hit["vector"] = result.vector
|
||||
|
|
@ -358,10 +380,6 @@ class QdrantClient:
|
|||
logger.info(
|
||||
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
for i, h in enumerate(hits[:3]):
|
||||
logger.debug(
|
||||
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||
|
|
@ -370,9 +388,10 @@ class QdrantClient:
|
|||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
||||
# 按分数排序并返回 top results
|
||||
all_hits.sort(key=lambda x: x["score"], reverse=True)
|
||||
all_hits = all_hits[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||
|
|
@ -386,6 +405,113 @@ class QdrantClient:
|
|||
|
||||
return all_hits
|
||||
|
||||
async def _get_tenant_collections(
|
||||
self,
|
||||
client: AsyncQdrantClient,
|
||||
tenant_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取指定租户的所有 collections。
|
||||
优先从 Redis 缓存获取,未缓存则从 Qdrant 查询并缓存。
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
Collection 名称列表
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
cache_key = f"collections:{tenant_id}"
|
||||
|
||||
try:
|
||||
# 确保 Redis 连接已初始化
|
||||
redis_client = await cache_service._get_redis()
|
||||
if redis_client and cache_service._enabled:
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
import json
|
||||
collections = json.loads(cached)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
|
||||
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
|
||||
)
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
|
||||
|
||||
# 2. 从 Qdrant 查询
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
prefix = f"{self._collection_prefix}{safe_tenant_id}"
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
tenant_collections = [
|
||||
c.name for c in collections.collections
|
||||
if c.name.startswith(prefix)
|
||||
]
|
||||
|
||||
# 按名称排序
|
||||
tenant_collections.sort()
|
||||
|
||||
db_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
|
||||
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
|
||||
)
|
||||
|
||||
# 3. 缓存结果(5分钟 TTL)
|
||||
try:
|
||||
redis_client = await cache_service._get_redis()
|
||||
if redis_client and cache_service._enabled:
|
||||
import json
|
||||
await redis_client.setex(
|
||||
cache_key,
|
||||
300, # 5分钟
|
||||
json.dumps(tenant_collections)
|
||||
)
|
||||
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
|
||||
|
||||
return tenant_collections
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
|
||||
return [self.get_collection_name(tenant_id)]
|
||||
|
||||
def _build_qdrant_filter(
|
||||
self,
|
||||
metadata_filter: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
构建 Qdrant 过滤条件。
|
||||
|
||||
Args:
|
||||
metadata_filter: 元数据过滤条件,如 {"grade": "三年级", "subject": "语文"}
|
||||
|
||||
Returns:
|
||||
Qdrant Filter 对象
|
||||
"""
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
must_conditions = []
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
# 支持嵌套 metadata 字段,如 metadata.grade
|
||||
field_path = f"metadata.{key}"
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
|
||||
return Filter(must=must_conditions) if must_conditions else None
|
||||
|
||||
async def delete_collection(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Delete tenant's collection.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.api.admin import (
|
|||
monitoring_router,
|
||||
prompt_templates_router,
|
||||
rag_router,
|
||||
scene_slot_bundle_router,
|
||||
script_flows_router,
|
||||
sessions_router,
|
||||
slot_definition_router,
|
||||
|
|
@ -55,6 +56,11 @@ logging.basicConfig(
|
|||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -88,6 +94,28 @@ async def lifespan(app: FastAPI):
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
from app.services.mid.tool_guide_registry import init_tool_guide_registry
|
||||
|
||||
logger.info("[ToolGuideRegistry] Starting tool guides initialization...")
|
||||
tool_guide_registry = init_tool_guide_registry()
|
||||
logger.info(f"[ToolGuideRegistry] Tool guides loaded: {tool_guide_registry.list_tools()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolRegistry] Tools initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
# [AC-AISVC-29] 预初始化 Embedding 服务,避免首次查询时的延迟
|
||||
try:
|
||||
from app.services.embedding import get_embedding_provider
|
||||
|
||||
logger.info("[AC-AISVC-29] Pre-initializing embedding service...")
|
||||
embedding_provider = await get_embedding_provider()
|
||||
logger.info(
|
||||
f"[AC-AISVC-29] Embedding service pre-initialized: "
|
||||
f"provider={embedding_provider.PROVIDER_NAME}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-29] Embedding service pre-initialization FAILED: {e}", exc_info=True)
|
||||
|
||||
yield
|
||||
|
||||
await close_db()
|
||||
|
|
@ -171,6 +199,7 @@ app.include_router(metadata_schema_router)
|
|||
app.include_router(monitoring_router)
|
||||
app.include_router(prompt_templates_router)
|
||||
app.include_router(rag_router)
|
||||
app.include_router(scene_slot_bundle_router)
|
||||
app.include_router(script_flows_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(slot_definition_router)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class ChatMessage(SQLModel, table=True):
|
|||
[AC-AISVC-13] Chat message entity with tenant isolation.
|
||||
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
||||
[v0.7.0] Extended with monitoring fields for Dashboard statistics.
|
||||
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
|
|
@ -90,6 +91,11 @@ class ChatMessage(SQLModel, table=True):
|
|||
sa_column=Column("guardrail_words", JSON, nullable=True),
|
||||
description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
|
||||
)
|
||||
route_trace: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("route_trace", JSON, nullable=True),
|
||||
description="[v0.8.0] Intent routing trace log for hybrid routing observability"
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionCreate(SQLModel):
|
||||
|
|
@ -227,6 +233,7 @@ class Document(SQLModel, table=True):
|
|||
file_type: str | None = Field(default=None, description="File MIME type")
|
||||
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
|
||||
error_msg: str | None = Field(default=None, description="Error message if failed")
|
||||
doc_metadata: dict | None = Field(default=None, sa_type=JSON, description="Document metadata as JSON")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
|
@ -421,6 +428,7 @@ class IntentRule(SQLModel, table=True):
|
|||
[AC-AISVC-65] Intent rule entity with tenant isolation.
|
||||
Supports keyword and regex matching for intent recognition.
|
||||
[AC-IDSMETA-16] Extended with metadata field for unified storage structure.
|
||||
[v0.8.0] Extended with intent_vector and semantic_examples for hybrid routing.
|
||||
"""
|
||||
|
||||
__tablename__ = "intent_rules"
|
||||
|
|
@ -458,6 +466,16 @@ class IntentRule(SQLModel, table=True):
|
|||
sa_column=Column("metadata", JSON, nullable=True),
|
||||
description="[AC-IDSMETA-16] Structured metadata for the intent rule"
|
||||
)
|
||||
intent_vector: list[float] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("intent_vector", JSON, nullable=True),
|
||||
description="[v0.8.0] Pre-computed intent vector for semantic matching"
|
||||
)
|
||||
semantic_examples: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("semantic_examples", JSON, nullable=True),
|
||||
description="[v0.8.0] Semantic example sentences for dynamic vector computation"
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
|
@ -475,6 +493,8 @@ class IntentRuleCreate(SQLModel):
|
|||
fixed_reply: str | None = None
|
||||
transfer_message: str | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
intent_vector: list[float] | None = None
|
||||
semantic_examples: list[str] | None = None
|
||||
|
||||
|
||||
class IntentRuleUpdate(SQLModel):
|
||||
|
|
@ -491,6 +511,8 @@ class IntentRuleUpdate(SQLModel):
|
|||
transfer_message: str | None = None
|
||||
is_enabled: bool | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
intent_vector: list[float] | None = None
|
||||
semantic_examples: list[str] | None = None
|
||||
|
||||
|
||||
class IntentMatchResult:
|
||||
|
|
@ -810,6 +832,24 @@ class FlowStep(SQLModel):
|
|||
default=None,
|
||||
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
|
||||
)
|
||||
allowed_kb_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Allowed knowledge base IDs for this step. If set, KB search will be restricted to these KBs."
|
||||
)
|
||||
preferred_kb_ids: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Preferred knowledge base IDs for this step. These KBs will be searched first."
|
||||
)
|
||||
kb_query_hint: str | None = Field(
|
||||
default=None,
|
||||
description="[Step-KB-Binding] Query hint for KB search in this step, helps improve retrieval accuracy."
|
||||
)
|
||||
max_kb_calls_per_step: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
le=5,
|
||||
description="[Step-KB-Binding] Max KB calls allowed per step. Default is 1 if not set."
|
||||
)
|
||||
|
||||
|
||||
class ScriptFlowCreate(SQLModel):
|
||||
|
|
@ -1078,6 +1118,7 @@ class MetadataFieldDefinition(SQLModel, table=True):
|
|||
)
|
||||
is_filterable: bool = Field(default=True, description="是否可用于过滤")
|
||||
is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
|
||||
usage_description: str | None = Field(default=None, description="用途说明")
|
||||
field_roles: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
|
||||
|
|
@ -1104,6 +1145,7 @@ class MetadataFieldDefinitionCreate(SQLModel):
|
|||
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
|
||||
is_filterable: bool = Field(default=True)
|
||||
is_rank_feature: bool = Field(default=False)
|
||||
usage_description: str | None = None
|
||||
field_roles: list[str] = Field(default_factory=list)
|
||||
status: str = Field(default=MetadataFieldStatus.DRAFT.value)
|
||||
|
||||
|
|
@ -1118,6 +1160,7 @@ class MetadataFieldDefinitionUpdate(SQLModel):
|
|||
scope: list[str] | None = None
|
||||
is_filterable: bool | None = None
|
||||
is_rank_feature: bool | None = None
|
||||
usage_description: str | None = None
|
||||
field_roles: list[str] | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
|
@ -1131,6 +1174,17 @@ class ExtractStrategy(str, Enum):
|
|||
USER_INPUT = "user_input"
|
||||
|
||||
|
||||
class ExtractFailureType(str, Enum):
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 提取失败类型
|
||||
统一失败分类,用于追踪和日志
|
||||
"""
|
||||
EXTRACT_EMPTY = "EXTRACT_EMPTY" # 提取结果为空
|
||||
EXTRACT_PARSE_FAIL = "EXTRACT_PARSE_FAIL" # 解析失败
|
||||
EXTRACT_VALIDATION_FAIL = "EXTRACT_VALIDATION_FAIL" # 校验失败
|
||||
EXTRACT_RUNTIME_ERROR = "EXTRACT_RUNTIME_ERROR" # 运行时错误
|
||||
|
||||
|
||||
class SlotValueSource(str, Enum):
|
||||
"""
|
||||
[AC-MRS-09] 槽位值来源
|
||||
|
|
@ -1145,6 +1199,7 @@ class SlotDefinition(SQLModel, table=True):
|
|||
"""
|
||||
[AC-MRS-07,08] 槽位定义表
|
||||
独立的槽位定义模型,与元数据字段解耦但可复用
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||||
"""
|
||||
|
||||
__tablename__ = "slot_definitions"
|
||||
|
|
@ -1162,14 +1217,31 @@ class SlotDefinition(SQLModel, table=True):
|
|||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名,例:grade -> '当前年级'",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str = Field(
|
||||
default=MetadataFieldType.STRING.value,
|
||||
description="槽位类型: string/number/boolean/enum/array_enum"
|
||||
)
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容读取
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="提取策略: rule/llm/user_input"
|
||||
description="[兼容字段] 提取策略: rule/llm/user_input,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("extract_strategies", JSON, nullable=True),
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
validation_rule: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -1192,14 +1264,72 @@ class SlotDefinition(SQLModel, table=True):
|
|||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||
|
||||
def get_effective_strategies(self) -> list[str]:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 获取有效的提取策略链
|
||||
优先使用 extract_strategies,如果不存在则兼容读取 extract_strategy
|
||||
"""
|
||||
if self.extract_strategies and len(self.extract_strategies) > 0:
|
||||
return self.extract_strategies
|
||||
if self.extract_strategy:
|
||||
return [self.extract_strategy]
|
||||
return []
|
||||
|
||||
def validate_strategies(self) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||||
|
||||
Returns:
|
||||
Tuple of (是否有效, 错误信息)
|
||||
"""
|
||||
valid_strategies = {"rule", "llm", "user_input"}
|
||||
strategies = self.get_effective_strategies()
|
||||
|
||||
if not strategies:
|
||||
return True, "" # 空策略链视为有效(使用默认行为)
|
||||
|
||||
# 校验至少1个策略
|
||||
if len(strategies) == 0:
|
||||
return False, "提取策略链不能为空"
|
||||
|
||||
# 校验不允许重复策略
|
||||
if len(strategies) != len(set(strategies)):
|
||||
return False, "提取策略链中不允许重复的策略"
|
||||
|
||||
# 校验策略值有效
|
||||
invalid = [s for s in strategies if s not in valid_strategies]
|
||||
if invalid:
|
||||
return False, f"无效的提取策略: {invalid},有效值为: {list(valid_strategies)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
class SlotDefinitionCreate(SQLModel):
|
||||
"""[AC-MRS-07,08] 创建槽位定义"""
|
||||
|
||||
slot_key: str = Field(..., min_length=1, max_length=100)
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str = Field(default=MetadataFieldType.STRING.value)
|
||||
required: bool = Field(default=False)
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
@ -1209,9 +1339,28 @@ class SlotDefinitionCreate(SQLModel):
|
|||
class SlotDefinitionUpdate(SQLModel):
|
||||
"""[AC-MRS-07] 更新槽位定义"""
|
||||
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="槽位名称,给运营/教研看的中文名",
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||
max_length=500,
|
||||
)
|
||||
type: str | None = None
|
||||
required: bool | None = None
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
@ -1522,3 +1671,107 @@ class MidAuditLog(SQLModel, table=True):
|
|||
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
|
||||
latency_ms: int | None = Field(default=None, description="总耗时(ms)")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
|
||||
|
||||
|
||||
class SceneSlotBundleStatus(str, Enum):
|
||||
"""[AC-SCENE-SLOT-01] 场景槽位包状态"""
|
||||
DRAFT = "draft"
|
||||
ACTIVE = "active"
|
||||
DEPRECATED = "deprecated"
|
||||
|
||||
|
||||
class SceneSlotBundle(SQLModel, table=True):
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置
|
||||
定义每个场景需要采集的槽位集合
|
||||
|
||||
三层关系:
|
||||
- 层1:slot ↔ metadata(通过 linked_field_id)
|
||||
- 层2:scene ↔ slot_bundle(本模型)
|
||||
- 层3:step.expected_variables ↔ slot_key(话术步骤引用)
|
||||
"""
|
||||
|
||||
__tablename__ = "scene_slot_bundles"
|
||||
__table_args__ = (
|
||||
Index("ix_scene_slot_bundles_tenant", "tenant_id"),
|
||||
Index("ix_scene_slot_bundles_tenant_scene", "tenant_id", "scene_key", unique=True),
|
||||
Index("ix_scene_slot_bundles_tenant_status", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
scene_key: str = Field(
|
||||
...,
|
||||
description="场景标识,如 'open_consult', 'refund_apply', 'course_recommend'",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
scene_name: str = Field(
|
||||
...,
|
||||
description="场景名称,如 '开放咨询', '退款申请', '课程推荐'",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="场景描述"
|
||||
)
|
||||
required_slots: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("required_slots", JSON, nullable=False),
|
||||
description="必填槽位 slot_key 列表"
|
||||
)
|
||||
optional_slots: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column("optional_slots", JSON, nullable=False),
|
||||
description="可选槽位 slot_key 列表"
|
||||
)
|
||||
slot_priority: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("slot_priority", JSON, nullable=True),
|
||||
description="槽位采集优先级顺序(slot_key 列表)"
|
||||
)
|
||||
completion_threshold: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="完成阈值(0.0-1.0),必填槽位填充比例达到此值视为完成"
|
||||
)
|
||||
ask_back_order: str = Field(
|
||||
default="priority",
|
||||
description="追问顺序策略: priority/required_first/parallel"
|
||||
)
|
||||
status: str = Field(
|
||||
default=SceneSlotBundleStatus.DRAFT.value,
|
||||
description="状态: draft/active/deprecated"
|
||||
)
|
||||
version: int = Field(default=1, description="版本号")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||
|
||||
|
||||
class SceneSlotBundleCreate(SQLModel):
|
||||
"""[AC-SCENE-SLOT-01] 创建场景槽位包"""
|
||||
|
||||
scene_key: str = Field(..., min_length=1, max_length=100)
|
||||
scene_name: str = Field(..., min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
required_slots: list[str] = Field(default_factory=list)
|
||||
optional_slots: list[str] = Field(default_factory=list)
|
||||
slot_priority: list[str] | None = None
|
||||
completion_threshold: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
ask_back_order: str = Field(default="priority")
|
||||
status: str = Field(default=SceneSlotBundleStatus.DRAFT.value)
|
||||
|
||||
|
||||
class SceneSlotBundleUpdate(SQLModel):
|
||||
"""[AC-SCENE-SLOT-01] 更新场景槽位包"""
|
||||
|
||||
scene_name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
required_slots: list[str] | None = None
|
||||
optional_slots: list[str] | None = None
|
||||
slot_priority: list[str] | None = None
|
||||
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
ask_back_order: str | None = None
|
||||
status: str | None = None
|
||||
|
|
|
|||
|
|
@ -139,7 +139,16 @@ class SlotDefinitionResponse(BaseModel):
|
|||
slot_key: str = Field(..., description="槽位键名")
|
||||
type: str = Field(..., description="槽位类型")
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
extract_strategy: str | None = Field(default=None, description="提取策略")
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input"
|
||||
)
|
||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||
default_value: dict[str, Any] | None = Field(default=None, description="默认值")
|
||||
|
|
@ -157,9 +166,15 @@ class SlotDefinitionCreateRequest(BaseModel):
|
|||
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
|
||||
type: str = Field(default="string", description="槽位类型")
|
||||
required: bool = Field(default=False, description="是否必填槽位")
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="提取策略: rule/llm/user_input"
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||
|
|
@ -172,7 +187,16 @@ class SlotDefinitionUpdateRequest(BaseModel):
|
|||
|
||||
type: str | None = None
|
||||
required: bool | None = None
|
||||
extract_strategy: str | None = None
|
||||
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||
extract_strategies: list[str] | None = Field(
|
||||
default=None,
|
||||
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||
extract_strategy: str | None = Field(
|
||||
default=None,
|
||||
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||
)
|
||||
validation_rule: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ class ApiKeyService:
|
|||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Backward-compat fallback for environments without new columns
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
|
|
@ -20,6 +23,7 @@ from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
||||
EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding"
|
||||
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
|
|
@ -170,8 +174,32 @@ class EmbeddingConfigManager:
|
|||
self._config = self._default_config.copy()
|
||||
self._provider: EmbeddingProvider | None = None
|
||||
|
||||
self._settings = get_settings()
|
||||
self._redis_client: redis.Redis | None = None
|
||||
|
||||
self._load_from_redis()
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._provider_name = saved.get("provider", self._default_provider)
|
||||
self._config = saved.get("config", self._default_config.copy())
|
||||
logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load embedding config from Redis: {e}")
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
|
|
@ -184,6 +212,28 @@ class EmbeddingConfigManager:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to load embedding config from file: {e}")
|
||||
|
||||
def _save_to_redis(self) -> None:
|
||||
"""Save configuration to Redis."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
if self._redis_client is None:
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
self._redis_client.set(
|
||||
EMBEDDING_CONFIG_REDIS_KEY,
|
||||
json.dumps({
|
||||
"provider": self._provider_name,
|
||||
"config": self._config,
|
||||
}, ensure_ascii=False),
|
||||
)
|
||||
logger.info(f"Saved embedding config to Redis: provider={self._provider_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save embedding config to Redis: {e}")
|
||||
|
||||
def _save_to_file(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
try:
|
||||
|
|
@ -262,6 +312,7 @@ class EmbeddingConfigManager:
|
|||
self._config = config
|
||||
self._provider = new_provider_instance
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"Updated embedding config: provider={provider}")
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class FlowEngine:
|
|||
stmt = select(FlowInstance).where(
|
||||
FlowInstance.tenant_id == tenant_id,
|
||||
FlowInstance.session_id == session_id,
|
||||
).order_by(col(FlowInstance.created_at).desc())
|
||||
).order_by(col(FlowInstance.started_at).desc())
|
||||
result = await self._session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ class IntentRuleService:
|
|||
is_enabled=True,
|
||||
hit_count=0,
|
||||
metadata_=create_data.metadata_,
|
||||
intent_vector=create_data.intent_vector,
|
||||
semantic_examples=create_data.semantic_examples,
|
||||
)
|
||||
self._session.add(rule)
|
||||
await self._session.flush()
|
||||
|
|
@ -195,6 +197,10 @@ class IntentRuleService:
|
|||
rule.is_enabled = update_data.is_enabled
|
||||
if update_data.metadata_ is not None:
|
||||
rule.metadata_ = update_data.metadata_
|
||||
if update_data.intent_vector is not None:
|
||||
rule.intent_vector = update_data.intent_vector
|
||||
if update_data.semantic_examples is not None:
|
||||
rule.semantic_examples = update_data.semantic_examples
|
||||
|
||||
rule.updated_at = datetime.utcnow()
|
||||
await self._session.flush()
|
||||
|
|
@ -267,7 +273,7 @@ class IntentRuleService:
|
|||
select(IntentRule)
|
||||
.where(
|
||||
IntentRule.tenant_id == tenant_id,
|
||||
IntentRule.is_enabled == True,
|
||||
IntentRule.is_enabled == True, # noqa: E712
|
||||
)
|
||||
.order_by(col(IntentRule.priority).desc())
|
||||
)
|
||||
|
|
@ -300,6 +306,8 @@ class IntentRuleService:
|
|||
"is_enabled": rule.is_enabled,
|
||||
"hit_count": rule.hit_count,
|
||||
"metadata": rule.metadata_,
|
||||
"created_at": rule.created_at.isoformat(),
|
||||
"updated_at": rule.updated_at.isoformat(),
|
||||
"intent_vector": rule.intent_vector,
|
||||
"semantic_examples": rule.semantic_examples,
|
||||
"created_at": rule.created_at.isoformat() if rule.created_at else None,
|
||||
"updated_at": rule.updated_at.isoformat() if rule.updated_at else None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class KBService:
|
|||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> tuple[Document, IndexJob]:
|
||||
"""
|
||||
[AC-ASA-01] Upload document and create indexing job.
|
||||
|
|
@ -108,6 +109,7 @@ class KBService:
|
|||
file_size=len(file_content),
|
||||
file_type=file_type,
|
||||
status=DocumentStatus.PENDING.value,
|
||||
doc_metadata=metadata,
|
||||
)
|
||||
self._session.add(document)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,14 @@ LLM Adapter module for AI Service.
|
|||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
||||
"""
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.base import (
|
||||
LLMClient,
|
||||
LLMConfig,
|
||||
LLMResponse,
|
||||
LLMStreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -12,4 +19,6 @@ __all__ = [
|
|||
"LLMResponse",
|
||||
"LLMStreamChunk",
|
||||
"OpenAIClient",
|
||||
"ToolCall",
|
||||
"ToolDefinition",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -28,17 +28,45 @@ class LLMConfig:
|
|||
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""
|
||||
Represents a function call from the LLM.
|
||||
Used in Function Calling mode.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
import json
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""
|
||||
Response from LLM generation.
|
||||
[AC-AISVC-02] Contains generated content and metadata.
|
||||
"""
|
||||
content: str
|
||||
model: str
|
||||
content: str | None = None
|
||||
model: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -50,9 +78,33 @@ class LLMStreamChunk:
|
|||
delta: str
|
||||
model: str
|
||||
finish_reason: str | None = None
|
||||
tool_calls_delta: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""
|
||||
Tool definition for Function Calling.
|
||||
Compatible with OpenAI/DeepSeek function calling format.
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any]
|
||||
type: str = "function"
|
||||
|
||||
def to_openai_format(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI tools format."""
|
||||
return {
|
||||
"type": self.type,
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""
|
||||
Abstract base class for LLM clients.
|
||||
|
|
@ -67,6 +119,8 @@ class LLMClient(ABC):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
|
@ -76,10 +130,12 @@ class LLMClient(ABC):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
LLMResponse with generated content, tool_calls, and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
|
|
@ -91,6 +147,8 @@ class LLMClient(ABC):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
|
|
@ -100,6 +158,8 @@ class LLMClient(ABC):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
|
|
|
|||
|
|
@ -11,12 +11,16 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.llm.base import LLMClient, LLMConfig
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLM_CONFIG_FILE = Path("config/llm_config.json")
|
||||
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -286,6 +290,8 @@ class LLMConfigManager:
|
|||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
self._settings = settings
|
||||
self._redis_client: redis.Redis | None = None
|
||||
|
||||
self._current_provider: str = settings.llm_provider
|
||||
self._current_config: dict[str, Any] = {
|
||||
|
|
@ -299,8 +305,75 @@ class LLMConfigManager:
|
|||
}
|
||||
self._client: LLMClient | None = None
|
||||
|
||||
self._load_from_redis()
|
||||
self._load_from_file()
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
def _save_to_redis(self) -> None:
|
||||
"""Save configuration to Redis."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
if self._redis_client is None:
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
self._redis_client.set(
|
||||
LLM_CONFIG_REDIS_KEY,
|
||||
json.dumps({
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}, ensure_ascii=False),
|
||||
)
|
||||
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
|
|
@ -364,6 +437,7 @@ class LLMConfigManager:
|
|||
self._current_provider = provider
|
||||
self._current_config = validated_config
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,14 @@ from tenacity import (
|
|||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.base import (
|
||||
LLMClient,
|
||||
LLMConfig,
|
||||
LLMResponse,
|
||||
LLMStreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -95,6 +102,8 @@ class OpenAIClient(LLMClient):
|
|||
messages: list[dict[str, str]],
|
||||
config: LLMConfig,
|
||||
stream: bool = False,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request body for OpenAI API."""
|
||||
|
|
@ -106,6 +115,13 @@ class OpenAIClient(LLMClient):
|
|||
"top_p": config.top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
body["tools"] = [tool.to_openai_format() for tool in tools]
|
||||
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
body.update(config.extra_params)
|
||||
body.update(kwargs)
|
||||
return body
|
||||
|
|
@ -119,6 +135,8 @@ class OpenAIClient(LLMClient):
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
|
@ -128,10 +146,12 @@ class OpenAIClient(LLMClient):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
LLMResponse with generated content, tool_calls, and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
|
|
@ -140,9 +160,14 @@ class OpenAIClient(LLMClient):
|
|||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
||||
body = self._build_request_body(
|
||||
messages, effective_config, stream=False,
|
||||
tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||
if tools:
|
||||
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
|
||||
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
|
|
@ -177,14 +202,18 @@ class OpenAIClient(LLMClient):
|
|||
|
||||
try:
|
||||
choice = data["choices"][0]
|
||||
content = choice["message"]["content"]
|
||||
message = choice["message"]
|
||||
content = message.get("content")
|
||||
usage = data.get("usage", {})
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
tool_calls = self._parse_tool_calls(message)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Generated response: "
|
||||
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
||||
f"finish_reason={finish_reason}"
|
||||
f"finish_reason={finish_reason}, "
|
||||
f"tool_calls={len(tool_calls)}"
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
|
|
@ -192,6 +221,7 @@ class OpenAIClient(LLMClient):
|
|||
model=data.get("model", effective_config.model),
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
metadata={"raw_response": data},
|
||||
)
|
||||
|
||||
|
|
@ -201,11 +231,34 @@ class OpenAIClient(LLMClient):
|
|||
message=f"Unexpected LLM response format: {e}",
|
||||
details=[{"response": str(data)}],
|
||||
)
|
||||
|
||||
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
|
||||
"""Parse tool calls from LLM response message."""
|
||||
tool_calls = []
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
|
||||
for tc in raw_tool_calls:
|
||||
if tc.get("type") == "function":
|
||||
func = tc.get("function", {})
|
||||
try:
|
||||
arguments = json.loads(func.get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append(ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=arguments,
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
|
|
@ -215,6 +268,8 @@ class OpenAIClient(LLMClient):
|
|||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
tools: Optional list of tools for function calling.
|
||||
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
|
|
@ -227,9 +282,14 @@ class OpenAIClient(LLMClient):
|
|||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
||||
body = self._build_request_body(
|
||||
messages, effective_config, stream=True,
|
||||
tools=tools, tool_choice=tool_choice, **kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||
if tools:
|
||||
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
|
||||
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
|
|
|
|||
|
|
@ -39,6 +39,19 @@ class MetadataFieldDefinitionService:
|
|||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
清除租户的元数据字段缓存
|
||||
在字段创建、更新、删除时调用
|
||||
"""
|
||||
try:
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
await cache_service.invalidate(tenant_id)
|
||||
except Exception as e:
|
||||
# 缓存失效失败不影响主流程
|
||||
logger.warning(f"[AC-IDSMETA-13] Failed to invalidate cache: {e}")
|
||||
|
||||
async def list_field_definitions(
|
||||
self,
|
||||
|
|
@ -180,6 +193,9 @@ class MetadataFieldDefinitionService:
|
|||
self._session.add(field)
|
||||
await self._session.flush()
|
||||
|
||||
# 清除缓存,使新字段在下次查询时生效
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
|
||||
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
|
||||
|
|
@ -223,6 +239,10 @@ class MetadataFieldDefinitionService:
|
|||
field.is_filterable = field_update.is_filterable
|
||||
if field_update.is_rank_feature is not None:
|
||||
field.is_rank_feature = field_update.is_rank_feature
|
||||
# [AC-MRS-01] 修复:添加 field_roles 更新逻辑
|
||||
if field_update.field_roles is not None:
|
||||
self._validate_field_roles(field_update.field_roles)
|
||||
field.field_roles = field_update.field_roles
|
||||
if field_update.status is not None:
|
||||
old_status = field.status
|
||||
field.status = field_update.status
|
||||
|
|
@ -235,6 +255,9 @@ class MetadataFieldDefinitionService:
|
|||
field.updated_at = datetime.utcnow()
|
||||
await self._session.flush()
|
||||
|
||||
# 清除缓存,使更新在下次查询时生效
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
|
||||
f"field_id={field_id}, version={field.version}"
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ RAG Optimization (rag-optimization/spec.md):
|
|||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sse_starlette.sse import ServerSentEvent
|
||||
|
|
@ -46,7 +46,6 @@ from app.services.flow.engine import FlowEngine
|
|||
from app.services.guardrail.behavior_service import BehaviorRuleService
|
||||
from app.services.guardrail.input_scanner import InputScanner
|
||||
from app.services.guardrail.output_filter import OutputFilter
|
||||
from app.services.guardrail.word_service import ForbiddenWordService
|
||||
from app.services.intent.router import IntentRouter
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
||||
|
|
@ -90,6 +89,8 @@ class GenerationContext:
|
|||
10. confidence_result: Confidence calculation result
|
||||
11. messages_saved: Whether messages were saved
|
||||
12. final_response: Final ChatResponse
|
||||
|
||||
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||
"""
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
|
|
@ -115,6 +116,11 @@ class GenerationContext:
|
|||
target_kb_ids: list[str] | None = None
|
||||
behavior_rules: list[str] = field(default_factory=list)
|
||||
|
||||
# [v0.8.0] Hybrid routing fields
|
||||
route_trace: dict[str, Any] | None = None
|
||||
fusion_confidence: float | None = None
|
||||
fusion_decision_reason: str | None = None
|
||||
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
execution_steps: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
|
@ -487,7 +493,7 @@ class OrchestratorService:
|
|||
finish_reason="flow_step",
|
||||
)
|
||||
ctx.diagnostics["flow_handled"] = True
|
||||
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM")
|
||||
logger.info("[AC-AISVC-75] Flow provided reply, skipping LLM")
|
||||
|
||||
else:
|
||||
ctx.diagnostics["flow_check_enabled"] = True
|
||||
|
|
@ -501,8 +507,8 @@ class OrchestratorService:
|
|||
"""
|
||||
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
|
||||
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
|
||||
[v0.8.0] Upgraded to use match_hybrid() for hybrid routing.
|
||||
"""
|
||||
# Skip if flow already handled the request
|
||||
if ctx.diagnostics.get("flow_handled"):
|
||||
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
|
||||
return
|
||||
|
|
@ -513,7 +519,6 @@ class OrchestratorService:
|
|||
return
|
||||
|
||||
try:
|
||||
# Load enabled rules ordered by priority
|
||||
async with get_session() as session:
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
rule_service = IntentRuleService(session)
|
||||
|
|
@ -524,33 +529,64 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_matched"] = False
|
||||
return
|
||||
|
||||
# Match intent
|
||||
ctx.intent_match = self._intent_router.match(
|
||||
fusion_result = await self._intent_router.match_hybrid(
|
||||
message=ctx.current_message,
|
||||
rules=rules,
|
||||
tenant_id=ctx.tenant_id,
|
||||
)
|
||||
|
||||
if ctx.intent_match:
|
||||
ctx.route_trace = fusion_result.trace.to_dict()
|
||||
ctx.fusion_confidence = fusion_result.final_confidence
|
||||
ctx.fusion_decision_reason = fusion_result.decision_reason
|
||||
|
||||
if fusion_result.final_intent:
|
||||
ctx.intent_match = type(
|
||||
"IntentMatchResult",
|
||||
(),
|
||||
{
|
||||
"rule": fusion_result.final_intent,
|
||||
"match_type": fusion_result.decision_reason,
|
||||
"matched": "",
|
||||
"to_dict": lambda: {
|
||||
"rule_id": str(fusion_result.final_intent.id),
|
||||
"rule_name": fusion_result.final_intent.name,
|
||||
"match_type": fusion_result.decision_reason,
|
||||
"matched": "",
|
||||
"response_type": fusion_result.final_intent.response_type,
|
||||
"target_kb_ids": (
|
||||
fusion_result.final_intent.target_kb_ids or []
|
||||
),
|
||||
"flow_id": (
|
||||
str(fusion_result.final_intent.flow_id)
|
||||
if fusion_result.final_intent.flow_id else None
|
||||
),
|
||||
"fixed_reply": fusion_result.final_intent.fixed_reply,
|
||||
"transfer_message": fusion_result.final_intent.transfer_message,
|
||||
},
|
||||
},
|
||||
)()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, "
|
||||
f"response_type={ctx.intent_match.rule.response_type}"
|
||||
f"[AC-AISVC-69] Intent matched: rule={fusion_result.final_intent.name}, "
|
||||
f"response_type={fusion_result.final_intent.response_type}, "
|
||||
f"decision={fusion_result.decision_reason}, "
|
||||
f"confidence={fusion_result.final_confidence:.3f}"
|
||||
)
|
||||
|
||||
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
|
||||
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||
|
||||
# Increment hit count
|
||||
async with get_session() as session:
|
||||
rule_service = IntentRuleService(session)
|
||||
await rule_service.increment_hit_count(
|
||||
tenant_id=ctx.tenant_id,
|
||||
rule_id=ctx.intent_match.rule.id,
|
||||
rule_id=fusion_result.final_intent.id,
|
||||
)
|
||||
|
||||
# Route based on response_type
|
||||
if ctx.intent_match.rule.response_type == "fixed":
|
||||
# Fixed reply - skip LLM
|
||||
rule = fusion_result.final_intent
|
||||
if rule.response_type == "fixed":
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。",
|
||||
content=rule.fixed_reply or "收到您的消息。",
|
||||
model="intent_fixed",
|
||||
usage={},
|
||||
finish_reason="intent_fixed",
|
||||
|
|
@ -558,20 +594,18 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "rag":
|
||||
# RAG with target KBs
|
||||
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
|
||||
elif rule.response_type == "rag":
|
||||
ctx.target_kb_ids = rule.target_kb_ids or []
|
||||
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "flow":
|
||||
# Start script flow
|
||||
if ctx.intent_match.rule.flow_id and self._flow_engine:
|
||||
elif rule.response_type == "flow":
|
||||
if rule.flow_id and self._flow_engine:
|
||||
async with get_session() as session:
|
||||
flow_engine = FlowEngine(session)
|
||||
instance, first_step = await flow_engine.start(
|
||||
tenant_id=ctx.tenant_id,
|
||||
session_id=ctx.session_id,
|
||||
flow_id=ctx.intent_match.rule.flow_id,
|
||||
flow_id=rule.flow_id,
|
||||
)
|
||||
if first_step:
|
||||
ctx.llm_response = LLMResponse(
|
||||
|
|
@ -583,10 +617,9 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
|
||||
|
||||
elif ctx.intent_match.rule.response_type == "transfer":
|
||||
# Transfer to human
|
||||
elif rule.response_type == "transfer":
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...",
|
||||
content=rule.transfer_message or "正在为您转接人工客服...",
|
||||
model="intent_transfer",
|
||||
usage={},
|
||||
finish_reason="intent_transfer",
|
||||
|
|
@ -600,9 +633,25 @@ class OrchestratorService:
|
|||
ctx.diagnostics["intent_handled"] = True
|
||||
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
|
||||
|
||||
if fusion_result.need_clarify:
|
||||
ctx.diagnostics["need_clarify"] = True
|
||||
ctx.diagnostics["clarify_candidates"] = [
|
||||
{"id": str(r.id), "name": r.name}
|
||||
for r in (fusion_result.clarify_candidates or [])
|
||||
]
|
||||
logger.info(
|
||||
f"[AC-AISVC-121] Low confidence, need clarify: "
|
||||
f"confidence={fusion_result.final_confidence:.3f}, "
|
||||
f"candidates={len(fusion_result.clarify_candidates or [])}"
|
||||
)
|
||||
|
||||
else:
|
||||
ctx.diagnostics["intent_match_enabled"] = True
|
||||
ctx.diagnostics["intent_matched"] = False
|
||||
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||
logger.info(
|
||||
f"[AC-AISVC-69] No intent matched, decision={fusion_result.decision_reason}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
|
||||
|
|
@ -724,43 +773,43 @@ class OrchestratorService:
|
|||
async def _build_metadata_filters(self, ctx: GenerationContext):
|
||||
"""
|
||||
[AC-IDSMETA-19] Build metadata filters from context.
|
||||
|
||||
|
||||
Sources:
|
||||
1. Intent rule metadata (if matched)
|
||||
2. Session metadata
|
||||
3. Request metadata
|
||||
4. Extracted slots from conversation
|
||||
|
||||
|
||||
Returns:
|
||||
TagFilter with at least grade, subject, scene if available
|
||||
"""
|
||||
from app.services.retrieval.metadata import TagFilter
|
||||
|
||||
|
||||
filter_fields = {}
|
||||
|
||||
|
||||
# 1. From intent rule metadata
|
||||
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
|
||||
intent_metadata = ctx.intent_match.rule.metadata_
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in intent_metadata:
|
||||
filter_fields[key] = intent_metadata[key]
|
||||
|
||||
|
||||
# 2. From session/request metadata
|
||||
if ctx.request_metadata:
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in ctx.request_metadata and key not in filter_fields:
|
||||
filter_fields[key] = ctx.request_metadata[key]
|
||||
|
||||
|
||||
# 3. From merged context (extracted slots)
|
||||
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
|
||||
slots = ctx.merged_context.slots or {}
|
||||
for key in ['grade', 'subject', 'scene']:
|
||||
if key in slots and key not in filter_fields:
|
||||
filter_fields[key] = slots[key]
|
||||
|
||||
|
||||
if not filter_fields:
|
||||
return None
|
||||
|
||||
|
||||
return TagFilter(fields=filter_fields)
|
||||
|
||||
async def _build_system_prompt(self, ctx: GenerationContext) -> None:
|
||||
|
|
@ -981,11 +1030,11 @@ class OrchestratorService:
|
|||
"根据知识库信息,我找到了一些相关内容,"
|
||||
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
||||
)
|
||||
|
||||
|
||||
# [AC-IDSMETA-20] Record structured fallback reason code
|
||||
fallback_reason_code = self._determine_fallback_reason_code(ctx)
|
||||
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"[AC-IDSMETA-20] No recall, using fallback: "
|
||||
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
|
||||
|
|
@ -993,7 +1042,7 @@ class OrchestratorService:
|
|||
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
|
||||
f"fallback_reason_code={fallback_reason_code}"
|
||||
)
|
||||
|
||||
|
||||
return (
|
||||
"抱歉,我暂时无法处理您的请求。"
|
||||
"请稍后重试或联系人工客服获取帮助。"
|
||||
|
|
@ -1002,7 +1051,7 @@ class OrchestratorService:
|
|||
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
|
||||
"""
|
||||
[AC-IDSMETA-20] Determine structured fallback reason code.
|
||||
|
||||
|
||||
Reason codes:
|
||||
- no_recall_after_metadata_filter: No results after applying metadata filters
|
||||
- no_recall_no_kb: No target knowledge bases configured
|
||||
|
|
@ -1011,27 +1060,27 @@ class OrchestratorService:
|
|||
- no_recall_error: Retrieval error occurred
|
||||
"""
|
||||
retrieval_diag = ctx.diagnostics.get("retrieval", {})
|
||||
|
||||
|
||||
# Check for retrieval error
|
||||
if ctx.diagnostics.get("retrieval_error"):
|
||||
return "no_recall_error"
|
||||
|
||||
|
||||
# Check if metadata filters were applied
|
||||
if retrieval_diag.get("applied_metadata_filters"):
|
||||
return "no_recall_after_metadata_filter"
|
||||
|
||||
|
||||
# Check if target KBs were configured
|
||||
if not ctx.target_kb_ids:
|
||||
return "no_recall_no_kb"
|
||||
|
||||
|
||||
# Check if KB is empty (no candidates at all)
|
||||
if retrieval_diag.get("total_candidates", 0) == 0:
|
||||
return "no_recall_kb_empty"
|
||||
|
||||
|
||||
# Results found but filtered out by score threshold
|
||||
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
|
||||
return "no_recall_low_score"
|
||||
|
||||
|
||||
return "no_recall_unknown"
|
||||
|
||||
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
||||
|
|
@ -1122,6 +1171,7 @@ class OrchestratorService:
|
|||
[AC-AISVC-02] Build final ChatResponse from generation context.
|
||||
Step 12 of the 12-step pipeline.
|
||||
Uses filtered_reply from Step 9.
|
||||
[v0.8.0] Includes route_trace in response metadata.
|
||||
"""
|
||||
# Use filtered_reply if available, otherwise use llm_response.content
|
||||
if ctx.filtered_reply:
|
||||
|
|
@ -1142,6 +1192,10 @@ class OrchestratorService:
|
|||
"execution_steps": ctx.execution_steps,
|
||||
}
|
||||
|
||||
# [v0.8.0] Include route_trace in response metadata
|
||||
if ctx.route_trace:
|
||||
response_metadata["route_trace"] = ctx.route_trace
|
||||
|
||||
return ChatResponse(
|
||||
reply=reply,
|
||||
confidence=confidence,
|
||||
|
|
|
|||
|
|
@ -178,6 +178,9 @@ class PromptTemplateService:
|
|||
current_version = v
|
||||
break
|
||||
|
||||
# Get latest version for current_content (not just published)
|
||||
latest_version = versions[0] if versions else None
|
||||
|
||||
return {
|
||||
"id": str(template.id),
|
||||
"name": template.name,
|
||||
|
|
@ -185,6 +188,8 @@ class PromptTemplateService:
|
|||
"description": template.description,
|
||||
"is_default": template.is_default,
|
||||
"metadata": template.metadata_,
|
||||
"current_content": latest_version.system_instruction if latest_version else None,
|
||||
"variables": latest_version.variables if latest_version else [],
|
||||
"current_version": {
|
||||
"version": current_version.version,
|
||||
"status": current_version.status,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class RetrievalContext:
|
|||
metadata: dict[str, Any] | None = None
|
||||
tag_filter: "TagFilter | None" = None
|
||||
kb_ids: list[str] | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
||||
"""获取标签过滤器的字典表示"""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Vector retriever for AI Service.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
|
|
@ -76,16 +77,30 @@ class VectorRetriever(BaseRetriever):
|
|||
query_vector = await self._get_embedding(ctx.query)
|
||||
logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}")
|
||||
|
||||
logger.info(f"[AC-AISVC-16] Searching in tenant collection: tenant_id={ctx.tenant_id}")
|
||||
hits = await client.search(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=self._top_k,
|
||||
score_threshold=self._score_threshold,
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} raw hits")
|
||||
logger.info(f"[AC-AISVC-16] Searching in tenant collections: tenant_id={ctx.tenant_id}")
|
||||
if ctx.kb_ids:
|
||||
logger.info(f"[AC-AISVC-16] Restricting search to KB IDs: {ctx.kb_ids}")
|
||||
hits = await client.search(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=self._top_k,
|
||||
score_threshold=self._score_threshold,
|
||||
vector_name="full",
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
kb_ids=ctx.kb_ids,
|
||||
)
|
||||
else:
|
||||
hits = await client.search(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=self._top_k,
|
||||
score_threshold=self._score_threshold,
|
||||
vector_name="full",
|
||||
metadata_filter=ctx.metadata_filter,
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} hits")
|
||||
|
||||
retrieval_hits = [
|
||||
RetrievalHit(
|
||||
text=hit.get("payload", {}).get("text", ""),
|
||||
|
|
@ -133,6 +148,47 @@ class VectorRetriever(BaseRetriever):
|
|||
diagnostics={"error": str(e), "is_insufficient": True},
|
||||
)
|
||||
|
||||
def _apply_metadata_filter(
|
||||
self,
|
||||
hits: list[dict[str, Any]],
|
||||
metadata_filter: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
应用元数据过滤条件。
|
||||
|
||||
支持的操作:
|
||||
- {"$eq": value} : 等于
|
||||
- {"$in": [values]} : 在列表中
|
||||
"""
|
||||
filtered = []
|
||||
for hit in hits:
|
||||
payload = hit.get("payload", {})
|
||||
hit_metadata = payload.get("metadata", {})
|
||||
|
||||
match = True
|
||||
for field_key, condition in metadata_filter.items():
|
||||
hit_value = hit_metadata.get(field_key)
|
||||
|
||||
if isinstance(condition, dict):
|
||||
if "$eq" in condition:
|
||||
if hit_value != condition["$eq"]:
|
||||
match = False
|
||||
break
|
||||
elif "$in" in condition:
|
||||
if hit_value not in condition["$in"]:
|
||||
match = False
|
||||
break
|
||||
else:
|
||||
# 直接值比较
|
||||
if hit_value != condition:
|
||||
match = False
|
||||
break
|
||||
|
||||
if match:
|
||||
filtered.append(hit)
|
||||
|
||||
return filtered
|
||||
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for text using pluggable embedding provider.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Slot Definition Service.
|
||||
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
class SlotDefinitionService:
|
||||
"""
|
||||
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链管理
|
||||
|
||||
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
||||
"""
|
||||
|
|
@ -114,6 +116,58 @@ class SlotDefinitionService:
|
|||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||||
|
||||
Args:
|
||||
strategies: 策略链列表
|
||||
|
||||
Returns:
|
||||
Tuple of (是否有效, 错误信息)
|
||||
"""
|
||||
if strategies is None:
|
||||
return True, ""
|
||||
|
||||
if not isinstance(strategies, list):
|
||||
return False, "extract_strategies 必须是数组类型"
|
||||
|
||||
if len(strategies) == 0:
|
||||
return False, "提取策略链不能为空数组"
|
||||
|
||||
# 校验不允许重复策略
|
||||
if len(strategies) != len(set(strategies)):
|
||||
return False, "提取策略链中不允许重复的策略"
|
||||
|
||||
# 校验策略值有效
|
||||
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
|
||||
if invalid:
|
||||
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _normalize_strategies(
|
||||
self,
|
||||
extract_strategies: list[str] | None,
|
||||
extract_strategy: str | None,
|
||||
) -> list[str] | None:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 规范化提取策略
|
||||
优先使用 extract_strategies,如果不存在则使用 extract_strategy
|
||||
|
||||
Args:
|
||||
extract_strategies: 策略链(新字段)
|
||||
extract_strategy: 单策略(旧字段,兼容)
|
||||
|
||||
Returns:
|
||||
规范化后的策略链或 None
|
||||
"""
|
||||
if extract_strategies is not None:
|
||||
return extract_strategies
|
||||
if extract_strategy:
|
||||
return [extract_strategy]
|
||||
return None
|
||||
|
||||
async def create_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
|
@ -121,6 +175,7 @@ class SlotDefinitionService:
|
|||
) -> SlotDefinition:
|
||||
"""
|
||||
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
|
@ -148,11 +203,16 @@ class SlotDefinitionService:
|
|||
f"有效类型为: {self.VALID_TYPES}"
|
||||
)
|
||||
|
||||
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"无效的提取策略 '{slot_create.extract_strategy}',"
|
||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||
)
|
||||
# [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
|
||||
strategies = self._normalize_strategies(
|
||||
slot_create.extract_strategies,
|
||||
slot_create.extract_strategy
|
||||
)
|
||||
|
||||
if strategies is not None:
|
||||
is_valid, error_msg = self._validate_strategies(strategies)
|
||||
if not is_valid:
|
||||
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||||
|
||||
linked_field = None
|
||||
if slot_create.linked_field_id:
|
||||
|
|
@ -162,12 +222,22 @@ class SlotDefinitionService:
|
|||
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
||||
)
|
||||
|
||||
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
|
||||
# 如果前端提交了 extract_strategies,则使用第一个作为旧字段值
|
||||
old_strategy = slot_create.extract_strategy
|
||||
if not old_strategy and strategies and len(strategies) > 0:
|
||||
old_strategy = strategies[0]
|
||||
|
||||
slot = SlotDefinition(
|
||||
tenant_id=tenant_id,
|
||||
slot_key=slot_create.slot_key,
|
||||
display_name=slot_create.display_name,
|
||||
description=slot_create.description,
|
||||
type=slot_create.type,
|
||||
required=slot_create.required,
|
||||
extract_strategy=slot_create.extract_strategy,
|
||||
# [AC-MRS-07-UPGRADE] 同时保存新旧字段
|
||||
extract_strategy=old_strategy,
|
||||
extract_strategies=strategies,
|
||||
validation_rule=slot_create.validation_rule,
|
||||
ask_back_prompt=slot_create.ask_back_prompt,
|
||||
default_value=slot_create.default_value,
|
||||
|
|
@ -180,6 +250,7 @@ class SlotDefinitionService:
|
|||
logger.info(
|
||||
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
||||
f"slot_key={slot.slot_key}, required={slot.required}, "
|
||||
f"strategies={strategies}, "
|
||||
f"linked_field_id={slot.linked_field_id}"
|
||||
)
|
||||
|
||||
|
|
@ -193,6 +264,7 @@ class SlotDefinitionService:
|
|||
) -> SlotDefinition | None:
|
||||
"""
|
||||
更新槽位定义
|
||||
[AC-MRS-07-UPGRADE] 支持提取策略链更新
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
|
@ -206,6 +278,12 @@ class SlotDefinitionService:
|
|||
if not slot:
|
||||
return None
|
||||
|
||||
if slot_update.display_name is not None:
|
||||
slot.display_name = slot_update.display_name
|
||||
|
||||
if slot_update.description is not None:
|
||||
slot.description = slot_update.description
|
||||
|
||||
if slot_update.type is not None:
|
||||
if slot_update.type not in self.VALID_TYPES:
|
||||
raise ValueError(
|
||||
|
|
@ -217,13 +295,28 @@ class SlotDefinitionService:
|
|||
if slot_update.required is not None:
|
||||
slot.required = slot_update.required
|
||||
|
||||
if slot_update.extract_strategy is not None:
|
||||
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||||
raise ValueError(
|
||||
f"无效的提取策略 '{slot_update.extract_strategy}',"
|
||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||
)
|
||||
slot.extract_strategy = slot_update.extract_strategy
|
||||
# [AC-MRS-07-UPGRADE] 处理提取策略链更新
|
||||
# 如果传入了 extract_strategies 或 extract_strategy,则更新
|
||||
if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
|
||||
strategies = self._normalize_strategies(
|
||||
slot_update.extract_strategies,
|
||||
slot_update.extract_strategy
|
||||
)
|
||||
|
||||
if strategies is not None:
|
||||
is_valid, error_msg = self._validate_strategies(strategies)
|
||||
if not is_valid:
|
||||
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||||
|
||||
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
|
||||
slot.extract_strategies = strategies
|
||||
# 如果前端提交了 extract_strategy,则使用它;否则使用策略链的第一个
|
||||
if slot_update.extract_strategy is not None:
|
||||
slot.extract_strategy = slot_update.extract_strategy
|
||||
elif strategies and len(strategies) > 0:
|
||||
slot.extract_strategy = strategies[0]
|
||||
else:
|
||||
slot.extract_strategy = None
|
||||
|
||||
if slot_update.validation_rule is not None:
|
||||
slot.validation_rule = slot_update.validation_rule
|
||||
|
|
@ -250,7 +343,7 @@ class SlotDefinitionService:
|
|||
|
||||
logger.info(
|
||||
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
||||
f"slot_id={slot_id}"
|
||||
f"slot_id={slot_id}, strategies={slot.extract_strategies}"
|
||||
)
|
||||
|
||||
return slot
|
||||
|
|
@ -331,9 +424,13 @@ class SlotDefinitionService:
|
|||
"id": str(slot.id),
|
||||
"tenant_id": slot.tenant_id,
|
||||
"slot_key": slot.slot_key,
|
||||
"display_name": slot.display_name,
|
||||
"description": slot.description,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"extract_strategies": slot.extract_strategies,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
|
|
|
|||
Loading…
Reference in New Issue