feat: enhance agent orchestrator with runtime hardening and tool governance [AC-MARH-01~12] [AC-IDMP-11~18]

This commit is contained in:
MerCry 2026-03-06 01:10:24 +08:00
parent 978aaee885
commit 5f4bde8752
9 changed files with 416 additions and 38 deletions

View File

@ -933,6 +933,7 @@ async def _execute_agent_mode(
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=tenant_id,
)
final_answer, react_ctx, agent_trace = await orchestrator.execute(

View File

@ -5,6 +5,7 @@ Middleware for AI Service.
import logging
import re
import uuid
from collections.abc import Callable
from fastapi import Request, Response, status
@ -20,6 +21,17 @@ TENANT_ID_HEADER = "X-Tenant-Id"
API_KEY_HEADER = "X-API-Key"
ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream"
REQUEST_ID_HEADER = "X-Request-Id"
# Prompt template protected variable names injected by system/runtime.
# These are reserved for internal orchestration and should not be overridden by user input.
PROMPT_PROTECTED_VARIABLES = {
"available_tools",
"query",
"history",
"internal_protocol",
"output_contract",
}
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
@ -29,6 +41,15 @@ PATHS_SKIP_API_KEY = {
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/openapi/v1/share/chat",
}
PATHS_SKIP_TENANT = {
"/health",
"/ai/health",
"/favicon.ico",
"/openapi/v1/share/chat",
}
@ -63,17 +84,24 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
if self._should_skip_api_key(request.url.path):
return await call_next(request)
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
request.state.request_id = request_id
api_key = request.headers.get(API_KEY_HEADER)
if not api_key or not api_key.strip():
logger.warning(f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}")
return JSONResponse(
logger.warning(
f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Missing required header: X-API-Key",
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
api_key = api_key.strip()
@ -90,17 +118,45 @@ class ApiKeyMiddleware(BaseHTTPMiddleware):
except Exception as e:
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
if not service.validate_key(api_key):
logger.warning(f"[AC-AISVC-50] Invalid API key for {request.url.path}")
return JSONResponse(
client_ip = request.client.host if request.client else None
tenant_id = request.headers.get(TENANT_ID_HEADER, "")
validation = service.validate_key_with_context(api_key, client_ip=client_ip)
if not validation.ok:
if validation.reason == "rate_limited":
logger.warning(
f"[AC-AISVC-50] Rate limited: path={request.url.path}, tenant={tenant_id}, "
f"ip={client_ip}, qpm={validation.rate_limit_qpm}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=ErrorResponse(
code=ErrorCode.SERVICE_UNAVAILABLE.value,
message="Rate limit exceeded",
details=[{"reason": "rate_limited", "limit_qpm": validation.rate_limit_qpm}],
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
logger.warning(
f"[AC-AISVC-50] API key validation failed: reason={validation.reason}, "
f"path={request.url.path}, tenant={tenant_id}, ip={client_ip}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Invalid API key",
details=[{"reason": validation.reason}],
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
return await call_next(request)
response = await call_next(request)
response.headers[REQUEST_ID_HEADER] = request_id
return response
def _should_skip_api_key(self, path: str) -> bool:
"""Check if the path should skip API key validation."""
@ -122,7 +178,7 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context()
if request.url.path in ("/health", "/ai/health"):
if self._should_skip_tenant(request.url.path):
return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER)
@ -173,6 +229,15 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
return response
def _should_skip_tenant(self, path: str) -> bool:
"""Check if the path should skip tenant validation."""
if path in PATHS_SKIP_TENANT:
return True
for skip_path in PATHS_SKIP_TENANT:
if path.startswith(skip_path):
return True
return False
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
"""
[AC-AISVC-10] Ensure tenant exists in database, create if not.

View File

@ -9,10 +9,11 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, Response
from app.api import chat_router, health_router
from app.api.mid import router as mid_router
from app.api.openapi import router as openapi_router
from app.api.admin import (
api_key_router,
dashboard_router,
@ -130,6 +131,11 @@ app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
@app.get("/favicon.ico", include_in_schema=False)
async def favicon() -> Response:
return Response(status_code=204)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
@ -171,6 +177,7 @@ app.include_router(slot_definition_router)
app.include_router(tenants_router)
app.include_router(mid_router)
app.include_router(openapi_router)
if __name__ == "__main__":

View File

@ -10,7 +10,9 @@ ReAct Flow:
"""
import asyncio
import json
import logging
import re
import time
import uuid
from dataclasses import dataclass
@ -25,6 +27,8 @@ from app.models.mid.schemas import (
TraceInfo,
)
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
logger = logging.getLogger(__name__)
@ -57,6 +61,7 @@ class AgentOrchestrator:
- ReAct loop with max 5 iterations (min 3)
- Per-tool timeout (2s) and end-to-end timeout (8s)
- Automatic fallback on iteration limit or timeout
- Template-based prompt with variable injection
"""
def __init__(
@ -65,11 +70,17 @@ class AgentOrchestrator:
timeout_governor: TimeoutGovernor | None = None,
llm_client: Any = None,
tool_registry: Any = None,
template_service: PromptTemplateService | None = None,
variable_resolver: VariableResolver | None = None,
tenant_id: str | None = None,
):
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._llm_client = llm_client
self._tool_registry = tool_registry
self._template_service = template_service
self._variable_resolver = variable_resolver or VariableResolver()
self._tenant_id = tenant_id
async def execute(
self,
@ -212,7 +223,7 @@ class AgentOrchestrator:
for tc in react_ctx.tool_calls[-3:]:
observations.append(f"工具 {tc.tool_name}: {tc.result_digest or '无结果'}")
prompt = self._build_react_prompt(user_message, observations)
prompt = await self._build_react_prompt(user_message, observations)
response = await self._llm_client.generate([{"role": "user", "content": prompt}])
logger.info(f"[AC-MARH-07] LLM response content: {response.content[:500] if response.content else 'None'}")
@ -223,38 +234,199 @@ class AgentOrchestrator:
logger.error(f"[AC-MARH-07] Think failed: {e}")
return AgentThought(content=f"思考失败: {str(e)}")
def _build_react_prompt(self, user_message: str, observations: list[str]) -> str:
"""Build ReAct prompt for LLM."""
async def _build_react_prompt(self, user_message: str, observations: list[str]) -> str:
"""Build ReAct prompt for LLM with template support."""
obs_text = "\n".join(observations) if observations else ""
return f"""你是一个智能助手,正在使用 ReAct 模式处理用户请求。
tools_text = self._build_tools_section()
internal_protocol = f"""你必须遵循以下决策协议:
1. 优先使用已有观察信息历史观察上一步工具结果避免重复调用同类工具
2. 当问题需要外部事实或结构化状态时再调用工具如果可直接回答则不要调用
3. 缺少关键参数时优先向用户追问不要使用空参数调用工具
4. 工具失败时先说明已尝试再给出降级方案或下一步引导
5. 只能调用可用工具列表中的工具工具名必须完全匹配区分大小写
6. tenant_id 由系统自动注入绝不能由你填写猜测或修改
7. 对用户输出必须拟人自然有同理心不暴露工具调用/路由/策略等内部术语
"""
output_contract = """输出格式(二选一):
A) 直接回答用户
Final Answer: [给用户的最终回答]
B) 调用工具
Thought: [你的思考]
Action: [工具名称]
Action Input:
```json
{"param1": "value1"}
```
要求
- Action Input 必须是合法 JSON 对象
- 不要输出不存在的工具名
- 如果要调用工具Action Action Input 必须同时出现
"""
default_template = f"""你是一个智能客服助手,正在使用 ReAct 模式处理用户请求。
{tools_text}
用户消息: {user_message}
历史观察:
{obs_text}
请思考下一步行动如果已经有足够信息回答用户请直接给出最终答案
如果需要使用工具请按以下格式回复:
Thought: [你的思考]
Action: [工具名称]
Action Input: {{"param1": "value1"}}
{internal_protocol}
{output_contract}
"""
def _parse_thought(self, content: str) -> AgentThought:
"""Parse LLM response into AgentThought."""
action = None
action_input = None
if not self._template_service or not self._tenant_id:
return default_template
if "Action:" in content:
lines = content.split("\n")
for line in lines:
if line.startswith("Action:"):
action = line.replace("Action:", "").strip()
elif line.startswith("Action Input:"):
import json
try:
action_input = json.loads(line.replace("Action Input:", "").strip())
template_version = await self._template_service.get_published_template(
tenant_id=self._tenant_id,
scene="agent_react",
)
if not template_version:
return default_template
extra_context = {
"available_tools": tools_text,
"query": user_message,
"history": obs_text,
"internal_protocol": internal_protocol,
"output_contract": output_contract,
}
resolved_template = self._variable_resolver.resolve(
template=template_version.system_instruction,
variables=template_version.variables,
extra_context=extra_context,
)
final_prompt = (
f"{resolved_template}\n\n"
f"【系统强制规则】\n{internal_protocol}\n"
f"【输出契约】\n{output_contract}"
)
logger.info(f"[AC-MARH-07] Using template: scene=agent_react, version={template_version.version}")
return final_prompt
except Exception as e:
logger.warning(f"[AC-MARH-07] Failed to load template, using default: {e}")
return default_template
def _build_tools_section(self) -> str:
"""Build rich tools section for ReAct prompt."""
if not self._tool_registry:
return "当前没有可用的工具。"
tools = self._tool_registry.list_tools(enabled_only=True)
if not tools:
return "当前没有可用的工具。"
lines = ["## 可用工具列表", "", "以下是你可以使用的工具,只能使用这些工具:", ""]
for tool in tools:
meta = tool.metadata or {}
lines.append(f"### {tool.name}")
lines.append(f"用途: {tool.description}")
when_to_use = meta.get("when_to_use")
when_not_to_use = meta.get("when_not_to_use")
if when_to_use:
lines.append(f"何时使用: {when_to_use}")
if when_not_to_use:
lines.append(f"何时不要使用: {when_not_to_use}")
params = meta.get("parameters")
if isinstance(params, dict):
properties = params.get("properties", {})
required = params.get("required", [])
if properties:
lines.append("参数:")
for param_name, param_info in properties.items():
param_desc = param_info.get("description", "") if isinstance(param_info, dict) else ""
line = f" - {param_name}: {param_desc}".strip()
if param_name == "tenant_id":
line += " (系统注入,模型不要填写)"
elif param_name in required:
line += " (必填)"
lines.append(line)
if meta.get("example_action_input"):
lines.append("示例入参(JSON):")
try:
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False)
except Exception:
example_text = str(meta["example_action_input"])
lines.append(example_text)
if meta.get("result_interpretation"):
lines.append(f"结果解释: {meta['result_interpretation']}")
lines.append("")
return "\n".join(lines)
def _extract_json_object(self, text: str) -> dict[str, Any] | None:
"""Extract the first valid JSON object from free text."""
candidates = []
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
if code_block_match:
candidates.append(code_block_match.group(1).strip())
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
if fence_match:
candidates.append(fence_match.group(1).strip())
brace_match = re.search(r"\{[\s\S]*\}", text)
if brace_match:
candidates.append(brace_match.group(0).strip())
for candidate in candidates:
if not candidate:
continue
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
fixed = candidate.replace("'", '"')
try:
obj = json.loads(fixed)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
continue
return None
def _parse_thought(self, content: str) -> AgentThought:
"""Parse LLM response into AgentThought with robust format handling."""
action = None
action_input: dict[str, Any] | None = None
action_match = re.search(r"^Action:\s*(.+)$", content, re.MULTILINE)
if action_match:
action = action_match.group(1).strip()
action_input_match = re.search(
r"Action Input:\s*([\s\S]*)$",
content,
re.IGNORECASE,
)
if action_input_match:
raw_input_text = action_input_match.group(1).strip()
parsed = self._extract_json_object(raw_input_text)
action_input = parsed if parsed is not None else {}
if action and action_input is None:
action_input = {}
return AgentThought(content=content, action=action, action_input=action_input)
@ -285,27 +457,31 @@ Action Input: {{"param1": "value1"}}
)
try:
tool_args = dict(thought.action_input or {})
if self._tenant_id:
tool_args["tenant_id"] = self._tenant_id
result = await asyncio.wait_for(
self._tool_registry.execute(
tool_name=tool_name,
args=thought.action_input or {},
args=tool_args,
),
timeout=self._timeout_governor.per_tool_timeout_seconds
)
duration_ms = int((time.time() - start_time) * 1000)
return ToolResult(
success=result.get("success", False),
output=result.get("output"),
error=result.get("error"),
success=result.success,
output=result.output,
error=result.error,
duration_ms=duration_ms,
), ToolCallTrace(
tool_name=tool_name,
tool_type=ToolType.INTERNAL,
duration_ms=duration_ms,
status=ToolCallStatus.OK if result.get("success") else ToolCallStatus.ERROR,
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
result_digest=str(result.get("output"))[:100] if result.get("output") else None,
result_digest=str(result.output)[:100] if result.output else None,
)
except asyncio.TimeoutError:

View File

@ -470,6 +470,25 @@ def register_high_risk_check_tool(
"supports_metadata_driven": True,
"min_scenarios": ["refund", "complaint_escalation", "privacy_sensitive_promise", "transfer"],
"supports_routing_signal_filter": True,
"when_to_use": "当用户消息可能涉及退款、投诉升级、隐私承诺、转人工等高风险场景时使用。",
"when_not_to_use": "当已完成高风险判定且结果未变化,或当前仅需知识检索时不要重复调用。",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "用户消息原文"},
"tenant_id": {"type": "string", "description": "租户 ID"},
"domain": {"type": "string", "description": "业务域(可选)"},
"scene": {"type": "string", "description": "场景标识(可选)"},
"context": {"type": "object", "description": "上下文(仅 routing_signal 字段会被消费)"}
},
"required": ["message", "tenant_id"]
},
"example_action_input": {
"message": "我要投诉你们并且现在就给我退款不然我去12315",
"tenant_id": "default",
"scene": "open_consult"
},
"result_interpretation": "matched=true 时优先按 recommended_mode 执行;关注 risk_scenario、rule_id、fallback_reason_code。"
},
)

