feat: update core backend services including LLM, embedding, KB, orchestrator and admin APIs [AC-AISVC-CORE]

This commit is contained in:
MerCry 2026-03-10 12:09:45 +08:00
parent 759eafb490
commit fe883cfff0
27 changed files with 1704 additions and 109 deletions

View File

@ -2,6 +2,7 @@
Admin API routes for AI Service management. Admin API routes for AI Service management.
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints. [AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
[AC-MRS-07,08,16] Slot definition management endpoints. [AC-MRS-07,08,16] Slot definition management endpoints.
[AC-SCENE-SLOT-01] Scene slot bundle management endpoints.
""" """
from app.api.admin.api_key import router as api_key_router from app.api.admin.api_key import router as api_key_router
@ -18,6 +19,7 @@ from app.api.admin.metadata_schema import router as metadata_schema_router
from app.api.admin.monitoring import router as monitoring_router from app.api.admin.monitoring import router as monitoring_router
from app.api.admin.prompt_templates import router as prompt_templates_router from app.api.admin.prompt_templates import router as prompt_templates_router
from app.api.admin.rag import router as rag_router from app.api.admin.rag import router as rag_router
from app.api.admin.scene_slot_bundle import router as scene_slot_bundle_router
from app.api.admin.script_flows import router as script_flows_router from app.api.admin.script_flows import router as script_flows_router
from app.api.admin.sessions import router as sessions_router from app.api.admin.sessions import router as sessions_router
from app.api.admin.slot_definition import router as slot_definition_router from app.api.admin.slot_definition import router as slot_definition_router
@ -38,6 +40,7 @@ __all__ = [
"monitoring_router", "monitoring_router",
"prompt_templates_router", "prompt_templates_router",
"rag_router", "rag_router",
"scene_slot_bundle_router",
"script_flows_router", "script_flows_router",
"sessions_router", "sessions_router",
"slot_definition_router", "slot_definition_router",

View File

@ -2,6 +2,8 @@
Intent Rule Management API. Intent Rule Management API.
[AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints. [AC-AISVC-65~AC-AISVC-68] Intent rule CRUD endpoints.
[AC-AISVC-96] Intent rule testing endpoint. [AC-AISVC-96] Intent rule testing endpoint.
[AC-AISVC-116] Fusion config management endpoints.
[AC-AISVC-114] Intent vector generation endpoint.
""" """
import logging import logging
@ -14,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session from app.core.database import get_session
from app.models.entities import IntentRuleCreate, IntentRuleUpdate from app.models.entities import IntentRuleCreate, IntentRuleUpdate
from app.services.intent.models import DEFAULT_FUSION_CONFIG, FusionConfig
from app.services.intent.rule_service import IntentRuleService from app.services.intent.rule_service import IntentRuleService
from app.services.intent.tester import IntentRuleTester from app.services.intent.tester import IntentRuleTester
@ -21,6 +24,8 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"]) router = APIRouter(prefix="/admin/intent-rules", tags=["Intent Rules"])
_fusion_config = FusionConfig()
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str: def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header.""" """Extract tenant ID from header."""
@ -204,3 +209,109 @@ async def test_rule(
result = await tester.test_rule(rule, [body.message], all_rules) result = await tester.test_rule(rule, [body.message], all_rules)
return result.to_dict() return result.to_dict()
class FusionConfigUpdate(BaseModel):
"""Request body for updating fusion config."""
w_rule: float | None = None
w_semantic: float | None = None
w_llm: float | None = None
semantic_threshold: float | None = None
conflict_threshold: float | None = None
gray_zone_threshold: float | None = None
min_trigger_threshold: float | None = None
clarify_threshold: float | None = None
multi_intent_threshold: float | None = None
llm_judge_enabled: bool | None = None
semantic_matcher_enabled: bool | None = None
semantic_matcher_timeout_ms: int | None = None
llm_judge_timeout_ms: int | None = None
semantic_top_k: int | None = None
@router.get("/fusion-config")
async def get_fusion_config() -> dict[str, Any]:
"""
[AC-AISVC-116] Get current fusion configuration.
"""
logger.info("[AC-AISVC-116] Getting fusion config")
return _fusion_config.to_dict()
@router.put("/fusion-config")
async def update_fusion_config(
body: FusionConfigUpdate,
) -> dict[str, Any]:
"""
[AC-AISVC-116] Update fusion configuration.
"""
global _fusion_config
logger.info(f"[AC-AISVC-116] Updating fusion config: {body.model_dump()}")
current_dict = _fusion_config.to_dict()
update_dict = body.model_dump(exclude_none=True)
current_dict.update(update_dict)
_fusion_config = FusionConfig.from_dict(current_dict)
return _fusion_config.to_dict()
@router.post("/{rule_id}/generate-vector")
async def generate_intent_vector(
rule_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-114] Generate intent vector for a rule.
Uses the rule's semantic_examples to generate an average vector.
If no semantic_examples exist, returns an error.
"""
logger.info(
f"[AC-AISVC-114] Generating intent vector for tenant={tenant_id}, rule_id={rule_id}"
)
service = IntentRuleService(session)
rule = await service.get_rule(tenant_id, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Intent rule not found")
if not rule.semantic_examples:
raise HTTPException(
status_code=400,
detail="Rule has no semantic_examples. Please add semantic_examples first."
)
try:
from app.core.dependencies import get_embedding_provider
embedding_provider = get_embedding_provider()
vectors = await embedding_provider.embed_batch(rule.semantic_examples)
import numpy as np
avg_vector = np.mean(vectors, axis=0).tolist()
update_data = IntentRuleUpdate(intent_vector=avg_vector)
updated_rule = await service.update_rule(tenant_id, rule_id, update_data)
logger.info(
f"[AC-AISVC-114] Generated intent vector for rule={rule_id}, "
f"dimension={len(avg_vector)}"
)
return {
"id": str(updated_rule.id),
"intent_vector": updated_rule.intent_vector,
"semantic_examples": updated_rule.semantic_examples,
}
except Exception as e:
logger.error(f"[AC-AISVC-114] Failed to generate intent vector: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to generate intent vector: {str(e)}"
)

View File

@ -6,28 +6,35 @@ Knowledge Base management endpoints.
import logging import logging
import uuid import uuid
import json
import hashlib
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Any, Optional from typing import Annotated, Any, Optional
import tiktoken import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import get_settings
from app.core.database import get_session from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id from app.core.tenant import get_tenant_id
from app.models import ErrorResponse from app.models import ErrorResponse
from app.models.entities import ( from app.models.entities import (
Document,
DocumentStatus,
IndexJob, IndexJob,
IndexJobStatus, IndexJobStatus,
KBType, KBType,
KnowledgeBase,
KnowledgeBaseCreate, KnowledgeBaseCreate,
KnowledgeBaseUpdate, KnowledgeBaseUpdate,
) )
from app.services.kb import KBService from app.services.kb import KBService
from app.services.knowledge_base_service import KnowledgeBaseService from app.services.knowledge_base_service import KnowledgeBaseService
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -457,6 +464,7 @@ async def list_documents(
"kbId": doc.kb_id, "kbId": doc.kb_id,
"fileName": doc.file_name, "fileName": doc.file_name,
"status": doc.status, "status": doc.status,
"metadata": doc.doc_metadata,
"jobId": str(latest_job.id) if latest_job else None, "jobId": str(latest_job.id) if latest_job else None,
"createdAt": doc.created_at.isoformat() + "Z", "createdAt": doc.created_at.isoformat() + "Z",
"updatedAt": doc.updated_at.isoformat() + "Z", "updatedAt": doc.updated_at.isoformat() + "Z",
@ -585,6 +593,7 @@ async def upload_document(
file_name=file.filename or "unknown", file_name=file.filename or "unknown",
file_content=file_content, file_content=file_content,
file_type=file.content_type, file_type=file.content_type,
metadata=metadata_dict,
) )
await kb_service.update_doc_count(tenant_id, kb_id, delta=1) await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
@ -915,3 +924,488 @@ async def delete_document(
"message": "Document deleted", "message": "Document deleted",
} }
) )
@router.put(
"/documents/{doc_id}/metadata",
operation_id="updateDocumentMetadata",
summary="Update document metadata",
description="[AC-ASA-08] Update metadata for a specific document.",
responses={
200: {"description": "Metadata updated"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_document_metadata(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
body: dict,
) -> JSONResponse:
"""
[AC-ASA-08] Update document metadata.
"""
import json
metadata = body.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
try:
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_METADATA",
"message": "Invalid JSON format for metadata",
},
)
logger.info(
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
)
from sqlalchemy import select
from app.models.entities import Document
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == doc_id,
)
result = await session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
document.doc_metadata = metadata
await session.commit()
return JSONResponse(
content={
"success": True,
"message": "Metadata updated",
"metadata": document.doc_metadata,
}
)
@router.post(
"/documents/batch-upload",
operation_id="batchUploadDocuments",
summary="Batch upload documents from zip",
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
responses={
200: {"description": "Batch upload result"},
400: {"description": "Bad Request - invalid zip or missing files"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def batch_upload_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
) -> JSONResponse:
"""
Batch upload documents from a zip file.
Zip structure:
- Each folder contains one .md file and one metadata.json
- metadata.json uses field_key from MetadataFieldDefinition as keys
Example metadata.json:
{
"grade": "高一",
"subject": "数学",
"type": "痛点"
}
"""
import json
import tempfile
import zipfile
from pathlib import Path
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger.info(
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
if not file.filename or not file.filename.lower().endswith('.zip'):
return JSONResponse(
status_code=400,
content={
"code": "INVALID_FORMAT",
"message": "Only .zip files are supported",
},
)
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
if not kb:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
file_content = await file.read()
results = []
succeeded = 0
failed = 0
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = Path(temp_dir) / "upload.zip"
with open(zip_path, "wb") as f:
f.write(file_content)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(temp_dir)
except zipfile.BadZipFile as e:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_ZIP",
"message": f"Invalid zip file: {str(e)}",
},
)
extracted_path = Path(temp_dir)
# 列出解压后的所有内容,用于调试
all_items = list(extracted_path.iterdir())
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
def find_document_folders(path: Path) -> list[Path]:
"""递归查找所有包含文档文件的文件夹"""
doc_folders = []
# 检查当前文件夹是否包含文档文件
content_files = (
list(path.glob("*.md")) +
list(path.glob("*.markdown")) +
list(path.glob("*.txt"))
)
if content_files:
# 这个文件夹包含文档文件,是一个文档文件夹
doc_folders.append(path)
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
# 递归检查子文件夹
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
doc_folders.extend(find_document_folders(subfolder))
return doc_folders
folders = find_document_folders(extracted_path)
if not folders:
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
return JSONResponse(
status_code=400,
content={
"code": "NO_DOCUMENTS_FOUND",
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
"details": {
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
"found_items": [item.name for item in all_items],
},
},
)
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
for folder in folders:
folder_name = folder.name if folder != extracted_path else "root"
content_files = (
list(folder.glob("*.md")) +
list(folder.glob("*.markdown")) +
list(folder.glob("*.txt"))
)
if not content_files:
# 这种情况不应该发生,因为我们已经过滤过了
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": "No content file found",
})
continue
content_file = content_files[0]
metadata_file = folder / "metadata.json"
metadata_dict = {}
if metadata_file.exists():
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata_dict = json.load(f)
except json.JSONDecodeError as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Invalid metadata.json: {str(e)}",
})
continue
else:
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
field_def_service = MetadataFieldDefinitionService(session)
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, metadata_dict, "kb_document"
)
if not is_valid:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Metadata validation failed: {validation_errors}",
})
continue
try:
with open(content_file, 'rb') as f:
doc_content = f.read()
file_ext = content_file.suffix.lower()
if file_ext == '.txt':
file_type = "text/plain"
else:
file_type = "text/markdown"
doc_kb_service = KBService(session)
document, job = await doc_kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=content_file.name,
file_content=doc_content,
file_type=file_type,
metadata=metadata_dict,
)
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
await session.commit()
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
doc_content,
content_file.name,
metadata_dict,
)
succeeded += 1
results.append({
"folder": folder_name,
"docId": str(document.id),
"jobId": str(job.id),
"status": "created",
"fileName": content_file.name,
})
logger.info(
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
f"doc_id={document.id}, job_id={job.id}"
)
except Exception as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": str(e),
})
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
logger.info(
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
)
return JSONResponse(
content={
"success": True,
"total": len(results),
"succeeded": succeeded,
"failed": failed,
"results": results,
}
)
@router.post(
"/{kb_id}/documents/json-batch",
summary="[AC-KB-03] JSON批量上传文档",
description="上传JSONL格式文件每行一个JSON对象包含text和元数据字段",
)
async def upload_json_batch(
kb_id: str,
tenant_id: str = Query(..., description="租户ID"),
file: UploadFile = File(..., description="JSONL格式文件每行一个JSON对象"),
session: AsyncSession = Depends(get_session),
background_tasks: BackgroundTasks = None,
):
"""
JSON批量上传文档
文件格式JSONL (每行一个JSON对象)
必填字段text - 需要录入知识库的文本内容
可选字段元数据字段如grade, subject, kb_scene等
示例
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
"""
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
if kb.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="无权访问此知识库")
valid_field_keys = set()
try:
field_defs = await MetadataFieldDefinitionService(session).get_fields(
tenant_id=tenant_id,
include_inactive=False,
)
valid_field_keys = {f.field_key for f in field_defs}
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
except Exception as e:
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
content = await file.read()
try:
text_content = content.decode("utf-8")
except UnicodeDecodeError:
try:
text_content = content.decode("gbk")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="文件编码不支持请使用UTF-8编码")
lines = text_content.strip().split("\n")
if not lines:
raise HTTPException(status_code=400, detail="文件内容为空")
results = []
succeeded = 0
failed = 0
kb_service = KBService(session)
for line_num, line in enumerate(lines, 1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
except json.JSONDecodeError as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": f"JSON解析失败: {e}",
})
continue
text = json_obj.get("text")
if not text:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": "缺少必填字段: text",
})
continue
metadata = {}
for key, value in json_obj.items():
if key == "text":
continue
if valid_field_keys and key not in valid_field_keys:
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
continue
if value is not None:
metadata[key] = value
try:
file_name = f"json_batch_line_{line_num}.txt"
file_content = text.encode("utf-8")
document, job = await kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file_name,
file_content=file_content,
file_type="text/plain",
metadata=metadata,
)
if background_tasks:
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
file_content,
file_name,
metadata,
)
succeeded += 1
results.append({
"line": line_num,
"success": True,
"doc_id": str(document.id),
"job_id": str(job.id),
"metadata": metadata,
})
except Exception as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": str(e),
})
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
await session.commit()
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
return JSONResponse(
content={
"success": True,
"total": len(lines),
"succeeded": succeeded,
"failed": failed,
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
"results": results,
}
)

