feat: implement output guardrail with forbidden word detection and behavior rules [AC-AISVC-78~AC-AISVC-85]

This commit is contained in:
MerCry 2026-02-27 16:03:39 +08:00
parent 9d8ecf0bb2
commit 8c259cee30
12 changed files with 2137 additions and 49 deletions

511
AI中台对接文档.md Normal file
View File

@ -0,0 +1,511 @@
# AI 中台对接文档
## 1. 概述
本文档描述 Python AI 中台对渠道侧Java 主框架)暴露的 HTTP 接口规范,用于智能客服对话生成和服务健康检查。
### 1.1 服务信息
- **服务名称**: AI Service (Python AI 中台)
- **服务地址**: `http://ai-service:8080`
- **协议**: HTTP/1.1
- **数据格式**: JSON / SSE (Server-Sent Events)
- **字符编码**: UTF-8
- **契约版本**: v1.1.0
### 1.2 核心能力
- ✅ 智能对话生成(基于 LLM + RAG
- ✅ 多租户隔离(基于 `X-Tenant-Id`
- ✅ 会话上下文管理(基于 `sessionId`
- ✅ 流式/非流式双模式输出
- ✅ 置信度评估与转人工建议
- ✅ 服务健康检查
---
## 2. 认证与租户隔离
### 2.1 API Key 认证(必填)
所有接口请求(除健康检查外)必须在 HTTP Header 中携带 API Key
```http
X-API-Key: <your_api_key>
```
**说明**
- API Key 用于身份认证和访问控制
- 缺失或无效的 API Key 将返回 `401 Unauthorized`
- API Key 由 AI 中台管理员分配,请妥善保管
- 以下路径无需 API Key`/health`、`/ai/health`、`/docs`
### 2.2 租户标识(必填)
所有接口请求必须在 HTTP Header 中携带租户 ID
```http
X-Tenant-Id: <tenant_id>
```
**租户 ID 格式规范**`name@ash@year`
示例:
- `szmp@ash@2026` - 深圳某项目 2026 年
- `abc123@ash@2025` - ABC 项目 2025 年
**说明**
- 租户 ID 用于数据隔离(知识库、会话历史、配置等)
- 缺失或格式错误的租户 ID 将返回 `400 Bad Request`
- 不同租户的数据完全隔离,不可跨租户访问
- 租户不存在时会自动创建
---
## 3. 接口列表
| 接口路径 | 方法 | 功能 | 响应模式 |
|---------|------|------|---------|
| `/ai/chat` | POST | 生成 AI 回复 | JSON / SSE |
| `/ai/health` | GET | 健康检查 | JSON |
---
## 4. 接口详细说明
### 4.1 生成 AI 回复
**接口路径**: `POST /ai/chat`
**功能描述**: 根据用户消息和会话历史生成 AI 回复,支持 RAG 检索增强、上下文管理、置信度评估。
#### 4.1.1 请求参数
**Headers**:
```http
Content-Type: application/json
X-API-Key: <your_api_key>
X-Tenant-Id: <tenant_id>
Accept: application/json # 或 text/event-stream流式输出
```
**Body** (JSON):
| 字段 | 类型 | 必填 | 说明 |
|-----|------|------|------|
| `sessionId` | string | ✅ | 会话 ID用于关联同一会话的对话历史 |
| `currentMessage` | string | ✅ | 当前用户消息内容 |
| `channelType` | string | ✅ | 渠道类型,枚举值:`wechat`、`douyin`、`jd` |
| `history` | array | ❌ | 历史消息列表可选AI 中台会自动管理会话历史) |
| `metadata` | object | ❌ | 扩展元数据(可选) |
**history 数组元素结构**:
```json
{
"role": "user | assistant",
"content": "消息内容"
}
```
**请求示例**:
```json
{
"sessionId": "kf_001_wx123456_1708765432000",
"currentMessage": "我想了解产品价格",
"channelType": "wechat",
"metadata": {
"channelUserId": "wx123456",
"extra": "..."
}
}
```
#### 4.1.2 响应格式
##### 模式 1: JSON 响应(非流式)
**状态码**: `200 OK`
**响应体**:
```json
{
"reply": "您好,我们的产品价格根据套餐不同有所差异...",
"confidence": 0.92,
"shouldTransfer": false,
"transferReason": null,
"metadata": {
"retrieval_count": 3,
"rag_enabled": true
}
}
```
**字段说明**:
| 字段 | 类型 | 必填 | 说明 |
|-----|------|------|------|
| `reply` | string | ✅ | AI 生成的回复内容 |
| `confidence` | number | ✅ | 置信度评分0.0-1.0),越高表示回答越可靠 |
| `shouldTransfer` | boolean | ✅ | 是否建议转人工true=建议转人工) |
| `transferReason` | string | ❌ | 转人工原因(可选) |
| `metadata` | object | ❌ | 响应元数据(可选) |
##### 模式 2: SSE 流式响应
**触发条件**: 请求头包含 `Accept: text/event-stream`
**响应头**:
```http
Content-Type: text/event-stream
Cache-Control: no-cache
Connection: keep-alive
```
**事件流格式**:
1. **增量消息事件** (可多次发送)
```
event: message
data: {"delta": "您好,"}
event: message
data: {"delta": "我们的产品"}
```
2. **最终结果事件** (发送一次后关闭连接)
```
event: final
data: {"reply": "完整回复内容", "confidence": 0.92, "shouldTransfer": false}
```
3. **错误事件** (发生错误时发送)
```
event: error
data: {"code": "INTERNAL_ERROR", "message": "错误描述"}
```
**事件序列保证**:
- `message*` (0 或多次) → `final` (1 次) → 连接关闭
- 或 `message*` (0 或多次) → `error` (1 次) → 连接关闭
#### 4.1.3 错误响应
**401 Unauthorized** - 认证失败
```json
{
"code": "UNAUTHORIZED",
"message": "Missing required header: X-API-Key",
"details": []
}
```
**400 Bad Request** - 请求参数错误
```json
{
"code": "INVALID_REQUEST",
"message": "缺少必填字段: sessionId",
"details": []
}
```
**400 Bad Request** - 租户 ID 格式错误
```json
{
"code": "INVALID_TENANT_ID",
"message": "Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
"details": []
}
```
**500 Internal Server Error** - 服务内部错误
```json
{
"code": "INTERNAL_ERROR",
"message": "LLM 调用失败",
"details": []
}
```
**503 Service Unavailable** - 服务不可用
```json
{
"code": "SERVICE_UNAVAILABLE",
"message": "向量数据库连接失败",
"details": []
}
```
---
### 4.2 健康检查
**接口路径**: `GET /ai/health`
**功能描述**: 检查 AI 服务是否正常运行,用于服务监控和负载均衡健康探测。
#### 4.2.1 请求参数
无需请求参数,无需认证头。
#### 4.2.2 响应格式
**200 OK** - 服务正常
```json
{
"status": "healthy"
}
```
**503 Service Unavailable** - 服务不健康
```json
{
"status": "unhealthy"
}
```
---
## 5. 调用示例
### 5.1 Java 调用示例(非流式)
```java
import org.springframework.http.*;
import org.springframework.web.client.RestTemplate;
public class AIServiceClient {
private final RestTemplate restTemplate;
private final String aiServiceUrl = "http://ai-service:8080";
private final String apiKey = "your_api_key_here";
public ChatResponse generateReply(String tenantId, ChatRequest request) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("X-API-Key", apiKey);
headers.set("X-Tenant-Id", tenantId);
HttpEntity<ChatRequest> entity = new HttpEntity<>(request, headers);
ResponseEntity<ChatResponse> response = restTemplate.postForEntity(
aiServiceUrl + "/ai/chat",
entity,
ChatResponse.class
);
return response.getBody();
}
}
```
### 5.2 Java 调用示例(流式)
```java
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
public class AIServiceStreamClient {
private final WebClient webClient;
private final String apiKey = "your_api_key_here";
public Flux<ServerSentEvent<String>> generateReplyStream(
String tenantId,
ChatRequest request
) {
return webClient.post()
.uri("/ai/chat")
.header("X-API-Key", apiKey)
.header("X-Tenant-Id", tenantId)
.header("Accept", "text/event-stream")
.bodyValue(request)
.retrieve()
.bodyToFlux(ServerSentEvent.class);
}
}
```
### 5.3 cURL 调用示例
```bash
# 非流式调用
curl -X POST http://ai-service:8080/ai/chat \
-H "Content-Type: application/json" \
-H "X-API-Key: your_api_key_here" \
-H "X-Tenant-Id: szmp@ash@2026" \
-d '{
"sessionId": "kf_001_wx123456_1708765432000",
"currentMessage": "我想了解产品价格",
"channelType": "wechat"
}'
# 流式调用
curl -X POST http://ai-service:8080/ai/chat \
-H "Content-Type: application/json" \
-H "X-API-Key: your_api_key_here" \
-H "X-Tenant-Id: szmp@ash@2026" \
-H "Accept: text/event-stream" \
-d '{
"sessionId": "kf_001_wx123456_1708765432000",
"currentMessage": "我想了解产品价格",
"channelType": "wechat"
}'
# 健康检查(无需认证)
curl http://ai-service:8080/ai/health
```
---
## 6. 业务逻辑说明
### 6.1 会话管理
- **会话标识**: `sessionId` 用于唯一标识一个对话会话
- **自动持久化**: AI 中台会自动保存会话历史,无需调用方每次传递完整历史
- **可选历史**: 调用方可通过 `history` 字段提供外部历史AI 中台会合并处理
- **租户隔离**: 相同 `sessionId` 在不同 `tenantId` 下视为不同会话
### 6.2 RAG 检索增强
- **自动触发**: AI 中台会根据用户问题自动判断是否需要检索知识库
- **多知识库**: 支持按知识库类型产品知识、FAQ、话术模板等分类检索
- **置信度评估**: 检索结果质量会影响 `confidence` 评分
- **兜底策略**: 检索失败或无结果时AI 会基于通用知识回答并降低置信度
### 6.3 转人工建议
`shouldTransfer` 字段由以下因素决定:
- ✅ 置信度低于阈值(默认 0.6
- ✅ 检索无结果或结果质量差
- ✅ 用户明确要求人工服务
- ✅ 意图识别命中"转人工"规则
**注意**: `shouldTransfer=true` 仅为建议最终是否转人工由调用方Java 主框架)决策。
### 6.4 意图识别与规则引擎
- **前置处理**: 用户消息会先经过意图识别
- **固定回复**: 命中固定规则时直接返回预设话术(跳过 LLM 调用)
- **话术流程**: 命中流程规则时进入多轮引导对话
- **定向检索**: 命中 RAG 规则时使用指定知识库检索
### 6.5 输出护栏
- **禁词过滤**: AI 回复会自动过滤禁词(竞品名称、敏感词等)
- **替换策略**: 支持星号替换、文本替换、整条拦截三种策略
- **行为约束**: Prompt 中注入行为规则(如"不承诺具体赔偿金额"
---
## 7. 性能与限制
### 7.1 性能指标
| 指标 | 非流式 | 流式 |
|-----|-------|------|
| 首字响应时间 | 1-3 秒 | 200-500 毫秒 |
| 完整响应时间 | 2-5 秒 | 3-8 秒 |
| 并发支持 | 100+ QPS | 50+ QPS |
### 7.2 限制说明
- **消息长度**: 单条消息最大 4000 字符
- **历史长度**: 建议历史消息不超过 20 轮AI 中台会自动截断)
- **超时设置**: 建议调用方设置 10 秒超时非流式、30 秒超时(流式)
- **重试策略**: 503 错误建议指数退避重试500 错误建议降级处理
---
## 8. 错误码参考
| 错误码 | HTTP 状态码 | 说明 | 处理建议 |
|-------|-----------|------|---------|
| `UNAUTHORIZED` | 401 | 认证失败(缺少或无效 API Key | 检查 X-API-Key 请求头 |
| `INVALID_REQUEST` | 400 | 请求参数错误 | 检查必填字段和参数格式 |
| `MISSING_TENANT_ID` | 400 | 缺少租户 ID | 添加 X-Tenant-Id 请求头 |
| `INVALID_TENANT_ID` | 400 | 租户 ID 格式错误 | 使用正确格式name@ash@year |
| `INTERNAL_ERROR` | 500 | 服务内部错误 | 降级处理或重试 |
| `LLM_ERROR` | 500 | LLM 调用失败 | 降级处理或重试 |
| `SERVICE_UNAVAILABLE` | 503 | 服务不可用 | 指数退避重试 |
| `QDRANT_ERROR` | 503 | 向量库不可用 | 指数退避重试 |
| `STREAMING_ERROR` | 200 (SSE) | 流式传输错误 | 关闭连接并重试 |
---
## 9. 最佳实践
### 9.1 API Key 管理
- API Key 由 AI 中台管理员通过管理后台分配
- 建议为不同环境(开发/测试/生产)使用不同的 API Key
- API Key 应存储在配置文件或环境变量中,不要硬编码
- 定期轮换 API Key 以提高安全性
### 9.2 会话 ID 生成规范
建议格式: `{业务前缀}_{租户ID}_{渠道用户ID}_{时间戳}`
示例: `kf_001_wx123456_1708765432000`
### 9.3 流式 vs 非流式选择
- **流式**: 适用于 Web/App 实时对话场景,用户体验更好
- **非流式**: 适用于批量处理、异步任务、API 集成场景
### 9.4 降级策略建议
```java
public ChatResponse generateReplyWithFallback(String tenantId, ChatRequest request) {
try {
return aiServiceClient.generateReply(tenantId, request);
} catch (ServiceUnavailableException e) {
// 降级策略 1: 返回固定话术
return ChatResponse.builder()
.reply("抱歉,当前咨询量较大,请稍后再试或转人工服务。")
.confidence(0.0)
.shouldTransfer(true)
.build();
} catch (Exception e) {
// 降级策略 2: 直接转人工
return ChatResponse.builder()
.reply("系统繁忙,正在为您转接人工客服...")
.confidence(0.0)
.shouldTransfer(true)
.transferReason("AI 服务异常")
.build();
}
}
```
### 9.5 监控指标建议
- ✅ 接口响应时间P50/P95/P99
- ✅ 接口成功率
- ✅ 置信度分布
- ✅ 转人工率
- ✅ 错误码分布
---
## 10. 变更日志
| 版本 | 日期 | 变更内容 |
|-----|------|---------|
| v1.1.0 | 2026-02-27 | 新增流式输出支持、意图识别、输出护栏 |
| v1.0.0 | 2026-02-20 | 初始版本,支持基础对话生成和健康检查 |
---
## 11. 联系方式
- **技术支持**: AI 中台开发团队
- **问题反馈**: 提交 Issue 到项目仓库
- **文档更新**: 参考 `spec/ai-service/openapi.provider.yaml`
---
**文档生成时间**: 2026-02-27
**契约版本**: v1.1.0
**维护状态**: ✅ 活跃维护

View File

@ -15,4 +15,5 @@ from app.api.admin.rag import router as rag_router
from app.api.admin.script_flows import router as script_flows_router
from app.api.admin.sessions import router as sessions_router
from app.api.admin.tenants import router as tenants_router
__all__ = ["api_key_router", "dashboard_router", "embedding_router", "guardrails_router", "intent_rules_router", "kb_router", "llm_router", "prompt_templates_router", "rag_router", "script_flows_router", "sessions_router", "tenants_router"]

View File

@ -0,0 +1,296 @@
"""
Guardrail Management API.
[AC-AISVC-78~AC-AISVC-85] Forbidden words and behavior rules CRUD endpoints.
"""
import logging
import uuid
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.models.entities import (
BehaviorRuleCreate,
BehaviorRuleUpdate,
ForbiddenWordCreate,
ForbiddenWordUpdate,
)
from app.services.guardrail.behavior_service import BehaviorRuleService
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/guardrails", tags=["Guardrails"])
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
return x_tenant_id
@router.get("/forbidden-words")
async def list_forbidden_words(
tenant_id: str = Depends(get_tenant_id),
category: str | None = None,
is_enabled: bool | None = None,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-79] List all forbidden words for a tenant.
"""
logger.info(
f"[AC-AISVC-79] Listing forbidden words for tenant={tenant_id}, "
f"category={category}, is_enabled={is_enabled}"
)
service = ForbiddenWordService(session)
words = await service.list_words(tenant_id, category, is_enabled)
data = []
for word in words:
data.append(await service.word_to_info_dict(word))
return {"data": data}
@router.post("/forbidden-words", status_code=201)
async def create_forbidden_word(
body: ForbiddenWordCreate,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-78] Create a new forbidden word.
"""
valid_categories = ["competitor", "sensitive", "political", "custom"]
if body.category not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}"
)
valid_strategies = ["mask", "replace", "block"]
if body.strategy not in valid_strategies:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy. Must be one of: {valid_strategies}"
)
if body.strategy == "replace" and not body.replacement:
raise HTTPException(
status_code=400,
detail="replacement is required when strategy is 'replace'"
)
logger.info(
f"[AC-AISVC-78] Creating forbidden word for tenant={tenant_id}, word={body.word}"
)
service = ForbiddenWordService(session)
try:
word = await service.create_word(tenant_id, body)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return await service.word_to_info_dict(word)
@router.get("/forbidden-words/{word_id}")
async def get_forbidden_word(
word_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-79] Get forbidden word detail.
"""
logger.info(f"[AC-AISVC-79] Getting forbidden word for tenant={tenant_id}, id={word_id}")
service = ForbiddenWordService(session)
word = await service.get_word(tenant_id, word_id)
if not word:
raise HTTPException(status_code=404, detail="Forbidden word not found")
return await service.word_to_info_dict(word)
@router.put("/forbidden-words/{word_id}")
async def update_forbidden_word(
word_id: uuid.UUID,
body: ForbiddenWordUpdate,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-80] Update a forbidden word.
"""
valid_categories = ["competitor", "sensitive", "political", "custom"]
if body.category is not None and body.category not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}"
)
valid_strategies = ["mask", "replace", "block"]
if body.strategy is not None and body.strategy not in valid_strategies:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy. Must be one of: {valid_strategies}"
)
logger.info(f"[AC-AISVC-80] Updating forbidden word for tenant={tenant_id}, id={word_id}")
service = ForbiddenWordService(session)
try:
word = await service.update_word(tenant_id, word_id, body)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if not word:
raise HTTPException(status_code=404, detail="Forbidden word not found")
return await service.word_to_info_dict(word)
@router.delete("/forbidden-words/{word_id}", status_code=204)
async def delete_forbidden_word(
word_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> None:
"""
[AC-AISVC-81] Delete a forbidden word.
"""
logger.info(f"[AC-AISVC-81] Deleting forbidden word for tenant={tenant_id}, id={word_id}")
service = ForbiddenWordService(session)
success = await service.delete_word(tenant_id, word_id)
if not success:
raise HTTPException(status_code=404, detail="Forbidden word not found")
@router.get("/behavior-rules")
async def list_behavior_rules(
tenant_id: str = Depends(get_tenant_id),
category: str | None = None,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-85] List all behavior rules for a tenant.
"""
logger.info(
f"[AC-AISVC-85] Listing behavior rules for tenant={tenant_id}, category={category}"
)
service = BehaviorRuleService(session)
rules = await service.list_rules(tenant_id, category)
data = []
for rule in rules:
data.append(await service.rule_to_info_dict(rule))
return {"data": data}
@router.post("/behavior-rules", status_code=201)
async def create_behavior_rule(
body: BehaviorRuleCreate,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-84] Create a new behavior rule.
"""
valid_categories = ["compliance", "tone", "boundary", "custom"]
if body.category not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}"
)
logger.info(
f"[AC-AISVC-84] Creating behavior rule for tenant={tenant_id}, category={body.category}"
)
service = BehaviorRuleService(session)
try:
rule = await service.create_rule(tenant_id, body)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return await service.rule_to_info_dict(rule)
@router.get("/behavior-rules/{rule_id}")
async def get_behavior_rule(
rule_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-85] Get behavior rule detail.
"""
logger.info(f"[AC-AISVC-85] Getting behavior rule for tenant={tenant_id}, id={rule_id}")
service = BehaviorRuleService(session)
rule = await service.get_rule(tenant_id, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Behavior rule not found")
return await service.rule_to_info_dict(rule)
@router.put("/behavior-rules/{rule_id}")
async def update_behavior_rule(
rule_id: uuid.UUID,
body: BehaviorRuleUpdate,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""
[AC-AISVC-85] Update a behavior rule.
"""
valid_categories = ["compliance", "tone", "boundary", "custom"]
if body.category is not None and body.category not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category. Must be one of: {valid_categories}"
)
logger.info(f"[AC-AISVC-85] Updating behavior rule for tenant={tenant_id}, id={rule_id}")
service = BehaviorRuleService(session)
try:
rule = await service.update_rule(tenant_id, rule_id, body)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if not rule:
raise HTTPException(status_code=404, detail="Behavior rule not found")
return await service.rule_to_info_dict(rule)
@router.delete("/behavior-rules/{rule_id}", status_code=204)
async def delete_behavior_rule(
rule_id: uuid.UUID,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_session),
) -> None:
"""
[AC-AISVC-85] Delete a behavior rule.
"""
logger.info(f"[AC-AISVC-85] Deleting behavior rule for tenant={tenant_id}, id={rule_id}")
service = BehaviorRuleService(session)
success = await service.delete_rule(tenant_id, rule_id)
if not success:
raise HTTPException(status_code=404, detail="Behavior rule not found")

View File

@ -12,7 +12,20 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api import chat_router, health_router
from app.api.admin import api_key_router, dashboard_router, embedding_router, guardrails_router, intent_rules_router, kb_router, llm_router, prompt_templates_router, rag_router, script_flows_router, sessions_router, tenants_router
from app.api.admin import (
api_key_router,
dashboard_router,
embedding_router,
guardrails_router,
intent_rules_router,
kb_router,
llm_router,
prompt_templates_router,
rag_router,
script_flows_router,
sessions_router,
tenants_router,
)
from app.api.admin.kb_optimized import router as kb_optimized_router
from app.core.config import get_settings
from app.core.database import close_db, init_db
@ -76,12 +89,12 @@ app = FastAPI(
version=settings.app_version,
description="""
Python AI Service for intelligent chat with RAG support.
## Features
- Multi-tenant isolation via X-Tenant-Id header
- SSE streaming support via Accept: text/event-stream
- RAG-powered responses with confidence scoring
## Response Modes
- **JSON**: Default response mode (Accept: application/json or no Accept header)
- **SSE Streaming**: Set Accept: text/event-stream for streaming responses
@ -130,6 +143,7 @@ app.include_router(chat_router)
app.include_router(api_key_router)
app.include_router(dashboard_router)
app.include_router(embedding_router)
app.include_router(guardrails_router)
app.include_router(intent_rules_router)
app.include_router(kb_router)
app.include_router(kb_optimized_router)

View File

@ -8,7 +8,7 @@ from datetime import datetime
from enum import Enum
from typing import Any
from sqlalchemy import Column, JSON
from sqlalchemy import JSON, Column
from sqlmodel import Field, Index, SQLModel
@ -141,7 +141,10 @@ class KnowledgeBase(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
name: str = Field(..., description="Knowledge base name")
kb_type: str = Field(default=KBType.GENERAL.value, description="Knowledge base type: product/faq/script/policy/general")
kb_type: str = Field(
default=KBType.GENERAL.value,
description="Knowledge base type: product/faq/script/policy/general"
)
description: str | None = Field(default=None, description="Knowledge base description")
priority: int = Field(default=0, ge=0, description="Priority weight, higher value means higher priority")
is_enabled: bool = Field(default=True, description="Whether the knowledge base is enabled")
@ -289,14 +292,25 @@ class PromptTemplateVersion(SQLModel, table=True):
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
template_id: uuid.UUID = Field(..., description="Foreign key to prompt_templates.id", foreign_key="prompt_templates.id", index=True)
template_id: uuid.UUID = Field(
...,
description="Foreign key to prompt_templates.id",
foreign_key="prompt_templates.id",
index=True
)
version: int = Field(..., description="Version number (auto-incremented per template)")
status: str = Field(default=TemplateVersionStatus.DRAFT.value, description="Version status: draft/published/archived")
system_instruction: str = Field(..., description="System instruction content with {{variable}} placeholders")
status: str = Field(
default=TemplateVersionStatus.DRAFT.value,
description="Version status: draft/published/archived"
)
system_instruction: str = Field(
...,
description="System instruction content with {{variable}} placeholders"
)
variables: list[dict[str, Any]] | None = Field(
default=None,
sa_column=Column("variables", JSON, nullable=True),
description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N', 'description': '人设名称'}]"
description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N'}]"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
@ -510,7 +524,10 @@ class BehaviorRule(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
rule_text: str = Field(..., description="Behavior constraint description, e.g., 'Do not promise specific compensation amounts'")
rule_text: str = Field(
...,
description="Behavior constraint description, e.g., 'Do not promise specific compensation'"
)
category: str = Field(..., description="Category: compliance/tone/boundary/custom")
is_enabled: bool = Field(default=True, description="Whether the rule is enabled")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
@ -618,7 +635,7 @@ class ScriptFlow(SQLModel, table=True):
steps: list[dict[str, Any]] = Field(
default=[],
sa_column=Column("steps", JSON, nullable=False),
description="Flow steps list with step_no, content, wait_input, timeout_seconds, timeout_action, next_conditions, default_next"
description="Flow steps list with step_no, content, wait_input, timeout_seconds"
)
is_enabled: bool = Field(default=True, description="Whether the flow is enabled")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
@ -640,13 +657,21 @@ class FlowInstance(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
flow_id: uuid.UUID = Field(..., description="Foreign key to script_flows.id", foreign_key="script_flows.id", index=True)
flow_id: uuid.UUID = Field(
...,
description="Foreign key to script_flows.id",
foreign_key="script_flows.id",
index=True
)
current_step: int = Field(default=1, ge=1, description="Current step number (1-indexed)")
status: str = Field(default=FlowInstanceStatus.ACTIVE.value, description="Instance status: active/completed/timeout/cancelled")
status: str = Field(
default=FlowInstanceStatus.ACTIVE.value,
description="Instance status: active/completed/timeout/cancelled"
)
context: dict[str, Any] | None = Field(
default=None,
sa_column=Column("context", JSON, nullable=True),
description="Flow execution context, stores user inputs etc."
description="Flow execution context, stores user inputs"
)
started_at: datetime = Field(default_factory=datetime.utcnow, description="Instance start time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -660,10 +685,13 @@ class FlowStep(SQLModel):
content: str = Field(..., description="Script content for this step")
wait_input: bool = Field(default=True, description="Whether to wait for user input")
timeout_seconds: int = Field(default=120, ge=1, description="Timeout in seconds")
timeout_action: str = Field(default=TimeoutAction.REPEAT.value, description="Action on timeout: repeat/skip/transfer")
timeout_action: str = Field(
default=TimeoutAction.REPEAT.value,
description="Action on timeout: repeat/skip/transfer"
)
next_conditions: list[dict[str, Any]] | None = Field(
default=None,
description="Conditions for next step: [{'keywords': [...], 'goto_step': N}, {'pattern': '...', 'goto_step': N}]"
description="Conditions for next step: [{'keywords': [...], 'goto_step': N}]"
)
default_next: int | None = Field(default=None, description="Default next step if no condition matches")

View File

@ -0,0 +1,18 @@
"""
Guardrail services for AI Service.
[AC-AISVC-78~AC-AISVC-85] Output guardrail with forbidden word detection and behavior rules.
"""
from app.services.guardrail.behavior_service import BehaviorRuleService
from app.services.guardrail.input_scanner import InputScanner
from app.services.guardrail.output_filter import OutputFilter
from app.services.guardrail.streaming_filter import StreamingGuardrail
from app.services.guardrail.word_service import ForbiddenWordService
__all__ = [
"ForbiddenWordService",
"BehaviorRuleService",
"InputScanner",
"OutputFilter",
"StreamingGuardrail",
]

View File

@ -0,0 +1,260 @@
"""
Behavior rule service for AI Service.
[AC-AISVC-84, AC-AISVC-85] Behavior rule CRUD management.
"""
import logging
import time
import uuid
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import (
BehaviorRule,
BehaviorRuleCategory,
BehaviorRuleCreate,
BehaviorRuleUpdate,
)
logger = logging.getLogger(__name__)
BEHAVIOR_CACHE_TTL_SECONDS = 60
class BehaviorRuleCache:
"""
[AC-AISVC-84] In-memory cache for behavior rules.
Key: tenant_id
Value: (rules_list, cached_at)
TTL: 60 seconds
"""
def __init__(self, ttl_seconds: int = BEHAVIOR_CACHE_TTL_SECONDS):
self._cache: dict[str, tuple[list[BehaviorRule], float]] = {}
self._ttl = ttl_seconds
def get(self, tenant_id: str) -> list[BehaviorRule] | None:
"""Get cached rules if not expired."""
if tenant_id in self._cache:
rules, cached_at = self._cache[tenant_id]
if time.time() - cached_at < self._ttl:
return rules
else:
del self._cache[tenant_id]
return None
def set(self, tenant_id: str, rules: list[BehaviorRule]) -> None:
"""Cache rules for a tenant."""
self._cache[tenant_id] = (rules, time.time())
def invalidate(self, tenant_id: str) -> None:
"""Invalidate cache for a tenant."""
if tenant_id in self._cache:
del self._cache[tenant_id]
logger.debug(f"Invalidated behavior rule cache for tenant={tenant_id}")
_behavior_cache = BehaviorRuleCache()
class BehaviorRuleService:
"""
[AC-AISVC-84, AC-AISVC-85] Service for managing behavior rules.
Features:
- Rule CRUD with tenant isolation
- In-memory caching with TTL
- Cache invalidation on CRUD operations
- Rules are injected into Prompt system instruction
"""
VALID_CATEGORIES = [c.value for c in BehaviorRuleCategory]
def __init__(self, session: AsyncSession):
self._session = session
self._cache = _behavior_cache
async def create_rule(
self,
tenant_id: str,
create_data: BehaviorRuleCreate,
) -> BehaviorRule:
"""
[AC-AISVC-84] Create a new behavior rule.
"""
if create_data.category not in self.VALID_CATEGORIES:
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
rule = BehaviorRule(
tenant_id=tenant_id,
rule_text=create_data.rule_text,
category=create_data.category,
is_enabled=True,
)
self._session.add(rule)
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-84] Created behavior rule: tenant={tenant_id}, "
f"id={rule.id}, category={rule.category}"
)
return rule
async def list_rules(
self,
tenant_id: str,
category: str | None = None,
) -> Sequence[BehaviorRule]:
"""
[AC-AISVC-85] List rules for a tenant with optional filters.
"""
stmt = select(BehaviorRule).where(
BehaviorRule.tenant_id == tenant_id # type: ignore
)
if category is not None:
stmt = stmt.where(BehaviorRule.category == category) # type: ignore
stmt = stmt.order_by(col(BehaviorRule.category), col(BehaviorRule.created_at).desc())
result = await self._session.execute(stmt)
return result.scalars().all()
async def get_rule(
self,
tenant_id: str,
rule_id: uuid.UUID,
) -> BehaviorRule | None:
"""
[AC-AISVC-85] Get rule by ID with tenant isolation.
"""
stmt = select(BehaviorRule).where(
BehaviorRule.tenant_id == tenant_id, # type: ignore
BehaviorRule.id == rule_id, # type: ignore
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_rule(
self,
tenant_id: str,
rule_id: uuid.UUID,
update_data: BehaviorRuleUpdate,
) -> BehaviorRule | None:
"""
[AC-AISVC-85] Update a behavior rule.
"""
rule = await self.get_rule(tenant_id, rule_id)
if not rule:
return None
if update_data.rule_text is not None:
rule.rule_text = update_data.rule_text
if update_data.category is not None:
if update_data.category not in self.VALID_CATEGORIES:
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
rule.category = update_data.category
if update_data.is_enabled is not None:
rule.is_enabled = update_data.is_enabled
rule.updated_at = datetime.utcnow()
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-85] Updated behavior rule: tenant={tenant_id}, id={rule_id}"
)
return rule
async def delete_rule(
self,
tenant_id: str,
rule_id: uuid.UUID,
) -> bool:
"""
[AC-AISVC-85] Delete a behavior rule.
"""
rule = await self.get_rule(tenant_id, rule_id)
if not rule:
return False
await self._session.delete(rule)
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-85] Deleted behavior rule: tenant={tenant_id}, id={rule_id}"
)
return True
async def get_enabled_rules_for_injection(
self,
tenant_id: str,
) -> list[BehaviorRule]:
"""
[AC-AISVC-84] Get enabled rules for Prompt injection.
Uses cache for performance.
"""
cached = self._cache.get(tenant_id)
if cached is not None:
logger.debug(f"[AC-AISVC-84] Cache hit for behavior rules: tenant={tenant_id}")
return cached
stmt = (
select(BehaviorRule)
.where(
BehaviorRule.tenant_id == tenant_id, # type: ignore
BehaviorRule.is_enabled == True, # type: ignore
)
.order_by(col(BehaviorRule.category))
)
result = await self._session.execute(stmt)
rules = list(result.scalars().all())
self._cache.set(tenant_id, rules)
logger.info(
f"[AC-AISVC-84] Loaded {len(rules)} enabled behavior rules from DB: tenant={tenant_id}"
)
return rules
def invalidate_cache(self, tenant_id: str) -> None:
"""Manually invalidate cache for a tenant."""
self._cache.invalidate(tenant_id)
async def rule_to_info_dict(self, rule: BehaviorRule) -> dict[str, Any]:
"""Convert rule entity to API response dict."""
return {
"id": str(rule.id),
"rule_text": rule.rule_text,
"category": rule.category,
"is_enabled": rule.is_enabled,
"created_at": rule.created_at.isoformat(),
"updated_at": rule.updated_at.isoformat(),
}
async def format_rules_for_prompt(
self,
tenant_id: str,
) -> str:
"""
[AC-AISVC-84] Format behavior rules for Prompt injection.
Returns formatted string to append to system instruction.
"""
rules = await self.get_enabled_rules_for_injection(tenant_id)
if not rules:
return ""
lines = ["\n\n[行为约束 - 以下规则必须严格遵守]"]
for i, rule in enumerate(rules, 1):
lines.append(f"{i}. {rule.rule_text}")
return "\n".join(lines)

View File

@ -0,0 +1,110 @@
"""
Input scanner for AI Service.
[AC-AISVC-83] User input pre-detection (logging only, no blocking).
"""
import logging
from typing import Any
from app.models.entities import (
ForbiddenWord,
InputScanResult,
)
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
class InputScanner:
"""
[AC-AISVC-83] Input scanner for pre-detection of forbidden words.
Features:
- Scans user input for forbidden words
- Records matched words and categories in metadata
- Does NOT block the request (only logging)
- Used for monitoring and analytics
"""
def __init__(self, word_service: ForbiddenWordService):
self._word_service = word_service
async def scan(
self,
text: str,
tenant_id: str,
) -> InputScanResult:
"""
[AC-AISVC-83] Scan user input for forbidden words.
Args:
text: User input text to scan
tenant_id: Tenant ID for isolation
Returns:
InputScanResult with flagged status and matched words
"""
if not text or not text.strip():
return InputScanResult(flagged=False)
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
if not words:
return InputScanResult(flagged=False)
matched_words: list[str] = []
matched_categories: list[str] = []
matched_word_entities: list[ForbiddenWord] = []
for word in words:
if word.word in text:
matched_words.append(word.word)
if word.category not in matched_categories:
matched_categories.append(word.category)
matched_word_entities.append(word)
if matched_words:
logger.info(
f"[AC-AISVC-83] Input flagged: tenant={tenant_id}, "
f"matched_words={matched_words}, categories={matched_categories}"
)
for word_entity in matched_word_entities:
try:
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
except Exception as e:
logger.warning(
f"Failed to increment hit count for word {word_entity.id}: {e}"
)
return InputScanResult(
flagged=len(matched_words) > 0,
matched_words=matched_words,
matched_categories=matched_categories,
)
async def scan_and_enrich_metadata(
self,
text: str,
tenant_id: str,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
[AC-AISVC-83] Scan input and enrich metadata with scan result.
Args:
text: User input text to scan
tenant_id: Tenant ID for isolation
metadata: Existing metadata dict to enrich
Returns:
Enriched metadata with input_flagged and matched info
"""
result = await self.scan(text, tenant_id)
if metadata is None:
metadata = {}
metadata.update(result.to_dict())
return metadata

View File

@ -0,0 +1,154 @@
"""
Output filter for AI Service.
[AC-AISVC-82] LLM output post-filtering with mask/replace/block strategies.
"""
import logging
from typing import Any
from app.models.entities import (
ForbiddenWord,
ForbiddenWordStrategy,
GuardrailResult,
)
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
class OutputFilter:
"""
[AC-AISVC-82] Output filter for post-filtering LLM responses.
Features:
- Scans LLM output for forbidden words
- Applies mask/replace/block strategies
- Returns fallback reply for block strategy
- Records triggered words in metadata
"""
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
def __init__(self, word_service: ForbiddenWordService):
self._word_service = word_service
async def filter(
self,
reply: str,
tenant_id: str,
) -> GuardrailResult:
"""
[AC-AISVC-82] Filter LLM output for forbidden words.
Args:
reply: LLM generated reply to filter
tenant_id: Tenant ID for isolation
Returns:
GuardrailResult with filtered reply and trigger info
"""
if not reply or not reply.strip():
return GuardrailResult(reply=reply)
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
if not words:
return GuardrailResult(reply=reply)
triggered_words: list[str] = []
triggered_categories: list[str] = []
filtered_reply = reply
blocked = False
fallback_reply = self.DEFAULT_FALLBACK_REPLY
for word in words:
if word.word in filtered_reply:
triggered_words.append(word.word)
if word.category not in triggered_categories:
triggered_categories.append(word.category)
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
blocked = True
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
logger.warning(
f"[AC-AISVC-82] Output blocked by forbidden word: tenant={tenant_id}, "
f"word={word.word}, category={word.category}"
)
break
elif word.strategy == ForbiddenWordStrategy.MASK.value:
filtered_reply = filtered_reply.replace(word.word, "*" * len(word.word))
logger.info(
f"[AC-AISVC-82] Output masked: tenant={tenant_id}, word={word.word}"
)
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
replacement = word.replacement or ""
filtered_reply = filtered_reply.replace(word.word, replacement)
logger.info(
f"[AC-AISVC-82] Output replaced: tenant={tenant_id}, "
f"word={word.word} -> {replacement}"
)
if blocked:
return GuardrailResult(
reply=fallback_reply,
blocked=True,
triggered_words=triggered_words,
triggered_categories=triggered_categories,
)
if triggered_words:
logger.info(
f"[AC-AISVC-82] Output filtered: tenant={tenant_id}, "
f"triggered_words={triggered_words}, categories={triggered_categories}"
)
for word_entity in self._get_triggered_word_entities(words, triggered_words):
try:
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
except Exception as e:
logger.warning(
f"Failed to increment hit count for word {word_entity.id}: {e}"
)
return GuardrailResult(
reply=filtered_reply,
blocked=False,
triggered_words=triggered_words,
triggered_categories=triggered_categories,
)
def _get_triggered_word_entities(
self,
words: list[ForbiddenWord],
triggered_words: list[str],
) -> list[ForbiddenWord]:
"""Get word entities for triggered words."""
return [w for w in words if w.word in triggered_words]
async def filter_and_enrich_metadata(
self,
reply: str,
tenant_id: str,
metadata: dict[str, Any] | None = None,
) -> tuple[str, dict[str, Any]]:
"""
[AC-AISVC-82] Filter output and enrich metadata with filter result.
Args:
reply: LLM generated reply to filter
tenant_id: Tenant ID for isolation
metadata: Existing metadata dict to enrich
Returns:
Tuple of (filtered_reply, enriched_metadata)
"""
result = await self.filter(reply, tenant_id)
if metadata is None:
metadata = {}
metadata.update(result.to_dict())
return result.reply, metadata

View File

@ -0,0 +1,366 @@
"""
Streaming guardrail for AI Service.
[AC-AISVC-82] Streaming mode forbidden word detection with sliding window buffer.
"""
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from app.models.entities import (
ForbiddenWord,
ForbiddenWordStrategy,
)
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
class StreamingGuardrailState(str, Enum):
"""State of streaming guardrail."""
ACTIVE = "active"
BLOCKED = "blocked"
COMPLETED = "completed"
@dataclass
class StreamingGuardrailResult:
"""Result from streaming guardrail processing."""
delta: str
should_stop: bool = False
fallback_reply: str | None = None
triggered_words: list[str] = field(default_factory=list)
triggered_categories: list[str] = field(default_factory=list)
class StreamingGuardrail:
"""
[AC-AISVC-82] Streaming guardrail with sliding window buffer.
Features:
- Maintains a sliding window buffer for forbidden word detection
- Detects forbidden words across chunk boundaries
- Applies mask/replace strategies incrementally
- Stops streaming and returns fallback for block strategy
- Final check at stream end
"""
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
DEFAULT_WINDOW_SIZE = 50
def __init__(
self,
word_service: ForbiddenWordService,
window_size: int = DEFAULT_WINDOW_SIZE,
):
self._word_service = word_service
self._window_size = window_size
self._buffer = ""
self._state = StreamingGuardrailState.ACTIVE
self._triggered_words: list[str] = []
self._triggered_categories: list[str] = []
self._fallback_reply = self.DEFAULT_FALLBACK_REPLY
self._words_cache: list[ForbiddenWord] | None = None
self._max_word_length = 0
async def initialize(self, tenant_id: str) -> None:
"""
Initialize the guardrail with tenant's forbidden words.
Must be called before processing chunks.
"""
self._words_cache = await self._word_service.get_enabled_words_for_filtering(tenant_id)
if self._words_cache:
self._max_word_length = max(len(w.word) for w in self._words_cache)
effective_window = max(self._window_size, self._max_word_length)
if effective_window > self._window_size:
logger.info(
f"[AC-AISVC-82] Adjusted window size from {self._window_size} to {effective_window} "
f"to accommodate max word length {self._max_word_length}"
)
self._window_size = effective_window
logger.debug(
f"[AC-AISVC-82] StreamingGuardrail initialized: "
f"words_count={len(self._words_cache) if self._words_cache else 0}, "
f"window_size={self._window_size}"
)
async def process_chunk(
self,
delta: str,
tenant_id: str,
) -> StreamingGuardrailResult:
"""
Process a streaming chunk with sliding window detection.
Args:
delta: New text chunk from LLM
tenant_id: Tenant ID for isolation
Returns:
StreamingGuardrailResult with processed delta and state
"""
if self._state == StreamingGuardrailState.BLOCKED:
return StreamingGuardrailResult(
delta="",
should_stop=True,
fallback_reply=self._fallback_reply,
)
if self._state == StreamingGuardrailState.COMPLETED:
return StreamingGuardrailResult(delta="")
if not self._words_cache:
await self.initialize(tenant_id)
if not self._words_cache:
return StreamingGuardrailResult(delta=delta)
self._buffer += delta
result = self._check_buffer()
if result.should_stop:
self._state = StreamingGuardrailState.BLOCKED
logger.warning(
f"[AC-AISVC-82] Streaming blocked: tenant={tenant_id}, "
f"triggered_words={result.triggered_words}"
)
return result
def _check_buffer(self) -> StreamingGuardrailResult:
"""
Check buffer for forbidden words and apply strategies.
"""
triggered_words: list[str] = []
triggered_categories: list[str] = []
filtered_buffer = self._buffer
blocked = False
fallback_reply = self.DEFAULT_FALLBACK_REPLY
if not self._words_cache:
return StreamingGuardrailResult(delta=self._buffer)
for word in self._words_cache:
if word.word in filtered_buffer:
triggered_words.append(word.word)
if word.category not in triggered_categories:
triggered_categories.append(word.category)
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
blocked = True
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
break
elif word.strategy == ForbiddenWordStrategy.MASK.value:
filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word))
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
replacement = word.replacement or ""
filtered_buffer = filtered_buffer.replace(word.word, replacement)
self._triggered_words.extend([w for w in triggered_words if w not in self._triggered_words])
self._triggered_categories.extend([c for c in triggered_categories if c not in self._triggered_categories])
if blocked:
return StreamingGuardrailResult(
delta="",
should_stop=True,
fallback_reply=fallback_reply,
triggered_words=triggered_words,
triggered_categories=triggered_categories,
)
safe_output, remaining = self._split_buffer(filtered_buffer)
self._buffer = remaining
return StreamingGuardrailResult(
delta=safe_output,
should_stop=False,
triggered_words=triggered_words,
triggered_categories=triggered_categories,
)
def _split_buffer(self, buffer: str) -> tuple[str, str]:
"""
Split buffer into safe output and remaining buffer.
Safe output: characters that are definitely safe (before window boundary)
Remaining: characters that might be part of a forbidden word
"""
if len(buffer) <= self._window_size:
return "", buffer
safe_length = len(buffer) - self._window_size
safe_output = buffer[:safe_length]
remaining = buffer[safe_length:]
return safe_output, remaining
async def finalize(self, tenant_id: str) -> StreamingGuardrailResult:
"""
Finalize the stream and process remaining buffer.
Must be called at the end of streaming.
Args:
tenant_id: Tenant ID for isolation
Returns:
StreamingGuardrailResult with final delta
"""
if self._state == StreamingGuardrailState.BLOCKED:
self._state = StreamingGuardrailState.COMPLETED
return StreamingGuardrailResult(
delta="",
should_stop=True,
fallback_reply=self._fallback_reply,
triggered_words=self._triggered_words,
triggered_categories=self._triggered_categories,
)
if not self._words_cache:
await self.initialize(tenant_id)
result = self._final_check()
self._state = StreamingGuardrailState.COMPLETED
if self._triggered_words:
logger.info(
f"[AC-AISVC-82] Streaming finalized with triggers: tenant={tenant_id}, "
f"triggered_words={self._triggered_words}"
)
for word_entity in self._get_triggered_word_entities():
try:
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
except Exception as e:
logger.warning(
f"Failed to increment hit count for word {word_entity.id}: {e}"
)
return result
def _final_check(self) -> StreamingGuardrailResult:
"""
Final check of remaining buffer.
"""
if not self._buffer:
return StreamingGuardrailResult(
delta="",
triggered_words=self._triggered_words,
triggered_categories=self._triggered_categories,
)
if not self._words_cache:
return StreamingGuardrailResult(
delta=self._buffer,
triggered_words=self._triggered_words,
triggered_categories=self._triggered_categories,
)
filtered_buffer = self._buffer
blocked = False
fallback_reply = self.DEFAULT_FALLBACK_REPLY
for word in self._words_cache:
if word.word in filtered_buffer:
if word.word not in self._triggered_words:
self._triggered_words.append(word.word)
if word.category not in self._triggered_categories:
self._triggered_categories.append(word.category)
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
blocked = True
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
break
elif word.strategy == ForbiddenWordStrategy.MASK.value:
filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word))
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
replacement = word.replacement or ""
filtered_buffer = filtered_buffer.replace(word.word, replacement)
if blocked:
return StreamingGuardrailResult(
delta="",
should_stop=True,
fallback_reply=fallback_reply,
triggered_words=self._triggered_words,
triggered_categories=self._triggered_categories,
)
self._buffer = ""
return StreamingGuardrailResult(
delta=filtered_buffer,
triggered_words=self._triggered_words,
triggered_categories=self._triggered_categories,
)
def _get_triggered_word_entities(self) -> list[ForbiddenWord]:
"""Get word entities for triggered words."""
if not self._words_cache:
return []
return [w for w in self._words_cache if w.word in self._triggered_words]
def get_triggered_info(self) -> dict[str, Any]:
"""Get triggered words and categories info."""
return {
"triggered_words": self._triggered_words,
"triggered_categories": self._triggered_categories,
"guardrail_triggered": len(self._triggered_words) > 0,
}
def reset(self) -> None:
"""Reset the guardrail state for reuse."""
self._buffer = ""
self._state = StreamingGuardrailState.ACTIVE
self._triggered_words = []
self._triggered_categories = []
self._fallback_reply = self.DEFAULT_FALLBACK_REPLY
self._words_cache = None
self._max_word_length = 0
async def wrap_stream_with_guardrail(
stream: AsyncIterator[str],
guardrail: StreamingGuardrail,
tenant_id: str,
) -> AsyncIterator[tuple[str, bool, str | None]]:
"""
Wrap an async stream with guardrail processing.
Args:
stream: Original LLM output stream
guardrail: StreamingGuardrail instance
tenant_id: Tenant ID for isolation
Yields:
Tuple of (delta, should_stop, fallback_reply)
"""
await guardrail.initialize(tenant_id)
async for delta in stream:
result = await guardrail.process_chunk(delta, tenant_id)
if result.delta:
yield (result.delta, False, None)
if result.should_stop:
yield ("", True, result.fallback_reply)
return
final_result = await guardrail.finalize(tenant_id)
if final_result.delta:
yield (final_result.delta, False, None)
if final_result.should_stop:
yield ("", True, final_result.fallback_reply)

View File

@ -0,0 +1,297 @@
"""
Forbidden word service for AI Service.
[AC-AISVC-78~AC-AISVC-81] Forbidden word CRUD and hit statistics management.
"""
import logging
import time
import uuid
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import (
ForbiddenWord,
ForbiddenWordCategory,
ForbiddenWordCreate,
ForbiddenWordStrategy,
ForbiddenWordUpdate,
)
logger = logging.getLogger(__name__)
WORD_CACHE_TTL_SECONDS = 60
class WordCache:
"""
[AC-AISVC-82] In-memory cache for forbidden words.
Key: tenant_id
Value: (words_list, cached_at)
TTL: 60 seconds
"""
def __init__(self, ttl_seconds: int = WORD_CACHE_TTL_SECONDS):
self._cache: dict[str, tuple[list[ForbiddenWord], float]] = {}
self._ttl = ttl_seconds
def get(self, tenant_id: str) -> list[ForbiddenWord] | None:
"""Get cached words if not expired."""
if tenant_id in self._cache:
words, cached_at = self._cache[tenant_id]
if time.time() - cached_at < self._ttl:
return words
else:
del self._cache[tenant_id]
return None
def set(self, tenant_id: str, words: list[ForbiddenWord]) -> None:
"""Cache words for a tenant."""
self._cache[tenant_id] = (words, time.time())
def invalidate(self, tenant_id: str) -> None:
"""Invalidate cache for a tenant."""
if tenant_id in self._cache:
del self._cache[tenant_id]
logger.debug(f"Invalidated word cache for tenant={tenant_id}")
_word_cache = WordCache()
class ForbiddenWordService:
"""
[AC-AISVC-78~AC-AISVC-81] Service for managing forbidden words.
Features:
- Word CRUD with tenant isolation
- Hit count statistics
- In-memory caching with TTL
- Cache invalidation on CRUD operations
- Support for mask/replace/block strategies
"""
VALID_CATEGORIES = [c.value for c in ForbiddenWordCategory]
VALID_STRATEGIES = [s.value for s in ForbiddenWordStrategy]
def __init__(self, session: AsyncSession):
self._session = session
self._cache = _word_cache
async def create_word(
self,
tenant_id: str,
create_data: ForbiddenWordCreate,
) -> ForbiddenWord:
"""
[AC-AISVC-78] Create a new forbidden word.
"""
if create_data.category not in self.VALID_CATEGORIES:
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
if create_data.strategy not in self.VALID_STRATEGIES:
raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}")
if create_data.strategy == ForbiddenWordStrategy.REPLACE.value and not create_data.replacement:
raise ValueError("replacement is required when strategy is 'replace'")
if create_data.strategy == ForbiddenWordStrategy.BLOCK.value and not create_data.fallback_reply:
logger.warning(
f"[AC-AISVC-78] Creating block word without fallback_reply: tenant={tenant_id}, word={create_data.word}"
)
word = ForbiddenWord(
tenant_id=tenant_id,
word=create_data.word,
category=create_data.category,
strategy=create_data.strategy,
replacement=create_data.replacement,
fallback_reply=create_data.fallback_reply,
is_enabled=True,
hit_count=0,
)
self._session.add(word)
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-78] Created forbidden word: tenant={tenant_id}, "
f"id={word.id}, word={word.word}, strategy={word.strategy}"
)
return word
async def list_words(
self,
tenant_id: str,
category: str | None = None,
is_enabled: bool | None = None,
) -> Sequence[ForbiddenWord]:
"""
[AC-AISVC-79] List words for a tenant with optional filters.
"""
stmt = select(ForbiddenWord).where(
ForbiddenWord.tenant_id == tenant_id # type: ignore
)
if category is not None:
stmt = stmt.where(ForbiddenWord.category == category) # type: ignore
if is_enabled is not None:
stmt = stmt.where(ForbiddenWord.is_enabled == is_enabled) # type: ignore
stmt = stmt.order_by(col(ForbiddenWord.category), col(ForbiddenWord.created_at).desc())
result = await self._session.execute(stmt)
return result.scalars().all()
async def get_word(
self,
tenant_id: str,
word_id: uuid.UUID,
) -> ForbiddenWord | None:
"""
[AC-AISVC-79] Get word by ID with tenant isolation.
"""
stmt = select(ForbiddenWord).where(
ForbiddenWord.tenant_id == tenant_id, # type: ignore
ForbiddenWord.id == word_id, # type: ignore
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_word(
self,
tenant_id: str,
word_id: uuid.UUID,
update_data: ForbiddenWordUpdate,
) -> ForbiddenWord | None:
"""
[AC-AISVC-80] Update a forbidden word.
"""
word = await self.get_word(tenant_id, word_id)
if not word:
return None
if update_data.word is not None:
word.word = update_data.word
if update_data.category is not None:
if update_data.category not in self.VALID_CATEGORIES:
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
word.category = update_data.category
if update_data.strategy is not None:
if update_data.strategy not in self.VALID_STRATEGIES:
raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}")
word.strategy = update_data.strategy
if update_data.replacement is not None:
word.replacement = update_data.replacement
if update_data.fallback_reply is not None:
word.fallback_reply = update_data.fallback_reply
if update_data.is_enabled is not None:
word.is_enabled = update_data.is_enabled
word.updated_at = datetime.utcnow()
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-80] Updated forbidden word: tenant={tenant_id}, id={word_id}"
)
return word
async def delete_word(
self,
tenant_id: str,
word_id: uuid.UUID,
) -> bool:
"""
[AC-AISVC-81] Delete a forbidden word.
"""
word = await self.get_word(tenant_id, word_id)
if not word:
return False
await self._session.delete(word)
await self._session.flush()
self._cache.invalidate(tenant_id)
logger.info(
f"[AC-AISVC-81] Deleted forbidden word: tenant={tenant_id}, id={word_id}"
)
return True
async def increment_hit_count(
self,
tenant_id: str,
word_id: uuid.UUID,
) -> bool:
"""
[AC-AISVC-79] Increment hit count for a word.
"""
word = await self.get_word(tenant_id, word_id)
if not word:
return False
word.hit_count += 1
word.updated_at = datetime.utcnow()
await self._session.flush()
logger.debug(
f"[AC-AISVC-79] Incremented hit count for word: tenant={tenant_id}, "
f"id={word_id}, hit_count={word.hit_count}"
)
return True
async def get_enabled_words_for_filtering(
self,
tenant_id: str,
) -> list[ForbiddenWord]:
"""
[AC-AISVC-82] Get enabled words for filtering.
Uses cache for performance.
"""
cached = self._cache.get(tenant_id)
if cached is not None:
logger.debug(f"[AC-AISVC-82] Cache hit for words: tenant={tenant_id}")
return cached
stmt = (
select(ForbiddenWord)
.where(
ForbiddenWord.tenant_id == tenant_id, # type: ignore
ForbiddenWord.is_enabled == True, # type: ignore
)
.order_by(col(ForbiddenWord.category))
)
result = await self._session.execute(stmt)
words = list(result.scalars().all())
self._cache.set(tenant_id, words)
logger.info(
f"[AC-AISVC-82] Loaded {len(words)} enabled words from DB: tenant={tenant_id}"
)
return words
def invalidate_cache(self, tenant_id: str) -> None:
"""Manually invalidate cache for a tenant."""
self._cache.invalidate(tenant_id)
async def word_to_info_dict(self, word: ForbiddenWord) -> dict[str, Any]:
"""Convert word entity to API response dict."""
return {
"id": str(word.id),
"word": word.word,
"category": word.category,
"strategy": word.strategy,
"replacement": word.replacement,
"fallback_reply": word.fallback_reply,
"is_enabled": word.is_enabled,
"hit_count": word.hit_count,
"created_at": word.created_at.isoformat(),
"updated_at": word.updated_at.isoformat(),
}