View File

@ -344,6 +344,26 @@ def register_intent_hint_tool(
"low_confidence_threshold": effective_config.low_confidence_threshold,
"top_n": effective_config.top_n,
"supports_routing_signal_filter": True,
"when_to_use": "当用户意图不明确、需要给 policy_router 提供软路由信号时使用。",
"when_not_to_use": "当已经明确进入固定模式/流程模式,或已有确定意图结果时不重复调用。",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "用户输入原文"},
"tenant_id": {"type": "string", "description": "租户 ID"},
"history": {"type": "array", "description": "会话历史(可选)"},
"top_n": {"type": "integer", "description": "返回建议数量(可选)"},
"context": {"type": "object", "description": "上下文字段(仅 routing_signal 字段会被消费)"}
},
"required": ["message", "tenant_id"]
},
"example_action_input": {
"message": "我想退款,但是也想先咨询下怎么处理",
"tenant_id": "default",
"top_n": 3,
"context": {"order_status": "delivered", "channel": "web"}
},
"result_interpretation": "关注输出中的 intent / confidence / suggested_mode。该工具只提供建议不做最终决策。"
},
)

View File

@ -479,6 +479,27 @@ def register_kb_search_dynamic_tool(
metadata={
"supports_dynamic_filter": True,
"min_score_threshold": config.min_score_threshold if config else 0.5,
"when_to_use": "当需要知识库事实支撑回答,且需按租户元数据动态过滤时使用。",
"when_not_to_use": "当用户问题不依赖知识库(纯闲聊/仅流程确认)或已有充分 KB 结果时不重复调用。",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "检索查询文本"},
"tenant_id": {"type": "string", "description": "租户 ID"},
"scene": {"type": "string", "description": "场景标识,如 open_consult"},
"top_k": {"type": "integer", "description": "返回条数"},
"context": {"type": "object", "description": "上下文,用于动态过滤字段"}
},
"required": ["query", "tenant_id"]
},
"example_action_input": {
"query": "退款到账一般要多久",
"tenant_id": "default",
"scene": "open_consult",
"top_k": 5,
"context": {"product_line": "vip_course", "region": "beijing"}
},
"result_interpretation": "success=true 且 hits 非空表示命中知识missing_required_slots 非空时应先向用户补采信息。"
},
)