View File

@ -51,6 +51,7 @@ def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
"scope": f.scope, "scope": f.scope,
"is_filterable": f.is_filterable, "is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature, "is_rank_feature": f.is_rank_feature,
"usage_description": f.usage_description,
"field_roles": f.field_roles or [], "field_roles": f.field_roles or [],
"status": f.status, "status": f.status,
"version": f.version, "version": f.version,

View File

@ -407,6 +407,7 @@ async def get_conversation_detail(
"guardrailTriggered": user_msg.guardrail_triggered, "guardrailTriggered": user_msg.guardrail_triggered,
"guardrailWords": user_msg.guardrail_words, "guardrailWords": user_msg.guardrail_words,
"executionSteps": execution_steps, "executionSteps": execution_steps,
"routeTrace": user_msg.route_trace,
"createdAt": user_msg.created_at.isoformat(), "createdAt": user_msg.created_at.isoformat(),
} }
@ -659,8 +660,56 @@ async def _process_export(
except Exception as e: except Exception as e:
logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}") logger.error(f"[AC-AISVC-110] Export failed: task_id={task_id}, error={e}")
task = await session.get(ExportTask, task_id) task = task_status.get(ExportTask, task_id)
if task: if task:
task.status = ExportTaskStatus.FAILED.value task.status = ExportTaskStatus.FAILED.value
task.error_message = str(e) task.error_message = str(e)
await session.commit() await session.commit()
@router.get("/clarification-metrics")
async def get_clarification_metrics(
tenant_id: str = Depends(get_tenant_id),
total_requests: int = Query(100, ge=1, description="Total requests for rate calculation"),
) -> dict[str, Any]:
"""
[AC-CLARIFY] Get clarification metrics.
Returns:
- clarify_trigger_rate: 澄清触发率
- clarify_converge_rate: 澄清后收敛率
- misroute_rate: 误入流程率
"""
from app.services.intent.clarification import get_clarify_metrics
metrics = get_clarify_metrics()
counts = metrics.get_metrics()
rates = metrics.get_rates(total_requests)
return {
"counts": counts,
"rates": rates,
"thresholds": {
"t_high": 0.75,
"t_low": 0.45,
"max_retry": 3,
},
}
@router.post("/clarification-metrics/reset")
async def reset_clarification_metrics(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
[AC-CLARIFY] Reset clarification metrics.
"""
from app.services.intent.clarification import get_clarify_metrics
metrics = get_clarify_metrics()
metrics.reset()
return {
"status": "reset",
"message": "Clarification metrics have been reset.",
}

View File

@ -64,6 +64,7 @@ class Settings(BaseSettings):
redis_enabled: bool = True redis_enabled: bool = True
dashboard_cache_ttl: int = 60 dashboard_cache_ttl: int = 60
stats_counter_ttl: int = 7776000 stats_counter_ttl: int = 7776000
slot_state_cache_ttl: int = 1800
frontend_base_url: str = "http://localhost:3000" frontend_base_url: str = "http://localhost:3000"

View File

@ -20,7 +20,7 @@ engine = create_async_engine(
settings.database_url, settings.database_url,
pool_size=settings.database_pool_size, pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow, max_overflow=settings.database_max_overflow,
echo=settings.debug, echo=False,
pool_pre_ping=True, pool_pre_ping=True,
) )

View File

@ -114,7 +114,12 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
from app.core.database import async_session_maker from app.core.database import async_session_maker
async with async_session_maker() as session: async with async_session_maker() as session:
await service.initialize(session) await service.initialize(session)
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys") if service._initialized and len(service._keys_cache) > 0:
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
elif service._initialized and len(service._keys_cache) == 0:
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
else:
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
except Exception as e: except Exception as e:
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}") logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")

View File

@ -272,20 +272,24 @@ class QdrantClient:
score_threshold: float | None = None, score_threshold: float | None = None,
vector_name: str = "full", vector_name: str = "full",
with_vectors: bool = False, with_vectors: bool = False,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
[AC-AISVC-10] Search vectors in tenant's collection. [AC-AISVC-10] Search vectors in tenant's collections.
Returns results with score >= score_threshold if specified. Returns results with score >= score_threshold if specified.
Searches both old format (with @) and new format (with _) for backward compatibility. Searches all collections for the tenant (multi-KB support).
Args: Args:
tenant_id: Tenant identifier tenant_id: Tenant identifier
query_vector: Query vector for similarity search query_vector: Query vector for similarity search
limit: Maximum number of results limit: Maximum number of results per collection
score_threshold: Minimum score threshold for results score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections) vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup. Default is "full" for 768-dim vectors in Matryoshka setup.
with_vectors: Whether to return vectors in results (for two-stage reranking) with_vectors: Whether to return vectors in results (for two-stage reranking)
metadata_filter: Optional metadata filter to apply during search
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
""" """
client = await self.get_client() client = await self.get_client()
@ -293,21 +297,36 @@ class QdrantClient:
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, " f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}" f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
) )
if metadata_filter:
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
collection_names = [self.get_collection_name(tenant_id)] # 构建 Qdrant filter
if '@' in tenant_id: qdrant_filter = None
old_format = f"{self._collection_prefix}{tenant_id}" if metadata_filter:
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}" qdrant_filter = self._build_qdrant_filter(metadata_filter)
collection_names = [new_format, old_format] logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}") # 获取该租户的所有 collections
collection_names = await self._get_tenant_collections(client, tenant_id)
# 如果指定了 kb_ids则只搜索指定的知识库 collections
if kb_ids:
target_collections = []
for kb_id in kb_ids:
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
if kb_collection_name in collection_names:
target_collections.append(kb_collection_name)
else:
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
collection_names = target_collections
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
else:
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
all_hits = [] all_hits = []
for collection_name in collection_names: for collection_name in collection_names:
try: try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
exists = await client.collection_exists(collection_name) exists = await client.collection_exists(collection_name)
if not exists: if not exists:
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist") logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
@ -321,6 +340,7 @@ class QdrantClient:
limit=limit, limit=limit,
with_vectors=with_vectors, with_vectors=with_vectors,
score_threshold=score_threshold, score_threshold=score_threshold,
query_filter=qdrant_filter,
) )
except Exception as e: except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower(): if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
@ -334,6 +354,7 @@ class QdrantClient:
limit=limit, limit=limit,
with_vectors=with_vectors, with_vectors=with_vectors,
score_threshold=score_threshold, score_threshold=score_threshold,
query_filter=qdrant_filter,
) )
else: else:
raise raise
@ -348,6 +369,7 @@ class QdrantClient:
"id": str(result.id), "id": str(result.id),
"score": result.score, "score": result.score,
"payload": result.payload or {}, "payload": result.payload or {},
"collection": collection_name, # 添加 collection 信息
} }
if with_vectors and result.vector: if with_vectors and result.vector:
hit["vector"] = result.vector hit["vector"] = result.vector
@ -358,10 +380,6 @@ class QdrantClient:
logger.info( logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}" f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
) )
for i, h in enumerate(hits[:3]):
logger.debug(
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
)
else: else:
logger.warning( logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)" f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
@ -370,9 +388,10 @@ class QdrantClient:
logger.warning( logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}" f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
) )
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit] # 按分数排序并返回 top results
all_hits.sort(key=lambda x: x["score"], reverse=True)
all_hits = all_hits[:limit]
logger.info( logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}" f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
@ -386,6 +405,113 @@ class QdrantClient:
return all_hits return all_hits
async def _get_tenant_collections(
self,
client: AsyncQdrantClient,
tenant_id: str,
) -> list[str]:
"""
获取指定租户的所有 collections
优先从 Redis 缓存获取未缓存则从 Qdrant 查询并缓存
Args:
client: Qdrant client
tenant_id: 租户 ID
Returns:
Collection 名称列表
"""
import time
start_time = time.time()
# 1. 尝试从缓存获取
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cache_key = f"collections:{tenant_id}"
try:
# 确保 Redis 连接已初始化
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
cached = await redis_client.get(cache_key)
if cached:
import json
collections = json.loads(cached)
logger.info(
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
)
return collections
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
# 2. 从 Qdrant 查询
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"{self._collection_prefix}{safe_tenant_id}"
try:
collections = await client.get_collections()
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
# 按名称排序
tenant_collections.sort()
db_time = (time.time() - start_time) * 1000
logger.info(
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
)
# 3. 缓存结果5分钟 TTL
try:
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
import json
await redis_client.setex(
cache_key,
300, # 5分钟
json.dumps(tenant_collections)
)
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
return tenant_collections
except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
return [self.get_collection_name(tenant_id)]
def _build_qdrant_filter(
self,
metadata_filter: dict[str, Any],
) -> Any:
"""
构建 Qdrant 过滤条件
Args:
metadata_filter: 元数据过滤条件 {"grade": "三年级", "subject": "语文"}
Returns:
Qdrant Filter 对象
"""
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
# 支持嵌套 metadata 字段,如 metadata.grade
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
return Filter(must=must_conditions) if must_conditions else None
async def delete_collection(self, tenant_id: str) -> bool: async def delete_collection(self, tenant_id: str) -> bool:
""" """
[AC-AISVC-10] Delete tenant's collection. [AC-AISVC-10] Delete tenant's collection.

View File

@ -29,6 +29,7 @@ from app.api.admin import (
monitoring_router, monitoring_router,
prompt_templates_router, prompt_templates_router,
rag_router, rag_router,
scene_slot_bundle_router,
script_flows_router, script_flows_router,
sessions_router, sessions_router,
slot_definition_router, slot_definition_router,
@ -55,6 +56,11 @@ logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
) )
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -88,6 +94,28 @@ async def lifespan(app: FastAPI):
except Exception as e: except Exception as e:
logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True) logger.error(f"[AC-AISVC-50] API key initialization FAILED: {e}", exc_info=True)
try:
from app.services.mid.tool_guide_registry import init_tool_guide_registry
logger.info("[ToolGuideRegistry] Starting tool guides initialization...")
tool_guide_registry = init_tool_guide_registry()
logger.info(f"[ToolGuideRegistry] Tool guides loaded: {tool_guide_registry.list_tools()}")
except Exception as e:
logger.error(f"[ToolRegistry] Tools initialization FAILED: {e}", exc_info=True)
# [AC-AISVC-29] 预初始化 Embedding 服务,避免首次查询时的延迟
try:
from app.services.embedding import get_embedding_provider
logger.info("[AC-AISVC-29] Pre-initializing embedding service...")
embedding_provider = await get_embedding_provider()
logger.info(
f"[AC-AISVC-29] Embedding service pre-initialized: "
f"provider={embedding_provider.PROVIDER_NAME}"
)
except Exception as e:
logger.error(f"[AC-AISVC-29] Embedding service pre-initialization FAILED: {e}", exc_info=True)
yield yield
await close_db() await close_db()
@ -171,6 +199,7 @@ app.include_router(metadata_schema_router)
app.include_router(monitoring_router) app.include_router(monitoring_router)
app.include_router(prompt_templates_router) app.include_router(prompt_templates_router)
app.include_router(rag_router) app.include_router(rag_router)
app.include_router(scene_slot_bundle_router)
app.include_router(script_flows_router) app.include_router(script_flows_router)
app.include_router(sessions_router) app.include_router(sessions_router)
app.include_router(slot_definition_router) app.include_router(slot_definition_router)

View File

@ -41,6 +41,7 @@ class ChatMessage(SQLModel, table=True):
[AC-AISVC-13] Chat message entity with tenant isolation. [AC-AISVC-13] Chat message entity with tenant isolation.
Messages are scoped by (tenant_id, session_id) for multi-tenant security. Messages are scoped by (tenant_id, session_id) for multi-tenant security.
[v0.7.0] Extended with monitoring fields for Dashboard statistics. [v0.7.0] Extended with monitoring fields for Dashboard statistics.
[v0.8.0] Extended with route_trace for hybrid routing observability.
""" """
__tablename__ = "chat_messages" __tablename__ = "chat_messages"
@ -90,6 +91,11 @@ class ChatMessage(SQLModel, table=True):
sa_column=Column("guardrail_words", JSON, nullable=True), sa_column=Column("guardrail_words", JSON, nullable=True),
description="[v0.7.0] Guardrail trigger details: words, categories, strategy" description="[v0.7.0] Guardrail trigger details: words, categories, strategy"
) )
route_trace: dict[str, Any] | None = Field(
default=None,
sa_column=Column("route_trace", JSON, nullable=True),
description="[v0.8.0] Intent routing trace log for hybrid routing observability"
)
class ChatSessionCreate(SQLModel): class ChatSessionCreate(SQLModel):
@ -227,6 +233,7 @@ class Document(SQLModel, table=True):
file_type: str | None = Field(default=None, description="File MIME type") file_type: str | None = Field(default=None, description="File MIME type")
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status") status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
error_msg: str | None = Field(default=None, description="Error message if failed") error_msg: str | None = Field(default=None, description="Error message if failed")
doc_metadata: dict | None = Field(default=None, sa_type=JSON, description="Document metadata as JSON")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time") created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -421,6 +428,7 @@ class IntentRule(SQLModel, table=True):
[AC-AISVC-65] Intent rule entity with tenant isolation. [AC-AISVC-65] Intent rule entity with tenant isolation.
Supports keyword and regex matching for intent recognition. Supports keyword and regex matching for intent recognition.
[AC-IDSMETA-16] Extended with metadata field for unified storage structure. [AC-IDSMETA-16] Extended with metadata field for unified storage structure.
[v0.8.0] Extended with intent_vector and semantic_examples for hybrid routing.
""" """
__tablename__ = "intent_rules" __tablename__ = "intent_rules"
@ -458,6 +466,16 @@ class IntentRule(SQLModel, table=True):
sa_column=Column("metadata", JSON, nullable=True), sa_column=Column("metadata", JSON, nullable=True),
description="[AC-IDSMETA-16] Structured metadata for the intent rule" description="[AC-IDSMETA-16] Structured metadata for the intent rule"
) )
intent_vector: list[float] | None = Field(
default=None,
sa_column=Column("intent_vector", JSON, nullable=True),
description="[v0.8.0] Pre-computed intent vector for semantic matching"
)
semantic_examples: list[str] | None = Field(
default=None,
sa_column=Column("semantic_examples", JSON, nullable=True),
description="[v0.8.0] Semantic example sentences for dynamic vector computation"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -475,6 +493,8 @@ class IntentRuleCreate(SQLModel):
fixed_reply: str | None = None fixed_reply: str | None = None
transfer_message: str | None = None transfer_message: str | None = None
metadata_: dict[str, Any] | None = None metadata_: dict[str, Any] | None = None
intent_vector: list[float] | None = None
semantic_examples: list[str] | None = None
class IntentRuleUpdate(SQLModel): class IntentRuleUpdate(SQLModel):
@ -491,6 +511,8 @@ class IntentRuleUpdate(SQLModel):
transfer_message: str | None = None transfer_message: str | None = None
is_enabled: bool | None = None is_enabled: bool | None = None
metadata_: dict[str, Any] | None = None metadata_: dict[str, Any] | None = None
intent_vector: list[float] | None = None
semantic_examples: list[str] | None = None
class IntentMatchResult: class IntentMatchResult:
@ -810,6 +832,24 @@ class FlowStep(SQLModel):
default=None, default=None,
description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}" description="RAG configuration for this step: {'enabled': true, 'tag_filter': {'grade': '${context.grade}', 'type': '痛点'}}"
) )
allowed_kb_ids: list[str] | None = Field(
default=None,
description="[Step-KB-Binding] Allowed knowledge base IDs for this step. If set, KB search will be restricted to these KBs."
)
preferred_kb_ids: list[str] | None = Field(
default=None,
description="[Step-KB-Binding] Preferred knowledge base IDs for this step. These KBs will be searched first."
)
kb_query_hint: str | None = Field(
default=None,
description="[Step-KB-Binding] Query hint for KB search in this step, helps improve retrieval accuracy."
)
max_kb_calls_per_step: int | None = Field(
default=None,
ge=1,
le=5,
description="[Step-KB-Binding] Max KB calls allowed per step. Default is 1 if not set."
)
class ScriptFlowCreate(SQLModel): class ScriptFlowCreate(SQLModel):
@ -1078,6 +1118,7 @@ class MetadataFieldDefinition(SQLModel, table=True):
) )
is_filterable: bool = Field(default=True, description="是否可用于过滤") is_filterable: bool = Field(default=True, description="是否可用于过滤")
is_rank_feature: bool = Field(default=False, description="是否用于排序特征") is_rank_feature: bool = Field(default=False, description="是否用于排序特征")
usage_description: str | None = Field(default=None, description="用途说明")
field_roles: list[str] = Field( field_roles: list[str] = Field(
default_factory=list, default_factory=list,
sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"), sa_column=Column("field_roles", JSON, nullable=False, server_default="'[]'"),
@ -1104,6 +1145,7 @@ class MetadataFieldDefinitionCreate(SQLModel):
scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value]) scope: list[str] = Field(default_factory=lambda: [MetadataScope.KB_DOCUMENT.value])
is_filterable: bool = Field(default=True) is_filterable: bool = Field(default=True)
is_rank_feature: bool = Field(default=False) is_rank_feature: bool = Field(default=False)
usage_description: str | None = None
field_roles: list[str] = Field(default_factory=list) field_roles: list[str] = Field(default_factory=list)
status: str = Field(default=MetadataFieldStatus.DRAFT.value) status: str = Field(default=MetadataFieldStatus.DRAFT.value)
@ -1118,6 +1160,7 @@ class MetadataFieldDefinitionUpdate(SQLModel):
scope: list[str] | None = None scope: list[str] | None = None
is_filterable: bool | None = None is_filterable: bool | None = None
is_rank_feature: bool | None = None is_rank_feature: bool | None = None
usage_description: str | None = None
field_roles: list[str] | None = None field_roles: list[str] | None = None
status: str | None = None status: str | None = None
@ -1131,6 +1174,17 @@ class ExtractStrategy(str, Enum):
USER_INPUT = "user_input" USER_INPUT = "user_input"
class ExtractFailureType(str, Enum):
"""
[AC-MRS-07-UPGRADE] 提取失败类型
统一失败分类用于追踪和日志
"""
EXTRACT_EMPTY = "EXTRACT_EMPTY" # 提取结果为空
EXTRACT_PARSE_FAIL = "EXTRACT_PARSE_FAIL" # 解析失败
EXTRACT_VALIDATION_FAIL = "EXTRACT_VALIDATION_FAIL" # 校验失败
EXTRACT_RUNTIME_ERROR = "EXTRACT_RUNTIME_ERROR" # 运行时错误
class SlotValueSource(str, Enum): class SlotValueSource(str, Enum):
""" """
[AC-MRS-09] 槽位值来源 [AC-MRS-09] 槽位值来源
@ -1145,6 +1199,7 @@ class SlotDefinition(SQLModel, table=True):
""" """
[AC-MRS-07,08] 槽位定义表 [AC-MRS-07,08] 槽位定义表
独立的槽位定义模型与元数据字段解耦但可复用 独立的槽位定义模型与元数据字段解耦但可复用
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
""" """
__tablename__ = "slot_definitions" __tablename__ = "slot_definitions"
@ -1162,14 +1217,31 @@ class SlotDefinition(SQLModel, table=True):
min_length=1, min_length=1,
max_length=100, max_length=100,
) )
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名grade -> '当前年级'",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str = Field( type: str = Field(
default=MetadataFieldType.STRING.value, default=MetadataFieldType.STRING.value,
description="槽位类型: string/number/boolean/enum/array_enum" description="槽位类型: string/number/boolean/enum/array_enum"
) )
required: bool = Field(default=False, description="是否必填槽位") required: bool = Field(default=False, description="是否必填槽位")
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容读取
extract_strategy: str | None = Field( extract_strategy: str | None = Field(
default=None, default=None,
description="提取策略: rule/llm/user_input" description="[兼容字段] 提取策略: rule/llm/user_input已废弃请使用 extract_strategies"
)
# [AC-MRS-07-UPGRADE] 新增策略链字段
extract_strategies: list[str] | None = Field(
default=None,
sa_column=Column("extract_strategies", JSON, nullable=True),
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
) )
validation_rule: str | None = Field( validation_rule: str | None = Field(
default=None, default=None,
@ -1192,14 +1264,72 @@ class SlotDefinition(SQLModel, table=True):
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间") created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间") updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
def get_effective_strategies(self) -> list[str]:
"""
[AC-MRS-07-UPGRADE] 获取有效的提取策略链
优先使用 extract_strategies如果不存在则兼容读取 extract_strategy
"""
if self.extract_strategies and len(self.extract_strategies) > 0:
return self.extract_strategies
if self.extract_strategy:
return [self.extract_strategy]
return []
def validate_strategies(self) -> tuple[bool, str]:
"""
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
Returns:
Tuple of (是否有效, 错误信息)
"""
valid_strategies = {"rule", "llm", "user_input"}
strategies = self.get_effective_strategies()
if not strategies:
return True, "" # 空策略链视为有效(使用默认行为)
# 校验至少1个策略
if len(strategies) == 0:
return False, "提取策略链不能为空"
# 校验不允许重复策略
if len(strategies) != len(set(strategies)):
return False, "提取策略链中不允许重复的策略"
# 校验策略值有效
invalid = [s for s in strategies if s not in valid_strategies]
if invalid:
return False, f"无效的提取策略: {invalid},有效值为: {list(valid_strategies)}"
return True, ""
class SlotDefinitionCreate(SQLModel): class SlotDefinitionCreate(SQLModel):
"""[AC-MRS-07,08] 创建槽位定义""" """[AC-MRS-07,08] 创建槽位定义"""
slot_key: str = Field(..., min_length=1, max_length=100) slot_key: str = Field(..., min_length=1, max_length=100)
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str = Field(default=MetadataFieldType.STRING.value) type: str = Field(default=MetadataFieldType.STRING.value)
required: bool = Field(default=False) required: bool = Field(default=False)
extract_strategy: str | None = None # [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None validation_rule: str | None = None
ask_back_prompt: str | None = None ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None default_value: dict[str, Any] | None = None
@ -1209,9 +1339,28 @@ class SlotDefinitionCreate(SQLModel):
class SlotDefinitionUpdate(SQLModel): class SlotDefinitionUpdate(SQLModel):
"""[AC-MRS-07] 更新槽位定义""" """[AC-MRS-07] 更新槽位定义"""
display_name: str | None = Field(
default=None,
description="槽位名称,给运营/教研看的中文名",
max_length=100,
)
description: str | None = Field(
default=None,
description="槽位说明,解释这个槽位采集什么、用于哪里",
max_length=500,
)
type: str | None = None type: str | None = None
required: bool | None = None required: bool | None = None
extract_strategy: str | None = None # [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None validation_rule: str | None = None
ask_back_prompt: str | None = None ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None default_value: dict[str, Any] | None = None
@ -1522,3 +1671,107 @@ class MidAuditLog(SQLModel, table=True):
high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景") high_risk_scenario: str | None = Field(default=None, description="触发的高风险场景")
latency_ms: int | None = Field(default=None, description="总耗时(ms)") latency_ms: int | None = Field(default=None, description="总耗时(ms)")
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True) created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间", index=True)
class SceneSlotBundleStatus(str, Enum):
"""[AC-SCENE-SLOT-01] 场景槽位包状态"""
DRAFT = "draft"
ACTIVE = "active"
DEPRECATED = "deprecated"
class SceneSlotBundle(SQLModel, table=True):
"""
[AC-SCENE-SLOT-01] 场景-槽位映射配置
定义每个场景需要采集的槽位集合
三层关系
- 层1slot metadata通过 linked_field_id
- 层2scene slot_bundle本模型
- 层3step.expected_variables slot_key话术步骤引用
"""
__tablename__ = "scene_slot_bundles"
__table_args__ = (
Index("ix_scene_slot_bundles_tenant", "tenant_id"),
Index("ix_scene_slot_bundles_tenant_scene", "tenant_id", "scene_key", unique=True),
Index("ix_scene_slot_bundles_tenant_status", "tenant_id", "status"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
scene_key: str = Field(
...,
description="场景标识,如 'open_consult', 'refund_apply', 'course_recommend'",
min_length=1,
max_length=100,
)
scene_name: str = Field(
...,
description="场景名称,如 '开放咨询', '退款申请', '课程推荐'",
min_length=1,
max_length=100,
)
description: str | None = Field(
default=None,
description="场景描述"
)
required_slots: list[str] = Field(
default_factory=list,
sa_column=Column("required_slots", JSON, nullable=False),
description="必填槽位 slot_key 列表"
)
optional_slots: list[str] = Field(
default_factory=list,
sa_column=Column("optional_slots", JSON, nullable=False),
description="可选槽位 slot_key 列表"
)
slot_priority: list[str] | None = Field(
default=None,
sa_column=Column("slot_priority", JSON, nullable=True),
description="槽位采集优先级顺序slot_key 列表)"
)
completion_threshold: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="完成阈值0.0-1.0),必填槽位填充比例达到此值视为完成"
)
ask_back_order: str = Field(
default="priority",
description="追问顺序策略: priority/required_first/parallel"
)
status: str = Field(
default=SceneSlotBundleStatus.DRAFT.value,
description="状态: draft/active/deprecated"
)
version: int = Field(default=1, description="版本号")
created_at: datetime = Field(default_factory=datetime.utcnow, description="创建时间")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="更新时间")
class SceneSlotBundleCreate(SQLModel):
"""[AC-SCENE-SLOT-01] 创建场景槽位包"""
scene_key: str = Field(..., min_length=1, max_length=100)
scene_name: str = Field(..., min_length=1, max_length=100)
description: str | None = None
required_slots: list[str] = Field(default_factory=list)
optional_slots: list[str] = Field(default_factory=list)
slot_priority: list[str] | None = None
completion_threshold: float = Field(default=1.0, ge=0.0, le=1.0)
ask_back_order: str = Field(default="priority")
status: str = Field(default=SceneSlotBundleStatus.DRAFT.value)
class SceneSlotBundleUpdate(SQLModel):
"""[AC-SCENE-SLOT-01] 更新场景槽位包"""
scene_name: str | None = Field(default=None, min_length=1, max_length=100)
description: str | None = None
required_slots: list[str] | None = None
optional_slots: list[str] | None = None
slot_priority: list[str] | None = None
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
ask_back_order: str | None = None
status: str | None = None

View File

@ -139,7 +139,16 @@ class SlotDefinitionResponse(BaseModel):
slot_key: str = Field(..., description="槽位键名") slot_key: str = Field(..., description="槽位键名")
type: str = Field(..., description="槽位类型") type: str = Field(..., description="槽位类型")
required: bool = Field(default=False, description="是否必填槽位") required: bool = Field(default=False, description="是否必填槽位")
extract_strategy: str | None = Field(default=None, description="提取策略") # [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃"
)
# [AC-MRS-07-UPGRADE] 新增策略链字段
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input"
)
validation_rule: str | None = Field(default=None, description="校验规则") validation_rule: str | None = Field(default=None, description="校验规则")
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板") ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
default_value: dict[str, Any] | None = Field(default=None, description="默认值") default_value: dict[str, Any] | None = Field(default=None, description="默认值")
@ -157,9 +166,15 @@ class SlotDefinitionCreateRequest(BaseModel):
slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名") slot_key: str = Field(..., min_length=1, max_length=100, description="槽位键名")
type: str = Field(default="string", description="槽位类型") type: str = Field(default="string", description="槽位类型")
required: bool = Field(default=False, description="是否必填槽位") required: bool = Field(default=False, description="是否必填槽位")
# [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field( extract_strategy: str | None = Field(
default=None, default=None,
description="提取策略: rule/llm/user_input" description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
) )
validation_rule: str | None = Field(default=None, description="校验规则") validation_rule: str | None = Field(default=None, description="校验规则")
ask_back_prompt: str | None = Field(default=None, description="追问提示语模板") ask_back_prompt: str | None = Field(default=None, description="追问提示语模板")
@ -172,7 +187,16 @@ class SlotDefinitionUpdateRequest(BaseModel):
type: str | None = None type: str | None = None
required: bool | None = None required: bool | None = None
extract_strategy: str | None = None # [AC-MRS-07-UPGRADE] 支持策略链
extract_strategies: list[str] | None = Field(
default=None,
description="[AC-MRS-07-UPGRADE] 提取策略链:有序数组,元素为 rule/llm/user_input按顺序执行直到成功"
)
# [AC-MRS-07-UPGRADE] 保留旧字段用于兼容
extract_strategy: str | None = Field(
default=None,
description="[兼容字段] 单提取策略,已废弃,请使用 extract_strategies"
)
validation_rule: str | None = None validation_rule: str | None = None
ask_back_prompt: str | None = None ask_back_prompt: str | None = None
default_value: dict[str, Any] | None = None default_value: dict[str, Any] | None = None

View File

@ -81,7 +81,6 @@ class ApiKeyService:
return return
except Exception as e: except Exception as e:
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}") logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
await session.rollback()
# Backward-compat fallback for environments without new columns # Backward-compat fallback for environments without new columns
try: try:

View File

@ -12,6 +12,9 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import redis
from app.core.config import get_settings
from app.services.embedding.base import EmbeddingException, EmbeddingProvider from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.nomic_provider import NomicEmbeddingProvider from app.services.embedding.nomic_provider import NomicEmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
@ -20,6 +23,7 @@ from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json") EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding"
class EmbeddingProviderFactory: class EmbeddingProviderFactory:
@ -170,8 +174,32 @@ class EmbeddingConfigManager:
self._config = self._default_config.copy() self._config = self._default_config.copy()
self._provider: EmbeddingProvider | None = None self._provider: EmbeddingProvider | None = None
self._settings = get_settings()
self._redis_client: redis.Redis | None = None
self._load_from_redis()
self._load_from_file() self._load_from_file()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._provider_name = saved.get("provider", self._default_provider)
self._config = saved.get("config", self._default_config.copy())
logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to load embedding config from Redis: {e}")
def _load_from_file(self) -> None: def _load_from_file(self) -> None:
"""Load configuration from file if exists.""" """Load configuration from file if exists."""
try: try:
@ -184,6 +212,28 @@ class EmbeddingConfigManager:
except Exception as e: except Exception as e:
logger.warning(f"Failed to load embedding config from file: {e}") logger.warning(f"Failed to load embedding config from file: {e}")
def _save_to_redis(self) -> None:
"""Save configuration to Redis."""
try:
if not self._settings.redis_enabled:
return
if self._redis_client is None:
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
self._redis_client.set(
EMBEDDING_CONFIG_REDIS_KEY,
json.dumps({
"provider": self._provider_name,
"config": self._config,
}, ensure_ascii=False),
)
logger.info(f"Saved embedding config to Redis: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to save embedding config to Redis: {e}")
def _save_to_file(self) -> None: def _save_to_file(self) -> None:
"""Save configuration to file.""" """Save configuration to file."""
try: try:
@ -262,6 +312,7 @@ class EmbeddingConfigManager:
self._config = config self._config = config
self._provider = new_provider_instance self._provider = new_provider_instance
self._save_to_redis()
self._save_to_file() self._save_to_file()
logger.info(f"Updated embedding config: provider={provider}") logger.info(f"Updated embedding config: provider={provider}")

View File

@ -322,7 +322,7 @@ class FlowEngine:
stmt = select(FlowInstance).where( stmt = select(FlowInstance).where(
FlowInstance.tenant_id == tenant_id, FlowInstance.tenant_id == tenant_id,
FlowInstance.session_id == session_id, FlowInstance.session_id == session_id,
).order_by(col(FlowInstance.created_at).desc()) ).order_by(col(FlowInstance.started_at).desc())
result = await self._session.execute(stmt) result = await self._session.execute(stmt)
instance = result.scalar_one_or_none() instance = result.scalar_one_or_none()

View File

@ -106,6 +106,8 @@ class IntentRuleService:
is_enabled=True, is_enabled=True,
hit_count=0, hit_count=0,
metadata_=create_data.metadata_, metadata_=create_data.metadata_,
intent_vector=create_data.intent_vector,
semantic_examples=create_data.semantic_examples,
) )
self._session.add(rule) self._session.add(rule)
await self._session.flush() await self._session.flush()
@ -195,6 +197,10 @@ class IntentRuleService:
rule.is_enabled = update_data.is_enabled rule.is_enabled = update_data.is_enabled
if update_data.metadata_ is not None: if update_data.metadata_ is not None:
rule.metadata_ = update_data.metadata_ rule.metadata_ = update_data.metadata_
if update_data.intent_vector is not None:
rule.intent_vector = update_data.intent_vector
if update_data.semantic_examples is not None:
rule.semantic_examples = update_data.semantic_examples
rule.updated_at = datetime.utcnow() rule.updated_at = datetime.utcnow()
await self._session.flush() await self._session.flush()
@ -267,7 +273,7 @@ class IntentRuleService:
select(IntentRule) select(IntentRule)
.where( .where(
IntentRule.tenant_id == tenant_id, IntentRule.tenant_id == tenant_id,
IntentRule.is_enabled == True, IntentRule.is_enabled == True, # noqa: E712
) )
.order_by(col(IntentRule.priority).desc()) .order_by(col(IntentRule.priority).desc())
) )
@ -300,6 +306,8 @@ class IntentRuleService:
"is_enabled": rule.is_enabled, "is_enabled": rule.is_enabled,
"hit_count": rule.hit_count, "hit_count": rule.hit_count,
"metadata": rule.metadata_, "metadata": rule.metadata_,
"created_at": rule.created_at.isoformat(), "intent_vector": rule.intent_vector,
"updated_at": rule.updated_at.isoformat(), "semantic_examples": rule.semantic_examples,
"created_at": rule.created_at.isoformat() if rule.created_at else None,
"updated_at": rule.updated_at.isoformat() if rule.updated_at else None,
} }

View File

@ -83,6 +83,7 @@ class KBService:
file_name: str, file_name: str,
file_content: bytes, file_content: bytes,
file_type: str | None = None, file_type: str | None = None,
metadata: dict | None = None,
) -> tuple[Document, IndexJob]: ) -> tuple[Document, IndexJob]:
""" """
[AC-ASA-01] Upload document and create indexing job. [AC-ASA-01] Upload document and create indexing job.
@ -108,6 +109,7 @@ class KBService:
file_size=len(file_content), file_size=len(file_content),
file_type=file_type, file_type=file_type,
status=DocumentStatus.PENDING.value, status=DocumentStatus.PENDING.value,
doc_metadata=metadata,
) )
self._session.add(document) self._session.add(document)

View File

@ -3,7 +3,14 @@ LLM Adapter module for AI Service.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers. [AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
""" """
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk from app.services.llm.base import (
LLMClient,
LLMConfig,
LLMResponse,
LLMStreamChunk,
ToolCall,
ToolDefinition,
)
from app.services.llm.openai_client import OpenAIClient from app.services.llm.openai_client import OpenAIClient
__all__ = [ __all__ = [
@ -12,4 +19,6 @@ __all__ = [
"LLMResponse", "LLMResponse",
"LLMStreamChunk", "LLMStreamChunk",
"OpenAIClient", "OpenAIClient",
"ToolCall",
"ToolDefinition",
] ]

View File

@ -28,17 +28,45 @@ class LLMConfig:
extra_params: dict[str, Any] = field(default_factory=dict) extra_params: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolCall:
"""
Represents a function call from the LLM.
Used in Function Calling mode.
"""
id: str
name: str
arguments: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
import json
return {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
}
}
@dataclass @dataclass
class LLMResponse: class LLMResponse:
""" """
Response from LLM generation. Response from LLM generation.
[AC-AISVC-02] Contains generated content and metadata. [AC-AISVC-02] Contains generated content and metadata.
""" """
content: str content: str | None = None
model: str model: str = ""
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop" finish_reason: str = "stop"
tool_calls: list[ToolCall] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
return len(self.tool_calls) > 0
@dataclass @dataclass
@ -50,9 +78,33 @@ class LLMStreamChunk:
delta: str delta: str
model: str model: str
finish_reason: str | None = None finish_reason: str | None = None
tool_calls_delta: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolDefinition:
"""
Tool definition for Function Calling.
Compatible with OpenAI/DeepSeek function calling format.
"""
name: str
description: str
parameters: dict[str, Any]
type: str = "function"
def to_openai_format(self) -> dict[str, Any]:
"""Convert to OpenAI tools format."""
return {
"type": self.type,
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
}
class LLMClient(ABC): class LLMClient(ABC):
""" """
Abstract base class for LLM clients. Abstract base class for LLM clients.
@ -67,6 +119,8 @@ class LLMClient(ABC):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
config: LLMConfig | None = None, config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResponse: ) -> LLMResponse:
""" """
@ -76,10 +130,12 @@ class LLMClient(ABC):
Args: Args:
messages: List of chat messages with 'role' and 'content'. messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides. config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters. **kwargs: Additional provider-specific parameters.
Returns: Returns:
LLMResponse with generated content and metadata. LLMResponse with generated content, tool_calls, and metadata.
Raises: Raises:
LLMException: If generation fails. LLMException: If generation fails.
@ -91,6 +147,8 @@ class LLMClient(ABC):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
config: LLMConfig | None = None, config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]: ) -> AsyncGenerator[LLMStreamChunk, None]:
""" """
@ -100,6 +158,8 @@ class LLMClient(ABC):
Args: Args:
messages: List of chat messages with 'role' and 'content'. messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides. config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters. **kwargs: Additional provider-specific parameters.
Yields: Yields:

View File

@ -11,12 +11,16 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import redis
from app.core.config import get_settings
from app.services.llm.base import LLMClient, LLMConfig from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LLM_CONFIG_FILE = Path("config/llm_config.json") LLM_CONFIG_FILE = Path("config/llm_config.json")
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
@dataclass @dataclass
@ -286,6 +290,8 @@ class LLMConfigManager:
from app.core.config import get_settings from app.core.config import get_settings
settings = get_settings() settings = get_settings()
self._settings = settings
self._redis_client: redis.Redis | None = None
self._current_provider: str = settings.llm_provider self._current_provider: str = settings.llm_provider
self._current_config: dict[str, Any] = { self._current_config: dict[str, Any] = {
@ -299,8 +305,75 @@ class LLMConfigManager:
} }
self._client: LLMClient | None = None self._client: LLMClient | None = None
self._load_from_redis()
self._load_from_file() self._load_from_file()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _save_to_redis(self) -> None:
"""Save configuration to Redis."""
try:
if not self._settings.redis_enabled:
return
if self._redis_client is None:
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
self._redis_client.set(
LLM_CONFIG_REDIS_KEY,
json.dumps({
"provider": self._current_provider,
"config": self._current_config,
}, ensure_ascii=False),
)
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _load_from_file(self) -> None: def _load_from_file(self) -> None:
"""Load configuration from file if exists.""" """Load configuration from file if exists."""
try: try:
@ -364,6 +437,7 @@ class LLMConfigManager:
self._current_provider = provider self._current_provider = provider
self._current_config = validated_config self._current_config = validated_config
self._save_to_redis()
self._save_to_file() self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}") logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")

View File

@ -22,7 +22,14 @@ from tenacity import (
from app.core.config import get_settings from app.core.config import get_settings
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk from app.services.llm.base import (
LLMClient,
LLMConfig,
LLMResponse,
LLMStreamChunk,
ToolCall,
ToolDefinition,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -95,6 +102,8 @@ class OpenAIClient(LLMClient):
messages: list[dict[str, str]], messages: list[dict[str, str]],
config: LLMConfig, config: LLMConfig,
stream: bool = False, stream: bool = False,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build request body for OpenAI API.""" """Build request body for OpenAI API."""
@ -106,6 +115,13 @@ class OpenAIClient(LLMClient):
"top_p": config.top_p, "top_p": config.top_p,
"stream": stream, "stream": stream,
} }
if tools:
body["tools"] = [tool.to_openai_format() for tool in tools]
if tool_choice:
body["tool_choice"] = tool_choice
body.update(config.extra_params) body.update(config.extra_params)
body.update(kwargs) body.update(kwargs)
return body return body
@ -119,6 +135,8 @@ class OpenAIClient(LLMClient):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
config: LLMConfig | None = None, config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResponse: ) -> LLMResponse:
""" """
@ -128,10 +146,12 @@ class OpenAIClient(LLMClient):
Args: Args:
messages: List of chat messages with 'role' and 'content'. messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides. config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters. **kwargs: Additional provider-specific parameters.
Returns: Returns:
LLMResponse with generated content and metadata. LLMResponse with generated content, tool_calls, and metadata.
Raises: Raises:
LLMException: If generation fails. LLMException: If generation fails.
@ -140,9 +160,14 @@ class OpenAIClient(LLMClient):
effective_config = config or self._default_config effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds) client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=False, **kwargs) body = self._build_request_body(
messages, effective_config, stream=False,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}") logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========") logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
role = msg.get("role", "unknown") role = msg.get("role", "unknown")
@ -177,14 +202,18 @@ class OpenAIClient(LLMClient):
try: try:
choice = data["choices"][0] choice = data["choices"][0]
content = choice["message"]["content"] message = choice["message"]
content = message.get("content")
usage = data.get("usage", {}) usage = data.get("usage", {})
finish_reason = choice.get("finish_reason", "stop") finish_reason = choice.get("finish_reason", "stop")
tool_calls = self._parse_tool_calls(message)
logger.info( logger.info(
f"[AC-AISVC-02] Generated response: " f"[AC-AISVC-02] Generated response: "
f"tokens={usage.get('total_tokens', 'N/A')}, " f"tokens={usage.get('total_tokens', 'N/A')}, "
f"finish_reason={finish_reason}" f"finish_reason={finish_reason}, "
f"tool_calls={len(tool_calls)}"
) )
return LLMResponse( return LLMResponse(
@ -192,6 +221,7 @@ class OpenAIClient(LLMClient):
model=data.get("model", effective_config.model), model=data.get("model", effective_config.model),
usage=usage, usage=usage,
finish_reason=finish_reason, finish_reason=finish_reason,
tool_calls=tool_calls,
metadata={"raw_response": data}, metadata={"raw_response": data},
) )
@ -201,11 +231,34 @@ class OpenAIClient(LLMClient):
message=f"Unexpected LLM response format: {e}", message=f"Unexpected LLM response format: {e}",
details=[{"response": str(data)}], details=[{"response": str(data)}],
) )
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from LLM response message."""
tool_calls = []
raw_tool_calls = message.get("tool_calls", [])
for tc in raw_tool_calls:
if tc.get("type") == "function":
func = tc.get("function", {})
try:
arguments = json.loads(func.get("arguments", "{}"))
except json.JSONDecodeError:
arguments = {}
tool_calls.append(ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=arguments,
))
return tool_calls
async def stream_generate( async def stream_generate(
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
config: LLMConfig | None = None, config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]: ) -> AsyncGenerator[LLMStreamChunk, None]:
""" """
@ -215,6 +268,8 @@ class OpenAIClient(LLMClient):
Args: Args:
messages: List of chat messages with 'role' and 'content'. messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides. config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters. **kwargs: Additional provider-specific parameters.
Yields: Yields:
@ -227,9 +282,14 @@ class OpenAIClient(LLMClient):
effective_config = config or self._default_config effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds) client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=True, **kwargs) body = self._build_request_body(
messages, effective_config, stream=True,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}") logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========") logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
role = msg.get("role", "unknown") role = msg.get("role", "unknown")

View File

@ -39,6 +39,19 @@ class MetadataFieldDefinitionService:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self._session = session self._session = session
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
清除租户的元数据字段缓存
在字段创建更新删除时调用
"""
try:
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
await cache_service.invalidate(tenant_id)
except Exception as e:
# 缓存失效失败不影响主流程
logger.warning(f"[AC-IDSMETA-13] Failed to invalidate cache: {e}")
async def list_field_definitions( async def list_field_definitions(
self, self,
@ -180,6 +193,9 @@ class MetadataFieldDefinitionService:
self._session.add(field) self._session.add(field)
await self._session.flush() await self._session.flush()
# 清除缓存,使新字段在下次查询时生效
await self._invalidate_cache(tenant_id)
logger.info( logger.info(
f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, " f"[AC-IDSMETA-13] [AC-MRS-01] Created field definition: tenant={tenant_id}, "
f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}" f"field_key={field.field_key}, status={field.status}, field_roles={field.field_roles}"
@ -223,6 +239,10 @@ class MetadataFieldDefinitionService:
field.is_filterable = field_update.is_filterable field.is_filterable = field_update.is_filterable
if field_update.is_rank_feature is not None: if field_update.is_rank_feature is not None:
field.is_rank_feature = field_update.is_rank_feature field.is_rank_feature = field_update.is_rank_feature
# [AC-MRS-01] 修复:添加 field_roles 更新逻辑
if field_update.field_roles is not None:
self._validate_field_roles(field_update.field_roles)
field.field_roles = field_update.field_roles
if field_update.status is not None: if field_update.status is not None:
old_status = field.status old_status = field.status
field.status = field_update.status field.status = field_update.status
@ -235,6 +255,9 @@ class MetadataFieldDefinitionService:
field.updated_at = datetime.utcnow() field.updated_at = datetime.utcnow()
await self._session.flush() await self._session.flush()
# 清除缓存,使更新在下次查询时生效
await self._invalidate_cache(tenant_id)
logger.info( logger.info(
f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, " f"[AC-IDSMETA-14] Updated field definition: tenant={tenant_id}, "
f"field_id={field_id}, version={field.version}" f"field_id={field_id}, version={field.version}"

View File

@ -23,9 +23,9 @@ RAG Optimization (rag-optimization/spec.md):
""" """
import logging import logging
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime
from typing import Any from typing import Any
from sse_starlette.sse import ServerSentEvent from sse_starlette.sse import ServerSentEvent
@ -46,7 +46,6 @@ from app.services.flow.engine import FlowEngine
from app.services.guardrail.behavior_service import BehaviorRuleService from app.services.guardrail.behavior_service import BehaviorRuleService
from app.services.guardrail.input_scanner import InputScanner from app.services.guardrail.input_scanner import InputScanner
from app.services.guardrail.output_filter import OutputFilter from app.services.guardrail.output_filter import OutputFilter
from app.services.guardrail.word_service import ForbiddenWordService
from app.services.intent.router import IntentRouter from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService from app.services.intent.rule_service import IntentRuleService
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
@ -90,6 +89,8 @@ class GenerationContext:
10. confidence_result: Confidence calculation result 10. confidence_result: Confidence calculation result
11. messages_saved: Whether messages were saved 11. messages_saved: Whether messages were saved
12. final_response: Final ChatResponse 12. final_response: Final ChatResponse
[v0.8.0] Extended with route_trace for hybrid routing observability.
""" """
tenant_id: str tenant_id: str
session_id: str session_id: str
@ -115,6 +116,11 @@ class GenerationContext:
target_kb_ids: list[str] | None = None target_kb_ids: list[str] | None = None
behavior_rules: list[str] = field(default_factory=list) behavior_rules: list[str] = field(default_factory=list)
# [v0.8.0] Hybrid routing fields
route_trace: dict[str, Any] | None = None
fusion_confidence: float | None = None
fusion_decision_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict) diagnostics: dict[str, Any] = field(default_factory=dict)
execution_steps: list[dict[str, Any]] = field(default_factory=list) execution_steps: list[dict[str, Any]] = field(default_factory=list)
@ -487,7 +493,7 @@ class OrchestratorService:
finish_reason="flow_step", finish_reason="flow_step",
) )
ctx.diagnostics["flow_handled"] = True ctx.diagnostics["flow_handled"] = True
logger.info(f"[AC-AISVC-75] Flow provided reply, skipping LLM") logger.info("[AC-AISVC-75] Flow provided reply, skipping LLM")
else: else:
ctx.diagnostics["flow_check_enabled"] = True ctx.diagnostics["flow_check_enabled"] = True
@ -501,8 +507,8 @@ class OrchestratorService:
""" """
[AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route. [AC-AISVC-69, AC-AISVC-70] Step 3: Match intent rules and route.
Routes to: fixed reply, RAG with target KBs, flow start, or transfer. Routes to: fixed reply, RAG with target KBs, flow start, or transfer.
[v0.8.0] Upgraded to use match_hybrid() for hybrid routing.
""" """
# Skip if flow already handled the request
if ctx.diagnostics.get("flow_handled"): if ctx.diagnostics.get("flow_handled"):
logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching") logger.info("[AC-AISVC-69] Flow already handled, skipping intent matching")
return return
@ -513,7 +519,6 @@ class OrchestratorService:
return return
try: try:
# Load enabled rules ordered by priority
async with get_session() as session: async with get_session() as session:
from app.services.intent.rule_service import IntentRuleService from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session) rule_service = IntentRuleService(session)
@ -524,33 +529,64 @@ class OrchestratorService:
ctx.diagnostics["intent_matched"] = False ctx.diagnostics["intent_matched"] = False
return return
# Match intent fusion_result = await self._intent_router.match_hybrid(
ctx.intent_match = self._intent_router.match(
message=ctx.current_message, message=ctx.current_message,
rules=rules, rules=rules,
tenant_id=ctx.tenant_id,
) )
if ctx.intent_match: ctx.route_trace = fusion_result.trace.to_dict()
ctx.fusion_confidence = fusion_result.final_confidence
ctx.fusion_decision_reason = fusion_result.decision_reason
if fusion_result.final_intent:
ctx.intent_match = type(
"IntentMatchResult",
(),
{
"rule": fusion_result.final_intent,
"match_type": fusion_result.decision_reason,
"matched": "",
"to_dict": lambda: {
"rule_id": str(fusion_result.final_intent.id),
"rule_name": fusion_result.final_intent.name,
"match_type": fusion_result.decision_reason,
"matched": "",
"response_type": fusion_result.final_intent.response_type,
"target_kb_ids": (
fusion_result.final_intent.target_kb_ids or []
),
"flow_id": (
str(fusion_result.final_intent.flow_id)
if fusion_result.final_intent.flow_id else None
),
"fixed_reply": fusion_result.final_intent.fixed_reply,
"transfer_message": fusion_result.final_intent.transfer_message,
},
},
)()
logger.info( logger.info(
f"[AC-AISVC-69] Intent matched: rule={ctx.intent_match.rule.name}, " f"[AC-AISVC-69] Intent matched: rule={fusion_result.final_intent.name}, "
f"response_type={ctx.intent_match.rule.response_type}" f"response_type={fusion_result.final_intent.response_type}, "
f"decision={fusion_result.decision_reason}, "
f"confidence={fusion_result.final_confidence:.3f}"
) )
ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict() ctx.diagnostics["intent_match"] = ctx.intent_match.to_dict()
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
# Increment hit count
async with get_session() as session: async with get_session() as session:
rule_service = IntentRuleService(session) rule_service = IntentRuleService(session)
await rule_service.increment_hit_count( await rule_service.increment_hit_count(
tenant_id=ctx.tenant_id, tenant_id=ctx.tenant_id,
rule_id=ctx.intent_match.rule.id, rule_id=fusion_result.final_intent.id,
) )
# Route based on response_type rule = fusion_result.final_intent
if ctx.intent_match.rule.response_type == "fixed": if rule.response_type == "fixed":
# Fixed reply - skip LLM
ctx.llm_response = LLMResponse( ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.fixed_reply or "收到您的消息。", content=rule.fixed_reply or "收到您的消息。",
model="intent_fixed", model="intent_fixed",
usage={}, usage={},
finish_reason="intent_fixed", finish_reason="intent_fixed",
@ -558,20 +594,18 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM") logger.info("[AC-AISVC-70] Intent fixed reply, skipping LLM")
elif ctx.intent_match.rule.response_type == "rag": elif rule.response_type == "rag":
# RAG with target KBs ctx.target_kb_ids = rule.target_kb_ids or []
ctx.target_kb_ids = ctx.intent_match.rule.target_kb_ids or []
logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}") logger.info(f"[AC-AISVC-70] Intent RAG, target_kb_ids={ctx.target_kb_ids}")
elif ctx.intent_match.rule.response_type == "flow": elif rule.response_type == "flow":
# Start script flow if rule.flow_id and self._flow_engine:
if ctx.intent_match.rule.flow_id and self._flow_engine:
async with get_session() as session: async with get_session() as session:
flow_engine = FlowEngine(session) flow_engine = FlowEngine(session)
instance, first_step = await flow_engine.start( instance, first_step = await flow_engine.start(
tenant_id=ctx.tenant_id, tenant_id=ctx.tenant_id,
session_id=ctx.session_id, session_id=ctx.session_id,
flow_id=ctx.intent_match.rule.flow_id, flow_id=rule.flow_id,
) )
if first_step: if first_step:
ctx.llm_response = LLMResponse( ctx.llm_response = LLMResponse(
@ -583,10 +617,9 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent flow started, skipping LLM") logger.info("[AC-AISVC-70] Intent flow started, skipping LLM")
elif ctx.intent_match.rule.response_type == "transfer": elif rule.response_type == "transfer":
# Transfer to human
ctx.llm_response = LLMResponse( ctx.llm_response = LLMResponse(
content=ctx.intent_match.rule.transfer_message or "正在为您转接人工客服...", content=rule.transfer_message or "正在为您转接人工客服...",
model="intent_transfer", model="intent_transfer",
usage={}, usage={},
finish_reason="intent_transfer", finish_reason="intent_transfer",
@ -600,9 +633,25 @@ class OrchestratorService:
ctx.diagnostics["intent_handled"] = True ctx.diagnostics["intent_handled"] = True
logger.info("[AC-AISVC-70] Intent transfer, skipping LLM") logger.info("[AC-AISVC-70] Intent transfer, skipping LLM")
if fusion_result.need_clarify:
ctx.diagnostics["need_clarify"] = True
ctx.diagnostics["clarify_candidates"] = [
{"id": str(r.id), "name": r.name}
for r in (fusion_result.clarify_candidates or [])
]
logger.info(
f"[AC-AISVC-121] Low confidence, need clarify: "
f"confidence={fusion_result.final_confidence:.3f}, "
f"candidates={len(fusion_result.clarify_candidates or [])}"
)
else: else:
ctx.diagnostics["intent_match_enabled"] = True ctx.diagnostics["intent_match_enabled"] = True
ctx.diagnostics["intent_matched"] = False ctx.diagnostics["intent_matched"] = False
ctx.diagnostics["fusion_result"] = fusion_result.to_dict()
logger.info(
f"[AC-AISVC-69] No intent matched, decision={fusion_result.decision_reason}"
)
except Exception as e: except Exception as e:
logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}") logger.warning(f"[AC-AISVC-69] Intent matching failed: {e}")
@ -724,43 +773,43 @@ class OrchestratorService:
async def _build_metadata_filters(self, ctx: GenerationContext): async def _build_metadata_filters(self, ctx: GenerationContext):
""" """
[AC-IDSMETA-19] Build metadata filters from context. [AC-IDSMETA-19] Build metadata filters from context.
Sources: Sources:
1. Intent rule metadata (if matched) 1. Intent rule metadata (if matched)
2. Session metadata 2. Session metadata
3. Request metadata 3. Request metadata
4. Extracted slots from conversation 4. Extracted slots from conversation
Returns: Returns:
TagFilter with at least grade, subject, scene if available TagFilter with at least grade, subject, scene if available
""" """
from app.services.retrieval.metadata import TagFilter from app.services.retrieval.metadata import TagFilter
filter_fields = {} filter_fields = {}
# 1. From intent rule metadata # 1. From intent rule metadata
if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_: if ctx.intent_match and hasattr(ctx.intent_match.rule, 'metadata_') and ctx.intent_match.rule.metadata_:
intent_metadata = ctx.intent_match.rule.metadata_ intent_metadata = ctx.intent_match.rule.metadata_
for key in ['grade', 'subject', 'scene']: for key in ['grade', 'subject', 'scene']:
if key in intent_metadata: if key in intent_metadata:
filter_fields[key] = intent_metadata[key] filter_fields[key] = intent_metadata[key]
# 2. From session/request metadata # 2. From session/request metadata
if ctx.request_metadata: if ctx.request_metadata:
for key in ['grade', 'subject', 'scene']: for key in ['grade', 'subject', 'scene']:
if key in ctx.request_metadata and key not in filter_fields: if key in ctx.request_metadata and key not in filter_fields:
filter_fields[key] = ctx.request_metadata[key] filter_fields[key] = ctx.request_metadata[key]
# 3. From merged context (extracted slots) # 3. From merged context (extracted slots)
if ctx.merged_context and hasattr(ctx.merged_context, 'slots'): if ctx.merged_context and hasattr(ctx.merged_context, 'slots'):
slots = ctx.merged_context.slots or {} slots = ctx.merged_context.slots or {}
for key in ['grade', 'subject', 'scene']: for key in ['grade', 'subject', 'scene']:
if key in slots and key not in filter_fields: if key in slots and key not in filter_fields:
filter_fields[key] = slots[key] filter_fields[key] = slots[key]
if not filter_fields: if not filter_fields:
return None return None
return TagFilter(fields=filter_fields) return TagFilter(fields=filter_fields)
async def _build_system_prompt(self, ctx: GenerationContext) -> None: async def _build_system_prompt(self, ctx: GenerationContext) -> None:
@ -981,11 +1030,11 @@ class OrchestratorService:
"根据知识库信息,我找到了一些相关内容," "根据知识库信息,我找到了一些相关内容,"
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。" "但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
) )
# [AC-IDSMETA-20] Record structured fallback reason code # [AC-IDSMETA-20] Record structured fallback reason code
fallback_reason_code = self._determine_fallback_reason_code(ctx) fallback_reason_code = self._determine_fallback_reason_code(ctx)
ctx.diagnostics["fallback_reason_code"] = fallback_reason_code ctx.diagnostics["fallback_reason_code"] = fallback_reason_code
logger.warning( logger.warning(
f"[AC-IDSMETA-20] No recall, using fallback: " f"[AC-IDSMETA-20] No recall, using fallback: "
f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, " f"intent_id={ctx.intent_match.rule.id if ctx.intent_match else None}, "
@ -993,7 +1042,7 @@ class OrchestratorService:
f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, " f"applied_metadata_filters={ctx.diagnostics.get('retrieval', {}).get('applied_metadata_filters')}, "
f"fallback_reason_code={fallback_reason_code}" f"fallback_reason_code={fallback_reason_code}"
) )
return ( return (
"抱歉,我暂时无法处理您的请求。" "抱歉,我暂时无法处理您的请求。"
"请稍后重试或联系人工客服获取帮助。" "请稍后重试或联系人工客服获取帮助。"
@ -1002,7 +1051,7 @@ class OrchestratorService:
def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str: def _determine_fallback_reason_code(self, ctx: GenerationContext) -> str:
""" """
[AC-IDSMETA-20] Determine structured fallback reason code. [AC-IDSMETA-20] Determine structured fallback reason code.
Reason codes: Reason codes:
- no_recall_after_metadata_filter: No results after applying metadata filters - no_recall_after_metadata_filter: No results after applying metadata filters
- no_recall_no_kb: No target knowledge bases configured - no_recall_no_kb: No target knowledge bases configured
@ -1011,27 +1060,27 @@ class OrchestratorService:
- no_recall_error: Retrieval error occurred - no_recall_error: Retrieval error occurred
""" """
retrieval_diag = ctx.diagnostics.get("retrieval", {}) retrieval_diag = ctx.diagnostics.get("retrieval", {})
# Check for retrieval error # Check for retrieval error
if ctx.diagnostics.get("retrieval_error"): if ctx.diagnostics.get("retrieval_error"):
return "no_recall_error" return "no_recall_error"
# Check if metadata filters were applied # Check if metadata filters were applied
if retrieval_diag.get("applied_metadata_filters"): if retrieval_diag.get("applied_metadata_filters"):
return "no_recall_after_metadata_filter" return "no_recall_after_metadata_filter"
# Check if target KBs were configured # Check if target KBs were configured
if not ctx.target_kb_ids: if not ctx.target_kb_ids:
return "no_recall_no_kb" return "no_recall_no_kb"
# Check if KB is empty (no candidates at all) # Check if KB is empty (no candidates at all)
if retrieval_diag.get("total_candidates", 0) == 0: if retrieval_diag.get("total_candidates", 0) == 0:
return "no_recall_kb_empty" return "no_recall_kb_empty"
# Results found but filtered out by score threshold # Results found but filtered out by score threshold
if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0: if retrieval_diag.get("total_candidates", 0) > 0 and retrieval_diag.get("filtered_hits", 0) == 0:
return "no_recall_low_score" return "no_recall_low_score"
return "no_recall_unknown" return "no_recall_unknown"
def _calculate_confidence(self, ctx: GenerationContext) -> None: def _calculate_confidence(self, ctx: GenerationContext) -> None:
@ -1122,6 +1171,7 @@ class OrchestratorService:
[AC-AISVC-02] Build final ChatResponse from generation context. [AC-AISVC-02] Build final ChatResponse from generation context.
Step 12 of the 12-step pipeline. Step 12 of the 12-step pipeline.
Uses filtered_reply from Step 9. Uses filtered_reply from Step 9.
[v0.8.0] Includes route_trace in response metadata.
""" """
# Use filtered_reply if available, otherwise use llm_response.content # Use filtered_reply if available, otherwise use llm_response.content
if ctx.filtered_reply: if ctx.filtered_reply:
@ -1142,6 +1192,10 @@ class OrchestratorService:
"execution_steps": ctx.execution_steps, "execution_steps": ctx.execution_steps,
} }
# [v0.8.0] Include route_trace in response metadata
if ctx.route_trace:
response_metadata["route_trace"] = ctx.route_trace
return ChatResponse( return ChatResponse(
reply=reply, reply=reply,
confidence=confidence, confidence=confidence,

View File

@ -178,6 +178,9 @@ class PromptTemplateService:
current_version = v current_version = v
break break
# Get latest version for current_content (not just published)
latest_version = versions[0] if versions else None
return { return {
"id": str(template.id), "id": str(template.id),
"name": template.name, "name": template.name,
@ -185,6 +188,8 @@ class PromptTemplateService:
"description": template.description, "description": template.description,
"is_default": template.is_default, "is_default": template.is_default,
"metadata": template.metadata_, "metadata": template.metadata_,
"current_content": latest_version.system_instruction if latest_version else None,
"variables": latest_version.variables if latest_version else [],
"current_version": { "current_version": {
"version": current_version.version, "version": current_version.version,
"status": current_version.status, "status": current_version.status,

View File

@ -28,6 +28,7 @@ class RetrievalContext:
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
tag_filter: "TagFilter | None" = None tag_filter: "TagFilter | None" = None
kb_ids: list[str] | None = None kb_ids: list[str] | None = None
metadata_filter: dict[str, Any] | None = None
def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None: def get_tag_filter_dict(self) -> dict[str, str | list[str] | None] | None:
"""获取标签过滤器的字典表示""" """获取标签过滤器的字典表示"""

View File

@ -4,6 +4,7 @@ Vector retriever for AI Service.
""" """
import logging import logging
from typing import Any
from app.core.config import get_settings from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client from app.core.qdrant_client import QdrantClient, get_qdrant_client
@ -76,16 +77,30 @@ class VectorRetriever(BaseRetriever):
query_vector = await self._get_embedding(ctx.query) query_vector = await self._get_embedding(ctx.query)
logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}") logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}")
logger.info(f"[AC-AISVC-16] Searching in tenant collection: tenant_id={ctx.tenant_id}") logger.info(f"[AC-AISVC-16] Searching in tenant collections: tenant_id={ctx.tenant_id}")
hits = await client.search( if ctx.kb_ids:
tenant_id=ctx.tenant_id, logger.info(f"[AC-AISVC-16] Restricting search to KB IDs: {ctx.kb_ids}")
query_vector=query_vector, hits = await client.search(
limit=self._top_k, tenant_id=ctx.tenant_id,
score_threshold=self._score_threshold, query_vector=query_vector,
) limit=self._top_k,
score_threshold=self._score_threshold,
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} raw hits") vector_name="full",
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
else:
hits = await client.search(
tenant_id=ctx.tenant_id,
query_vector=query_vector,
limit=self._top_k,
score_threshold=self._score_threshold,
vector_name="full",
metadata_filter=ctx.metadata_filter,
)
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} hits")
retrieval_hits = [ retrieval_hits = [
RetrievalHit( RetrievalHit(
text=hit.get("payload", {}).get("text", ""), text=hit.get("payload", {}).get("text", ""),
@ -133,6 +148,47 @@ class VectorRetriever(BaseRetriever):
diagnostics={"error": str(e), "is_insufficient": True}, diagnostics={"error": str(e), "is_insufficient": True},
) )
def _apply_metadata_filter(
self,
hits: list[dict[str, Any]],
metadata_filter: dict[str, Any],
) -> list[dict[str, Any]]:
"""
应用元数据过滤条件
支持的操作:
- {"$eq": value} : 等于
- {"$in": [values]} : 在列表中
"""
filtered = []
for hit in hits:
payload = hit.get("payload", {})
hit_metadata = payload.get("metadata", {})
match = True
for field_key, condition in metadata_filter.items():
hit_value = hit_metadata.get(field_key)
if isinstance(condition, dict):
if "$eq" in condition:
if hit_value != condition["$eq"]:
match = False
break
elif "$in" in condition:
if hit_value not in condition["$in"]:
match = False
break
else:
# 直接值比较
if hit_value != condition:
match = False
break
if match:
filtered.append(hit)
return filtered
async def _get_embedding(self, text: str) -> list[float]: async def _get_embedding(self, text: str) -> list[float]:
""" """
Generate embedding for text using pluggable embedding provider. Generate embedding for text using pluggable embedding provider.

View File

@ -1,6 +1,7 @@
""" """
Slot Definition Service. Slot Definition Service.
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务 [AC-MRS-07, AC-MRS-08] 槽位定义管理服务
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
""" """
import logging import logging
@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
class SlotDefinitionService: class SlotDefinitionService:
""" """
[AC-MRS-07, AC-MRS-08] 槽位定义服务 [AC-MRS-07, AC-MRS-08] 槽位定义服务
[AC-MRS-07-UPGRADE] 支持提取策略链管理
管理独立的槽位定义模型与元数据字段解耦但可复用 管理独立的槽位定义模型与元数据字段解耦但可复用
""" """
@ -114,6 +116,58 @@ class SlotDefinitionService:
result = await self._session.execute(stmt) result = await self._session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
"""
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
Args:
strategies: 策略链列表
Returns:
Tuple of (是否有效, 错误信息)
"""
if strategies is None:
return True, ""
if not isinstance(strategies, list):
return False, "extract_strategies 必须是数组类型"
if len(strategies) == 0:
return False, "提取策略链不能为空数组"
# 校验不允许重复策略
if len(strategies) != len(set(strategies)):
return False, "提取策略链中不允许重复的策略"
# 校验策略值有效
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
if invalid:
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
return True, ""
def _normalize_strategies(
self,
extract_strategies: list[str] | None,
extract_strategy: str | None,
) -> list[str] | None:
"""
[AC-MRS-07-UPGRADE] 规范化提取策略
优先使用 extract_strategies如果不存在则使用 extract_strategy
Args:
extract_strategies: 策略链新字段
extract_strategy: 单策略旧字段兼容
Returns:
规范化后的策略链或 None
"""
if extract_strategies is not None:
return extract_strategies
if extract_strategy:
return [extract_strategy]
return None
async def create_slot_definition( async def create_slot_definition(
self, self,
tenant_id: str, tenant_id: str,
@ -121,6 +175,7 @@ class SlotDefinitionService:
) -> SlotDefinition: ) -> SlotDefinition:
""" """
[AC-MRS-07, AC-MRS-08] 创建槽位定义 [AC-MRS-07, AC-MRS-08] 创建槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链
Args: Args:
tenant_id: 租户 ID tenant_id: 租户 ID
@ -148,11 +203,16 @@ class SlotDefinitionService:
f"有效类型为: {self.VALID_TYPES}" f"有效类型为: {self.VALID_TYPES}"
) )
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES: # [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
raise ValueError( strategies = self._normalize_strategies(
f"无效的提取策略 '{slot_create.extract_strategy}'" slot_create.extract_strategies,
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}" slot_create.extract_strategy
) )
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
linked_field = None linked_field = None
if slot_create.linked_field_id: if slot_create.linked_field_id:
@ -162,12 +222,22 @@ class SlotDefinitionService:
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在" f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
) )
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
# 如果前端提交了 extract_strategies则使用第一个作为旧字段值
old_strategy = slot_create.extract_strategy
if not old_strategy and strategies and len(strategies) > 0:
old_strategy = strategies[0]
slot = SlotDefinition( slot = SlotDefinition(
tenant_id=tenant_id, tenant_id=tenant_id,
slot_key=slot_create.slot_key, slot_key=slot_create.slot_key,
display_name=slot_create.display_name,
description=slot_create.description,
type=slot_create.type, type=slot_create.type,
required=slot_create.required, required=slot_create.required,
extract_strategy=slot_create.extract_strategy, # [AC-MRS-07-UPGRADE] 同时保存新旧字段
extract_strategy=old_strategy,
extract_strategies=strategies,
validation_rule=slot_create.validation_rule, validation_rule=slot_create.validation_rule,
ask_back_prompt=slot_create.ask_back_prompt, ask_back_prompt=slot_create.ask_back_prompt,
default_value=slot_create.default_value, default_value=slot_create.default_value,
@ -180,6 +250,7 @@ class SlotDefinitionService:
logger.info( logger.info(
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, " f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
f"slot_key={slot.slot_key}, required={slot.required}, " f"slot_key={slot.slot_key}, required={slot.required}, "
f"strategies={strategies}, "
f"linked_field_id={slot.linked_field_id}" f"linked_field_id={slot.linked_field_id}"
) )
@ -193,6 +264,7 @@ class SlotDefinitionService:
) -> SlotDefinition | None: ) -> SlotDefinition | None:
""" """
更新槽位定义 更新槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链更新
Args: Args:
tenant_id: 租户 ID tenant_id: 租户 ID
@ -206,6 +278,12 @@ class SlotDefinitionService:
if not slot: if not slot:
return None return None
if slot_update.display_name is not None:
slot.display_name = slot_update.display_name
if slot_update.description is not None:
slot.description = slot_update.description
if slot_update.type is not None: if slot_update.type is not None:
if slot_update.type not in self.VALID_TYPES: if slot_update.type not in self.VALID_TYPES:
raise ValueError( raise ValueError(
@ -217,13 +295,28 @@ class SlotDefinitionService:
if slot_update.required is not None: if slot_update.required is not None:
slot.required = slot_update.required slot.required = slot_update.required
if slot_update.extract_strategy is not None: # [AC-MRS-07-UPGRADE] 处理提取策略链更新
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES: # 如果传入了 extract_strategies 或 extract_strategy则更新
raise ValueError( if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
f"无效的提取策略 '{slot_update.extract_strategy}'" strategies = self._normalize_strategies(
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}" slot_update.extract_strategies,
) slot_update.extract_strategy
slot.extract_strategy = slot_update.extract_strategy )
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
slot.extract_strategies = strategies
# 如果前端提交了 extract_strategy则使用它否则使用策略链的第一个
if slot_update.extract_strategy is not None:
slot.extract_strategy = slot_update.extract_strategy
elif strategies and len(strategies) > 0:
slot.extract_strategy = strategies[0]
else:
slot.extract_strategy = None
if slot_update.validation_rule is not None: if slot_update.validation_rule is not None:
slot.validation_rule = slot_update.validation_rule slot.validation_rule = slot_update.validation_rule
@ -250,7 +343,7 @@ class SlotDefinitionService:
logger.info( logger.info(
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, " f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}" f"slot_id={slot_id}, strategies={slot.extract_strategies}"
) )
return slot return slot
@ -331,9 +424,13 @@ class SlotDefinitionService:
"id": str(slot.id), "id": str(slot.id),
"tenant_id": slot.tenant_id, "tenant_id": slot.tenant_id,
"slot_key": slot.slot_key, "slot_key": slot.slot_key,
"display_name": slot.display_name,
"description": slot.description,
"type": slot.type, "type": slot.type,
"required": slot.required, "required": slot.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": slot.extract_strategy, "extract_strategy": slot.extract_strategy,
"extract_strategies": slot.extract_strategies,
"validation_rule": slot.validation_rule, "validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt, "ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value, "default_value": slot.default_value,