View File

@ -6,7 +6,7 @@
- module: `ai-service`
- feature: `AISVC` (Python AI 中台)
- status: 🔄 进行中 (Phase 12)
- status: 🔄 进行中 (Phase 14 完成)
---
@ -35,27 +35,29 @@
- [x] Phase 10: Prompt 模板化 (80%) 🔄 (T10.1-T10.8 完成T10.9-T10.10 待集成阶段)
- [x] Phase 11: 多知识库管理 (63%) 🔄 (T11.1-T11.5 完成T11.6-T11.8 待集成阶段)
- [x] Phase 12: 意图识别与规则引擎 (71%) 🔄 (T12.1-T12.5 完成T12.6-T12.7 待集成阶段)
- [x] Phase 13: 话术流程引擎 (0%) ⏳ 待处理
- [x] Phase 14: 输出护栏 (88%) ✅ (T14.1-T14.7 完成T14.8 单元测试留到集成阶段)
---
## 🔄 Current Phase
### Goal
Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5)T11.6OptimizedRetriever 多 Collection 检索、T11.7kb_default 迁移、T11.8(单元测试)留待集成阶段。
Phase 14 输出护栏核心功能已完成 (T14.1-T14.7)T14.8(单元测试)留到集成阶段。
### Completed Tasks (Phase 11)
### Completed Tasks (Phase 14)
- [x] T11.1 扩展 `KnowledgeBase` 实体:新增 `kb_type`、`priority`、`is_enabled`、`doc_count` 字段 `[AC-AISVC-59]`
- [x] T11.2 实现知识库 CRUD 服务:创建时初始化 Qdrant Collection删除时清理 Collection `[AC-AISVC-59, AC-AISVC-61, AC-AISVC-62]`
- [x] T11.3 实现知识库管理 API`POST/GET/PUT/DELETE /admin/kb/knowledge-bases` `[AC-AISVC-59, AC-AISVC-60, AC-AISVC-61, AC-AISVC-62]`
- [x] T11.4 升级 Qdrant Collection 命名:`kb_{tenant_id}_{kb_id}`,兼容现有 `kb_{tenant_id}` `[AC-AISVC-63]`
- [x] T11.5 修改文档上传流程:支持指定 `kbId` 参数,索引到对应 Collection `[AC-AISVC-63]`
- [x] T14.1 定义 `ForbiddenWord``BehaviorRule` SQLModel 实体,创建数据库表 `[AC-AISVC-78, AC-AISVC-84]`
- [x] T14.2 实现 `ForbiddenWordService`:禁词 CRUD + 命中统计 `[AC-AISVC-78, AC-AISVC-79, AC-AISVC-80, AC-AISVC-81]`
- [x] T14.3 实现 `BehaviorRuleService`:行为规则 CRUD `[AC-AISVC-84, AC-AISVC-85]`
- [x] T14.4 实现 `InputScanner`:用户输入前置禁词检测(仅记录,不阻断) `[AC-AISVC-83]`
- [x] T14.5 实现 `OutputFilter`LLM 输出后置过滤mask/replace/block 三种策略) `[AC-AISVC-82]`
- [x] T14.6 实现 Streaming 模式下的滑动窗口禁词检测 `[AC-AISVC-82]`
- [x] T14.7 实现护栏管理 API`/admin/guardrails` 相关端点 `[AC-AISVC-78~AC-AISVC-85]`
### Pending Tasks (Phase 11 - 集成阶段)
### Pending Tasks (Phase 14 - 集成阶段)
- [ ] T11.6 修改 `OptimizedRetriever`:支持 `target_kb_ids` 参数,实现多 Collection 并行检索 `[AC-AISVC-64]`
- [ ] T11.7 实现 `kb_default` 自动迁移:首次启动时为现有数据创建默认知识库记录 `[AC-AISVC-59]`
- [ ] T11.8 编写多知识库服务单元测试 `[AC-AISVC-59~AC-AISVC-64]`
- [ ] T14.8 编写输出护栏服务单元测试 `[AC-AISVC-78~AC-AISVC-85]`
---
@ -66,42 +68,73 @@ Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5)T11.6Optimiz
- `ai-service/`
- `app/`
- `api/` - FastAPI 路由层
- `admin/intent_rules.py` - 意图规则管理 API ✅
- `admin/prompt_templates.py` - Prompt 模板管理 API ✅
- `admin/guardrails.py` - 护栏管理 API ✅
- `models/` - Pydantic 模型和 SQLModel 实体
- `entities.py` - IntentRule, PromptTemplate, PromptTemplateVersion 实体 ✅
- `entities.py` - ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体 ✅
- `services/`
- `intent/` - 意图识别服务 ✅
- `guardrail/` - 输出护栏服务 ✅
- `__init__.py` - 模块导出
- `rule_service.py` - 规则 CRUD、命中统计、缓存
- `router.py` - IntentRouter 匹配引擎
- `prompt/` - Prompt 模板服务 ✅
- `__init__.py` - 模块导出
- `template_service.py` - 模板 CRUD、版本管理、发布/回滚、缓存
- `variable_resolver.py` - 变量替换引擎
- `word_service.py` - 禁词 CRUD、命中统计、缓存
- `behavior_service.py` - 行为规则 CRUD、缓存、Prompt 注入格式化
- `input_scanner.py` - 用户输入前置检测(仅记录,不阻断)
- `output_filter.py` - LLM 输出后置过滤mask/replace/block
- `streaming_filter.py` - Streaming 滑动窗口检测
### Key Decisions (Why / Impact)
- decision: 意图规则数据库驱动
reason: 支持动态配置意图识别规则,无需重启服务
impact: 规则存储在 PostgreSQL支持按租户隔离
- decision: 三种禁词替换策略
reason: 满足不同场景的内容合规需求
impact: mask 星号替换、replace 指定文本替换、block 拦截整条回复返回兜底话术
- decision: 关键词 + 正则双匹配机制
reason: 关键词匹配快速高效,正则匹配支持复杂模式
impact: 先关键词匹配再正则匹配,优先级高的规则先匹配
- decision: 输入检测不阻断
reason: 用户输入包含禁词时仍需正常处理,仅记录用于监控分析
impact: InputScanner 返回 flagged 状态和匹配信息,不抛异常
- decision: Streaming 滑动窗口检测
reason: 流式输出无法预知完整内容,需要缓冲区检测跨 chunk 的禁词
impact: 维护滑动窗口 buffer默认 50 字符,自动调整到最长禁词长度),检测到禁词后立即停止
- decision: 行为规则注入 Prompt
reason: 行为规则作为 LLM 的行为约束,不进行运行时检测
impact: BehaviorRuleService 提供 format_rules_for_prompt() 方法,追加到系统指令末尾
- decision: 内存缓存 + TTL 策略
reason: 减少数据库查询,提升匹配性能
reason: 减少数据库查询,提升过滤性能
impact: 缓存 TTL=60sCRUD 操作时主动失效
- decision: 四种响应类型
reason: 支持不同的处理链路
impact: fixed 直接返回、rag 定向检索、flow 进入流程、transfer 转人工
---
## 🧾 Session History
### Session #10 (2026-02-27)
- completed:
- T14.1-T14.7 输出护栏核心功能
- 实现 ForbiddenWord 和 BehaviorRule 实体
- 实现 ForbiddenWordServiceCRUD、命中统计、缓存
- 实现 BehaviorRuleServiceCRUD、缓存、Prompt 注入格式化)
- 实现 InputScanner用户输入前置检测仅记录不阻断
- 实现 OutputFilterLLM 输出后置过滤mask/replace/block 三种策略)
- 实现 StreamingGuardrailStreaming 滑动窗口检测)
- 实现护栏管理 API禁词和行为规则 CRUD
- changes:
- 新增 `app/models/entities.py` ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体
- 新增 `app/services/guardrail/__init__.py` 模块导出
- 新增 `app/services/guardrail/word_service.py` 禁词服务
- 新增 `app/services/guardrail/behavior_service.py` 行为规则服务
- 新增 `app/services/guardrail/input_scanner.py` 输入扫描器
- 新增 `app/services/guardrail/output_filter.py` 输出过滤器
- 新增 `app/services/guardrail/streaming_filter.py` Streaming 过滤器
- 新增 `app/api/admin/guardrails.py` 护栏管理 API
- 更新 `app/api/admin/__init__.py` 导出新路由
- 更新 `app/main.py` 注册新路由
- 更新 `spec/ai-service/tasks.md` 标记任务完成
- notes:
- T14.8(单元测试)留到集成阶段
- 禁词检测三种策略mask星号替换、replace指定文本替换、block拦截返回兜底话术
- InputScanner 仅记录命中,不阻断请求
- OutputFilter 应用 mask/replace/block 策略
- StreamingGuardrail 使用滑动窗口 buffer默认 50 字符,自动调整)
### Session #9 (2026-02-27)
- completed:
- T11.1-T11.5 多知识库管理核心功能