View File

@ -576,6 +576,27 @@ def register_memory_recall_tool(
"ac_ids": ["AC-IDMP-13"],
"recall_scope": cfg.default_recall_scope,
"max_recent_messages": cfg.max_recent_messages,
"when_to_use": "当需要补全用户画像、历史事实、偏好、槽位,避免重复追问时使用。",
"when_not_to_use": "当当前轮次已经有完整上下文且无需个性化记忆支撑时可不调用。",
"parameters": {
"type": "object",
"properties": {
"tenant_id": {"type": "string", "description": "租户 ID"},
"user_id": {"type": "string", "description": "用户 ID"},
"session_id": {"type": "string", "description": "会话 ID"},
"recall_scope": {"type": "array", "description": "召回范围,例如 profile/facts/preferences/summary/slots"},
"max_recent_messages": {"type": "integer", "description": "历史回填窗口大小"}
},
"required": ["tenant_id", "user_id", "session_id"]
},
"example_action_input": {
"tenant_id": "default",
"user_id": "u_10086",
"session_id": "s_abc_001",
"recall_scope": ["profile", "facts", "preferences", "summary", "slots"],
"max_recent_messages": 8
},
"result_interpretation": "关注 profile/facts/preferences/slots/missing_slots。若 fallback_reason_code 存在,需降级处理。"
},
)

View File

@ -259,6 +259,18 @@ class RoleBasedFieldProvider:
FieldRole.RESOURCE_FILTER.value
)
async def get_resource_filter_field_keys(
self,
tenant_id: str,
) -> list[str]:
"""
[AC-MRS-11] 获取资源过滤角色字段键名列表
"""
return await self.get_field_keys_by_role(
tenant_id,
FieldRole.RESOURCE_FILTER.value
)
async def get_slot_fields(
self,
tenant_id: str,
@ -272,6 +284,18 @@ class RoleBasedFieldProvider:
FieldRole.SLOT.value
)
async def get_slot_field_keys(
self,
tenant_id: str,
) -> list[str]:
"""
[AC-MRS-12] 获取槽位角色字段键名列表
"""
return await self.get_field_keys_by_role(
tenant_id,
FieldRole.SLOT.value
)
async def get_routing_signal_fields(
self,
tenant_id: str,
@ -285,6 +309,18 @@ class RoleBasedFieldProvider:
FieldRole.ROUTING_SIGNAL.value
)
async def get_routing_signal_field_keys(
self,
tenant_id: str,
) -> list[str]:
"""
[AC-MRS-13] 获取路由信号角色字段键名列表
"""
return await self.get_field_keys_by_role(
tenant_id,
FieldRole.ROUTING_SIGNAL.value
)
async def get_prompt_var_fields(
self,
tenant_id: str,
@ -297,3 +333,15 @@ class RoleBasedFieldProvider:
tenant_id,
FieldRole.PROMPT_VAR.value
)
async def get_prompt_var_field_keys(
self,
tenant_id: str,
) -> list[str]:
"""
[AC-MRS-14] 获取提示词变量角色字段键名列表
"""
return await self.get_field_keys_by_role(
tenant_id,
FieldRole.PROMPT_VAR.value
)