Merge pull request '[AC-AISVC-50] 合入第一个稳定版本' (#2) from feature/prompt-unification-and-logging into main

Reviewed-on: #2
This commit is contained in:
MerCry 2026-02-26 13:03:31 +00:00
commit 60bf649d96
45 changed files with 1888 additions and 1085 deletions

21
.env.example Normal file
View File

@ -0,0 +1,21 @@
# AI Service Environment Variables
# Copy this file to .env and modify as needed
# LLM Configuration (OpenAI)
AI_SERVICE_LLM_PROVIDER=openai
AI_SERVICE_LLM_API_KEY=your-api-key-here
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
AI_SERVICE_LLM_MODEL=gpt-4o-mini
# If using DeepSeek
# AI_SERVICE_LLM_PROVIDER=deepseek
# AI_SERVICE_LLM_API_KEY=your-deepseek-api-key
# AI_SERVICE_LLM_MODEL=deepseek-chat
# Ollama Configuration (for embedding model)
AI_SERVICE_OLLAMA_BASE_URL=http://ollama:11434
AI_SERVICE_OLLAMA_EMBEDDING_MODEL=nomic-embed-text
# Frontend API Key (required for admin panel authentication)
# Get this key from the backend logs after first startup, or from /admin/api-keys
VITE_APP_API_KEY=your-frontend-api-key-here

1
.gitignore vendored
View File

@ -162,5 +162,6 @@ cython_debug/
# Project specific # Project specific
ai-service/uploads/ ai-service/uploads/
ai-service/config/
*.local *.local

295
README.md
View File

@ -1,3 +1,294 @@
# ai-robot-core # AI Robot Core
ai中台业务的能力支撑 AI中台业务的能力支撑提供智能客服、RAG知识库检索、LLM对话等核心能力。
## 项目结构
```
ai-robot-core/
├── ai-service/ # Python 后端服务
│ ├── app/ # FastAPI 应用
│ ├── tests/ # 测试用例
│ ├── Dockerfile # 后端镜像
│ └── pyproject.toml # Python 依赖
├── ai-service-admin/ # Vue 前端管理界面
│ ├── src/ # Vue 源码
│ ├── Dockerfile # 前端镜像
│ ├── nginx.conf # Nginx 配置
│ └── package.json # Node 依赖
├── docker-compose.yaml # 容器编排
├── .env.example # 环境变量示例
└── README.md
```
## 功能特性
- **多租户支持**: 通过 X-Tenant-Id 头实现租户隔离
- **RAG 知识库**: 基于 Qdrant 的向量检索增强生成
- **LLM 集成**: 支持 OpenAI、DeepSeek、Ollama 等多种 LLM 提供商
- **SSE 流式输出**: 支持 Server-Sent Events 实时响应
- **置信度评估**: 自动评估回复质量,低置信度时建议转人工
## 快速开始
### 环境要求
- Docker 20.10+
- Docker Compose 2.0+
### 部署步骤
#### 1. 克隆代码
```bash
git clone http://49.232.209.156:3005/MerCry/ai-robot-core.git
cd ai-robot-core
```
#### 2. 配置环境变量
```bash
cp .env.example .env
```
编辑 `.env` 文件,配置 LLM API
```env
# OpenAI 配置
AI_SERVICE_LLM_PROVIDER=openai
AI_SERVICE_LLM_API_KEY=your-openai-api-key
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
AI_SERVICE_LLM_MODEL=gpt-4o-mini
# 或使用 DeepSeek
# AI_SERVICE_LLM_PROVIDER=deepseek
# AI_SERVICE_LLM_API_KEY=your-deepseek-api-key
# AI_SERVICE_LLM_MODEL=deepseek-chat
```
#### 3. 启动服务
```bash
# Docker Compose V2 (推荐Docker 内置)
docker compose up -d --build
# 或 Docker Compose V1 (旧版,需要单独安装)
docker-compose up -d --build
```
#### 4. 拉取嵌入模型
服务启动后,需要在 Ollama 容器中拉取嵌入模型。推荐使用 `nomic-embed-text-v2-moe`,对中文支持更好:
```bash
# 进入 Ollama 容器拉取模型
docker exec -it ai-ollama ollama pull toshk0/nomic-embed-text-v2-moe:Q6_K
```
**可选模型**
| 模型 | 维度 | 说明 |
|------|------|------|
| `toshk0/nomic-embed-text-v2-moe:Q6_K` | 768 | 推荐,中文支持好,支持任务前缀 |
| `nomic-embed-text:v1.5` | 768 | 原版,支持任务前缀和 Matryoshka |
| `bge-large-zh` | 1024 | 中文专用,效果最好 |
#### 5. 配置嵌入模型
访问前端管理界面,进入 **嵌入模型配置** 页面:
1. 选择提供者:**Nomic Embed (优化版)**
2. 配置参数:
- **API 地址**`http://ollama:11434`Docker 环境)或 `http://localhost:11434`(本地开发)
- **模型名称**`toshk0/nomic-embed-text-v2-moe:Q6_K`
- **向量维度**`768`
- **Matryoshka 截断**`true`
3. 点击 **保存配置**
> **注意**:
> - 使用 Nomic Embed (优化版) provider 可启用完整的 RAG 优化功能任务前缀、Matryoshka 多向量、两阶段检索。
> - 嵌入模型配置会持久化保存到 `ai-service/config/embedding_config.json`,服务重启后自动加载。
> - **重要**: 切换嵌入模型后,需要删除现有知识库并重新上传文档,因为不同模型生成的向量不兼容。
#### 6. 验证服务
```bash
# 检查服务状态
docker ps
# 查看后端日志,找到自动生成的 API Key
docker logs -f ai-service | grep "Default API Key"
```
> **重要**: 后端首次启动时会自动生成一个默认 API Key请从日志中复制该 Key用于前端配置。
#### 7. 配置前端 API Key
```bash
# 创建前端环境变量文件
cd ai-service-admin
cp .env.example .env
```
编辑 `ai-service-admin/.env`,将 `VITE_APP_API_KEY` 设置为后端日志中的 API Key
```env
VITE_APP_BASE_API=/api
VITE_APP_API_KEY=<从后端日志复制的API Key>
```
然后重新构建前端:
```bash
cd ..
docker compose up -d --build ai-service-admin
```
#### 7. 访问服务
| 服务 | 地址 | 说明 |
|------|------|------|
| 前端管理界面 | http://服务器IP:8181 | Vue 管理后台 |
| 后端 API | http://服务器IP:8182 | FastAPI 服务Java渠道侧调用 |
| API 文档 | http://服务器IP:8182/docs | Swagger UI |
| Qdrant 控制台 | http://服务器IP:6333/dashboard | 向量数据库管理 |
| Ollama API | http://服务器IP:11434 | 嵌入模型服务 |
> **端口说明**:
> - `8181`: 前端管理界面,内部代理后端 API
> - `8182`: 后端 API供 Java 渠道侧直接调用
## 服务架构
```
┌─────────────────────────────────────────────────────────┐
│ 用户访问 │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ ai-service-admin (端口8181) │
│ - Nginx 静态文件服务 │
│ - 反向代理 /api/* → ai-service:8080 │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ ai-service (端口8080) │
│ - FastAPI 后端服务 │
│ - RAG / LLM / 知识库管理 │
└─────────────────────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
│ PostgreSQL │ │ Qdrant │ │ Ollama │
│ (端口5432) │ │ (端口6333) │ │ (端口11434) │
│ - 会话存储 │ │ - 向量存储 │ │ - nomic-embed │
│ - 知识库元数据 │ │ - 文档索引 │ │ - 嵌入模型 │
└──────────────────┘ └──────────────────┘ └──────────────────┘
```
## 常用命令
```bash
# 启动所有服务
docker compose up -d
# 重新构建并启动
docker compose up -d --build
# 查看服务状态
docker compose ps
# 查看日志
docker compose logs -f ai-service
docker compose logs -f ai-service-admin
# 重启服务
docker compose restart ai-service
# 停止所有服务
docker compose down
# 停止并删除数据卷(清空数据)
docker compose down -v
```
## 宿主机 Nginx 配置(可选)
如果需要通过宿主机 Nginx 统一管理入口配置域名、SSL证书可参考 `deploy/nginx.conf.example`
```bash
# 复制配置文件
sudo cp deploy/nginx.conf.example /etc/nginx/conf.d/ai-service.conf
# 修改配置中的域名
sudo vim /etc/nginx/conf.d/ai-service.conf
# 测试配置
sudo nginx -t
# 重载 Nginx
sudo nginx -s reload
```
## 本地开发
### 后端开发
```bash
cd ai-service
# 创建虚拟环境
python -m venv .venv
source .venv/bin/activate # Linux/Mac
# .venv\Scripts\activate # Windows
# 安装依赖
pip install -e ".[dev]"
# 启动开发服务器
uvicorn app.main:app --reload --port 8000
```
### 前端开发
```bash
cd ai-service-admin
# 安装依赖
npm install
# 启动开发服务器
npm run dev
```
## API 接口
### 核心接口
| 接口 | 方法 | 说明 |
|------|------|------|
| `/ai/chat` | POST | AI 对话接口 |
| `/admin/kb` | GET/POST | 知识库管理 |
| `/admin/rag/experiments/run` | POST | RAG 实验室 |
| `/admin/llm/config` | GET/PUT | LLM 配置 |
| `/admin/embedding/config` | GET/PUT | 嵌入模型配置 |
详细 API 文档请访问 http://服务器IP:8080/docs
## 环境变量说明
| 变量名 | 默认值 | 说明 |
|--------|--------|------|
| `AI_SERVICE_LLM_PROVIDER` | openai | LLM 提供商 |
| `AI_SERVICE_LLM_API_KEY` | - | API 密钥 |
| `AI_SERVICE_LLM_BASE_URL` | https://api.openai.com/v1 | API 地址 |
| `AI_SERVICE_LLM_MODEL` | gpt-4o-mini | 模型名称 |
| `AI_SERVICE_DATABASE_URL` | postgresql+asyncpg://... | 数据库连接 |
| `AI_SERVICE_QDRANT_URL` | http://qdrant:6333 | Qdrant 地址 |
| `AI_SERVICE_LOG_LEVEL` | INFO | 日志级别 |
## License
MIT

View File

@ -0,0 +1,19 @@
node_modules
dist
.env
.env.local
.env.*.local
*.log
.idea/
.vscode/
*.swp
*.swo
.git
.gitignore
*.md
!README.md

View File

@ -0,0 +1,8 @@
# API Base URL
VITE_APP_BASE_API=/api
# Default API Key for authentication
# IMPORTANT: You must set this to a valid API key from the backend
# The backend creates a default API key on first startup (check backend logs)
# Or you can create one via the API: POST /admin/api-keys
VITE_APP_API_KEY=your-api-key-here

View File

@ -0,0 +1,28 @@
# AI Service Admin Frontend Dockerfile
FROM docker.1ms.run/node:20-alpine AS builder
WORKDIR /app
ARG VITE_APP_API_KEY
ARG VITE_APP_BASE_API=/api
ENV VITE_APP_API_KEY=$VITE_APP_API_KEY
ENV VITE_APP_BASE_API=$VITE_APP_BASE_API
COPY package*.json ./
RUN npm install && npm install @rollup/rollup-linux-x64-musl --save-optional
COPY . .
RUN npm run build
FROM docker.1ms.run/nginx:alpine
COPY --from=builder /app/dist /usr/share/nginx/html
COPY nginx.conf /etc/nginx/conf.d/default.conf
EXPOSE 80
CMD ["nginx", "-g", "daemon off;"]

View File

@ -0,0 +1,29 @@
server {
listen 80;
server_name localhost;
root /usr/share/nginx/html;
index index.html;
location / {
try_files $uri $uri/ /index.html;
}
location /api/ {
proxy_pass http://ai-service:8080/;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_cache_bypass $http_upgrade;
proxy_read_timeout 300s;
proxy_connect_timeout 75s;
proxy_buffering off;
}
gzip on;
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
gzip_min_length 1000;
}

File diff suppressed because it is too large Load Diff

View File

@ -17,8 +17,8 @@
}, },
"devDependencies": { "devDependencies": {
"@vitejs/plugin-vue": "^5.0.4", "@vitejs/plugin-vue": "^5.0.4",
"typescript": "^5.2.2", "typescript": "~5.6.0",
"vite": "^5.1.4", "vite": "^5.1.4",
"vue-tsc": "^1.8.27" "vue-tsc": "^2.1.0"
} }
} }

View File

@ -98,7 +98,6 @@ const isValidTenantId = (tenantId: string): boolean => {
const fetchTenantList = async () => { const fetchTenantList = async () => {
loading.value = true loading.value = true
try { try {
// ID
if (!isValidTenantId(currentTenantId.value)) { if (!isValidTenantId(currentTenantId.value)) {
console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value) console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value)
currentTenantId.value = 'default@ash@2026' currentTenantId.value = 'default@ash@2026'
@ -108,7 +107,6 @@ const fetchTenantList = async () => {
const response = await getTenantList() const response = await getTenantList()
tenantList.value = response.tenants || [] tenantList.value = response.tenants || []
//
if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) { if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) {
const firstTenant = tenantList.value[0] const firstTenant = tenantList.value[0]
currentTenantId.value = firstTenant.id currentTenantId.value = firstTenant.id
@ -117,8 +115,7 @@ const fetchTenantList = async () => {
} catch (error) { } catch (error) {
ElMessage.error('获取租户列表失败') ElMessage.error('获取租户列表失败')
console.error('Failed to fetch tenant list:', error) console.error('Failed to fetch tenant list:', error)
// 使 tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)', displayName: 'default', year: '2026', createdAt: new Date().toISOString() }]
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)' }]
} finally { } finally {
loading.value = false loading.value = false
} }

View File

@ -13,7 +13,7 @@ export interface TenantListResponse {
total: number total: number
} }
export function getTenantList() { export function getTenantList(): Promise<TenantListResponse> {
return request<TenantListResponse>({ return request<TenantListResponse>({
url: '/admin/tenants', url: '/admin/tenants',
method: 'get' method: 'get'

View File

@ -92,6 +92,7 @@ const emit = defineEmits<{
const formRef = ref<FormInstance>() const formRef = ref<FormInstance>()
const formData = ref<Record<string, any>>({}) const formData = ref<Record<string, any>>({})
const isUpdating = ref(false)
const schemaProperties = computed(() => { const schemaProperties = computed(() => {
return props.schema?.properties || {} return props.schema?.properties || {}
@ -173,8 +174,11 @@ const initFormData = () => {
watch( watch(
() => props.modelValue, () => props.modelValue,
() => { (newVal) => {
if (isUpdating.value) return
if (JSON.stringify(newVal) !== JSON.stringify(formData.value)) {
initFormData() initFormData()
}
}, },
{ deep: true } { deep: true }
) )
@ -190,7 +194,14 @@ watch(
watch( watch(
formData, formData,
(val) => { (val) => {
if (isUpdating.value) return
if (JSON.stringify(val) !== JSON.stringify(props.modelValue)) {
isUpdating.value = true
emit('update:modelValue', val) emit('update:modelValue', val)
Promise.resolve().then(() => {
isUpdating.value = false
})
}
}, },
{ deep: true } { deep: true }
) )

View File

@ -92,6 +92,7 @@ const emit = defineEmits<{
const formRef = ref<FormInstance>() const formRef = ref<FormInstance>()
const formData = ref<Record<string, any>>({}) const formData = ref<Record<string, any>>({})
const isUpdating = ref(false)
const schemaProperties = computed(() => { const schemaProperties = computed(() => {
return props.schema?.properties || {} return props.schema?.properties || {}
@ -173,8 +174,11 @@ const initFormData = () => {
watch( watch(
() => props.modelValue, () => props.modelValue,
() => { (newVal) => {
if (isUpdating.value) return
if (JSON.stringify(newVal) !== JSON.stringify(formData.value)) {
initFormData() initFormData()
}
}, },
{ deep: true } { deep: true }
) )
@ -190,7 +194,14 @@ watch(
watch( watch(
formData, formData,
(val) => { (val) => {
if (isUpdating.value) return
if (JSON.stringify(val) !== JSON.stringify(props.modelValue)) {
isUpdating.value = true
emit('update:modelValue', val) emit('update:modelValue', val)
Promise.resolve().then(() => {
isUpdating.value = false
})
}
}, },
{ deep: true } { deep: true }
) )

View File

@ -74,7 +74,8 @@ export const useEmbeddingStore = defineStore('embedding', () => {
provider: currentConfig.value.provider, provider: currentConfig.value.provider,
config: currentConfig.value.config config: currentConfig.value.config
} }
await saveConfig(updateData) const response = await saveConfig(updateData)
return response
} catch (error) { } catch (error) {
console.error('Failed to save config:', error) console.error('Failed to save config:', error)
throw error throw error

View File

@ -1,21 +1,22 @@
import axios from 'axios' import axios, { type AxiosRequestConfig } from 'axios'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { useTenantStore } from '@/stores/tenant' import { useTenantStore } from '@/stores/tenant'
// 创建 axios 实例
const service = axios.create({ const service = axios.create({
baseURL: import.meta.env.VITE_APP_BASE_API || '/api', baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
timeout: 60000 timeout: 60000
}) })
// 请求拦截器
service.interceptors.request.use( service.interceptors.request.use(
(config) => { (config) => {
const tenantStore = useTenantStore() const tenantStore = useTenantStore()
if (tenantStore.currentTenantId) { if (tenantStore.currentTenantId) {
config.headers['X-Tenant-Id'] = tenantStore.currentTenantId config.headers['X-Tenant-Id'] = tenantStore.currentTenantId
} }
// TODO: 如果有 token 也可以在这里注入 Authorization const apiKey = import.meta.env.VITE_APP_API_KEY
if (apiKey) {
config.headers['X-API-Key'] = apiKey
}
return config return config
}, },
(error) => { (error) => {
@ -24,11 +25,9 @@ service.interceptors.request.use(
} }
) )
// 响应拦截器
service.interceptors.response.use( service.interceptors.response.use(
(response) => { (response) => {
const res = response.data const res = response.data
// 这里可以根据后端的 code 进行统一处理
return res return res
}, },
(error) => { (error) => {
@ -42,7 +41,6 @@ service.interceptors.response.use(
cancelButtonText: '取消', cancelButtonText: '取消',
type: 'warning' type: 'warning'
}).then(() => { }).then(() => {
// TODO: 跳转到登录页或执行退出逻辑
location.href = '/login' location.href = '/login'
}) })
} else if (status === 403) { } else if (status === 403) {
@ -69,4 +67,13 @@ service.interceptors.response.use(
} }
) )
export default service interface RequestConfig extends AxiosRequestConfig {
url: string
method?: string
}
function request<T = any>(config: RequestConfig): Promise<T> {
return service.request<any, T>(config)
}
export default request

View File

@ -169,8 +169,19 @@ const handleSave = async () => {
saving.value = true saving.value = true
try { try {
await embeddingStore.saveCurrentConfig() const response: any = await embeddingStore.saveCurrentConfig()
ElMessage.success('配置保存成功') ElMessage.success('配置保存成功')
if (response?.warning || response?.requires_reindex) {
ElMessageBox.alert(
response.warning || '嵌入模型已更改,请重新上传文档以确保检索效果正常。',
'重要提示',
{
confirmButtonText: '我知道了',
type: 'warning',
}
)
}
} catch (error) { } catch (error) {
ElMessage.error('配置保存失败') ElMessage.error('配置保存失败')
} finally { } finally {

View File

@ -102,10 +102,17 @@ interface DocumentItem {
createTime: string createTime: string
} }
interface IndexJob {
jobId: string
status: string
progress: number
errorMsg?: string
}
const tableData = ref<DocumentItem[]>([]) const tableData = ref<DocumentItem[]>([])
const loading = ref(false) const loading = ref(false)
const jobDialogVisible = ref(false) const jobDialogVisible = ref(false)
const currentJob = ref<any>(null) const currentJob = ref<IndexJob | null>(null)
const pollingJobs = ref<Set<string>>(new Set()) const pollingJobs = ref<Set<string>>(new Set())
let pollingInterval: number | null = null let pollingInterval: number | null = null
@ -150,10 +157,15 @@ const fetchDocuments = async () => {
} }
} }
const fetchJobStatus = async (jobId: string) => { const fetchJobStatus = async (jobId: string): Promise<IndexJob | null> => {
try { try {
const res = await getIndexJob(jobId) const res: any = await getIndexJob(jobId)
return res return {
jobId: res.jobId || jobId,
status: res.status || 'pending',
progress: res.progress || 0,
errorMsg: res.errorMsg
}
} catch (error) { } catch (error) {
console.error('Failed to fetch job status:', error) console.error('Failed to fetch job status:', error)
return null return null
@ -246,19 +258,21 @@ const handleFileChange = async (event: Event) => {
try { try {
loading.value = true loading.value = true
const res = await uploadDocument(formData) const res: any = await uploadDocument(formData)
ElMessage.success(`文档上传成功任务ID: ${res.jobId}`) const jobId = res.jobId as string
ElMessage.success(`文档上传成功任务ID: ${jobId}`)
console.log('Upload response:', res) console.log('Upload response:', res)
const newDoc: DocumentItem = { const newDoc: DocumentItem = {
docId: res.docId || '',
name: file.name, name: file.name,
status: res.status || 'pending', status: (res.status as string) || 'pending',
jobId: res.jobId, jobId: jobId,
createTime: new Date().toLocaleString('zh-CN') createTime: new Date().toLocaleString('zh-CN')
} }
tableData.value.unshift(newDoc) tableData.value.unshift(newDoc)
startPolling(res.jobId) startPolling(jobId)
} catch (error) { } catch (error) {
ElMessage.error('文档上传失败') ElMessage.error('文档上传失败')
console.error('Upload error:', error) console.error('Upload error:', error)

View File

@ -327,7 +327,7 @@ const runStreamExperiment = async () => {
} else if (parsed.type === 'error') { } else if (parsed.type === 'error') {
streamError.value = parsed.message || '流式输出错误' streamError.value = parsed.message || '流式输出错误'
streaming.value = false streaming.value = false
ElMessage.error(streamError.value) ElMessage.error(streamError.value || '未知错误')
} }
} catch { } catch {
streamContent.value += data streamContent.value += data

10
ai-service-admin/src/vite-env.d.ts vendored Normal file
View File

@ -0,0 +1,10 @@
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_APP_BASE_API: string
readonly VITE_APP_API_KEY: string
}
interface ImportMeta {
readonly env: ImportMetaEnv
}

View File

@ -15,7 +15,8 @@
"baseUrl": ".", "baseUrl": ".",
"paths": { "paths": {
"@/*": ["src/*"] "@/*": ["src/*"]
} },
"types": ["vite/client"]
}, },
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"], "include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"],
"references": [{ "path": "./tsconfig.node.json" }] "references": [{ "path": "./tsconfig.node.json" }]

53
ai-service/.dockerignore Normal file
View File

@ -0,0 +1,53 @@
__pycache__
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.pytest_cache
.coverage
htmlcov/
.tox/
.hypothesis/
.mypy_cache/
.ruff_cache/
.env
.env.local
.env.*.local
*.log
*.pot
*.pyc
.idea/
.vscode/
*.swp
*.swo
tests/
scripts/
*.md
!README.md
.git
.gitignore
.gitea
check_qdrant.py

32
ai-service/Dockerfile Normal file
View File

@ -0,0 +1,32 @@
# AI Service Backend Dockerfile
FROM docker.1ms.run/python:3.11-slim AS builder
WORKDIR /app
RUN pip install --no-cache-dir uv
COPY pyproject.toml README.md ./
RUN uv pip install --system --no-cache-dir .
FROM docker.1ms.run/python:3.11-slim
WORKDIR /app
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
COPY app ./app
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8080
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]

View File

@ -1,8 +1,9 @@
""" """
Admin API routes for AI Service management. Admin API routes for AI Service management.
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08] Admin management endpoints. [AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
""" """
from app.api.admin.api_key import router as api_key_router
from app.api.admin.dashboard import router as dashboard_router from app.api.admin.dashboard import router as dashboard_router
from app.api.admin.embedding import router as embedding_router from app.api.admin.embedding import router as embedding_router
from app.api.admin.kb import router as kb_router from app.api.admin.kb import router as kb_router
@ -11,4 +12,4 @@ from app.api.admin.rag import router as rag_router
from app.api.admin.sessions import router as sessions_router from app.api.admin.sessions import router as sessions_router
from app.api.admin.tenants import router as tenants_router from app.api.admin.tenants import router as tenants_router
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"] __all__ = ["api_key_router", "dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]

View File

@ -0,0 +1,154 @@
"""
API Key management endpoints.
[AC-AISVC-50] CRUD operations for API keys.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.models.entities import ApiKey, ApiKeyCreate
from app.services.api_key import get_api_key_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/api-keys", tags=["API Keys"])
class ApiKeyResponse(BaseModel):
"""Response model for API key."""
id: str = Field(..., description="API key ID")
key: str = Field(..., description="API key value")
name: str = Field(..., description="API key name")
is_active: bool = Field(..., description="Whether the key is active")
created_at: str = Field(..., description="Creation time")
updated_at: str = Field(..., description="Last update time")
class ApiKeyListResponse(BaseModel):
"""Response model for API key list."""
keys: list[ApiKeyResponse] = Field(..., description="List of API keys")
total: int = Field(..., description="Total count")
class CreateApiKeyRequest(BaseModel):
"""Request model for creating API key."""
name: str = Field(..., description="API key name/description")
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
class ToggleApiKeyRequest(BaseModel):
"""Request model for toggling API key status."""
is_active: bool = Field(..., description="New active status")
def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse:
"""Convert ApiKey entity to response model."""
return ApiKeyResponse(
id=str(api_key.id),
key=api_key.key,
name=api_key.name,
is_active=api_key.is_active,
created_at=api_key.created_at.isoformat(),
updated_at=api_key.updated_at.isoformat(),
)
@router.get("", response_model=ApiKeyListResponse)
async def list_api_keys(
session: Annotated[AsyncSession, Depends(get_session)],
):
"""
[AC-AISVC-50] List all API keys.
"""
service = get_api_key_service()
keys = await service.list_keys(session)
return ApiKeyListResponse(
keys=[api_key_to_response(k) for k in keys],
total=len(keys),
)
@router.post("", response_model=ApiKeyResponse, status_code=status.HTTP_201_CREATED)
async def create_api_key(
request: CreateApiKeyRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
"""
[AC-AISVC-50] Create a new API key.
"""
service = get_api_key_service()
key_value = request.key or service.generate_key()
key_create = ApiKeyCreate(
key=key_value,
name=request.name,
is_active=True,
)
api_key = await service.create_key(session, key_create)
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
return api_key_to_response(api_key)
@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_api_key(
key_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
):
"""
[AC-AISVC-50] Delete an API key.
"""
service = get_api_key_service()
deleted = await service.delete_key(session, key_id)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found",
)
@router.patch("/{key_id}/toggle", response_model=ApiKeyResponse)
async def toggle_api_key(
key_id: str,
request: ToggleApiKeyRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
"""
[AC-AISVC-50] Toggle API key active status.
"""
service = get_api_key_service()
api_key = await service.toggle_key(session, key_id, request.is_active)
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found",
)
return api_key_to_response(api_key)
@router.post("/reload-cache", status_code=status.HTTP_204_NO_CONTENT)
async def reload_api_key_cache(
session: Annotated[AsyncSession, Depends(get_session)],
):
"""
[AC-AISVC-50] Reload API key cache from database.
"""
service = get_api_key_service()
await service.reload_cache(session)

View File

@ -78,12 +78,32 @@ async def update_embedding_config(
manager = get_embedding_config_manager() manager = get_embedding_config_manager()
old_config = manager.get_full_config()
old_provider = old_config.get("provider")
old_model = old_config.get("config", {}).get("model", "")
new_model = config.get("model", "")
try: try:
await manager.update_config(provider, config) await manager.update_config(provider, config)
return {
response = {
"success": True, "success": True,
"message": f"Configuration updated to use {provider}", "message": f"Configuration updated to use {provider}",
} }
if old_provider != provider or old_model != new_model:
response["warning"] = (
"嵌入模型已更改。由于不同模型生成的向量不兼容,"
"请删除现有知识库并重新上传文档,以确保检索效果正常。"
)
response["requires_reindex"] = True
logger.warning(
f"[EMBEDDING] Model changed from {old_provider}/{old_model} to {provider}/{new_model}. "
f"Documents need to be re-uploaded."
)
return response
except EmbeddingException as e: except EmbeddingException as e:
raise InvalidRequestException(str(e)) raise InvalidRequestException(str(e))

View File

@ -442,13 +442,15 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}") logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client() qdrant = await get_qdrant_client()
await qdrant.ensure_collection_exists(tenant_id) await qdrant.ensure_collection_exists(tenant_id, use_multi_vector=True)
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
logger.info(f"[INDEX] Using multi-vector format: {use_multi_vector}")
points = [] points = []
total_chunks = len(all_chunks) total_chunks = len(all_chunks)
for i, chunk in enumerate(all_chunks): for i, chunk in enumerate(all_chunks):
embedding = await embedding_provider.embed(chunk.text)
payload = { payload = {
"text": chunk.text, "text": chunk.text,
"source": doc_id, "source": doc_id,
@ -461,6 +463,19 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
if chunk.source: if chunk.source:
payload["filename"] = chunk.source payload["filename"] = chunk.source
if use_multi_vector:
embedding_result = await embedding_provider.embed_document(chunk.text)
points.append({
"id": str(uuid.uuid4()),
"vector": {
"full": embedding_result.embedding_full,
"dim_256": embedding_result.embedding_256,
"dim_512": embedding_result.embedding_512,
},
"payload": payload,
})
else:
embedding = await embedding_provider.embed(chunk.text)
points.append( points.append(
PointStruct( PointStruct(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -478,6 +493,9 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
if points: if points:
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...") logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points)
else:
await qdrant.upsert_vectors(tenant_id, points) await qdrant.upsert_vectors(tenant_id, points)
await kb_service.update_job_status( await kb_service.update_job_status(

View File

@ -1,6 +1,6 @@
""" """
Middleware for AI Service. Middleware for AI Service.
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection. [AC-AISVC-10, AC-AISVC-12, AC-AISVC-50] X-Tenant-Id header validation, tenant context injection, and API Key authentication.
""" """
import logging import logging
@ -17,12 +17,20 @@ from app.core.tenant import clear_tenant_context, set_tenant_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TENANT_ID_HEADER = "X-Tenant-Id" TENANT_ID_HEADER = "X-Tenant-Id"
API_KEY_HEADER = "X-API-Key"
ACCEPT_HEADER = "Accept" ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream" SSE_CONTENT_TYPE = "text/event-stream"
# Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$') TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
PATHS_SKIP_API_KEY = {
"/health",
"/ai/health",
"/docs",
"/redoc",
"/openapi.json",
}
def validate_tenant_id_format(tenant_id: str) -> bool: def validate_tenant_id_format(tenant_id: str) -> bool:
""" """
@ -41,6 +49,59 @@ def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
return parts[0], parts[2] return parts[0], parts[2]
class ApiKeyMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-50] Middleware to validate API Key for all requests.
Features:
- Validates X-API-Key header against in-memory cache
- Skips validation for health/docs endpoints
- Returns 401 for missing or invalid API key
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if self._should_skip_api_key(request.url.path):
return await call_next(request)
api_key = request.headers.get(API_KEY_HEADER)
if not api_key or not api_key.strip():
logger.warning(f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}")
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Missing required header: X-API-Key",
).model_dump(exclude_none=True),
)
api_key = api_key.strip()
from app.services.api_key import get_api_key_service
service = get_api_key_service()
if not service.validate_key(api_key):
logger.warning(f"[AC-AISVC-50] Invalid API key for {request.url.path}")
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Invalid API key",
).model_dump(exclude_none=True),
)
return await call_next(request)
def _should_skip_api_key(self, path: str) -> bool:
"""Check if the path should skip API key validation."""
if path in PATHS_SKIP_API_KEY:
return True
for skip_path in PATHS_SKIP_API_KEY:
if path.startswith(skip_path):
return True
return False
class TenantContextMiddleware(BaseHTTPMiddleware): class TenantContextMiddleware(BaseHTTPMiddleware):
""" """
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header. [AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
@ -51,7 +112,7 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response: async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context() clear_tenant_context()
if request.url.path == "/ai/health": if request.url.path in ("/health", "/ai/health"):
return await call_next(request) return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER) tenant_id = request.headers.get(TENANT_ID_HEADER)
@ -68,7 +129,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
tenant_id = tenant_id.strip() tenant_id = tenant_id.strip()
# Validate tenant ID format
if not validate_tenant_id_format(tenant_id): if not validate_tenant_id_format(tenant_id):
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}") logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
return JSONResponse( return JSONResponse(
@ -79,13 +139,11 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
).model_dump(exclude_none=True), ).model_dump(exclude_none=True),
) )
# Auto-create tenant if not exists (for admin endpoints)
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"): if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
try: try:
await self._ensure_tenant_exists(request, tenant_id) await self._ensure_tenant_exists(request, tenant_id)
except Exception as e: except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}") logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
# Continue processing even if tenant creation fails
set_tenant_context(tenant_id) set_tenant_context(tenant_id)
request.state.tenant_id = tenant_id request.state.tenant_id = tenant_id
@ -112,7 +170,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
name, year = parse_tenant_id(tenant_id) name, year = parse_tenant_id(tenant_id)
async with async_session_maker() as session: async with async_session_maker() as session:
# Check if tenant exists
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id) stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
result = await session.execute(stmt) result = await session.execute(stmt)
existing_tenant = result.scalar_one_or_none() existing_tenant = result.scalar_one_or_none()
@ -121,7 +178,6 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}") logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
return return
# Create new tenant
new_tenant = Tenant( new_tenant = Tenant(
tenant_id=tenant_id, tenant_id=tenant_id,
name=name, name=name,

View File

@ -8,7 +8,7 @@ import logging
from typing import Any from typing import Any
from qdrant_client import AsyncQdrantClient from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig from qdrant_client.models import Distance, PointStruct, VectorParams, QueryRequest
from app.core.config import get_settings from app.core.config import get_settings
@ -61,8 +61,7 @@ class QdrantClient:
collection_name = self.get_collection_name(tenant_id) collection_name = self.get_collection_name(tenant_id)
try: try:
collections = await client.get_collections() exists = await client.collection_exists(collection_name)
exists = any(c.name == collection_name for c in collections.collections)
if not exists: if not exists:
if use_multi_vector: if use_multi_vector:
@ -176,6 +175,7 @@ class QdrantClient:
limit: int = 5, limit: int = 5,
score_threshold: float | None = None, score_threshold: float | None = None,
vector_name: str = "full", vector_name: str = "full",
with_vectors: bool = False,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
[AC-AISVC-10] Search vectors in tenant's collection. [AC-AISVC-10] Search vectors in tenant's collection.
@ -189,6 +189,7 @@ class QdrantClient:
score_threshold: Minimum score threshold for results score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections) vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup. Default is "full" for 768-dim vectors in Matryoshka setup.
with_vectors: Whether to return vectors in results (for two-stage reranking)
""" """
client = await self.get_client() client = await self.get_client()
@ -211,39 +212,50 @@ class QdrantClient:
try: try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}") logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
continue
try: try:
results = await client.search( results = await client.query_points(
collection_name=collection_name, collection_name=collection_name,
query_vector=(vector_name, query_vector), query=query_vector,
using=vector_name,
limit=limit, limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
) )
except Exception as e: except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e): if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
logger.info( logger.info(
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', " f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
f"trying without vector name (single-vector mode)" f"trying without vector name (single-vector mode)"
) )
results = await client.search( results = await client.query_points(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query=query_vector,
limit=limit, limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
) )
else: else:
raise raise
logger.info( logger.info(
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results" f"[AC-AISVC-10] Collection {collection_name} returned {len(results.points)} raw results"
) )
hits = [ hits = []
{ for result in results.points:
hit = {
"id": str(result.id), "id": str(result.id),
"score": result.score, "score": result.score,
"payload": result.payload or {}, "payload": result.payload or {},
} }
for result in results if with_vectors and result.vector:
if score_threshold is None or result.score >= score_threshold hit["vector"] = result.vector
] hits.append(hit)
all_hits.extend(hits) all_hits.extend(hits)
if hits: if hits:

View File

@ -12,7 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from app.api import chat_router, health_router from app.api import chat_router, health_router
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router from app.api.admin import api_key_router, dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
from app.api.admin.kb_optimized import router as kb_optimized_router from app.api.admin.kb_optimized import router as kb_optimized_router
from app.core.config import get_settings from app.core.config import get_settings
from app.core.database import close_db, init_db from app.core.database import close_db, init_db
@ -24,7 +24,7 @@ from app.core.exceptions import (
generic_exception_handler, generic_exception_handler,
http_exception_handler, http_exception_handler,
) )
from app.core.middleware import TenantContextMiddleware from app.core.middleware import ApiKeyMiddleware, TenantContextMiddleware
from app.core.qdrant_client import close_qdrant_client from app.core.qdrant_client import close_qdrant_client
settings = get_settings() settings = get_settings()
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
""" """
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager. [AC-AISVC-01, AC-AISVC-11, AC-AISVC-50] Application lifespan manager.
Handles startup and shutdown of database and external connections. Handles startup and shutdown of database and external connections.
""" """
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}") logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
@ -51,6 +51,19 @@ async def lifespan(app: FastAPI):
except Exception as e: except Exception as e:
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}") logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
try:
from app.core.database import async_session_maker
from app.services.api_key import get_api_key_service
async with async_session_maker() as session:
api_key_service = get_api_key_service()
await api_key_service.initialize(session)
default_key = await api_key_service.create_default_key(session)
if default_key:
logger.info(f"[AC-AISVC-50] Default API key created: {default_key.key}")
except Exception as e:
logger.warning(f"[AC-AISVC-50] API key initialization skipped: {e}")
yield yield
await close_db() await close_db()
@ -87,6 +100,7 @@ app.add_middleware(
) )
app.add_middleware(TenantContextMiddleware) app.add_middleware(TenantContextMiddleware)
app.add_middleware(ApiKeyMiddleware)
app.add_exception_handler(AIServiceException, ai_service_exception_handler) app.add_exception_handler(AIServiceException, ai_service_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(HTTPException, http_exception_handler)
@ -113,6 +127,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
app.include_router(health_router) app.include_router(health_router)
app.include_router(chat_router) app.include_router(chat_router)
app.include_router(api_key_router)
app.include_router(dashboard_router) app.include_router(dashboard_router)
app.include_router(embedding_router) app.include_router(embedding_router)
app.include_router(kb_router) app.include_router(kb_router)

View File

@ -50,6 +50,7 @@ class ErrorCode(str, Enum):
INVALID_REQUEST = "INVALID_REQUEST" INVALID_REQUEST = "INVALID_REQUEST"
MISSING_TENANT_ID = "MISSING_TENANT_ID" MISSING_TENANT_ID = "MISSING_TENANT_ID"
INVALID_TENANT_ID = "INVALID_TENANT_ID" INVALID_TENANT_ID = "INVALID_TENANT_ID"
UNAUTHORIZED = "UNAUTHORIZED"
INTERNAL_ERROR = "INTERNAL_ERROR" INTERNAL_ERROR = "INTERNAL_ERROR"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
TIMEOUT = "TIMEOUT" TIMEOUT = "TIMEOUT"

View File

@ -198,3 +198,27 @@ class DocumentCreate(SQLModel):
file_path: str | None = None file_path: str | None = None
file_size: int | None = None file_size: int | None = None
file_type: str | None = None file_type: str | None = None
class ApiKey(SQLModel, table=True):
"""
[AC-AISVC-50] API Key entity for lightweight authentication.
Keys are loaded into memory on startup for fast validation.
"""
__tablename__ = "api_keys"
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
name: str = Field(..., description="Key name/description for identification")
is_active: bool = Field(default=True, description="Whether the key is active")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class ApiKeyCreate(SQLModel):
"""Schema for creating a new API key."""
key: str
name: str
is_active: bool = True

View File

@ -0,0 +1,249 @@
"""
API Key management service.
[AC-AISVC-50] Lightweight authentication with in-memory cache.
"""
import logging
import secrets
from datetime import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import ApiKey, ApiKeyCreate
logger = logging.getLogger(__name__)
class ApiKeyService:
"""
[AC-AISVC-50] API Key management service.
Features:
- In-memory cache for fast validation
- Database persistence
- Hot-reload support
"""
def __init__(self):
self._keys_cache: set[str] = set()
self._initialized: bool = False
async def initialize(self, session: AsyncSession) -> None:
"""
Load all active API keys from database into memory.
Should be called on application startup.
"""
result = await session.execute(
select(ApiKey).where(ApiKey.is_active == True)
)
keys = result.scalars().all()
self._keys_cache = {key.key for key in keys}
self._initialized = True
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
def validate_key(self, key: str) -> bool:
"""
Validate an API key against the in-memory cache.
Args:
key: The API key to validate
Returns:
True if the key is valid, False otherwise
"""
if not self._initialized:
logger.warning("[AC-AISVC-50] API key service not initialized")
return False
return key in self._keys_cache
def generate_key(self) -> str:
"""
Generate a new secure API key.
Returns:
A URL-safe random string
"""
return secrets.token_urlsafe(32)
async def create_key(
self,
session: AsyncSession,
key_create: ApiKeyCreate
) -> ApiKey:
"""
Create a new API key.
Args:
session: Database session
key_create: Key creation data
Returns:
The created ApiKey entity
"""
api_key = ApiKey(
key=key_create.key,
name=key_create.name,
is_active=key_create.is_active,
)
session.add(api_key)
await session.commit()
await session.refresh(api_key)
if api_key.is_active:
self._keys_cache.add(api_key.key)
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
return api_key
async def create_default_key(self, session: AsyncSession) -> Optional[ApiKey]:
"""
Create a default API key if none exists.
Returns:
The created ApiKey or None if keys already exist
"""
result = await session.execute(select(ApiKey).limit(1))
existing = result.scalar_one_or_none()
if existing:
return None
default_key = secrets.token_urlsafe(32)
api_key = ApiKey(
key=default_key,
name="Default API Key",
is_active=True,
)
session.add(api_key)
await session.commit()
await session.refresh(api_key)
self._keys_cache.add(api_key.key)
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
return api_key
async def delete_key(
self,
session: AsyncSession,
key_id: str
) -> bool:
"""
Delete an API key.
Args:
session: Database session
key_id: The key ID to delete
Returns:
True if deleted, False if not found
"""
import uuid
try:
key_uuid = uuid.UUID(key_id)
except ValueError:
return False
result = await session.execute(
select(ApiKey).where(ApiKey.id == key_uuid)
)
api_key = result.scalar_one_or_none()
if not api_key:
return False
key_value = api_key.key
await session.delete(api_key)
await session.commit()
self._keys_cache.discard(key_value)
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
return True
async def toggle_key(
self,
session: AsyncSession,
key_id: str,
is_active: bool
) -> Optional[ApiKey]:
"""
Toggle API key active status.
Args:
session: Database session
key_id: The key ID to toggle
is_active: New active status
Returns:
The updated ApiKey or None if not found
"""
import uuid
try:
key_uuid = uuid.UUID(key_id)
except ValueError:
return None
result = await session.execute(
select(ApiKey).where(ApiKey.id == key_uuid)
)
api_key = result.scalar_one_or_none()
if not api_key:
return None
api_key.is_active = is_active
api_key.updated_at = datetime.utcnow()
session.add(api_key)
await session.commit()
await session.refresh(api_key)
if is_active:
self._keys_cache.add(api_key.key)
else:
self._keys_cache.discard(api_key.key)
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
return api_key
async def list_keys(self, session: AsyncSession) -> list[ApiKey]:
"""
List all API keys.
Args:
session: Database session
Returns:
List of all ApiKey entities
"""
result = await session.execute(select(ApiKey))
return list(result.scalars().all())
async def reload_cache(self, session: AsyncSession) -> None:
"""
Reload all API keys from database into memory.
"""
self._keys_cache.clear()
await self.initialize(session)
logger.info("[AC-AISVC-50] API key cache reloaded")
_api_key_service: ApiKeyService | None = None
def get_api_key_service() -> ApiKeyService:
"""Get the global API key service instance."""
global _api_key_service
if _api_key_service is None:
_api_key_service = ApiKeyService()
return _api_key_service

View File

@ -7,7 +7,9 @@ Design reference: progress.md Section 7.1 - Architecture
- EmbeddingConfigManager: manages configuration with hot-reload support - EmbeddingConfigManager: manages configuration with hot-reload support
""" """
import json
import logging import logging
from pathlib import Path
from typing import Any, Type from typing import Any, Type
from app.services.embedding.base import EmbeddingException, EmbeddingProvider from app.services.embedding.base import EmbeddingException, EmbeddingProvider
@ -17,6 +19,8 @@ from app.services.embedding.nomic_provider import NomicEmbeddingProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
class EmbeddingProviderFactory: class EmbeddingProviderFactory:
""" """
@ -74,11 +78,38 @@ class EmbeddingProviderFactory:
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断专为RAG优化", "nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断专为RAG优化",
} }
raw_schema = temp_instance.get_config_schema()
properties = {}
required = []
for key, field in raw_schema.items():
properties[key] = {
"type": field.get("type", "string"),
"title": field.get("title", key),
"description": field.get("description", ""),
"default": field.get("default"),
}
if field.get("enum"):
properties[key]["enum"] = field.get("enum")
if field.get("minimum") is not None:
properties[key]["minimum"] = field.get("minimum")
if field.get("maximum") is not None:
properties[key]["maximum"] = field.get("maximum")
if field.get("required"):
required.append(key)
config_schema = {
"type": "object",
"properties": properties,
}
if required:
config_schema["required"] = required
return { return {
"name": name, "name": name,
"display_name": display_names.get(name, name), "display_name": display_names.get(name, name),
"description": descriptions.get(name, ""), "description": descriptions.get(name, ""),
"config_schema": temp_instance.get_config_schema(), "config_schema": config_schema,
} }
@classmethod @classmethod
@ -125,18 +156,47 @@ class EmbeddingProviderFactory:
class EmbeddingConfigManager: class EmbeddingConfigManager:
""" """
Manager for embedding configuration. Manager for embedding configuration.
[AC-AISVC-31] Supports hot-reload of configuration. [AC-AISVC-31] Supports hot-reload of configuration with persistence.
""" """
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None): def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
self._provider_name = default_provider self._default_provider = default_provider
self._config = default_config or { self._default_config = default_config or {
"base_url": "http://localhost:11434", "base_url": "http://localhost:11434",
"model": "nomic-embed-text", "model": "nomic-embed-text",
"dimension": 768, "dimension": 768,
} }
self._provider_name = default_provider
self._config = self._default_config.copy()
self._provider: EmbeddingProvider | None = None self._provider: EmbeddingProvider | None = None
self._load_from_file()
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if EMBEDDING_CONFIG_FILE.exists():
with open(EMBEDDING_CONFIG_FILE, 'r', encoding='utf-8') as f:
saved = json.load(f)
self._provider_name = saved.get("provider", self._default_provider)
self._config = saved.get("config", self._default_config.copy())
logger.info(f"Loaded embedding config from file: provider={self._provider_name}")
except Exception as e:
logger.warning(f"Failed to load embedding config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._provider_name,
"config": self._config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"Saved embedding config to file: provider={self._provider_name}")
except Exception as e:
logger.error(f"Failed to save embedding config to file: {e}")
def get_provider_name(self) -> str: def get_provider_name(self) -> str:
"""Get current provider name.""" """Get current provider name."""
return self._provider_name return self._provider_name
@ -174,7 +234,7 @@ class EmbeddingConfigManager:
) -> bool: ) -> bool:
""" """
Update embedding configuration. Update embedding configuration.
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload. [AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence.
Args: Args:
provider: New provider name provider: New provider name
@ -202,6 +262,8 @@ class EmbeddingConfigManager:
self._config = config self._config = config
self._provider = new_provider_instance self._provider = new_provider_instance
self._save_to_file()
logger.info(f"Updated embedding config: provider={provider}") logger.info(f"Updated embedding config: provider={provider}")
return True return True
@ -286,7 +348,7 @@ def get_embedding_config_manager() -> EmbeddingConfigManager:
settings = get_settings() settings = get_settings()
_embedding_config_manager = EmbeddingConfigManager( _embedding_config_manager = EmbeddingConfigManager(
default_provider="ollama", default_provider="nomic",
default_config={ default_config={
"base_url": settings.ollama_base_url, "base_url": settings.ollama_base_url,
"model": settings.ollama_embedding_model, "model": settings.ollama_embedding_model,

View File

@ -149,6 +149,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
embedding_256 = self._truncate_and_normalize(embedding, 256) embedding_256 = self._truncate_and_normalize(embedding, 256)
embedding_512 = self._truncate_and_normalize(embedding, 512) embedding_512 = self._truncate_and_normalize(embedding, 512)
embedding_full = self._truncate_and_normalize(embedding, len(embedding))
logger.debug( logger.debug(
f"Generated Nomic embedding: task={task.value}, " f"Generated Nomic embedding: task={task.value}, "
@ -156,7 +157,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
) )
return NomicEmbeddingResult( return NomicEmbeddingResult(
embedding_full=embedding, embedding_full=embedding_full,
embedding_256=embedding_256, embedding_256=embedding_256,
embedding_512=embedding_512, embedding_512=embedding_512,
dimension=len(embedding), dimension=len(embedding),
@ -259,26 +260,31 @@ class NomicEmbeddingProvider(EmbeddingProvider):
return { return {
"base_url": { "base_url": {
"type": "string", "type": "string",
"title": "API 地址",
"description": "Ollama API 地址", "description": "Ollama API 地址",
"default": "http://localhost:11434", "default": "http://localhost:11434",
}, },
"model": { "model": {
"type": "string", "type": "string",
"title": "模型名称",
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5", "description": "嵌入模型名称(推荐 nomic-embed-text v1.5",
"default": "nomic-embed-text", "default": "nomic-embed-text",
}, },
"dimension": { "dimension": {
"type": "integer", "type": "integer",
"title": "向量维度",
"description": "向量维度(支持 256/512/768", "description": "向量维度(支持 256/512/768",
"default": 768, "default": 768,
}, },
"timeout_seconds": { "timeout_seconds": {
"type": "integer", "type": "integer",
"title": "超时时间",
"description": "请求超时时间(秒)", "description": "请求超时时间(秒)",
"default": 60, "default": 60,
}, },
"enable_matryoshka": { "enable_matryoshka": {
"type": "boolean", "type": "boolean",
"title": "Matryoshka 截断",
"description": "启用 Matryoshka 维度截断", "description": "启用 Matryoshka 维度截断",
"default": True, "default": True,
}, },

View File

@ -130,21 +130,25 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
return { return {
"base_url": { "base_url": {
"type": "string", "type": "string",
"title": "API 地址",
"description": "Ollama API 地址", "description": "Ollama API 地址",
"default": "http://localhost:11434", "default": "http://localhost:11434",
}, },
"model": { "model": {
"type": "string", "type": "string",
"title": "模型名称",
"description": "嵌入模型名称", "description": "嵌入模型名称",
"default": "nomic-embed-text", "default": "nomic-embed-text",
}, },
"dimension": { "dimension": {
"type": "integer", "type": "integer",
"title": "向量维度",
"description": "向量维度", "description": "向量维度",
"default": 768, "default": 768,
}, },
"timeout_seconds": { "timeout_seconds": {
"type": "integer", "type": "integer",
"title": "超时时间",
"description": "请求超时时间(秒)", "description": "请求超时时间(秒)",
"default": 60, "default": 60,
}, },

View File

@ -159,28 +159,33 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
return { return {
"api_key": { "api_key": {
"type": "string", "type": "string",
"title": "API 密钥",
"description": "OpenAI API 密钥", "description": "OpenAI API 密钥",
"required": True, "required": True,
"secret": True, "secret": True,
}, },
"model": { "model": {
"type": "string", "type": "string",
"title": "模型名称",
"description": "嵌入模型名称", "description": "嵌入模型名称",
"default": "text-embedding-3-small", "default": "text-embedding-3-small",
"enum": list(self.MODEL_DIMENSIONS.keys()), "enum": list(self.MODEL_DIMENSIONS.keys()),
}, },
"base_url": { "base_url": {
"type": "string", "type": "string",
"title": "API 地址",
"description": "OpenAI API 地址(支持兼容接口)", "description": "OpenAI API 地址(支持兼容接口)",
"default": "https://api.openai.com/v1", "default": "https://api.openai.com/v1",
}, },
"dimension": { "dimension": {
"type": "integer", "type": "integer",
"title": "向量维度",
"description": "向量维度(仅 text-embedding-3 系列支持自定义)", "description": "向量维度(仅 text-embedding-3 系列支持自定义)",
"default": 1536, "default": 1536,
}, },
"timeout_seconds": { "timeout_seconds": {
"type": "integer", "type": "integer",
"title": "超时时间",
"description": "请求超时时间(秒)", "description": "请求超时时间(秒)",
"default": 60, "default": 60,
}, },

View File

@ -5,8 +5,10 @@ LLM Provider Factory and Configuration Management.
Design pattern: Factory pattern for pluggable LLM providers. Design pattern: Factory pattern for pluggable LLM providers.
""" """
import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Any from typing import Any
from app.services.llm.base import LLMClient, LLMConfig from app.services.llm.base import LLMClient, LLMConfig
@ -14,6 +16,8 @@ from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LLM_CONFIG_FILE = Path("config/llm_config.json")
@dataclass @dataclass
class LLMProviderInfo: class LLMProviderInfo:
@ -257,7 +261,7 @@ class LLMProviderFactory:
class LLMConfigManager: class LLMConfigManager:
""" """
Manager for LLM configuration. Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload. [AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
""" """
def __init__(self): def __init__(self):
@ -275,11 +279,40 @@ class LLMConfigManager:
} }
self._client: LLMClient | None = None self._client: LLMClient | None = None
self._load_from_file()
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if LLM_CONFIG_FILE.exists():
with open(LLM_CONFIG_FILE, 'r', encoding='utf-8') as f:
saved = json.load(f)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from file: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._current_provider,
"config": self._current_config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"[AC-ASA-16] Saved LLM config to file: provider={self._current_provider}")
except Exception as e:
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
def get_current_config(self) -> dict[str, Any]: def get_current_config(self) -> dict[str, Any]:
"""Get current LLM configuration.""" """Get current LLM configuration."""
return { return {
"provider": self._current_provider, "provider": self._current_provider,
"config": self._current_config, "config": self._current_config.copy(),
} }
async def update_config( async def update_config(
@ -289,7 +322,7 @@ class LLMConfigManager:
) -> bool: ) -> bool:
""" """
Update LLM configuration. Update LLM configuration.
[AC-ASA-16] Hot-reload configuration. [AC-ASA-16] Hot-reload configuration with persistence.
Args: Args:
provider: Provider name provider: Provider name
@ -311,6 +344,8 @@ class LLMConfigManager:
self._current_provider = provider self._current_provider = provider
self._current_config = validated_config self._current_config = validated_config
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}") logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
return True return True
@ -365,7 +400,7 @@ class LLMConfigManager:
test_provider = provider or self._current_provider test_provider = provider or self._current_provider
test_config = config if config else self._current_config test_config = config if config else self._current_config
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, config={test_config}") logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
if test_provider not in LLM_PROVIDERS: if test_provider not in LLM_PROVIDERS:
return { return {

View File

@ -119,13 +119,7 @@ class OrchestratorService:
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000), max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
enable_rag=True, enable_rag=True,
) )
self._llm_config = LLMConfig( self._llm_config: LLMConfig | None = None
model=getattr(settings, "llm_model", "gpt-4o-mini"),
max_tokens=getattr(settings, "llm_max_tokens", 2048),
temperature=getattr(settings, "llm_temperature", 0.7),
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
max_retries=getattr(settings, "llm_max_retries", 3),
)
async def generate( async def generate(
self, self,
@ -345,7 +339,6 @@ class OrchestratorService:
try: try:
ctx.llm_response = await self._llm_client.generate( ctx.llm_response = await self._llm_client.generate(
messages=messages, messages=messages,
config=self._llm_config,
) )
ctx.diagnostics["llm_mode"] = "live" ctx.diagnostics["llm_mode"] = "live"
ctx.diagnostics["llm_model"] = ctx.llm_response.model ctx.diagnostics["llm_model"] = ctx.llm_response.model
@ -627,7 +620,7 @@ class OrchestratorService:
""" """
messages = self._build_llm_messages(ctx) messages = self._build_llm_messages(ctx)
async for chunk in self._llm_client.stream_generate(messages, self._llm_config): async for chunk in self._llm_client.stream_generate(messages):
if not state_machine.can_send_message(): if not state_machine.can_send_message():
break break

View File

@ -84,7 +84,13 @@ class RRFCombiner:
"bm25_rank": -1, "bm25_rank": -1,
"payload": result.get("payload", {}), "payload": result.get("payload", {}),
"id": chunk_id, "id": chunk_id,
"vector": result.get("vector"),
} }
else:
combined_scores[chunk_id]["vector_score"] = result.get("score", 0.0)
combined_scores[chunk_id]["vector_rank"] = rank
if result.get("vector"):
combined_scores[chunk_id]["vector"] = result.get("vector")
combined_scores[chunk_id]["score"] += rrf_score combined_scores[chunk_id]["score"] += rrf_score
@ -101,6 +107,7 @@ class RRFCombiner:
"bm25_rank": rank, "bm25_rank": rank,
"payload": result.get("payload", {}), "payload": result.get("payload", {}),
"id": chunk_id, "id": chunk_id,
"vector": result.get("vector"),
} }
else: else:
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0) combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
@ -131,7 +138,6 @@ class OptimizedRetriever(BaseRetriever):
def __init__( def __init__(
self, self,
qdrant_client: QdrantClient | None = None, qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
top_k: int | None = None, top_k: int | None = None,
score_threshold: float | None = None, score_threshold: float | None = None,
min_hits: int | None = None, min_hits: int | None = None,
@ -141,7 +147,6 @@ class OptimizedRetriever(BaseRetriever):
rrf_k: int | None = None, rrf_k: int | None = None,
): ):
self._qdrant_client = qdrant_client self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._top_k = top_k or settings.rag_top_k self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits self._min_hits = min_hits or settings.rag_min_hits
@ -157,19 +162,17 @@ class OptimizedRetriever(BaseRetriever):
return self._qdrant_client return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider: async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
from app.services.embedding.factory import get_embedding_config_manager from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager() manager = get_embedding_config_manager()
provider = await manager.get_provider() provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider): if isinstance(provider, NomicEmbeddingProvider):
self._embedding_provider = provider return provider
else: else:
self._embedding_provider = NomicEmbeddingProvider( return NomicEmbeddingProvider(
base_url=settings.ollama_base_url, base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model, model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size, dimension=settings.qdrant_vector_size,
) )
return self._embedding_provider
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
""" """
@ -199,7 +202,15 @@ class OptimizedRetriever(BaseRetriever):
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}" f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
) )
if self._two_stage_enabled: if self._two_stage_enabled and self._hybrid_enabled:
logger.info("[RAG-OPT] Using two-stage + hybrid retrieval strategy")
results = await self._two_stage_hybrid_retrieve(
ctx.tenant_id,
embedding_result,
ctx.query,
self._top_k,
)
elif self._two_stage_enabled:
logger.info("[RAG-OPT] Using two-stage retrieval strategy") logger.info("[RAG-OPT] Using two-stage retrieval strategy")
results = await self._two_stage_retrieve( results = await self._two_stage_retrieve(
ctx.tenant_id, ctx.tenant_id,
@ -300,20 +311,27 @@ class OptimizedRetriever(BaseRetriever):
stage1_start = time.perf_counter() stage1_start = time.perf_counter()
candidates = await self._search_with_dimension( candidates = await self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256", client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor top_k * self._two_stage_expand_factor,
with_vectors=True,
) )
stage1_latency = (time.perf_counter() - stage1_start) * 1000 stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.debug( logger.info(
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms" f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
) )
stage2_start = time.perf_counter() stage2_start = time.perf_counter()
reranked = [] reranked = []
for candidate in candidates: for candidate in candidates:
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", []) vector_data = candidate.get("vector", {})
if stored_full_embedding: stored_full_embedding = None
import numpy as np
if isinstance(vector_data, dict):
stored_full_embedding = vector_data.get("full", [])
elif isinstance(vector_data, list):
stored_full_embedding = vector_data
if stored_full_embedding and len(stored_full_embedding) > 0:
similarity = self._cosine_similarity( similarity = self._cosine_similarity(
embedding_result.embedding_full, embedding_result.embedding_full,
stored_full_embedding stored_full_embedding
@ -326,7 +344,7 @@ class OptimizedRetriever(BaseRetriever):
results = reranked[:top_k] results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000 stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.debug( logger.info(
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms" f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
) )
@ -374,6 +392,92 @@ class OptimizedRetriever(BaseRetriever):
return combined[:top_k] return combined[:top_k]
async def _two_stage_hybrid_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
query: str,
top_k: int,
) -> list[dict[str, Any]]:
"""
Two-stage + Hybrid retrieval strategy.
Stage 1: Fast retrieval with 256-dim vectors + BM25 in parallel
Stage 2: RRF fusion + Precise reranking with 768-dim vectors
This combines the best of both worlds:
- Two-stage: Speed from 256-dim, precision from 768-dim reranking
- Hybrid: Semantic matching from vectors, keyword matching from BM25
"""
import time
client = await self._get_client()
stage1_start = time.perf_counter()
vector_task = self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor,
with_vectors=True,
)
bm25_task = self._bm25_search(client, tenant_id, query, top_k * self._two_stage_expand_factor)
vector_results, bm25_results = await asyncio.gather(
vector_task, bm25_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
vector_results = []
if isinstance(bm25_results, Exception):
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
bm25_results = []
stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.info(
f"[RAG-OPT] Two-stage Hybrid Stage 1: vector={len(vector_results)}, bm25={len(bm25_results)}, latency={stage1_latency:.2f}ms"
)
stage2_start = time.perf_counter()
combined = self._rrf_combiner.combine(
vector_results,
bm25_results,
vector_weight=settings.rag_vector_weight,
bm25_weight=settings.rag_bm25_weight,
)
reranked = []
for candidate in combined[:top_k * 2]:
vector_data = candidate.get("vector", {})
stored_full_embedding = None
if isinstance(vector_data, dict):
stored_full_embedding = vector_data.get("full", [])
elif isinstance(vector_data, list):
stored_full_embedding = vector_data
if stored_full_embedding and len(stored_full_embedding) > 0:
similarity = self._cosine_similarity(
embedding_result.embedding_full,
stored_full_embedding
)
candidate["score"] = similarity
candidate["stage"] = "two_stage_hybrid_reranked"
reranked.append(candidate)
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.info(
f"[RAG-OPT] Two-stage Hybrid Stage 2 (reranking): {len(results)} final results in {stage2_latency:.2f}ms"
)
return results
async def _vector_retrieve( async def _vector_retrieve(
self, self,
tenant_id: str, tenant_id: str,
@ -393,45 +497,37 @@ class OptimizedRetriever(BaseRetriever):
query_vector: list[float], query_vector: list[float],
vector_name: str, vector_name: str,
limit: int, limit: int,
with_vectors: bool = False,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Search using specified vector dimension.""" """Search using specified vector dimension."""
try: try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
logger.info( logger.info(
f"[RAG-OPT] Searching collection={collection_name}, " f"[RAG-OPT] Searching with vector_name={vector_name}, "
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}" f"limit={limit}, vector_dim={len(query_vector)}, with_vectors={with_vectors}"
) )
results = await qdrant.search( results = await client.search(
collection_name=collection_name, tenant_id=tenant_id,
query_vector=(vector_name, query_vector), query_vector=query_vector,
limit=limit, limit=limit,
vector_name=vector_name,
with_vectors=with_vectors,
) )
logger.info( logger.info(
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}" f"[RAG-OPT] Search returned {len(results)} results"
) )
if len(results) > 0: if len(results) > 0:
for i, r in enumerate(results[:3]): for i, r in enumerate(results[:3]):
logger.debug( logger.debug(
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}" f"[RAG-OPT] Result {i+1}: id={r['id']}, score={r['score']:.4f}"
) )
return [ return results
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
]
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[RAG-OPT] Search with {vector_name} failed: {e}, " f"[RAG-OPT] Search with {vector_name} failed: {e}",
f"collection_name={client.get_collection_name(tenant_id)}",
exc_info=True exc_info=True
) )
return [] return []

View File

@ -14,12 +14,13 @@ dependencies = [
"tenacity>=8.2.0", "tenacity>=8.2.0",
"sqlmodel>=0.0.14", "sqlmodel>=0.0.14",
"asyncpg>=0.29.0", "asyncpg>=0.29.0",
"qdrant-client>=1.7.0", "qdrant-client>=1.9.0,<2.0.0",
"tiktoken>=0.5.0", "tiktoken>=0.5.0",
"openpyxl>=3.1.0", "openpyxl>=3.1.0",
"python-docx>=1.1.0", "python-docx>=1.1.0",
"pymupdf>=1.23.0", "pymupdf>=1.23.0",
"pdfplumber>=0.10.0", "pdfplumber>=0.10.0",
"python-multipart>=0.0.6",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@ -0,0 +1,89 @@
"""
Script to cleanup Qdrant collections and data.
"""
import asyncio
import logging
import sys
sys.path.insert(0, "Q:\\agentProject\\ai-robot-core\\ai-service")
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def list_collections():
"""List all collections in Qdrant."""
client = await get_qdrant_client()
qdrant = await client.get_client()
collections = await qdrant.get_collections()
return [c.name for c in collections.collections]
async def delete_collection(collection_name: str):
"""Delete a specific collection."""
client = await get_qdrant_client()
qdrant = await client.get_client()
try:
await qdrant.delete_collection(collection_name)
logger.info(f"Deleted collection: {collection_name}")
return True
except Exception as e:
logger.error(f"Failed to delete collection {collection_name}: {e}")
return False
async def delete_all_collections():
"""Delete all collections."""
collections = await list_collections()
logger.info(f"Found {len(collections)} collections: {collections}")
for name in collections:
await delete_collection(name)
logger.info("All collections deleted")
async def delete_tenant_collection(tenant_id: str):
"""Delete collection for a specific tenant."""
client = await get_qdrant_client()
collection_name = client.get_collection_name(tenant_id)
success = await delete_collection(collection_name)
if success:
logger.info(f"Deleted collection for tenant: {tenant_id}")
return success
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Cleanup Qdrant data")
parser.add_argument("--all", action="store_true", help="Delete all collections")
parser.add_argument("--tenant", type=str, help="Delete collection for specific tenant")
parser.add_argument("--list", action="store_true", help="List all collections")
args = parser.parse_args()
if args.list:
collections = asyncio.run(list_collections())
print(f"Collections: {collections}")
elif args.all:
confirm = input("Are you sure you want to delete ALL collections? (yes/no): ")
if confirm.lower() == "yes":
asyncio.run(delete_all_collections())
else:
print("Cancelled")
elif args.tenant:
confirm = input(f"Delete collection for tenant '{args.tenant}'? (yes/no): ")
if confirm.lower() == "yes":
asyncio.run(delete_tenant_collection(args.tenant))
else:
print("Cancelled")
else:
parser.print_help()

View File

@ -28,6 +28,13 @@ CREATE TABLE IF NOT EXISTS chat_messages (
session_id VARCHAR NOT NULL, session_id VARCHAR NOT NULL,
role VARCHAR NOT NULL, role VARCHAR NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
latency_ms INTEGER,
first_token_ms INTEGER,
is_error BOOLEAN NOT NULL DEFAULT FALSE,
error_message VARCHAR,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
); );
@ -74,6 +81,18 @@ CREATE TABLE IF NOT EXISTS index_jobs (
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
); );
-- ============================================
-- API Keys Table [AC-AISVC-50]
-- ============================================
CREATE TABLE IF NOT EXISTS api_keys (
id UUID NOT NULL PRIMARY KEY,
key VARCHAR NOT NULL UNIQUE,
name VARCHAR NOT NULL,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
);
-- ============================================ -- ============================================
-- Indexes -- Indexes
-- ============================================ -- ============================================
@ -100,6 +119,10 @@ CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_id ON index_jobs (tenant_id);
CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_doc ON index_jobs (tenant_id, doc_id); CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_doc ON index_jobs (tenant_id, doc_id);
CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_status ON index_jobs (tenant_id, status); CREATE INDEX IF NOT EXISTS ix_index_jobs_tenant_status ON index_jobs (tenant_id, status);
-- API Keys Indexes [AC-AISVC-50]
CREATE INDEX IF NOT EXISTS ix_api_keys_key ON api_keys (key);
CREATE INDEX IF NOT EXISTS ix_api_keys_is_active ON api_keys (is_active);
-- ============================================ -- ============================================
-- Verification -- Verification
-- ============================================ -- ============================================

View File

@ -0,0 +1,29 @@
-- Migration: Add missing columns to chat_messages table
-- Execute this on existing database to add new columns
-- Add token tracking columns
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS prompt_tokens INTEGER;
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS completion_tokens INTEGER;
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS total_tokens INTEGER;
-- Add latency tracking columns
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS latency_ms INTEGER;
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS first_token_ms INTEGER;
-- Add error tracking columns
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS is_error BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE chat_messages ADD COLUMN IF NOT EXISTS error_message VARCHAR;
-- Create API Keys table if not exists
CREATE TABLE IF NOT EXISTS api_keys (
id UUID NOT NULL PRIMARY KEY,
key VARCHAR NOT NULL UNIQUE,
name VARCHAR NOT NULL,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL
);
-- Create API Keys indexes
CREATE INDEX IF NOT EXISTS ix_api_keys_key ON api_keys (key);
CREATE INDEX IF NOT EXISTS ix_api_keys_is_active ON api_keys (is_active);

134
deploy/nginx.conf.example Normal file
View File

@ -0,0 +1,134 @@
# AI Service Nginx Configuration
# 将此文件放置于 /etc/nginx/conf.d/ai-service.conf
# 或 include 到主配置文件中
# 后端 API 上游(供 Java 渠道侧调用)
upstream ai_service_backend {
server 127.0.0.1:8182;
}
# 前端管理界面上游
upstream ai_service_admin {
server 127.0.0.1:8181;
}
# 前端管理界面
server {
listen 80;
server_name your-domain.com; # 替换为你的域名或服务器IP
# 访问日志
access_log /var/log/nginx/ai-service-admin.access.log;
error_log /var/log/nginx/ai-service-admin.error.log;
location / {
proxy_pass http://ai_service_admin;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_cache_bypass $http_upgrade;
# SSE 流式响应支持
proxy_read_timeout 300s;
proxy_connect_timeout 75s;
proxy_buffering off;
}
}
# 后端 API供 Java 渠道侧调用)
# 如果使用域名,可以用不同的路径或子域名
# 示例api.your-domain.com 或 your-domain.com/api/
server {
listen 80;
server_name api.your-domain.com; # 替换为 API 子域名
# 访问日志
access_log /var/log/nginx/ai-service-api.access.log;
error_log /var/log/nginx/ai-service-api.error.log;
location / {
proxy_pass http://ai_service_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_cache_bypass $http_upgrade;
# SSE 流式响应支持
proxy_read_timeout 300s;
proxy_connect_timeout 75s;
proxy_buffering off;
}
}
# ============================================================
# HTTPS 配置示例 (使用 Let's Encrypt)
# ============================================================
# server {
# listen 443 ssl http2;
# server_name your-domain.com;
#
# ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
# ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
#
# ssl_protocols TLSv1.2 TLSv1.3;
# ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
# ssl_prefer_server_ciphers off;
#
# location / {
# proxy_pass http://ai_service_admin;
# proxy_http_version 1.1;
# proxy_set_header Upgrade $http_upgrade;
# proxy_set_header Connection 'upgrade';
# proxy_set_header Host $host;
# proxy_set_header X-Real-IP $remote_addr;
# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# proxy_set_header X-Forwarded-Proto $scheme;
# proxy_cache_bypass $http_upgrade;
# proxy_read_timeout 300s;
# proxy_connect_timeout 75s;
# proxy_buffering off;
# }
# }
# server {
# listen 443 ssl http2;
# server_name api.your-domain.com;
#
# ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
# ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
#
# ssl_protocols TLSv1.2 TLSv1.3;
# ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
# ssl_prefer_server_ciphers off;
#
# location / {
# proxy_pass http://ai_service_backend;
# proxy_http_version 1.1;
# proxy_set_header Upgrade $http_upgrade;
# proxy_set_header Connection 'upgrade';
# proxy_set_header Host $host;
# proxy_set_header X-Real-IP $remote_addr;
# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# proxy_set_header X-Forwarded-Proto $scheme;
# proxy_cache_bypass $http_upgrade;
# proxy_read_timeout 300s;
# proxy_connect_timeout 75s;
# proxy_buffering off;
# }
# }
# HTTP 重定向到 HTTPS
# server {
# listen 80;
# server_name your-domain.com api.your-domain.com;
# return 301 https://$server_name$request_uri;
# }

108
docker-compose.yaml Normal file
View File

@ -0,0 +1,108 @@
services:
ai-service:
build:
context: ./ai-service
dockerfile: Dockerfile
container_name: ai-service
restart: unless-stopped
ports:
- "8182:8080"
environment:
- AI_SERVICE_DEBUG=false
- AI_SERVICE_LOG_LEVEL=INFO
- AI_SERVICE_DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/ai_service
- AI_SERVICE_QDRANT_URL=http://qdrant:6333
- AI_SERVICE_LLM_PROVIDER=${AI_SERVICE_LLM_PROVIDER:-openai}
- AI_SERVICE_LLM_API_KEY=${AI_SERVICE_LLM_API_KEY:-}
- AI_SERVICE_LLM_BASE_URL=${AI_SERVICE_LLM_BASE_URL:-https://api.openai.com/v1}
- AI_SERVICE_LLM_MODEL=${AI_SERVICE_LLM_MODEL:-gpt-4o-mini}
- AI_SERVICE_OLLAMA_BASE_URL=${AI_SERVICE_OLLAMA_BASE_URL:-http://ollama:11434}
volumes:
- ai_service_config:/app/config
depends_on:
postgres:
condition: service_healthy
qdrant:
condition: service_started
networks:
- ai-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/ai/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
ai-service-admin:
build:
context: ./ai-service-admin
dockerfile: Dockerfile
args:
VITE_APP_API_KEY: ${VITE_APP_API_KEY:-}
VITE_APP_BASE_API: /api
container_name: ai-service-admin
restart: unless-stopped
ports:
- "8183:80"
depends_on:
- ai-service
networks:
- ai-network
postgres:
image: postgres:15-alpine
container_name: ai-postgres
restart: unless-stopped
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=ai_service
volumes:
- postgres_data:/var/lib/postgresql/data
- ./ai-service/scripts/init_db.sql:/docker-entrypoint-initdb.d/init_db.sql:ro
ports:
- "5432:5432"
networks:
- ai-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres -d ai_service"]
interval: 10s
timeout: 5s
retries: 5
qdrant:
image: qdrant/qdrant:latest
container_name: ai-qdrant
restart: unless-stopped
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
networks:
- ai-network
ollama:
image: ollama/ollama:latest
container_name: ai-ollama
restart: unless-stopped
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
networks:
- ai-network
deploy:
resources:
reservations:
memory: 1G
networks:
ai-network:
driver: bridge
volumes:
postgres_data:
qdrant_data:
ollama_data:
ai_service_config: