feat: update core backend services including LLM, embedding, KB, orchestrator and admin APIs [AC-AISVC-CORE]
This commit is contained in:
parent
759eafb490
commit
fe883cfff0
|
|
@ -2,6 +2,7 @@
|
||||||
Admin API routes for AI Service management.
|
Admin API routes for AI Service management.
|
||||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
|
||||||
[AC-MRS-07,08,16] Slot definition management endpoints.
|
[AC-MRS-07,08,16] Slot definition management endpoints.
|
||||||
|
[AC-SCENE-SLOT-01] Scene slot bundle management endpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.api.admin.api_key import router as api_key_router
|
from app.api.admin.api_key import router as api_key_router
|
||||||
|
|
@ -18,6 +19,7 @@ from app.api.admin.metadata_schema import router as metadata_schema_router
|
||||||
from app.api.admin.monitoring import router as monitoring_router
|
from app.api.admin.monitoring import router as monitoring_router
|
||||||
from app.api.admin.prompt_templates import router as prompt_templates_router
|
from app.api.admin.prompt_templates import router as prompt_templates_router
|
||||||
from app.api.admin.rag import router as rag_router
|
from app.api.admin.rag import router as rag_router
|
||||||
|
from app.api.admin.scene_slot_bundle import router as scene_slot_bundle_router
|
||||||
from app.api.admin.script_flows import router as script_flows_router
|
from app.api.admin.script_flows import router as script_flows_router
|
||||||
from app.api.admin.sessions import router as sessions_router
|
from app.api.admin.sessions import router as sessions_router
|
||||||
from app.api.admin.slot_definition import router as slot_definition_router
|
from app.api.admin.slot_definition import router as slot_definition_router
|
||||||
|
|
@ -38,6 +40,7 @@ __all__ = [
|
||||||
"monitoring_router",
|
"monitoring_router",
|
||||||
"prompt_templates_router",
|
"prompt_templates_router",
|
||||||
"rag_router",
|
"rag_router",
|
||||||
|
"scene_slot_bundle_router",
|
||||||
"script_flows_router",
|
"script_flows_router",
|
||||||
"sessions_router",
|
"sessions_router",
|
||||||
"slot_definition_router",
|
"slot_definition_router",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
Intent Rule Management API.
|
Intent Rule Management API.
|
||||||
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
|
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
|
||||||
[AC-AISVC-96] Intent rule testing endpoint.
|
[AC-AISVC-96] Intent rule testing endpoint.
|
||||||
|
[AC-AISVC-116] Fusion config management endpoints.
|
||||||
|
[AC-AISVC-114] Intent vector generation endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -14,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.models.entities import IntentRuleCreate, IntentRuleUpdate
|
from app.models.entities import IntentRuleCreate, IntentRuleUpdate
|
||||||
|
from app.services.intent.models import DEFAULT_FUSION_CONFIG, FusionConfig
|
||||||
from app.services.intent.rule_service import IntentRuleService
|
from app.services.intent.rule_service import IntentRuleService
|
||||||
from app.services.intent.tester import IntentRuleTester
|
from app.services.intent.tester import IntentRuleTester
|
||||||
|
|
||||||
|
|
@ -21,6 +24,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
|
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
|
||||||
|
|
||||||
|
_fusion_config = FusionConfig()
|
||||||
|
|
||||||
|
|
||||||
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||||
"""Extract tenant ID from header."""
|
"""Extract tenant ID from header."""
|
||||||
|
|
@ -204,3 +209,109 @@ async def test_rule(
|
||||||
result = await tester.test_rule(rule, [body.message], all_rules)
|
result = await tester.test_rule(rule, [body.message], all_rules)
|
||||||
|
|
||||||
return result.to_dict()
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class FusionConfigUpdate(BaseModel):
|
||||||
|
"""Request body for updating fusion config."""
|
||||||
|
|
||||||
|
w_rule: float | None = None
|
||||||
|
w_semantic: float | None = None
|
||||||
|
w_llm: float | None = None
|
||||||
|
semantic_threshold: float | None = None
|
||||||
|
conflict_threshold: float | None = None
|
||||||
|
gray_zone_threshold: float | None = None
|
||||||
|
min_trigger_threshold: float | None = None
|
||||||
|
clarify_threshold: float | None = None
|
||||||
|
multi_intent_threshold: float | None = None
|
||||||
|
llm_judge_enabled: bool | None = None
|
||||||
|
semantic_matcher_enabled: bool | None = None
|
||||||
|
semantic_matcher_timeout_ms: int | None = None
|
||||||
|
llm_judge_timeout_ms: int | None = None
|
||||||
|
semantic_top_k: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/fusion-config")
|
||||||
|
async def get_fusion_config() -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-116] Get current fusion configuration.
|
||||||
|
"""
|
||||||
|
logger.info("[AC-AISVC-116] Getting fusion config")
|
||||||
|
return _fusion_config.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/fusion-config")
|
||||||
|
async def update_fusion_config(
|
||||||
|
body: FusionConfigUpdate,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-116] Update fusion configuration.
|
||||||
|
"""
|
||||||
|
global _fusion_config
|
||||||
|
|
||||||
|
logger.info(f"[AC-AISVC-116] Updating fusion config: {body.model_dump()}")
|
||||||
|
|
||||||
|
current_dict = _fusion_config.to_dict()
|
||||||
|
update_dict = body.model_dump(exclude_none=True)
|
||||||
|
current_dict.update(update_dict)
|
||||||
|
_fusion_config = FusionConfig.from_dict(current_dict)
|
||||||
|
|
||||||
|
return _fusion_config.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{rule_id}/generate-vector")
|
||||||
|
async def generate_intent_vector(
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-114] Generate intent vector for a rule.
|
||||||
|
|
||||||
|
Uses the rule's semantic_examples to generate an average vector.
|
||||||
|
If no semantic_examples exist, returns an error.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-114] Generating intent vector for tenant={tenant_id}, rule_id={rule_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
service = IntentRuleService(session)
|
||||||
|
rule = await service.get_rule(tenant_id, rule_id)
|
||||||
|
|
||||||
|
if not rule:
|
||||||
|
raise HTTPException(status_code=404, detail="Intent rule not found")
|
||||||
|
|
||||||
|
if not rule.semantic_examples:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Rule has no semantic_examples. Please add semantic_examples first."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.core.dependencies import get_embedding_provider
|
||||||
|
embedding_provider = get_embedding_provider()
|
||||||
|
|
||||||
|
vectors = await embedding_provider.embed_batch(rule.semantic_examples)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
avg_vector = np.mean(vectors, axis=0).tolist()
|
||||||
|
|
||||||
|
update_data = IntentRuleUpdate(intent_vector=avg_vector)
|
||||||
|
updated_rule = await service.update_rule(tenant_id, rule_id, update_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-114] Generated intent vector for rule={rule_id}, "
|
||||||
|
f"dimension={len(avg_vector)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(updated_rule.id),
|
||||||
|
"intent_vector": updated_rule.intent_vector,
|
||||||
|
"semantic_examples": updated_rule.semantic_examples,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[AC-AISVC-114] Failed to generate intent vector: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to generate intent vector: {str(e)}"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,28 +6,35 @@ Knowledge Base management endpoints.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, UploadFile
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.core.exceptions import MissingTenantIdException
|
from app.core.exceptions import MissingTenantIdException
|
||||||
from app.core.tenant import get_tenant_id
|
from app.core.tenant import get_tenant_id
|
||||||
from app.models import ErrorResponse
|
from app.models import ErrorResponse
|
||||||
from app.models.entities import (
|
from app.models.entities import (
|
||||||
|
Document,
|
||||||
|
DocumentStatus,
|
||||||
IndexJob,
|
IndexJob,
|
||||||
IndexJobStatus,
|
IndexJobStatus,
|
||||||
KBType,
|
KBType,
|
||||||
|
KnowledgeBase,
|
||||||
KnowledgeBaseCreate,
|
KnowledgeBaseCreate,
|
||||||
KnowledgeBaseUpdate,
|
KnowledgeBaseUpdate,
|
||||||
)
|
)
|
||||||
from app.services.kb import KBService
|
from app.services.kb import KBService
|
||||||
from app.services.knowledge_base_service import KnowledgeBaseService
|
from app.services.knowledge_base_service import KnowledgeBaseService
|
||||||
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -457,6 +464,7 @@ async def list_documents(
|
||||||
"kbId": doc.kb_id,
|
"kbId": doc.kb_id,
|
||||||
"fileName": doc.file_name,
|
"fileName": doc.file_name,
|
||||||
"status": doc.status,
|
"status": doc.status,
|
||||||
|
"metadata": doc.doc_metadata,
|
||||||
"jobId": str(latest_job.id) if latest_job else None,
|
"jobId": str(latest_job.id) if latest_job else None,
|
||||||
"createdAt": doc.created_at.isoformat() + "Z",
|
"createdAt": doc.created_at.isoformat() + "Z",
|
||||||
"updatedAt": doc.updated_at.isoformat() + "Z",
|
"updatedAt": doc.updated_at.isoformat() + "Z",
|
||||||
|
|
@ -585,6 +593,7 @@ async def upload_document(
|
||||||
file_name=file.filename or "unknown",
|
file_name=file.filename or "unknown",
|
||||||
file_content=file_content,
|
file_content=file_content,
|
||||||
file_type=file.content_type,
|
file_type=file.content_type,
|
||||||
|
metadata=metadata_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||||
|
|
@ -915,3 +924,488 @@ async def delete_document(
|
||||||
"message": "Document deleted",
|
"message": "Document deleted",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/documents/{doc_id}/metadata",
|
||||||
|
operation_id="updateDocumentMetadata",
|
||||||
|
summary="Update document metadata",
|
||||||
|
description="[AC-ASA-08] Update metadata for a specific document.",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Metadata updated"},
|
||||||
|
404: {"description": "Document not found"},
|
||||||
|
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||||
|
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_document_metadata(
|
||||||
|
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
doc_id: str,
|
||||||
|
body: dict,
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
[AC-ASA-08] Update document metadata.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
metadata = body.get("metadata")
|
||||||
|
|
||||||
|
if metadata is not None and not isinstance(metadata, dict):
|
||||||
|
try:
|
||||||
|
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": "INVALID_METADATA",
|
||||||
|
"message": "Invalid JSON format for metadata",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.entities import Document
|
||||||
|
|
||||||
|
stmt = select(Document).where(
|
||||||
|
Document.tenant_id == tenant_id,
|
||||||
|
Document.id == doc_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not document:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"code": "DOCUMENT_NOT_FOUND",
|
||||||
|
"message": f"Document {doc_id} not found",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
document.doc_metadata = metadata
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"success": True,
|
||||||
|
"message": "Metadata updated",
|
||||||
|
"metadata": document.doc_metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/documents/batch-upload",
|
||||||
|
operation_id="batchUploadDocuments",
|
||||||
|
summary="Batch upload documents from zip",
|
||||||
|
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Batch upload result"},
|
||||||
|
400: {"description": "Bad Request - invalid zip or missing files"},
|
||||||
|
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||||
|
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def batch_upload_documents(
|
||||||
|
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
kb_id: str = Form(...),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Batch upload documents from a zip file.
|
||||||
|
|
||||||
|
Zip structure:
|
||||||
|
- Each folder contains one .md file and one metadata.json
|
||||||
|
- metadata.json uses field_key from MetadataFieldDefinition as keys
|
||||||
|
|
||||||
|
Example metadata.json:
|
||||||
|
{
|
||||||
|
"grade": "高一",
|
||||||
|
"subject": "数学",
|
||||||
|
"type": "痛点"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
|
||||||
|
f"kb_id={kb_id}, filename={file.filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file.filename or not file.filename.lower().endswith('.zip'):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": "INVALID_FORMAT",
|
||||||
|
"message": "Only .zip files are supported",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
|
||||||
|
if not kb:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"code": "KB_NOT_FOUND",
|
||||||
|
"message": f"Knowledge base {kb_id} not found",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
file_content = await file.read()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
succeeded = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
zip_path = Path(temp_dir) / "upload.zip"
|
||||||
|
with open(zip_path, "wb") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||||
|
zf.extractall(temp_dir)
|
||||||
|
except zipfile.BadZipFile as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": "INVALID_ZIP",
|
||||||
|
"message": f"Invalid zip file: {str(e)}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# 列出解压后的所有内容,用于调试
|
||||||
|
all_items = list(extracted_path.iterdir())
|
||||||
|
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
|
||||||
|
|
||||||
|
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
|
||||||
|
def find_document_folders(path: Path) -> list[Path]:
|
||||||
|
"""递归查找所有包含文档文件的文件夹"""
|
||||||
|
doc_folders = []
|
||||||
|
|
||||||
|
# 检查当前文件夹是否包含文档文件
|
||||||
|
content_files = (
|
||||||
|
list(path.glob("*.md")) +
|
||||||
|
list(path.glob("*.markdown")) +
|
||||||
|
list(path.glob("*.txt"))
|
||||||
|
)
|
||||||
|
|
||||||
|
if content_files:
|
||||||
|
# 这个文件夹包含文档文件,是一个文档文件夹
|
||||||
|
doc_folders.append(path)
|
||||||
|
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
|
||||||
|
|
||||||
|
# 递归检查子文件夹
|
||||||
|
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
|
||||||
|
doc_folders.extend(find_document_folders(subfolder))
|
||||||
|
|
||||||
|
return doc_folders
|
||||||
|
|
||||||
|
folders = find_document_folders(extracted_path)
|
||||||
|
|
||||||
|
if not folders:
|
||||||
|
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": "NO_DOCUMENTS_FOUND",
|
||||||
|
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
|
||||||
|
"details": {
|
||||||
|
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
|
||||||
|
"found_items": [item.name for item in all_items],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
|
||||||
|
|
||||||
|
for folder in folders:
|
||||||
|
folder_name = folder.name if folder != extracted_path else "root"
|
||||||
|
|
||||||
|
content_files = (
|
||||||
|
list(folder.glob("*.md")) +
|
||||||
|
list(folder.glob("*.markdown")) +
|
||||||
|
list(folder.glob("*.txt"))
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_files:
|
||||||
|
# 这种情况不应该发生,因为我们已经过滤过了
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"folder": folder_name,
|
||||||
|
"status": "failed",
|
||||||
|
"error": "No content file found",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
content_file = content_files[0]
|
||||||
|
metadata_file = folder / "metadata.json"
|
||||||
|
|
||||||
|
metadata_dict = {}
|
||||||
|
if metadata_file.exists():
|
||||||
|
try:
|
||||||
|
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||||
|
metadata_dict = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"folder": folder_name,
|
||||||
|
"status": "failed",
|
||||||
|
"error": f"Invalid metadata.json: {str(e)}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
|
||||||
|
|
||||||
|
field_def_service = MetadataFieldDefinitionService(session)
|
||||||
|
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||||||
|
tenant_id, metadata_dict, "kb_document"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"folder": folder_name,
|
||||||
|
"status": "failed",
|
||||||
|
"error": f"Metadata validation failed: {validation_errors}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(content_file, 'rb') as f:
|
||||||
|
doc_content = f.read()
|
||||||
|
|
||||||
|
file_ext = content_file.suffix.lower()
|
||||||
|
if file_ext == '.txt':
|
||||||
|
file_type = "text/plain"
|
||||||
|
else:
|
||||||
|
file_type = "text/markdown"
|
||||||
|
|
||||||
|
doc_kb_service = KBService(session)
|
||||||
|
document, job = await doc_kb_service.upload_document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
kb_id=kb_id,
|
||||||
|
file_name=content_file.name,
|
||||||
|
file_content=doc_content,
|
||||||
|
file_type=file_type,
|
||||||
|
metadata=metadata_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
background_tasks.add_task(
|
||||||
|
_index_document,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
str(job.id),
|
||||||
|
str(document.id),
|
||||||
|
doc_content,
|
||||||
|
content_file.name,
|
||||||
|
metadata_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
succeeded += 1
|
||||||
|
results.append({
|
||||||
|
"folder": folder_name,
|
||||||
|
"docId": str(document.id),
|
||||||
|
"jobId": str(job.id),
|
||||||
|
"status": "created",
|
||||||
|
"fileName": content_file.name,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
|
||||||
|
f"doc_id={document.id}, job_id={job.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"folder": folder_name,
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"success": True,
|
||||||
|
"total": len(results),
|
||||||
|
"succeeded": succeeded,
|
||||||
|
"failed": failed,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{kb_id}/documents/json-batch",
|
||||||
|
summary="[AC-KB-03] JSON批量上传文档",
|
||||||
|
description="上传JSONL格式文件,每行一个JSON对象,包含text和元数据字段",
|
||||||
|
)
|
||||||
|
async def upload_json_batch(
|
||||||
|
kb_id: str,
|
||||||
|
tenant_id: str = Query(..., description="租户ID"),
|
||||||
|
file: UploadFile = File(..., description="JSONL格式文件,每行一个JSON对象"),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
background_tasks: BackgroundTasks = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
JSON批量上传文档
|
||||||
|
|
||||||
|
文件格式:JSONL (每行一个JSON对象)
|
||||||
|
必填字段:text - 需要录入知识库的文本内容
|
||||||
|
可选字段:元数据字段(如grade, subject, kb_scene等)
|
||||||
|
|
||||||
|
示例:
|
||||||
|
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
|
||||||
|
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
|
||||||
|
"""
|
||||||
|
kb = await session.get(KnowledgeBase, kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||||
|
|
||||||
|
if kb.tenant_id != tenant_id:
|
||||||
|
raise HTTPException(status_code=403, detail="无权访问此知识库")
|
||||||
|
|
||||||
|
valid_field_keys = set()
|
||||||
|
try:
|
||||||
|
field_defs = await MetadataFieldDefinitionService(session).get_fields(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
include_inactive=False,
|
||||||
|
)
|
||||||
|
valid_field_keys = {f.field_key for f in field_defs}
|
||||||
|
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
try:
|
||||||
|
text_content = content.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
text_content = content.decode("gbk")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="文件编码不支持,请使用UTF-8编码")
|
||||||
|
|
||||||
|
lines = text_content.strip().split("\n")
|
||||||
|
if not lines:
|
||||||
|
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
succeeded = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
kb_service = KBService(session)
|
||||||
|
|
||||||
|
for line_num, line in enumerate(lines, 1):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"line": line_num,
|
||||||
|
"success": False,
|
||||||
|
"error": f"JSON解析失败: {e}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = json_obj.get("text")
|
||||||
|
if not text:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"line": line_num,
|
||||||
|
"success": False,
|
||||||
|
"error": "缺少必填字段: text",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
for key, value in json_obj.items():
|
||||||
|
if key == "text":
|
||||||
|
continue
|
||||||
|
if valid_field_keys and key not in valid_field_keys:
|
||||||
|
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
|
||||||
|
continue
|
||||||
|
if value is not None:
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_name = f"json_batch_line_{line_num}.txt"
|
||||||
|
file_content = text.encode("utf-8")
|
||||||
|
|
||||||
|
document, job = await kb_service.upload_document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
kb_id=kb_id,
|
||||||
|
file_name=file_name,
|
||||||
|
file_content=file_content,
|
||||||
|
file_type="text/plain",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if background_tasks:
|
||||||
|
background_tasks.add_task(
|
||||||
|
_index_document,
|
||||||
|
tenant_id,
|
||||||
|
kb_id,
|
||||||
|
str(job.id),
|
||||||
|
str(document.id),
|
||||||
|
file_content,
|
||||||
|
file_name,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
succeeded += 1
|
||||||
|
results.append({
|
||||||
|
"line": line_num,
|
||||||
|
"success": True,
|
||||||
|
"doc_id": str(document.id),
|
||||||
|
"job_id": str(job.id),
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
results.append({
|
||||||
|
"line": line_num,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"success": True,
|
||||||
|
"total": len(lines),
|
||||||
|
"succeeded": succeeded,
|
||||||
|
"failed": failed,
|
||||||
|
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
|
||||||
"scope": f.scope,
|
"scope": f.scope,
|
||||||
"is_filterable": f.is_filterable,
|
"is_filterable": f.is_filterable,
|
||||||
"is_rank_feature": f.is_rank_feature,
|
"is_rank_feature": f.is_rank_feature,
|
||||||
|
"usage_description": f.usage_description,
|
||||||
"field_roles": f.field_roles or [],
|
"field_roles": f.field_roles or [],
|
||||||
"status": f.status,
|
"status": f.status,
|
||||||
"version": f.version,
|
"version": f.version,
|
||||||
|
|
|
||||||
|
|
@ -407,6 +407,7 @@ async def get_conversation_detail(
|
||||||
"guardrailTriggered": user_msg.guardrail_triggered,
|
"guardrailTriggered": user_msg.guardrail_triggered,
|
||||||
"guardrailWords": user_msg.guardrail_words,
|
"guardrailWords": user_msg.guardrail_words,
|
||||||
"executionSteps": execution_steps,
|
"executionSteps": execution_steps,
|
||||||
|
"routeTrace": user_msg.route_trace,
|
||||||
"createdAt": user_msg.created_at.isoformat(),
|
"createdAt": user_msg.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -659,8 +660,56 @@ async def _process_export(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
|
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
|
||||||
|
|
||||||
task = await session.get(ExportTask, task_id)
|
task = task_status.get(ExportTask, task_id)
|
||||||
if task:
|
if task:
|
||||||
task.status = ExportTaskStatus.FAILED.value
|
task.status = ExportTaskStatus.FAILED.value
|
||||||
task.error_message = str(e)
|
task.error_message = str(e)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/clarification-metrics")
|
||||||
|
async def get_clarification_metrics(
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
total_requests: int = Query(100, ge=1, description="Total requests for rate calculation"),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-CLARIFY] Get clarification metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- clarify_trigger_rate: 澄清触发率
|
||||||
|
- clarify_converge_rate: 澄清后收敛率
|
||||||
|
- misroute_rate: 误入流程率
|
||||||
|
"""
|
||||||
|
from app.services.intent.clarification import get_clarify_metrics
|
||||||
|
|
||||||
|
metrics = get_clarify_metrics()
|
||||||
|
counts = metrics.get_metrics()
|
||||||
|
rates = metrics.get_rates(total_requests)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"counts": counts,
|
||||||
|
"rates": rates,
|
||||||
|
"thresholds": {
|
||||||
|
"t_high": 0.75,
|
||||||
|
"t_low": 0.45,
|
||||||
|
"max_retry": 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/clarification-metrics/reset")
|
||||||
|
async def reset_clarification_metrics(
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-CLARIFY] Reset clarification metrics.
|
||||||
|
"""
|
||||||
|
from app.services.intent.clarification import get_clarify_metrics
|
||||||
|
|
||||||
|
metrics = get_clarify_metrics()
|
||||||
|
metrics.reset()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "reset",
|
||||||
|
"message": "Clarification metrics have been reset.",
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ class Settings(BaseSettings):
|
||||||
redis_enabled: bool = True
|
redis_enabled: bool = True
|
||||||
dashboard_cache_ttl: int = 60
|
dashboard_cache_ttl: int = 60
|
||||||
stats_counter_ttl: int = 7776000
|
stats_counter_ttl: int = 7776000
|
||||||
|
slot_state_cache_ttl: int = 1800
|
||||||
|
|
||||||
frontend_base_url: str = "http://localhost:3000"
|
frontend_base_url: str = "http://localhost:3000"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ engine = create_async_engine(
|
||||||
settings.database_url,
|
settings.database_url,
|
||||||
pool_size=settings.database_pool_size,
|
pool_size=settings.database_pool_size,
|
||||||
max_overflow=settings.database_max_overflow,
|
max_overflow=settings.database_max_overflow,
|
||||||
echo=settings.debug,
|
echo=False,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,12 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
|
||||||
from app.core.database import async_session_maker
|
from app.core.database import async_session_maker
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
await service.initialize(session)
|
await service.initialize(session)
|
||||||
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
if service._initialized and len(service._keys_cache) > 0:
|
||||||
|
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
||||||
|
elif service._initialized and len(service._keys_cache) == 0:
|
||||||
|
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
|
||||||
|
else:
|
||||||
|
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
|
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -272,20 +272,24 @@ class QdrantClient:
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
vector_name: str = "full",
|
vector_name: str = "full",
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
kb_ids: list[str] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
[AC-AISVC-10] Search vectors in tenant's collections.
|
||||||
Returns results with score >= score_threshold if specified.
|
Returns results with score >= score_threshold if specified.
|
||||||
Searches both old format (with @) and new format (with _) for backward compatibility.
|
Searches all collections for the tenant (multi-KB support).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant identifier
|
tenant_id: Tenant identifier
|
||||||
query_vector: Query vector for similarity search
|
query_vector: Query vector for similarity search
|
||||||
limit: Maximum number of results
|
limit: Maximum number of results per collection
|
||||||
score_threshold: Minimum score threshold for results
|
score_threshold: Minimum score threshold for results
|
||||||
vector_name: Name of the vector to search (for multi-vector collections)
|
vector_name: Name of the vector to search (for multi-vector collections)
|
||||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||||
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
||||||
|
metadata_filter: Optional metadata filter to apply during search
|
||||||
|
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
|
||||||
"""
|
"""
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
|
|
||||||
|
|
@ -293,21 +297,36 @@ class QdrantClient:
|
||||||
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||||
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||||
)
|
)
|
||||||
|
if metadata_filter:
|
||||||
|
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
|
||||||
|
|
||||||
collection_names = [self.get_collection_name(tenant_id)]
|
# 构建 Qdrant filter
|
||||||
if '@' in tenant_id:
|
qdrant_filter = None
|
||||||
old_format = f"{self._collection_prefix}{tenant_id}"
|
if metadata_filter:
|
||||||
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
qdrant_filter = self._build_qdrant_filter(metadata_filter)
|
||||||
collection_names = [new_format, old_format]
|
logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
# 获取该租户的所有 collections
|
||||||
|
collection_names = await self._get_tenant_collections(client, tenant_id)
|
||||||
|
|
||||||
|
# 如果指定了 kb_ids,则只搜索指定的知识库 collections
|
||||||
|
if kb_ids:
|
||||||
|
target_collections = []
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
|
||||||
|
if kb_collection_name in collection_names:
|
||||||
|
target_collections.append(kb_collection_name)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
|
||||||
|
collection_names = target_collections
|
||||||
|
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
|
||||||
|
else:
|
||||||
|
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
|
||||||
|
|
||||||
all_hits = []
|
all_hits = []
|
||||||
|
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
|
||||||
|
|
||||||
exists = await client.collection_exists(collection_name)
|
exists = await client.collection_exists(collection_name)
|
||||||
if not exists:
|
if not exists:
|
||||||
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
|
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
|
||||||
|
|
@ -321,6 +340,7 @@ class QdrantClient:
|
||||||
limit=limit,
|
limit=limit,
|
||||||
with_vectors=with_vectors,
|
with_vectors=with_vectors,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
query_filter=qdrant_filter,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
|
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
|
||||||
|
|
@ -334,6 +354,7 @@ class QdrantClient:
|
||||||
limit=limit,
|
limit=limit,
|
||||||
with_vectors=with_vectors,
|
with_vectors=with_vectors,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
query_filter=qdrant_filter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
@ -348,6 +369,7 @@ class QdrantClient:
|
||||||
"id": str(result.id),
|
"id": str(result.id),
|
||||||
"score": result.score,
|
"score": result.score,
|
||||||
"payload": result.payload or {},
|
"payload": result.payload or {},
|
||||||
|
"collection": collection_name, # 添加 collection 信息
|
||||||
}
|
}
|
||||||
if with_vectors and result.vector:
|
if with_vectors and result.vector:
|
||||||
hit["vector"] = result.vector
|
hit["vector"] = result.vector
|
||||||
|
|
@ -358,10 +380,6 @@ class QdrantClient:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||||
)
|
)
|
||||||
for i, h in enumerate(hits[:3]):
|
|
||||||
logger.debug(
|
|
||||||
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||||
|
|
@ -370,9 +388,10 @@ class QdrantClient:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
# 按分数排序并返回 top results
|
||||||
|
all_hits.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
all_hits = all_hits[:limit]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||||
|
|
@ -386,6 +405,113 @@ class QdrantClient:
|
||||||
|
|
||||||
return all_hits
|
return all_hits
|
||||||
|
|
||||||
|
async def _get_tenant_collections(
|
||||||
|
self,
|
||||||
|
client: AsyncQdrantClient,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取指定租户的所有 collections。
|
||||||
|
优先从 Redis 缓存获取,未缓存则从 Qdrant 查询并缓存。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Collection 名称列表
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 1. 尝试从缓存获取
|
||||||
|
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||||
|
cache_service = await get_metadata_cache_service()
|
||||||
|
cache_key = f"collections:{tenant_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 确保 Redis 连接已初始化
|
||||||
|
redis_client = await cache_service._get_redis()
|
||||||
|
if redis_client and cache_service._enabled:
|
||||||
|
cached = await redis_client.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
import json
|
||||||
|
collections = json.loads(cached)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
|
||||||
|
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
|
||||||
|
)
|
||||||
|
return collections
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
|
||||||
|
|
||||||
|
# 2. 从 Qdrant 查询
|
||||||
|
safe_tenant_id = tenant_id.replace('@', '_')
|
||||||
|
prefix = f"{self._collection_prefix}{safe_tenant_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
collections = await client.get_collections()
|
||||||
|
tenant_collections = [
|
||||||
|
c.name for c in collections.collections
|
||||||
|
if c.name.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 按名称排序
|
||||||
|
tenant_collections.sort()
|
||||||
|
|
||||||
|
db_time = (time.time() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
|
||||||
|
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 缓存结果(5分钟 TTL)
|
||||||
|
try:
|
||||||
|
redis_client = await cache_service._get_redis()
|
||||||
|
if redis_client and cache_service._enabled:
|
||||||
|
import json
|
||||||
|
await redis_client.setex(
|
||||||
|
cache_key,
|
||||||
|
300, # 5分钟
|
||||||
|
json.dumps(tenant_collections)
|
||||||
|
)
|
||||||
|
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
|
||||||
|
|
||||||
|
return tenant_collections
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
|
||||||
|
return [self.get_collection_name(tenant_id)]
|
||||||
|
|
||||||
|
def _build_qdrant_filter(
|
||||||
|
self,
|
||||||
|
metadata_filter: dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
构建 Qdrant 过滤条件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filter: 元数据过滤条件,如 {"grade": "三年级", "subject": "语文"}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Qdrant Filter 对象
|
||||||
|
"""
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||||
|
|
||||||
|
must_conditions = []
|
||||||
|
|
||||||
|
for key, value in metadata_filter.items():
|
||||||
|
# 支持嵌套 metadata 字段,如 metadata.grade
|
||||||
|
field_path = f"metadata.{key}"
|
||||||
|
condition = FieldCondition(
|
||||||
|
key=field_path,
|
||||||
|
match=MatchValue(value=value),
|
||||||
|
)
|
||||||
|
must_conditions.append(condition)
|
||||||
|
|
||||||
|
return Filter(must=must_conditions) if must_conditions else None
|
||||||
|
|
||||||
async def delete_collection(self, tenant_id: str) -> bool:
|
async def delete_collection(self, tenant_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Delete tenant's collection.
|
[AC-AISVC-10] Delete tenant's collection.
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from app.api.admin import (
|
||||||
monitoring_router,
|
monitoring_router,
|
||||||
prompt_templates_router,
|
prompt_templates_router,
|
||||||
rag_router,
|
rag_router,
|
||||||
|
scene_slot_bundle_router,
|
||||||
script_flows_router,
|
script_flows_router,
|
||||||
sessions_router,
|
sessions_router,
|
||||||
slot_definition_router,
|
slot_definition_router,
|
||||||
|
|
@ -55,6 +56,11 @@ logging.basicConfig(
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,6 +94,28 @@ async def lifespan(app: FastAPI):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
|
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.services.mid.tool_guide_registry import init_tool_guide_registry
|
||||||
|
|
||||||
|
logger.info("[ToolGuideRegistry] Starting tool guides initialization...")
|
||||||
|
tool_guide_registry = init_tool_guide_registry()
|
||||||
|
logger.info(f"[ToolGuideRegistry] Tool guides loaded: {tool_guide_registry.list_tools()}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ToolRegistry] Tools initialization FAILED: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# [AC-AISVC-29] 预初始化 Embedding 服务,避免首次查询时的延迟
|
||||||
|
try:
|
||||||
|
from app.services.embedding import get_embedding_provider
|
||||||
|
|
||||||
|
logger.info("[AC-AISVC-29] Pre-initializing embedding service...")
|
||||||
|
embedding_provider = await get_embedding_provider()
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-29] Embedding service pre-initialized: "
|
||||||
|
f"provider={embedding_provider.PROVIDER_NAME}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[AC-AISVC-29] Embedding service pre-initialization FAILED: {e}", exc_info=True)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
await close_db()
|
await close_db()
|
||||||
|
|
@ -171,6 +199,7 @@ app.include_router(metadata_schema_router)
|
||||||
app.include_router(monitoring_router)
|
app.include_router(monitoring_router)
|
||||||
app.include_router(prompt_templates_router)
|
app.include_router(prompt_templates_router)
|
||||||
app.include_router(rag_router)
|
app.include_router(rag_router)
|
||||||
|
app.include_router(scene_slot_bundle_router)
|
||||||
app.include_router(script_flows_router)
|
app.include_router(script_flows_router)
|
||||||
app.include_router(sessions_router)
|
app.include_router(sessions_router)
|
||||||
app.include_router(slot_definition_router)
|
app.include_router(slot_definition_router)
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ class ChatMessage(SQLModel, table=True):
|
||||||
[AC-AISVC-13] Chat message entity with tenant isolation.
|
[AC-AISVC-13] Chat message entity with tenant isolation.
|
||||||
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
||||||
[v0.7.0] Extended with monitoring fields for Dashboard statistics.
|
[v0.7.0] Extended with monitoring fields for Dashboard statistics.
|
||||||
|
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "chat_messages"
|
__tablename__ = "chat_messages"
|
||||||
|
|
@ -90,6 +91,11 @@ class ChatMessage(SQLModel, table=True):
|
||||||
sa_column=Column("guardrail_words", JSON, nullable=True),
|
sa_column=Column("guardrail_words", JSON, nullable=True),
|
||||||
description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
|
description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
|
||||||
)
|
)
|
||||||
|
route_trace: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("route_trace", JSON, nullable=True),
|
||||||
|
description="[v0.8.0] Intent routing trace log for hybrid routing observability"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionCreate(SQLModel):
|
class ChatSessionCreate(SQLModel):
|
||||||
|
|
@ -227,6 +233,7 @@ class Document(SQLModel, table=True):
|
||||||
file_type: str | None = Field(default=None, description="File MIME type")
|
file_type: str | None = Field(default=None, description="File MIME type")
|
||||||
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
|
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
|
||||||
error_msg: str | None = Field(default=None, description="Error message if failed")
|
error_msg: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
doc_metadata: dict | None = Field(default=None, sa_type=JSON, description="Document metadata as JSON")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||||
|
|
||||||
|
|
@ -421,6 +428,7 @@ class IntentRule(SQLModel, table=True):
|
||||||
[AC-AISVC-65] Intent rule entity with tenant isolation.
|
[AC-AISVC-65] Intent rule entity with tenant isolation.
|
||||||
Supports keyword and regex matching for intent recognition.
|
Supports keyword and regex matching for intent recognition.
|
||||||
[AC-IDSMETA-16] Extended with metadata field for unified storage structure.
|
[AC-IDSMETA-16] Extended with metadata field for unified storage structure.
|
||||||
|
[v0.8.0] Extended with intent_vector and semantic_examples for hybrid routing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "intent_rules"
|
__tablename__ = "intent_rules"
|
||||||
|
|
@ -458,6 +466,16 @@ class IntentRule(SQLModel, table=True):
|
||||||
sa_column=Column("metadata", JSON, nullable=True),
|
sa_column=Column("metadata", JSON, nullable=True),
|
||||||
description="[AC-IDSMETA-16] Structured metadata for the intent rule"
|
description="[AC-IDSMETA-16] Structured metadata for the intent rule"
|
||||||
)
|
)
|
||||||
|
intent_vector: list[float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("intent_vector", JSON, nullable=True),
|
||||||
|
description="[v0.8.0] Pre-computed intent vector for semantic matching"
|
||||||
|
)
|
||||||
|
semantic_examples: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("semantic_examples", JSON, nullable=True),
|
||||||
|
description="[v0.8.0] Semantic example sentences for dynamic vector computation"
|
||||||
|
)
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||||
|
|
||||||
|
|
@ -475,6 +493,8 @@ class IntentRuleCreate(SQLModel):
|
||||||
fixed_reply: str | None = None
|
fixed_reply: str | None = None
|
||||||
transfer_message: str | None = None
|
transfer_message: str | None = None
|
||||||
metadata_: dict[str, Any] | None = None
|
metadata_: dict[str, Any] | None = None
|
||||||
|
intent_vector: list[float] | None = None
|
||||||
|
semantic_examples: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class IntentRuleUpdate(SQLModel):
|
class IntentRuleUpdate(SQLModel):
|
||||||
|
|
@ -491,6 +511,8 @@ class IntentRuleUpdate(SQLModel):
|
||||||
transfer_message: str | None = None
|
transfer_message: str | None = None
|
||||||
is_enabled: bool | None = None
|
is_enabled: bool | None = None
|
||||||
metadata_: dict[str, Any] | None = None
|
metadata_: dict[str, Any] | None = None
|
||||||
|
intent_vector: list[float] | None = None
|
||||||
|
semantic_examples: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class IntentMatchResult:
|
class IntentMatchResult:
|
||||||
|
|
@ -810,6 +832,24 @@ class FlowStep(SQLModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
|
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
|
||||||
)
|
)
|
||||||
|
allowed_kb_ids: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[Step-KB-Binding] Allowed knowledge base IDs for this step. If set, KB search will be restricted to these KBs."
|
||||||
|
)
|
||||||
|
preferred_kb_ids: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[Step-KB-Binding] Preferred knowledge base IDs for this step. These KBs will be searched first."
|
||||||
|
)
|
||||||
|
kb_query_hint: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[Step-KB-Binding] Query hint for KB search in this step, helps improve retrieval accuracy."
|
||||||
|
)
|
||||||
|
max_kb_calls_per_step: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
le=5,
|
||||||
|
description="[Step-KB-Binding] Max KB calls allowed per step. Default is 1 if not set."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScriptFlowCreate(SQLModel):
|
class ScriptFlowCreate(SQLModel):
|
||||||
|
|
@ -1078,6 +1118,7 @@ class MetadataFieldDefinition(SQLModel, table=True):
|
||||||
)
|
)
|
||||||
is_filterable: bool = Field(default=True, description="是否可用于过滤")
|
is_filterable: bool = Field(default=True, description="是否可用于过滤")
|
||||||
is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
|
is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
|
||||||
|
usage_description: str | None = Field(default=None, description="用途说明")
|
||||||
field_roles: list[str] = Field(
|
field_roles: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
|
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
|
||||||
|
|
@ -1104,6 +1145,7 @@ class MetadataFieldDefinitionCreate(SQLModel):
|
||||||
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
|
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
|
||||||
is_filterable: bool = Field(default=True)
|
is_filterable: bool = Field(default=True)
|
||||||
is_rank_feature: bool = Field(default=False)
|
is_rank_feature: bool = Field(default=False)
|
||||||
|
usage_description: str | None = None
|
||||||
field_roles: list[str] = Field(default_factory=list)
|
field_roles: list[str] = Field(default_factory=list)
|
||||||
status: str = Field(default=MetadataFieldStatus.DRAFT.value)
|
status: str = Field(default=MetadataFieldStatus.DRAFT.value)
|
||||||
|
|
||||||
|
|
@ -1118,6 +1160,7 @@ class MetadataFieldDefinitionUpdate(SQLModel):
|
||||||
scope: list[str] | None = None
|
scope: list[str] | None = None
|
||||||
is_filterable: bool | None = None
|
is_filterable: bool | None = None
|
||||||
is_rank_feature: bool | None = None
|
is_rank_feature: bool | None = None
|
||||||
|
usage_description: str | None = None
|
||||||
field_roles: list[str] | None = None
|
field_roles: list[str] | None = None
|
||||||
status: str | None = None
|
status: str | None = None
|
||||||
|
|
||||||
|
|
@ -1131,6 +1174,17 @@ class ExtractStrategy(str, Enum):
|
||||||
USER_INPUT = "user_input"
|
USER_INPUT = "user_input"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractFailureType(str, Enum):
|
||||||
|
"""
|
||||||
|
[AC-MRS-07-UPGRADE] 提取失败类型
|
||||||
|
统一失败分类,用于追踪和日志
|
||||||
|
"""
|
||||||
|
EXTRACT_EMPTY = "EXTRACT_EMPTY" # 提取结果为空
|
||||||
|
EXTRACT_PARSE_FAIL = "EXTRACT_PARSE_FAIL" # 解析失败
|
||||||
|
EXTRACT_VALIDATION_FAIL = "EXTRACT_VALIDATION_FAIL" # 校验失败
|
||||||
|
EXTRACT_RUNTIME_ERROR = "EXTRACT_RUNTIME_ERROR" # 运行时错误
|
||||||
|
|
||||||
|
|
||||||
class SlotValueSource(str, Enum):
|
class SlotValueSource(str, Enum):
|
||||||
"""
|
"""
|
||||||
[AC-MRS-09] 槽位值来源
|
[AC-MRS-09] 槽位值来源
|
||||||
|
|
@ -1145,6 +1199,7 @@ class SlotDefinition(SQLModel, table=True):
|
||||||
"""
|
"""
|
||||||
[AC-MRS-07,08] 槽位定义表
|
[AC-MRS-07,08] 槽位定义表
|
||||||
独立的槽位定义模型,与元数据字段解耦但可复用
|
独立的槽位定义模型,与元数据字段解耦但可复用
|
||||||
|
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "slot_definitions"
|
__tablename__ = "slot_definitions"
|
||||||
|
|
@ -1162,14 +1217,31 @@ class SlotDefinition(SQLModel, table=True):
|
||||||
min_length=1,
|
min_length=1,
|
||||||
max_length=100,
|
max_length=100,
|
||||||
)
|
)
|
||||||
|
display_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位名称,给运营/教研看的中文名,例:grade -> '当前年级'",
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||||
|
max_length=500,
|
||||||
|
)
|
||||||
type: str = Field(
|
type: str = Field(
|
||||||
default=MetadataFieldType.STRING.value,
|
default=MetadataFieldType.STRING.value,
|
||||||
description="槽位类型: string/number/boolean/enum/array_enum"
|
description="槽位类型: string/number/boolean/enum/array_enum"
|
||||||
)
|
)
|
||||||
required: bool = Field(default=False, description="是否必填槽位")
|
required: bool = Field(default=False, description="是否必填槽位")
|
||||||
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容读取
|
||||||
extract_strategy: str | None = Field(
|
extract_strategy: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="提取策略: rule/llm/user_input"
|
description="[兼容字段] 提取策略: rule/llm/user_input,已废弃,请使用 extract_strategies"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("extract_strategies", JSON, nullable=True),
|
||||||
|
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||||
)
|
)
|
||||||
validation_rule: str | None = Field(
|
validation_rule: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
@ -1192,14 +1264,72 @@ class SlotDefinition(SQLModel, table=True):
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||||
|
|
||||||
|
def get_effective_strategies(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-07-UPGRADE] 获取有效的提取策略链
|
||||||
|
优先使用 extract_strategies,如果不存在则兼容读取 extract_strategy
|
||||||
|
"""
|
||||||
|
if self.extract_strategies and len(self.extract_strategies) > 0:
|
||||||
|
return self.extract_strategies
|
||||||
|
if self.extract_strategy:
|
||||||
|
return [self.extract_strategy]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate_strategies(self) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (是否有效, 错误信息)
|
||||||
|
"""
|
||||||
|
valid_strategies = {"rule", "llm", "user_input"}
|
||||||
|
strategies = self.get_effective_strategies()
|
||||||
|
|
||||||
|
if not strategies:
|
||||||
|
return True, "" # 空策略链视为有效(使用默认行为)
|
||||||
|
|
||||||
|
# 校验至少1个策略
|
||||||
|
if len(strategies) == 0:
|
||||||
|
return False, "提取策略链不能为空"
|
||||||
|
|
||||||
|
# 校验不允许重复策略
|
||||||
|
if len(strategies) != len(set(strategies)):
|
||||||
|
return False, "提取策略链中不允许重复的策略"
|
||||||
|
|
||||||
|
# 校验策略值有效
|
||||||
|
invalid = [s for s in strategies if s not in valid_strategies]
|
||||||
|
if invalid:
|
||||||
|
return False, f"无效的提取策略: {invalid},有效值为: {list(valid_strategies)}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
class SlotDefinitionCreate(SQLModel):
|
class SlotDefinitionCreate(SQLModel):
|
||||||
"""[AC-MRS-07,08] 创建槽位定义"""
|
"""[AC-MRS-07,08] 创建槽位定义"""
|
||||||
|
|
||||||
slot_key: str = Field(..., min_length=1, max_length=100)
|
slot_key: str = Field(..., min_length=1, max_length=100)
|
||||||
|
display_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位名称,给运营/教研看的中文名",
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||||
|
max_length=500,
|
||||||
|
)
|
||||||
type: str = Field(default=MetadataFieldType.STRING.value)
|
type: str = Field(default=MetadataFieldType.STRING.value)
|
||||||
required: bool = Field(default=False)
|
required: bool = Field(default=False)
|
||||||
extract_strategy: str | None = None
|
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||||
|
extract_strategy: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||||
|
)
|
||||||
validation_rule: str | None = None
|
validation_rule: str | None = None
|
||||||
ask_back_prompt: str | None = None
|
ask_back_prompt: str | None = None
|
||||||
default_value: dict[str, Any] | None = None
|
default_value: dict[str, Any] | None = None
|
||||||
|
|
@ -1209,9 +1339,28 @@ class SlotDefinitionCreate(SQLModel):
|
||||||
class SlotDefinitionUpdate(SQLModel):
|
class SlotDefinitionUpdate(SQLModel):
|
||||||
"""[AC-MRS-07] 更新槽位定义"""
|
"""[AC-MRS-07] 更新槽位定义"""
|
||||||
|
|
||||||
|
display_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位名称,给运营/教研看的中文名",
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="槽位说明,解释这个槽位采集什么、用于哪里",
|
||||||
|
max_length=500,
|
||||||
|
)
|
||||||
type: str | None = None
|
type: str | None = None
|
||||||
required: bool | None = None
|
required: bool | None = None
|
||||||
extract_strategy: str | None = None
|
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||||
|
extract_strategy: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||||
|
)
|
||||||
validation_rule: str | None = None
|
validation_rule: str | None = None
|
||||||
ask_back_prompt: str | None = None
|
ask_back_prompt: str | None = None
|
||||||
default_value: dict[str, Any] | None = None
|
default_value: dict[str, Any] | None = None
|
||||||
|
|
@ -1522,3 +1671,107 @@ class MidAuditLog(SQLModel, table=True):
|
||||||
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
|
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
|
||||||
latency_ms: int | None = Field(default=None, description="总耗时(ms)")
|
latency_ms: int | None = Field(default=None, description="总耗时(ms)")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SceneSlotBundleStatus(str, Enum):
|
||||||
|
"""[AC-SCENE-SLOT-01] 场景槽位包状态"""
|
||||||
|
DRAFT = "draft"
|
||||||
|
ACTIVE = "active"
|
||||||
|
DEPRECATED = "deprecated"
|
||||||
|
|
||||||
|
|
||||||
|
class SceneSlotBundle(SQLModel, table=True):
|
||||||
|
"""
|
||||||
|
[AC-SCENE-SLOT-01] 场景-槽位映射配置
|
||||||
|
定义每个场景需要采集的槽位集合
|
||||||
|
|
||||||
|
三层关系:
|
||||||
|
- 层1:slot ↔ metadata(通过 linked_field_id)
|
||||||
|
- 层2:scene ↔ slot_bundle(本模型)
|
||||||
|
- 层3:step.expected_variables ↔ slot_key(话术步骤引用)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "scene_slot_bundles"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_scene_slot_bundles_tenant", "tenant_id"),
|
||||||
|
Index("ix_scene_slot_bundles_tenant_scene", "tenant_id", "scene_key", unique=True),
|
||||||
|
Index("ix_scene_slot_bundles_tenant_status", "tenant_id", "status"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||||
|
scene_key: str = Field(
|
||||||
|
...,
|
||||||
|
description="场景标识,如 'open_consult', 'refund_apply', 'course_recommend'",
|
||||||
|
min_length=1,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
scene_name: str = Field(
|
||||||
|
...,
|
||||||
|
description="场景名称,如 '开放咨询', '退款申请', '课程推荐'",
|
||||||
|
min_length=1,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="场景描述"
|
||||||
|
)
|
||||||
|
required_slots: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
sa_column=Column("required_slots", JSON, nullable=False),
|
||||||
|
description="必填槽位 slot_key 列表"
|
||||||
|
)
|
||||||
|
optional_slots: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
sa_column=Column("optional_slots", JSON, nullable=False),
|
||||||
|
description="可选槽位 slot_key 列表"
|
||||||
|
)
|
||||||
|
slot_priority: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("slot_priority", JSON, nullable=True),
|
||||||
|
description="槽位采集优先级顺序(slot_key 列表)"
|
||||||
|
)
|
||||||
|
completion_threshold: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="完成阈值(0.0-1.0),必填槽位填充比例达到此值视为完成"
|
||||||
|
)
|
||||||
|
ask_back_order: str = Field(
|
||||||
|
default="priority",
|
||||||
|
description="追问顺序策略: priority/required_first/parallel"
|
||||||
|
)
|
||||||
|
status: str = Field(
|
||||||
|
default=SceneSlotBundleStatus.DRAFT.value,
|
||||||
|
description="状态: draft/active/deprecated"
|
||||||
|
)
|
||||||
|
version: int = Field(default=1, description="版本号")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
|
||||||
|
|
||||||
|
|
||||||
|
class SceneSlotBundleCreate(SQLModel):
|
||||||
|
"""[AC-SCENE-SLOT-01] 创建场景槽位包"""
|
||||||
|
|
||||||
|
scene_key: str = Field(..., min_length=1, max_length=100)
|
||||||
|
scene_name: str = Field(..., min_length=1, max_length=100)
|
||||||
|
description: str | None = None
|
||||||
|
required_slots: list[str] = Field(default_factory=list)
|
||||||
|
optional_slots: list[str] = Field(default_factory=list)
|
||||||
|
slot_priority: list[str] | None = None
|
||||||
|
completion_threshold: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
ask_back_order: str = Field(default="priority")
|
||||||
|
status: str = Field(default=SceneSlotBundleStatus.DRAFT.value)
|
||||||
|
|
||||||
|
|
||||||
|
class SceneSlotBundleUpdate(SQLModel):
|
||||||
|
"""[AC-SCENE-SLOT-01] 更新场景槽位包"""
|
||||||
|
|
||||||
|
scene_name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||||
|
description: str | None = None
|
||||||
|
required_slots: list[str] | None = None
|
||||||
|
optional_slots: list[str] | None = None
|
||||||
|
slot_priority: list[str] | None = None
|
||||||
|
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||||
|
ask_back_order: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,16 @@ class SlotDefinitionResponse(BaseModel):
|
||||||
slot_key: str = Field(..., description="槽位键名")
|
slot_key: str = Field(..., description="槽位键名")
|
||||||
type: str = Field(..., description="槽位类型")
|
type: str = Field(..., description="槽位类型")
|
||||||
required: bool = Field(default=False, description="是否必填槽位")
|
required: bool = Field(default=False, description="是否必填槽位")
|
||||||
extract_strategy: str | None = Field(default=None, description="提取策略")
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||||
|
extract_strategy: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[兼容字段] 单提取策略,已废弃"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 新增策略链字段
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input"
|
||||||
|
)
|
||||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||||
default_value: dict[str, Any] | None = Field(default=None, description="默认值")
|
default_value: dict[str, Any] | None = Field(default=None, description="默认值")
|
||||||
|
|
@ -157,9 +166,15 @@ class SlotDefinitionCreateRequest(BaseModel):
|
||||||
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
|
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
|
||||||
type: str = Field(default="string", description="槽位类型")
|
type: str = Field(default="string", description="槽位类型")
|
||||||
required: bool = Field(default=False, description="是否必填槽位")
|
required: bool = Field(default=False, description="是否必填槽位")
|
||||||
|
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||||
extract_strategy: str | None = Field(
|
extract_strategy: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="提取策略: rule/llm/user_input"
|
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||||
)
|
)
|
||||||
validation_rule: str | None = Field(default=None, description="校验规则")
|
validation_rule: str | None = Field(default=None, description="校验规则")
|
||||||
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
|
||||||
|
|
@ -172,7 +187,16 @@ class SlotDefinitionUpdateRequest(BaseModel):
|
||||||
|
|
||||||
type: str | None = None
|
type: str | None = None
|
||||||
required: bool | None = None
|
required: bool | None = None
|
||||||
extract_strategy: str | None = None
|
# [AC-MRS-07-UPGRADE] 支持策略链
|
||||||
|
extract_strategies: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input,按顺序执行直到成功"
|
||||||
|
)
|
||||||
|
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
|
||||||
|
extract_strategy: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
|
||||||
|
)
|
||||||
validation_rule: str | None = None
|
validation_rule: str | None = None
|
||||||
ask_back_prompt: str | None = None
|
ask_back_prompt: str | None = None
|
||||||
default_value: dict[str, Any] | None = None
|
default_value: dict[str, Any] | None = None
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,6 @@ class ApiKeyService:
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
# Backward-compat fallback for environments without new columns
|
# Backward-compat fallback for environments without new columns
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,9 @@ import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||||
|
|
@ -20,6 +23,7 @@ from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
||||||
|
EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding"
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProviderFactory:
|
class EmbeddingProviderFactory:
|
||||||
|
|
@ -170,8 +174,32 @@ class EmbeddingConfigManager:
|
||||||
self._config = self._default_config.copy()
|
self._config = self._default_config.copy()
|
||||||
self._provider: EmbeddingProvider | None = None
|
self._provider: EmbeddingProvider | None = None
|
||||||
|
|
||||||
|
self._settings = get_settings()
|
||||||
|
self._redis_client: redis.Redis | None = None
|
||||||
|
|
||||||
|
self._load_from_redis()
|
||||||
self._load_from_file()
|
self._load_from_file()
|
||||||
|
|
||||||
|
def _load_from_redis(self) -> None:
|
||||||
|
"""Load configuration from Redis if exists."""
|
||||||
|
try:
|
||||||
|
if not self._settings.redis_enabled:
|
||||||
|
return
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY)
|
||||||
|
if not saved_raw:
|
||||||
|
return
|
||||||
|
saved = json.loads(saved_raw)
|
||||||
|
self._provider_name = saved.get("provider", self._default_provider)
|
||||||
|
self._config = saved.get("config", self._default_config.copy())
|
||||||
|
logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load embedding config from Redis: {e}")
|
||||||
|
|
||||||
def _load_from_file(self) -> None:
|
def _load_from_file(self) -> None:
|
||||||
"""Load configuration from file if exists."""
|
"""Load configuration from file if exists."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -184,6 +212,28 @@ class EmbeddingConfigManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load embedding config from file: {e}")
|
logger.warning(f"Failed to load embedding config from file: {e}")
|
||||||
|
|
||||||
|
def _save_to_redis(self) -> None:
|
||||||
|
"""Save configuration to Redis."""
|
||||||
|
try:
|
||||||
|
if not self._settings.redis_enabled:
|
||||||
|
return
|
||||||
|
if self._redis_client is None:
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
self._redis_client.set(
|
||||||
|
EMBEDDING_CONFIG_REDIS_KEY,
|
||||||
|
json.dumps({
|
||||||
|
"provider": self._provider_name,
|
||||||
|
"config": self._config,
|
||||||
|
}, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
logger.info(f"Saved embedding config to Redis: provider={self._provider_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save embedding config to Redis: {e}")
|
||||||
|
|
||||||
def _save_to_file(self) -> None:
|
def _save_to_file(self) -> None:
|
||||||
"""Save configuration to file."""
|
"""Save configuration to file."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -262,6 +312,7 @@ class EmbeddingConfigManager:
|
||||||
self._config = config
|
self._config = config
|
||||||
self._provider = new_provider_instance
|
self._provider = new_provider_instance
|
||||||
|
|
||||||
|
self._save_to_redis()
|
||||||
self._save_to_file()
|
self._save_to_file()
|
||||||
|
|
||||||
logger.info(f"Updated embedding config: provider={provider}")
|
logger.info(f"Updated embedding config: provider={provider}")
|
||||||
|
|
|
||||||
|
|
@ -322,7 +322,7 @@ class FlowEngine:
|
||||||
stmt = select(FlowInstance).where(
|
stmt = select(FlowInstance).where(
|
||||||
FlowInstance.tenant_id == tenant_id,
|
FlowInstance.tenant_id == tenant_id,
|
||||||
FlowInstance.session_id == session_id,
|
FlowInstance.session_id == session_id,
|
||||||
).order_by(col(FlowInstance.created_at).desc())
|
).order_by(col(FlowInstance.started_at).desc())
|
||||||
result = await self._session.execute(stmt)
|
result = await self._session.execute(stmt)
|
||||||
instance = result.scalar_one_or_none()
|
instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,8 @@ class IntentRuleService:
|
||||||
is_enabled=True,
|
is_enabled=True,
|
||||||
hit_count=0,
|
hit_count=0,
|
||||||
metadata_=create_data.metadata_,
|
metadata_=create_data.metadata_,
|
||||||
|
intent_vector=create_data.intent_vector,
|
||||||
|
semantic_examples=create_data.semantic_examples,
|
||||||
)
|
)
|
||||||
self._session.add(rule)
|
self._session.add(rule)
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|
@ -195,6 +197,10 @@ class IntentRuleService:
|
||||||
rule.is_enabled = update_data.is_enabled
|
rule.is_enabled = update_data.is_enabled
|
||||||
if update_data.metadata_ is not None:
|
if update_data.metadata_ is not None:
|
||||||
rule.metadata_ = update_data.metadata_
|
rule.metadata_ = update_data.metadata_
|
||||||
|
if update_data.intent_vector is not None:
|
||||||
|
rule.intent_vector = update_data.intent_vector
|
||||||
|
if update_data.semantic_examples is not None:
|
||||||
|
rule.semantic_examples = update_data.semantic_examples
|
||||||
|
|
||||||
rule.updated_at = datetime.utcnow()
|
rule.updated_at = datetime.utcnow()
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|
@ -267,7 +273,7 @@ class IntentRuleService:
|
||||||
select(IntentRule)
|
select(IntentRule)
|
||||||
.where(
|
.where(
|
||||||
IntentRule.tenant_id == tenant_id,
|
IntentRule.tenant_id == tenant_id,
|
||||||
IntentRule.is_enabled == True,
|
IntentRule.is_enabled == True, # noqa: E712
|
||||||
)
|
)
|
||||||
.order_by(col(IntentRule.priority).desc())
|
.order_by(col(IntentRule.priority).desc())
|
||||||
)
|
)
|
||||||
|
|
@ -300,6 +306,8 @@ class IntentRuleService:
|
||||||
"is_enabled": rule.is_enabled,
|
"is_enabled": rule.is_enabled,
|
||||||
"hit_count": rule.hit_count,
|
"hit_count": rule.hit_count,
|
||||||
"metadata": rule.metadata_,
|
"metadata": rule.metadata_,
|
||||||
"created_at": rule.created_at.isoformat(),
|
"intent_vector": rule.intent_vector,
|
||||||
"updated_at": rule.updated_at.isoformat(),
|
"semantic_examples": rule.semantic_examples,
|
||||||
|
"created_at": rule.created_at.isoformat() if rule.created_at else None,
|
||||||
|
"updated_at": rule.updated_at.isoformat() if rule.updated_at else None,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class KBService:
|
||||||
file_name: str,
|
file_name: str,
|
||||||
file_content: bytes,
|
file_content: bytes,
|
||||||
file_type: str | None = None,
|
file_type: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
) -> tuple[Document, IndexJob]:
|
) -> tuple[Document, IndexJob]:
|
||||||
"""
|
"""
|
||||||
[AC-ASA-01] Upload document and create indexing job.
|
[AC-ASA-01] Upload document and create indexing job.
|
||||||
|
|
@ -108,6 +109,7 @@ class KBService:
|
||||||
file_size=len(file_content),
|
file_size=len(file_content),
|
||||||
file_type=file_type,
|
file_type=file_type,
|
||||||
status=DocumentStatus.PENDING.value,
|
status=DocumentStatus.PENDING.value,
|
||||||
|
doc_metadata=metadata,
|
||||||
)
|
)
|
||||||
self._session.add(document)
|
self._session.add(document)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,14 @@ LLM Adapter module for AI Service.
|
||||||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
from app.services.llm.base import (
|
||||||
|
LLMClient,
|
||||||
|
LLMConfig,
|
||||||
|
LLMResponse,
|
||||||
|
LLMStreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
from app.services.llm.openai_client import OpenAIClient
|
from app.services.llm.openai_client import OpenAIClient
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -12,4 +19,6 @@ __all__ = [
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
"LLMStreamChunk",
|
"LLMStreamChunk",
|
||||||
"OpenAIClient",
|
"OpenAIClient",
|
||||||
|
"ToolCall",
|
||||||
|
"ToolDefinition",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -28,17 +28,45 @@ class LLMConfig:
|
||||||
extra_params: dict[str, Any] = field(default_factory=dict)
|
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
"""
|
||||||
|
Represents a function call from the LLM.
|
||||||
|
Used in Function Calling mode.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Any]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
import json
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
"""
|
"""
|
||||||
Response from LLM generation.
|
Response from LLM generation.
|
||||||
[AC-AISVC-02] Contains generated content and metadata.
|
[AC-AISVC-02] Contains generated content and metadata.
|
||||||
"""
|
"""
|
||||||
content: str
|
content: str | None = None
|
||||||
model: str
|
model: str = ""
|
||||||
usage: dict[str, int] = field(default_factory=dict)
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
finish_reason: str = "stop"
|
finish_reason: str = "stop"
|
||||||
|
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_tool_calls(self) -> bool:
|
||||||
|
"""Check if response contains tool calls."""
|
||||||
|
return len(self.tool_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -50,9 +78,33 @@ class LLMStreamChunk:
|
||||||
delta: str
|
delta: str
|
||||||
model: str
|
model: str
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
|
tool_calls_delta: list[dict[str, Any]] = field(default_factory=list)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
"""
|
||||||
|
Tool definition for Function Calling.
|
||||||
|
Compatible with OpenAI/DeepSeek function calling format.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
type: str = "function"
|
||||||
|
|
||||||
|
def to_openai_format(self) -> dict[str, Any]:
|
||||||
|
"""Convert to OpenAI tools format."""
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLMClient(ABC):
|
class LLMClient(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for LLM clients.
|
Abstract base class for LLM clients.
|
||||||
|
|
@ -67,6 +119,8 @@ class LLMClient(ABC):
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
|
|
@ -76,10 +130,12 @@ class LLMClient(ABC):
|
||||||
Args:
|
Args:
|
||||||
messages: List of chat messages with 'role' and 'content'.
|
messages: List of chat messages with 'role' and 'content'.
|
||||||
config: Optional LLM configuration overrides.
|
config: Optional LLM configuration overrides.
|
||||||
|
tools: Optional list of tools for function calling.
|
||||||
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||||
**kwargs: Additional provider-specific parameters.
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMResponse with generated content and metadata.
|
LLMResponse with generated content, tool_calls, and metadata.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
LLMException: If generation fails.
|
LLMException: If generation fails.
|
||||||
|
|
@ -91,6 +147,8 @@ class LLMClient(ABC):
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -100,6 +158,8 @@ class LLMClient(ABC):
|
||||||
Args:
|
Args:
|
||||||
messages: List of chat messages with 'role' and 'content'.
|
messages: List of chat messages with 'role' and 'content'.
|
||||||
config: Optional LLM configuration overrides.
|
config: Optional LLM configuration overrides.
|
||||||
|
tools: Optional list of tools for function calling.
|
||||||
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||||
**kwargs: Additional provider-specific parameters.
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,16 @@ from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
from app.services.llm.base import LLMClient, LLMConfig
|
from app.services.llm.base import LLMClient, LLMConfig
|
||||||
from app.services.llm.openai_client import OpenAIClient
|
from app.services.llm.openai_client import OpenAIClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LLM_CONFIG_FILE = Path("config/llm_config.json")
|
LLM_CONFIG_FILE = Path("config/llm_config.json")
|
||||||
|
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -286,6 +290,8 @@ class LLMConfigManager:
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
self._settings = settings
|
||||||
|
self._redis_client: redis.Redis | None = None
|
||||||
|
|
||||||
self._current_provider: str = settings.llm_provider
|
self._current_provider: str = settings.llm_provider
|
||||||
self._current_config: dict[str, Any] = {
|
self._current_config: dict[str, Any] = {
|
||||||
|
|
@ -299,8 +305,75 @@ class LLMConfigManager:
|
||||||
}
|
}
|
||||||
self._client: LLMClient | None = None
|
self._client: LLMClient | None = None
|
||||||
|
|
||||||
|
self._load_from_redis()
|
||||||
self._load_from_file()
|
self._load_from_file()
|
||||||
|
|
||||||
|
def _load_from_redis(self) -> None:
|
||||||
|
"""Load configuration from Redis if exists."""
|
||||||
|
try:
|
||||||
|
if not self._settings.redis_enabled:
|
||||||
|
return
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||||
|
if not saved_raw:
|
||||||
|
return
|
||||||
|
saved = json.loads(saved_raw)
|
||||||
|
self._current_provider = saved.get("provider", self._current_provider)
|
||||||
|
saved_config = saved.get("config", {})
|
||||||
|
if saved_config:
|
||||||
|
self._current_config.update(saved_config)
|
||||||
|
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||||
|
|
||||||
|
def _save_to_redis(self) -> None:
|
||||||
|
"""Save configuration to Redis."""
|
||||||
|
try:
|
||||||
|
if not self._settings.redis_enabled:
|
||||||
|
return
|
||||||
|
if self._redis_client is None:
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
self._redis_client.set(
|
||||||
|
LLM_CONFIG_REDIS_KEY,
|
||||||
|
json.dumps({
|
||||||
|
"provider": self._current_provider,
|
||||||
|
"config": self._current_config,
|
||||||
|
}, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
|
||||||
|
|
||||||
|
def _load_from_redis(self) -> None:
|
||||||
|
"""Load configuration from Redis if exists."""
|
||||||
|
try:
|
||||||
|
if not self._settings.redis_enabled:
|
||||||
|
return
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self._settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||||
|
if not saved_raw:
|
||||||
|
return
|
||||||
|
saved = json.loads(saved_raw)
|
||||||
|
self._current_provider = saved.get("provider", self._current_provider)
|
||||||
|
saved_config = saved.get("config", {})
|
||||||
|
if saved_config:
|
||||||
|
self._current_config.update(saved_config)
|
||||||
|
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||||
|
|
||||||
def _load_from_file(self) -> None:
|
def _load_from_file(self) -> None:
|
||||||
"""Load configuration from file if exists."""
|
"""Load configuration from file if exists."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -364,6 +437,7 @@ class LLMConfigManager:
|
||||||
self._current_provider = provider
|
self._current_provider = provider
|
||||||
self._current_config = validated_config
|
self._current_config = validated_config
|
||||||
|
|
||||||
|
self._save_to_redis()
|
||||||
self._save_to_file()
|
self._save_to_file()
|
||||||
|
|
||||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,14 @@ from tenacity import (
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
|
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
|
||||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
from app.services.llm.base import (
|
||||||
|
LLMClient,
|
||||||
|
LLMConfig,
|
||||||
|
LLMResponse,
|
||||||
|
LLMStreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -95,6 +102,8 @@ class OpenAIClient(LLMClient):
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
config: LLMConfig,
|
config: LLMConfig,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build request body for OpenAI API."""
|
"""Build request body for OpenAI API."""
|
||||||
|
|
@ -106,6 +115,13 @@ class OpenAIClient(LLMClient):
|
||||||
"top_p": config.top_p,
|
"top_p": config.top_p,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
body["tools"] = [tool.to_openai_format() for tool in tools]
|
||||||
|
|
||||||
|
if tool_choice:
|
||||||
|
body["tool_choice"] = tool_choice
|
||||||
|
|
||||||
body.update(config.extra_params)
|
body.update(config.extra_params)
|
||||||
body.update(kwargs)
|
body.update(kwargs)
|
||||||
return body
|
return body
|
||||||
|
|
@ -119,6 +135,8 @@ class OpenAIClient(LLMClient):
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
|
|
@ -128,10 +146,12 @@ class OpenAIClient(LLMClient):
|
||||||
Args:
|
Args:
|
||||||
messages: List of chat messages with 'role' and 'content'.
|
messages: List of chat messages with 'role' and 'content'.
|
||||||
config: Optional LLM configuration overrides.
|
config: Optional LLM configuration overrides.
|
||||||
|
tools: Optional list of tools for function calling.
|
||||||
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||||
**kwargs: Additional provider-specific parameters.
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMResponse with generated content and metadata.
|
LLMResponse with generated content, tool_calls, and metadata.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
LLMException: If generation fails.
|
LLMException: If generation fails.
|
||||||
|
|
@ -140,9 +160,14 @@ class OpenAIClient(LLMClient):
|
||||||
effective_config = config or self._default_config
|
effective_config = config or self._default_config
|
||||||
client = self._get_client(effective_config.timeout_seconds)
|
client = self._get_client(effective_config.timeout_seconds)
|
||||||
|
|
||||||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
body = self._build_request_body(
|
||||||
|
messages, effective_config, stream=False,
|
||||||
|
tools=tools, tool_choice=tool_choice, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||||
|
if tools:
|
||||||
|
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
|
||||||
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
role = msg.get("role", "unknown")
|
role = msg.get("role", "unknown")
|
||||||
|
|
@ -177,14 +202,18 @@ class OpenAIClient(LLMClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
content = choice["message"]["content"]
|
message = choice["message"]
|
||||||
|
content = message.get("content")
|
||||||
usage = data.get("usage", {})
|
usage = data.get("usage", {})
|
||||||
finish_reason = choice.get("finish_reason", "stop")
|
finish_reason = choice.get("finish_reason", "stop")
|
||||||
|
|
||||||
|
tool_calls = self._parse_tool_calls(message)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-02] Generated response: "
|
f"[AC-AISVC-02] Generated response: "
|
||||||
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
||||||
f"finish_reason={finish_reason}"
|
f"finish_reason={finish_reason}, "
|
||||||
|
f"tool_calls={len(tool_calls)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
|
|
@ -192,6 +221,7 @@ class OpenAIClient(LLMClient):
|
||||||
model=data.get("model", effective_config.model),
|
model=data.get("model", effective_config.model),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
metadata={"raw_response": data},
|
metadata={"raw_response": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -201,11 +231,34 @@ class OpenAIClient(LLMClient):
|
||||||
message=f"Unexpected LLM response format: {e}",
|
message=f"Unexpected LLM response format: {e}",
|
||||||
details=[{"response": str(data)}],
|
details=[{"response": str(data)}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
|
||||||
|
"""Parse tool calls from LLM response message."""
|
||||||
|
tool_calls = []
|
||||||
|
raw_tool_calls = message.get("tool_calls", [])
|
||||||
|
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
if tc.get("type") == "function":
|
||||||
|
func = tc.get("function", {})
|
||||||
|
try:
|
||||||
|
arguments = json.loads(func.get("arguments", "{}"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
tool_calls.append(ToolCall(
|
||||||
|
id=tc.get("id", ""),
|
||||||
|
name=func.get("name", ""),
|
||||||
|
arguments=arguments,
|
||||||
|
))
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
async def stream_generate(
|
async def stream_generate(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
tools: list[ToolDefinition] | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -215,6 +268,8 @@ class OpenAIClient(LLMClient):
|
||||||
Args:
|
Args:
|
||||||
messages: List of chat messages with 'role' and 'content'.
|
messages: List of chat messages with 'role' and 'content'.
|
||||||
config: Optional LLM configuration overrides.
|
config: Optional LLM configuration overrides.
|
||||||
|
tools: Optional list of tools for function calling.
|
||||||
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
|
||||||
**kwargs: Additional provider-specific parameters.
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
|
|
@ -227,9 +282,14 @@ class OpenAIClient(LLMClient):
|
||||||
effective_config = config or self._default_config
|
effective_config = config or self._default_config
|
||||||
client = self._get_client(effective_config.timeout_seconds)
|
client = self._get_client(effective_config.timeout_seconds)
|
||||||
|
|
||||||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
body = self._build_request_body(
|
||||||
|
messages, effective_config, stream=True,
|
||||||
|
tools=tools, tool_choice=tool_choice, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||||
|
if tools:
|
||||||
|
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
|
||||||
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
role = msg.get("role", "unknown")
|
role = msg.get("role", "unknown")
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,19 @@ class MetadataFieldDefinitionService:
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession):
|
def __init__(self, session: AsyncSession):
|
||||||
self._session = session
|
self._session = session
|
||||||
|
|
||||||
|
async def _invalidate_cache(self, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
清除租户的元数据字段缓存
|
||||||
|
在字段创建、更新、删除时调用
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||||
|
cache_service = await get_metadata_cache_service()
|
||||||
|
await cache_service.invalidate(tenant_id)
|
||||||
|
except Exception as e:
|
||||||
|
# 缓存失效失败不影响主流程
|
||||||
|
logger.warning(f"[AC-IDSMETA-13] Failed to invalidate cache: {e}")
|
||||||
|
|
||||||
async def list_field_definitions(
|
async def list_field_definitions(
|
||||||
self,
|
self,
|
||||||
|
|
@ -180,6 +193,9 @@ class MetadataFieldDefinitionService:
|
||||||
self._session.add(field)
|
self._session.add(field)
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|
||||||
|
# 清除缓存,使新字段在下次查询时生效
|
||||||
|
await self._invalidate_cache(tenant_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
|
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
|
||||||
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
|
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
|
||||||
|
|
@ -223,6 +239,10 @@ class MetadataFieldDefinitionService:
|
||||||
field.is_filterable = field_update.is_filterable
|
field.is_filterable = field_update.is_filterable
|
||||||
if field_update.is_rank_feature is not None:
|
if field_update.is_rank_feature is not None:
|
||||||
field.is_rank_feature = field_update.is_rank_feature
|
field.is_rank_feature = field_update.is_rank_feature
|
||||||
|
# [AC-MRS-01] 修复:添加 field_roles 更新逻辑
|
||||||
|
if field_update.field_roles is not None:
|
||||||
|
self._validate_field_roles(field_update.field_roles)
|
||||||
|
field.field_roles = field_update.field_roles
|
||||||
if field_update.status is not None:
|
if field_update.status is not None:
|
||||||
old_status = field.status
|
old_status = field.status
|
||||||
field.status = field_update.status
|
field.status = field_update.status
|
||||||
|
|
@ -235,6 +255,9 @@ class MetadataFieldDefinitionService:
|
||||||
field.updated_at = datetime.utcnow()
|
field.updated_at = datetime.utcnow()
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|
||||||
|
# 清除缓存,使更新在下次查询时生效
|
||||||
|
await self._invalidate_cache(tenant_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
|
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
|
||||||
f"field_id={field_id}, version={field.version}"
|
f"field_id={field_id}, version={field.version}"
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ RAG Optimization (rag-optimization/spec.md):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sse_starlette.sse import ServerSentEvent
|
from sse_starlette.sse import ServerSentEvent
|
||||||
|
|
@ -46,7 +46,6 @@ from app.services.flow.engine import FlowEngine
|
||||||
from app.services.guardrail.behavior_service import BehaviorRuleService
|
from app.services.guardrail.behavior_service import BehaviorRuleService
|
||||||
from app.services.guardrail.input_scanner import InputScanner
|
from app.services.guardrail.input_scanner import InputScanner
|
||||||
from app.services.guardrail.output_filter import OutputFilter
|
from app.services.guardrail.output_filter import OutputFilter
|
||||||
from app.services.guardrail.word_service import ForbiddenWordService
|
|
||||||
from app.services.intent.router import IntentRouter
|
from app.services.intent.router import IntentRouter
|
||||||
from app.services.intent.rule_service import IntentRuleService
|
from app.services.intent.rule_service import IntentRuleService
|
||||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
||||||
|
|
@ -90,6 +89,8 @@ class GenerationContext:
|
||||||
10. confidence_result: Confidence calculation result
|
10. confidence_result: Confidence calculation result
|
||||||
11. messages_saved: Whether messages were saved
|
11. messages_saved: Whether messages were saved
|
||||||
12. final_response: Final ChatResponse
|
12. final_response: Final ChatResponse
|
||||||
|
|
||||||
|
[v0.8.0] Extended with route_trace for hybrid routing observability.
|
||||||
"""
|
"""
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
@ -115,6 +116,11 @@ class GenerationContext:
|
||||||
target_kb_ids: list[str] | None = None
|
target_kb_ids: list[str] | None = None
|
||||||
behavior_rules: list[str] = field(default_factory=list)
|
behavior_rules: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# [v0.8.0] Hybrid routing fields
|
||||||
|
route_trace: dict[str, Any] | None = None
|
||||||
|
fusion_confidence: float | None = None
|
||||||
|
fusion_decision_reason: str | None = None
|
||||||
|
|
||||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||||
execution_steps: list[dict[str, Any]] = field(default_factory=list)
|
execution_steps: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
@ -487,7 +493,7 @@ class OrchestratorService:
|
||||||
finish_reason="flow_step",
|
finish_reason="flow_step",
|
||||||
)
|
)
|
||||||
ctx.diagnostics["flow_handled"] = True
|
ctx.diagnostics["flow_handled"] = True
|
||||||
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM")
|
logger.info("[AC-AISVC-75] Flow provided reply, skipping LLM")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ctx.diagnostics["flow_check_enabled"] = True
|
ctx.diagnostics["flow_check_enabled"] = True
|
||||||
|
|
@ -501,8 +507,8 @@ class OrchestratorService:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
|
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
|
||||||
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
|
Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
|
||||||
|
[v0.8.0] Upgraded to use match_hybrid() for hybrid routing.
|
||||||
"""
|
"""
|
||||||
# Skip if flow already handled the request
|
|
||||||
if ctx.diagnostics.get("flow_handled"):
|
if ctx.diagnostics.get("flow_handled"):
|
||||||
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
|
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
|
||||||
return
|
return
|
||||||
|
|
@ -513,7 +519,6 @@ class OrchestratorService:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load enabled rules ordered by priority
|
|
||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
from app.services.intent.rule_service import IntentRuleService
|
from app.services.intent.rule_service import IntentRuleService
|
||||||
rule_service = IntentRuleService(session)
|
rule_service = IntentRuleService(session)
|
||||||
|
|
@ -524,33 +529,64 @@ class OrchestratorService:
|
||||||
ctx.diagnostics["intent_matched"] = False
|
ctx.diagnostics["intent_matched"] = False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Match intent
|
fusion_result = await self._intent_router.match_hybrid(
|
||||||
ctx.intent_match = self._intent_router.match(
|
|
||||||
message=ctx.current_message,
|
message=ctx.current_message,
|
||||||
rules=rules,
|
rules=rules,
|
||||||
|
tenant_id=ctx.tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ctx.intent_match:
|
ctx.route_trace = fusion_result.trace.to_dict()
|
||||||
|
ctx.fusion_confidence = fusion_result.final_confidence
|
||||||
|
ctx.fusion_decision_reason = fusion_result.decision_reason
|
||||||
|
|
||||||
|
if fusion_result.final_intent:
|
||||||
|
ctx.intent_match = type(
|
||||||
|
"IntentMatchResult",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"rule": fusion_result.final_intent,
|
||||||
|
"match_type": fusion_result.decision_reason,
|
||||||
|
"matched": "",
|
||||||
|
"to_dict": lambda: {
|
||||||
|
"rule_id": str(fusion_result.final_intent.id),
|
||||||
|
"rule_name": fusion_result.final_intent.name,
|
||||||
|
"match_type": fusion_result.decision_reason,
|
||||||
|
"matched": "",
|
||||||
|
"response_type": fusion_result.final_intent.response_type,
|
||||||
|
"target_kb_ids": (
|
||||||
|
fusion_result.final_intent.target_kb_ids or []
|
||||||
|
),
|
||||||
|
"flow_id": (
|
||||||
|
str(fusion_result.final_intent.flow_id)
|
||||||
|
if fusion_result.final_intent.flow_id else None
|
||||||
|
),
|
||||||
|
"fixed_reply": fusion_result.final_intent.fixed_reply,
|
||||||
|
"transfer_message": fusion_result.final_intent.transfer_message,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, "
|
f"[AC-AISVC-69] Intent matched: rule={fusion_result.final_intent.name}, "
|
||||||
f"response_type={ctx.intent_match.rule.response_type}"
|
f"response_type={fusion_result.final_intent.response_type}, "
|
||||||
|
f"decision={fusion_result.decision_reason}, "
|
||||||
|
f"confidence={fusion_result.final_confidence:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
|
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
|
||||||
|
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||||
|
|
||||||
# Increment hit count
|
|
||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
rule_service = IntentRuleService(session)
|
rule_service = IntentRuleService(session)
|
||||||
await rule_service.increment_hit_count(
|
await rule_service.increment_hit_count(
|
||||||
tenant_id=ctx.tenant_id,
|
tenant_id=ctx.tenant_id,
|
||||||
rule_id=ctx.intent_match.rule.id,
|
rule_id=fusion_result.final_intent.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Route based on response_type
|
rule = fusion_result.final_intent
|
||||||
if ctx.intent_match.rule.response_type == "fixed":
|
if rule.response_type == "fixed":
|
||||||
# Fixed reply - skip LLM
|
|
||||||
ctx.llm_response = LLMResponse(
|
ctx.llm_response = LLMResponse(
|
||||||
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。",
|
content=rule.fixed_reply or "收到您的消息。",
|
||||||
model="intent_fixed",
|
model="intent_fixed",
|
||||||
usage={},
|
usage={},
|
||||||
finish_reason="intent_fixed",
|
finish_reason="intent_fixed",
|
||||||
|
|
@ -558,20 +594,18 @@ class OrchestratorService:
|
||||||
ctx.diagnostics["intent_handled"] = True
|
ctx.diagnostics["intent_handled"] = True
|
||||||
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
|
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
|
||||||
|
|
||||||
elif ctx.intent_match.rule.response_type == "rag":
|
elif rule.response_type == "rag":
|
||||||
# RAG with target KBs
|
ctx.target_kb_ids = rule.target_kb_ids or []
|
||||||
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
|
|
||||||
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
|
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
|
||||||
|
|
||||||
elif ctx.intent_match.rule.response_type == "flow":
|
elif rule.response_type == "flow":
|
||||||
# Start script flow
|
if rule.flow_id and self._flow_engine:
|
||||||
if ctx.intent_match.rule.flow_id and self._flow_engine:
|
|
||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
flow_engine = FlowEngine(session)
|
flow_engine = FlowEngine(session)
|
||||||
instance, first_step = await flow_engine.start(
|
instance, first_step = await flow_engine.start(
|
||||||
tenant_id=ctx.tenant_id,
|
tenant_id=ctx.tenant_id,
|
||||||
session_id=ctx.session_id,
|
session_id=ctx.session_id,
|
||||||
flow_id=ctx.intent_match.rule.flow_id,
|
flow_id=rule.flow_id,
|
||||||
)
|
)
|
||||||
if first_step:
|
if first_step:
|
||||||
ctx.llm_response = LLMResponse(
|
ctx.llm_response = LLMResponse(
|
||||||
|
|
@ -583,10 +617,9 @@ class OrchestratorService:
|
||||||
ctx.diagnostics["intent_handled"] = True
|
ctx.diagnostics["intent_handled"] = True
|
||||||
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
|
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
|
||||||
|
|
||||||
elif ctx.intent_match.rule.response_type == "transfer":
|
elif rule.response_type == "transfer":
|
||||||
# Transfer to human
|
|
||||||
ctx.llm_response = LLMResponse(
|
ctx.llm_response = LLMResponse(
|
||||||
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...",
|
content=rule.transfer_message or "正在为您转接人工客服...",
|
||||||
model="intent_transfer",
|
model="intent_transfer",
|
||||||
usage={},
|
usage={},
|
||||||
finish_reason="intent_transfer",
|
finish_reason="intent_transfer",
|
||||||
|
|
@ -600,9 +633,25 @@ class OrchestratorService:
|
||||||
ctx.diagnostics["intent_handled"] = True
|
ctx.diagnostics["intent_handled"] = True
|
||||||
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
|
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
|
||||||
|
|
||||||
|
if fusion_result.need_clarify:
|
||||||
|
ctx.diagnostics["need_clarify"] = True
|
||||||
|
ctx.diagnostics["clarify_candidates"] = [
|
||||||
|
{"id": str(r.id), "name": r.name}
|
||||||
|
for r in (fusion_result.clarify_candidates or [])
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-121] Low confidence, need clarify: "
|
||||||
|
f"confidence={fusion_result.final_confidence:.3f}, "
|
||||||
|
f"candidates={len(fusion_result.clarify_candidates or [])}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ctx.diagnostics["intent_match_enabled"] = True
|
ctx.diagnostics["intent_match_enabled"] = True
|
||||||
ctx.diagnostics["intent_matched"] = False
|
ctx.diagnostics["intent_matched"] = False
|
||||||
|
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-69] No intent matched, decision={fusion_result.decision_reason}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
|
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
|
||||||
|
|
@ -724,43 +773,43 @@ class OrchestratorService:
|
||||||
async def _build_metadata_filters(self, ctx: GenerationContext):
|
async def _build_metadata_filters(self, ctx: GenerationContext):
|
||||||
"""
|
"""
|
||||||
[AC-IDSMETA-19] Build metadata filters from context.
|
[AC-IDSMETA-19] Build metadata filters from context.
|
||||||
|
|
||||||
Sources:
|
Sources:
|
||||||
1. Intent rule metadata (if matched)
|
1. Intent rule metadata (if matched)
|
||||||
2. Session metadata
|
2. Session metadata
|
||||||
3. Request metadata
|
3. Request metadata
|
||||||
4. Extracted slots from conversation
|
4. Extracted slots from conversation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TagFilter with at least grade, subject, scene if available
|
TagFilter with at least grade, subject, scene if available
|
||||||
"""
|
"""
|
||||||
from app.services.retrieval.metadata import TagFilter
|
from app.services.retrieval.metadata import TagFilter
|
||||||
|
|
||||||
filter_fields = {}
|
filter_fields = {}
|
||||||
|
|
||||||
# 1. From intent rule metadata
|
# 1. From intent rule metadata
|
||||||
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
|
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
|
||||||
intent_metadata = ctx.intent_match.rule.metadata_
|
intent_metadata = ctx.intent_match.rule.metadata_
|
||||||
for key in ['grade', 'subject', 'scene']:
|
for key in ['grade', 'subject', 'scene']:
|
||||||
if key in intent_metadata:
|
if key in intent_metadata:
|
||||||
filter_fields[key] = intent_metadata[key]
|
filter_fields[key] = intent_metadata[key]
|
||||||
|
|
||||||
# 2. From session/request metadata
|
# 2. From session/request metadata
|
||||||
if ctx.request_metadata:
|
if ctx.request_metadata:
|
||||||
for key in ['grade', 'subject', 'scene']:
|
for key in ['grade', 'subject', 'scene']:
|
||||||
if key in ctx.request_metadata and key not in filter_fields:
|
if key in ctx.request_metadata and key not in filter_fields:
|
||||||
filter_fields[key] = ctx.request_metadata[key]
|
filter_fields[key] = ctx.request_metadata[key]
|
||||||
|
|
||||||
# 3. From merged context (extracted slots)
|
# 3. From merged context (extracted slots)
|
||||||
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
|
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
|
||||||
slots = ctx.merged_context.slots or {}
|
slots = ctx.merged_context.slots or {}
|
||||||
for key in ['grade', 'subject', 'scene']:
|
for key in ['grade', 'subject', 'scene']:
|
||||||
if key in slots and key not in filter_fields:
|
if key in slots and key not in filter_fields:
|
||||||
filter_fields[key] = slots[key]
|
filter_fields[key] = slots[key]
|
||||||
|
|
||||||
if not filter_fields:
|
if not filter_fields:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return TagFilter(fields=filter_fields)
|
return TagFilter(fields=filter_fields)
|
||||||
|
|
||||||
async def _build_system_prompt(self, ctx: GenerationContext) -> None:
|
async def _build_system_prompt(self, ctx: GenerationContext) -> None:
|
||||||
|
|
@ -981,11 +1030,11 @@ class OrchestratorService:
|
||||||
"根据知识库信息,我找到了一些相关内容,"
|
"根据知识库信息,我找到了一些相关内容,"
|
||||||
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# [AC-IDSMETA-20] Record structured fallback reason code
|
# [AC-IDSMETA-20] Record structured fallback reason code
|
||||||
fallback_reason_code = self._determine_fallback_reason_code(ctx)
|
fallback_reason_code = self._determine_fallback_reason_code(ctx)
|
||||||
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
|
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[AC-IDSMETA-20] No recall, using fallback: "
|
f"[AC-IDSMETA-20] No recall, using fallback: "
|
||||||
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
|
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
|
||||||
|
|
@ -993,7 +1042,7 @@ class OrchestratorService:
|
||||||
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
|
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
|
||||||
f"fallback_reason_code={fallback_reason_code}"
|
f"fallback_reason_code={fallback_reason_code}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
"抱歉,我暂时无法处理您的请求。"
|
"抱歉,我暂时无法处理您的请求。"
|
||||||
"请稍后重试或联系人工客服获取帮助。"
|
"请稍后重试或联系人工客服获取帮助。"
|
||||||
|
|
@ -1002,7 +1051,7 @@ class OrchestratorService:
|
||||||
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
|
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
|
||||||
"""
|
"""
|
||||||
[AC-IDSMETA-20] Determine structured fallback reason code.
|
[AC-IDSMETA-20] Determine structured fallback reason code.
|
||||||
|
|
||||||
Reason codes:
|
Reason codes:
|
||||||
- no_recall_after_metadata_filter: No results after applying metadata filters
|
- no_recall_after_metadata_filter: No results after applying metadata filters
|
||||||
- no_recall_no_kb: No target knowledge bases configured
|
- no_recall_no_kb: No target knowledge bases configured
|
||||||
|
|
@ -1011,27 +1060,27 @@ class OrchestratorService:
|
||||||
- no_recall_error: Retrieval error occurred
|
- no_recall_error: Retrieval error occurred
|
||||||
"""
|
"""
|
||||||
retrieval_diag = ctx.diagnostics.get("retrieval", {})
|
retrieval_diag = ctx.diagnostics.get("retrieval", {})
|
||||||
|
|
||||||
# Check for retrieval error
|
# Check for retrieval error
|
||||||
if ctx.diagnostics.get("retrieval_error"):
|
if ctx.diagnostics.get("retrieval_error"):
|
||||||
return "no_recall_error"
|
return "no_recall_error"
|
||||||
|
|
||||||
# Check if metadata filters were applied
|
# Check if metadata filters were applied
|
||||||
if retrieval_diag.get("applied_metadata_filters"):
|
if retrieval_diag.get("applied_metadata_filters"):
|
||||||
return "no_recall_after_metadata_filter"
|
return "no_recall_after_metadata_filter"
|
||||||
|
|
||||||
# Check if target KBs were configured
|
# Check if target KBs were configured
|
||||||
if not ctx.target_kb_ids:
|
if not ctx.target_kb_ids:
|
||||||
return "no_recall_no_kb"
|
return "no_recall_no_kb"
|
||||||
|
|
||||||
# Check if KB is empty (no candidates at all)
|
# Check if KB is empty (no candidates at all)
|
||||||
if retrieval_diag.get("total_candidates", 0) == 0:
|
if retrieval_diag.get("total_candidates", 0) == 0:
|
||||||
return "no_recall_kb_empty"
|
return "no_recall_kb_empty"
|
||||||
|
|
||||||
# Results found but filtered out by score threshold
|
# Results found but filtered out by score threshold
|
||||||
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
|
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
|
||||||
return "no_recall_low_score"
|
return "no_recall_low_score"
|
||||||
|
|
||||||
return "no_recall_unknown"
|
return "no_recall_unknown"
|
||||||
|
|
||||||
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
||||||
|
|
@ -1122,6 +1171,7 @@ class OrchestratorService:
|
||||||
[AC-AISVC-02] Build final ChatResponse from generation context.
|
[AC-AISVC-02] Build final ChatResponse from generation context.
|
||||||
Step 12 of the 12-step pipeline.
|
Step 12 of the 12-step pipeline.
|
||||||
Uses filtered_reply from Step 9.
|
Uses filtered_reply from Step 9.
|
||||||
|
[v0.8.0] Includes route_trace in response metadata.
|
||||||
"""
|
"""
|
||||||
# Use filtered_reply if available, otherwise use llm_response.content
|
# Use filtered_reply if available, otherwise use llm_response.content
|
||||||
if ctx.filtered_reply:
|
if ctx.filtered_reply:
|
||||||
|
|
@ -1142,6 +1192,10 @@ class OrchestratorService:
|
||||||
"execution_steps": ctx.execution_steps,
|
"execution_steps": ctx.execution_steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# [v0.8.0] Include route_trace in response metadata
|
||||||
|
if ctx.route_trace:
|
||||||
|
response_metadata["route_trace"] = ctx.route_trace
|
||||||
|
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
reply=reply,
|
reply=reply,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,9 @@ class PromptTemplateService:
|
||||||
current_version = v
|
current_version = v
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Get latest version for current_content (not just published)
|
||||||
|
latest_version = versions[0] if versions else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(template.id),
|
"id": str(template.id),
|
||||||
"name": template.name,
|
"name": template.name,
|
||||||
|
|
@ -185,6 +188,8 @@ class PromptTemplateService:
|
||||||
"description": template.description,
|
"description": template.description,
|
||||||
"is_default": template.is_default,
|
"is_default": template.is_default,
|
||||||
"metadata": template.metadata_,
|
"metadata": template.metadata_,
|
||||||
|
"current_content": latest_version.system_instruction if latest_version else None,
|
||||||
|
"variables": latest_version.variables if latest_version else [],
|
||||||
"current_version": {
|
"current_version": {
|
||||||
"version": current_version.version,
|
"version": current_version.version,
|
||||||
"status": current_version.status,
|
"status": current_version.status,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class RetrievalContext:
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
tag_filter: "TagFilter | None" = None
|
tag_filter: "TagFilter | None" = None
|
||||||
kb_ids: list[str] | None = None
|
kb_ids: list[str] | None = None
|
||||||
|
metadata_filter: dict[str, Any] | None = None
|
||||||
|
|
||||||
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
|
||||||
"""获取标签过滤器的字典表示"""
|
"""获取标签过滤器的字典表示"""
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ Vector retriever for AI Service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||||
|
|
@ -76,16 +77,30 @@ class VectorRetriever(BaseRetriever):
|
||||||
query_vector = await self._get_embedding(ctx.query)
|
query_vector = await self._get_embedding(ctx.query)
|
||||||
logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}")
|
logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}")
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-16] Searching in tenant collection: tenant_id={ctx.tenant_id}")
|
logger.info(f"[AC-AISVC-16] Searching in tenant collections: tenant_id={ctx.tenant_id}")
|
||||||
hits = await client.search(
|
if ctx.kb_ids:
|
||||||
tenant_id=ctx.tenant_id,
|
logger.info(f"[AC-AISVC-16] Restricting search to KB IDs: {ctx.kb_ids}")
|
||||||
query_vector=query_vector,
|
hits = await client.search(
|
||||||
limit=self._top_k,
|
tenant_id=ctx.tenant_id,
|
||||||
score_threshold=self._score_threshold,
|
query_vector=query_vector,
|
||||||
)
|
limit=self._top_k,
|
||||||
|
score_threshold=self._score_threshold,
|
||||||
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} raw hits")
|
vector_name="full",
|
||||||
|
metadata_filter=ctx.metadata_filter,
|
||||||
|
kb_ids=ctx.kb_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hits = await client.search(
|
||||||
|
tenant_id=ctx.tenant_id,
|
||||||
|
query_vector=query_vector,
|
||||||
|
limit=self._top_k,
|
||||||
|
score_threshold=self._score_threshold,
|
||||||
|
vector_name="full",
|
||||||
|
metadata_filter=ctx.metadata_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} hits")
|
||||||
|
|
||||||
retrieval_hits = [
|
retrieval_hits = [
|
||||||
RetrievalHit(
|
RetrievalHit(
|
||||||
text=hit.get("payload", {}).get("text", ""),
|
text=hit.get("payload", {}).get("text", ""),
|
||||||
|
|
@ -133,6 +148,47 @@ class VectorRetriever(BaseRetriever):
|
||||||
diagnostics={"error": str(e), "is_insufficient": True},
|
diagnostics={"error": str(e), "is_insufficient": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_metadata_filter(
|
||||||
|
self,
|
||||||
|
hits: list[dict[str, Any]],
|
||||||
|
metadata_filter: dict[str, Any],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
应用元数据过滤条件。
|
||||||
|
|
||||||
|
支持的操作:
|
||||||
|
- {"$eq": value} : 等于
|
||||||
|
- {"$in": [values]} : 在列表中
|
||||||
|
"""
|
||||||
|
filtered = []
|
||||||
|
for hit in hits:
|
||||||
|
payload = hit.get("payload", {})
|
||||||
|
hit_metadata = payload.get("metadata", {})
|
||||||
|
|
||||||
|
match = True
|
||||||
|
for field_key, condition in metadata_filter.items():
|
||||||
|
hit_value = hit_metadata.get(field_key)
|
||||||
|
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
if "$eq" in condition:
|
||||||
|
if hit_value != condition["$eq"]:
|
||||||
|
match = False
|
||||||
|
break
|
||||||
|
elif "$in" in condition:
|
||||||
|
if hit_value not in condition["$in"]:
|
||||||
|
match = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 直接值比较
|
||||||
|
if hit_value != condition:
|
||||||
|
match = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if match:
|
||||||
|
filtered.append(hit)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
async def _get_embedding(self, text: str) -> list[float]:
|
async def _get_embedding(self, text: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for text using pluggable embedding provider.
|
Generate embedding for text using pluggable embedding provider.
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Slot Definition Service.
|
Slot Definition Service.
|
||||||
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
||||||
|
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
|
||||||
class SlotDefinitionService:
|
class SlotDefinitionService:
|
||||||
"""
|
"""
|
||||||
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
||||||
|
[AC-MRS-07-UPGRADE] 支持提取策略链管理
|
||||||
|
|
||||||
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
||||||
"""
|
"""
|
||||||
|
|
@ -114,6 +116,58 @@ class SlotDefinitionService:
|
||||||
result = await self._session.execute(stmt)
|
result = await self._session.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategies: 策略链列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (是否有效, 错误信息)
|
||||||
|
"""
|
||||||
|
if strategies is None:
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
if not isinstance(strategies, list):
|
||||||
|
return False, "extract_strategies 必须是数组类型"
|
||||||
|
|
||||||
|
if len(strategies) == 0:
|
||||||
|
return False, "提取策略链不能为空数组"
|
||||||
|
|
||||||
|
# 校验不允许重复策略
|
||||||
|
if len(strategies) != len(set(strategies)):
|
||||||
|
return False, "提取策略链中不允许重复的策略"
|
||||||
|
|
||||||
|
# 校验策略值有效
|
||||||
|
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
|
||||||
|
if invalid:
|
||||||
|
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def _normalize_strategies(
|
||||||
|
self,
|
||||||
|
extract_strategies: list[str] | None,
|
||||||
|
extract_strategy: str | None,
|
||||||
|
) -> list[str] | None:
|
||||||
|
"""
|
||||||
|
[AC-MRS-07-UPGRADE] 规范化提取策略
|
||||||
|
优先使用 extract_strategies,如果不存在则使用 extract_strategy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extract_strategies: 策略链(新字段)
|
||||||
|
extract_strategy: 单策略(旧字段,兼容)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
规范化后的策略链或 None
|
||||||
|
"""
|
||||||
|
if extract_strategies is not None:
|
||||||
|
return extract_strategies
|
||||||
|
if extract_strategy:
|
||||||
|
return [extract_strategy]
|
||||||
|
return None
|
||||||
|
|
||||||
async def create_slot_definition(
|
async def create_slot_definition(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|
@ -121,6 +175,7 @@ class SlotDefinitionService:
|
||||||
) -> SlotDefinition:
|
) -> SlotDefinition:
|
||||||
"""
|
"""
|
||||||
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
||||||
|
[AC-MRS-07-UPGRADE] 支持提取策略链
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: 租户 ID
|
tenant_id: 租户 ID
|
||||||
|
|
@ -148,11 +203,16 @@ class SlotDefinitionService:
|
||||||
f"有效类型为: {self.VALID_TYPES}"
|
f"有效类型为: {self.VALID_TYPES}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
# [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
|
||||||
raise ValueError(
|
strategies = self._normalize_strategies(
|
||||||
f"无效的提取策略 '{slot_create.extract_strategy}',"
|
slot_create.extract_strategies,
|
||||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
slot_create.extract_strategy
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if strategies is not None:
|
||||||
|
is_valid, error_msg = self._validate_strategies(strategies)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||||||
|
|
||||||
linked_field = None
|
linked_field = None
|
||||||
if slot_create.linked_field_id:
|
if slot_create.linked_field_id:
|
||||||
|
|
@ -162,12 +222,22 @@ class SlotDefinitionService:
|
||||||
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
|
||||||
|
# 如果前端提交了 extract_strategies,则使用第一个作为旧字段值
|
||||||
|
old_strategy = slot_create.extract_strategy
|
||||||
|
if not old_strategy and strategies and len(strategies) > 0:
|
||||||
|
old_strategy = strategies[0]
|
||||||
|
|
||||||
slot = SlotDefinition(
|
slot = SlotDefinition(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
slot_key=slot_create.slot_key,
|
slot_key=slot_create.slot_key,
|
||||||
|
display_name=slot_create.display_name,
|
||||||
|
description=slot_create.description,
|
||||||
type=slot_create.type,
|
type=slot_create.type,
|
||||||
required=slot_create.required,
|
required=slot_create.required,
|
||||||
extract_strategy=slot_create.extract_strategy,
|
# [AC-MRS-07-UPGRADE] 同时保存新旧字段
|
||||||
|
extract_strategy=old_strategy,
|
||||||
|
extract_strategies=strategies,
|
||||||
validation_rule=slot_create.validation_rule,
|
validation_rule=slot_create.validation_rule,
|
||||||
ask_back_prompt=slot_create.ask_back_prompt,
|
ask_back_prompt=slot_create.ask_back_prompt,
|
||||||
default_value=slot_create.default_value,
|
default_value=slot_create.default_value,
|
||||||
|
|
@ -180,6 +250,7 @@ class SlotDefinitionService:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
||||||
f"slot_key={slot.slot_key}, required={slot.required}, "
|
f"slot_key={slot.slot_key}, required={slot.required}, "
|
||||||
|
f"strategies={strategies}, "
|
||||||
f"linked_field_id={slot.linked_field_id}"
|
f"linked_field_id={slot.linked_field_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -193,6 +264,7 @@ class SlotDefinitionService:
|
||||||
) -> SlotDefinition | None:
|
) -> SlotDefinition | None:
|
||||||
"""
|
"""
|
||||||
更新槽位定义
|
更新槽位定义
|
||||||
|
[AC-MRS-07-UPGRADE] 支持提取策略链更新
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: 租户 ID
|
tenant_id: 租户 ID
|
||||||
|
|
@ -206,6 +278,12 @@ class SlotDefinitionService:
|
||||||
if not slot:
|
if not slot:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if slot_update.display_name is not None:
|
||||||
|
slot.display_name = slot_update.display_name
|
||||||
|
|
||||||
|
if slot_update.description is not None:
|
||||||
|
slot.description = slot_update.description
|
||||||
|
|
||||||
if slot_update.type is not None:
|
if slot_update.type is not None:
|
||||||
if slot_update.type not in self.VALID_TYPES:
|
if slot_update.type not in self.VALID_TYPES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -217,13 +295,28 @@ class SlotDefinitionService:
|
||||||
if slot_update.required is not None:
|
if slot_update.required is not None:
|
||||||
slot.required = slot_update.required
|
slot.required = slot_update.required
|
||||||
|
|
||||||
if slot_update.extract_strategy is not None:
|
# [AC-MRS-07-UPGRADE] 处理提取策略链更新
|
||||||
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
# 如果传入了 extract_strategies 或 extract_strategy,则更新
|
||||||
raise ValueError(
|
if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
|
||||||
f"无效的提取策略 '{slot_update.extract_strategy}',"
|
strategies = self._normalize_strategies(
|
||||||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
slot_update.extract_strategies,
|
||||||
)
|
slot_update.extract_strategy
|
||||||
slot.extract_strategy = slot_update.extract_strategy
|
)
|
||||||
|
|
||||||
|
if strategies is not None:
|
||||||
|
is_valid, error_msg = self._validate_strategies(strategies)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||||||
|
|
||||||
|
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
|
||||||
|
slot.extract_strategies = strategies
|
||||||
|
# 如果前端提交了 extract_strategy,则使用它;否则使用策略链的第一个
|
||||||
|
if slot_update.extract_strategy is not None:
|
||||||
|
slot.extract_strategy = slot_update.extract_strategy
|
||||||
|
elif strategies and len(strategies) > 0:
|
||||||
|
slot.extract_strategy = strategies[0]
|
||||||
|
else:
|
||||||
|
slot.extract_strategy = None
|
||||||
|
|
||||||
if slot_update.validation_rule is not None:
|
if slot_update.validation_rule is not None:
|
||||||
slot.validation_rule = slot_update.validation_rule
|
slot.validation_rule = slot_update.validation_rule
|
||||||
|
|
@ -250,7 +343,7 @@ class SlotDefinitionService:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
||||||
f"slot_id={slot_id}"
|
f"slot_id={slot_id}, strategies={slot.extract_strategies}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return slot
|
return slot
|
||||||
|
|
@ -331,9 +424,13 @@ class SlotDefinitionService:
|
||||||
"id": str(slot.id),
|
"id": str(slot.id),
|
||||||
"tenant_id": slot.tenant_id,
|
"tenant_id": slot.tenant_id,
|
||||||
"slot_key": slot.slot_key,
|
"slot_key": slot.slot_key,
|
||||||
|
"display_name": slot.display_name,
|
||||||
|
"description": slot.description,
|
||||||
"type": slot.type,
|
"type": slot.type,
|
||||||
"required": slot.required,
|
"required": slot.required,
|
||||||
|
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||||
"extract_strategy": slot.extract_strategy,
|
"extract_strategy": slot.extract_strategy,
|
||||||
|
"extract_strategies": slot.extract_strategies,
|
||||||
"validation_rule": slot.validation_rule,
|
"validation_rule": slot.validation_rule,
|
||||||
"ask_back_prompt": slot.ask_back_prompt,
|
"ask_back_prompt": slot.ask_back_prompt,
|
||||||
"default_value": slot.default_value,
|
"default_value": slot.default_value,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue