feat: implement output guardrail with forbidden word detection and behavior rules [AC-AISVC-78~AC-AISVC-85]
This commit is contained in:
parent
9d8ecf0bb2
commit
8c259cee30
|
|
@ -0,0 +1,511 @@
|
||||||
|
# AI 中台对接文档
|
||||||
|
|
||||||
|
## 1. 概述
|
||||||
|
|
||||||
|
本文档描述 Python AI 中台对渠道侧(Java 主框架)暴露的 HTTP 接口规范,用于智能客服对话生成和服务健康检查。
|
||||||
|
|
||||||
|
### 1.1 服务信息
|
||||||
|
|
||||||
|
- **服务名称**: AI Service (Python AI 中台)
|
||||||
|
- **服务地址**: `http://ai-service:8080`
|
||||||
|
- **协议**: HTTP/1.1
|
||||||
|
- **数据格式**: JSON / SSE (Server-Sent Events)
|
||||||
|
- **字符编码**: UTF-8
|
||||||
|
- **契约版本**: v1.1.0
|
||||||
|
|
||||||
|
### 1.2 核心能力
|
||||||
|
|
||||||
|
- ✅ 智能对话生成(基于 LLM + RAG)
|
||||||
|
- ✅ 多租户隔离(基于 `X-Tenant-Id`)
|
||||||
|
- ✅ 会话上下文管理(基于 `sessionId`)
|
||||||
|
- ✅ 流式/非流式双模式输出
|
||||||
|
- ✅ 置信度评估与转人工建议
|
||||||
|
- ✅ 服务健康检查
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 认证与租户隔离
|
||||||
|
|
||||||
|
### 2.1 API Key 认证(必填)
|
||||||
|
|
||||||
|
所有接口请求(除健康检查外)必须在 HTTP Header 中携带 API Key:
|
||||||
|
|
||||||
|
```http
|
||||||
|
X-API-Key: <your_api_key>
|
||||||
|
```
|
||||||
|
|
||||||
|
**说明**:
|
||||||
|
- API Key 用于身份认证和访问控制
|
||||||
|
- 缺失或无效的 API Key 将返回 `401 Unauthorized`
|
||||||
|
- API Key 由 AI 中台管理员分配,请妥善保管
|
||||||
|
- 以下路径无需 API Key:`/health`、`/ai/health`、`/docs`
|
||||||
|
|
||||||
|
### 2.2 租户标识(必填)
|
||||||
|
|
||||||
|
所有接口请求必须在 HTTP Header 中携带租户 ID:
|
||||||
|
|
||||||
|
```http
|
||||||
|
X-Tenant-Id: <tenant_id>
|
||||||
|
```
|
||||||
|
|
||||||
|
**租户 ID 格式规范**:`name@ash@year`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
- `szmp@ash@2026` - 深圳某项目 2026 年
|
||||||
|
- `abc123@ash@2025` - ABC 项目 2025 年
|
||||||
|
|
||||||
|
**说明**:
|
||||||
|
- 租户 ID 用于数据隔离(知识库、会话历史、配置等)
|
||||||
|
- 缺失或格式错误的租户 ID 将返回 `400 Bad Request`
|
||||||
|
- 不同租户的数据完全隔离,不可跨租户访问
|
||||||
|
- 租户不存在时会自动创建
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 接口列表
|
||||||
|
|
||||||
|
| 接口路径 | 方法 | 功能 | 响应模式 |
|
||||||
|
|---------|------|------|---------|
|
||||||
|
| `/ai/chat` | POST | 生成 AI 回复 | JSON / SSE |
|
||||||
|
| `/ai/health` | GET | 健康检查 | JSON |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 接口详细说明
|
||||||
|
|
||||||
|
### 4.1 生成 AI 回复
|
||||||
|
|
||||||
|
**接口路径**: `POST /ai/chat`
|
||||||
|
|
||||||
|
**功能描述**: 根据用户消息和会话历史生成 AI 回复,支持 RAG 检索增强、上下文管理、置信度评估。
|
||||||
|
|
||||||
|
#### 4.1.1 请求参数
|
||||||
|
|
||||||
|
**Headers**:
|
||||||
|
```http
|
||||||
|
Content-Type: application/json
|
||||||
|
X-API-Key: <your_api_key>
|
||||||
|
X-Tenant-Id: <tenant_id>
|
||||||
|
Accept: application/json # 或 text/event-stream(流式输出)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Body** (JSON):
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|-----|------|------|------|
|
||||||
|
| `sessionId` | string | ✅ | 会话 ID,用于关联同一会话的对话历史 |
|
||||||
|
| `currentMessage` | string | ✅ | 当前用户消息内容 |
|
||||||
|
| `channelType` | string | ✅ | 渠道类型,枚举值:`wechat`、`douyin`、`jd` |
|
||||||
|
| `history` | array | ❌ | 历史消息列表(可选,AI 中台会自动管理会话历史) |
|
||||||
|
| `metadata` | object | ❌ | 扩展元数据(可选) |
|
||||||
|
|
||||||
|
**history 数组元素结构**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"role": "user | assistant",
|
||||||
|
"content": "消息内容"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**请求示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sessionId": "kf_001_wx123456_1708765432000",
|
||||||
|
"currentMessage": "我想了解产品价格",
|
||||||
|
"channelType": "wechat",
|
||||||
|
"metadata": {
|
||||||
|
"channelUserId": "wx123456",
|
||||||
|
"extra": "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.1.2 响应格式
|
||||||
|
|
||||||
|
##### 模式 1: JSON 响应(非流式)
|
||||||
|
|
||||||
|
**状态码**: `200 OK`
|
||||||
|
|
||||||
|
**响应体**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"reply": "您好,我们的产品价格根据套餐不同有所差异...",
|
||||||
|
"confidence": 0.92,
|
||||||
|
"shouldTransfer": false,
|
||||||
|
"transferReason": null,
|
||||||
|
"metadata": {
|
||||||
|
"retrieval_count": 3,
|
||||||
|
"rag_enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**字段说明**:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|-----|------|------|------|
|
||||||
|
| `reply` | string | ✅ | AI 生成的回复内容 |
|
||||||
|
| `confidence` | number | ✅ | 置信度评分(0.0-1.0),越高表示回答越可靠 |
|
||||||
|
| `shouldTransfer` | boolean | ✅ | 是否建议转人工(true=建议转人工) |
|
||||||
|
| `transferReason` | string | ❌ | 转人工原因(可选) |
|
||||||
|
| `metadata` | object | ❌ | 响应元数据(可选) |
|
||||||
|
|
||||||
|
##### 模式 2: SSE 流式响应
|
||||||
|
|
||||||
|
**触发条件**: 请求头包含 `Accept: text/event-stream`
|
||||||
|
|
||||||
|
**响应头**:
|
||||||
|
```http
|
||||||
|
Content-Type: text/event-stream
|
||||||
|
Cache-Control: no-cache
|
||||||
|
Connection: keep-alive
|
||||||
|
```
|
||||||
|
|
||||||
|
**事件流格式**:
|
||||||
|
|
||||||
|
1. **增量消息事件** (可多次发送)
|
||||||
|
```
|
||||||
|
event: message
|
||||||
|
data: {"delta": "您好,"}
|
||||||
|
|
||||||
|
event: message
|
||||||
|
data: {"delta": "我们的产品"}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **最终结果事件** (发送一次后关闭连接)
|
||||||
|
```
|
||||||
|
event: final
|
||||||
|
data: {"reply": "完整回复内容", "confidence": 0.92, "shouldTransfer": false}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **错误事件** (发生错误时发送)
|
||||||
|
```
|
||||||
|
event: error
|
||||||
|
data: {"code": "INTERNAL_ERROR", "message": "错误描述"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**事件序列保证**:
|
||||||
|
- `message*` (0 或多次) → `final` (1 次) → 连接关闭
|
||||||
|
- 或 `message*` (0 或多次) → `error` (1 次) → 连接关闭
|
||||||
|
|
||||||
|
#### 4.1.3 错误响应
|
||||||
|
|
||||||
|
**401 Unauthorized** - 认证失败
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": "UNAUTHORIZED",
|
||||||
|
"message": "Missing required header: X-API-Key",
|
||||||
|
"details": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**400 Bad Request** - 请求参数错误
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": "INVALID_REQUEST",
|
||||||
|
"message": "缺少必填字段: sessionId",
|
||||||
|
"details": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**400 Bad Request** - 租户 ID 格式错误
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": "INVALID_TENANT_ID",
|
||||||
|
"message": "Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
|
||||||
|
"details": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**500 Internal Server Error** - 服务内部错误
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": "INTERNAL_ERROR",
|
||||||
|
"message": "LLM 调用失败",
|
||||||
|
"details": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**503 Service Unavailable** - 服务不可用
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": "SERVICE_UNAVAILABLE",
|
||||||
|
"message": "向量数据库连接失败",
|
||||||
|
"details": []
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.2 健康检查
|
||||||
|
|
||||||
|
**接口路径**: `GET /ai/health`
|
||||||
|
|
||||||
|
**功能描述**: 检查 AI 服务是否正常运行,用于服务监控和负载均衡健康探测。
|
||||||
|
|
||||||
|
#### 4.2.1 请求参数
|
||||||
|
|
||||||
|
无需请求参数,无需认证头。
|
||||||
|
|
||||||
|
#### 4.2.2 响应格式
|
||||||
|
|
||||||
|
**200 OK** - 服务正常
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "healthy"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**503 Service Unavailable** - 服务不健康
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "unhealthy"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 调用示例
|
||||||
|
|
||||||
|
### 5.1 Java 调用示例(非流式)
|
||||||
|
|
||||||
|
```java
|
||||||
|
import org.springframework.http.*;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
public class AIServiceClient {
|
||||||
|
|
||||||
|
private final RestTemplate restTemplate;
|
||||||
|
private final String aiServiceUrl = "http://ai-service:8080";
|
||||||
|
private final String apiKey = "your_api_key_here";
|
||||||
|
|
||||||
|
public ChatResponse generateReply(String tenantId, ChatRequest request) {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
headers.set("X-API-Key", apiKey);
|
||||||
|
headers.set("X-Tenant-Id", tenantId);
|
||||||
|
|
||||||
|
HttpEntity<ChatRequest> entity = new HttpEntity<>(request, headers);
|
||||||
|
|
||||||
|
ResponseEntity<ChatResponse> response = restTemplate.postForEntity(
|
||||||
|
aiServiceUrl + "/ai/chat",
|
||||||
|
entity,
|
||||||
|
ChatResponse.class
|
||||||
|
);
|
||||||
|
|
||||||
|
return response.getBody();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Java 调用示例(流式)
|
||||||
|
|
||||||
|
```java
|
||||||
|
import org.springframework.web.reactive.function.client.WebClient;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
public class AIServiceStreamClient {
|
||||||
|
|
||||||
|
private final WebClient webClient;
|
||||||
|
private final String apiKey = "your_api_key_here";
|
||||||
|
|
||||||
|
public Flux<ServerSentEvent<String>> generateReplyStream(
|
||||||
|
String tenantId,
|
||||||
|
ChatRequest request
|
||||||
|
) {
|
||||||
|
return webClient.post()
|
||||||
|
.uri("/ai/chat")
|
||||||
|
.header("X-API-Key", apiKey)
|
||||||
|
.header("X-Tenant-Id", tenantId)
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.bodyValue(request)
|
||||||
|
.retrieve()
|
||||||
|
.bodyToFlux(ServerSentEvent.class);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 cURL 调用示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 非流式调用
|
||||||
|
curl -X POST http://ai-service:8080/ai/chat \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "X-API-Key: your_api_key_here" \
|
||||||
|
-H "X-Tenant-Id: szmp@ash@2026" \
|
||||||
|
-d '{
|
||||||
|
"sessionId": "kf_001_wx123456_1708765432000",
|
||||||
|
"currentMessage": "我想了解产品价格",
|
||||||
|
"channelType": "wechat"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 流式调用
|
||||||
|
curl -X POST http://ai-service:8080/ai/chat \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "X-API-Key: your_api_key_here" \
|
||||||
|
-H "X-Tenant-Id: szmp@ash@2026" \
|
||||||
|
-H "Accept: text/event-stream" \
|
||||||
|
-d '{
|
||||||
|
"sessionId": "kf_001_wx123456_1708765432000",
|
||||||
|
"currentMessage": "我想了解产品价格",
|
||||||
|
"channelType": "wechat"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 健康检查(无需认证)
|
||||||
|
curl http://ai-service:8080/ai/health
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 业务逻辑说明
|
||||||
|
|
||||||
|
### 6.1 会话管理
|
||||||
|
|
||||||
|
- **会话标识**: `sessionId` 用于唯一标识一个对话会话
|
||||||
|
- **自动持久化**: AI 中台会自动保存会话历史,无需调用方每次传递完整历史
|
||||||
|
- **可选历史**: 调用方可通过 `history` 字段提供外部历史,AI 中台会合并处理
|
||||||
|
- **租户隔离**: 相同 `sessionId` 在不同 `tenantId` 下视为不同会话
|
||||||
|
|
||||||
|
### 6.2 RAG 检索增强
|
||||||
|
|
||||||
|
- **自动触发**: AI 中台会根据用户问题自动判断是否需要检索知识库
|
||||||
|
- **多知识库**: 支持按知识库类型(产品知识、FAQ、话术模板等)分类检索
|
||||||
|
- **置信度评估**: 检索结果质量会影响 `confidence` 评分
|
||||||
|
- **兜底策略**: 检索失败或无结果时,AI 会基于通用知识回答并降低置信度
|
||||||
|
|
||||||
|
### 6.3 转人工建议
|
||||||
|
|
||||||
|
`shouldTransfer` 字段由以下因素决定:
|
||||||
|
|
||||||
|
- ✅ 置信度低于阈值(默认 0.6)
|
||||||
|
- ✅ 检索无结果或结果质量差
|
||||||
|
- ✅ 用户明确要求人工服务
|
||||||
|
- ✅ 意图识别命中"转人工"规则
|
||||||
|
|
||||||
|
**注意**: `shouldTransfer=true` 仅为建议,最终是否转人工由调用方(Java 主框架)决策。
|
||||||
|
|
||||||
|
### 6.4 意图识别与规则引擎
|
||||||
|
|
||||||
|
- **前置处理**: 用户消息会先经过意图识别
|
||||||
|
- **固定回复**: 命中固定规则时直接返回预设话术(跳过 LLM 调用)
|
||||||
|
- **话术流程**: 命中流程规则时进入多轮引导对话
|
||||||
|
- **定向检索**: 命中 RAG 规则时使用指定知识库检索
|
||||||
|
|
||||||
|
### 6.5 输出护栏
|
||||||
|
|
||||||
|
- **禁词过滤**: AI 回复会自动过滤禁词(竞品名称、敏感词等)
|
||||||
|
- **替换策略**: 支持星号替换、文本替换、整条拦截三种策略
|
||||||
|
- **行为约束**: Prompt 中注入行为规则(如"不承诺具体赔偿金额")
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 性能与限制
|
||||||
|
|
||||||
|
### 7.1 性能指标
|
||||||
|
|
||||||
|
| 指标 | 非流式 | 流式 |
|
||||||
|
|-----|-------|------|
|
||||||
|
| 首字响应时间 | 1-3 秒 | 200-500 毫秒 |
|
||||||
|
| 完整响应时间 | 2-5 秒 | 3-8 秒 |
|
||||||
|
| 并发支持 | 100+ QPS | 50+ QPS |
|
||||||
|
|
||||||
|
### 7.2 限制说明
|
||||||
|
|
||||||
|
- **消息长度**: 单条消息最大 4000 字符
|
||||||
|
- **历史长度**: 建议历史消息不超过 20 轮(AI 中台会自动截断)
|
||||||
|
- **超时设置**: 建议调用方设置 10 秒超时(非流式)、30 秒超时(流式)
|
||||||
|
- **重试策略**: 503 错误建议指数退避重试,500 错误建议降级处理
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 错误码参考
|
||||||
|
|
||||||
|
| 错误码 | HTTP 状态码 | 说明 | 处理建议 |
|
||||||
|
|-------|-----------|------|---------|
|
||||||
|
| `UNAUTHORIZED` | 401 | 认证失败(缺少或无效 API Key) | 检查 X-API-Key 请求头 |
|
||||||
|
| `INVALID_REQUEST` | 400 | 请求参数错误 | 检查必填字段和参数格式 |
|
||||||
|
| `MISSING_TENANT_ID` | 400 | 缺少租户 ID | 添加 X-Tenant-Id 请求头 |
|
||||||
|
| `INVALID_TENANT_ID` | 400 | 租户 ID 格式错误 | 使用正确格式:name@ash@year |
|
||||||
|
| `INTERNAL_ERROR` | 500 | 服务内部错误 | 降级处理或重试 |
|
||||||
|
| `LLM_ERROR` | 500 | LLM 调用失败 | 降级处理或重试 |
|
||||||
|
| `SERVICE_UNAVAILABLE` | 503 | 服务不可用 | 指数退避重试 |
|
||||||
|
| `QDRANT_ERROR` | 503 | 向量库不可用 | 指数退避重试 |
|
||||||
|
| `STREAMING_ERROR` | 200 (SSE) | 流式传输错误 | 关闭连接并重试 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. 最佳实践
|
||||||
|
|
||||||
|
### 9.1 API Key 管理
|
||||||
|
|
||||||
|
- API Key 由 AI 中台管理员通过管理后台分配
|
||||||
|
- 建议为不同环境(开发/测试/生产)使用不同的 API Key
|
||||||
|
- API Key 应存储在配置文件或环境变量中,不要硬编码
|
||||||
|
- 定期轮换 API Key 以提高安全性
|
||||||
|
|
||||||
|
### 9.2 会话 ID 生成规范
|
||||||
|
|
||||||
|
建议格式: `{业务前缀}_{租户ID}_{渠道用户ID}_{时间戳}`
|
||||||
|
|
||||||
|
示例: `kf_001_wx123456_1708765432000`
|
||||||
|
|
||||||
|
### 9.3 流式 vs 非流式选择
|
||||||
|
|
||||||
|
- **流式**: 适用于 Web/App 实时对话场景,用户体验更好
|
||||||
|
- **非流式**: 适用于批量处理、异步任务、API 集成场景
|
||||||
|
|
||||||
|
### 9.4 降级策略建议
|
||||||
|
|
||||||
|
```java
|
||||||
|
public ChatResponse generateReplyWithFallback(String tenantId, ChatRequest request) {
|
||||||
|
try {
|
||||||
|
return aiServiceClient.generateReply(tenantId, request);
|
||||||
|
} catch (ServiceUnavailableException e) {
|
||||||
|
// 降级策略 1: 返回固定话术
|
||||||
|
return ChatResponse.builder()
|
||||||
|
.reply("抱歉,当前咨询量较大,请稍后再试或转人工服务。")
|
||||||
|
.confidence(0.0)
|
||||||
|
.shouldTransfer(true)
|
||||||
|
.build();
|
||||||
|
} catch (Exception e) {
|
||||||
|
// 降级策略 2: 直接转人工
|
||||||
|
return ChatResponse.builder()
|
||||||
|
.reply("系统繁忙,正在为您转接人工客服...")
|
||||||
|
.confidence(0.0)
|
||||||
|
.shouldTransfer(true)
|
||||||
|
.transferReason("AI 服务异常")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9.5 监控指标建议
|
||||||
|
|
||||||
|
- ✅ 接口响应时间(P50/P95/P99)
|
||||||
|
- ✅ 接口成功率
|
||||||
|
- ✅ 置信度分布
|
||||||
|
- ✅ 转人工率
|
||||||
|
- ✅ 错误码分布
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. 变更日志
|
||||||
|
|
||||||
|
| 版本 | 日期 | 变更内容 |
|
||||||
|
|-----|------|---------|
|
||||||
|
| v1.1.0 | 2026-02-27 | 新增流式输出支持、意图识别、输出护栏 |
|
||||||
|
| v1.0.0 | 2026-02-20 | 初始版本,支持基础对话生成和健康检查 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. 联系方式
|
||||||
|
|
||||||
|
- **技术支持**: AI 中台开发团队
|
||||||
|
- **问题反馈**: 提交 Issue 到项目仓库
|
||||||
|
- **文档更新**: 参考 `spec/ai-service/openapi.provider.yaml`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**文档生成时间**: 2026-02-27
|
||||||
|
**契约版本**: v1.1.0
|
||||||
|
**维护状态**: ✅ 活跃维护
|
||||||
|
|
@ -15,4 +15,5 @@ from app.api.admin.rag import router as rag_router
|
||||||
from app.api.admin.script_flows import router as script_flows_router
|
from app.api.admin.script_flows import router as script_flows_router
|
||||||
from app.api.admin.sessions import router as sessions_router
|
from app.api.admin.sessions import router as sessions_router
|
||||||
from app.api.admin.tenants import router as tenants_router
|
from app.api.admin.tenants import router as tenants_router
|
||||||
|
|
||||||
__all__ = ["api_key_router", "dashboard_router", "embedding_router", "guardrails_router", "intent_rules_router", "kb_router", "llm_router", "prompt_templates_router", "rag_router", "script_flows_router", "sessions_router", "tenants_router"]
|
__all__ = ["api_key_router", "dashboard_router", "embedding_router", "guardrails_router", "intent_rules_router", "kb_router", "llm_router", "prompt_templates_router", "rag_router", "script_flows_router", "sessions_router", "tenants_router"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,296 @@
|
||||||
|
"""
|
||||||
|
Guardrail Management API.
|
||||||
|
[AC-AISVC-78~AC-AISVC-85] Forbidden words and behavior rules CRUD endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_session
|
||||||
|
from app.models.entities import (
|
||||||
|
BehaviorRuleCreate,
|
||||||
|
BehaviorRuleUpdate,
|
||||||
|
ForbiddenWordCreate,
|
||||||
|
ForbiddenWordUpdate,
|
||||||
|
)
|
||||||
|
from app.services.guardrail.behavior_service import BehaviorRuleService
|
||||||
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin/guardrails", tags=["Guardrails"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||||
|
"""Extract tenant ID from header."""
|
||||||
|
if not x_tenant_id:
|
||||||
|
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
|
||||||
|
return x_tenant_id
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/forbidden-words")
|
||||||
|
async def list_forbidden_words(
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
category: str | None = None,
|
||||||
|
is_enabled: bool | None = None,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-79] List all forbidden words for a tenant.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-79] Listing forbidden words for tenant={tenant_id}, "
|
||||||
|
f"category={category}, is_enabled={is_enabled}"
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ForbiddenWordService(session)
|
||||||
|
words = await service.list_words(tenant_id, category, is_enabled)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for word in words:
|
||||||
|
data.append(await service.word_to_info_dict(word))
|
||||||
|
|
||||||
|
return {"data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/forbidden-words", status_code=201)
|
||||||
|
async def create_forbidden_word(
|
||||||
|
body: ForbiddenWordCreate,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-78] Create a new forbidden word.
|
||||||
|
"""
|
||||||
|
valid_categories = ["competitor", "sensitive", "political", "custom"]
|
||||||
|
if body.category not in valid_categories:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category. Must be one of: {valid_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_strategies = ["mask", "replace", "block"]
|
||||||
|
if body.strategy not in valid_strategies:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid strategy. Must be one of: {valid_strategies}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if body.strategy == "replace" and not body.replacement:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="replacement is required when strategy is 'replace'"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-78] Creating forbidden word for tenant={tenant_id}, word={body.word}"
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ForbiddenWordService(session)
|
||||||
|
try:
|
||||||
|
word = await service.create_word(tenant_id, body)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
return await service.word_to_info_dict(word)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/forbidden-words/{word_id}")
|
||||||
|
async def get_forbidden_word(
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-79] Get forbidden word detail.
|
||||||
|
"""
|
||||||
|
logger.info(f"[AC-AISVC-79] Getting forbidden word for tenant={tenant_id}, id={word_id}")
|
||||||
|
|
||||||
|
service = ForbiddenWordService(session)
|
||||||
|
word = await service.get_word(tenant_id, word_id)
|
||||||
|
|
||||||
|
if not word:
|
||||||
|
raise HTTPException(status_code=404, detail="Forbidden word not found")
|
||||||
|
|
||||||
|
return await service.word_to_info_dict(word)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/forbidden-words/{word_id}")
|
||||||
|
async def update_forbidden_word(
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
body: ForbiddenWordUpdate,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-80] Update a forbidden word.
|
||||||
|
"""
|
||||||
|
valid_categories = ["competitor", "sensitive", "political", "custom"]
|
||||||
|
if body.category is not None and body.category not in valid_categories:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category. Must be one of: {valid_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_strategies = ["mask", "replace", "block"]
|
||||||
|
if body.strategy is not None and body.strategy not in valid_strategies:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid strategy. Must be one of: {valid_strategies}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[AC-AISVC-80] Updating forbidden word for tenant={tenant_id}, id={word_id}")
|
||||||
|
|
||||||
|
service = ForbiddenWordService(session)
|
||||||
|
try:
|
||||||
|
word = await service.update_word(tenant_id, word_id, body)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
if not word:
|
||||||
|
raise HTTPException(status_code=404, detail="Forbidden word not found")
|
||||||
|
|
||||||
|
return await service.word_to_info_dict(word)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/forbidden-words/{word_id}", status_code=204)
|
||||||
|
async def delete_forbidden_word(
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-81] Delete a forbidden word.
|
||||||
|
"""
|
||||||
|
logger.info(f"[AC-AISVC-81] Deleting forbidden word for tenant={tenant_id}, id={word_id}")
|
||||||
|
|
||||||
|
service = ForbiddenWordService(session)
|
||||||
|
success = await service.delete_word(tenant_id, word_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="Forbidden word not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/behavior-rules")
|
||||||
|
async def list_behavior_rules(
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
category: str | None = None,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] List all behavior rules for a tenant.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-85] Listing behavior rules for tenant={tenant_id}, category={category}"
|
||||||
|
)
|
||||||
|
|
||||||
|
service = BehaviorRuleService(session)
|
||||||
|
rules = await service.list_rules(tenant_id, category)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for rule in rules:
|
||||||
|
data.append(await service.rule_to_info_dict(rule))
|
||||||
|
|
||||||
|
return {"data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/behavior-rules", status_code=201)
|
||||||
|
async def create_behavior_rule(
|
||||||
|
body: BehaviorRuleCreate,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84] Create a new behavior rule.
|
||||||
|
"""
|
||||||
|
valid_categories = ["compliance", "tone", "boundary", "custom"]
|
||||||
|
if body.category not in valid_categories:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category. Must be one of: {valid_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-84] Creating behavior rule for tenant={tenant_id}, category={body.category}"
|
||||||
|
)
|
||||||
|
|
||||||
|
service = BehaviorRuleService(session)
|
||||||
|
try:
|
||||||
|
rule = await service.create_rule(tenant_id, body)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
return await service.rule_to_info_dict(rule)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/behavior-rules/{rule_id}")
|
||||||
|
async def get_behavior_rule(
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Get behavior rule detail.
|
||||||
|
"""
|
||||||
|
logger.info(f"[AC-AISVC-85] Getting behavior rule for tenant={tenant_id}, id={rule_id}")
|
||||||
|
|
||||||
|
service = BehaviorRuleService(session)
|
||||||
|
rule = await service.get_rule(tenant_id, rule_id)
|
||||||
|
|
||||||
|
if not rule:
|
||||||
|
raise HTTPException(status_code=404, detail="Behavior rule not found")
|
||||||
|
|
||||||
|
return await service.rule_to_info_dict(rule)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/behavior-rules/{rule_id}")
|
||||||
|
async def update_behavior_rule(
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
body: BehaviorRuleUpdate,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Update a behavior rule.
|
||||||
|
"""
|
||||||
|
valid_categories = ["compliance", "tone", "boundary", "custom"]
|
||||||
|
if body.category is not None and body.category not in valid_categories:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid category. Must be one of: {valid_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[AC-AISVC-85] Updating behavior rule for tenant={tenant_id}, id={rule_id}")
|
||||||
|
|
||||||
|
service = BehaviorRuleService(session)
|
||||||
|
try:
|
||||||
|
rule = await service.update_rule(tenant_id, rule_id, body)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
if not rule:
|
||||||
|
raise HTTPException(status_code=404, detail="Behavior rule not found")
|
||||||
|
|
||||||
|
return await service.rule_to_info_dict(rule)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/behavior-rules/{rule_id}", status_code=204)
|
||||||
|
async def delete_behavior_rule(
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
tenant_id: str = Depends(get_tenant_id),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Delete a behavior rule.
|
||||||
|
"""
|
||||||
|
logger.info(f"[AC-AISVC-85] Deleting behavior rule for tenant={tenant_id}, id={rule_id}")
|
||||||
|
|
||||||
|
service = BehaviorRuleService(session)
|
||||||
|
success = await service.delete_rule(tenant_id, rule_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="Behavior rule not found")
|
||||||
|
|
@ -12,7 +12,20 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.api import chat_router, health_router
|
from app.api import chat_router, health_router
|
||||||
from app.api.admin import api_key_router, dashboard_router, embedding_router, guardrails_router, intent_rules_router, kb_router, llm_router, prompt_templates_router, rag_router, script_flows_router, sessions_router, tenants_router
|
from app.api.admin import (
|
||||||
|
api_key_router,
|
||||||
|
dashboard_router,
|
||||||
|
embedding_router,
|
||||||
|
guardrails_router,
|
||||||
|
intent_rules_router,
|
||||||
|
kb_router,
|
||||||
|
llm_router,
|
||||||
|
prompt_templates_router,
|
||||||
|
rag_router,
|
||||||
|
script_flows_router,
|
||||||
|
sessions_router,
|
||||||
|
tenants_router,
|
||||||
|
)
|
||||||
from app.api.admin.kb_optimized import router as kb_optimized_router
|
from app.api.admin.kb_optimized import router as kb_optimized_router
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.database import close_db, init_db
|
from app.core.database import close_db, init_db
|
||||||
|
|
@ -130,6 +143,7 @@ app.include_router(chat_router)
|
||||||
app.include_router(api_key_router)
|
app.include_router(api_key_router)
|
||||||
app.include_router(dashboard_router)
|
app.include_router(dashboard_router)
|
||||||
app.include_router(embedding_router)
|
app.include_router(embedding_router)
|
||||||
|
app.include_router(guardrails_router)
|
||||||
app.include_router(intent_rules_router)
|
app.include_router(intent_rules_router)
|
||||||
app.include_router(kb_router)
|
app.include_router(kb_router)
|
||||||
app.include_router(kb_optimized_router)
|
app.include_router(kb_optimized_router)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import Column, JSON
|
from sqlalchemy import JSON, Column
|
||||||
from sqlmodel import Field, Index, SQLModel
|
from sqlmodel import Field, Index, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,7 +141,10 @@ class KnowledgeBase(SQLModel, table=True):
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||||
name: str = Field(..., description="Knowledge base name")
|
name: str = Field(..., description="Knowledge base name")
|
||||||
kb_type: str = Field(default=KBType.GENERAL.value, description="Knowledge base type: product/faq/script/policy/general")
|
kb_type: str = Field(
|
||||||
|
default=KBType.GENERAL.value,
|
||||||
|
description="Knowledge base type: product/faq/script/policy/general"
|
||||||
|
)
|
||||||
description: str | None = Field(default=None, description="Knowledge base description")
|
description: str | None = Field(default=None, description="Knowledge base description")
|
||||||
priority: int = Field(default=0, ge=0, description="Priority weight, higher value means higher priority")
|
priority: int = Field(default=0, ge=0, description="Priority weight, higher value means higher priority")
|
||||||
is_enabled: bool = Field(default=True, description="Whether the knowledge base is enabled")
|
is_enabled: bool = Field(default=True, description="Whether the knowledge base is enabled")
|
||||||
|
|
@ -289,14 +292,25 @@ class PromptTemplateVersion(SQLModel, table=True):
|
||||||
)
|
)
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
template_id: uuid.UUID = Field(..., description="Foreign key to prompt_templates.id", foreign_key="prompt_templates.id", index=True)
|
template_id: uuid.UUID = Field(
|
||||||
|
...,
|
||||||
|
description="Foreign key to prompt_templates.id",
|
||||||
|
foreign_key="prompt_templates.id",
|
||||||
|
index=True
|
||||||
|
)
|
||||||
version: int = Field(..., description="Version number (auto-incremented per template)")
|
version: int = Field(..., description="Version number (auto-incremented per template)")
|
||||||
status: str = Field(default=TemplateVersionStatus.DRAFT.value, description="Version status: draft/published/archived")
|
status: str = Field(
|
||||||
system_instruction: str = Field(..., description="System instruction content with {{variable}} placeholders")
|
default=TemplateVersionStatus.DRAFT.value,
|
||||||
|
description="Version status: draft/published/archived"
|
||||||
|
)
|
||||||
|
system_instruction: str = Field(
|
||||||
|
...,
|
||||||
|
description="System instruction content with {{variable}} placeholders"
|
||||||
|
)
|
||||||
variables: list[dict[str, Any]] | None = Field(
|
variables: list[dict[str, Any]] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column("variables", JSON, nullable=True),
|
sa_column=Column("variables", JSON, nullable=True),
|
||||||
description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N', 'description': '人设名称'}]"
|
description="Variable definitions, e.g., [{'name': 'persona_name', 'default': '小N'}]"
|
||||||
)
|
)
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||||
|
|
||||||
|
|
@ -510,7 +524,10 @@ class BehaviorRule(SQLModel, table=True):
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||||
rule_text: str = Field(..., description="Behavior constraint description, e.g., 'Do not promise specific compensation amounts'")
|
rule_text: str = Field(
|
||||||
|
...,
|
||||||
|
description="Behavior constraint description, e.g., 'Do not promise specific compensation'"
|
||||||
|
)
|
||||||
category: str = Field(..., description="Category: compliance/tone/boundary/custom")
|
category: str = Field(..., description="Category: compliance/tone/boundary/custom")
|
||||||
is_enabled: bool = Field(default=True, description="Whether the rule is enabled")
|
is_enabled: bool = Field(default=True, description="Whether the rule is enabled")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||||
|
|
@ -618,7 +635,7 @@ class ScriptFlow(SQLModel, table=True):
|
||||||
steps: list[dict[str, Any]] = Field(
|
steps: list[dict[str, Any]] = Field(
|
||||||
default=[],
|
default=[],
|
||||||
sa_column=Column("steps", JSON, nullable=False),
|
sa_column=Column("steps", JSON, nullable=False),
|
||||||
description="Flow steps list with step_no, content, wait_input, timeout_seconds, timeout_action, next_conditions, default_next"
|
description="Flow steps list with step_no, content, wait_input, timeout_seconds"
|
||||||
)
|
)
|
||||||
is_enabled: bool = Field(default=True, description="Whether the flow is enabled")
|
is_enabled: bool = Field(default=True, description="Whether the flow is enabled")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||||
|
|
@ -640,13 +657,21 @@ class FlowInstance(SQLModel, table=True):
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||||
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
|
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
|
||||||
flow_id: uuid.UUID = Field(..., description="Foreign key to script_flows.id", foreign_key="script_flows.id", index=True)
|
flow_id: uuid.UUID = Field(
|
||||||
|
...,
|
||||||
|
description="Foreign key to script_flows.id",
|
||||||
|
foreign_key="script_flows.id",
|
||||||
|
index=True
|
||||||
|
)
|
||||||
current_step: int = Field(default=1, ge=1, description="Current step number (1-indexed)")
|
current_step: int = Field(default=1, ge=1, description="Current step number (1-indexed)")
|
||||||
status: str = Field(default=FlowInstanceStatus.ACTIVE.value, description="Instance status: active/completed/timeout/cancelled")
|
status: str = Field(
|
||||||
|
default=FlowInstanceStatus.ACTIVE.value,
|
||||||
|
description="Instance status: active/completed/timeout/cancelled"
|
||||||
|
)
|
||||||
context: dict[str, Any] | None = Field(
|
context: dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column("context", JSON, nullable=True),
|
sa_column=Column("context", JSON, nullable=True),
|
||||||
description="Flow execution context, stores user inputs etc."
|
description="Flow execution context, stores user inputs"
|
||||||
)
|
)
|
||||||
started_at: datetime = Field(default_factory=datetime.utcnow, description="Instance start time")
|
started_at: datetime = Field(default_factory=datetime.utcnow, description="Instance start time")
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||||
|
|
@ -660,10 +685,13 @@ class FlowStep(SQLModel):
|
||||||
content: str = Field(..., description="Script content for this step")
|
content: str = Field(..., description="Script content for this step")
|
||||||
wait_input: bool = Field(default=True, description="Whether to wait for user input")
|
wait_input: bool = Field(default=True, description="Whether to wait for user input")
|
||||||
timeout_seconds: int = Field(default=120, ge=1, description="Timeout in seconds")
|
timeout_seconds: int = Field(default=120, ge=1, description="Timeout in seconds")
|
||||||
timeout_action: str = Field(default=TimeoutAction.REPEAT.value, description="Action on timeout: repeat/skip/transfer")
|
timeout_action: str = Field(
|
||||||
|
default=TimeoutAction.REPEAT.value,
|
||||||
|
description="Action on timeout: repeat/skip/transfer"
|
||||||
|
)
|
||||||
next_conditions: list[dict[str, Any]] | None = Field(
|
next_conditions: list[dict[str, Any]] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Conditions for next step: [{'keywords': [...], 'goto_step': N}, {'pattern': '...', 'goto_step': N}]"
|
description="Conditions for next step: [{'keywords': [...], 'goto_step': N}]"
|
||||||
)
|
)
|
||||||
default_next: int | None = Field(default=None, description="Default next step if no condition matches")
|
default_next: int | None = Field(default=None, description="Default next step if no condition matches")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""
|
||||||
|
Guardrail services for AI Service.
|
||||||
|
[AC-AISVC-78~AC-AISVC-85] Output guardrail with forbidden word detection and behavior rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.services.guardrail.behavior_service import BehaviorRuleService
|
||||||
|
from app.services.guardrail.input_scanner import InputScanner
|
||||||
|
from app.services.guardrail.output_filter import OutputFilter
|
||||||
|
from app.services.guardrail.streaming_filter import StreamingGuardrail
|
||||||
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ForbiddenWordService",
|
||||||
|
"BehaviorRuleService",
|
||||||
|
"InputScanner",
|
||||||
|
"OutputFilter",
|
||||||
|
"StreamingGuardrail",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,260 @@
|
||||||
|
"""
|
||||||
|
Behavior rule service for AI Service.
|
||||||
|
[AC-AISVC-84, AC-AISVC-85] Behavior rule CRUD management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
BehaviorRule,
|
||||||
|
BehaviorRuleCategory,
|
||||||
|
BehaviorRuleCreate,
|
||||||
|
BehaviorRuleUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BEHAVIOR_CACHE_TTL_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
class BehaviorRuleCache:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84] In-memory cache for behavior rules.
|
||||||
|
Key: tenant_id
|
||||||
|
Value: (rules_list, cached_at)
|
||||||
|
TTL: 60 seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = BEHAVIOR_CACHE_TTL_SECONDS):
|
||||||
|
self._cache: dict[str, tuple[list[BehaviorRule], float]] = {}
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
|
||||||
|
def get(self, tenant_id: str) -> list[BehaviorRule] | None:
|
||||||
|
"""Get cached rules if not expired."""
|
||||||
|
if tenant_id in self._cache:
|
||||||
|
rules, cached_at = self._cache[tenant_id]
|
||||||
|
if time.time() - cached_at < self._ttl:
|
||||||
|
return rules
|
||||||
|
else:
|
||||||
|
del self._cache[tenant_id]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, tenant_id: str, rules: list[BehaviorRule]) -> None:
|
||||||
|
"""Cache rules for a tenant."""
|
||||||
|
self._cache[tenant_id] = (rules, time.time())
|
||||||
|
|
||||||
|
def invalidate(self, tenant_id: str) -> None:
|
||||||
|
"""Invalidate cache for a tenant."""
|
||||||
|
if tenant_id in self._cache:
|
||||||
|
del self._cache[tenant_id]
|
||||||
|
logger.debug(f"Invalidated behavior rule cache for tenant={tenant_id}")
|
||||||
|
|
||||||
|
|
||||||
|
_behavior_cache = BehaviorRuleCache()
|
||||||
|
|
||||||
|
|
||||||
|
class BehaviorRuleService:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84, AC-AISVC-85] Service for managing behavior rules.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Rule CRUD with tenant isolation
|
||||||
|
- In-memory caching with TTL
|
||||||
|
- Cache invalidation on CRUD operations
|
||||||
|
- Rules are injected into Prompt system instruction
|
||||||
|
"""
|
||||||
|
|
||||||
|
VALID_CATEGORIES = [c.value for c in BehaviorRuleCategory]
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self._session = session
|
||||||
|
self._cache = _behavior_cache
|
||||||
|
|
||||||
|
async def create_rule(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
create_data: BehaviorRuleCreate,
|
||||||
|
) -> BehaviorRule:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84] Create a new behavior rule.
|
||||||
|
"""
|
||||||
|
if create_data.category not in self.VALID_CATEGORIES:
|
||||||
|
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
|
||||||
|
|
||||||
|
rule = BehaviorRule(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
rule_text=create_data.rule_text,
|
||||||
|
category=create_data.category,
|
||||||
|
is_enabled=True,
|
||||||
|
)
|
||||||
|
self._session.add(rule)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-84] Created behavior rule: tenant={tenant_id}, "
|
||||||
|
f"id={rule.id}, category={rule.category}"
|
||||||
|
)
|
||||||
|
return rule
|
||||||
|
|
||||||
|
async def list_rules(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
category: str | None = None,
|
||||||
|
) -> Sequence[BehaviorRule]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] List rules for a tenant with optional filters.
|
||||||
|
"""
|
||||||
|
stmt = select(BehaviorRule).where(
|
||||||
|
BehaviorRule.tenant_id == tenant_id # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if category is not None:
|
||||||
|
stmt = stmt.where(BehaviorRule.category == category) # type: ignore
|
||||||
|
|
||||||
|
stmt = stmt.order_by(col(BehaviorRule.category), col(BehaviorRule.created_at).desc())
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_rule(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
) -> BehaviorRule | None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Get rule by ID with tenant isolation.
|
||||||
|
"""
|
||||||
|
stmt = select(BehaviorRule).where(
|
||||||
|
BehaviorRule.tenant_id == tenant_id, # type: ignore
|
||||||
|
BehaviorRule.id == rule_id, # type: ignore
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update_rule(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
update_data: BehaviorRuleUpdate,
|
||||||
|
) -> BehaviorRule | None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Update a behavior rule.
|
||||||
|
"""
|
||||||
|
rule = await self.get_rule(tenant_id, rule_id)
|
||||||
|
if not rule:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if update_data.rule_text is not None:
|
||||||
|
rule.rule_text = update_data.rule_text
|
||||||
|
if update_data.category is not None:
|
||||||
|
if update_data.category not in self.VALID_CATEGORIES:
|
||||||
|
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
|
||||||
|
rule.category = update_data.category
|
||||||
|
if update_data.is_enabled is not None:
|
||||||
|
rule.is_enabled = update_data.is_enabled
|
||||||
|
|
||||||
|
rule.updated_at = datetime.utcnow()
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-85] Updated behavior rule: tenant={tenant_id}, id={rule_id}"
|
||||||
|
)
|
||||||
|
return rule
|
||||||
|
|
||||||
|
async def delete_rule(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
rule_id: uuid.UUID,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-85] Delete a behavior rule.
|
||||||
|
"""
|
||||||
|
rule = await self.get_rule(tenant_id, rule_id)
|
||||||
|
if not rule:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._session.delete(rule)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-85] Deleted behavior rule: tenant={tenant_id}, id={rule_id}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_enabled_rules_for_injection(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[BehaviorRule]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84] Get enabled rules for Prompt injection.
|
||||||
|
Uses cache for performance.
|
||||||
|
"""
|
||||||
|
cached = self._cache.get(tenant_id)
|
||||||
|
if cached is not None:
|
||||||
|
logger.debug(f"[AC-AISVC-84] Cache hit for behavior rules: tenant={tenant_id}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(BehaviorRule)
|
||||||
|
.where(
|
||||||
|
BehaviorRule.tenant_id == tenant_id, # type: ignore
|
||||||
|
BehaviorRule.is_enabled == True, # type: ignore
|
||||||
|
)
|
||||||
|
.order_by(col(BehaviorRule.category))
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
rules = list(result.scalars().all())
|
||||||
|
|
||||||
|
self._cache.set(tenant_id, rules)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-84] Loaded {len(rules)} enabled behavior rules from DB: tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
return rules
|
||||||
|
|
||||||
|
def invalidate_cache(self, tenant_id: str) -> None:
|
||||||
|
"""Manually invalidate cache for a tenant."""
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
async def rule_to_info_dict(self, rule: BehaviorRule) -> dict[str, Any]:
|
||||||
|
"""Convert rule entity to API response dict."""
|
||||||
|
return {
|
||||||
|
"id": str(rule.id),
|
||||||
|
"rule_text": rule.rule_text,
|
||||||
|
"category": rule.category,
|
||||||
|
"is_enabled": rule.is_enabled,
|
||||||
|
"created_at": rule.created_at.isoformat(),
|
||||||
|
"updated_at": rule.updated_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def format_rules_for_prompt(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-84] Format behavior rules for Prompt injection.
|
||||||
|
Returns formatted string to append to system instruction.
|
||||||
|
"""
|
||||||
|
rules = await self.get_enabled_rules_for_injection(tenant_id)
|
||||||
|
|
||||||
|
if not rules:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = ["\n\n[行为约束 - 以下规则必须严格遵守]"]
|
||||||
|
for i, rule in enumerate(rules, 1):
|
||||||
|
lines.append(f"{i}. {rule.rule_text}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""
|
||||||
|
Input scanner for AI Service.
|
||||||
|
[AC-AISVC-83] User input pre-detection (logging only, no blocking).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
ForbiddenWord,
|
||||||
|
InputScanResult,
|
||||||
|
)
|
||||||
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InputScanner:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-83] Input scanner for pre-detection of forbidden words.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Scans user input for forbidden words
|
||||||
|
- Records matched words and categories in metadata
|
||||||
|
- Does NOT block the request (only logging)
|
||||||
|
- Used for monitoring and analytics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, word_service: ForbiddenWordService):
|
||||||
|
self._word_service = word_service
|
||||||
|
|
||||||
|
async def scan(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> InputScanResult:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-83] Scan user input for forbidden words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: User input text to scan
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputScanResult with flagged status and matched words
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return InputScanResult(flagged=False)
|
||||||
|
|
||||||
|
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
|
||||||
|
|
||||||
|
if not words:
|
||||||
|
return InputScanResult(flagged=False)
|
||||||
|
|
||||||
|
matched_words: list[str] = []
|
||||||
|
matched_categories: list[str] = []
|
||||||
|
matched_word_entities: list[ForbiddenWord] = []
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if word.word in text:
|
||||||
|
matched_words.append(word.word)
|
||||||
|
if word.category not in matched_categories:
|
||||||
|
matched_categories.append(word.category)
|
||||||
|
matched_word_entities.append(word)
|
||||||
|
|
||||||
|
if matched_words:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-83] Input flagged: tenant={tenant_id}, "
|
||||||
|
f"matched_words={matched_words}, categories={matched_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for word_entity in matched_word_entities:
|
||||||
|
try:
|
||||||
|
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to increment hit count for word {word_entity.id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return InputScanResult(
|
||||||
|
flagged=len(matched_words) > 0,
|
||||||
|
matched_words=matched_words,
|
||||||
|
matched_categories=matched_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scan_and_enrich_metadata(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
tenant_id: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-83] Scan input and enrich metadata with scan result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: User input text to scan
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
metadata: Existing metadata dict to enrich
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enriched metadata with input_flagged and matched info
|
||||||
|
"""
|
||||||
|
result = await self.scan(text, tenant_id)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
metadata.update(result.to_dict())
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
"""
|
||||||
|
Output filter for AI Service.
|
||||||
|
[AC-AISVC-82] LLM output post-filtering with mask/replace/block strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
ForbiddenWord,
|
||||||
|
ForbiddenWordStrategy,
|
||||||
|
GuardrailResult,
|
||||||
|
)
|
||||||
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFilter:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] Output filter for post-filtering LLM responses.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Scans LLM output for forbidden words
|
||||||
|
- Applies mask/replace/block strategies
|
||||||
|
- Returns fallback reply for block strategy
|
||||||
|
- Records triggered words in metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
|
||||||
|
|
||||||
|
def __init__(self, word_service: ForbiddenWordService):
|
||||||
|
self._word_service = word_service
|
||||||
|
|
||||||
|
async def filter(
|
||||||
|
self,
|
||||||
|
reply: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> GuardrailResult:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] Filter LLM output for forbidden words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reply: LLM generated reply to filter
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GuardrailResult with filtered reply and trigger info
|
||||||
|
"""
|
||||||
|
if not reply or not reply.strip():
|
||||||
|
return GuardrailResult(reply=reply)
|
||||||
|
|
||||||
|
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
|
||||||
|
|
||||||
|
if not words:
|
||||||
|
return GuardrailResult(reply=reply)
|
||||||
|
|
||||||
|
triggered_words: list[str] = []
|
||||||
|
triggered_categories: list[str] = []
|
||||||
|
filtered_reply = reply
|
||||||
|
blocked = False
|
||||||
|
fallback_reply = self.DEFAULT_FALLBACK_REPLY
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if word.word in filtered_reply:
|
||||||
|
triggered_words.append(word.word)
|
||||||
|
if word.category not in triggered_categories:
|
||||||
|
triggered_categories.append(word.category)
|
||||||
|
|
||||||
|
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
|
||||||
|
blocked = True
|
||||||
|
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-AISVC-82] Output blocked by forbidden word: tenant={tenant_id}, "
|
||||||
|
f"word={word.word}, category={word.category}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.MASK.value:
|
||||||
|
filtered_reply = filtered_reply.replace(word.word, "*" * len(word.word))
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Output masked: tenant={tenant_id}, word={word.word}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
|
||||||
|
replacement = word.replacement or ""
|
||||||
|
filtered_reply = filtered_reply.replace(word.word, replacement)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Output replaced: tenant={tenant_id}, "
|
||||||
|
f"word={word.word} -> {replacement}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
return GuardrailResult(
|
||||||
|
reply=fallback_reply,
|
||||||
|
blocked=True,
|
||||||
|
triggered_words=triggered_words,
|
||||||
|
triggered_categories=triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if triggered_words:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Output filtered: tenant={tenant_id}, "
|
||||||
|
f"triggered_words={triggered_words}, categories={triggered_categories}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for word_entity in self._get_triggered_word_entities(words, triggered_words):
|
||||||
|
try:
|
||||||
|
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to increment hit count for word {word_entity.id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GuardrailResult(
|
||||||
|
reply=filtered_reply,
|
||||||
|
blocked=False,
|
||||||
|
triggered_words=triggered_words,
|
||||||
|
triggered_categories=triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_triggered_word_entities(
|
||||||
|
self,
|
||||||
|
words: list[ForbiddenWord],
|
||||||
|
triggered_words: list[str],
|
||||||
|
) -> list[ForbiddenWord]:
|
||||||
|
"""Get word entities for triggered words."""
|
||||||
|
return [w for w in words if w.word in triggered_words]
|
||||||
|
|
||||||
|
async def filter_and_enrich_metadata(
|
||||||
|
self,
|
||||||
|
reply: str,
|
||||||
|
tenant_id: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] Filter output and enrich metadata with filter result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reply: LLM generated reply to filter
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
metadata: Existing metadata dict to enrich
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filtered_reply, enriched_metadata)
|
||||||
|
"""
|
||||||
|
result = await self.filter(reply, tenant_id)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
metadata.update(result.to_dict())
|
||||||
|
|
||||||
|
return result.reply, metadata
|
||||||
|
|
@ -0,0 +1,366 @@
|
||||||
|
"""
|
||||||
|
Streaming guardrail for AI Service.
|
||||||
|
[AC-AISVC-82] Streaming mode forbidden word detection with sliding window buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
ForbiddenWord,
|
||||||
|
ForbiddenWordStrategy,
|
||||||
|
)
|
||||||
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingGuardrailState(str, Enum):
|
||||||
|
"""State of streaming guardrail."""
|
||||||
|
ACTIVE = "active"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamingGuardrailResult:
|
||||||
|
"""Result from streaming guardrail processing."""
|
||||||
|
delta: str
|
||||||
|
should_stop: bool = False
|
||||||
|
fallback_reply: str | None = None
|
||||||
|
triggered_words: list[str] = field(default_factory=list)
|
||||||
|
triggered_categories: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingGuardrail:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] Streaming guardrail with sliding window buffer.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Maintains a sliding window buffer for forbidden word detection
|
||||||
|
- Detects forbidden words across chunk boundaries
|
||||||
|
- Applies mask/replace strategies incrementally
|
||||||
|
- Stops streaming and returns fallback for block strategy
|
||||||
|
- Final check at stream end
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
|
||||||
|
DEFAULT_WINDOW_SIZE = 50
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
word_service: ForbiddenWordService,
|
||||||
|
window_size: int = DEFAULT_WINDOW_SIZE,
|
||||||
|
):
|
||||||
|
self._word_service = word_service
|
||||||
|
self._window_size = window_size
|
||||||
|
self._buffer = ""
|
||||||
|
self._state = StreamingGuardrailState.ACTIVE
|
||||||
|
self._triggered_words: list[str] = []
|
||||||
|
self._triggered_categories: list[str] = []
|
||||||
|
self._fallback_reply = self.DEFAULT_FALLBACK_REPLY
|
||||||
|
self._words_cache: list[ForbiddenWord] | None = None
|
||||||
|
self._max_word_length = 0
|
||||||
|
|
||||||
|
async def initialize(self, tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the guardrail with tenant's forbidden words.
|
||||||
|
Must be called before processing chunks.
|
||||||
|
"""
|
||||||
|
self._words_cache = await self._word_service.get_enabled_words_for_filtering(tenant_id)
|
||||||
|
|
||||||
|
if self._words_cache:
|
||||||
|
self._max_word_length = max(len(w.word) for w in self._words_cache)
|
||||||
|
effective_window = max(self._window_size, self._max_word_length)
|
||||||
|
if effective_window > self._window_size:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Adjusted window size from {self._window_size} to {effective_window} "
|
||||||
|
f"to accommodate max word length {self._max_word_length}"
|
||||||
|
)
|
||||||
|
self._window_size = effective_window
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[AC-AISVC-82] StreamingGuardrail initialized: "
|
||||||
|
f"words_count={len(self._words_cache) if self._words_cache else 0}, "
|
||||||
|
f"window_size={self._window_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_chunk(
|
||||||
|
self,
|
||||||
|
delta: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> StreamingGuardrailResult:
|
||||||
|
"""
|
||||||
|
Process a streaming chunk with sliding window detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta: New text chunk from LLM
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingGuardrailResult with processed delta and state
|
||||||
|
"""
|
||||||
|
if self._state == StreamingGuardrailState.BLOCKED:
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta="",
|
||||||
|
should_stop=True,
|
||||||
|
fallback_reply=self._fallback_reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._state == StreamingGuardrailState.COMPLETED:
|
||||||
|
return StreamingGuardrailResult(delta="")
|
||||||
|
|
||||||
|
if not self._words_cache:
|
||||||
|
await self.initialize(tenant_id)
|
||||||
|
|
||||||
|
if not self._words_cache:
|
||||||
|
return StreamingGuardrailResult(delta=delta)
|
||||||
|
|
||||||
|
self._buffer += delta
|
||||||
|
|
||||||
|
result = self._check_buffer()
|
||||||
|
|
||||||
|
if result.should_stop:
|
||||||
|
self._state = StreamingGuardrailState.BLOCKED
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-AISVC-82] Streaming blocked: tenant={tenant_id}, "
|
||||||
|
f"triggered_words={result.triggered_words}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _check_buffer(self) -> StreamingGuardrailResult:
|
||||||
|
"""
|
||||||
|
Check buffer for forbidden words and apply strategies.
|
||||||
|
"""
|
||||||
|
triggered_words: list[str] = []
|
||||||
|
triggered_categories: list[str] = []
|
||||||
|
filtered_buffer = self._buffer
|
||||||
|
blocked = False
|
||||||
|
fallback_reply = self.DEFAULT_FALLBACK_REPLY
|
||||||
|
|
||||||
|
if not self._words_cache:
|
||||||
|
return StreamingGuardrailResult(delta=self._buffer)
|
||||||
|
|
||||||
|
for word in self._words_cache:
|
||||||
|
if word.word in filtered_buffer:
|
||||||
|
triggered_words.append(word.word)
|
||||||
|
if word.category not in triggered_categories:
|
||||||
|
triggered_categories.append(word.category)
|
||||||
|
|
||||||
|
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
|
||||||
|
blocked = True
|
||||||
|
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
|
||||||
|
break
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.MASK.value:
|
||||||
|
filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word))
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
|
||||||
|
replacement = word.replacement or ""
|
||||||
|
filtered_buffer = filtered_buffer.replace(word.word, replacement)
|
||||||
|
|
||||||
|
self._triggered_words.extend([w for w in triggered_words if w not in self._triggered_words])
|
||||||
|
self._triggered_categories.extend([c for c in triggered_categories if c not in self._triggered_categories])
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta="",
|
||||||
|
should_stop=True,
|
||||||
|
fallback_reply=fallback_reply,
|
||||||
|
triggered_words=triggered_words,
|
||||||
|
triggered_categories=triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
safe_output, remaining = self._split_buffer(filtered_buffer)
|
||||||
|
|
||||||
|
self._buffer = remaining
|
||||||
|
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta=safe_output,
|
||||||
|
should_stop=False,
|
||||||
|
triggered_words=triggered_words,
|
||||||
|
triggered_categories=triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_buffer(self, buffer: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Split buffer into safe output and remaining buffer.
|
||||||
|
|
||||||
|
Safe output: characters that are definitely safe (before window boundary)
|
||||||
|
Remaining: characters that might be part of a forbidden word
|
||||||
|
"""
|
||||||
|
if len(buffer) <= self._window_size:
|
||||||
|
return "", buffer
|
||||||
|
|
||||||
|
safe_length = len(buffer) - self._window_size
|
||||||
|
safe_output = buffer[:safe_length]
|
||||||
|
remaining = buffer[safe_length:]
|
||||||
|
|
||||||
|
return safe_output, remaining
|
||||||
|
|
||||||
|
async def finalize(self, tenant_id: str) -> StreamingGuardrailResult:
|
||||||
|
"""
|
||||||
|
Finalize the stream and process remaining buffer.
|
||||||
|
Must be called at the end of streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingGuardrailResult with final delta
|
||||||
|
"""
|
||||||
|
if self._state == StreamingGuardrailState.BLOCKED:
|
||||||
|
self._state = StreamingGuardrailState.COMPLETED
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta="",
|
||||||
|
should_stop=True,
|
||||||
|
fallback_reply=self._fallback_reply,
|
||||||
|
triggered_words=self._triggered_words,
|
||||||
|
triggered_categories=self._triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._words_cache:
|
||||||
|
await self.initialize(tenant_id)
|
||||||
|
|
||||||
|
result = self._final_check()
|
||||||
|
|
||||||
|
self._state = StreamingGuardrailState.COMPLETED
|
||||||
|
|
||||||
|
if self._triggered_words:
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Streaming finalized with triggers: tenant={tenant_id}, "
|
||||||
|
f"triggered_words={self._triggered_words}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for word_entity in self._get_triggered_word_entities():
|
||||||
|
try:
|
||||||
|
await self._word_service.increment_hit_count(tenant_id, word_entity.id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to increment hit count for word {word_entity.id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _final_check(self) -> StreamingGuardrailResult:
|
||||||
|
"""
|
||||||
|
Final check of remaining buffer.
|
||||||
|
"""
|
||||||
|
if not self._buffer:
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta="",
|
||||||
|
triggered_words=self._triggered_words,
|
||||||
|
triggered_categories=self._triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._words_cache:
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta=self._buffer,
|
||||||
|
triggered_words=self._triggered_words,
|
||||||
|
triggered_categories=self._triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered_buffer = self._buffer
|
||||||
|
blocked = False
|
||||||
|
fallback_reply = self.DEFAULT_FALLBACK_REPLY
|
||||||
|
|
||||||
|
for word in self._words_cache:
|
||||||
|
if word.word in filtered_buffer:
|
||||||
|
if word.word not in self._triggered_words:
|
||||||
|
self._triggered_words.append(word.word)
|
||||||
|
if word.category not in self._triggered_categories:
|
||||||
|
self._triggered_categories.append(word.category)
|
||||||
|
|
||||||
|
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
|
||||||
|
blocked = True
|
||||||
|
fallback_reply = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
|
||||||
|
break
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.MASK.value:
|
||||||
|
filtered_buffer = filtered_buffer.replace(word.word, "*" * len(word.word))
|
||||||
|
|
||||||
|
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
|
||||||
|
replacement = word.replacement or ""
|
||||||
|
filtered_buffer = filtered_buffer.replace(word.word, replacement)
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta="",
|
||||||
|
should_stop=True,
|
||||||
|
fallback_reply=fallback_reply,
|
||||||
|
triggered_words=self._triggered_words,
|
||||||
|
triggered_categories=self._triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
return StreamingGuardrailResult(
|
||||||
|
delta=filtered_buffer,
|
||||||
|
triggered_words=self._triggered_words,
|
||||||
|
triggered_categories=self._triggered_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_triggered_word_entities(self) -> list[ForbiddenWord]:
|
||||||
|
"""Get word entities for triggered words."""
|
||||||
|
if not self._words_cache:
|
||||||
|
return []
|
||||||
|
return [w for w in self._words_cache if w.word in self._triggered_words]
|
||||||
|
|
||||||
|
def get_triggered_info(self) -> dict[str, Any]:
|
||||||
|
"""Get triggered words and categories info."""
|
||||||
|
return {
|
||||||
|
"triggered_words": self._triggered_words,
|
||||||
|
"triggered_categories": self._triggered_categories,
|
||||||
|
"guardrail_triggered": len(self._triggered_words) > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the guardrail state for reuse."""
|
||||||
|
self._buffer = ""
|
||||||
|
self._state = StreamingGuardrailState.ACTIVE
|
||||||
|
self._triggered_words = []
|
||||||
|
self._triggered_categories = []
|
||||||
|
self._fallback_reply = self.DEFAULT_FALLBACK_REPLY
|
||||||
|
self._words_cache = None
|
||||||
|
self._max_word_length = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def wrap_stream_with_guardrail(
|
||||||
|
stream: AsyncIterator[str],
|
||||||
|
guardrail: StreamingGuardrail,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> AsyncIterator[tuple[str, bool, str | None]]:
|
||||||
|
"""
|
||||||
|
Wrap an async stream with guardrail processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream: Original LLM output stream
|
||||||
|
guardrail: StreamingGuardrail instance
|
||||||
|
tenant_id: Tenant ID for isolation
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (delta, should_stop, fallback_reply)
|
||||||
|
"""
|
||||||
|
await guardrail.initialize(tenant_id)
|
||||||
|
|
||||||
|
async for delta in stream:
|
||||||
|
result = await guardrail.process_chunk(delta, tenant_id)
|
||||||
|
|
||||||
|
if result.delta:
|
||||||
|
yield (result.delta, False, None)
|
||||||
|
|
||||||
|
if result.should_stop:
|
||||||
|
yield ("", True, result.fallback_reply)
|
||||||
|
return
|
||||||
|
|
||||||
|
final_result = await guardrail.finalize(tenant_id)
|
||||||
|
|
||||||
|
if final_result.delta:
|
||||||
|
yield (final_result.delta, False, None)
|
||||||
|
|
||||||
|
if final_result.should_stop:
|
||||||
|
yield ("", True, final_result.fallback_reply)
|
||||||
|
|
@ -0,0 +1,297 @@
|
||||||
|
"""
|
||||||
|
Forbidden word service for AI Service.
|
||||||
|
[AC-AISVC-78~AC-AISVC-81] Forbidden word CRUD and hit statistics management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
ForbiddenWord,
|
||||||
|
ForbiddenWordCategory,
|
||||||
|
ForbiddenWordCreate,
|
||||||
|
ForbiddenWordStrategy,
|
||||||
|
ForbiddenWordUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
WORD_CACHE_TTL_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
class WordCache:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] In-memory cache for forbidden words.
|
||||||
|
Key: tenant_id
|
||||||
|
Value: (words_list, cached_at)
|
||||||
|
TTL: 60 seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = WORD_CACHE_TTL_SECONDS):
|
||||||
|
self._cache: dict[str, tuple[list[ForbiddenWord], float]] = {}
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
|
||||||
|
def get(self, tenant_id: str) -> list[ForbiddenWord] | None:
|
||||||
|
"""Get cached words if not expired."""
|
||||||
|
if tenant_id in self._cache:
|
||||||
|
words, cached_at = self._cache[tenant_id]
|
||||||
|
if time.time() - cached_at < self._ttl:
|
||||||
|
return words
|
||||||
|
else:
|
||||||
|
del self._cache[tenant_id]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, tenant_id: str, words: list[ForbiddenWord]) -> None:
|
||||||
|
"""Cache words for a tenant."""
|
||||||
|
self._cache[tenant_id] = (words, time.time())
|
||||||
|
|
||||||
|
def invalidate(self, tenant_id: str) -> None:
|
||||||
|
"""Invalidate cache for a tenant."""
|
||||||
|
if tenant_id in self._cache:
|
||||||
|
del self._cache[tenant_id]
|
||||||
|
logger.debug(f"Invalidated word cache for tenant={tenant_id}")
|
||||||
|
|
||||||
|
|
||||||
|
_word_cache = WordCache()
|
||||||
|
|
||||||
|
|
||||||
|
class ForbiddenWordService:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-78~AC-AISVC-81] Service for managing forbidden words.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Word CRUD with tenant isolation
|
||||||
|
- Hit count statistics
|
||||||
|
- In-memory caching with TTL
|
||||||
|
- Cache invalidation on CRUD operations
|
||||||
|
- Support for mask/replace/block strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
VALID_CATEGORIES = [c.value for c in ForbiddenWordCategory]
|
||||||
|
VALID_STRATEGIES = [s.value for s in ForbiddenWordStrategy]
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self._session = session
|
||||||
|
self._cache = _word_cache
|
||||||
|
|
||||||
|
async def create_word(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
create_data: ForbiddenWordCreate,
|
||||||
|
) -> ForbiddenWord:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-78] Create a new forbidden word.
|
||||||
|
"""
|
||||||
|
if create_data.category not in self.VALID_CATEGORIES:
|
||||||
|
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
|
||||||
|
|
||||||
|
if create_data.strategy not in self.VALID_STRATEGIES:
|
||||||
|
raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}")
|
||||||
|
|
||||||
|
if create_data.strategy == ForbiddenWordStrategy.REPLACE.value and not create_data.replacement:
|
||||||
|
raise ValueError("replacement is required when strategy is 'replace'")
|
||||||
|
|
||||||
|
if create_data.strategy == ForbiddenWordStrategy.BLOCK.value and not create_data.fallback_reply:
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-AISVC-78] Creating block word without fallback_reply: tenant={tenant_id}, word={create_data.word}"
|
||||||
|
)
|
||||||
|
|
||||||
|
word = ForbiddenWord(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
word=create_data.word,
|
||||||
|
category=create_data.category,
|
||||||
|
strategy=create_data.strategy,
|
||||||
|
replacement=create_data.replacement,
|
||||||
|
fallback_reply=create_data.fallback_reply,
|
||||||
|
is_enabled=True,
|
||||||
|
hit_count=0,
|
||||||
|
)
|
||||||
|
self._session.add(word)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-78] Created forbidden word: tenant={tenant_id}, "
|
||||||
|
f"id={word.id}, word={word.word}, strategy={word.strategy}"
|
||||||
|
)
|
||||||
|
return word
|
||||||
|
|
||||||
|
async def list_words(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
category: str | None = None,
|
||||||
|
is_enabled: bool | None = None,
|
||||||
|
) -> Sequence[ForbiddenWord]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-79] List words for a tenant with optional filters.
|
||||||
|
"""
|
||||||
|
stmt = select(ForbiddenWord).where(
|
||||||
|
ForbiddenWord.tenant_id == tenant_id # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if category is not None:
|
||||||
|
stmt = stmt.where(ForbiddenWord.category == category) # type: ignore
|
||||||
|
|
||||||
|
if is_enabled is not None:
|
||||||
|
stmt = stmt.where(ForbiddenWord.is_enabled == is_enabled) # type: ignore
|
||||||
|
|
||||||
|
stmt = stmt.order_by(col(ForbiddenWord.category), col(ForbiddenWord.created_at).desc())
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_word(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
) -> ForbiddenWord | None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-79] Get word by ID with tenant isolation.
|
||||||
|
"""
|
||||||
|
stmt = select(ForbiddenWord).where(
|
||||||
|
ForbiddenWord.tenant_id == tenant_id, # type: ignore
|
||||||
|
ForbiddenWord.id == word_id, # type: ignore
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update_word(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
update_data: ForbiddenWordUpdate,
|
||||||
|
) -> ForbiddenWord | None:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-80] Update a forbidden word.
|
||||||
|
"""
|
||||||
|
word = await self.get_word(tenant_id, word_id)
|
||||||
|
if not word:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if update_data.word is not None:
|
||||||
|
word.word = update_data.word
|
||||||
|
if update_data.category is not None:
|
||||||
|
if update_data.category not in self.VALID_CATEGORIES:
|
||||||
|
raise ValueError(f"Invalid category. Must be one of: {self.VALID_CATEGORIES}")
|
||||||
|
word.category = update_data.category
|
||||||
|
if update_data.strategy is not None:
|
||||||
|
if update_data.strategy not in self.VALID_STRATEGIES:
|
||||||
|
raise ValueError(f"Invalid strategy. Must be one of: {self.VALID_STRATEGIES}")
|
||||||
|
word.strategy = update_data.strategy
|
||||||
|
if update_data.replacement is not None:
|
||||||
|
word.replacement = update_data.replacement
|
||||||
|
if update_data.fallback_reply is not None:
|
||||||
|
word.fallback_reply = update_data.fallback_reply
|
||||||
|
if update_data.is_enabled is not None:
|
||||||
|
word.is_enabled = update_data.is_enabled
|
||||||
|
|
||||||
|
word.updated_at = datetime.utcnow()
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-80] Updated forbidden word: tenant={tenant_id}, id={word_id}"
|
||||||
|
)
|
||||||
|
return word
|
||||||
|
|
||||||
|
async def delete_word(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-81] Delete a forbidden word.
|
||||||
|
"""
|
||||||
|
word = await self.get_word(tenant_id, word_id)
|
||||||
|
if not word:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._session.delete(word)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-81] Deleted forbidden word: tenant={tenant_id}, id={word_id}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def increment_hit_count(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
word_id: uuid.UUID,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-79] Increment hit count for a word.
|
||||||
|
"""
|
||||||
|
word = await self.get_word(tenant_id, word_id)
|
||||||
|
if not word:
|
||||||
|
return False
|
||||||
|
|
||||||
|
word.hit_count += 1
|
||||||
|
word.updated_at = datetime.utcnow()
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[AC-AISVC-79] Incremented hit count for word: tenant={tenant_id}, "
|
||||||
|
f"id={word_id}, hit_count={word.hit_count}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_enabled_words_for_filtering(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[ForbiddenWord]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-82] Get enabled words for filtering.
|
||||||
|
Uses cache for performance.
|
||||||
|
"""
|
||||||
|
cached = self._cache.get(tenant_id)
|
||||||
|
if cached is not None:
|
||||||
|
logger.debug(f"[AC-AISVC-82] Cache hit for words: tenant={tenant_id}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ForbiddenWord)
|
||||||
|
.where(
|
||||||
|
ForbiddenWord.tenant_id == tenant_id, # type: ignore
|
||||||
|
ForbiddenWord.is_enabled == True, # type: ignore
|
||||||
|
)
|
||||||
|
.order_by(col(ForbiddenWord.category))
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
words = list(result.scalars().all())
|
||||||
|
|
||||||
|
self._cache.set(tenant_id, words)
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-82] Loaded {len(words)} enabled words from DB: tenant={tenant_id}"
|
||||||
|
)
|
||||||
|
return words
|
||||||
|
|
||||||
|
def invalidate_cache(self, tenant_id: str) -> None:
|
||||||
|
"""Manually invalidate cache for a tenant."""
|
||||||
|
self._cache.invalidate(tenant_id)
|
||||||
|
|
||||||
|
async def word_to_info_dict(self, word: ForbiddenWord) -> dict[str, Any]:
|
||||||
|
"""Convert word entity to API response dict."""
|
||||||
|
return {
|
||||||
|
"id": str(word.id),
|
||||||
|
"word": word.word,
|
||||||
|
"category": word.category,
|
||||||
|
"strategy": word.strategy,
|
||||||
|
"replacement": word.replacement,
|
||||||
|
"fallback_reply": word.fallback_reply,
|
||||||
|
"is_enabled": word.is_enabled,
|
||||||
|
"hit_count": word.hit_count,
|
||||||
|
"created_at": word.created_at.isoformat(),
|
||||||
|
"updated_at": word.updated_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
- module: `ai-service`
|
- module: `ai-service`
|
||||||
- feature: `AISVC` (Python AI 中台)
|
- feature: `AISVC` (Python AI 中台)
|
||||||
- status: 🔄 进行中 (Phase 12)
|
- status: 🔄 进行中 (Phase 14 完成)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -35,27 +35,29 @@
|
||||||
- [x] Phase 10: Prompt 模板化 (80%) 🔄 (T10.1-T10.8 完成,T10.9-T10.10 待集成阶段)
|
- [x] Phase 10: Prompt 模板化 (80%) 🔄 (T10.1-T10.8 完成,T10.9-T10.10 待集成阶段)
|
||||||
- [x] Phase 11: 多知识库管理 (63%) 🔄 (T11.1-T11.5 完成,T11.6-T11.8 待集成阶段)
|
- [x] Phase 11: 多知识库管理 (63%) 🔄 (T11.1-T11.5 完成,T11.6-T11.8 待集成阶段)
|
||||||
- [x] Phase 12: 意图识别与规则引擎 (71%) 🔄 (T12.1-T12.5 完成,T12.6-T12.7 待集成阶段)
|
- [x] Phase 12: 意图识别与规则引擎 (71%) 🔄 (T12.1-T12.5 完成,T12.6-T12.7 待集成阶段)
|
||||||
|
- [x] Phase 13: 话术流程引擎 (0%) ⏳ 待处理
|
||||||
|
- [x] Phase 14: 输出护栏 (88%) ✅ (T14.1-T14.7 完成,T14.8 单元测试留到集成阶段)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🔄 Current Phase
|
## 🔄 Current Phase
|
||||||
|
|
||||||
### Goal
|
### Goal
|
||||||
Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5),T11.6(OptimizedRetriever 多 Collection 检索)、T11.7(kb_default 迁移)、T11.8(单元测试)留待集成阶段。
|
Phase 14 输出护栏核心功能已完成 (T14.1-T14.7),T14.8(单元测试)留到集成阶段。
|
||||||
|
|
||||||
### Completed Tasks (Phase 11)
|
### Completed Tasks (Phase 14)
|
||||||
|
|
||||||
- [x] T11.1 扩展 `KnowledgeBase` 实体:新增 `kb_type`、`priority`、`is_enabled`、`doc_count` 字段 `[AC-AISVC-59]` ✅
|
- [x] T14.1 定义 `ForbiddenWord` 和 `BehaviorRule` SQLModel 实体,创建数据库表 `[AC-AISVC-78, AC-AISVC-84]` ✅
|
||||||
- [x] T11.2 实现知识库 CRUD 服务:创建时初始化 Qdrant Collection,删除时清理 Collection `[AC-AISVC-59, AC-AISVC-61, AC-AISVC-62]` ✅
|
- [x] T14.2 实现 `ForbiddenWordService`:禁词 CRUD + 命中统计 `[AC-AISVC-78, AC-AISVC-79, AC-AISVC-80, AC-AISVC-81]` ✅
|
||||||
- [x] T11.3 实现知识库管理 API:`POST/GET/PUT/DELETE /admin/kb/knowledge-bases` `[AC-AISVC-59, AC-AISVC-60, AC-AISVC-61, AC-AISVC-62]` ✅
|
- [x] T14.3 实现 `BehaviorRuleService`:行为规则 CRUD `[AC-AISVC-84, AC-AISVC-85]` ✅
|
||||||
- [x] T11.4 升级 Qdrant Collection 命名:`kb_{tenant_id}_{kb_id}`,兼容现有 `kb_{tenant_id}` `[AC-AISVC-63]` ✅
|
- [x] T14.4 实现 `InputScanner`:用户输入前置禁词检测(仅记录,不阻断) `[AC-AISVC-83]` ✅
|
||||||
- [x] T11.5 修改文档上传流程:支持指定 `kbId` 参数,索引到对应 Collection `[AC-AISVC-63]` ✅
|
- [x] T14.5 实现 `OutputFilter`:LLM 输出后置过滤(mask/replace/block 三种策略) `[AC-AISVC-82]` ✅
|
||||||
|
- [x] T14.6 实现 Streaming 模式下的滑动窗口禁词检测 `[AC-AISVC-82]` ✅
|
||||||
|
- [x] T14.7 实现护栏管理 API:`/admin/guardrails` 相关端点 `[AC-AISVC-78~AC-AISVC-85]` ✅
|
||||||
|
|
||||||
### Pending Tasks (Phase 11 - 集成阶段)
|
### Pending Tasks (Phase 14 - 集成阶段)
|
||||||
|
|
||||||
- [ ] T11.6 修改 `OptimizedRetriever`:支持 `target_kb_ids` 参数,实现多 Collection 并行检索 `[AC-AISVC-64]`
|
- [ ] T14.8 编写输出护栏服务单元测试 `[AC-AISVC-78~AC-AISVC-85]`
|
||||||
- [ ] T11.7 实现 `kb_default` 自动迁移:首次启动时为现有数据创建默认知识库记录 `[AC-AISVC-59]`
|
|
||||||
- [ ] T11.8 编写多知识库服务单元测试 `[AC-AISVC-59~AC-AISVC-64]`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -66,42 +68,73 @@ Phase 11 多知识库管理核心功能已完成 (T11.1-T11.5),T11.6(Optimiz
|
||||||
- `ai-service/`
|
- `ai-service/`
|
||||||
- `app/`
|
- `app/`
|
||||||
- `api/` - FastAPI 路由层
|
- `api/` - FastAPI 路由层
|
||||||
- `admin/intent_rules.py` - 意图规则管理 API ✅
|
- `admin/guardrails.py` - 护栏管理 API ✅
|
||||||
- `admin/prompt_templates.py` - Prompt 模板管理 API ✅
|
|
||||||
- `models/` - Pydantic 模型和 SQLModel 实体
|
- `models/` - Pydantic 模型和 SQLModel 实体
|
||||||
- `entities.py` - IntentRule, PromptTemplate, PromptTemplateVersion 实体 ✅
|
- `entities.py` - ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体 ✅
|
||||||
- `services/`
|
- `services/`
|
||||||
- `intent/` - 意图识别服务 ✅
|
- `guardrail/` - 输出护栏服务 ✅
|
||||||
- `__init__.py` - 模块导出
|
- `__init__.py` - 模块导出
|
||||||
- `rule_service.py` - 规则 CRUD、命中统计、缓存
|
- `word_service.py` - 禁词 CRUD、命中统计、缓存
|
||||||
- `router.py` - IntentRouter 匹配引擎
|
- `behavior_service.py` - 行为规则 CRUD、缓存、Prompt 注入格式化
|
||||||
- `prompt/` - Prompt 模板服务 ✅
|
- `input_scanner.py` - 用户输入前置检测(仅记录,不阻断)
|
||||||
- `__init__.py` - 模块导出
|
- `output_filter.py` - LLM 输出后置过滤(mask/replace/block)
|
||||||
- `template_service.py` - 模板 CRUD、版本管理、发布/回滚、缓存
|
- `streaming_filter.py` - Streaming 滑动窗口检测
|
||||||
- `variable_resolver.py` - 变量替换引擎
|
|
||||||
|
|
||||||
### Key Decisions (Why / Impact)
|
### Key Decisions (Why / Impact)
|
||||||
|
|
||||||
- decision: 意图规则数据库驱动
|
- decision: 三种禁词替换策略
|
||||||
reason: 支持动态配置意图识别规则,无需重启服务
|
reason: 满足不同场景的内容合规需求
|
||||||
impact: 规则存储在 PostgreSQL,支持按租户隔离
|
impact: mask 星号替换、replace 指定文本替换、block 拦截整条回复返回兜底话术
|
||||||
|
|
||||||
- decision: 关键词 + 正则双匹配机制
|
- decision: 输入检测不阻断
|
||||||
reason: 关键词匹配快速高效,正则匹配支持复杂模式
|
reason: 用户输入包含禁词时仍需正常处理,仅记录用于监控分析
|
||||||
impact: 先关键词匹配再正则匹配,优先级高的规则先匹配
|
impact: InputScanner 返回 flagged 状态和匹配信息,不抛异常
|
||||||
|
|
||||||
|
- decision: Streaming 滑动窗口检测
|
||||||
|
reason: 流式输出无法预知完整内容,需要缓冲区检测跨 chunk 的禁词
|
||||||
|
impact: 维护滑动窗口 buffer(默认 50 字符,自动调整到最长禁词长度),检测到禁词后立即停止
|
||||||
|
|
||||||
|
- decision: 行为规则注入 Prompt
|
||||||
|
reason: 行为规则作为 LLM 的行为约束,不进行运行时检测
|
||||||
|
impact: BehaviorRuleService 提供 format_rules_for_prompt() 方法,追加到系统指令末尾
|
||||||
|
|
||||||
- decision: 内存缓存 + TTL 策略
|
- decision: 内存缓存 + TTL 策略
|
||||||
reason: 减少数据库查询,提升匹配性能
|
reason: 减少数据库查询,提升过滤性能
|
||||||
impact: 缓存 TTL=60s,CRUD 操作时主动失效
|
impact: 缓存 TTL=60s,CRUD 操作时主动失效
|
||||||
|
|
||||||
- decision: 四种响应类型
|
|
||||||
reason: 支持不同的处理链路
|
|
||||||
impact: fixed 直接返回、rag 定向检索、flow 进入流程、transfer 转人工
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🧾 Session History
|
## 🧾 Session History
|
||||||
|
|
||||||
|
### Session #10 (2026-02-27)
|
||||||
|
- completed:
|
||||||
|
- T14.1-T14.7 输出护栏核心功能
|
||||||
|
- 实现 ForbiddenWord 和 BehaviorRule 实体
|
||||||
|
- 实现 ForbiddenWordService(CRUD、命中统计、缓存)
|
||||||
|
- 实现 BehaviorRuleService(CRUD、缓存、Prompt 注入格式化)
|
||||||
|
- 实现 InputScanner(用户输入前置检测,仅记录不阻断)
|
||||||
|
- 实现 OutputFilter(LLM 输出后置过滤,mask/replace/block 三种策略)
|
||||||
|
- 实现 StreamingGuardrail(Streaming 滑动窗口检测)
|
||||||
|
- 实现护栏管理 API(禁词和行为规则 CRUD)
|
||||||
|
- changes:
|
||||||
|
- 新增 `app/models/entities.py` ForbiddenWord, BehaviorRule, GuardrailResult, InputScanResult 实体
|
||||||
|
- 新增 `app/services/guardrail/__init__.py` 模块导出
|
||||||
|
- 新增 `app/services/guardrail/word_service.py` 禁词服务
|
||||||
|
- 新增 `app/services/guardrail/behavior_service.py` 行为规则服务
|
||||||
|
- 新增 `app/services/guardrail/input_scanner.py` 输入扫描器
|
||||||
|
- 新增 `app/services/guardrail/output_filter.py` 输出过滤器
|
||||||
|
- 新增 `app/services/guardrail/streaming_filter.py` Streaming 过滤器
|
||||||
|
- 新增 `app/api/admin/guardrails.py` 护栏管理 API
|
||||||
|
- 更新 `app/api/admin/__init__.py` 导出新路由
|
||||||
|
- 更新 `app/main.py` 注册新路由
|
||||||
|
- 更新 `spec/ai-service/tasks.md` 标记任务完成
|
||||||
|
- notes:
|
||||||
|
- T14.8(单元测试)留到集成阶段
|
||||||
|
- 禁词检测三种策略:mask(星号替换)、replace(指定文本替换)、block(拦截返回兜底话术)
|
||||||
|
- InputScanner 仅记录命中,不阻断请求
|
||||||
|
- OutputFilter 应用 mask/replace/block 策略
|
||||||
|
- StreamingGuardrail 使用滑动窗口 buffer(默认 50 字符,自动调整)
|
||||||
|
|
||||||
### Session #9 (2026-02-27)
|
### Session #9 (2026-02-27)
|
||||||
- completed:
|
- completed:
|
||||||
- T11.1-T11.5 多知识库管理核心功能
|
- T11.1-T11.5 多知识库管理核心功能
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue