Merge pull request '[AC-AISVC-02, AC-AISVC-16] 多个需求合并' (#1) from feature/prompt-unification-and-logging into main
Reviewed-on: http://ashai.com.cn:3005/MerCry/ai-robot-core/pulls/1
This commit is contained in:
commit
1e3fe808e8
|
|
@ -158,5 +158,9 @@ cython_debug/
|
|||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# Project specific
|
||||
ai-service/uploads/
|
||||
*.local
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
node_modules/
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" href="/favicon.ico" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>AI Service Admin</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"name": "ai-service-admin",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vue-tsc --noEmit && vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"axios": "^1.6.7",
|
||||
"element-plus": "^2.6.1",
|
||||
"pinia": "^2.1.7",
|
||||
"vue": "^3.4.21",
|
||||
"vue-router": "^4.3.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@vitejs/plugin-vue": "^5.0.4",
|
||||
"typescript": "^5.2.2",
|
||||
"vite": "^5.1.4",
|
||||
"vue-tsc": "^1.8.27"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
<template>
|
||||
<div class="app-wrapper">
|
||||
<header class="app-header">
|
||||
<div class="header-left">
|
||||
<div class="logo">
|
||||
<div class="logo-icon">
|
||||
<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 2L2 7L12 12L22 7L12 2Z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2 17L12 22L22 17" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2 12L12 17L22 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="logo-text">AI Robot</span>
|
||||
</div>
|
||||
<nav class="main-nav">
|
||||
<router-link to="/dashboard" class="nav-item" :class="{ active: isActive('/dashboard') }">
|
||||
<el-icon><Odometer /></el-icon>
|
||||
<span>控制台</span>
|
||||
</router-link>
|
||||
<router-link to="/kb" class="nav-item" :class="{ active: isActive('/kb') }">
|
||||
<el-icon><FolderOpened /></el-icon>
|
||||
<span>知识库</span>
|
||||
</router-link>
|
||||
<router-link to="/rag-lab" class="nav-item" :class="{ active: isActive('/rag-lab') }">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
<span>RAG 实验室</span>
|
||||
</router-link>
|
||||
<router-link to="/monitoring" class="nav-item" :class="{ active: isActive('/monitoring') }">
|
||||
<el-icon><Monitor /></el-icon>
|
||||
<span>会话监控</span>
|
||||
</router-link>
|
||||
<div class="nav-divider"></div>
|
||||
<router-link to="/admin/embedding" class="nav-item" :class="{ active: isActive('/admin/embedding') }">
|
||||
<el-icon><Connection /></el-icon>
|
||||
<span>嵌入模型</span>
|
||||
</router-link>
|
||||
<router-link to="/admin/llm" class="nav-item" :class="{ active: isActive('/admin/llm') }">
|
||||
<el-icon><ChatDotSquare /></el-icon>
|
||||
<span>LLM 配置</span>
|
||||
</router-link>
|
||||
</nav>
|
||||
</div>
|
||||
<div class="header-right">
|
||||
<div class="tenant-selector">
|
||||
<el-select
|
||||
v-model="currentTenantId"
|
||||
placeholder="选择租户"
|
||||
size="default"
|
||||
:loading="loading"
|
||||
@change="handleTenantChange"
|
||||
>
|
||||
<el-option
|
||||
v-for="tenant in tenantList"
|
||||
:key="tenant.id"
|
||||
:label="tenant.name"
|
||||
:value="tenant.id"
|
||||
/>
|
||||
</el-select>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
<main class="app-main">
|
||||
<router-view />
|
||||
</main>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
import { getTenantList, type Tenant } from '@/api/tenant'
|
||||
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare } from '@element-plus/icons-vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
|
||||
const route = useRoute()
|
||||
const tenantStore = useTenantStore()
|
||||
|
||||
const currentTenantId = ref(tenantStore.currentTenantId)
|
||||
const tenantList = ref<Tenant[]>([])
|
||||
const loading = ref(false)
|
||||
|
||||
const isActive = (path: string) => {
|
||||
return route.path === path || route.path.startsWith(path + '/')
|
||||
}
|
||||
|
||||
const handleTenantChange = (val: string) => {
|
||||
tenantStore.setTenant(val)
|
||||
// 刷新页面以加载新租户的数据
|
||||
window.location.reload()
|
||||
}
|
||||
|
||||
// Validate tenant ID format: name@ash@year
|
||||
const isValidTenantId = (tenantId: string): boolean => {
|
||||
return /^[^@]+@ash@\d{4}$/.test(tenantId)
|
||||
}
|
||||
|
||||
const fetchTenantList = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
// 检查当前租户ID格式是否有效
|
||||
if (!isValidTenantId(currentTenantId.value)) {
|
||||
console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value)
|
||||
currentTenantId.value = 'default@ash@2026'
|
||||
tenantStore.setTenant(currentTenantId.value)
|
||||
}
|
||||
|
||||
const response = await getTenantList()
|
||||
tenantList.value = response.tenants || []
|
||||
|
||||
// 如果当前租户不在列表中,默认选择第一个
|
||||
if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) {
|
||||
const firstTenant = tenantList.value[0]
|
||||
currentTenantId.value = firstTenant.id
|
||||
tenantStore.setTenant(firstTenant.id)
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取租户列表失败')
|
||||
console.error('Failed to fetch tenant list:', error)
|
||||
// 失败时使用默认租户
|
||||
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)' }]
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchTenantList()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.app-wrapper {
|
||||
min-height: 100vh;
|
||||
background-color: var(--bg-primary, #F8FAFC);
|
||||
}
|
||||
|
||||
.app-header {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 100;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0 24px;
|
||||
height: 60px;
|
||||
background-color: var(--bg-secondary, #FFFFFF);
|
||||
border-bottom: 1px solid var(--border-color, #E2E8F0);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.04);
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 32px;
|
||||
}
|
||||
|
||||
.logo {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.logo-icon {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--primary-color, #4F7CFF);
|
||||
}
|
||||
|
||||
.logo-icon svg {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
}
|
||||
|
||||
.logo-text {
|
||||
font-size: 18px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary, #1E293B);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.main-nav {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 14px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--text-secondary, #64748B);
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.nav-item:hover {
|
||||
color: var(--primary-color, #4F7CFF);
|
||||
background-color: var(--primary-lighter, #E8EEFF);
|
||||
}
|
||||
|
||||
.nav-item.active {
|
||||
color: var(--primary-color, #4F7CFF);
|
||||
background-color: var(--primary-lighter, #E8EEFF);
|
||||
}
|
||||
|
||||
.nav-item .el-icon {
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.nav-divider {
|
||||
width: 1px;
|
||||
height: 20px;
|
||||
margin: 0 8px;
|
||||
background-color: var(--border-color, #E2E8F0);
|
||||
}
|
||||
|
||||
.header-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.tenant-selector {
|
||||
min-width: 140px;
|
||||
}
|
||||
|
||||
.tenant-selector :deep(.el-input__wrapper) {
|
||||
background-color: var(--bg-tertiary, #F1F5F9);
|
||||
border-color: transparent;
|
||||
}
|
||||
|
||||
.tenant-selector :deep(.el-input__wrapper:hover) {
|
||||
background-color: var(--bg-hover, #E2E8F0);
|
||||
}
|
||||
|
||||
.app-main {
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
.app-header {
|
||||
padding: 0 16px;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.main-nav {
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.nav-item {
|
||||
padding: 8px 10px;
|
||||
}
|
||||
|
||||
.nav-item span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.nav-divider {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.logo-text {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export function getDashboardStats() {
|
||||
return request({
|
||||
url: '/admin/dashboard/stats',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export interface EmbeddingProviderInfo {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
config_schema: Record<string, any>
|
||||
}
|
||||
|
||||
export interface EmbeddingConfig {
|
||||
provider: string
|
||||
config: Record<string, any>
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface EmbeddingConfigUpdate {
|
||||
provider: string
|
||||
config?: Record<string, any>
|
||||
}
|
||||
|
||||
export interface EmbeddingTestResult {
|
||||
success: boolean
|
||||
dimension: number
|
||||
latency_ms?: number
|
||||
message?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface EmbeddingTestRequest {
|
||||
test_text?: string
|
||||
config?: EmbeddingConfigUpdate
|
||||
}
|
||||
|
||||
export interface DocumentFormat {
|
||||
extension: string
|
||||
name: string
|
||||
description?: string
|
||||
}
|
||||
|
||||
export interface EmbeddingProvidersResponse {
|
||||
providers: EmbeddingProviderInfo[]
|
||||
}
|
||||
|
||||
export interface EmbeddingConfigUpdateResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface SupportedFormatsResponse {
|
||||
formats: DocumentFormat[]
|
||||
}
|
||||
|
||||
export function getProviders() {
|
||||
return request({
|
||||
url: '/embedding/providers',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function getConfig() {
|
||||
return request({
|
||||
url: '/embedding/config',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function saveConfig(data: EmbeddingConfigUpdate) {
|
||||
return request({
|
||||
url: '/embedding/config',
|
||||
method: 'put',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function testEmbedding(data: EmbeddingTestRequest): Promise<EmbeddingTestResult> {
|
||||
return request({
|
||||
url: '/embedding/test',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function getSupportedFormats() {
|
||||
return request({
|
||||
url: '/embedding/formats',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export function listKnowledgeBases() {
|
||||
return request({
|
||||
url: '/admin/kb/knowledge-bases',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function listDocuments(params: any) {
|
||||
return request({
|
||||
url: '/admin/kb/documents',
|
||||
method: 'get',
|
||||
params
|
||||
})
|
||||
}
|
||||
|
||||
export function uploadDocument(data: FormData) {
|
||||
return request({
|
||||
url: '/admin/kb/documents',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function getIndexJob(jobId: string) {
|
||||
return request({
|
||||
url: `/admin/kb/index/jobs/${jobId}`,
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function deleteDocument(docId: string) {
|
||||
return request({
|
||||
url: `/admin/kb/documents/${docId}`,
|
||||
method: 'delete'
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
import request from '@/utils/request'
|
||||
import type {
|
||||
LLMProviderInfo,
|
||||
LLMConfig,
|
||||
LLMConfigUpdate,
|
||||
LLMTestResult,
|
||||
LLMTestRequest,
|
||||
LLMProvidersResponse,
|
||||
LLMConfigUpdateResponse
|
||||
} from '@/types/llm'
|
||||
|
||||
export function getLLMProviders(): Promise<LLMProvidersResponse> {
|
||||
return request({
|
||||
url: '/admin/llm/providers',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function getLLMConfig(): Promise<LLMConfig> {
|
||||
return request({
|
||||
url: '/admin/llm/config',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function updateLLMConfig(data: LLMConfigUpdate): Promise<LLMConfigUpdateResponse> {
|
||||
return request({
|
||||
url: '/admin/llm/config',
|
||||
method: 'put',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function testLLM(data: LLMTestRequest): Promise<LLMTestResult> {
|
||||
return request({
|
||||
url: '/admin/llm/test',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export type {
|
||||
LLMProviderInfo,
|
||||
LLMConfig,
|
||||
LLMConfigUpdate,
|
||||
LLMTestResult,
|
||||
LLMTestRequest,
|
||||
LLMProvidersResponse,
|
||||
LLMConfigUpdateResponse
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export function listSessions(params: any) {
|
||||
return request({
|
||||
url: '/admin/sessions',
|
||||
method: 'get',
|
||||
params
|
||||
})
|
||||
}
|
||||
|
||||
export function getSessionDetail(sessionId: string) {
|
||||
return request({
|
||||
url: `/admin/sessions/${sessionId}`,
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
import request from '@/utils/request'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
|
||||
export interface AIResponse {
|
||||
content: string
|
||||
prompt_tokens?: number
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
latency_ms?: number
|
||||
model?: string
|
||||
}
|
||||
|
||||
export interface RetrievalResult {
|
||||
content: string
|
||||
score: number
|
||||
source: string
|
||||
metadata?: Record<string, any>
|
||||
}
|
||||
|
||||
export interface RagExperimentRequest {
|
||||
query: string
|
||||
kb_ids?: string[]
|
||||
top_k?: number
|
||||
score_threshold?: number
|
||||
llm_provider?: string
|
||||
generate_response?: boolean
|
||||
}
|
||||
|
||||
export interface RagExperimentResult {
|
||||
query: string
|
||||
retrieval_results?: RetrievalResult[]
|
||||
final_prompt?: string
|
||||
ai_response?: AIResponse
|
||||
total_latency_ms?: number
|
||||
}
|
||||
|
||||
export function runRagExperiment(data: RagExperimentRequest): Promise<RagExperimentResult> {
|
||||
return request({
|
||||
url: '/admin/rag/experiments/run',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
export function runRagExperimentStream(
|
||||
data: RagExperimentRequest,
|
||||
onMessage: (event: MessageEvent) => void,
|
||||
onError?: (error: Event) => void,
|
||||
onComplete?: () => void
|
||||
): EventSource {
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const url = `${baseUrl}/admin/rag/experiments/stream`
|
||||
|
||||
const eventSource = new EventSource(url, {
|
||||
withCredentials: true
|
||||
})
|
||||
|
||||
eventSource.onmessage = onMessage
|
||||
eventSource.onerror = (error) => {
|
||||
eventSource.close()
|
||||
onError?.(error)
|
||||
}
|
||||
|
||||
return eventSource
|
||||
}
|
||||
|
||||
export function createSSEConnection(
|
||||
url: string,
|
||||
body: RagExperimentRequest,
|
||||
onMessage: (data: string) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onComplete?: () => void
|
||||
): () => void {
|
||||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const fullUrl = `${baseUrl}${url}`
|
||||
|
||||
const tenantStore = useTenantStore()
|
||||
|
||||
const controller = new AbortController()
|
||||
|
||||
fetch(fullUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream',
|
||||
'X-Tenant-Id': tenantStore.currentTenantId || '',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal
|
||||
})
|
||||
.then(async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`)
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('No response body')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
onComplete?.()
|
||||
break
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6)
|
||||
if (data === '[DONE]') {
|
||||
onComplete?.()
|
||||
return
|
||||
}
|
||||
onMessage(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error.name !== 'AbortError') {
|
||||
onError?.(error)
|
||||
}
|
||||
})
|
||||
|
||||
return () => controller.abort()
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export interface Tenant {
|
||||
id: string
|
||||
name: string
|
||||
displayName: string
|
||||
year: string
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export interface TenantListResponse {
|
||||
tenants: Tenant[]
|
||||
total: number
|
||||
}
|
||||
|
||||
export function getTenantList() {
|
||||
return request<TenantListResponse>({
|
||||
url: '/admin/tenants',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
<template>
|
||||
<el-form :model="model" v-bind="$attrs" ref="formRef">
|
||||
<slot />
|
||||
<el-form-item v-if="showFooter">
|
||||
<el-button type="primary" @click="handleSubmit">提交</el-button>
|
||||
<el-button @click="handleReset">重置</el-button>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import type { FormInstance } from 'element-plus'
|
||||
|
||||
const props = defineProps<{
|
||||
model: any
|
||||
showFooter?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits(['submit', 'reset'])
|
||||
|
||||
const formRef = ref<FormInstance>()
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!formRef.value) return
|
||||
await formRef.value.validate((valid) => {
|
||||
if (valid) {
|
||||
emit('submit', props.model)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const handleReset = () => {
|
||||
if (!formRef.value) return
|
||||
formRef.value.resetFields()
|
||||
emit('reset')
|
||||
}
|
||||
|
||||
defineExpose({
|
||||
validate: () => formRef.value?.validate(),
|
||||
resetFields: () => formRef.value?.resetFields()
|
||||
})
|
||||
</script>
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
<template>
|
||||
<el-table :data="data" v-bind="$attrs" style="width: 100%">
|
||||
<slot />
|
||||
</el-table>
|
||||
<div v-if="total > 0" class="pagination-container">
|
||||
<el-pagination
|
||||
v-model:current-page="currentPage"
|
||||
v-model:page-size="pageSize"
|
||||
:total="total"
|
||||
:page-sizes="[10, 20, 50, 100]"
|
||||
layout="total, sizes, prev, pager, next, jumper"
|
||||
@size-change="handleSizeChange"
|
||||
@current-change="handleCurrentChange"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
|
||||
const props = defineProps<{
|
||||
data: any[]
|
||||
total: number
|
||||
pageNum: number
|
||||
pageSize: number
|
||||
}>()
|
||||
|
||||
const emit = defineEmits(['update:pageNum', 'update:pageSize', 'pagination'])
|
||||
|
||||
const currentPage = computed({
|
||||
get: () => props.pageNum,
|
||||
set: (val) => emit('update:pageNum', val)
|
||||
})
|
||||
|
||||
const pageSize = computed({
|
||||
get: () => props.pageSize,
|
||||
set: (val) => emit('update:pageSize', val)
|
||||
})
|
||||
|
||||
const handleSizeChange = (val: number) => {
|
||||
emit('pagination', { page: currentPage.value, limit: val })
|
||||
}
|
||||
|
||||
const handleCurrentChange = (val: number) => {
|
||||
emit('pagination', { page: val, limit: pageSize.value })
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.pagination-container {
|
||||
padding: 32px 16px;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
<template>
|
||||
<el-form
|
||||
ref="formRef"
|
||||
:model="formData"
|
||||
:rules="formRules"
|
||||
:label-width="labelWidth"
|
||||
v-bind="$attrs"
|
||||
>
|
||||
<el-form-item
|
||||
v-for="(field, key) in schemaProperties"
|
||||
:key="key"
|
||||
:label="field.title || key"
|
||||
:prop="key"
|
||||
>
|
||||
<template #label>
|
||||
<span>{{ field.title || key }}</span>
|
||||
<el-tooltip v-if="field.description" :content="field.description" placement="top">
|
||||
<el-icon class="ml-1 cursor-help"><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</template>
|
||||
<el-input
|
||||
v-if="field.type === 'string'"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请输入${field.title || key}`"
|
||||
clearable
|
||||
:show-password="isPasswordField(key)"
|
||||
/>
|
||||
<el-input-number
|
||||
v-else-if="field.type === 'integer' || field.type === 'number'"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请输入${field.title || key}`"
|
||||
:min="field.minimum"
|
||||
:max="field.maximum"
|
||||
:step="field.type === 'number' ? 0.1 : 1"
|
||||
:precision="field.type === 'number' ? 2 : 0"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
/>
|
||||
<el-switch
|
||||
v-else-if="field.type === 'boolean'"
|
||||
v-model="formData[key]"
|
||||
/>
|
||||
<el-select
|
||||
v-else-if="field.enum && field.enum.length > 0"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请选择${field.title || key}`"
|
||||
clearable
|
||||
class="w-full"
|
||||
>
|
||||
<el-option
|
||||
v-for="option in field.enum"
|
||||
:key="option"
|
||||
:label="option"
|
||||
:value="option"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, onMounted } from 'vue'
|
||||
import { QuestionFilled } from '@element-plus/icons-vue'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
|
||||
export interface SchemaProperty {
|
||||
type: string
|
||||
title?: string
|
||||
description?: string
|
||||
default?: any
|
||||
enum?: string[]
|
||||
minimum?: number
|
||||
maximum?: number
|
||||
required?: boolean
|
||||
}
|
||||
|
||||
export interface ConfigSchema {
|
||||
type?: string
|
||||
properties?: Record<string, SchemaProperty>
|
||||
required?: string[]
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
schema: ConfigSchema
|
||||
modelValue: Record<string, any>
|
||||
labelWidth?: string
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
(e: 'update:modelValue', value: Record<string, any>): void
|
||||
}>()
|
||||
|
||||
const formRef = ref<FormInstance>()
|
||||
const formData = ref<Record<string, any>>({})
|
||||
|
||||
const schemaProperties = computed(() => {
|
||||
return props.schema?.properties || {}
|
||||
})
|
||||
|
||||
const requiredFields = computed(() => {
|
||||
const required = props.schema?.required || []
|
||||
const propsRequired = Object.entries(schemaProperties.value)
|
||||
.filter(([, field]) => field.required)
|
||||
.map(([key]) => key)
|
||||
return [...new Set([...required, ...propsRequired])]
|
||||
})
|
||||
|
||||
const formRules = computed<FormRules>(() => {
|
||||
const rules: FormRules = {}
|
||||
Object.entries(schemaProperties.value).forEach(([key, field]) => {
|
||||
const fieldRules: any[] = []
|
||||
if (requiredFields.value.includes(key)) {
|
||||
fieldRules.push({
|
||||
required: true,
|
||||
message: `${field.title || key}不能为空`,
|
||||
trigger: ['blur', 'change']
|
||||
})
|
||||
}
|
||||
if (field.type === 'string' && field.minimum !== undefined) {
|
||||
fieldRules.push({
|
||||
min: field.minimum,
|
||||
message: `${field.title || key}长度不能小于${field.minimum}`,
|
||||
trigger: ['blur']
|
||||
})
|
||||
}
|
||||
if (field.type === 'string' && field.maximum !== undefined) {
|
||||
fieldRules.push({
|
||||
max: field.maximum,
|
||||
message: `${field.title || key}长度不能大于${field.maximum}`,
|
||||
trigger: ['blur']
|
||||
})
|
||||
}
|
||||
if (rules[key]) {
|
||||
rules[key] = fieldRules
|
||||
} else if (fieldRules.length > 0) {
|
||||
rules[key] = fieldRules
|
||||
}
|
||||
})
|
||||
return rules
|
||||
})
|
||||
|
||||
const isPasswordField = (key: string): boolean => {
|
||||
const lowerKey = key.toLowerCase()
|
||||
return lowerKey.includes('password') || lowerKey.includes('secret') || lowerKey.includes('key') || lowerKey.includes('token')
|
||||
}
|
||||
|
||||
const initFormData = () => {
|
||||
const data: Record<string, any> = {}
|
||||
Object.entries(schemaProperties.value).forEach(([key, field]) => {
|
||||
if (props.modelValue && props.modelValue[key] !== undefined) {
|
||||
data[key] = props.modelValue[key]
|
||||
} else if (field.default !== undefined) {
|
||||
data[key] = field.default
|
||||
} else {
|
||||
switch (field.type) {
|
||||
case 'string':
|
||||
data[key] = ''
|
||||
break
|
||||
case 'integer':
|
||||
case 'number':
|
||||
data[key] = field.minimum ?? 0
|
||||
break
|
||||
case 'boolean':
|
||||
data[key] = false
|
||||
break
|
||||
default:
|
||||
data[key] = null
|
||||
}
|
||||
}
|
||||
})
|
||||
formData.value = data
|
||||
}
|
||||
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
() => {
|
||||
initFormData()
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
watch(
|
||||
() => props.schema,
|
||||
() => {
|
||||
initFormData()
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
watch(
|
||||
formData,
|
||||
(val) => {
|
||||
emit('update:modelValue', val)
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
onMounted(() => {
|
||||
initFormData()
|
||||
})
|
||||
|
||||
defineExpose({
|
||||
validate: () => formRef.value?.validate(),
|
||||
resetFields: () => formRef.value?.resetFields(),
|
||||
clearValidate: () => formRef.value?.clearValidate()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.ml-1 {
|
||||
margin-left: 4px;
|
||||
}
|
||||
.cursor-help {
|
||||
cursor: help;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
<template>
|
||||
<el-select
|
||||
:model-value="modelValue"
|
||||
:loading="loading"
|
||||
:placeholder="placeholder"
|
||||
:disabled="disabled"
|
||||
:clearable="clearable"
|
||||
:teleported="true"
|
||||
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
|
||||
@update:model-value="handleChange"
|
||||
>
|
||||
<el-option
|
||||
v-for="provider in providers"
|
||||
:key="provider.name"
|
||||
:label="provider.display_name"
|
||||
:value="provider.name"
|
||||
>
|
||||
<div class="provider-option">
|
||||
<span class="provider-name">{{ provider.display_name }}</span>
|
||||
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
export interface ProviderInfo {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
config_schema: Record<string, any>
|
||||
}
|
||||
|
||||
const props = withDefaults(
|
||||
defineProps<{
|
||||
modelValue?: string
|
||||
providers: ProviderInfo[]
|
||||
loading?: boolean
|
||||
disabled?: boolean
|
||||
clearable?: boolean
|
||||
placeholder?: string
|
||||
}>(),
|
||||
{
|
||||
modelValue: '',
|
||||
loading: false,
|
||||
disabled: false,
|
||||
clearable: false,
|
||||
placeholder: '请选择提供者'
|
||||
}
|
||||
)
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: string]
|
||||
change: [provider: ProviderInfo | undefined]
|
||||
}>()
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
emit('update:modelValue', value)
|
||||
const selectedProvider = props.providers.find((p) => p.name === value)
|
||||
emit('change', selectedProvider)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.provider-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
line-height: 1.5;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
font-weight: 500;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.provider-desc {
|
||||
font-size: 12px;
|
||||
color: var(--text-secondary);
|
||||
margin-top: 2px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,523 @@
|
|||
<template>
|
||||
<el-card shadow="hover" class="test-panel">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><Connection /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">{{ title }}</span>
|
||||
</div>
|
||||
<el-tag v-if="testResult" :type="testResult.success ? 'success' : 'danger'" size="small" effect="dark">
|
||||
{{ testResult.success ? '连接成功' : '连接失败' }}
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="test-content">
|
||||
<div class="test-form-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Edit /></el-icon>
|
||||
<span>{{ inputLabel }}</span>
|
||||
</div>
|
||||
<el-input
|
||||
v-model="testInputValue"
|
||||
type="textarea"
|
||||
:rows="3"
|
||||
:placeholder="inputPlaceholder"
|
||||
clearable
|
||||
class="test-textarea"
|
||||
/>
|
||||
<el-button
|
||||
type="primary"
|
||||
size="large"
|
||||
:loading="loading"
|
||||
:disabled="!canTest"
|
||||
class="test-button"
|
||||
@click="handleTest"
|
||||
>
|
||||
<el-icon v-if="!loading"><Connection /></el-icon>
|
||||
{{ loading ? '测试中...' : '测试连接' }}
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<transition name="result-fade">
|
||||
<div v-if="testResult" class="test-result">
|
||||
<el-divider />
|
||||
|
||||
<div v-if="testResult.success" class="success-result">
|
||||
<div class="result-header">
|
||||
<div class="success-icon">
|
||||
<el-icon><CircleCheck /></el-icon>
|
||||
</div>
|
||||
<span class="result-title">{{ testResult.message || '连接成功' }}</span>
|
||||
</div>
|
||||
<div class="success-details">
|
||||
<div v-if="testResult.dimension !== undefined" class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><Grid /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">向量维度</span>
|
||||
<span class="detail-value">{{ testResult.dimension }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="testResult.response" class="detail-card response-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><ChatDotRound /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">模型响应</span>
|
||||
<span class="detail-value response-text">{{ testResult.response }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="testResult.latency_ms" class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><Timer /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">响应延迟</span>
|
||||
<span class="detail-value">{{ testResult.latency_ms.toFixed(2) }} ms</span>
|
||||
</div>
|
||||
</div>
|
||||
<template v-if="showTokenStats && testResult.total_tokens !== undefined">
|
||||
<div class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><Document /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">输入 Token</span>
|
||||
<span class="detail-value">{{ testResult.prompt_tokens || 0 }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><EditPen /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">输出 Token</span>
|
||||
<span class="detail-value">{{ testResult.completion_tokens || 0 }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><DataAnalysis /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">总 Token</span>
|
||||
<span class="detail-value">{{ testResult.total_tokens }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-else class="error-result">
|
||||
<div class="result-header">
|
||||
<div class="error-icon">
|
||||
<el-icon><CircleClose /></el-icon>
|
||||
</div>
|
||||
<span class="result-title error">连接失败</span>
|
||||
</div>
|
||||
<div class="error-message-box">
|
||||
<p class="error-text">{{ testResult.error || '未知错误' }}</p>
|
||||
</div>
|
||||
<div class="troubleshooting">
|
||||
<div class="troubleshoot-header">
|
||||
<el-icon><Warning /></el-icon>
|
||||
<span>排查建议</span>
|
||||
</div>
|
||||
<ul class="troubleshoot-list">
|
||||
<li v-for="(tip, index) in troubleshootingTips" :key="index">
|
||||
<el-icon class="list-icon"><Right /></el-icon>
|
||||
{{ tip }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</transition>
|
||||
</div>
|
||||
</el-card>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { Connection, Edit, CircleCheck, CircleClose, Timer, Grid, Warning, Right, ChatDotRound, Document, EditPen, DataAnalysis } from '@element-plus/icons-vue'
|
||||
|
||||
export interface TestResult {
|
||||
success: boolean
|
||||
dimension?: number
|
||||
latency_ms?: number
|
||||
message?: string
|
||||
error?: string
|
||||
response?: string
|
||||
prompt_tokens?: number
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
}
|
||||
|
||||
const props = withDefaults(
|
||||
defineProps<{
|
||||
testFn: (input?: string) => Promise<TestResult>
|
||||
canTest?: boolean
|
||||
title?: string
|
||||
inputLabel?: string
|
||||
inputPlaceholder?: string
|
||||
showTokenStats?: boolean
|
||||
}>(),
|
||||
{
|
||||
canTest: true,
|
||||
title: '连接测试',
|
||||
inputLabel: '测试输入',
|
||||
inputPlaceholder: '请输入测试内容(可选,默认使用系统预设内容)',
|
||||
showTokenStats: false
|
||||
}
|
||||
)
|
||||
|
||||
const loading = ref(false)
|
||||
const testResult = ref<TestResult | null>(null)
|
||||
const testInputValue = ref('')
|
||||
|
||||
const troubleshootingTips = computed(() => {
|
||||
const tips: string[] = []
|
||||
const error = testResult.value?.error?.toLowerCase() || ''
|
||||
|
||||
if (error.includes('timeout') || error.includes('超时')) {
|
||||
tips.push('检查网络连接是否正常')
|
||||
tips.push('确认服务地址是否正确且可访问')
|
||||
tips.push('尝试增加请求超时时间')
|
||||
} else if (error.includes('auth') || error.includes('unauthorized') || error.includes('认证') || error.includes('api key')) {
|
||||
tips.push('检查 API Key 是否正确')
|
||||
tips.push('确认 API Key 是否已过期或被禁用')
|
||||
tips.push('验证 API Key 是否具有足够的权限')
|
||||
} else if (error.includes('connection') || error.includes('连接') || error.includes('refused')) {
|
||||
tips.push('确认服务地址(host/port)配置正确')
|
||||
tips.push('检查目标服务是否正在运行')
|
||||
tips.push('验证防火墙是否允许访问')
|
||||
} else if (error.includes('model') || error.includes('模型')) {
|
||||
tips.push('确认模型名称是否正确')
|
||||
tips.push('检查模型是否已部署或可用')
|
||||
tips.push('验证模型配置参数是否符合要求')
|
||||
} else {
|
||||
tips.push('检查所有配置参数是否正确')
|
||||
tips.push('确认服务是否正常运行')
|
||||
tips.push('查看服务端日志获取详细错误信息')
|
||||
}
|
||||
|
||||
return tips
|
||||
})
|
||||
|
||||
const handleTest = async () => {
|
||||
loading.value = true
|
||||
testResult.value = null
|
||||
|
||||
try {
|
||||
const input = testInputValue.value?.trim() || undefined
|
||||
const result = await props.testFn(input)
|
||||
testResult.value = result
|
||||
} catch (error: any) {
|
||||
testResult.value = {
|
||||
success: false,
|
||||
error: error?.message || '请求失败,请检查网络连接'
|
||||
}
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.test-panel {
|
||||
border-radius: 16px;
|
||||
border: none;
|
||||
background: rgba(255, 255, 255, 0.98);
|
||||
backdrop-filter: blur(10px);
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-panel:hover {
|
||||
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
|
||||
transform: translateY(-4px);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.test-content {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.test-form-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #606266;
|
||||
}
|
||||
|
||||
.section-label .el-icon {
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.test-textarea {
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.test-textarea :deep(.el-textarea__inner) {
|
||||
border-radius: 10px;
|
||||
border: 1px solid #dcdfe6;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-textarea :deep(.el-textarea__inner:focus) {
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.2);
|
||||
}
|
||||
|
||||
.test-button {
|
||||
align-self: flex-start;
|
||||
border-radius: 10px;
|
||||
padding: 12px 24px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border: none;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-button:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.test-button:disabled {
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.test-result {
|
||||
animation: fadeIn 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.success-icon,
|
||||
.error-icon {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 50%;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
background: linear-gradient(135deg, #f56c6c 0%, #f89898 100%);
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.result-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #67c23a;
|
||||
}
|
||||
|
||||
.result-title.error {
|
||||
color: #f56c6c;
|
||||
}
|
||||
|
||||
.success-details {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.detail-card {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 14px 18px;
|
||||
background: linear-gradient(135deg, #f0f9eb 0%, #e1f3d8 100%);
|
||||
border-radius: 12px;
|
||||
border: 1px solid #e1f3d8;
|
||||
}
|
||||
|
||||
.response-card {
|
||||
flex: 1;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.detail-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.detail-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.detail-label {
|
||||
font-size: 12px;
|
||||
color: #909399;
|
||||
}
|
||||
|
||||
.detail-value {
|
||||
font-size: 18px;
|
||||
font-weight: 700;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.response-text {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.error-result {
|
||||
animation: shake 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes shake {
|
||||
0%, 100% { transform: translateX(0); }
|
||||
25% { transform: translateX(-5px); }
|
||||
75% { transform: translateX(5px); }
|
||||
}
|
||||
|
||||
.error-message-box {
|
||||
padding: 14px 16px;
|
||||
background: linear-gradient(135deg, #fef0f0 0%, #fde2e2 100%);
|
||||
border-radius: 10px;
|
||||
border-left: 3px solid #f56c6c;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.error-text {
|
||||
margin: 0;
|
||||
color: #f56c6c;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.troubleshooting {
|
||||
padding: 16px;
|
||||
background: linear-gradient(135deg, #fdf6ec 0%, #faecd8 100%);
|
||||
border-radius: 12px;
|
||||
border: 1px solid #faecd8;
|
||||
}
|
||||
|
||||
.troubleshoot-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
font-weight: 600;
|
||||
color: #e6a23c;
|
||||
}
|
||||
|
||||
.troubleshoot-list {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.troubleshoot-list li {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
color: #606266;
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.list-icon {
|
||||
margin-top: 4px;
|
||||
color: #e6a23c;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.result-fade-enter-active {
|
||||
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.result-fade-leave-active {
|
||||
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
|
||||
}
|
||||
|
||||
.result-fade-enter-from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
|
||||
.result-fade-leave-to {
|
||||
opacity: 0;
|
||||
transform: translateY(-10px);
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
<template>
|
||||
<el-form
|
||||
ref="formRef"
|
||||
:model="formData"
|
||||
:rules="formRules"
|
||||
:label-width="labelWidth"
|
||||
v-bind="$attrs"
|
||||
>
|
||||
<el-form-item
|
||||
v-for="(field, key) in schemaProperties"
|
||||
:key="key"
|
||||
:label="field.title || key"
|
||||
:prop="key"
|
||||
>
|
||||
<template #label>
|
||||
<span>{{ field.title || key }}</span>
|
||||
<el-tooltip v-if="field.description" :content="field.description" placement="top">
|
||||
<el-icon class="ml-1 cursor-help"><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</template>
|
||||
<el-input
|
||||
v-if="field.type === 'string'"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请输入${field.title || key}`"
|
||||
clearable
|
||||
:show-password="isPasswordField(key)"
|
||||
/>
|
||||
<el-input-number
|
||||
v-else-if="field.type === 'integer' || field.type === 'number'"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请输入${field.title || key}`"
|
||||
:min="field.minimum"
|
||||
:max="field.maximum"
|
||||
:step="field.type === 'number' ? 0.1 : 1"
|
||||
:precision="field.type === 'number' ? 2 : 0"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
/>
|
||||
<el-switch
|
||||
v-else-if="field.type === 'boolean'"
|
||||
v-model="formData[key]"
|
||||
/>
|
||||
<el-select
|
||||
v-else-if="field.enum && field.enum.length > 0"
|
||||
v-model="formData[key]"
|
||||
:placeholder="`请选择${field.title || key}`"
|
||||
clearable
|
||||
class="w-full"
|
||||
>
|
||||
<el-option
|
||||
v-for="option in field.enum"
|
||||
:key="option"
|
||||
:label="option"
|
||||
:value="option"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, onMounted } from 'vue'
|
||||
import { QuestionFilled } from '@element-plus/icons-vue'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
|
||||
interface SchemaProperty {
|
||||
type: string
|
||||
title?: string
|
||||
description?: string
|
||||
default?: any
|
||||
enum?: string[]
|
||||
minimum?: number
|
||||
maximum?: number
|
||||
required?: boolean
|
||||
}
|
||||
|
||||
interface ConfigSchema {
|
||||
type?: string
|
||||
properties?: Record<string, SchemaProperty>
|
||||
required?: string[]
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
schema: ConfigSchema
|
||||
modelValue: Record<string, any>
|
||||
labelWidth?: string
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
(e: 'update:modelValue', value: Record<string, any>): void
|
||||
}>()
|
||||
|
||||
const formRef = ref<FormInstance>()
|
||||
const formData = ref<Record<string, any>>({})
|
||||
|
||||
const schemaProperties = computed(() => {
|
||||
return props.schema?.properties || {}
|
||||
})
|
||||
|
||||
const requiredFields = computed(() => {
|
||||
const required = props.schema?.required || []
|
||||
const propsRequired = Object.entries(schemaProperties.value)
|
||||
.filter(([, field]) => field.required)
|
||||
.map(([key]) => key)
|
||||
return [...new Set([...required, ...propsRequired])]
|
||||
})
|
||||
|
||||
const formRules = computed<FormRules>(() => {
|
||||
const rules: FormRules = {}
|
||||
Object.entries(schemaProperties.value).forEach(([key, field]) => {
|
||||
const fieldRules: any[] = []
|
||||
if (requiredFields.value.includes(key)) {
|
||||
fieldRules.push({
|
||||
required: true,
|
||||
message: `${field.title || key}不能为空`,
|
||||
trigger: ['blur', 'change']
|
||||
})
|
||||
}
|
||||
if (field.type === 'string' && field.minimum !== undefined) {
|
||||
fieldRules.push({
|
||||
min: field.minimum,
|
||||
message: `${field.title || key}长度不能小于${field.minimum}`,
|
||||
trigger: ['blur']
|
||||
})
|
||||
}
|
||||
if (field.type === 'string' && field.maximum !== undefined) {
|
||||
fieldRules.push({
|
||||
max: field.maximum,
|
||||
message: `${field.title || key}长度不能大于${field.maximum}`,
|
||||
trigger: ['blur']
|
||||
})
|
||||
}
|
||||
if (rules[key]) {
|
||||
rules[key] = fieldRules
|
||||
} else if (fieldRules.length > 0) {
|
||||
rules[key] = fieldRules
|
||||
}
|
||||
})
|
||||
return rules
|
||||
})
|
||||
|
||||
const isPasswordField = (key: string): boolean => {
|
||||
const lowerKey = key.toLowerCase()
|
||||
return lowerKey.includes('password') || lowerKey.includes('secret') || lowerKey.includes('key') || lowerKey.includes('token')
|
||||
}
|
||||
|
||||
const initFormData = () => {
|
||||
const data: Record<string, any> = {}
|
||||
Object.entries(schemaProperties.value).forEach(([key, field]) => {
|
||||
if (props.modelValue && props.modelValue[key] !== undefined) {
|
||||
data[key] = props.modelValue[key]
|
||||
} else if (field.default !== undefined) {
|
||||
data[key] = field.default
|
||||
} else {
|
||||
switch (field.type) {
|
||||
case 'string':
|
||||
data[key] = ''
|
||||
break
|
||||
case 'integer':
|
||||
case 'number':
|
||||
data[key] = field.minimum ?? 0
|
||||
break
|
||||
case 'boolean':
|
||||
data[key] = false
|
||||
break
|
||||
default:
|
||||
data[key] = null
|
||||
}
|
||||
}
|
||||
})
|
||||
formData.value = data
|
||||
}
|
||||
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
() => {
|
||||
initFormData()
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
watch(
|
||||
() => props.schema,
|
||||
() => {
|
||||
initFormData()
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
watch(
|
||||
formData,
|
||||
(val) => {
|
||||
emit('update:modelValue', val)
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
onMounted(() => {
|
||||
initFormData()
|
||||
})
|
||||
|
||||
defineExpose({
|
||||
validate: () => formRef.value?.validate(),
|
||||
resetFields: () => formRef.value?.resetFields(),
|
||||
clearValidate: () => formRef.value?.clearValidate()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.ml-1 {
|
||||
margin-left: 4px;
|
||||
}
|
||||
.cursor-help {
|
||||
cursor: help;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
<template>
|
||||
<el-select
|
||||
:model-value="modelValue"
|
||||
:loading="loading"
|
||||
:placeholder="placeholder"
|
||||
:disabled="disabled"
|
||||
:clearable="clearable"
|
||||
:teleported="true"
|
||||
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
|
||||
@update:model-value="handleChange"
|
||||
>
|
||||
<el-option
|
||||
v-for="provider in providers"
|
||||
:key="provider.name"
|
||||
:label="provider.display_name"
|
||||
:value="provider.name"
|
||||
>
|
||||
<div class="provider-option">
|
||||
<span class="provider-name">{{ provider.display_name }}</span>
|
||||
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import type { EmbeddingProviderInfo } from '@/types/embedding'
|
||||
|
||||
const props = withDefaults(
|
||||
defineProps<{
|
||||
modelValue?: string
|
||||
providers: EmbeddingProviderInfo[]
|
||||
loading?: boolean
|
||||
disabled?: boolean
|
||||
clearable?: boolean
|
||||
placeholder?: string
|
||||
}>(),
|
||||
{
|
||||
modelValue: '',
|
||||
loading: false,
|
||||
disabled: false,
|
||||
clearable: false,
|
||||
placeholder: '请选择嵌入模型提供者'
|
||||
}
|
||||
)
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: string]
|
||||
change: [provider: EmbeddingProviderInfo | undefined]
|
||||
}>()
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
emit('update:modelValue', value)
|
||||
const selectedProvider = props.providers.find((p) => p.name === value)
|
||||
emit('change', selectedProvider)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.provider-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
line-height: 1.5;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
font-weight: 500;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.provider-desc {
|
||||
font-size: 12px;
|
||||
color: var(--text-secondary);
|
||||
margin-top: 2px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,460 @@
|
|||
<template>
|
||||
<el-card shadow="hover" class="test-panel">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><Connection /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">连接测试</span>
|
||||
</div>
|
||||
<el-tag v-if="testResult" :type="testResult.success ? 'success' : 'danger'" size="small" effect="dark">
|
||||
{{ testResult.success ? '连接成功' : '连接失败' }}
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="test-content">
|
||||
<div class="test-form-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Edit /></el-icon>
|
||||
<span>测试文本</span>
|
||||
</div>
|
||||
<el-input
|
||||
v-model="testForm.test_text"
|
||||
type="textarea"
|
||||
:rows="3"
|
||||
placeholder="请输入测试文本(可选,默认使用系统预设文本)"
|
||||
clearable
|
||||
class="test-textarea"
|
||||
/>
|
||||
<el-button
|
||||
type="primary"
|
||||
size="large"
|
||||
:loading="loading"
|
||||
:disabled="!config?.provider"
|
||||
class="test-button"
|
||||
@click="handleTest"
|
||||
>
|
||||
<el-icon v-if="!loading"><Connection /></el-icon>
|
||||
{{ loading ? '测试中...' : '测试连接' }}
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<transition name="result-fade">
|
||||
<div v-if="testResult" class="test-result">
|
||||
<el-divider />
|
||||
|
||||
<div v-if="testResult.success" class="success-result">
|
||||
<div class="result-header">
|
||||
<div class="success-icon">
|
||||
<el-icon><CircleCheck /></el-icon>
|
||||
</div>
|
||||
<span class="result-title">{{ testResult.message || '连接成功' }}</span>
|
||||
</div>
|
||||
<div class="success-details">
|
||||
<div class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><Grid /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">向量维度</span>
|
||||
<span class="detail-value">{{ testResult.dimension }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="testResult.latency_ms" class="detail-card">
|
||||
<div class="detail-icon">
|
||||
<el-icon><Timer /></el-icon>
|
||||
</div>
|
||||
<div class="detail-info">
|
||||
<span class="detail-label">响应延迟</span>
|
||||
<span class="detail-value">{{ testResult.latency_ms.toFixed(2) }} ms</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-else class="error-result">
|
||||
<div class="result-header">
|
||||
<div class="error-icon">
|
||||
<el-icon><CircleClose /></el-icon>
|
||||
</div>
|
||||
<span class="result-title error">连接失败</span>
|
||||
</div>
|
||||
<div class="error-message-box">
|
||||
<p class="error-text">{{ testResult.error || '未知错误' }}</p>
|
||||
</div>
|
||||
<div class="troubleshooting">
|
||||
<div class="troubleshoot-header">
|
||||
<el-icon><Warning /></el-icon>
|
||||
<span>排查建议</span>
|
||||
</div>
|
||||
<ul class="troubleshoot-list">
|
||||
<li v-for="(tip, index) in troubleshootingTips" :key="index">
|
||||
<el-icon class="list-icon"><Right /></el-icon>
|
||||
{{ tip }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</transition>
|
||||
</div>
|
||||
</el-card>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { Connection, Edit, CircleCheck, CircleClose, Timer, Grid, Warning, Right } from '@element-plus/icons-vue'
|
||||
import { testEmbedding, type EmbeddingConfigUpdate, type EmbeddingTestResult } from '@/api/embedding'
|
||||
|
||||
const props = defineProps<{
|
||||
config: EmbeddingConfigUpdate | null
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const testResult = ref<EmbeddingTestResult | null>(null)
|
||||
|
||||
const testForm = ref({
|
||||
test_text: ''
|
||||
})
|
||||
|
||||
const troubleshootingTips = computed(() => {
|
||||
const tips: string[] = []
|
||||
const error = testResult.value?.error?.toLowerCase() || ''
|
||||
|
||||
if (error.includes('timeout') || error.includes('超时')) {
|
||||
tips.push('检查网络连接是否正常')
|
||||
tips.push('确认服务地址是否正确且可访问')
|
||||
tips.push('尝试增加请求超时时间')
|
||||
} else if (error.includes('auth') || error.includes('unauthorized') || error.includes('认证') || error.includes('api key')) {
|
||||
tips.push('检查 API Key 是否正确')
|
||||
tips.push('确认 API Key 是否已过期或被禁用')
|
||||
tips.push('验证 API Key 是否具有足够的权限')
|
||||
} else if (error.includes('connection') || error.includes('连接') || error.includes('refused')) {
|
||||
tips.push('确认服务地址(host/port)配置正确')
|
||||
tips.push('检查目标服务是否正在运行')
|
||||
tips.push('验证防火墙是否允许访问')
|
||||
} else if (error.includes('model') || error.includes('模型')) {
|
||||
tips.push('确认模型名称是否正确')
|
||||
tips.push('检查模型是否已部署或可用')
|
||||
tips.push('验证模型配置参数是否符合要求')
|
||||
} else {
|
||||
tips.push('检查所有配置参数是否正确')
|
||||
tips.push('确认服务是否正常运行')
|
||||
tips.push('查看服务端日志获取详细错误信息')
|
||||
}
|
||||
|
||||
return tips
|
||||
})
|
||||
|
||||
const handleTest = async () => {
|
||||
if (!props.config?.provider) {
|
||||
return
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
testResult.value = null
|
||||
|
||||
try {
|
||||
const requestData: any = {
|
||||
config: props.config
|
||||
}
|
||||
if (testForm.value.test_text?.trim()) {
|
||||
requestData.test_text = testForm.value.test_text.trim()
|
||||
}
|
||||
|
||||
const result = await testEmbedding(requestData)
|
||||
testResult.value = result
|
||||
} catch (error: any) {
|
||||
testResult.value = {
|
||||
success: false,
|
||||
dimension: 0,
|
||||
error: error?.message || '请求失败,请检查网络连接'
|
||||
}
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.test-panel {
|
||||
border-radius: 16px;
|
||||
border: none;
|
||||
background: rgba(255, 255, 255, 0.98);
|
||||
backdrop-filter: blur(10px);
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-panel:hover {
|
||||
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
|
||||
transform: translateY(-4px);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.test-content {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.test-form-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #606266;
|
||||
}
|
||||
|
||||
.section-label .el-icon {
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.test-textarea {
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.test-textarea :deep(.el-textarea__inner) {
|
||||
border-radius: 10px;
|
||||
border: 1px solid #dcdfe6;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-textarea :deep(.el-textarea__inner:focus) {
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.2);
|
||||
}
|
||||
|
||||
.test-button {
|
||||
align-self: flex-start;
|
||||
border-radius: 10px;
|
||||
padding: 12px 24px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border: none;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.test-button:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.test-button:disabled {
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.test-result {
|
||||
animation: fadeIn 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.success-icon,
|
||||
.error-icon {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 50%;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
background: linear-gradient(135deg, #f56c6c 0%, #f89898 100%);
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.result-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #67c23a;
|
||||
}
|
||||
|
||||
.result-title.error {
|
||||
color: #f56c6c;
|
||||
}
|
||||
|
||||
.success-details {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.detail-card {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 14px 18px;
|
||||
background: linear-gradient(135deg, #f0f9eb 0%, #e1f3d8 100%);
|
||||
border-radius: 12px;
|
||||
border: 1px solid #e1f3d8;
|
||||
}
|
||||
|
||||
.detail-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.detail-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.detail-label {
|
||||
font-size: 12px;
|
||||
color: #909399;
|
||||
}
|
||||
|
||||
.detail-value {
|
||||
font-size: 18px;
|
||||
font-weight: 700;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.error-result {
|
||||
animation: shake 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes shake {
|
||||
0%, 100% { transform: translateX(0); }
|
||||
25% { transform: translateX(-5px); }
|
||||
75% { transform: translateX(5px); }
|
||||
}
|
||||
|
||||
.error-message-box {
|
||||
padding: 14px 16px;
|
||||
background: linear-gradient(135deg, #fef0f0 0%, #fde2e2 100%);
|
||||
border-radius: 10px;
|
||||
border-left: 3px solid #f56c6c;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.error-text {
|
||||
margin: 0;
|
||||
color: #f56c6c;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.troubleshooting {
|
||||
padding: 16px;
|
||||
background: linear-gradient(135deg, #fdf6ec 0%, #faecd8 100%);
|
||||
border-radius: 12px;
|
||||
border: 1px solid #faecd8;
|
||||
}
|
||||
|
||||
.troubleshoot-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
font-weight: 600;
|
||||
color: #e6a23c;
|
||||
}
|
||||
|
||||
.troubleshoot-list {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.troubleshoot-list li {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
color: #606266;
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.list-icon {
|
||||
margin-top: 4px;
|
||||
color: #e6a23c;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.result-fade-enter-active {
|
||||
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.result-fade-leave-active {
|
||||
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
|
||||
}
|
||||
|
||||
.result-fade-enter-from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
|
||||
.result-fade-leave-to {
|
||||
opacity: 0;
|
||||
transform: translateY(-10px);
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
<template>
|
||||
<div class="supported-formats">
|
||||
<div v-loading="loading" class="formats-content">
|
||||
<transition-group name="tag-fade" tag="div" class="formats-grid">
|
||||
<el-tooltip
|
||||
v-for="format in formats"
|
||||
:key="format.extension"
|
||||
:content="format.description"
|
||||
placement="top"
|
||||
:disabled="!format.description"
|
||||
effect="light"
|
||||
>
|
||||
<div class="format-item">
|
||||
<div class="format-icon">
|
||||
<span class="extension">{{ format.extension }}</span>
|
||||
</div>
|
||||
<div class="format-info">
|
||||
<span class="format-name">{{ format.name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</el-tooltip>
|
||||
</transition-group>
|
||||
<el-empty v-if="!loading && formats.length === 0" description="暂无支持的格式" :image-size="80">
|
||||
<template #image>
|
||||
<div class="empty-icon">
|
||||
<el-icon><Document /></el-icon>
|
||||
</div>
|
||||
</template>
|
||||
</el-empty>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted } from 'vue'
|
||||
import { Document } from '@element-plus/icons-vue'
|
||||
import { useEmbeddingStore } from '@/stores/embedding'
|
||||
|
||||
const embeddingStore = useEmbeddingStore()
|
||||
|
||||
const formats = computed(() => embeddingStore.formats)
|
||||
const loading = computed(() => embeddingStore.formatsLoading)
|
||||
|
||||
onMounted(() => {
|
||||
if (formats.value.length === 0) {
|
||||
embeddingStore.loadFormats()
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.supported-formats {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.formats-content {
|
||||
min-height: 60px;
|
||||
}
|
||||
|
||||
.formats-grid {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.format-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 10px 14px;
|
||||
background: linear-gradient(135deg, #f8f9fc 0%, #eef1f5 100%);
|
||||
border-radius: 10px;
|
||||
border: 1px solid #e4e7ed;
|
||||
cursor: default;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.format-item:hover {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-color: transparent;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
|
||||
.format-item:hover .extension,
|
||||
.format-item:hover .format-name {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.format-icon {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.format-item:hover .format-icon {
|
||||
background: rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
|
||||
.extension {
|
||||
font-size: 11px;
|
||||
font-weight: 700;
|
||||
color: #ffffff;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.format-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.format-name {
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.empty-icon {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #e8ecf1 100%);
|
||||
border-radius: 50%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.empty-icon .el-icon {
|
||||
font-size: 40px;
|
||||
color: #c0c4cc;
|
||||
}
|
||||
|
||||
.tag-fade-enter-active {
|
||||
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.tag-fade-leave-active {
|
||||
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
|
||||
}
|
||||
|
||||
.tag-fade-enter-from {
|
||||
opacity: 0;
|
||||
transform: scale(0.8);
|
||||
}
|
||||
|
||||
.tag-fade-leave-to {
|
||||
opacity: 0;
|
||||
transform: scale(0.8);
|
||||
}
|
||||
|
||||
.tag-fade-move {
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,351 @@
|
|||
<template>
|
||||
<el-card shadow="hover" class="ai-response-viewer">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><ChatDotRound /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">AI 回复</span>
|
||||
</div>
|
||||
<el-tag v-if="response" type="success" size="small" effect="dark">
|
||||
已生成
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="response-content">
|
||||
<div v-if="!response" class="placeholder-text">
|
||||
<el-icon class="placeholder-icon"><Document /></el-icon>
|
||||
<p>运行实验后将在此显示 AI 回复</p>
|
||||
</div>
|
||||
|
||||
<template v-else>
|
||||
<div class="markdown-content" v-html="renderedContent"></div>
|
||||
|
||||
<el-divider />
|
||||
|
||||
<div class="stats-section">
|
||||
<div class="section-label">
|
||||
<el-icon><DataAnalysis /></el-icon>
|
||||
<span>统计信息</span>
|
||||
</div>
|
||||
|
||||
<div class="stats-grid">
|
||||
<div v-if="response.model" class="stat-card">
|
||||
<div class="stat-icon model-icon">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-label">模型</span>
|
||||
<span class="stat-value">{{ response.model }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="response.latency_ms" class="stat-card">
|
||||
<div class="stat-icon latency-icon">
|
||||
<el-icon><Timer /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-label">响应耗时</span>
|
||||
<span class="stat-value">{{ response.latency_ms.toFixed(2) }} ms</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="response.prompt_tokens" class="stat-card">
|
||||
<div class="stat-icon prompt-icon">
|
||||
<el-icon><EditPen /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-label">Prompt Tokens</span>
|
||||
<span class="stat-value">{{ response.prompt_tokens }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="response.completion_tokens" class="stat-card">
|
||||
<div class="stat-icon completion-icon">
|
||||
<el-icon><DocumentCopy /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-label">Completion Tokens</span>
|
||||
<span class="stat-value">{{ response.completion_tokens }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="response.total_tokens" class="stat-card highlight">
|
||||
<div class="stat-icon total-icon">
|
||||
<el-icon><Coin /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-label">Total Tokens</span>
|
||||
<span class="stat-value">{{ response.total_tokens }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</el-card>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { ChatDotRound, Document, DataAnalysis, Timer, EditPen, DocumentCopy, Coin, Cpu } from '@element-plus/icons-vue'
|
||||
import type { AIResponse } from '@/api/rag'
|
||||
|
||||
const props = defineProps<{
|
||||
response: AIResponse | null
|
||||
}>()
|
||||
|
||||
const renderedContent = computed(() => {
|
||||
if (!props.response?.content) return ''
|
||||
return renderMarkdown(props.response.content)
|
||||
})
|
||||
|
||||
const renderMarkdown = (text: string): string => {
|
||||
let html = text
|
||||
|
||||
html = html.replace(/```(\w*)\n([\s\S]*?)```/g, '<pre><code class="language-$1">$2</code></pre>')
|
||||
html = html.replace(/`([^`]+)`/g, '<code class="inline-code">$1</code>')
|
||||
html = html.replace(/^### (.+)$/gm, '<h3>$1</h3>')
|
||||
html = html.replace(/^## (.+)$/gm, '<h2>$1</h2>')
|
||||
html = html.replace(/^# (.+)$/gm, '<h1>$1</h1>')
|
||||
html = html.replace(/\*\*([^*]+)\*\*/g, '<strong>$1</strong>')
|
||||
html = html.replace(/\*([^*]+)\*/g, '<em>$1</em>')
|
||||
html = html.replace(/^\- (.+)$/gm, '<li>$1</li>')
|
||||
html = html.replace(/^\d+\. (.+)$/gm, '<li>$1</li>')
|
||||
html = html.replace(/\n\n/g, '</p><p>')
|
||||
html = html.replace(/\n/g, '<br>')
|
||||
|
||||
return `<p>${html}</p>`
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.ai-response-viewer {
|
||||
border-radius: 16px;
|
||||
border: none;
|
||||
background: rgba(255, 255, 255, 0.98);
|
||||
backdrop-filter: blur(10px);
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.ai-response-viewer:hover {
|
||||
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #36d1dc 0%, #5b86e5 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.placeholder-text {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 60px 20px;
|
||||
color: #909399;
|
||||
}
|
||||
|
||||
.placeholder-icon {
|
||||
font-size: 48px;
|
||||
margin-bottom: 16px;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.placeholder-text p {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.markdown-content {
|
||||
padding: 16px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 12px;
|
||||
line-height: 1.8;
|
||||
color: #303133;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.markdown-content :deep(h1) {
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
margin: 16px 0 12px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.markdown-content :deep(h2) {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
margin: 14px 0 10px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.markdown-content :deep(h3) {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
margin: 12px 0 8px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.markdown-content :deep(pre) {
|
||||
background: #1e1e1e;
|
||||
border-radius: 8px;
|
||||
padding: 16px;
|
||||
overflow-x: auto;
|
||||
margin: 12px 0;
|
||||
}
|
||||
|
||||
.markdown-content :deep(code) {
|
||||
font-family: 'Consolas', 'Monaco', monospace;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.markdown-content :deep(pre code) {
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.markdown-content :deep(.inline-code) {
|
||||
background: #e8e8e8;
|
||||
padding: 2px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 13px;
|
||||
color: #e83e8c;
|
||||
}
|
||||
|
||||
.markdown-content :deep(strong) {
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.markdown-content :deep(em) {
|
||||
font-style: italic;
|
||||
color: #606266;
|
||||
}
|
||||
|
||||
.markdown-content :deep(li) {
|
||||
margin: 4px 0;
|
||||
padding-left: 8px;
|
||||
}
|
||||
|
||||
.stats-section {
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #606266;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.section-label .el-icon {
|
||||
color: #5b86e5;
|
||||
}
|
||||
|
||||
.stats-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.stat-card {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 14px 16px;
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #e8ecf1 100%);
|
||||
border-radius: 12px;
|
||||
border: 1px solid #e4e7ed;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.stat-card:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
|
||||
.stat-card.highlight {
|
||||
background: linear-gradient(135deg, #e8f4fd 0%, #d4e9f7 100%);
|
||||
border-color: #b8d9f0;
|
||||
}
|
||||
|
||||
.stat-icon {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 8px;
|
||||
color: #ffffff;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.model-icon {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
}
|
||||
|
||||
.latency-icon {
|
||||
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
||||
}
|
||||
|
||||
.prompt-icon {
|
||||
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
||||
}
|
||||
|
||||
.completion-icon {
|
||||
background: linear-gradient(135deg, #43e97b 0%, #38f9d7 100%);
|
||||
}
|
||||
|
||||
.total-icon {
|
||||
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
|
||||
}
|
||||
|
||||
.stat-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 12px;
|
||||
color: #909399;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
<template>
|
||||
<el-select
|
||||
:model-value="displayValue"
|
||||
:loading="loading"
|
||||
:placeholder="computedPlaceholder"
|
||||
:disabled="disabled"
|
||||
:clearable="clearable"
|
||||
:teleported="true"
|
||||
:popper-class="popperClass"
|
||||
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }, { name: 'computeStyles', options: { adaptive: false, gpuAcceleration: false } }] }"
|
||||
@update:model-value="handleChange"
|
||||
@clear="handleClear"
|
||||
>
|
||||
<el-option
|
||||
v-for="provider in providers"
|
||||
:key="provider.name"
|
||||
:label="provider.display_name"
|
||||
:value="provider.name"
|
||||
>
|
||||
<div class="provider-option">
|
||||
<div class="provider-info">
|
||||
<span class="provider-name">{{ provider.display_name }}</span>
|
||||
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
|
||||
</div>
|
||||
<el-tag v-if="provider.name === currentProvider" type="success" size="small" effect="plain" class="current-tag">
|
||||
当前配置
|
||||
</el-tag>
|
||||
<el-tag v-else-if="provider.name === modelValue" type="primary" size="small" effect="plain" class="selected-tag">
|
||||
已选择
|
||||
</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import type { LLMProviderInfo } from '@/api/llm'
|
||||
|
||||
const popperClass = 'llm-selector-popper'
|
||||
|
||||
const props = withDefaults(
|
||||
defineProps<{
|
||||
modelValue?: string
|
||||
providers: LLMProviderInfo[]
|
||||
loading?: boolean
|
||||
disabled?: boolean
|
||||
clearable?: boolean
|
||||
placeholder?: string
|
||||
currentProvider?: string
|
||||
}>(),
|
||||
{
|
||||
modelValue: '',
|
||||
loading: false,
|
||||
disabled: false,
|
||||
clearable: false,
|
||||
placeholder: '请选择 LLM 提供者',
|
||||
currentProvider: ''
|
||||
}
|
||||
)
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: string]
|
||||
change: [provider: LLMProviderInfo | undefined]
|
||||
}>()
|
||||
|
||||
const displayValue = computed(() => {
|
||||
return props.modelValue || ''
|
||||
})
|
||||
|
||||
const computedPlaceholder = computed(() => {
|
||||
if (props.modelValue) {
|
||||
return props.placeholder
|
||||
}
|
||||
if (props.currentProvider) {
|
||||
const current = props.providers.find(p => p.name === props.currentProvider)
|
||||
return `默认: ${current?.display_name || props.currentProvider}`
|
||||
}
|
||||
return props.placeholder
|
||||
})
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
emit('update:modelValue', value)
|
||||
const selectedProvider = props.providers.find((p) => p.name === value)
|
||||
emit('change', selectedProvider)
|
||||
}
|
||||
|
||||
const handleClear = () => {
|
||||
emit('update:modelValue', '')
|
||||
emit('change', undefined)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.llm-selector-popper {
|
||||
min-width: 320px !important;
|
||||
z-index: 9999 !important;
|
||||
}
|
||||
|
||||
.llm-selector-popper .el-select-dropdown__wrap {
|
||||
max-height: 400px;
|
||||
}
|
||||
|
||||
.llm-selector-popper .el-select-dropdown__item {
|
||||
height: auto;
|
||||
padding: 8px 12px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style scoped>
|
||||
.provider-option {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 4px 0;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
line-height: 1.5;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
font-weight: 500;
|
||||
color: var(--text-primary);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.provider-desc {
|
||||
font-size: 12px;
|
||||
color: var(--text-secondary);
|
||||
margin-top: 2px;
|
||||
line-height: 1.4;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.current-tag {
|
||||
flex-shrink: 0;
|
||||
margin-left: 8px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.selected-tag {
|
||||
flex-shrink: 0;
|
||||
margin-left: 8px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
<template>
|
||||
<el-card shadow="hover" class="stream-output">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper" :class="{ streaming: isStreaming }">
|
||||
<el-icon><Promotion /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">流式输出</span>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<el-tag v-if="isStreaming" type="warning" size="small" effect="dark" class="pulse-tag">
|
||||
<el-icon class="is-loading"><Loading /></el-icon>
|
||||
生成中...
|
||||
</el-tag>
|
||||
<el-tag v-else-if="hasContent" type="success" size="small" effect="dark">
|
||||
已完成
|
||||
</el-tag>
|
||||
<el-tag v-else type="info" size="small" effect="plain">
|
||||
等待中
|
||||
</el-tag>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="stream-content">
|
||||
<div v-if="!hasContent && !isStreaming" class="placeholder-text">
|
||||
<el-icon class="placeholder-icon"><ChatLineSquare /></el-icon>
|
||||
<p>启用流式输出后,AI 回复将实时显示</p>
|
||||
</div>
|
||||
|
||||
<div v-else class="output-area">
|
||||
<div class="stream-text" v-html="renderedContent"></div>
|
||||
|
||||
<div v-if="isStreaming" class="typing-indicator">
|
||||
<span class="dot"></span>
|
||||
<span class="dot"></span>
|
||||
<span class="dot"></span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="error" class="error-section">
|
||||
<el-alert
|
||||
:title="error"
|
||||
type="error"
|
||||
:closable="false"
|
||||
show-icon
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { Promotion, Loading, ChatLineSquare } from '@element-plus/icons-vue'
|
||||
|
||||
const props = defineProps<{
|
||||
content: string
|
||||
isStreaming: boolean
|
||||
error?: string | null
|
||||
}>()
|
||||
|
||||
const hasContent = computed(() => props.content && props.content.length > 0)
|
||||
|
||||
const renderedContent = computed(() => {
|
||||
if (!props.content) return ''
|
||||
return renderMarkdown(props.content)
|
||||
})
|
||||
|
||||
const renderMarkdown = (text: string): string => {
|
||||
let html = text
|
||||
|
||||
html = html.replace(/```(\w*)\n([\s\S]*?)```/g, '<pre><code class="language-$1">$2</code></pre>')
|
||||
html = html.replace(/`([^`]+)`/g, '<code class="inline-code">$1</code>')
|
||||
html = html.replace(/^### (.+)$/gm, '<h3>$1</h3>')
|
||||
html = html.replace(/^## (.+)$/gm, '<h2>$1</h2>')
|
||||
html = html.replace(/^# (.+)$/gm, '<h1>$1</h1>')
|
||||
html = html.replace(/\*\*([^*]+)\*\*/g, '<strong>$1</strong>')
|
||||
html = html.replace(/\*([^*]+)\*/g, '<em>$1</em>')
|
||||
html = html.replace(/^\- (.+)$/gm, '<li>$1</li>')
|
||||
html = html.replace(/^\d+\. (.+)$/gm, '<li>$1</li>')
|
||||
html = html.replace(/\n\n/g, '</p><p>')
|
||||
html = html.replace(/\n/g, '<br>')
|
||||
|
||||
return `<p>${html}</p>`
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.stream-output {
|
||||
border-radius: 16px;
|
||||
border: none;
|
||||
background: rgba(255, 255, 255, 0.98);
|
||||
backdrop-filter: blur(10px);
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.stream-output:hover {
|
||||
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
|
||||
border-radius: 10px;
|
||||
color: #ffffff;
|
||||
font-size: 20px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.icon-wrapper.streaming {
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% {
|
||||
transform: scale(1);
|
||||
box-shadow: 0 0 0 0 rgba(17, 153, 142, 0.4);
|
||||
}
|
||||
50% {
|
||||
transform: scale(1.05);
|
||||
box-shadow: 0 0 0 10px rgba(17, 153, 142, 0);
|
||||
}
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.pulse-tag {
|
||||
animation: pulse-tag 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-tag {
|
||||
0%, 100% {
|
||||
opacity: 1;
|
||||
}
|
||||
50% {
|
||||
opacity: 0.7;
|
||||
}
|
||||
}
|
||||
|
||||
.placeholder-text {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 60px 20px;
|
||||
color: #909399;
|
||||
}
|
||||
|
||||
.placeholder-icon {
|
||||
font-size: 48px;
|
||||
margin-bottom: 16px;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.placeholder-text p {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.output-area {
|
||||
padding: 16px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 12px;
|
||||
min-height: 200px;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.stream-text {
|
||||
line-height: 1.8;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.stream-text :deep(h1) {
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
margin: 16px 0 12px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.stream-text :deep(h2) {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
margin: 14px 0 10px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.stream-text :deep(h3) {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
margin: 12px 0 8px;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.stream-text :deep(pre) {
|
||||
background: #1e1e1e;
|
||||
border-radius: 8px;
|
||||
padding: 16px;
|
||||
overflow-x: auto;
|
||||
margin: 12px 0;
|
||||
}
|
||||
|
||||
.stream-text :deep(code) {
|
||||
font-family: 'Consolas', 'Monaco', monospace;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.stream-text :deep(pre code) {
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.stream-text :deep(.inline-code) {
|
||||
background: #e8e8e8;
|
||||
padding: 2px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 13px;
|
||||
color: #e83e8c;
|
||||
}
|
||||
|
||||
.stream-text :deep(strong) {
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
}
|
||||
|
||||
.stream-text :deep(em) {
|
||||
font-style: italic;
|
||||
color: #606266;
|
||||
}
|
||||
|
||||
.stream-text :deep(li) {
|
||||
margin: 4px 0;
|
||||
padding-left: 8px;
|
||||
}
|
||||
|
||||
.typing-indicator {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
margin-top: 12px;
|
||||
padding: 8px 12px;
|
||||
background: rgba(17, 153, 142, 0.1);
|
||||
border-radius: 8px;
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
.typing-indicator .dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
background: #11998e;
|
||||
border-radius: 50%;
|
||||
animation: typing 1.4s infinite ease-in-out both;
|
||||
}
|
||||
|
||||
.typing-indicator .dot:nth-child(1) {
|
||||
animation-delay: -0.32s;
|
||||
}
|
||||
|
||||
.typing-indicator .dot:nth-child(2) {
|
||||
animation-delay: -0.16s;
|
||||
}
|
||||
|
||||
@keyframes typing {
|
||||
0%, 80%, 100% {
|
||||
transform: scale(0.6);
|
||||
opacity: 0.5;
|
||||
}
|
||||
40% {
|
||||
transform: scale(1);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.error-section {
|
||||
margin-top: 16px;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import { createApp } from 'vue'
|
||||
import { createPinia } from 'pinia'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import './styles/main.css'
|
||||
import App from './App.vue'
|
||||
import router from './router'
|
||||
|
||||
const app = createApp(App)
|
||||
const pinia = createPinia()
|
||||
|
||||
app.use(pinia)
|
||||
app.use(router)
|
||||
app.use(ElementPlus)
|
||||
|
||||
app.mount('#app')
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
import { createRouter, createWebHistory, RouteRecordRaw } from 'vue-router'
|
||||
|
||||
const routes: Array<RouteRecordRaw> = [
|
||||
{
|
||||
path: '/',
|
||||
redirect: '/dashboard'
|
||||
},
|
||||
{
|
||||
path: '/dashboard',
|
||||
name: 'Dashboard',
|
||||
component: () => import('@/views/dashboard/index.vue'),
|
||||
meta: { title: '控制台' }
|
||||
},
|
||||
{
|
||||
path: '/kb',
|
||||
name: 'KBManagement',
|
||||
component: () => import('@/views/kb/index.vue'),
|
||||
meta: { title: '知识库管理' }
|
||||
},
|
||||
{
|
||||
path: '/rag-lab',
|
||||
name: 'RagLab',
|
||||
component: () => import('@/views/rag-lab/index.vue'),
|
||||
meta: { title: 'RAG 实验室' }
|
||||
},
|
||||
{
|
||||
path: '/monitoring',
|
||||
name: 'Monitoring',
|
||||
component: () => import('@/views/monitoring/index.vue'),
|
||||
meta: { title: '会话监控' }
|
||||
},
|
||||
{
|
||||
path: '/admin/embedding',
|
||||
name: 'EmbeddingConfig',
|
||||
component: () => import('@/views/admin/embedding/index.vue'),
|
||||
meta: { title: '嵌入模型配置' }
|
||||
},
|
||||
{
|
||||
path: '/admin/llm',
|
||||
name: 'LLMConfig',
|
||||
component: () => import('@/views/admin/llm/index.vue'),
|
||||
meta: { title: 'LLM 模型配置' }
|
||||
}
|
||||
]
|
||||
|
||||
const router = createRouter({
|
||||
history: createWebHistory(),
|
||||
routes
|
||||
})
|
||||
|
||||
export default router
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import {
|
||||
getProviders,
|
||||
getConfig,
|
||||
saveConfig,
|
||||
testEmbedding,
|
||||
getSupportedFormats,
|
||||
type EmbeddingProviderInfo,
|
||||
type EmbeddingConfig,
|
||||
type EmbeddingConfigUpdate,
|
||||
type EmbeddingTestResult,
|
||||
type DocumentFormat
|
||||
} from '@/api/embedding'
|
||||
|
||||
export const useEmbeddingStore = defineStore('embedding', () => {
|
||||
const providers = ref<EmbeddingProviderInfo[]>([])
|
||||
const currentConfig = ref<EmbeddingConfig>({
|
||||
provider: '',
|
||||
config: {}
|
||||
})
|
||||
const formats = ref<DocumentFormat[]>([])
|
||||
const loading = ref(false)
|
||||
const providersLoading = ref(false)
|
||||
const formatsLoading = ref(false)
|
||||
const testResult = ref<EmbeddingTestResult | null>(null)
|
||||
const testLoading = ref(false)
|
||||
|
||||
const currentProvider = computed(() => {
|
||||
return providers.value.find(p => p.name === currentConfig.value.provider)
|
||||
})
|
||||
|
||||
const configSchema = computed(() => {
|
||||
return currentProvider.value?.config_schema || { properties: {} }
|
||||
})
|
||||
|
||||
const loadProviders = async () => {
|
||||
providersLoading.value = true
|
||||
try {
|
||||
const res: any = await getProviders()
|
||||
providers.value = res?.providers || res?.data?.providers || []
|
||||
} catch (error) {
|
||||
console.error('Failed to load providers:', error)
|
||||
throw error
|
||||
} finally {
|
||||
providersLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const loadConfig = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await getConfig()
|
||||
const config = res?.data || res
|
||||
if (config) {
|
||||
currentConfig.value = {
|
||||
provider: config.provider || '',
|
||||
config: config.config || {},
|
||||
updated_at: config.updated_at
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load config:', error)
|
||||
throw error
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const saveCurrentConfig = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const updateData: EmbeddingConfigUpdate = {
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
}
|
||||
await saveConfig(updateData)
|
||||
} catch (error) {
|
||||
console.error('Failed to save config:', error)
|
||||
throw error
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const runTest = async (testText?: string) => {
|
||||
testLoading.value = true
|
||||
testResult.value = null
|
||||
try {
|
||||
const result = await testEmbedding({
|
||||
test_text: testText,
|
||||
config: {
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
}
|
||||
})
|
||||
testResult.value = result
|
||||
} catch (error: any) {
|
||||
testResult.value = {
|
||||
success: false,
|
||||
dimension: 0,
|
||||
error: error?.message || '连接测试失败'
|
||||
}
|
||||
} finally {
|
||||
testLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const loadFormats = async () => {
|
||||
formatsLoading.value = true
|
||||
try {
|
||||
const res: any = await getSupportedFormats()
|
||||
formats.value = res?.formats || res?.data?.formats || []
|
||||
} catch (error) {
|
||||
console.error('Failed to load formats:', error)
|
||||
throw error
|
||||
} finally {
|
||||
formatsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const setProvider = (providerName: string) => {
|
||||
currentConfig.value.provider = providerName
|
||||
const provider = providers.value.find(p => p.name === providerName)
|
||||
if (provider?.config_schema?.properties) {
|
||||
const newConfig: Record<string, any> = {}
|
||||
Object.entries(provider.config_schema.properties).forEach(([key, field]: [string, any]) => {
|
||||
newConfig[key] = field.default !== undefined ? field.default : ''
|
||||
})
|
||||
currentConfig.value.config = newConfig
|
||||
} else {
|
||||
currentConfig.value.config = {}
|
||||
}
|
||||
}
|
||||
|
||||
const updateConfigValue = (key: string, value: any) => {
|
||||
currentConfig.value.config[key] = value
|
||||
}
|
||||
|
||||
const clearTestResult = () => {
|
||||
testResult.value = null
|
||||
}
|
||||
|
||||
return {
|
||||
providers,
|
||||
currentConfig,
|
||||
formats,
|
||||
loading,
|
||||
providersLoading,
|
||||
formatsLoading,
|
||||
testResult,
|
||||
testLoading,
|
||||
currentProvider,
|
||||
configSchema,
|
||||
loadProviders,
|
||||
loadConfig,
|
||||
saveCurrentConfig,
|
||||
runTest,
|
||||
loadFormats,
|
||||
setProvider,
|
||||
updateConfigValue,
|
||||
clearTestResult
|
||||
}
|
||||
})
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import {
|
||||
getLLMProviders,
|
||||
getLLMConfig,
|
||||
updateLLMConfig,
|
||||
testLLM,
|
||||
type LLMProviderInfo,
|
||||
type LLMConfig,
|
||||
type LLMConfigUpdate,
|
||||
type LLMTestResult
|
||||
} from '@/api/llm'
|
||||
|
||||
export const useLLMStore = defineStore('llm', () => {
|
||||
const providers = ref<LLMProviderInfo[]>([])
|
||||
const currentConfig = ref<LLMConfig>({
|
||||
provider: '',
|
||||
config: {}
|
||||
})
|
||||
const loading = ref(false)
|
||||
const providersLoading = ref(false)
|
||||
const testResult = ref<LLMTestResult | null>(null)
|
||||
const testLoading = ref(false)
|
||||
|
||||
const currentProvider = computed(() => {
|
||||
return providers.value.find(p => p.name === currentConfig.value.provider)
|
||||
})
|
||||
|
||||
const configSchema = computed(() => {
|
||||
return currentProvider.value?.config_schema || { properties: {} }
|
||||
})
|
||||
|
||||
const loadProviders = async () => {
|
||||
providersLoading.value = true
|
||||
try {
|
||||
const res: any = await getLLMProviders()
|
||||
providers.value = res?.providers || res?.data?.providers || []
|
||||
} catch (error) {
|
||||
console.error('Failed to load LLM providers:', error)
|
||||
throw error
|
||||
} finally {
|
||||
providersLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const loadConfig = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await getLLMConfig()
|
||||
const config = res?.data || res
|
||||
if (config) {
|
||||
currentConfig.value = {
|
||||
provider: config.provider || '',
|
||||
config: config.config || {},
|
||||
updated_at: config.updated_at
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load LLM config:', error)
|
||||
throw error
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const saveCurrentConfig = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const updateData: LLMConfigUpdate = {
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
}
|
||||
await updateLLMConfig(updateData)
|
||||
} catch (error) {
|
||||
console.error('Failed to save LLM config:', error)
|
||||
throw error
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const runTest = async (testPrompt?: string): Promise<LLMTestResult> => {
|
||||
testLoading.value = true
|
||||
testResult.value = null
|
||||
try {
|
||||
const result = await testLLM({
|
||||
test_prompt: testPrompt,
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
})
|
||||
testResult.value = result
|
||||
return result
|
||||
} catch (error: any) {
|
||||
const errorResult: LLMTestResult = {
|
||||
success: false,
|
||||
error: error?.message || '连接测试失败'
|
||||
}
|
||||
testResult.value = errorResult
|
||||
return errorResult
|
||||
} finally {
|
||||
testLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const setProvider = (providerName: string) => {
|
||||
currentConfig.value.provider = providerName
|
||||
const provider = providers.value.find(p => p.name === providerName)
|
||||
if (provider?.config_schema?.properties) {
|
||||
const newConfig: Record<string, any> = {}
|
||||
Object.entries(provider.config_schema.properties).forEach(([key, field]: [string, any]) => {
|
||||
if (field.default !== undefined) {
|
||||
newConfig[key] = field.default
|
||||
} else {
|
||||
switch (field.type) {
|
||||
case 'string':
|
||||
newConfig[key] = ''
|
||||
break
|
||||
case 'integer':
|
||||
case 'number':
|
||||
newConfig[key] = field.minimum ?? 0
|
||||
break
|
||||
case 'boolean':
|
||||
newConfig[key] = false
|
||||
break
|
||||
default:
|
||||
newConfig[key] = null
|
||||
}
|
||||
}
|
||||
})
|
||||
currentConfig.value.config = newConfig
|
||||
} else {
|
||||
currentConfig.value.config = {}
|
||||
}
|
||||
}
|
||||
|
||||
const updateConfigValue = (key: string, value: any) => {
|
||||
currentConfig.value.config[key] = value
|
||||
}
|
||||
|
||||
const clearTestResult = () => {
|
||||
testResult.value = null
|
||||
}
|
||||
|
||||
return {
|
||||
providers,
|
||||
currentConfig,
|
||||
loading,
|
||||
providersLoading,
|
||||
testResult,
|
||||
testLoading,
|
||||
currentProvider,
|
||||
configSchema,
|
||||
loadProviders,
|
||||
loadConfig,
|
||||
saveCurrentConfig,
|
||||
runTest,
|
||||
setProvider,
|
||||
updateConfigValue,
|
||||
clearTestResult
|
||||
}
|
||||
})
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import {
|
||||
runRagExperiment,
|
||||
createSSEConnection,
|
||||
type AIResponse,
|
||||
type RetrievalResult,
|
||||
type RagExperimentRequest,
|
||||
type RagExperimentResult
|
||||
} from '@/api/rag'
|
||||
|
||||
export const useRagStore = defineStore('rag', () => {
|
||||
const retrievalResults = ref<RetrievalResult[]>([])
|
||||
const finalPrompt = ref('')
|
||||
const aiResponse = ref<AIResponse | null>(null)
|
||||
const totalLatencyMs = ref<number>(0)
|
||||
|
||||
const loading = ref(false)
|
||||
const streaming = ref(false)
|
||||
const streamContent = ref('')
|
||||
const streamError = ref<string | null>(null)
|
||||
|
||||
const hasResults = computed(() => retrievalResults.value.length > 0 || aiResponse.value !== null)
|
||||
|
||||
const abortStream = ref<(() => void) | null>(null)
|
||||
|
||||
const runExperiment = async (params: RagExperimentRequest) => {
|
||||
loading.value = true
|
||||
streamError.value = null
|
||||
|
||||
try {
|
||||
const result: RagExperimentResult = await runRagExperiment(params)
|
||||
|
||||
retrievalResults.value = result.retrieval_results || []
|
||||
finalPrompt.value = result.final_prompt || ''
|
||||
aiResponse.value = result.ai_response || null
|
||||
totalLatencyMs.value = result.total_latency_ms || 0
|
||||
|
||||
return result
|
||||
} catch (error: any) {
|
||||
streamError.value = error?.message || '实验运行失败'
|
||||
throw error
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const startStream = (params: RagExperimentRequest) => {
|
||||
streaming.value = true
|
||||
streamContent.value = ''
|
||||
streamError.value = null
|
||||
aiResponse.value = null
|
||||
|
||||
abortStream.value = createSSEConnection(
|
||||
'/admin/rag/experiments/stream',
|
||||
params,
|
||||
(data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
|
||||
if (parsed.type === 'content') {
|
||||
streamContent.value += parsed.content || ''
|
||||
} else if (parsed.type === 'retrieval') {
|
||||
retrievalResults.value = parsed.results || []
|
||||
} else if (parsed.type === 'prompt') {
|
||||
finalPrompt.value = parsed.prompt || ''
|
||||
} else if (parsed.type === 'complete') {
|
||||
aiResponse.value = {
|
||||
content: streamContent.value,
|
||||
prompt_tokens: parsed.prompt_tokens,
|
||||
completion_tokens: parsed.completion_tokens,
|
||||
total_tokens: parsed.total_tokens,
|
||||
latency_ms: parsed.latency_ms,
|
||||
model: parsed.model
|
||||
}
|
||||
totalLatencyMs.value = parsed.total_latency_ms || 0
|
||||
} else if (parsed.type === 'error') {
|
||||
streamError.value = parsed.message || '流式输出错误'
|
||||
}
|
||||
} catch {
|
||||
streamContent.value += data
|
||||
}
|
||||
},
|
||||
(error: Error) => {
|
||||
streaming.value = false
|
||||
streamError.value = error.message
|
||||
},
|
||||
() => {
|
||||
streaming.value = false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
const stopStream = () => {
|
||||
if (abortStream.value) {
|
||||
abortStream.value()
|
||||
abortStream.value = null
|
||||
}
|
||||
streaming.value = false
|
||||
}
|
||||
|
||||
const clearResults = () => {
|
||||
retrievalResults.value = []
|
||||
finalPrompt.value = ''
|
||||
aiResponse.value = null
|
||||
totalLatencyMs.value = 0
|
||||
streamContent.value = ''
|
||||
streamError.value = null
|
||||
}
|
||||
|
||||
return {
|
||||
retrievalResults,
|
||||
finalPrompt,
|
||||
aiResponse,
|
||||
totalLatencyMs,
|
||||
loading,
|
||||
streaming,
|
||||
streamContent,
|
||||
streamError,
|
||||
hasResults,
|
||||
runExperiment,
|
||||
startStream,
|
||||
stopStream,
|
||||
clearResults
|
||||
}
|
||||
})
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import { ref, watch } from 'vue'
|
||||
|
||||
export const useRagLabStore = defineStore('ragLab', () => {
|
||||
const query = ref(localStorage.getItem('ragLab_query') || '')
|
||||
const kbIds = ref<string[]>(JSON.parse(localStorage.getItem('ragLab_kbIds') || '[]'))
|
||||
const llmProvider = ref(localStorage.getItem('ragLab_llmProvider') || '')
|
||||
const topK = ref(parseInt(localStorage.getItem('ragLab_topK') || '3', 10))
|
||||
const scoreThreshold = ref(parseFloat(localStorage.getItem('ragLab_scoreThreshold') || '0.5'))
|
||||
const generateResponse = ref(localStorage.getItem('ragLab_generateResponse') !== 'false')
|
||||
const streamOutput = ref(localStorage.getItem('ragLab_streamOutput') === 'true')
|
||||
|
||||
watch(query, (val) => localStorage.setItem('ragLab_query', val))
|
||||
watch(kbIds, (val) => localStorage.setItem('ragLab_kbIds', JSON.stringify(val)), { deep: true })
|
||||
watch(llmProvider, (val) => localStorage.setItem('ragLab_llmProvider', val))
|
||||
watch(topK, (val) => localStorage.setItem('ragLab_topK', String(val)))
|
||||
watch(scoreThreshold, (val) => localStorage.setItem('ragLab_scoreThreshold', String(val)))
|
||||
watch(generateResponse, (val) => localStorage.setItem('ragLab_generateResponse', String(val)))
|
||||
watch(streamOutput, (val) => localStorage.setItem('ragLab_streamOutput', String(val)))
|
||||
|
||||
const clearParams = () => {
|
||||
query.value = ''
|
||||
kbIds.value = []
|
||||
llmProvider.value = ''
|
||||
topK.value = 3
|
||||
scoreThreshold.value = 0.5
|
||||
generateResponse.value = true
|
||||
streamOutput.value = false
|
||||
}
|
||||
|
||||
return {
|
||||
query,
|
||||
kbIds,
|
||||
llmProvider,
|
||||
topK,
|
||||
scoreThreshold,
|
||||
generateResponse,
|
||||
streamOutput,
|
||||
clearParams
|
||||
}
|
||||
})
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import { defineStore } from 'pinia'
|
||||
|
||||
// Default tenant ID format: name@ash@year
|
||||
const DEFAULT_TENANT_ID = 'default@ash@2026'
|
||||
|
||||
export const useTenantStore = defineStore('tenant', {
|
||||
state: () => ({
|
||||
currentTenantId: localStorage.getItem('X-Tenant-Id') || DEFAULT_TENANT_ID
|
||||
}),
|
||||
actions: {
|
||||
setTenant(id: string) {
|
||||
this.currentTenantId = id
|
||||
localStorage.setItem('X-Tenant-Id', id)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -0,0 +1,486 @@
|
|||
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=DM+Sans:wght@400;500;600;700&display=swap');
|
||||
|
||||
:root {
|
||||
--primary-color: #4F7CFF;
|
||||
--primary-light: #6B91FF;
|
||||
--primary-lighter: #E8EEFF;
|
||||
--primary-dark: #3A5FD9;
|
||||
|
||||
--secondary-color: #6366F1;
|
||||
--secondary-light: #818CF8;
|
||||
|
||||
--accent-color: #10B981;
|
||||
--accent-light: #34D399;
|
||||
|
||||
--warning-color: #F59E0B;
|
||||
--danger-color: #EF4444;
|
||||
--success-color: #10B981;
|
||||
--info-color: #3B82F6;
|
||||
|
||||
--bg-primary: #F8FAFC;
|
||||
--bg-secondary: #FFFFFF;
|
||||
--bg-tertiary: #F1F5F9;
|
||||
--bg-hover: #F1F5F9;
|
||||
|
||||
--text-primary: #1E293B;
|
||||
--text-secondary: #64748B;
|
||||
--text-tertiary: #94A3B8;
|
||||
--text-placeholder: #CBD5E1;
|
||||
|
||||
--border-color: #E2E8F0;
|
||||
--border-light: #F1F5F9;
|
||||
|
||||
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
|
||||
--shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -2px rgba(0, 0, 0, 0.05);
|
||||
--shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.05), 0 4px 6px -4px rgba(0, 0, 0, 0.05);
|
||||
--shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.05), 0 8px 10px -6px rgba(0, 0, 0, 0.05);
|
||||
|
||||
--radius-sm: 6px;
|
||||
--radius-md: 10px;
|
||||
--radius-lg: 14px;
|
||||
--radius-xl: 20px;
|
||||
|
||||
--transition-fast: 0.15s ease;
|
||||
--transition-normal: 0.25s ease;
|
||||
--transition-slow: 0.35s ease;
|
||||
|
||||
--font-sans: 'DM Sans', 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
--font-mono: 'JetBrains Mono', 'Fira Code', 'SF Mono', Consolas, monospace;
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html, body {
|
||||
font-family: var(--font-sans);
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: var(--text-primary);
|
||||
background-color: var(--bg-primary);
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
}
|
||||
|
||||
.el-menu {
|
||||
font-family: var(--font-sans) !important;
|
||||
border-bottom: 1px solid var(--border-color) !important;
|
||||
background-color: var(--bg-secondary) !important;
|
||||
}
|
||||
|
||||
.el-menu-item {
|
||||
font-weight: 500 !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-menu-item:hover {
|
||||
background-color: var(--bg-hover) !important;
|
||||
}
|
||||
|
||||
.el-menu-item.is-active {
|
||||
color: var(--primary-color) !important;
|
||||
border-bottom-color: var(--primary-color) !important;
|
||||
background-color: var(--primary-lighter) !important;
|
||||
}
|
||||
|
||||
.el-card {
|
||||
border-radius: var(--radius-lg) !important;
|
||||
border: 1px solid var(--border-color) !important;
|
||||
box-shadow: var(--shadow-sm) !important;
|
||||
transition: all var(--transition-normal) !important;
|
||||
background-color: var(--bg-secondary) !important;
|
||||
}
|
||||
|
||||
.el-card:hover {
|
||||
box-shadow: var(--shadow-md) !important;
|
||||
}
|
||||
|
||||
.el-card__header {
|
||||
border-bottom: 1px solid var(--border-light) !important;
|
||||
padding: 16px 20px !important;
|
||||
font-weight: 600 !important;
|
||||
color: var(--text-primary) !important;
|
||||
}
|
||||
|
||||
.el-card__body {
|
||||
padding: 20px !important;
|
||||
}
|
||||
|
||||
.el-button--primary {
|
||||
background-color: var(--primary-color) !important;
|
||||
border-color: var(--primary-color) !important;
|
||||
font-weight: 500 !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-button--primary:hover {
|
||||
background-color: var(--primary-light) !important;
|
||||
border-color: var(--primary-light) !important;
|
||||
transform: translateY(-1px);
|
||||
box-shadow: var(--shadow-md) !important;
|
||||
}
|
||||
|
||||
.el-button--default {
|
||||
font-weight: 500 !important;
|
||||
border-color: var(--border-color) !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-button--default:hover {
|
||||
border-color: var(--primary-color) !important;
|
||||
color: var(--primary-color) !important;
|
||||
background-color: var(--primary-lighter) !important;
|
||||
}
|
||||
|
||||
.el-input__wrapper {
|
||||
border-radius: var(--radius-md) !important;
|
||||
box-shadow: none !important;
|
||||
border: 1px solid var(--border-color) !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-input__wrapper:hover {
|
||||
border-color: var(--primary-light) !important;
|
||||
}
|
||||
|
||||
.el-input__wrapper.is-focus {
|
||||
border-color: var(--primary-color) !important;
|
||||
box-shadow: 0 0 0 3px var(--primary-lighter) !important;
|
||||
}
|
||||
|
||||
.el-select {
|
||||
--el-select-border-color-hover: var(--primary-light) !important;
|
||||
}
|
||||
|
||||
.el-select .el-input__wrapper {
|
||||
border-radius: var(--radius-md) !important;
|
||||
}
|
||||
|
||||
.el-select-dropdown {
|
||||
border-radius: var(--radius-lg) !important;
|
||||
border: 1px solid var(--border-color) !important;
|
||||
box-shadow: var(--shadow-xl) !important;
|
||||
margin-top: 8px !important;
|
||||
}
|
||||
|
||||
.el-select-dropdown__wrap {
|
||||
max-height: 320px !important;
|
||||
}
|
||||
|
||||
.el-select-dropdown__item {
|
||||
padding: 10px 16px !important;
|
||||
line-height: 1.5 !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-select-dropdown__item.hover,
|
||||
.el-select-dropdown__item:hover {
|
||||
background-color: var(--primary-lighter) !important;
|
||||
color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-select-dropdown__item.is-selected {
|
||||
background-color: var(--primary-lighter) !important;
|
||||
color: var(--primary-color) !important;
|
||||
font-weight: 600 !important;
|
||||
}
|
||||
|
||||
.el-tag {
|
||||
border-radius: var(--radius-sm) !important;
|
||||
font-weight: 500 !important;
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
.el-tag--success {
|
||||
background-color: #D1FAE5 !important;
|
||||
color: #059669 !important;
|
||||
}
|
||||
|
||||
.el-tag--warning {
|
||||
background-color: #FEF3C7 !important;
|
||||
color: #D97706 !important;
|
||||
}
|
||||
|
||||
.el-tag--danger {
|
||||
background-color: #FEE2E2 !important;
|
||||
color: #DC2626 !important;
|
||||
}
|
||||
|
||||
.el-tag--info {
|
||||
background-color: #E0E7FF !important;
|
||||
color: #4F46E5 !important;
|
||||
}
|
||||
|
||||
.el-table {
|
||||
border-radius: var(--radius-lg) !important;
|
||||
overflow: hidden !important;
|
||||
}
|
||||
|
||||
.el-table th.el-table__cell {
|
||||
background-color: var(--bg-tertiary) !important;
|
||||
font-weight: 600 !important;
|
||||
color: var(--text-secondary) !important;
|
||||
}
|
||||
|
||||
.el-table td.el-table__cell {
|
||||
border-bottom: 1px solid var(--border-light) !important;
|
||||
}
|
||||
|
||||
.el-table--striped .el-table__body tr.el-table__row--striped td.el-table__cell {
|
||||
background-color: var(--bg-tertiary) !important;
|
||||
}
|
||||
|
||||
.el-table__row:hover > td.el-table__cell {
|
||||
background-color: var(--primary-lighter) !important;
|
||||
}
|
||||
|
||||
.el-tabs__item {
|
||||
font-weight: 500 !important;
|
||||
transition: all var(--transition-fast) !important;
|
||||
}
|
||||
|
||||
.el-tabs__item.is-active {
|
||||
color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-tabs__active-bar {
|
||||
background-color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-dialog {
|
||||
border-radius: var(--radius-xl) !important;
|
||||
overflow: hidden !important;
|
||||
}
|
||||
|
||||
.el-dialog__header {
|
||||
padding: 20px 24px !important;
|
||||
border-bottom: 1px solid var(--border-light) !important;
|
||||
}
|
||||
|
||||
.el-dialog__title {
|
||||
font-weight: 600 !important;
|
||||
font-size: 18px !important;
|
||||
}
|
||||
|
||||
.el-dialog__body {
|
||||
padding: 24px !important;
|
||||
}
|
||||
|
||||
.el-form-item__label {
|
||||
font-weight: 500 !important;
|
||||
color: var(--text-secondary) !important;
|
||||
}
|
||||
|
||||
.el-slider__bar {
|
||||
background-color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-slider__button {
|
||||
border-color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-switch.is-checked .el-switch__core {
|
||||
background-color: var(--primary-color) !important;
|
||||
border-color: var(--primary-color) !important;
|
||||
}
|
||||
|
||||
.el-alert {
|
||||
border-radius: var(--radius-md) !important;
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
.el-alert--info {
|
||||
background-color: #EFF6FF !important;
|
||||
}
|
||||
|
||||
.el-alert--success {
|
||||
background-color: #ECFDF5 !important;
|
||||
}
|
||||
|
||||
.el-alert--warning {
|
||||
background-color: #FFFBEB !important;
|
||||
}
|
||||
|
||||
.el-alert--error {
|
||||
background-color: #FEF2F2 !important;
|
||||
}
|
||||
|
||||
.el-divider {
|
||||
border-color: var(--border-light) !important;
|
||||
}
|
||||
|
||||
.el-empty__description {
|
||||
color: var(--text-tertiary) !important;
|
||||
}
|
||||
|
||||
.el-loading-mask {
|
||||
border-radius: var(--radius-lg) !important;
|
||||
}
|
||||
|
||||
.el-descriptions {
|
||||
border-radius: var(--radius-md) !important;
|
||||
overflow: hidden !important;
|
||||
}
|
||||
|
||||
.el-descriptions__label {
|
||||
background-color: var(--bg-tertiary) !important;
|
||||
font-weight: 500 !important;
|
||||
}
|
||||
|
||||
.fade-in-up {
|
||||
animation: fadeInUp 0.5s ease-out forwards;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.slide-in-left {
|
||||
animation: slideInLeft 0.4s ease-out forwards;
|
||||
}
|
||||
|
||||
@keyframes slideInLeft {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateX(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
|
||||
.scale-in {
|
||||
animation: scaleIn 0.3s ease-out forwards;
|
||||
}
|
||||
|
||||
@keyframes scaleIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: scale(0.95);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
.page-container {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
background-color: var(--bg-primary);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
margin: 0 0 8px 0;
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
margin: 0;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.card-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: var(--radius-md);
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.card-icon.primary {
|
||||
background-color: var(--primary-lighter);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.card-icon.success {
|
||||
background-color: #D1FAE5;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.card-icon.warning {
|
||||
background-color: #FEF3C7;
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.card-icon.info {
|
||||
background-color: #E0E7FF;
|
||||
color: #4F46E5;
|
||||
}
|
||||
|
||||
.stat-card {
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.stat-card::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 3px;
|
||||
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
|
||||
opacity: 0;
|
||||
transition: opacity var(--transition-fast);
|
||||
}
|
||||
|
||||
.stat-card:hover::before {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: var(--text-tertiary);
|
||||
border-radius: 4px;
|
||||
transition: background var(--transition-fast);
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
::selection {
|
||||
background-color: var(--primary-lighter);
|
||||
color: var(--primary-dark);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.page-container {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
export interface EmbeddingProviderInfo {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
config_schema: Record<string, any>
|
||||
}
|
||||
|
||||
export interface EmbeddingConfig {
|
||||
provider: string
|
||||
config: Record<string, any>
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface EmbeddingConfigUpdate {
|
||||
provider: string
|
||||
config?: Record<string, any>
|
||||
}
|
||||
|
||||
export interface EmbeddingTestResult {
|
||||
success: boolean
|
||||
dimension: number
|
||||
latency_ms?: number
|
||||
message?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface DocumentFormat {
|
||||
extension: string
|
||||
name: string
|
||||
description?: string
|
||||
}
|
||||
|
||||
export interface EmbeddingProvidersResponse {
|
||||
providers: EmbeddingProviderInfo[]
|
||||
}
|
||||
|
||||
export interface EmbeddingConfigUpdateResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface SupportedFormatsResponse {
|
||||
formats: DocumentFormat[]
|
||||
}
|
||||
|
||||
export interface EmbeddingTestRequest {
|
||||
test_text?: string
|
||||
config?: EmbeddingConfigUpdate
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
export interface LLMProviderInfo {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
config_schema: Record<string, any>
|
||||
}
|
||||
|
||||
export interface LLMConfig {
|
||||
provider: string
|
||||
config: Record<string, any>
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface LLMConfigUpdate {
|
||||
provider: string
|
||||
config?: Record<string, any>
|
||||
}
|
||||
|
||||
export interface LLMTestResult {
|
||||
success: boolean
|
||||
response?: string
|
||||
latency_ms?: number
|
||||
prompt_tokens?: number
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
message?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface LLMTestRequest {
|
||||
test_prompt?: string
|
||||
provider?: string
|
||||
config?: Record<string, any>
|
||||
}
|
||||
|
||||
export interface LLMProvidersResponse {
|
||||
providers: LLMProviderInfo[]
|
||||
}
|
||||
|
||||
export interface LLMConfigUpdateResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
}
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
import axios from 'axios'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
|
||||
// 创建 axios 实例
|
||||
const service = axios.create({
|
||||
baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
|
||||
timeout: 60000
|
||||
})
|
||||
|
||||
// 请求拦截器
|
||||
service.interceptors.request.use(
|
||||
(config) => {
|
||||
const tenantStore = useTenantStore()
|
||||
if (tenantStore.currentTenantId) {
|
||||
config.headers['X-Tenant-Id'] = tenantStore.currentTenantId
|
||||
}
|
||||
// TODO: 如果有 token 也可以在这里注入 Authorization
|
||||
return config
|
||||
},
|
||||
(error) => {
|
||||
console.log(error)
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
// 响应拦截器
|
||||
service.interceptors.response.use(
|
||||
(response) => {
|
||||
const res = response.data
|
||||
// 这里可以根据后端的 code 进行统一处理
|
||||
return res
|
||||
},
|
||||
(error) => {
|
||||
console.log('err' + error)
|
||||
let { message, response } = error
|
||||
if (response) {
|
||||
const status = response.status
|
||||
if (status === 401) {
|
||||
ElMessageBox.confirm('登录状态已过期,您可以继续留在该页面,或者重新登录', '系统提示', {
|
||||
confirmButtonText: '重新登录',
|
||||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
}).then(() => {
|
||||
// TODO: 跳转到登录页或执行退出逻辑
|
||||
location.href = '/login'
|
||||
})
|
||||
} else if (status === 403) {
|
||||
ElMessage({
|
||||
message: '当前操作无权限',
|
||||
type: 'error',
|
||||
duration: 5 * 1000
|
||||
})
|
||||
} else {
|
||||
ElMessage({
|
||||
message: message || '后端接口未知异常',
|
||||
type: 'error',
|
||||
duration: 5 * 1000
|
||||
})
|
||||
}
|
||||
} else {
|
||||
ElMessage({
|
||||
message: '网络连接异常',
|
||||
type: 'error',
|
||||
duration: 5 * 1000
|
||||
})
|
||||
}
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
export default service
|
||||
|
|
@ -0,0 +1,483 @@
|
|||
<template>
|
||||
<div class="embedding-config-page">
|
||||
<div class="page-header">
|
||||
<div class="header-content">
|
||||
<div class="title-section">
|
||||
<h1 class="page-title">嵌入模型配置</h1>
|
||||
<p class="page-desc">配置和管理系统使用的嵌入模型,支持多种提供者切换。配置修改后需保存才能生效。</p>
|
||||
</div>
|
||||
<div class="header-actions" v-if="currentConfig.updated_at">
|
||||
<div class="update-info">
|
||||
<el-icon class="update-icon"><Clock /></el-icon>
|
||||
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-row :gutter="24" v-loading="pageLoading" element-loading-text="加载中...">
|
||||
<el-col :xs="24" :sm="24" :md="12" :lg="12">
|
||||
<el-card shadow="hover" class="config-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><Setting /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">模型配置</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="card-content">
|
||||
<div class="provider-select-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Connection /></el-icon>
|
||||
<span>选择提供者</span>
|
||||
</div>
|
||||
<EmbeddingProviderSelect
|
||||
v-model="currentConfig.provider"
|
||||
:providers="providers"
|
||||
:loading="providersLoading"
|
||||
placeholder="请选择嵌入模型提供者"
|
||||
@change="handleProviderChange"
|
||||
/>
|
||||
<transition name="fade">
|
||||
<div v-if="currentProvider" class="provider-info">
|
||||
<el-icon class="info-icon"><InfoFilled /></el-icon>
|
||||
<span class="info-text">{{ currentProvider.description }}</span>
|
||||
</div>
|
||||
</transition>
|
||||
</div>
|
||||
|
||||
<el-divider />
|
||||
|
||||
<transition name="slide-fade" mode="out-in">
|
||||
<div v-if="currentConfig.provider" key="form" class="config-form-section">
|
||||
<EmbeddingConfigForm
|
||||
ref="configFormRef"
|
||||
:schema="configSchema"
|
||||
v-model="currentConfig.config"
|
||||
label-width="140px"
|
||||
/>
|
||||
</div>
|
||||
<el-empty v-else key="empty" description="请先选择一个嵌入模型提供者" :image-size="120">
|
||||
<template #image>
|
||||
<div class="empty-icon">
|
||||
<el-icon><Box /></el-icon>
|
||||
</div>
|
||||
</template>
|
||||
</el-empty>
|
||||
</transition>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<div class="card-footer">
|
||||
<el-button size="large" @click="handleReset">
|
||||
<el-icon><RefreshLeft /></el-icon>
|
||||
重置
|
||||
</el-button>
|
||||
<el-button type="primary" size="large" :loading="saving" @click="handleSave">
|
||||
<el-icon><Check /></el-icon>
|
||||
保存配置
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
<el-col :xs="24" :sm="24" :md="12" :lg="12">
|
||||
<div class="right-column">
|
||||
<EmbeddingTestPanel
|
||||
:config="{ provider: currentConfig.provider, config: currentConfig.config }"
|
||||
/>
|
||||
|
||||
<el-card shadow="hover" class="formats-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><Document /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">支持的文档格式</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<SupportedFormats />
|
||||
</el-card>
|
||||
</div>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Setting, Connection, InfoFilled, Box, RefreshLeft, Check, Clock, Document } from '@element-plus/icons-vue'
|
||||
import { useEmbeddingStore } from '@/stores/embedding'
|
||||
import EmbeddingProviderSelect from '@/components/embedding/EmbeddingProviderSelect.vue'
|
||||
import EmbeddingConfigForm from '@/components/embedding/EmbeddingConfigForm.vue'
|
||||
import EmbeddingTestPanel from '@/components/embedding/EmbeddingTestPanel.vue'
|
||||
import SupportedFormats from '@/components/embedding/SupportedFormats.vue'
|
||||
|
||||
const embeddingStore = useEmbeddingStore()
|
||||
|
||||
const configFormRef = ref<InstanceType<typeof EmbeddingConfigForm>>()
|
||||
const saving = ref(false)
|
||||
const pageLoading = ref(false)
|
||||
|
||||
const providers = computed(() => embeddingStore.providers)
|
||||
const currentConfig = computed(() => embeddingStore.currentConfig)
|
||||
const currentProvider = computed(() => embeddingStore.currentProvider)
|
||||
const configSchema = computed(() => embeddingStore.configSchema)
|
||||
const providersLoading = computed(() => embeddingStore.providersLoading)
|
||||
|
||||
const formatDate = (dateStr: string) => {
|
||||
if (!dateStr) return ''
|
||||
const date = new Date(dateStr)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
})
|
||||
}
|
||||
|
||||
const handleProviderChange = (provider: any) => {
|
||||
if (provider) {
|
||||
embeddingStore.setProvider(provider.name)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!currentConfig.value.provider) {
|
||||
ElMessage.warning('请先选择嵌入模型提供者')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const valid = await configFormRef.value?.validate()
|
||||
if (!valid) {
|
||||
return
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.warning('请检查配置表单中的必填项')
|
||||
return
|
||||
}
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
await embeddingStore.saveCurrentConfig()
|
||||
ElMessage.success('配置保存成功')
|
||||
} catch (error) {
|
||||
ElMessage.error('配置保存失败')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleReset = async () => {
|
||||
try {
|
||||
await ElMessageBox.confirm('确定要重置配置吗?将恢复为当前保存的配置。', '确认重置', {
|
||||
confirmButtonText: '确定',
|
||||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
})
|
||||
await embeddingStore.loadConfig()
|
||||
ElMessage.success('配置已重置')
|
||||
} catch (error) {
|
||||
}
|
||||
}
|
||||
|
||||
const initPage = async () => {
|
||||
pageLoading.value = true
|
||||
try {
|
||||
await Promise.all([
|
||||
embeddingStore.loadProviders(),
|
||||
embeddingStore.loadConfig(),
|
||||
embeddingStore.loadFormats()
|
||||
])
|
||||
} catch (error) {
|
||||
ElMessage.error('初始化页面失败')
|
||||
} finally {
|
||||
pageLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
initPage()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.embedding-config-page {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
animation: slideDown 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-16px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.header-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
gap: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
flex: 1;
|
||||
min-width: 300px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.update-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 14px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.update-icon {
|
||||
font-size: 14px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.config-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: var(--primary-lighter);
|
||||
border-radius: 10px;
|
||||
color: var(--primary-color);
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.card-content {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.provider-select-section {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.section-label .el-icon {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 10px;
|
||||
margin-top: 14px;
|
||||
padding: 14px 16px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 10px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
border-left: 3px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.info-icon {
|
||||
margin-top: 2px;
|
||||
color: var(--primary-color);
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.info-text {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.config-form-section {
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-track {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-thumb {
|
||||
background: var(--text-tertiary);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
.card-footer {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 12px;
|
||||
padding-top: 8px;
|
||||
}
|
||||
|
||||
.right-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 24px;
|
||||
}
|
||||
|
||||
.formats-card {
|
||||
animation: fadeInUp 0.6s ease-out;
|
||||
}
|
||||
|
||||
.empty-icon {
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 50%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.empty-icon .el-icon {
|
||||
font-size: 48px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.fade-enter-active,
|
||||
.fade-leave-active {
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.fade-enter-from,
|
||||
.fade-leave-to {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.slide-fade-enter-active {
|
||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.slide-fade-leave-active {
|
||||
transition: all 0.2s cubic-bezier(1, 0.5, 0.8, 1);
|
||||
}
|
||||
|
||||
.slide-fade-enter-from {
|
||||
opacity: 0;
|
||||
transform: translateX(-16px);
|
||||
}
|
||||
|
||||
.slide-fade-leave-to {
|
||||
opacity: 0;
|
||||
transform: translateX(16px);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.embedding-config-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
min-width: 100%;
|
||||
}
|
||||
|
||||
.config-form-section {
|
||||
max-height: 300px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,470 @@
|
|||
<template>
|
||||
<div class="llm-config-page">
|
||||
<div class="page-header">
|
||||
<div class="header-content">
|
||||
<div class="title-section">
|
||||
<h1 class="page-title">LLM 模型配置</h1>
|
||||
<p class="page-desc">配置和管理系统使用的大语言模型,支持多种提供者切换。配置修改后需保存才能生效。</p>
|
||||
</div>
|
||||
<div class="header-actions" v-if="currentConfig.updated_at">
|
||||
<div class="update-info">
|
||||
<el-icon class="update-icon"><Clock /></el-icon>
|
||||
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-row :gutter="24" v-loading="pageLoading" element-loading-text="加载中...">
|
||||
<el-col :xs="24" :sm="24" :md="12" :lg="12">
|
||||
<el-card shadow="hover" class="config-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper llm-icon">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">模型配置</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="card-content">
|
||||
<div class="provider-select-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Connection /></el-icon>
|
||||
<span>选择提供者</span>
|
||||
</div>
|
||||
<ProviderSelect
|
||||
v-model="currentConfig.provider"
|
||||
:providers="providers"
|
||||
:loading="providersLoading"
|
||||
placeholder="请选择 LLM 提供者"
|
||||
@change="handleProviderChange"
|
||||
/>
|
||||
<transition name="fade">
|
||||
<div v-if="currentProvider" class="provider-info">
|
||||
<el-icon class="info-icon"><InfoFilled /></el-icon>
|
||||
<span class="info-text">{{ currentProvider.description }}</span>
|
||||
</div>
|
||||
</transition>
|
||||
</div>
|
||||
|
||||
<el-divider />
|
||||
|
||||
<transition name="slide-fade" mode="out-in">
|
||||
<div v-if="currentConfig.provider" key="form" class="config-form-section">
|
||||
<ConfigForm
|
||||
ref="configFormRef"
|
||||
:schema="configSchema"
|
||||
v-model="currentConfig.config"
|
||||
label-width="140px"
|
||||
/>
|
||||
</div>
|
||||
<el-empty v-else key="empty" description="请先选择一个 LLM 提供者" :image-size="120">
|
||||
<template #image>
|
||||
<div class="empty-icon">
|
||||
<el-icon><Box /></el-icon>
|
||||
</div>
|
||||
</template>
|
||||
</el-empty>
|
||||
</transition>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<div class="card-footer">
|
||||
<el-button size="large" @click="handleReset">
|
||||
<el-icon><RefreshLeft /></el-icon>
|
||||
重置
|
||||
</el-button>
|
||||
<el-button type="primary" size="large" :loading="saving" @click="handleSave">
|
||||
<el-icon><Check /></el-icon>
|
||||
保存配置
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
<el-col :xs="24" :sm="24" :md="12" :lg="12">
|
||||
<TestPanel
|
||||
:test-fn="handleTest"
|
||||
:can-test="!!currentConfig.provider"
|
||||
title="LLM 连接测试"
|
||||
input-label="测试提示词"
|
||||
input-placeholder="请输入测试提示词(可选,默认使用系统预设提示词)"
|
||||
:show-token-stats="true"
|
||||
/>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Clock } from '@element-plus/icons-vue'
|
||||
import { useLLMStore } from '@/stores/llm'
|
||||
import ProviderSelect from '@/components/common/ProviderSelect.vue'
|
||||
import ConfigForm from '@/components/common/ConfigForm.vue'
|
||||
import TestPanel from '@/components/common/TestPanel.vue'
|
||||
import type { TestResult } from '@/components/common/TestPanel.vue'
|
||||
|
||||
const llmStore = useLLMStore()
|
||||
|
||||
const configFormRef = ref<InstanceType<typeof ConfigForm>>()
|
||||
const saving = ref(false)
|
||||
const pageLoading = ref(false)
|
||||
|
||||
const providers = computed(() => llmStore.providers)
|
||||
const currentConfig = computed(() => llmStore.currentConfig)
|
||||
const currentProvider = computed(() => llmStore.currentProvider)
|
||||
const configSchema = computed(() => llmStore.configSchema)
|
||||
const providersLoading = computed(() => llmStore.providersLoading)
|
||||
|
||||
const formatDate = (dateStr: string) => {
|
||||
if (!dateStr) return ''
|
||||
const date = new Date(dateStr)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
})
|
||||
}
|
||||
|
||||
const handleProviderChange = (provider: any) => {
|
||||
if (provider?.name) {
|
||||
llmStore.setProvider(provider.name)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!currentConfig.value.provider) {
|
||||
ElMessage.warning('请先选择 LLM 提供者')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const valid = await configFormRef.value?.validate()
|
||||
if (!valid) {
|
||||
return
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.warning('请检查配置表单中的必填项')
|
||||
return
|
||||
}
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
await llmStore.saveCurrentConfig()
|
||||
ElMessage.success('配置保存成功')
|
||||
} catch (error) {
|
||||
ElMessage.error('配置保存失败')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleReset = async () => {
|
||||
try {
|
||||
await ElMessageBox.confirm('确定要重置配置吗?将恢复为当前保存的配置。', '确认重置', {
|
||||
confirmButtonText: '确定',
|
||||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
})
|
||||
await llmStore.loadConfig()
|
||||
ElMessage.success('配置已重置')
|
||||
} catch (error) {
|
||||
}
|
||||
}
|
||||
|
||||
const handleTest = async (input?: string): Promise<TestResult> => {
|
||||
return await llmStore.runTest(input)
|
||||
}
|
||||
|
||||
const initPage = async () => {
|
||||
pageLoading.value = true
|
||||
try {
|
||||
await Promise.all([
|
||||
llmStore.loadProviders(),
|
||||
llmStore.loadConfig()
|
||||
])
|
||||
} catch (error) {
|
||||
ElMessage.error('初始化页面失败')
|
||||
} finally {
|
||||
pageLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
initPage()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.llm-config-page {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
animation: slideDown 0.4s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-16px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.header-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
gap: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
flex: 1;
|
||||
min-width: 300px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.update-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 14px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.update-icon {
|
||||
font-size: 14px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.config-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: var(--primary-lighter);
|
||||
border-radius: 10px;
|
||||
color: var(--primary-color);
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.icon-wrapper.llm-icon {
|
||||
background: linear-gradient(135deg, #E0F2FE 0%, #BAE6FD 100%);
|
||||
color: #0284C7;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.card-content {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.provider-select-section {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.section-label .el-icon {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 10px;
|
||||
margin-top: 14px;
|
||||
padding: 14px 16px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 10px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
border-left: 3px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.info-icon {
|
||||
margin-top: 2px;
|
||||
color: var(--primary-color);
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.info-text {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.config-form-section {
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-track {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-thumb {
|
||||
background: var(--text-tertiary);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.config-form-section::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
.card-footer {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 12px;
|
||||
padding-top: 8px;
|
||||
}
|
||||
|
||||
.empty-icon {
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 50%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.empty-icon .el-icon {
|
||||
font-size: 48px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.fade-enter-active,
|
||||
.fade-leave-active {
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.fade-enter-from,
|
||||
.fade-leave-to {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.slide-fade-enter-active {
|
||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
.slide-fade-leave-active {
|
||||
transition: all 0.2s cubic-bezier(1, 0.5, 0.8, 1);
|
||||
}
|
||||
|
||||
.slide-fade-enter-from {
|
||||
opacity: 0;
|
||||
transform: translateX(-16px);
|
||||
}
|
||||
|
||||
.slide-fade-leave-to {
|
||||
opacity: 0;
|
||||
transform: translateX(16px);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.llm-config-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
min-width: 100%;
|
||||
}
|
||||
|
||||
.config-form-section {
|
||||
max-height: 300px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,720 @@
|
|||
<template>
|
||||
<div class="dashboard-page">
|
||||
<div class="page-header">
|
||||
<h1 class="page-title">控制台</h1>
|
||||
<p class="page-desc">系统概览与数据统计</p>
|
||||
</div>
|
||||
|
||||
<el-row :gutter="20" v-loading="loading">
|
||||
<el-col :xs="12" :sm="12" :md="6" :lg="6">
|
||||
<el-card shadow="hover" class="stat-card">
|
||||
<div class="stat-content">
|
||||
<div class="stat-icon primary">
|
||||
<el-icon><FolderOpened /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-value">{{ stats.knowledgeBases }}</span>
|
||||
<span class="stat-label">知识库总数</span>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :xs="12" :sm="12" :md="6" :lg="6">
|
||||
<el-card shadow="hover" class="stat-card">
|
||||
<div class="stat-content">
|
||||
<div class="stat-icon success">
|
||||
<el-icon><Document /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-value">{{ stats.totalDocuments }}</span>
|
||||
<span class="stat-label">文档总数</span>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :xs="12" :sm="12" :md="6" :lg="6">
|
||||
<el-card shadow="hover" class="stat-card">
|
||||
<div class="stat-content">
|
||||
<div class="stat-icon warning">
|
||||
<el-icon><ChatDotSquare /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-value">{{ stats.totalMessages.toLocaleString() }}</span>
|
||||
<span class="stat-label">总消息数</span>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :xs="12" :sm="12" :md="6" :lg="6">
|
||||
<el-card shadow="hover" class="stat-card">
|
||||
<div class="stat-content">
|
||||
<div class="stat-icon info">
|
||||
<el-icon><Monitor /></el-icon>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<span class="stat-value">{{ stats.totalSessions }}</span>
|
||||
<span class="stat-label">会话总数</span>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" style="margin-top: 20px;">
|
||||
<el-col :xs="24" :sm="24" :md="8" :lg="8">
|
||||
<el-card shadow="hover" class="metric-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper primary">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">Token 消耗统计</span>
|
||||
</div>
|
||||
<el-tag type="primary" size="small" effect="plain">实时</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
<div class="metric-content">
|
||||
<div class="metric-main">
|
||||
<span class="metric-value primary">{{ formatNumber(stats.totalTokens) }}</span>
|
||||
<span class="metric-label">总消耗</span>
|
||||
</div>
|
||||
<div class="metric-divider"></div>
|
||||
<div class="metric-detail">
|
||||
<div class="detail-item">
|
||||
<span class="detail-label">输入</span>
|
||||
<span class="detail-value">{{ formatNumber(stats.promptTokens) }}</span>
|
||||
</div>
|
||||
<div class="detail-item">
|
||||
<span class="detail-label">输出</span>
|
||||
<span class="detail-value">{{ formatNumber(stats.completionTokens) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :xs="24" :sm="24" :md="8" :lg="8">
|
||||
<el-card shadow="hover" class="metric-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper success">
|
||||
<el-icon><Timer /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">响应时间统计</span>
|
||||
</div>
|
||||
<el-tag :type="stats.slowRequestsCount > 0 ? 'warning' : 'success'" size="small" effect="plain">
|
||||
{{ stats.slowRequestsCount }} 次超时
|
||||
</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
<div class="metric-content">
|
||||
<div class="latency-grid">
|
||||
<div class="latency-item">
|
||||
<span class="metric-value success">{{ formatLatency(stats.avgLatencyMs) }}</span>
|
||||
<span class="metric-label">平均耗时</span>
|
||||
</div>
|
||||
<div class="latency-item">
|
||||
<span class="metric-value">{{ formatLatency(stats.lastLatencyMs) }}</span>
|
||||
<span class="metric-label">上次耗时</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="metric-divider"></div>
|
||||
<div class="latency-stats">
|
||||
<div class="stat-row">
|
||||
<div class="stat-col">
|
||||
<span class="stat-label">P95</span>
|
||||
<span class="stat-value">{{ formatLatency(stats.p95LatencyMs) }}</span>
|
||||
</div>
|
||||
<div class="stat-col">
|
||||
<span class="stat-label">P99</span>
|
||||
<span class="stat-value">{{ formatLatency(stats.p99LatencyMs) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="stat-row">
|
||||
<div class="stat-col">
|
||||
<span class="stat-label">最小</span>
|
||||
<span class="stat-value">{{ formatLatency(stats.minLatencyMs) }}</span>
|
||||
</div>
|
||||
<div class="stat-col">
|
||||
<span class="stat-label">最大</span>
|
||||
<span class="stat-value">{{ formatLatency(stats.maxLatencyMs) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :xs="24" :sm="24" :md="8" :lg="8">
|
||||
<el-card shadow="hover" class="metric-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper warning">
|
||||
<el-icon><DataLine /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">请求统计</span>
|
||||
</div>
|
||||
<el-tag type="info" size="small" effect="plain">阈值 {{ stats.latencyThresholdMs }}ms</el-tag>
|
||||
</div>
|
||||
</template>
|
||||
<div class="metric-content">
|
||||
<div class="requests-grid">
|
||||
<div class="request-item">
|
||||
<span class="metric-value primary">{{ stats.aiRequestsCount }}</span>
|
||||
<span class="metric-label">AI 请求</span>
|
||||
</div>
|
||||
<div class="request-item">
|
||||
<span class="metric-value" :class="{ danger: stats.errorRequestsCount > 0 }">{{ stats.errorRequestsCount }}</span>
|
||||
<span class="metric-label">错误</span>
|
||||
</div>
|
||||
<div class="request-item">
|
||||
<span class="metric-value" :class="{ warning: stats.slowRequestsCount > 0 }">{{ stats.slowRequestsCount }}</span>
|
||||
<span class="metric-label">超时</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="metric-divider"></div>
|
||||
<div class="rate-stats">
|
||||
<div class="rate-item">
|
||||
<span class="rate-label">错误率</span>
|
||||
<span class="rate-value" :class="{ danger: parseFloat(errorRate) > 5 }">{{ errorRate }}%</span>
|
||||
</div>
|
||||
<div class="rate-item">
|
||||
<span class="rate-label">超时率</span>
|
||||
<span class="rate-value" :class="{ warning: parseFloat(timeoutRate) > 10 }">{{ timeoutRate }}%</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" style="margin-top: 20px;">
|
||||
<el-col :span="24">
|
||||
<el-card shadow="hover">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper info">
|
||||
<el-icon><InfoFilled /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">使用说明</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<div class="help-content">
|
||||
<div class="help-item">
|
||||
<div class="help-icon primary">
|
||||
<el-icon><FolderOpened /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>知识库管理</h4>
|
||||
<p>上传文档并建立向量索引,支持 PDF、Word、TXT 等格式。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-item">
|
||||
<div class="help-icon success">
|
||||
<el-icon><Cpu /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>RAG 实验室</h4>
|
||||
<p>测试检索增强生成效果,查看检索结果和 AI 响应。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-item">
|
||||
<div class="help-icon warning">
|
||||
<el-icon><Connection /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>嵌入模型配置</h4>
|
||||
<p>配置文本嵌入模型,用于文档向量化。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-item">
|
||||
<div class="help-icon info">
|
||||
<el-icon><ChatDotSquare /></el-icon>
|
||||
</div>
|
||||
<div class="help-text">
|
||||
<h4>LLM 模型配置</h4>
|
||||
<p>配置大语言模型,支持 OpenAI、DeepSeek、Ollama 等。</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine } from '@element-plus/icons-vue'
|
||||
import { getDashboardStats } from '@/api/dashboard'
|
||||
|
||||
const loading = ref(false)
|
||||
const stats = reactive({
|
||||
knowledgeBases: 0,
|
||||
totalDocuments: 0,
|
||||
totalMessages: 0,
|
||||
totalSessions: 0,
|
||||
totalTokens: 0,
|
||||
promptTokens: 0,
|
||||
completionTokens: 0,
|
||||
aiRequestsCount: 0,
|
||||
avgLatencyMs: 0,
|
||||
lastLatencyMs: 0,
|
||||
slowRequestsCount: 0,
|
||||
errorRequestsCount: 0,
|
||||
p95LatencyMs: 0,
|
||||
p99LatencyMs: 0,
|
||||
minLatencyMs: 0,
|
||||
maxLatencyMs: 0,
|
||||
latencyThresholdMs: 5000
|
||||
})
|
||||
|
||||
const errorRate = computed(() => {
|
||||
if (stats.aiRequestsCount === 0) return '0.00'
|
||||
return ((stats.errorRequestsCount / stats.aiRequestsCount) * 100).toFixed(2)
|
||||
})
|
||||
|
||||
const timeoutRate = computed(() => {
|
||||
if (stats.aiRequestsCount === 0) return '0.00'
|
||||
return ((stats.slowRequestsCount / stats.aiRequestsCount) * 100).toFixed(2)
|
||||
})
|
||||
|
||||
const formatNumber = (num: number) => {
|
||||
if (num >= 1000000) {
|
||||
return (num / 1000000).toFixed(2) + 'M'
|
||||
} else if (num >= 1000) {
|
||||
return (num / 1000).toFixed(1) + 'K'
|
||||
}
|
||||
return num.toString()
|
||||
}
|
||||
|
||||
const formatLatency = (ms: number | null | undefined) => {
|
||||
if (ms === null || ms === undefined) return '-'
|
||||
if (ms >= 1000) {
|
||||
return (ms / 1000).toFixed(2) + 's'
|
||||
}
|
||||
return ms.toFixed(0) + 'ms'
|
||||
}
|
||||
|
||||
const fetchStats = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await getDashboardStats()
|
||||
stats.knowledgeBases = res.knowledgeBases || 0
|
||||
stats.totalDocuments = res.totalDocuments || 0
|
||||
stats.totalMessages = res.totalMessages || 0
|
||||
stats.totalSessions = res.totalSessions || 0
|
||||
stats.totalTokens = res.totalTokens || 0
|
||||
stats.promptTokens = res.promptTokens || 0
|
||||
stats.completionTokens = res.completionTokens || 0
|
||||
stats.aiRequestsCount = res.aiRequestsCount || 0
|
||||
stats.avgLatencyMs = res.avgLatencyMs || 0
|
||||
stats.lastLatencyMs = res.lastLatencyMs || 0
|
||||
stats.slowRequestsCount = res.slowRequestsCount || 0
|
||||
stats.errorRequestsCount = res.errorRequestsCount || 0
|
||||
stats.p95LatencyMs = res.p95LatencyMs || 0
|
||||
stats.p99LatencyMs = res.p99LatencyMs || 0
|
||||
stats.minLatencyMs = res.minLatencyMs || 0
|
||||
stats.maxLatencyMs = res.maxLatencyMs || 0
|
||||
stats.latencyThresholdMs = res.latencyThresholdMs || 5000
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch dashboard stats:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchStats()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-page {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.stat-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
||||
.stat-card:nth-child(1) { animation-delay: 0s; }
|
||||
.stat-card:nth-child(2) { animation-delay: 0.1s; }
|
||||
.stat-card:nth-child(3) { animation-delay: 0.2s; }
|
||||
.stat-card:nth-child(4) { animation-delay: 0.3s; }
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.stat-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.stat-icon {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 12px;
|
||||
font-size: 22px;
|
||||
}
|
||||
|
||||
.stat-icon.primary {
|
||||
background-color: var(--primary-lighter);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.stat-icon.success {
|
||||
background-color: #D1FAE5;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.stat-icon.warning {
|
||||
background-color: #FEF3C7;
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.stat-icon.info {
|
||||
background-color: #E0E7FF;
|
||||
color: #4F46E5;
|
||||
}
|
||||
|
||||
.stat-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
animation: fadeInUp 0.6s ease-out;
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 10px;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.icon-wrapper.primary {
|
||||
background-color: var(--primary-lighter);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.icon-wrapper.success {
|
||||
background-color: #D1FAE5;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.icon-wrapper.warning {
|
||||
background-color: #FEF3C7;
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.icon-wrapper.info {
|
||||
background-color: #E0E7FF;
|
||||
color: #4F46E5;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.metric-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.metric-main {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 32px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
.metric-value.primary {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.metric-value.success {
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.metric-value.warning {
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.metric-value.danger {
|
||||
color: #DC2626;
|
||||
}
|
||||
|
||||
.metric-label {
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.metric-divider {
|
||||
height: 1px;
|
||||
background-color: var(--border-light);
|
||||
}
|
||||
|
||||
.metric-detail {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
}
|
||||
|
||||
.detail-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.detail-label {
|
||||
font-size: 12px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.detail-value {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.latency-grid {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.latency-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.latency-stats {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.stat-row {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
}
|
||||
|
||||
.stat-col {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 12px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.requests-grid {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.request-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.rate-stats {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
}
|
||||
|
||||
.rate-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.rate-label {
|
||||
font-size: 12px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.rate-value {
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.rate-value.warning {
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.rate-value.danger {
|
||||
color: #DC2626;
|
||||
}
|
||||
|
||||
.help-content {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.help-item {
|
||||
display: flex;
|
||||
gap: 14px;
|
||||
padding: 16px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 12px;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.help-item:hover {
|
||||
background-color: var(--primary-lighter);
|
||||
}
|
||||
|
||||
.help-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 10px;
|
||||
font-size: 18px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.help-icon.primary {
|
||||
background-color: var(--primary-lighter);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.help-icon.success {
|
||||
background-color: #D1FAE5;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.help-icon.warning {
|
||||
background-color: #FEF3C7;
|
||||
color: #D97706;
|
||||
}
|
||||
|
||||
.help-icon.info {
|
||||
background-color: #E0E7FF;
|
||||
color: #4F46E5;
|
||||
}
|
||||
|
||||
.help-text h4 {
|
||||
margin: 0 0 4px 0;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.help-text p {
|
||||
margin: 0;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.dashboard-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 24px;
|
||||
}
|
||||
|
||||
.help-content {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.metric-value {
|
||||
font-size: 28px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
<template>
|
||||
<div class="kb-page">
|
||||
<div class="page-header">
|
||||
<div class="header-content">
|
||||
<div class="title-section">
|
||||
<h1 class="page-title">知识库管理</h1>
|
||||
<p class="page-desc">上传文档并建立向量索引,支持多种文档格式。</p>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<el-button type="primary" @click="handleUploadClick">
|
||||
<el-icon><Upload /></el-icon>
|
||||
上传文档
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-card shadow="hover" class="table-card">
|
||||
<el-table v-loading="loading" :data="tableData" style="width: 100%">
|
||||
<el-table-column prop="name" label="文件名" min-width="200">
|
||||
<template #default="scope">
|
||||
<div class="file-name">
|
||||
<el-icon class="file-icon"><Document /></el-icon>
|
||||
<span>{{ scope.row.name }}</span>
|
||||
</div>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="status" label="状态" width="120">
|
||||
<template #default="scope">
|
||||
<el-tag :type="getStatusType(scope.row.status)" size="small">
|
||||
{{ getStatusText(scope.row.status) }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="jobId" label="任务ID" width="180">
|
||||
<template #default="scope">
|
||||
<span class="job-id">{{ scope.row.jobId || '-' }}</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="createTime" label="上传时间" width="180" />
|
||||
<el-table-column label="操作" width="160" fixed="right">
|
||||
<template #default="scope">
|
||||
<el-button link type="primary" @click="handleViewJob(scope.row)">
|
||||
<el-icon><View /></el-icon>
|
||||
详情
|
||||
</el-button>
|
||||
<el-button link type="danger" @click="handleDelete(scope.row)">
|
||||
<el-icon><Delete /></el-icon>
|
||||
删除
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
</el-card>
|
||||
|
||||
<el-dialog v-model="jobDialogVisible" title="索引任务详情" width="500px" class="job-dialog">
|
||||
<el-descriptions :column="1" border v-if="currentJob">
|
||||
<el-descriptions-item label="任务ID">
|
||||
<span class="job-id">{{ currentJob.jobId }}</span>
|
||||
</el-descriptions-item>
|
||||
<el-descriptions-item label="状态">
|
||||
<el-tag :type="getStatusType(currentJob.status)" size="small">
|
||||
{{ getStatusText(currentJob.status) }}
|
||||
</el-tag>
|
||||
</el-descriptions-item>
|
||||
<el-descriptions-item label="进度">
|
||||
<div class="progress-wrapper">
|
||||
<el-progress :percentage="currentJob.progress" :status="getProgressStatus(currentJob.status)" />
|
||||
</div>
|
||||
</el-descriptions-item>
|
||||
<el-descriptions-item label="错误信息" v-if="currentJob.errorMsg">
|
||||
<el-alert type="error" :closable="false">{{ currentJob.errorMsg }}</el-alert>
|
||||
</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
<template #footer>
|
||||
<el-button @click="jobDialogVisible = false">关闭</el-button>
|
||||
<el-button
|
||||
v-if="currentJob?.status === 'pending' || currentJob?.status === 'processing'"
|
||||
type="primary"
|
||||
@click="refreshJobStatus"
|
||||
>
|
||||
刷新状态
|
||||
</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
|
||||
<input ref="fileInput" type="file" style="display: none" @change="handleFileChange" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Upload, Document, View, Delete } from '@element-plus/icons-vue'
|
||||
import { uploadDocument, listDocuments, getIndexJob, deleteDocument } from '@/api/kb'
|
||||
|
||||
interface DocumentItem {
|
||||
docId: string
|
||||
name: string
|
||||
status: string
|
||||
jobId: string
|
||||
createTime: string
|
||||
}
|
||||
|
||||
const tableData = ref<DocumentItem[]>([])
|
||||
const loading = ref(false)
|
||||
const jobDialogVisible = ref(false)
|
||||
const currentJob = ref<any>(null)
|
||||
const pollingJobs = ref<Set<string>>(new Set())
|
||||
let pollingInterval: number | null = null
|
||||
|
||||
const getStatusType = (status: string) => {
|
||||
const typeMap: Record<string, string> = {
|
||||
completed: 'success',
|
||||
processing: 'warning',
|
||||
pending: 'info',
|
||||
failed: 'danger'
|
||||
}
|
||||
return typeMap[status] || 'info'
|
||||
}
|
||||
|
||||
const getStatusText = (status: string) => {
|
||||
const textMap: Record<string, string> = {
|
||||
completed: '已完成',
|
||||
processing: '处理中',
|
||||
pending: '等待中',
|
||||
failed: '失败'
|
||||
}
|
||||
return textMap[status] || status
|
||||
}
|
||||
|
||||
const getProgressStatus = (status: string) => {
|
||||
if (status === 'completed') return 'success'
|
||||
if (status === 'failed') return 'exception'
|
||||
return undefined
|
||||
}
|
||||
|
||||
const fetchDocuments = async () => {
|
||||
try {
|
||||
const res = await listDocuments({})
|
||||
tableData.value = res.data.map((doc: any) => ({
|
||||
docId: doc.docId,
|
||||
name: doc.fileName,
|
||||
status: doc.status,
|
||||
jobId: doc.jobId,
|
||||
createTime: new Date(doc.createdAt).toLocaleString('zh-CN')
|
||||
}))
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch documents:', error)
|
||||
}
|
||||
}
|
||||
|
||||
const fetchJobStatus = async (jobId: string) => {
|
||||
try {
|
||||
const res = await getIndexJob(jobId)
|
||||
return res
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch job status:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const handleViewJob = async (row: DocumentItem) => {
|
||||
if (!row.jobId) {
|
||||
ElMessage.warning('该文档没有任务ID')
|
||||
return
|
||||
}
|
||||
currentJob.value = await fetchJobStatus(row.jobId)
|
||||
jobDialogVisible.value = true
|
||||
}
|
||||
|
||||
const handleDelete = async (row: DocumentItem) => {
|
||||
try {
|
||||
await ElMessageBox.confirm(`确定要删除文档 "${row.name}" 吗?`, '确认删除', {
|
||||
confirmButtonText: '确定',
|
||||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
})
|
||||
|
||||
await deleteDocument(row.docId)
|
||||
ElMessage.success('文档删除成功')
|
||||
fetchDocuments()
|
||||
} catch (error: any) {
|
||||
if (error !== 'cancel') {
|
||||
ElMessage.error(error.message || '删除失败')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const refreshJobStatus = async () => {
|
||||
if (currentJob.value?.jobId) {
|
||||
currentJob.value = await fetchJobStatus(currentJob.value.jobId)
|
||||
}
|
||||
}
|
||||
|
||||
const startPolling = (jobId: string) => {
|
||||
pollingJobs.value.add(jobId)
|
||||
if (!pollingInterval) {
|
||||
pollingInterval = window.setInterval(async () => {
|
||||
for (const jobId of pollingJobs.value) {
|
||||
const job = await fetchJobStatus(jobId)
|
||||
if (job) {
|
||||
if (job.status === 'completed') {
|
||||
pollingJobs.value.delete(jobId)
|
||||
ElMessage.success('文档索引任务已完成')
|
||||
fetchDocuments()
|
||||
} else if (job.status === 'failed') {
|
||||
pollingJobs.value.delete(jobId)
|
||||
ElMessage.error('文档索引任务失败')
|
||||
ElMessage.warning(`错误: ${job.errorMsg}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (pollingJobs.value.size === 0 && pollingInterval) {
|
||||
clearInterval(pollingInterval)
|
||||
pollingInterval = null
|
||||
}
|
||||
}, 3000)
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchDocuments()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (pollingInterval) {
|
||||
clearInterval(pollingInterval)
|
||||
}
|
||||
})
|
||||
|
||||
const fileInput = ref<HTMLInputElement | null>(null)
|
||||
|
||||
const handleUploadClick = () => {
|
||||
fileInput.value?.click()
|
||||
}
|
||||
|
||||
const handleFileChange = async (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
const file = target.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
formData.append('kb_id', 'kb_default')
|
||||
|
||||
try {
|
||||
loading.value = true
|
||||
const res = await uploadDocument(formData)
|
||||
ElMessage.success(`文档上传成功!任务ID: ${res.jobId}`)
|
||||
console.log('Upload response:', res)
|
||||
|
||||
const newDoc: DocumentItem = {
|
||||
name: file.name,
|
||||
status: res.status || 'pending',
|
||||
jobId: res.jobId,
|
||||
createTime: new Date().toLocaleString('zh-CN')
|
||||
}
|
||||
tableData.value.unshift(newDoc)
|
||||
|
||||
startPolling(res.jobId)
|
||||
} catch (error) {
|
||||
ElMessage.error('文档上传失败')
|
||||
console.error('Upload error:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
target.value = ''
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.kb-page {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
gap: 20px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
flex: 1;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.table-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.file-name {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.file-icon {
|
||||
color: var(--primary-color);
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.job-id {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 12px;
|
||||
color: var(--text-secondary);
|
||||
background-color: var(--bg-tertiary);
|
||||
padding: 2px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.progress-wrapper {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.job-dialog :deep(.el-dialog__header) {
|
||||
border-bottom: 1px solid var(--border-light);
|
||||
}
|
||||
|
||||
.job-dialog :deep(.el-dialog__body) {
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.kb-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.title-section {
|
||||
min-width: 100%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
<template>
|
||||
<div class="monitoring-container">
|
||||
<el-card shadow="never">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="title">会话监控 [AC-ASA-09]</span>
|
||||
<div class="header-ops">
|
||||
<el-form :inline="true" :model="queryParams" class="search-form">
|
||||
<el-form-item label="状态">
|
||||
<el-select v-model="queryParams.status" placeholder="会话状态" clearable style="width: 120px">
|
||||
<el-option label="活跃" value="active" />
|
||||
<el-option label="已关闭" value="closed" />
|
||||
<el-option label="已过期" value="expired" />
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-button type="primary" @click="handleQuery">查询</el-button>
|
||||
<el-button @click="resetQuery">重置</el-button>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<base-table
|
||||
:data="tableData"
|
||||
:total="total"
|
||||
v-model:page-num="queryParams.page"
|
||||
v-model:page-size="queryParams.pageSize"
|
||||
@pagination="getList"
|
||||
v-loading="loading"
|
||||
>
|
||||
<el-table-column prop="sessionId" label="会话 ID" width="280" show-overflow-tooltip />
|
||||
<el-table-column prop="tenantId" label="租户 ID" width="280" show-overflow-tooltip />
|
||||
<el-table-column prop="messageCount" label="消息数" width="100" align="center" />
|
||||
<el-table-column prop="status" label="状态" width="100" align="center">
|
||||
<template #default="scope">
|
||||
<el-tag :type="statusMap[scope.row.status]?.type" size="small">
|
||||
{{ statusMap[scope.row.status]?.label || scope.row.status }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="channelType" label="渠道" width="100" />
|
||||
<el-table-column prop="startTime" label="开始时间" width="180" />
|
||||
<el-table-column label="操作" fixed="right" width="120" align="center">
|
||||
<template #default="scope">
|
||||
<el-button link type="primary" @click="handleTrace(scope.row)">全链路追踪</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</base-table>
|
||||
</el-card>
|
||||
|
||||
<el-drawer
|
||||
v-model="drawerVisible"
|
||||
title="会话全链路追踪详情"
|
||||
size="50%"
|
||||
destroy-on-close
|
||||
>
|
||||
<div v-loading="detailLoading" class="detail-container">
|
||||
<el-empty v-if="!sessionDetail && !detailLoading" description="暂无追踪详情" />
|
||||
<div v-else>
|
||||
<el-descriptions :column="1" border class="session-info">
|
||||
<el-descriptions-item label="会话ID">{{ sessionDetail?.sessionId }}</el-descriptions-item>
|
||||
<el-descriptions-item label="消息数">{{ sessionDetail?.messages?.length || 0 }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
|
||||
<el-divider content-position="left">消息记录</el-divider>
|
||||
|
||||
<el-timeline>
|
||||
<el-timeline-item
|
||||
v-for="(msg, index) in sessionDetail?.messages"
|
||||
:key="index"
|
||||
:timestamp="msg.timestamp"
|
||||
placement="top"
|
||||
:type="msg.role === 'user' ? 'primary' : 'success'"
|
||||
>
|
||||
<el-card shadow="never" class="msg-card">
|
||||
<div class="msg-header">
|
||||
<span class="role-tag" :class="msg.role">{{ msg.role === 'user' ? 'USER' : 'ASSISTANT' }}</span>
|
||||
</div>
|
||||
<div class="msg-content">{{ msg.content }}</div>
|
||||
</el-card>
|
||||
</el-timeline-item>
|
||||
</el-timeline>
|
||||
|
||||
<el-divider content-position="left" v-if="sessionDetail?.trace && (sessionDetail.trace.retrieval?.length || sessionDetail.trace.tools?.length)">
|
||||
追踪信息
|
||||
</el-divider>
|
||||
|
||||
<el-collapse v-if="sessionDetail?.trace">
|
||||
<el-collapse-item v-if="sessionDetail.trace.retrieval?.length" title="检索追踪 (Retrieval)" name="retrieval">
|
||||
<div v-for="(hit, hIdx) in sessionDetail.trace.retrieval" :key="hIdx" class="hit-item">
|
||||
<div class="hit-meta">
|
||||
<el-tag size="small" type="success">Score: {{ hit.score }}</el-tag>
|
||||
<span class="hit-source" v-if="hit.source">来源: {{ hit.source }}</span>
|
||||
</div>
|
||||
<div class="hit-text">{{ hit.content }}</div>
|
||||
</div>
|
||||
</el-collapse-item>
|
||||
<el-collapse-item v-if="sessionDetail.trace.tools?.length" title="工具调用 (Tool Calls)" name="tools">
|
||||
<pre class="code-block"><code>{{ JSON.stringify(sessionDetail.trace.tools, null, 2) }}</code></pre>
|
||||
</el-collapse-item>
|
||||
</el-collapse>
|
||||
</div>
|
||||
</div>
|
||||
</el-drawer>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted } from 'vue'
|
||||
import BaseTable from '@/components/BaseTable.vue'
|
||||
import { listSessions, getSessionDetail } from '@/api/monitoring'
|
||||
|
||||
const statusMap: Record<string, { label: string, type: string }> = {
|
||||
active: { label: '活跃', type: 'success' },
|
||||
closed: { label: '已关闭', type: 'info' },
|
||||
expired: { label: '已过期', type: 'warning' }
|
||||
}
|
||||
|
||||
const loading = ref(false)
|
||||
const tableData = ref([])
|
||||
const total = ref(0)
|
||||
const queryParams = reactive({
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
status: ''
|
||||
})
|
||||
|
||||
const drawerVisible = ref(false)
|
||||
const detailLoading = ref(false)
|
||||
const sessionDetail = ref<any>(null)
|
||||
|
||||
const getList = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await listSessions(queryParams)
|
||||
tableData.value = res.data || []
|
||||
total.value = res.pagination?.total || 0
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch sessions:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleQuery = () => {
|
||||
queryParams.page = 1
|
||||
getList()
|
||||
}
|
||||
|
||||
const resetQuery = () => {
|
||||
queryParams.status = ''
|
||||
handleQuery()
|
||||
}
|
||||
|
||||
const handleTrace = async (row: any) => {
|
||||
drawerVisible.value = true
|
||||
detailLoading.value = true
|
||||
try {
|
||||
sessionDetail.value = await getSessionDetail(row.sessionId)
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch session detail:', error)
|
||||
} finally {
|
||||
detailLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
getList()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.monitoring-container { padding: 20px; }
|
||||
.card-header { display: flex; justify-content: space-between; align-items: center; }
|
||||
.title { font-size: 16px; font-weight: bold; }
|
||||
.detail-container { padding: 10px 20px; }
|
||||
.session-info { margin-bottom: 20px; }
|
||||
.msg-card { border-radius: 8px; margin-bottom: 10px; }
|
||||
.msg-header { margin-bottom: 8px; }
|
||||
.role-tag { font-size: 11px; font-weight: bold; padding: 2px 6px; border-radius: 4px; }
|
||||
.role-tag.user { background-color: #ecf5ff; color: #409eff; }
|
||||
.role-tag.assistant { background-color: #f0f9eb; color: #67c23a; }
|
||||
.msg-content { font-size: 14px; line-height: 1.6; color: #333; white-space: pre-wrap; }
|
||||
.hit-item { padding: 10px; background-color: #f8f9fa; border-radius: 4px; margin-bottom: 8px; }
|
||||
.hit-meta { display: flex; justify-content: space-between; align-items: center; margin-bottom: 6px; }
|
||||
.hit-source { font-size: 11px; color: #999; }
|
||||
.hit-text { font-size: 12px; color: #666; line-height: 1.5; }
|
||||
.code-block { background-color: #fafafa; border: 1px solid #eaeaea; padding: 8px; border-radius: 4px; font-size: 12px; overflow-x: auto; margin: 0; }
|
||||
</style>
|
||||
|
|
@ -0,0 +1,545 @@
|
|||
<template>
|
||||
<div class="rag-lab-page">
|
||||
<div class="page-header">
|
||||
<h1 class="page-title">RAG 实验室</h1>
|
||||
<p class="page-desc">测试检索增强生成效果,查看检索结果和 AI 响应。</p>
|
||||
</div>
|
||||
|
||||
<el-row :gutter="24">
|
||||
<el-col :xs="24" :sm="24" :md="10" :lg="10">
|
||||
<el-card shadow="hover" class="input-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<div class="header-left">
|
||||
<div class="icon-wrapper">
|
||||
<el-icon><Edit /></el-icon>
|
||||
</div>
|
||||
<span class="header-title">调试输入</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<el-form label-position="top">
|
||||
<el-form-item label="查询 Query">
|
||||
<el-input
|
||||
v-model="query"
|
||||
type="textarea"
|
||||
:rows="4"
|
||||
placeholder="输入测试问题..."
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="知识库范围">
|
||||
<el-select
|
||||
v-model="kbIds"
|
||||
multiple
|
||||
placeholder="请选择知识库"
|
||||
style="width: 100%"
|
||||
:loading="kbLoading"
|
||||
:teleported="true"
|
||||
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
|
||||
>
|
||||
<el-option
|
||||
v-for="kb in knowledgeBases"
|
||||
:key="kb.id"
|
||||
:label="`${kb.name} (${kb.documentCount}个文档)`"
|
||||
:value="kb.id"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="LLM 模型">
|
||||
<LLMSelector
|
||||
v-model="llmProvider"
|
||||
:providers="llmProviders"
|
||||
:loading="llmLoading"
|
||||
:current-provider="currentLLMProvider"
|
||||
placeholder="使用默认配置"
|
||||
clearable
|
||||
@change="handleLLMChange"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="参数配置">
|
||||
<div class="param-item">
|
||||
<span class="label">Top-K</span>
|
||||
<el-input-number v-model="topK" :min="1" :max="10" />
|
||||
</div>
|
||||
<div class="param-item">
|
||||
<span class="label">Score Threshold</span>
|
||||
<el-slider
|
||||
v-model="scoreThreshold"
|
||||
:min="0"
|
||||
:max="1"
|
||||
:step="0.1"
|
||||
show-input
|
||||
/>
|
||||
</div>
|
||||
<div class="param-item">
|
||||
<span class="label">生成 AI 回复</span>
|
||||
<el-switch v-model="generateResponse" />
|
||||
</div>
|
||||
<div class="param-item" v-if="generateResponse">
|
||||
<span class="label">流式输出</span>
|
||||
<el-switch v-model="streamOutput" />
|
||||
</div>
|
||||
</el-form-item>
|
||||
<el-button
|
||||
type="primary"
|
||||
block
|
||||
@click="handleRun"
|
||||
:loading="loading || streaming"
|
||||
>
|
||||
{{ streaming ? '生成中...' : '运行实验' }}
|
||||
</el-button>
|
||||
<el-button
|
||||
v-if="streaming"
|
||||
type="danger"
|
||||
block
|
||||
@click="handleStopStream"
|
||||
style="margin-top: 10px;"
|
||||
>
|
||||
停止生成
|
||||
</el-button>
|
||||
</el-form>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
<el-col :xs="24" :sm="24" :md="14" :lg="14">
|
||||
<el-tabs v-model="activeTab" type="border-card" class="result-tabs">
|
||||
<el-tab-pane label="召回片段" name="retrieval">
|
||||
<div v-if="retrievalResults.length === 0" class="placeholder-text">
|
||||
暂无实验数据
|
||||
</div>
|
||||
<div v-else class="result-list">
|
||||
<el-card
|
||||
v-for="(item, index) in retrievalResults"
|
||||
:key="index"
|
||||
class="result-card"
|
||||
shadow="never"
|
||||
>
|
||||
<div class="result-header">
|
||||
<el-tag size="small" type="primary">Score: {{ item.score.toFixed(4) }}</el-tag>
|
||||
<span class="source">来源: {{ item.source }}</span>
|
||||
</div>
|
||||
<div class="result-content">{{ item.content }}</div>
|
||||
</el-card>
|
||||
</div>
|
||||
</el-tab-pane>
|
||||
<el-tab-pane label="最终 Prompt" name="prompt">
|
||||
<div v-if="!finalPrompt" class="placeholder-text">
|
||||
等待实验运行...
|
||||
</div>
|
||||
<div v-else class="prompt-view">
|
||||
<pre><code>{{ finalPrompt }}</code></pre>
|
||||
</div>
|
||||
</el-tab-pane>
|
||||
<el-tab-pane label="AI 回复" name="ai-response" v-if="generateResponse">
|
||||
<StreamOutput
|
||||
v-if="streamOutput"
|
||||
:content="streamContent"
|
||||
:is-streaming="streaming"
|
||||
:error="streamError"
|
||||
/>
|
||||
<AIResponseViewer
|
||||
v-else
|
||||
:response="aiResponse"
|
||||
/>
|
||||
</el-tab-pane>
|
||||
<el-tab-pane label="诊断信息" name="diagnostics">
|
||||
<div v-if="!diagnostics" class="placeholder-text">
|
||||
等待实验运行...
|
||||
</div>
|
||||
<div v-else class="diagnostics-view">
|
||||
<pre><code>{{ JSON.stringify(diagnostics, null, 2) }}</code></pre>
|
||||
</div>
|
||||
</el-tab-pane>
|
||||
</el-tabs>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { Edit } from '@element-plus/icons-vue'
|
||||
import { runRagExperiment, createSSEConnection, type AIResponse, type RetrievalResult } from '@/api/rag'
|
||||
import { getLLMProviders, getLLMConfig, type LLMProviderInfo } from '@/api/llm'
|
||||
import { listKnowledgeBases } from '@/api/kb'
|
||||
import { useRagLabStore } from '@/stores/ragLab'
|
||||
import { storeToRefs } from 'pinia'
|
||||
import AIResponseViewer from '@/components/rag/AIResponseViewer.vue'
|
||||
import StreamOutput from '@/components/rag/StreamOutput.vue'
|
||||
import LLMSelector from '@/components/rag/LLMSelector.vue'
|
||||
|
||||
interface KnowledgeBase {
|
||||
id: string
|
||||
name: string
|
||||
documentCount: number
|
||||
}
|
||||
|
||||
const ragLabStore = useRagLabStore()
|
||||
const {
|
||||
query,
|
||||
kbIds,
|
||||
llmProvider,
|
||||
topK,
|
||||
scoreThreshold,
|
||||
generateResponse,
|
||||
streamOutput
|
||||
} = storeToRefs(ragLabStore)
|
||||
|
||||
const loading = ref(false)
|
||||
const kbLoading = ref(false)
|
||||
const llmLoading = ref(false)
|
||||
const streaming = ref(false)
|
||||
const activeTab = ref('retrieval')
|
||||
const knowledgeBases = ref<KnowledgeBase[]>([])
|
||||
const llmProviders = ref<LLMProviderInfo[]>([])
|
||||
const currentLLMProvider = ref('')
|
||||
|
||||
const retrievalResults = ref<RetrievalResult[]>([])
|
||||
const finalPrompt = ref('')
|
||||
const aiResponse = ref<AIResponse | null>(null)
|
||||
const diagnostics = ref<any>(null)
|
||||
const streamContent = ref('')
|
||||
const streamError = ref<string | null>(null)
|
||||
|
||||
const totalLatencyMs = ref(0)
|
||||
|
||||
let abortStream: (() => void) | null = null
|
||||
|
||||
const fetchKnowledgeBases = async () => {
|
||||
kbLoading.value = true
|
||||
try {
|
||||
const res: any = await listKnowledgeBases()
|
||||
knowledgeBases.value = res.data || []
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch knowledge bases:', error)
|
||||
} finally {
|
||||
kbLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const fetchLLMProviders = async () => {
|
||||
llmLoading.value = true
|
||||
try {
|
||||
const [providersRes, configRes]: [any, any] = await Promise.all([
|
||||
getLLMProviders(),
|
||||
getLLMConfig()
|
||||
])
|
||||
llmProviders.value = providersRes?.providers || []
|
||||
currentLLMProvider.value = configRes?.provider || ''
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch LLM providers:', error)
|
||||
} finally {
|
||||
llmLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleLLMChange = (provider: LLMProviderInfo | undefined) => {
|
||||
llmProvider.value = provider?.name || ''
|
||||
}
|
||||
|
||||
const handleRun = async () => {
|
||||
if (!query.value.trim()) {
|
||||
ElMessage.warning('请输入查询 Query')
|
||||
return
|
||||
}
|
||||
|
||||
clearResults()
|
||||
|
||||
if (streamOutput.value && generateResponse.value) {
|
||||
await runStreamExperiment()
|
||||
} else {
|
||||
await runNormalExperiment()
|
||||
}
|
||||
}
|
||||
|
||||
const runNormalExperiment = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await runRagExperiment({
|
||||
query: query.value,
|
||||
kb_ids: kbIds.value,
|
||||
top_k: topK.value,
|
||||
score_threshold: scoreThreshold.value,
|
||||
llm_provider: llmProvider.value || undefined,
|
||||
generate_response: generateResponse.value
|
||||
})
|
||||
|
||||
retrievalResults.value = res.retrieval_results || res.retrievalResults || []
|
||||
finalPrompt.value = res.final_prompt || res.finalPrompt || ''
|
||||
aiResponse.value = res.ai_response || res.aiResponse || null
|
||||
diagnostics.value = res.diagnostics || null
|
||||
totalLatencyMs.value = res.total_latency_ms || res.totalLatencyMs || 0
|
||||
|
||||
if (generateResponse.value) {
|
||||
activeTab.value = 'ai-response'
|
||||
} else {
|
||||
activeTab.value = 'retrieval'
|
||||
}
|
||||
|
||||
ElMessage.success('实验运行成功')
|
||||
} catch (err: any) {
|
||||
console.error(err)
|
||||
ElMessage.error(err?.message || '实验运行失败')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const runStreamExperiment = async () => {
|
||||
streaming.value = true
|
||||
streamContent.value = ''
|
||||
streamError.value = null
|
||||
activeTab.value = 'ai-response'
|
||||
|
||||
abortStream = createSSEConnection(
|
||||
'/admin/rag/experiments/stream',
|
||||
{
|
||||
query: query.value,
|
||||
kb_ids: kbIds.value,
|
||||
top_k: topK.value,
|
||||
score_threshold: scoreThreshold.value,
|
||||
llm_provider: llmProvider.value || undefined,
|
||||
generate_response: true
|
||||
},
|
||||
(data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
|
||||
if (parsed.type === 'content') {
|
||||
streamContent.value += parsed.content || ''
|
||||
} else if (parsed.type === 'retrieval') {
|
||||
retrievalResults.value = parsed.results || []
|
||||
} else if (parsed.type === 'prompt') {
|
||||
finalPrompt.value = parsed.prompt || ''
|
||||
} else if (parsed.type === 'complete') {
|
||||
aiResponse.value = {
|
||||
content: streamContent.value,
|
||||
prompt_tokens: parsed.prompt_tokens,
|
||||
completion_tokens: parsed.completion_tokens,
|
||||
total_tokens: parsed.total_tokens,
|
||||
latency_ms: parsed.latency_ms,
|
||||
model: parsed.model
|
||||
}
|
||||
totalLatencyMs.value = parsed.total_latency_ms || 0
|
||||
streaming.value = false
|
||||
ElMessage.success('生成完成')
|
||||
} else if (parsed.type === 'error') {
|
||||
streamError.value = parsed.message || '流式输出错误'
|
||||
streaming.value = false
|
||||
ElMessage.error(streamError.value)
|
||||
}
|
||||
} catch {
|
||||
streamContent.value += data
|
||||
}
|
||||
},
|
||||
(error: Error) => {
|
||||
streaming.value = false
|
||||
streamError.value = error.message
|
||||
ElMessage.error(error.message)
|
||||
},
|
||||
() => {
|
||||
streaming.value = false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
const handleStopStream = () => {
|
||||
if (abortStream) {
|
||||
abortStream()
|
||||
abortStream = null
|
||||
}
|
||||
streaming.value = false
|
||||
ElMessage.info('已停止生成')
|
||||
}
|
||||
|
||||
const clearResults = () => {
|
||||
retrievalResults.value = []
|
||||
finalPrompt.value = ''
|
||||
aiResponse.value = null
|
||||
diagnostics.value = null
|
||||
streamContent.value = ''
|
||||
streamError.value = null
|
||||
totalLatencyMs.value = 0
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchKnowledgeBases()
|
||||
fetchLLMProviders()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.rag-lab-page {
|
||||
padding: 24px;
|
||||
min-height: calc(100vh - 60px);
|
||||
}
|
||||
|
||||
.page-header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
letter-spacing: -0.5px;
|
||||
}
|
||||
|
||||
.page-desc {
|
||||
margin: 0;
|
||||
font-size: 14px;
|
||||
color: var(--text-secondary);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.input-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
||||
@keyframes fadeInUp {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.header-left {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.icon-wrapper {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: var(--primary-lighter);
|
||||
border-radius: 10px;
|
||||
color: var(--primary-color);
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.header-title {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.param-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 16px;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.param-item .label {
|
||||
width: 140px;
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
color: var(--text-secondary);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.param-item :deep(.el-slider) {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.result-tabs {
|
||||
animation: fadeInUp 0.6s ease-out;
|
||||
}
|
||||
|
||||
.result-tabs :deep(.el-tabs__header) {
|
||||
border-radius: 12px 12px 0 0;
|
||||
}
|
||||
|
||||
.placeholder-text {
|
||||
color: var(--text-tertiary);
|
||||
text-align: center;
|
||||
padding: 60px 20px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.result-list {
|
||||
max-height: 600px;
|
||||
overflow-y: auto;
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
.result-card {
|
||||
margin-bottom: 16px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.source {
|
||||
font-size: 12px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.result-content {
|
||||
font-size: 14px;
|
||||
line-height: 1.7;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.prompt-view,
|
||||
.diagnostics-view {
|
||||
background-color: var(--bg-tertiary);
|
||||
padding: 16px;
|
||||
border-radius: 10px;
|
||||
max-height: 600px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.prompt-view pre,
|
||||
.diagnostics-view pre {
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-family: var(--font-mono);
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.rag-lab-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.param-item {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.param-item .label {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ESNext",
|
||||
"useDefineForClassFields": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Node",
|
||||
"strict": true,
|
||||
"jsx": "preserve",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"esModuleInterop": true,
|
||||
"lib": ["ESNext", "DOM"],
|
||||
"skipLibCheck": true,
|
||||
"noEmit": true,
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": ["src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
"skipLibCheck": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Node",
|
||||
"allowSyntheticDefaultImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import { defineConfig } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import path from 'path'
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
},
|
||||
},
|
||||
server: {
|
||||
port: 3000,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:8000',
|
||||
changeOrigin: true,
|
||||
rewrite: (path) => path.replace(/^\/api/, ''),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# AI Service
|
||||
|
||||
Python AI Service for intelligent chat with RAG support.
|
||||
|
||||
## Features
|
||||
|
||||
- Multi-tenant isolation via X-Tenant-Id header
|
||||
- SSE streaming support via Accept: text/event-stream
|
||||
- RAG-powered responses with confidence scoring
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- PostgreSQL 12+
|
||||
- Qdrant vector database
|
||||
- Python 3.10+
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Database Initialization
|
||||
|
||||
### Option 1: Using Python script (Recommended)
|
||||
|
||||
```bash
|
||||
# Create database and tables
|
||||
python scripts/init_db.py --create-db
|
||||
|
||||
# Or just create tables (database must exist)
|
||||
python scripts/init_db.py
|
||||
```
|
||||
|
||||
### Option 2: Using SQL script
|
||||
|
||||
```bash
|
||||
# Connect to PostgreSQL and run
|
||||
psql -U postgres -f scripts/init_db.sql
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a `.env` file in the project root:
|
||||
|
||||
```env
|
||||
AI_SERVICE_DATABASE_URL=postgresql+asyncpg://postgres:password@localhost:5432/ai_service
|
||||
AI_SERVICE_QDRANT_URL=http://localhost:6333
|
||||
AI_SERVICE_LLM_API_KEY=your-api-key
|
||||
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
|
||||
AI_SERVICE_LLM_MODEL=gpt-4o-mini
|
||||
AI_SERVICE_DEBUG=true
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Chat API
|
||||
- `POST /ai/chat` - Generate AI reply (supports SSE streaming)
|
||||
- `GET /ai/health` - Health check
|
||||
|
||||
### Admin API
|
||||
- `GET /admin/kb/documents` - List documents
|
||||
- `POST /admin/kb/documents` - Upload document
|
||||
- `GET /admin/kb/index/jobs/{jobId}` - Get indexing job status
|
||||
- `DELETE /admin/kb/documents/{docId}` - Delete document
|
||||
- `POST /admin/rag/experiments/run` - Run RAG experiment
|
||||
- `GET /admin/sessions` - List chat sessions
|
||||
- `GET /admin/sessions/{sessionId}` - Get session details
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
"""
|
||||
AI Service - Python AI Middle Platform
|
||||
[AC-AISVC-01] FastAPI-based AI chat service with multi-tenant support.
|
||||
"""
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
API module for AI Service.
|
||||
"""
|
||||
|
||||
from app.api.chat import router as chat_router
|
||||
from app.api.health import router as health_router
|
||||
|
||||
__all__ = ["chat_router", "health_router"]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""
|
||||
Admin API routes for AI Service management.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08] Admin management endpoints.
|
||||
"""
|
||||
|
||||
from app.api.admin.dashboard import router as dashboard_router
|
||||
from app.api.admin.embedding import router as embedding_router
|
||||
from app.api.admin.kb import router as kb_router
|
||||
from app.api.admin.llm import router as llm_router
|
||||
from app.api.admin.rag import router as rag_router
|
||||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.tenants import router as tenants_router
|
||||
|
||||
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Dashboard statistics endpoints.
|
||||
Provides overview statistics for the admin dashboard.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import ChatMessage, ChatSession, Document, KnowledgeBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
|
||||
|
||||
LATENCY_THRESHOLD_MS = 5000
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
operation_id="getDashboardStats",
|
||||
summary="Get dashboard statistics",
|
||||
description="Get overview statistics for the admin dashboard.",
|
||||
responses={
|
||||
200: {"description": "Dashboard statistics"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def get_dashboard_stats(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
latency_threshold: int = Query(default=LATENCY_THRESHOLD_MS, description="Latency threshold in ms"),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Get dashboard statistics including knowledge bases, messages, and activity.
|
||||
"""
|
||||
logger.info(f"Getting dashboard stats: tenant={tenant_id}")
|
||||
|
||||
kb_count_stmt = select(func.count()).select_from(KnowledgeBase).where(
|
||||
KnowledgeBase.tenant_id == tenant_id
|
||||
)
|
||||
kb_result = await session.execute(kb_count_stmt)
|
||||
kb_count = kb_result.scalar() or 0
|
||||
|
||||
msg_count_stmt = select(func.count()).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id
|
||||
)
|
||||
msg_result = await session.execute(msg_count_stmt)
|
||||
msg_count = msg_result.scalar() or 0
|
||||
|
||||
doc_count_stmt = select(func.count()).select_from(Document).where(
|
||||
Document.tenant_id == tenant_id
|
||||
)
|
||||
doc_result = await session.execute(doc_count_stmt)
|
||||
doc_count = doc_result.scalar() or 0
|
||||
|
||||
session_count_stmt = select(func.count()).select_from(ChatSession).where(
|
||||
ChatSession.tenant_id == tenant_id
|
||||
)
|
||||
session_result = await session.execute(session_count_stmt)
|
||||
session_count = session_result.scalar() or 0
|
||||
|
||||
total_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.total_tokens), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(ChatMessage.tenant_id == tenant_id)
|
||||
total_tokens_result = await session.execute(total_tokens_stmt)
|
||||
total_tokens = total_tokens_result.scalar() or 0
|
||||
|
||||
prompt_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.prompt_tokens), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(ChatMessage.tenant_id == tenant_id)
|
||||
prompt_tokens_result = await session.execute(prompt_tokens_stmt)
|
||||
prompt_tokens = prompt_tokens_result.scalar() or 0
|
||||
|
||||
completion_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.completion_tokens), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(ChatMessage.tenant_id == tenant_id)
|
||||
completion_tokens_result = await session.execute(completion_tokens_stmt)
|
||||
completion_tokens = completion_tokens_result.scalar() or 0
|
||||
|
||||
ai_requests_stmt = select(func.count()).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant"
|
||||
)
|
||||
ai_requests_result = await session.execute(ai_requests_stmt)
|
||||
ai_requests_count = ai_requests_result.scalar() or 0
|
||||
|
||||
avg_latency_stmt = select(func.coalesce(func.avg(ChatMessage.latency_ms), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None)
|
||||
)
|
||||
avg_latency_result = await session.execute(avg_latency_stmt)
|
||||
avg_latency_ms = float(avg_latency_result.scalar() or 0)
|
||||
|
||||
last_request_stmt = select(ChatMessage.latency_ms, ChatMessage.created_at).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant"
|
||||
).order_by(desc(ChatMessage.created_at)).limit(1)
|
||||
last_request_result = await session.execute(last_request_stmt)
|
||||
last_request_row = last_request_result.fetchone()
|
||||
last_latency_ms = last_request_row[0] if last_request_row else None
|
||||
last_request_time = last_request_row[1].isoformat() if last_request_row and last_request_row[1] else None
|
||||
|
||||
slow_requests_stmt = select(func.count()).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None),
|
||||
ChatMessage.latency_ms >= latency_threshold
|
||||
)
|
||||
slow_requests_result = await session.execute(slow_requests_stmt)
|
||||
slow_requests_count = slow_requests_result.scalar() or 0
|
||||
|
||||
error_requests_stmt = select(func.count()).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.is_error == True
|
||||
)
|
||||
error_requests_result = await session.execute(error_requests_stmt)
|
||||
error_requests_count = error_requests_result.scalar() or 0
|
||||
|
||||
p95_latency_stmt = select(func.coalesce(
|
||||
func.percentile_cont(0.95).within_group(ChatMessage.latency_ms), 0
|
||||
)).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None)
|
||||
)
|
||||
p95_latency_result = await session.execute(p95_latency_stmt)
|
||||
p95_latency_ms = float(p95_latency_result.scalar() or 0)
|
||||
|
||||
p99_latency_stmt = select(func.coalesce(
|
||||
func.percentile_cont(0.99).within_group(ChatMessage.latency_ms), 0
|
||||
)).select_from(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None)
|
||||
)
|
||||
p99_latency_result = await session.execute(p99_latency_stmt)
|
||||
p99_latency_ms = float(p99_latency_result.scalar() or 0)
|
||||
|
||||
min_latency_stmt = select(func.coalesce(func.min(ChatMessage.latency_ms), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None)
|
||||
)
|
||||
min_latency_result = await session.execute(min_latency_stmt)
|
||||
min_latency_ms = float(min_latency_result.scalar() or 0)
|
||||
|
||||
max_latency_stmt = select(func.coalesce(func.max(ChatMessage.latency_ms), 0)).select_from(
|
||||
ChatMessage
|
||||
).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.latency_ms.isnot(None)
|
||||
)
|
||||
max_latency_result = await session.execute(max_latency_stmt)
|
||||
max_latency_ms = float(max_latency_result.scalar() or 0)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"knowledgeBases": kb_count,
|
||||
"totalMessages": msg_count,
|
||||
"totalDocuments": doc_count,
|
||||
"totalSessions": session_count,
|
||||
"totalTokens": total_tokens,
|
||||
"promptTokens": prompt_tokens,
|
||||
"completionTokens": completion_tokens,
|
||||
"aiRequestsCount": ai_requests_count,
|
||||
"avgLatencyMs": round(avg_latency_ms, 2),
|
||||
"lastLatencyMs": last_latency_ms,
|
||||
"lastRequestTime": last_request_time,
|
||||
"slowRequestsCount": slow_requests_count,
|
||||
"errorRequestsCount": error_requests_count,
|
||||
"p95LatencyMs": round(p95_latency_ms, 2),
|
||||
"p99LatencyMs": round(p99_latency_ms, 2),
|
||||
"minLatencyMs": round(min_latency_ms, 2),
|
||||
"maxLatencyMs": round(max_latency_ms, 2),
|
||||
"latencyThresholdMs": latency_threshold,
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Embedding management API endpoints.
|
||||
[AC-AISVC-38, AC-AISVC-39, AC-AISVC-40, AC-AISVC-41] Embedding model management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
|
||||
from app.core.exceptions import InvalidRequestException
|
||||
from app.services.embedding import (
|
||||
EmbeddingException,
|
||||
EmbeddingProviderFactory,
|
||||
get_embedding_config_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/embedding", tags=["Embedding Management"])
|
||||
|
||||
|
||||
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||
"""Extract tenant ID from header."""
|
||||
if not x_tenant_id:
|
||||
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
|
||||
return x_tenant_id
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_embedding_providers(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get available embedding providers.
|
||||
[AC-AISVC-38] Returns all registered providers with config schemas.
|
||||
"""
|
||||
providers = []
|
||||
for name in EmbeddingProviderFactory.get_available_providers():
|
||||
info = EmbeddingProviderFactory.get_provider_info(name)
|
||||
providers.append(info)
|
||||
|
||||
return {"providers": providers}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_embedding_config(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get current embedding configuration.
|
||||
[AC-AISVC-39] Returns current provider and config.
|
||||
"""
|
||||
manager = get_embedding_config_manager()
|
||||
return manager.get_full_config()
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_embedding_config(
|
||||
request: dict[str, Any],
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Update embedding configuration.
|
||||
[AC-AISVC-40, AC-AISVC-31] Updates config with hot-reload support.
|
||||
"""
|
||||
provider = request.get("provider")
|
||||
config = request.get("config", {})
|
||||
|
||||
if not provider:
|
||||
raise InvalidRequestException("provider is required")
|
||||
|
||||
if provider not in EmbeddingProviderFactory.get_available_providers():
|
||||
raise InvalidRequestException(
|
||||
f"Unknown provider: {provider}. "
|
||||
f"Available: {EmbeddingProviderFactory.get_available_providers()}"
|
||||
)
|
||||
|
||||
manager = get_embedding_config_manager()
|
||||
|
||||
try:
|
||||
await manager.update_config(provider, config)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Configuration updated to use {provider}",
|
||||
}
|
||||
except EmbeddingException as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_embedding(
|
||||
request: dict[str, Any] | None = None,
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Test embedding connection.
|
||||
[AC-AISVC-41] Tests provider connectivity and returns dimension info.
|
||||
"""
|
||||
request = request or {}
|
||||
test_text = request.get("test_text", "这是一个测试文本")
|
||||
config = request.get("config")
|
||||
provider = request.get("provider")
|
||||
|
||||
manager = get_embedding_config_manager()
|
||||
|
||||
result = await manager.test_connection(
|
||||
test_text=test_text,
|
||||
provider=provider,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/formats")
|
||||
async def get_supported_document_formats(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get supported document formats for embedding.
|
||||
Returns list of supported file extensions.
|
||||
"""
|
||||
from app.services.document import get_supported_document_formats, DocumentParserFactory
|
||||
|
||||
formats = get_supported_document_formats()
|
||||
parser_info = DocumentParserFactory.get_parser_info()
|
||||
|
||||
return {
|
||||
"formats": formats,
|
||||
"parsers": parser_info,
|
||||
}
|
||||
|
|
@ -0,0 +1,593 @@
|
|||
"""
|
||||
Knowledge Base management endpoints.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Document upload, list, and index job status.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import DocumentStatus, IndexJob, IndexJobStatus
|
||||
from app.services.kb import KBService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Text chunk with metadata."""
|
||||
text: str
|
||||
start_token: int
|
||||
end_token: int
|
||||
page: int | None = None
|
||||
source: str | None = None
|
||||
|
||||
|
||||
def chunk_text_by_lines(
|
||||
text: str,
|
||||
min_line_length: int = 10,
|
||||
source: str | None = None,
|
||||
) -> list[TextChunk]:
|
||||
"""
|
||||
按行分块,每行作为一个独立的检索单元。
|
||||
|
||||
Args:
|
||||
text: 要分块的文本
|
||||
min_line_length: 最小行长度,低于此长度的行会被跳过
|
||||
source: 来源文件路径(可选)
|
||||
|
||||
Returns:
|
||||
分块列表,每个块对应一行文本
|
||||
"""
|
||||
lines = text.split('\n')
|
||||
chunks: list[TextChunk] = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < min_line_length:
|
||||
continue
|
||||
|
||||
chunks.append(TextChunk(
|
||||
text=line,
|
||||
start_token=i,
|
||||
end_token=i + 1,
|
||||
page=None,
|
||||
source=source,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_text_with_tiktoken(
|
||||
text: str,
|
||||
chunk_size: int = 512,
|
||||
overlap: int = 100,
|
||||
page: int | None = None,
|
||||
source: str | None = None,
|
||||
) -> list[TextChunk]:
|
||||
"""
|
||||
使用 tiktoken 按 token 数分块,支持重叠分块。
|
||||
|
||||
Args:
|
||||
text: 要分块的文本
|
||||
chunk_size: 每个块的最大 token 数
|
||||
overlap: 块之间的重叠 token 数
|
||||
page: 页码(可选)
|
||||
source: 来源文件路径(可选)
|
||||
|
||||
Returns:
|
||||
分块列表,每个块包含文本及起始/结束位置
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = encoding.encode(text)
|
||||
chunks: list[TextChunk] = []
|
||||
start = 0
|
||||
|
||||
while start < len(tokens):
|
||||
end = min(start + chunk_size, len(tokens))
|
||||
chunk_tokens = tokens[start:end]
|
||||
chunk_text = encoding.decode(chunk_tokens)
|
||||
chunks.append(TextChunk(
|
||||
text=chunk_text,
|
||||
start_token=start,
|
||||
end_token=end,
|
||||
page=page,
|
||||
source=source,
|
||||
))
|
||||
if end == len(tokens):
|
||||
break
|
||||
start += chunk_size - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"/knowledge-bases",
|
||||
operation_id="listKnowledgeBases",
|
||||
summary="Query knowledge base list",
|
||||
description="Get list of knowledge bases for the current tenant.",
|
||||
responses={
|
||||
200: {"description": "Knowledge base list"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def list_knowledge_bases(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
List all knowledge bases for the current tenant.
|
||||
"""
|
||||
logger.info(f"Listing knowledge bases: tenant={tenant_id}")
|
||||
|
||||
kb_service = KBService(session)
|
||||
knowledge_bases = await kb_service.list_knowledge_bases(tenant_id)
|
||||
|
||||
kb_ids = [str(kb.id) for kb in knowledge_bases]
|
||||
|
||||
doc_counts = {}
|
||||
if kb_ids:
|
||||
from sqlalchemy import func
|
||||
from app.models.entities import Document
|
||||
count_stmt = (
|
||||
select(Document.kb_id, func.count(Document.id).label("count"))
|
||||
.where(Document.tenant_id == tenant_id, Document.kb_id.in_(kb_ids))
|
||||
.group_by(Document.kb_id)
|
||||
)
|
||||
count_result = await session.execute(count_stmt)
|
||||
for row in count_result:
|
||||
doc_counts[row.kb_id] = row.count
|
||||
|
||||
data = []
|
||||
for kb in knowledge_bases:
|
||||
kb_id_str = str(kb.id)
|
||||
data.append({
|
||||
"id": kb_id_str,
|
||||
"name": kb.name,
|
||||
"documentCount": doc_counts.get(kb_id_str, 0),
|
||||
"createdAt": kb.created_at.isoformat() + "Z",
|
||||
})
|
||||
|
||||
return JSONResponse(content={"data": data})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
operation_id="listDocuments",
|
||||
summary="Query document list",
|
||||
description="[AC-ASA-08] Get list of documents with pagination and filtering.",
|
||||
responses={
|
||||
200: {"description": "Document list with pagination"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def list_documents(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
kb_id: Annotated[Optional[str], Query()] = None,
|
||||
status: Annotated[Optional[str], Query()] = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-08] List documents with filtering and pagination.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-08] Listing documents: tenant={tenant_id}, kb_id={kb_id}, "
|
||||
f"status={status}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
kb_service = KBService(session)
|
||||
documents, total = await kb_service.list_documents(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
status=status,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
||||
|
||||
data = []
|
||||
for doc in documents:
|
||||
job_stmt = select(IndexJob).where(
|
||||
IndexJob.tenant_id == tenant_id,
|
||||
IndexJob.doc_id == doc.id,
|
||||
).order_by(IndexJob.created_at.desc())
|
||||
job_result = await session.execute(job_stmt)
|
||||
latest_job = job_result.scalar_one_or_none()
|
||||
|
||||
data.append({
|
||||
"docId": str(doc.id),
|
||||
"kbId": doc.kb_id,
|
||||
"fileName": doc.file_name,
|
||||
"status": doc.status,
|
||||
"jobId": str(latest_job.id) if latest_job else None,
|
||||
"createdAt": doc.created_at.isoformat() + "Z",
|
||||
"updatedAt": doc.updated_at.isoformat() + "Z",
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"data": data,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"pageSize": page_size,
|
||||
"total": total,
|
||||
"totalPages": total_pages,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/documents",
|
||||
operation_id="uploadDocument",
|
||||
summary="Upload/import document",
|
||||
description="[AC-ASA-01] Upload document and trigger indexing job.",
|
||||
responses={
|
||||
202: {"description": "Accepted - async indexing job started"},
|
||||
400: {"description": "Bad Request - unsupported format"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def upload_document(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
kb_id: str = Form(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-01] Upload document and create indexing job.
|
||||
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35, AC-AISVC-37] Support multiple document formats.
|
||||
"""
|
||||
from app.services.document import get_supported_document_formats, UnsupportedFormatError
|
||||
from pathlib import Path
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-01] Uploading document: tenant={tenant_id}, "
|
||||
f"kb_id={kb_id}, filename={file.filename}"
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename or "").suffix.lower()
|
||||
supported_formats = get_supported_document_formats()
|
||||
|
||||
if file_ext and file_ext not in supported_formats:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"code": "UNSUPPORTED_FORMAT",
|
||||
"message": f"Unsupported file format: {file_ext}",
|
||||
"details": {
|
||||
"supported_formats": supported_formats,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
kb_service = KBService(session)
|
||||
|
||||
kb = await kb_service.get_or_create_kb(tenant_id, kb_id)
|
||||
|
||||
file_content = await file.read()
|
||||
document, job = await kb_service.upload_document(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=str(kb.id),
|
||||
file_name=file.filename or "unknown",
|
||||
file_content=file_content,
|
||||
file_type=file.content_type,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
background_tasks.add_task(
|
||||
_index_document, tenant_id, str(job.id), str(document.id), file_content, file.filename
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"jobId": str(job.id),
|
||||
"docId": str(document.id),
|
||||
"status": job.status,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes, filename: str | None = None):
|
||||
"""
|
||||
Background indexing task.
|
||||
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Uses document parsing and pluggable embedding.
|
||||
"""
|
||||
from app.core.database import async_session_maker
|
||||
from app.services.kb import KBService
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from app.services.embedding import get_embedding_provider
|
||||
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException, PageText
|
||||
from qdrant_client.models import PointStruct
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
kb_service = KBService(session)
|
||||
try:
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
parse_result = None
|
||||
text = None
|
||||
file_ext = Path(filename or "").suffix.lower()
|
||||
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
|
||||
|
||||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||||
|
||||
if file_ext in text_extensions or not file_ext:
|
||||
logger.info(f"[INDEX] Treating as text file, trying multiple encodings")
|
||||
text = None
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
logger.info(f"[INDEX] Successfully decoded with encoding: {encoding}")
|
||||
break
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
logger.warning(f"[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
|
||||
else:
|
||||
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
logger.info(f"[INDEX] Temp file created: {tmp_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
|
||||
parse_result = parse_document(tmp_path)
|
||||
text = parse_result.text
|
||||
logger.info(
|
||||
f"[INDEX] Parsed document SUCCESS: {filename}, "
|
||||
f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
|
||||
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
|
||||
f"metadata={parse_result.metadata}"
|
||||
)
|
||||
if len(text) < 100:
|
||||
logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
|
||||
except UnsupportedFormatError as e:
|
||||
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
except DocumentParseException as e:
|
||||
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
logger.info(f"[INDEX] Temp file cleaned up")
|
||||
|
||||
logger.info(f"[INDEX] Final text length: {len(text)} chars")
|
||||
if len(text) < 50:
|
||||
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
|
||||
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[INDEX] Getting embedding provider...")
|
||||
embedding_provider = await get_embedding_provider()
|
||||
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
|
||||
|
||||
all_chunks: list[TextChunk] = []
|
||||
|
||||
if parse_result and parse_result.pages:
|
||||
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
|
||||
for page in parse_result.pages:
|
||||
page_chunks = chunk_text_by_lines(
|
||||
page.text,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
for pc in page_chunks:
|
||||
pc.page = page.page
|
||||
all_chunks.extend(page_chunks)
|
||||
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
|
||||
else:
|
||||
logger.info(f"[INDEX] Using line-based chunking")
|
||||
all_chunks = chunk_text_by_lines(
|
||||
text,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
|
||||
|
||||
qdrant = await get_qdrant_client()
|
||||
await qdrant.ensure_collection_exists(tenant_id)
|
||||
|
||||
points = []
|
||||
total_chunks = len(all_chunks)
|
||||
for i, chunk in enumerate(all_chunks):
|
||||
embedding = await embedding_provider.embed(chunk.text)
|
||||
|
||||
payload = {
|
||||
"text": chunk.text,
|
||||
"source": doc_id,
|
||||
"chunk_index": i,
|
||||
"start_token": chunk.start_token,
|
||||
"end_token": chunk.end_token,
|
||||
}
|
||||
if chunk.page is not None:
|
||||
payload["page"] = chunk.page
|
||||
if chunk.source:
|
||||
payload["filename"] = chunk.source
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=embedding,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
progress = 20 + int((i + 1) / total_chunks * 70)
|
||||
if i % 10 == 0 or i == total_chunks - 1:
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if points:
|
||||
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
|
||||
await qdrant.upsert_vectors(tenant_id, points)
|
||||
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"[INDEX] COMPLETED: tenant={tenant_id}, "
|
||||
f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
|
||||
await session.rollback()
|
||||
async with async_session_maker() as error_session:
|
||||
kb_service = KBService(error_session)
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.FAILED.value,
|
||||
progress=0, error_msg=str(e)
|
||||
)
|
||||
await error_session.commit()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/index/jobs/{job_id}",
|
||||
operation_id="getIndexJob",
|
||||
summary="Query index job status",
|
||||
description="[AC-ASA-02] Get indexing job status and progress.",
|
||||
responses={
|
||||
200: {"description": "Job status details"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def get_index_job(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
job_id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-02] Get indexing job status with progress.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-02] Getting job status: tenant={tenant_id}, job_id={job_id}"
|
||||
)
|
||||
|
||||
kb_service = KBService(session)
|
||||
job = await kb_service.get_index_job(tenant_id, job_id)
|
||||
|
||||
if not job:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "JOB_NOT_FOUND",
|
||||
"message": f"Job {job_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jobId": str(job.id),
|
||||
"docId": str(job.doc_id),
|
||||
"status": job.status,
|
||||
"progress": job.progress,
|
||||
"errorMsg": job.error_msg,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{doc_id}",
|
||||
operation_id="deleteDocument",
|
||||
summary="Delete document",
|
||||
description="[AC-ASA-08] Delete a document and its associated files.",
|
||||
responses={
|
||||
200: {"description": "Document deleted"},
|
||||
404: {"description": "Document not found"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def delete_document(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
doc_id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-08] Delete a document.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-08] Deleting document: tenant={tenant_id}, doc_id={doc_id}"
|
||||
)
|
||||
|
||||
kb_service = KBService(session)
|
||||
deleted = await kb_service.delete_document(tenant_id, doc_id)
|
||||
|
||||
if not deleted:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "DOCUMENT_NOT_FOUND",
|
||||
"message": f"Document {doc_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Document deleted",
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
Knowledge base management API with RAG optimization features.
|
||||
Reference: rag-optimization/spec.md Section 4.2
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.services.retrieval import (
|
||||
ChunkMetadata,
|
||||
ChunkMetadataModel,
|
||||
IndexingProgress,
|
||||
IndexingResult,
|
||||
KnowledgeIndexer,
|
||||
MetadataFilter,
|
||||
RetrievalStrategy,
|
||||
get_knowledge_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"])
|
||||
|
||||
|
||||
class IndexDocumentRequest(BaseModel):
|
||||
"""Request to index a document."""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
document_id: str = Field(..., description="Document ID")
|
||||
text: str = Field(..., description="Document text content")
|
||||
metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata")
|
||||
|
||||
|
||||
class IndexDocumentResponse(BaseModel):
|
||||
"""Response from document indexing."""
|
||||
success: bool
|
||||
total_chunks: int
|
||||
indexed_chunks: int
|
||||
failed_chunks: int
|
||||
elapsed_seconds: float
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class IndexingProgressResponse(BaseModel):
|
||||
"""Response with current indexing progress."""
|
||||
total_chunks: int
|
||||
processed_chunks: int
|
||||
failed_chunks: int
|
||||
progress_percent: int
|
||||
elapsed_seconds: float
|
||||
current_document: str
|
||||
|
||||
|
||||
class MetadataFilterRequest(BaseModel):
|
||||
"""Request for metadata filtering."""
|
||||
categories: list[str] | None = None
|
||||
target_audiences: list[str] | None = None
|
||||
departments: list[str] | None = None
|
||||
valid_only: bool = True
|
||||
min_priority: int | None = None
|
||||
keywords: list[str] | None = None
|
||||
|
||||
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Request for knowledge retrieval."""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
query: str = Field(..., description="Search query")
|
||||
top_k: int = Field(default=10, ge=1, le=50, description="Number of results")
|
||||
filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters")
|
||||
strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy")
|
||||
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
"""Response from knowledge retrieval."""
|
||||
hits: list[dict[str, Any]]
|
||||
total_hits: int
|
||||
max_score: float
|
||||
is_insufficient: bool
|
||||
diagnostics: dict[str, Any]
|
||||
|
||||
|
||||
class MetadataOptionsResponse(BaseModel):
|
||||
"""Response with available metadata options."""
|
||||
categories: list[str]
|
||||
departments: list[str]
|
||||
target_audiences: list[str]
|
||||
priorities: list[int]
|
||||
|
||||
|
||||
@router.post("/index", response_model=IndexDocumentResponse)
|
||||
async def index_document(
|
||||
request: IndexDocumentRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Index a document with optimized embedding.
|
||||
|
||||
Features:
|
||||
- Task prefixes (search_document:) for document embedding
|
||||
- Multi-dimensional vectors (256/512/768)
|
||||
- Metadata support
|
||||
"""
|
||||
try:
|
||||
index = get_knowledge_indexer()
|
||||
|
||||
chunk_metadata = None
|
||||
if request.metadata:
|
||||
chunk_metadata = ChunkMetadata(
|
||||
category=request.metadata.category,
|
||||
subcategory=request.metadata.subcategory,
|
||||
target_audience=request.metadata.target_audience,
|
||||
source_doc=request.metadata.source_doc,
|
||||
source_url=request.metadata.source_url,
|
||||
department=request.metadata.department,
|
||||
priority=request.metadata.priority,
|
||||
keywords=request.metadata.keywords,
|
||||
)
|
||||
|
||||
result = await index.index_document(
|
||||
tenant_id=request.tenant_id,
|
||||
document_id=request.document_id,
|
||||
text=request.text,
|
||||
metadata=chunk_metadata,
|
||||
)
|
||||
|
||||
return IndexDocumentResponse(
|
||||
success=result.success,
|
||||
total_chunks=result.total_chunks,
|
||||
indexed_chunks=result.indexed_chunks,
|
||||
failed_chunks=result.failed_chunks,
|
||||
elapsed_seconds=result.elapsed_seconds,
|
||||
error_message=result.error_message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to index document: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"索引失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/index/progress", response_model=IndexingProgressResponse | None)
|
||||
async def get_indexing_progress():
|
||||
"""Get current indexing progress."""
|
||||
try:
|
||||
index = get_knowledge_indexer()
|
||||
progress = index.get_progress()
|
||||
|
||||
if progress is None:
|
||||
return None
|
||||
|
||||
return IndexingProgressResponse(
|
||||
total_chunks=progress.total_chunks,
|
||||
processed_chunks=progress.processed_chunks,
|
||||
failed_chunks=progress.failed_chunks,
|
||||
progress_percent=progress.progress_percent,
|
||||
elapsed_seconds=progress.elapsed_seconds,
|
||||
current_document=progress.current_document,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to get progress: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取进度失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve_knowledge(request: RetrieveRequest):
|
||||
"""
|
||||
Retrieve knowledge using optimized RAG.
|
||||
|
||||
Strategies:
|
||||
- vector: Simple vector search
|
||||
- bm25: BM25 keyword search
|
||||
- hybrid: RRF combination of vector + BM25 (default)
|
||||
- two_stage: Two-stage retrieval with Matryoshka dimensions
|
||||
"""
|
||||
try:
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
metadata_filter = None
|
||||
if request.filters:
|
||||
filter_dict = request.filters.model_dump(exclude_none=True)
|
||||
metadata_filter = MetadataFilter(**filter_dict)
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=request.tenant_id,
|
||||
query=request.query,
|
||||
)
|
||||
|
||||
if metadata_filter:
|
||||
ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()}
|
||||
|
||||
result = await retriever.retrieve(ctx)
|
||||
|
||||
return RetrieveResponse(
|
||||
hits=[
|
||||
{
|
||||
"text": hit.text,
|
||||
"score": hit.score,
|
||||
"source": hit.source,
|
||||
"metadata": hit.metadata,
|
||||
}
|
||||
for hit in result.hits
|
||||
],
|
||||
total_hits=result.hit_count,
|
||||
max_score=result.max_score,
|
||||
is_insufficient=result.diagnostics.get("is_insufficient", False),
|
||||
diagnostics=result.diagnostics or {},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to retrieve: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"检索失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metadata/options", response_model=MetadataOptionsResponse)
|
||||
async def get_metadata_options():
|
||||
"""
|
||||
Get available metadata options for filtering.
|
||||
These would typically be loaded from a database.
|
||||
"""
|
||||
try:
|
||||
return MetadataOptionsResponse(
|
||||
categories=[
|
||||
"课程咨询",
|
||||
"考试政策",
|
||||
"学籍管理",
|
||||
"奖助学金",
|
||||
"宿舍管理",
|
||||
"校园服务",
|
||||
"就业指导",
|
||||
"其他",
|
||||
],
|
||||
departments=[
|
||||
"教务处",
|
||||
"学生处",
|
||||
"财务处",
|
||||
"后勤处",
|
||||
"就业指导中心",
|
||||
"图书馆",
|
||||
"信息中心",
|
||||
],
|
||||
target_audiences=[
|
||||
"本科生",
|
||||
"研究生",
|
||||
"留学生",
|
||||
"新生",
|
||||
"毕业生",
|
||||
"教职工",
|
||||
],
|
||||
priorities=list(range(1, 11)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to get metadata options: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reindex")
|
||||
async def reindex_all(
|
||||
tenant_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Reindex all documents for a tenant with optimized embedding.
|
||||
This would typically read from the documents table and reindex.
|
||||
"""
|
||||
try:
|
||||
from app.models.entities import Document, DocumentStatus
|
||||
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.status == DocumentStatus.COMPLETED.value,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
documents = result.scalars().all()
|
||||
|
||||
index = get_knowledge_indexer()
|
||||
|
||||
total_indexed = 0
|
||||
total_failed = 0
|
||||
|
||||
for doc in documents:
|
||||
if doc.file_path:
|
||||
import os
|
||||
if os.path.exists(doc.file_path):
|
||||
with open(doc.file_path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
result = await index.index_document(
|
||||
tenant_id=tenant_id,
|
||||
document_id=str(doc.id),
|
||||
text=text,
|
||||
)
|
||||
|
||||
total_indexed += result.indexed_chunks
|
||||
total_failed += result.failed_chunks
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"total_documents": len(documents),
|
||||
"total_indexed": total_indexed,
|
||||
"total_failed": total_failed,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to reindex: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"重新索引失败: {str(e)}"
|
||||
)
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
"""
|
||||
LLM Configuration Management API.
|
||||
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
|
||||
from app.services.llm.factory import (
|
||||
LLMConfigManager,
|
||||
LLMProviderFactory,
|
||||
get_llm_config_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/llm", tags=["LLM Management"])
|
||||
|
||||
|
||||
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
|
||||
"""Extract tenant ID from header."""
|
||||
if not x_tenant_id:
|
||||
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
|
||||
return x_tenant_id
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_providers(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List all available LLM providers.
|
||||
[AC-ASA-15] Returns provider list with configuration schemas.
|
||||
"""
|
||||
logger.info(f"[AC-ASA-15] Listing LLM providers for tenant={tenant_id}")
|
||||
|
||||
providers = LLMProviderFactory.get_providers()
|
||||
return {
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"display_name": p.display_name,
|
||||
"description": p.description,
|
||||
"config_schema": p.config_schema,
|
||||
}
|
||||
for p in providers
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get current LLM configuration.
|
||||
[AC-ASA-14] Returns current provider and config.
|
||||
"""
|
||||
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}")
|
||||
|
||||
manager = get_llm_config_manager()
|
||||
config = manager.get_current_config()
|
||||
|
||||
masked_config = _mask_secrets(config.get("config", {}))
|
||||
|
||||
return {
|
||||
"provider": config["provider"],
|
||||
"config": masked_config,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_config(
|
||||
body: dict[str, Any],
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Update LLM configuration.
|
||||
[AC-ASA-16] Updates provider and config with validation.
|
||||
"""
|
||||
provider = body.get("provider")
|
||||
config = body.get("config", {})
|
||||
|
||||
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}")
|
||||
|
||||
if not provider:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Provider is required",
|
||||
}
|
||||
|
||||
try:
|
||||
manager = get_llm_config_manager()
|
||||
await manager.update_config(provider, config)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"LLM configuration updated to {provider}",
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[AC-ASA-16] Invalid LLM config: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_connection(
|
||||
body: dict[str, Any] | None = None,
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Test LLM connection.
|
||||
[AC-ASA-17, AC-ASA-18] Tests connection and returns response.
|
||||
"""
|
||||
body = body or {}
|
||||
|
||||
test_prompt = body.get("test_prompt", "你好,请简单介绍一下自己。")
|
||||
provider = body.get("provider")
|
||||
config = body.get("config")
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-17] Testing LLM connection for tenant={tenant_id}, "
|
||||
f"provider={provider or 'current'}"
|
||||
)
|
||||
|
||||
manager = get_llm_config_manager()
|
||||
result = await manager.test_connection(
|
||||
test_prompt=test_prompt,
|
||||
provider=provider,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _mask_secrets(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mask secret fields in config for display."""
|
||||
masked = {}
|
||||
for key, value in config.items():
|
||||
if key in ("api_key", "password", "secret"):
|
||||
if value:
|
||||
masked[key] = f"{str(value)[:4]}***"
|
||||
else:
|
||||
masked[key] = ""
|
||||
else:
|
||||
masked[key] = value
|
||||
return masked
|
||||
|
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
RAG Lab endpoints for debugging and experimentation.
|
||||
[AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Annotated, Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.prompts import format_evidence_for_prompt, build_user_prompt_with_evidence
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
class RAGExperimentRequest(BaseModel):
|
||||
query: str = Field(..., description="Query text for retrieval")
|
||||
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
|
||||
top_k: int = Field(default=5, description="Number of results to retrieve")
|
||||
score_threshold: float = Field(default=0.5, description="Minimum similarity score")
|
||||
generate_response: bool = Field(default=True, description="Whether to generate AI response")
|
||||
llm_provider: str | None = Field(default=None, description="Specific LLM provider to use")
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
content: str
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
latency_ms: float = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
class RAGExperimentResult(BaseModel):
|
||||
query: str
|
||||
retrieval_results: List[dict] = []
|
||||
final_prompt: str = ""
|
||||
ai_response: AIResponse | None = None
|
||||
total_latency_ms: float = 0
|
||||
diagnostics: dict[str, Any] = {}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/experiments/run",
|
||||
operation_id="runRagExperiment",
|
||||
summary="Run RAG debugging experiment with AI output",
|
||||
description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.",
|
||||
responses={
|
||||
200: {"description": "Experiment results with retrieval, prompt, and AI response"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def run_rag_experiment(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
request: RAGExperimentRequest = Body(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
|
||||
f"query={request.query[:50]}..., kb_ids={request.kb_ids}, "
|
||||
f"generate_response={request.generate_response}"
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
top_k = request.top_k or settings.rag_top_k
|
||||
threshold = request.score_threshold or settings.rag_score_threshold
|
||||
|
||||
try:
|
||||
# Use optimized retriever with RAG enhancements
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=request.query,
|
||||
session_id="rag_experiment",
|
||||
channel_type="admin",
|
||||
metadata={"kb_ids": request.kb_ids},
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(retrieval_ctx)
|
||||
|
||||
retrieval_results = [
|
||||
{
|
||||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"source": hit.source,
|
||||
"metadata": hit.metadata,
|
||||
}
|
||||
for hit in result.hits
|
||||
]
|
||||
|
||||
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, "
|
||||
f"max_score={result.max_score:.3f}"
|
||||
)
|
||||
|
||||
ai_response = None
|
||||
if request.generate_response:
|
||||
ai_response = await _generate_ai_response(
|
||||
final_prompt,
|
||||
provider=request.llm_provider,
|
||||
)
|
||||
|
||||
total_latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"query": request.query,
|
||||
"retrieval_results": retrieval_results,
|
||||
"final_prompt": final_prompt,
|
||||
"ai_response": ai_response.model_dump() if ai_response else None,
|
||||
"total_latency_ms": round(total_latency_ms, 2),
|
||||
"diagnostics": result.diagnostics,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
|
||||
|
||||
fallback_results = _get_fallback_results(request.query)
|
||||
fallback_prompt = _build_final_prompt(request.query, fallback_results)
|
||||
|
||||
ai_response = None
|
||||
if request.generate_response:
|
||||
ai_response = await _generate_ai_response(
|
||||
fallback_prompt,
|
||||
provider=request.llm_provider,
|
||||
)
|
||||
|
||||
total_latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"query": request.query,
|
||||
"retrieval_results": fallback_results,
|
||||
"final_prompt": fallback_prompt,
|
||||
"ai_response": ai_response.model_dump() if ai_response else None,
|
||||
"total_latency_ms": round(total_latency_ms, 2),
|
||||
"diagnostics": {
|
||||
"error": str(e),
|
||||
"fallback": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/experiments/stream",
|
||||
operation_id="runRagExperimentStream",
|
||||
summary="Run RAG experiment with streaming AI output",
|
||||
description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.",
|
||||
responses={
|
||||
200: {"description": "SSE stream with retrieval results and AI response"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def run_rag_experiment_stream(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
request: RAGExperimentRequest = Body(...),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
[AC-ASA-20] Run RAG experiment with SSE streaming for AI response.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, "
|
||||
f"query={request.query[:50]}..."
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
top_k = request.top_k or settings.rag_top_k
|
||||
threshold = request.score_threshold or settings.rag_score_threshold
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
# Use optimized retriever with RAG enhancements
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=request.query,
|
||||
session_id="rag_experiment_stream",
|
||||
channel_type="admin",
|
||||
metadata={"kb_ids": request.kb_ids},
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(retrieval_ctx)
|
||||
|
||||
retrieval_results = [
|
||||
{
|
||||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"source": hit.source,
|
||||
"metadata": hit.metadata,
|
||||
}
|
||||
for hit in result.hits
|
||||
]
|
||||
|
||||
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
||||
|
||||
logger.info(f"[AC-ASA-20] ========== RAG LAB STREAM FULL PROMPT ==========")
|
||||
logger.info(f"[AC-ASA-20] Prompt length: {len(final_prompt)}")
|
||||
logger.info(f"[AC-ASA-20] Prompt content:\n{final_prompt}")
|
||||
logger.info(f"[AC-ASA-20] ==============================================")
|
||||
|
||||
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
|
||||
|
||||
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
|
||||
|
||||
if request.generate_response:
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_client()
|
||||
|
||||
full_content = ""
|
||||
async for chunk in client.stream_generate(
|
||||
messages=[{"role": "user", "content": final_prompt}],
|
||||
):
|
||||
if chunk.delta:
|
||||
full_content += chunk.delta
|
||||
yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n"
|
||||
|
||||
yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n"
|
||||
else:
|
||||
yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _generate_ai_response(
|
||||
prompt: str,
|
||||
provider: str | None = None,
|
||||
) -> AIResponse | None:
|
||||
"""
|
||||
[AC-ASA-19, AC-ASA-21] Generate AI response from prompt.
|
||||
"""
|
||||
import time
|
||||
|
||||
logger.info(f"[AC-ASA-19] ========== RAG LAB FULL PROMPT ==========")
|
||||
logger.info(f"[AC-ASA-19] Prompt length: {len(prompt)}")
|
||||
logger.info(f"[AC-ASA-19] Prompt content:\n{prompt}")
|
||||
logger.info(f"[AC-ASA-19] ==========================================")
|
||||
|
||||
try:
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_client()
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.generate(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return AIResponse(
|
||||
content=response.content,
|
||||
prompt_tokens=response.usage.get("prompt_tokens", 0),
|
||||
completion_tokens=response.usage.get("completion_tokens", 0),
|
||||
total_tokens=response.usage.get("total_tokens", 0),
|
||||
latency_ms=round(latency_ms, 2),
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-19] AI response generation failed: {e}")
|
||||
return AIResponse(
|
||||
content=f"AI 响应生成失败: {str(e)}",
|
||||
latency_ms=0,
|
||||
)
|
||||
|
||||
|
||||
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
|
||||
"""
|
||||
Build the final prompt from query and retrieval results.
|
||||
Uses shared prompt configuration for consistency with orchestrator.
|
||||
"""
|
||||
evidence_text = format_evidence_for_prompt(retrieval_results, max_results=5, max_content_length=500)
|
||||
return build_user_prompt_with_evidence(query, evidence_text)
|
||||
|
||||
|
||||
def _get_fallback_results(query: str) -> list[dict]:
|
||||
"""
|
||||
Provide fallback results when retrieval fails.
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"content": "检索服务暂时不可用,这是模拟结果。",
|
||||
"score": 0.5,
|
||||
"source": "fallback",
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
"""
|
||||
Session monitoring and management endpoints.
|
||||
[AC-ASA-07, AC-ASA-09] Session list and detail monitoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Optional, Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import ChatSession, ChatMessage, SessionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/sessions", tags=["Session Monitoring"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listSessions",
|
||||
summary="Query session list",
|
||||
description="[AC-ASA-09] Get list of sessions with pagination and filtering.",
|
||||
responses={
|
||||
200: {"description": "Session list with pagination"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def list_sessions(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
status: Annotated[Optional[str], Query()] = None,
|
||||
start_time: Annotated[Optional[str], Query(alias="startTime")] = None,
|
||||
end_time: Annotated[Optional[str], Query(alias="endTime")] = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-09] List sessions with filtering and pagination.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-09] Listing sessions: tenant={tenant_id}, status={status}, "
|
||||
f"start_time={start_time}, end_time={end_time}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
stmt = select(ChatSession).where(ChatSession.tenant_id == tenant_id)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(ChatSession.metadata_["status"].as_string() == status)
|
||||
|
||||
if start_time:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
|
||||
stmt = stmt.where(ChatSession.created_at >= start_dt)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
|
||||
stmt = stmt.where(ChatSession.created_at <= end_dt)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total_result = await session.execute(count_stmt)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
stmt = stmt.order_by(col(ChatSession.created_at).desc())
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
sessions = result.scalars().all()
|
||||
|
||||
session_ids = [s.session_id for s in sessions]
|
||||
|
||||
if session_ids:
|
||||
msg_count_stmt = (
|
||||
select(
|
||||
ChatMessage.session_id,
|
||||
func.count(ChatMessage.id).label("count")
|
||||
)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id.in_(session_ids)
|
||||
)
|
||||
.group_by(ChatMessage.session_id)
|
||||
)
|
||||
msg_count_result = await session.execute(msg_count_stmt)
|
||||
msg_counts = {row.session_id: row.count for row in msg_count_result}
|
||||
else:
|
||||
msg_counts = {}
|
||||
|
||||
data = []
|
||||
for s in sessions:
|
||||
session_status = SessionStatus.ACTIVE.value
|
||||
if s.metadata_ and "status" in s.metadata_:
|
||||
session_status = s.metadata_["status"]
|
||||
|
||||
end_time_val = None
|
||||
if s.metadata_ and "endTime" in s.metadata_:
|
||||
end_time_val = s.metadata_["endTime"]
|
||||
|
||||
data.append({
|
||||
"sessionId": s.session_id,
|
||||
"tenantId": tenant_id,
|
||||
"status": session_status,
|
||||
"startTime": s.created_at.isoformat() + "Z",
|
||||
"endTime": end_time_val,
|
||||
"messageCount": msg_counts.get(s.session_id, 0),
|
||||
"channelType": s.channel_type,
|
||||
})
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"data": data,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"pageSize": page_size,
|
||||
"total": total,
|
||||
"totalPages": total_pages,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{session_id}",
|
||||
operation_id="getSessionDetail",
|
||||
summary="Get session details",
|
||||
description="[AC-ASA-07] Get full session details with messages and trace.",
|
||||
responses={
|
||||
200: {"description": "Full session details with messages and trace"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def get_session_detail(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
session_id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-07] Get session detail with messages and trace information.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-07] Getting session detail: tenant={tenant_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
session_stmt = select(ChatSession).where(
|
||||
ChatSession.tenant_id == tenant_id,
|
||||
ChatSession.session_id == session_id,
|
||||
)
|
||||
session_result = await session.execute(session_stmt)
|
||||
chat_session = session_result.scalar_one_or_none()
|
||||
|
||||
if not chat_session:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "SESSION_NOT_FOUND",
|
||||
"message": f"Session {session_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
messages_stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).asc())
|
||||
)
|
||||
messages_result = await session.execute(messages_stmt)
|
||||
messages = messages_result.scalars().all()
|
||||
|
||||
messages_data = []
|
||||
for msg in messages:
|
||||
msg_data = {
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.created_at.isoformat() + "Z",
|
||||
}
|
||||
messages_data.append(msg_data)
|
||||
|
||||
trace = _build_trace_info(messages)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"sessionId": session_id,
|
||||
"messages": messages_data,
|
||||
"trace": trace,
|
||||
"metadata": chat_session.metadata_ or {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_trace_info(messages: Sequence[ChatMessage]) -> dict:
|
||||
"""
|
||||
Build trace information from messages.
|
||||
This extracts retrieval and tool call information from message metadata.
|
||||
"""
|
||||
trace = {
|
||||
"retrieval": [],
|
||||
"tools": [],
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "assistant":
|
||||
pass
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{session_id}/status",
|
||||
operation_id="updateSessionStatus",
|
||||
summary="Update session status",
|
||||
description="[AC-ASA-09] Update session status (active, closed, expired).",
|
||||
responses={
|
||||
200: {"description": "Session status updated"},
|
||||
404: {"description": "Session not found"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def update_session_status(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
db_session: Annotated[AsyncSession, Depends(get_session)],
|
||||
session_id: str,
|
||||
status: str = Query(..., description="New status: active, closed, expired"),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-ASA-09] Update session status.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-ASA-09] Updating session status: tenant={tenant_id}, "
|
||||
f"session_id={session_id}, status={status}"
|
||||
)
|
||||
|
||||
stmt = select(ChatSession).where(
|
||||
ChatSession.tenant_id == tenant_id,
|
||||
ChatSession.session_id == session_id,
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
chat_session = result.scalar_one_or_none()
|
||||
|
||||
if not chat_session:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"code": "SESSION_NOT_FOUND",
|
||||
"message": f"Session {session_id} not found",
|
||||
},
|
||||
)
|
||||
|
||||
metadata = chat_session.metadata_ or {}
|
||||
metadata["status"] = status
|
||||
|
||||
if status == SessionStatus.CLOSED.value or status == SessionStatus.EXPIRED.value:
|
||||
metadata["endTime"] = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
chat_session.metadata_ = metadata
|
||||
chat_session.updated_at = datetime.utcnow()
|
||||
await db_session.flush()
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"sessionId": session_id,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
Tenant management endpoints.
|
||||
Provides tenant list and management functionality.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.middleware import parse_tenant_id
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import Tenant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/tenants", tags=["Tenants"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listTenants",
|
||||
summary="List all tenants",
|
||||
description="Get a list of all tenants from the system.",
|
||||
responses={
|
||||
200: {"description": "List of tenants"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def list_tenants(
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Get a list of all tenants from the tenants table.
|
||||
Returns tenant ID and display name (first part of tenant_id).
|
||||
"""
|
||||
logger.info("Getting all tenants")
|
||||
|
||||
# Get all tenants from tenants table
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
result = await session.execute(stmt)
|
||||
tenants = result.scalars().all()
|
||||
|
||||
# Format tenant list with display name
|
||||
tenant_list = []
|
||||
for tenant in tenants:
|
||||
name, year = parse_tenant_id(tenant.tenant_id)
|
||||
tenant_list.append({
|
||||
"id": tenant.tenant_id,
|
||||
"name": f"{name} ({year})",
|
||||
"displayName": name,
|
||||
"year": year,
|
||||
"createdAt": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(tenant_list)} tenants: {[t['id'] for t in tenant_list]}")
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"tenants": tenant_list,
|
||||
"total": len(tenant_list)
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
"""
|
||||
Chat endpoint for AI Service.
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-08, AC-AISVC-09] Main chat endpoint with streaming/non-streaming modes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.middleware import get_response_mode, is_sse_request
|
||||
from app.core.sse import SSEStateMachine, create_error_event
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
||||
from app.services.memory import MemoryService
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["AI Chat"])
|
||||
|
||||
|
||||
async def get_orchestrator_service_with_memory(
|
||||
session: Annotated[AsyncSession, Depends(get_session)]
|
||||
) -> OrchestratorService:
|
||||
"""
|
||||
[AC-AISVC-13] Create orchestrator service with memory service and LLM client.
|
||||
Ensures each request has a fresh MemoryService with database session.
|
||||
"""
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
|
||||
memory_service = MemoryService(session)
|
||||
llm_config_manager = get_llm_config_manager()
|
||||
llm_client = llm_config_manager.get_client()
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
return OrchestratorService(
|
||||
llm_client=llm_client,
|
||||
memory_service=memory_service,
|
||||
retriever=retriever,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/chat",
|
||||
operation_id="generateReply",
|
||||
summary="Generate AI reply",
|
||||
description="""
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Generate AI reply based on user message.
|
||||
|
||||
Response mode is determined by Accept header:
|
||||
- Accept: text/event-stream -> SSE streaming response
|
||||
- Other -> JSON response
|
||||
""",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Success - JSON or SSE stream",
|
||||
"content": {
|
||||
"application/json": {"schema": {"$ref": "#/components/schemas/ChatResponse"}},
|
||||
"text/event-stream": {"schema": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
400: {"description": "Invalid request", "model": ErrorResponse},
|
||||
500: {"description": "Internal error", "model": ErrorResponse},
|
||||
503: {"description": "Service unavailable", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def generate_reply(
|
||||
request: Request,
|
||||
chat_request: ChatRequest,
|
||||
accept: Annotated[str | None, Header()] = None,
|
||||
orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory),
|
||||
) -> Any:
|
||||
"""
|
||||
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
|
||||
|
||||
Based on Accept header:
|
||||
- text/event-stream: Returns SSE stream with message/final/error events
|
||||
- Other: Returns JSON ChatResponse
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-06] Processing chat request: tenant={tenant_id}, "
|
||||
f"session={chat_request.session_id}, mode={get_response_mode(request)}"
|
||||
)
|
||||
|
||||
if is_sse_request(request):
|
||||
return await _handle_streaming_request(tenant_id, chat_request, orchestrator)
|
||||
else:
|
||||
return await _handle_json_request(tenant_id, chat_request, orchestrator)
|
||||
|
||||
|
||||
async def _handle_json_request(
|
||||
tenant_id: str,
|
||||
chat_request: ChatRequest,
|
||||
orchestrator: OrchestratorService,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-AISVC-02] Handle non-streaming JSON request.
|
||||
Returns ChatResponse with reply, confidence, shouldTransfer.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-02] Processing JSON request for tenant={tenant_id}")
|
||||
|
||||
try:
|
||||
response = await orchestrator.generate(tenant_id, chat_request)
|
||||
return JSONResponse(
|
||||
content=response.model_dump(exclude_none=True, by_alias=True),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-04] Error generating response: {e}")
|
||||
from app.core.exceptions import AIServiceException, ErrorCode
|
||||
if isinstance(e, AIServiceException):
|
||||
raise e
|
||||
from app.core.exceptions import AIServiceException
|
||||
raise AIServiceException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_streaming_request(
|
||||
tenant_id: str,
|
||||
chat_request: ChatRequest,
|
||||
orchestrator: OrchestratorService,
|
||||
) -> EventSourceResponse:
|
||||
"""
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
|
||||
|
||||
SSE Event Sequence (per design.md Section 6.2):
|
||||
- message* (0 or more) -> final (exactly 1) -> close
|
||||
- OR message* (0 or more) -> error (exactly 1) -> close
|
||||
|
||||
State machine ensures:
|
||||
- No events after final/error
|
||||
- Only one final OR one error event
|
||||
- Proper connection close
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
|
||||
async def event_generator():
|
||||
"""
|
||||
[AC-AISVC-08, AC-AISVC-09] Event generator with state machine enforcement.
|
||||
Ensures proper event sequence and error handling.
|
||||
"""
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
try:
|
||||
async for event in orchestrator.generate_stream(tenant_id, chat_request):
|
||||
if not state_machine.can_send_message():
|
||||
logger.warning("[AC-AISVC-08] Received event after state machine closed, ignoring")
|
||||
break
|
||||
|
||||
if event.event == "final":
|
||||
if await state_machine.transition_to_final():
|
||||
logger.info("[AC-AISVC-08] Sending final event and closing stream")
|
||||
yield event
|
||||
break
|
||||
|
||||
elif event.event == "error":
|
||||
if await state_machine.transition_to_error():
|
||||
logger.info("[AC-AISVC-09] Sending error event and closing stream")
|
||||
yield event
|
||||
break
|
||||
|
||||
elif event.event == "message":
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
|
||||
if await state_machine.transition_to_error():
|
||||
yield create_error_event(
|
||||
code="STREAMING_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
await state_machine.close()
|
||||
logger.debug("[AC-AISVC-08] SSE connection closed")
|
||||
|
||||
return EventSourceResponse(event_generator(), ping=15)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
Health check endpoint.
|
||||
[AC-AISVC-20] Health check for service monitoring.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ai/health",
|
||||
operation_id="healthCheck",
|
||||
summary="Health check",
|
||||
description="[AC-AISVC-20] Check if AI service is healthy",
|
||||
responses={
|
||||
200: {"description": "Service is healthy"},
|
||||
503: {"description": "Service is unhealthy"},
|
||||
},
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
[AC-AISVC-20] Health check endpoint.
|
||||
Returns 200 with status if healthy, 503 if not.
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"status": "healthy"},
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Core module - Configuration, dependencies, and utilities.
|
||||
[AC-AISVC-01, AC-AISVC-10, AC-AISVC-11] Core infrastructure components.
|
||||
"""
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.core.database import async_session_maker, get_session, init_db, close_db
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"get_settings",
|
||||
"async_session_maker",
|
||||
"get_session",
|
||||
"init_db",
|
||||
"close_db",
|
||||
"QdrantClient",
|
||||
"get_qdrant_client",
|
||||
]
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Configuration management for AI Service.
|
||||
[AC-AISVC-01] Centralized configuration with environment variable support.
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="AI_SERVICE_", env_file=".env", extra="ignore")
|
||||
|
||||
app_name: str = "AI Service"
|
||||
app_version: str = "0.1.0"
|
||||
debug: bool = False
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8080
|
||||
|
||||
request_timeout_seconds: int = 20
|
||||
sse_ping_interval_seconds: int = 15
|
||||
|
||||
log_level: str = "INFO"
|
||||
|
||||
llm_provider: str = "openai"
|
||||
llm_api_key: str = ""
|
||||
llm_base_url: str = "https://api.openai.com/v1"
|
||||
llm_model: str = "gpt-4o-mini"
|
||||
llm_max_tokens: int = 2048
|
||||
llm_temperature: float = 0.7
|
||||
llm_timeout_seconds: int = 30
|
||||
llm_max_retries: int = 3
|
||||
|
||||
database_url: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/ai_service"
|
||||
database_pool_size: int = 10
|
||||
database_max_overflow: int = 20
|
||||
|
||||
qdrant_url: str = "http://localhost:6333"
|
||||
qdrant_collection_prefix: str = "kb_"
|
||||
qdrant_vector_size: int = 768
|
||||
|
||||
ollama_base_url: str = "http://localhost:11434"
|
||||
ollama_embedding_model: str = "nomic-embed-text"
|
||||
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.01
|
||||
rag_min_hits: int = 1
|
||||
rag_max_evidence_tokens: int = 2000
|
||||
|
||||
rag_two_stage_enabled: bool = True
|
||||
rag_two_stage_expand_factor: int = 10
|
||||
rag_hybrid_enabled: bool = True
|
||||
rag_rrf_k: int = 60
|
||||
rag_vector_weight: float = 0.7
|
||||
rag_bm25_weight: float = 0.3
|
||||
|
||||
confidence_low_threshold: float = 0.5
|
||||
confidence_high_threshold: float = 0.8
|
||||
confidence_insufficient_penalty: float = 0.3
|
||||
max_history_tokens: int = 4000
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Database client for AI Service.
|
||||
[AC-AISVC-11] PostgreSQL database with SQLModel for multi-tenant data isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""
|
||||
[AC-AISVC-11] Initialize database tables.
|
||||
Creates all tables defined in SQLModel metadata.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
logger.info("[AC-AISVC-11] Database tables initialized")
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""
|
||||
Close database connections.
|
||||
"""
|
||||
await engine.dispose()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
[AC-AISVC-11] Dependency injection for database session.
|
||||
Ensures proper session lifecycle management.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Exception handling for AI Service.
|
||||
[AC-AISVC-03, AC-AISVC-04, AC-AISVC-05] Structured error responses.
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.models import ErrorCode, ErrorResponse
|
||||
|
||||
|
||||
class AIServiceException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
code: ErrorCode,
|
||||
message: str,
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
details: list[dict] | None = None,
|
||||
):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
self.details = details
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MissingTenantIdException(AIServiceException):
|
||||
def __init__(self, message: str = "Missing required header: X-Tenant-Id"):
|
||||
super().__init__(
|
||||
code=ErrorCode.MISSING_TENANT_ID,
|
||||
message=message,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestException(AIServiceException):
|
||||
def __init__(self, message: str, details: list[dict] | None = None):
|
||||
super().__init__(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
message=message,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableException(AIServiceException):
|
||||
def __init__(self, message: str = "Service temporarily unavailable"):
|
||||
super().__init__(
|
||||
code=ErrorCode.SERVICE_UNAVAILABLE,
|
||||
message=message,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
class TimeoutException(AIServiceException):
|
||||
def __init__(self, message: str = "Request timeout"):
|
||||
super().__init__(
|
||||
code=ErrorCode.TIMEOUT,
|
||||
message=message,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
|
||||
async def ai_service_exception_handler(request: Request, exc: AIServiceException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
code=exc.code.value,
|
||||
message=exc.message,
|
||||
details=exc.details,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
if exc.status_code == status.HTTP_400_BAD_REQUEST:
|
||||
code = ErrorCode.INVALID_REQUEST
|
||||
elif exc.status_code == status.HTTP_503_SERVICE_UNAVAILABLE:
|
||||
code = ErrorCode.SERVICE_UNAVAILABLE
|
||||
else:
|
||||
code = ErrorCode.INTERNAL_ERROR
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
code=code.value,
|
||||
message=exc.detail or "An error occurred",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.INTERNAL_ERROR.value,
|
||||
message="An unexpected error occurred",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Middleware for AI Service.
|
||||
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
|
||||
from app.core.tenant import clear_tenant_context, set_tenant_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TENANT_ID_HEADER = "X-Tenant-Id"
|
||||
ACCEPT_HEADER = "Accept"
|
||||
SSE_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
# Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
|
||||
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
|
||||
|
||||
|
||||
def validate_tenant_id_format(tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Validate tenant ID format: name@ash@year
|
||||
Examples: szmp@ash@2026, abc123@ash@2025
|
||||
"""
|
||||
return bool(TENANT_ID_PATTERN.match(tenant_id))
|
||||
|
||||
|
||||
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
[AC-AISVC-10] Parse tenant ID into name and year.
|
||||
Returns: (name, year)
|
||||
"""
|
||||
parts = tenant_id.split('@')
|
||||
return parts[0], parts[2]
|
||||
|
||||
|
||||
class TenantContextMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
||||
Injects tenant context into request state for downstream processing.
|
||||
Validates tenant ID format and auto-creates tenant if not exists.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
clear_tenant_context()
|
||||
|
||||
if request.url.path == "/ai/health":
|
||||
return await call_next(request)
|
||||
|
||||
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
||||
|
||||
if not tenant_id or not tenant_id.strip():
|
||||
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.MISSING_TENANT_ID.value,
|
||||
message="Missing required header: X-Tenant-Id",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
tenant_id = tenant_id.strip()
|
||||
|
||||
# Validate tenant ID format
|
||||
if not validate_tenant_id_format(tenant_id):
|
||||
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.INVALID_TENANT_ID.value,
|
||||
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
# Auto-create tenant if not exists (for admin endpoints)
|
||||
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
|
||||
try:
|
||||
await self._ensure_tenant_exists(request, tenant_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
|
||||
# Continue processing even if tenant creation fails
|
||||
|
||||
set_tenant_context(tenant_id)
|
||||
request.state.tenant_id = tenant_id
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
clear_tenant_context()
|
||||
|
||||
return response
|
||||
|
||||
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
|
||||
"""
|
||||
[AC-AISVC-10] Ensure tenant exists in database, create if not.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.models.entities import Tenant
|
||||
|
||||
name, year = parse_tenant_id(tenant_id)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Check if tenant exists
|
||||
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
|
||||
result = await session.execute(stmt)
|
||||
existing_tenant = result.scalar_one_or_none()
|
||||
|
||||
if existing_tenant:
|
||||
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
|
||||
return
|
||||
|
||||
# Create new tenant
|
||||
new_tenant = Tenant(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
year=year,
|
||||
)
|
||||
session.add(new_tenant)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
|
||||
|
||||
|
||||
def is_sse_request(request: Request) -> bool:
|
||||
"""
|
||||
[AC-AISVC-06] Check if the request expects SSE streaming response.
|
||||
Based on Accept header: text/event-stream indicates SSE mode.
|
||||
"""
|
||||
accept_header = request.headers.get(ACCEPT_HEADER, "")
|
||||
return SSE_CONTENT_TYPE in accept_header
|
||||
|
||||
|
||||
def get_response_mode(request: Request) -> str:
|
||||
"""
|
||||
[AC-AISVC-06] Determine response mode based on Accept header.
|
||||
Returns 'streaming' for SSE, 'json' for regular JSON response.
|
||||
"""
|
||||
return "streaming" if is_sse_request(request) else "json"
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Shared prompt templates for AI Service.
|
||||
Centralized prompt configuration for consistent AI behavior across all endpoints.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """你是一名经验丰富的客服专员,名字叫"小N"。请遵循以下准则:
|
||||
像朋友一样聊天:使用自然、口语化的中文,语气亲切有礼貌,就像在和一位老朋友聊天。多使用"咱们"、"您看"、"是不是"这样拉近距离的表达。
|
||||
避免生硬罗列:介绍多个信息时,不要用"第一、第二"或分点列举,而是用流畅的段落把它们串联起来。比如:"我们这边有几类课程挺受欢迎的,一个是……另外还有……您对哪方面更感兴趣呀?"
|
||||
主动引导对话:不要一次性把知识库里的所有内容都倒出来。可以先简单概括,然后问问客户具体想了解什么,比如:"您是想提升职场英语,还是有留学方面的打算呢?"
|
||||
委婉处理未知:如果知识库里没有答案,别直接说"不知道",可以说:"您稍等,我帮您确认一下"、"这个问题我需要问问同事,马上给您回复"。
|
||||
隐藏机器身份:绝不要提及"作为AI"、"根据算法"等词,也避免使用过于官方的术语。就像普通人一样用"我"、"我们"。
|
||||
多用语气词和情感表达:适当加入"呢"、"哦"、"哈"、"呀"等语气词,根据情况使用表情符号(😊)。表达理解和关心,比如"我明白您的意思"、"别着急,我们一起看看"。
|
||||
保持简洁但有温度:回答一般控制在3-5句话,如果需要详细说明,可以分成小段落,但每段都要口语化,读起来不累。
|
||||
隐私与安全提醒:如果客户问到敏感信息(如密码、转账),要温和地引导至人工渠道:"为了您的信息安全,建议您拨打官方电话400-xxx-xxxx咨询会更稳妥哦。"""
|
||||
|
||||
|
||||
def format_evidence_for_prompt(
|
||||
retrieval_results: list,
|
||||
max_results: int = 5,
|
||||
max_content_length: int = 500
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieval results as evidence text for prompts.
|
||||
|
||||
Args:
|
||||
retrieval_results: List of retrieval hits. Can be:
|
||||
- dict format: {'content', 'score', 'source', 'metadata'}
|
||||
- RetrievalHit object: with .text, .score, .source, .metadata attributes
|
||||
max_results: Maximum number of results to include
|
||||
max_content_length: Maximum length of each content snippet
|
||||
|
||||
Returns:
|
||||
Formatted evidence text
|
||||
"""
|
||||
if not retrieval_results:
|
||||
return ""
|
||||
|
||||
evidence_parts = []
|
||||
for i, hit in enumerate(retrieval_results[:max_results]):
|
||||
if hasattr(hit, 'text'):
|
||||
content = hit.text
|
||||
score = hit.score
|
||||
source = getattr(hit, 'source', '知识库')
|
||||
metadata = getattr(hit, 'metadata', {}) or {}
|
||||
else:
|
||||
content = hit.get('content', '')
|
||||
score = hit.get('score', 0)
|
||||
source = hit.get('source', '知识库')
|
||||
metadata = hit.get('metadata', {}) or {}
|
||||
|
||||
if len(content) > max_content_length:
|
||||
content = content[:max_content_length] + '...'
|
||||
|
||||
nested_meta = metadata.get('metadata', {})
|
||||
source_doc = nested_meta.get('source_doc', source) if nested_meta else source
|
||||
category = nested_meta.get('category', '') if nested_meta else ''
|
||||
department = nested_meta.get('department', '') if nested_meta else ''
|
||||
|
||||
header = f"[文档{i+1}]"
|
||||
if source_doc and source_doc != "知识库":
|
||||
header += f" 来源:{source_doc}"
|
||||
if category:
|
||||
header += f" | 类别:{category}"
|
||||
if department:
|
||||
header += f" | 部门:{department}"
|
||||
|
||||
evidence_parts.append(f"{header}\n相关度:{score:.2f}\n内容:{content}")
|
||||
|
||||
return "\n\n".join(evidence_parts)
|
||||
|
||||
|
||||
def build_system_prompt_with_evidence(evidence_text: str) -> str:
|
||||
"""
|
||||
Build system prompt with knowledge base evidence.
|
||||
|
||||
Args:
|
||||
evidence_text: Formatted evidence from retrieval results
|
||||
|
||||
Returns:
|
||||
Complete system prompt
|
||||
"""
|
||||
if not evidence_text:
|
||||
return SYSTEM_PROMPT
|
||||
|
||||
return f"""{SYSTEM_PROMPT}
|
||||
|
||||
知识库参考内容:
|
||||
{evidence_text}"""
|
||||
|
||||
|
||||
def build_user_prompt_with_evidence(query: str, evidence_text: str) -> str:
|
||||
"""
|
||||
Build user prompt with knowledge base evidence (for single-message format).
|
||||
|
||||
Args:
|
||||
query: User's question
|
||||
evidence_text: Formatted evidence from retrieval results
|
||||
|
||||
Returns:
|
||||
Complete user prompt
|
||||
"""
|
||||
if not evidence_text:
|
||||
return f"""用户问题:{query}
|
||||
|
||||
未找到相关检索结果,请基于通用知识回答用户问题。"""
|
||||
|
||||
return f"""【系统指令】
|
||||
{SYSTEM_PROMPT}
|
||||
|
||||
【知识库内容】
|
||||
{evidence_text}
|
||||
|
||||
【用户问题】
|
||||
{query}"""
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
"""
|
||||
Qdrant client for AI Service.
|
||||
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
|
||||
Supports multi-dimensional vectors for Matryoshka representation learning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
"""
|
||||
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
|
||||
Collection naming: kb_{tenantId} for tenant isolation.
|
||||
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client: AsyncQdrantClient | None = None
|
||||
self._collection_prefix = settings.qdrant_collection_prefix
|
||||
self._vector_size = settings.qdrant_vector_size
|
||||
|
||||
async def get_client(self) -> AsyncQdrantClient:
|
||||
"""Get or create Qdrant client instance."""
|
||||
if self._client is None:
|
||||
self._client = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Qdrant client connection."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.info("Qdrant client connection closed")
|
||||
|
||||
def get_collection_name(self, tenant_id: str) -> str:
|
||||
"""
|
||||
[AC-AISVC-10] Get collection name for a tenant.
|
||||
Naming convention: kb_{tenantId}
|
||||
Replaces @ with _ to ensure valid collection names.
|
||||
"""
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
return f"{self._collection_prefix}{safe_tenant_id}"
|
||||
|
||||
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Ensure collection exists for tenant.
|
||||
Supports multi-dimensional vectors for Matryoshka retrieval.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
collections = await client.get_collections()
|
||||
exists = any(c.name == collection_name for c in collections.collections)
|
||||
|
||||
if not exists:
|
||||
if use_multi_vector:
|
||||
vectors_config = {
|
||||
"full": VectorParams(
|
||||
size=768,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
"dim_256": VectorParams(
|
||||
size=256,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
"dim_512": VectorParams(
|
||||
size=512,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
}
|
||||
else:
|
||||
vectors_config = VectorParams(
|
||||
size=self._vector_size,
|
||||
distance=Distance.COSINE,
|
||||
)
|
||||
|
||||
await client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
|
||||
f"with multi_vector={use_multi_vector}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
points: list[PointStruct],
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Upsert vectors into tenant's collection.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
await client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_multi_vector(
|
||||
self,
|
||||
tenant_id: str,
|
||||
points: list[dict[str, Any]],
|
||||
) -> bool:
|
||||
"""
|
||||
Upsert points with multi-dimensional vectors.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
points: List of points with format:
|
||||
{
|
||||
"id": str | int,
|
||||
"vector": {
|
||||
"full": [768 floats],
|
||||
"dim_256": [256 floats],
|
||||
"dim_512": [512 floats],
|
||||
},
|
||||
"payload": dict
|
||||
}
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
qdrant_points = []
|
||||
for p in points:
|
||||
point = PointStruct(
|
||||
id=p["id"],
|
||||
vector=p["vector"],
|
||||
payload=p.get("payload", {}),
|
||||
)
|
||||
qdrant_points.append(point)
|
||||
|
||||
await client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=qdrant_points,
|
||||
)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vector: list[float],
|
||||
limit: int = 5,
|
||||
score_threshold: float | None = None,
|
||||
vector_name: str = "full",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
Returns results with score >= score_threshold if specified.
|
||||
Searches both old format (with @) and new format (with _) for backward compatibility.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
query_vector: Query vector for similarity search
|
||||
limit: Maximum number of results
|
||||
score_threshold: Minimum score threshold for results
|
||||
vector_name: Name of the vector to search (for multi-vector collections)
|
||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||
)
|
||||
|
||||
collection_names = [self.get_collection_name(tenant_id)]
|
||||
if '@' in tenant_id:
|
||||
old_format = f"{self._collection_prefix}{tenant_id}"
|
||||
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
||||
collection_names = [new_format, old_format]
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
||||
|
||||
all_hits = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||
|
||||
try:
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
|
||||
f"trying without vector name (single-vector mode)"
|
||||
)
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
||||
)
|
||||
|
||||
hits = [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
if score_threshold is None or result.score >= score_threshold
|
||||
]
|
||||
all_hits.extend(hits)
|
||||
|
||||
if hits:
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
for i, h in enumerate(hits[:3]):
|
||||
logger.debug(
|
||||
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
if len(all_hits) == 0:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
|
||||
f"collections_tried={collection_names}, limit={limit}"
|
||||
)
|
||||
|
||||
return all_hits
|
||||
|
||||
async def delete_collection(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Delete tenant's collection.
|
||||
Used when tenant is removed.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
await client.delete_collection(collection_name=collection_name)
|
||||
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_qdrant_client: QdrantClient | None = None
|
||||
|
||||
|
||||
async def get_qdrant_client() -> QdrantClient:
|
||||
"""Get or create Qdrant client instance."""
|
||||
global _qdrant_client
|
||||
if _qdrant_client is None:
|
||||
_qdrant_client = QdrantClient()
|
||||
return _qdrant_client
|
||||
|
||||
|
||||
async def close_qdrant_client() -> None:
|
||||
"""Close Qdrant client connection."""
|
||||
global _qdrant_client
|
||||
if _qdrant_client:
|
||||
await _qdrant_client.close()
|
||||
_qdrant_client = None
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
SSE utilities for AI Service.
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSEState(str, Enum):
|
||||
INIT = "INIT"
|
||||
STREAMING = "STREAMING"
|
||||
FINAL_SENT = "FINAL_SENT"
|
||||
ERROR_SENT = "ERROR_SENT"
|
||||
CLOSED = "CLOSED"
|
||||
|
||||
|
||||
class SSEStateMachine:
|
||||
"""
|
||||
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
|
||||
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state = SSEState.INIT
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def state(self) -> SSEState:
|
||||
return self._state
|
||||
|
||||
async def transition_to_streaming(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state == SSEState.INIT:
|
||||
self._state = SSEState.STREAMING
|
||||
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def transition_to_final(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state == SSEState.STREAMING:
|
||||
self._state = SSEState.FINAL_SENT
|
||||
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def transition_to_error(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._state in (SSEState.INIT, SSEState.STREAMING):
|
||||
self._state = SSEState.ERROR_SENT
|
||||
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._lock:
|
||||
self._state = SSEState.CLOSED
|
||||
logger.debug("SSE state transition: -> CLOSED")
|
||||
|
||||
def can_send_message(self) -> bool:
|
||||
return self._state == SSEState.STREAMING
|
||||
|
||||
|
||||
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
|
||||
"""Format data as SSE event."""
|
||||
return ServerSentEvent(
|
||||
event=event_type.value,
|
||||
data=json.dumps(data, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
def create_message_event(delta: str) -> ServerSentEvent:
|
||||
"""[AC-AISVC-07] Create a message event with incremental content."""
|
||||
event_data = SSEMessageEvent(delta=delta)
|
||||
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
|
||||
|
||||
|
||||
def create_final_event(
|
||||
reply: str,
|
||||
confidence: float,
|
||||
should_transfer: bool,
|
||||
transfer_reason: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ServerSentEvent:
|
||||
"""[AC-AISVC-08] Create a final event with complete response."""
|
||||
event_data = SSEFinalEvent(
|
||||
reply=reply,
|
||||
confidence=confidence,
|
||||
should_transfer=should_transfer,
|
||||
transfer_reason=transfer_reason,
|
||||
metadata=metadata,
|
||||
)
|
||||
return format_sse_event(
|
||||
SSEEventType.FINAL,
|
||||
event_data.model_dump(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
|
||||
def create_error_event(
|
||||
code: str,
|
||||
message: str,
|
||||
details: list[dict[str, Any]] | None = None,
|
||||
) -> ServerSentEvent:
|
||||
"""[AC-AISVC-09] Create an error event."""
|
||||
event_data = SSEErrorEvent(
|
||||
code=code,
|
||||
message=message,
|
||||
details=details,
|
||||
)
|
||||
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
|
||||
Sends ': ping' as comment lines (not events) to keep connection alive.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(interval_seconds)
|
||||
yield ": ping\n\n"
|
||||
|
||||
|
||||
class SSEResponseBuilder:
|
||||
"""
|
||||
Builder for SSE response with proper event sequencing and ping keep-alive.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state_machine = SSEStateMachine()
|
||||
self._settings = get_settings()
|
||||
|
||||
async def build_response(
|
||||
self,
|
||||
content_generator: AsyncGenerator[ServerSentEvent, None],
|
||||
) -> EventSourceResponse:
|
||||
"""
|
||||
Build SSE response with ping keep-alive mechanism.
|
||||
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
|
||||
"""
|
||||
|
||||
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
|
||||
await self._state_machine.transition_to_streaming()
|
||||
try:
|
||||
async for event in content_generator:
|
||||
if self._state_machine.can_send_message():
|
||||
yield event
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
|
||||
if await self._state_machine.transition_to_error():
|
||||
yield create_error_event(
|
||||
code="STREAMING_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
finally:
|
||||
await self._state_machine.close()
|
||||
|
||||
return EventSourceResponse(
|
||||
event_generator(),
|
||||
ping=self._settings.sse_ping_interval_seconds,
|
||||
)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""
|
||||
Tenant context management.
|
||||
[AC-AISVC-10, AC-AISVC-12] Multi-tenant isolation via X-Tenant-Id header.
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
tenant_context: ContextVar["TenantContext | None"] = ContextVar("tenant_context", default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TenantContext:
|
||||
tenant_id: str
|
||||
|
||||
|
||||
def set_tenant_context(tenant_id: str) -> None:
|
||||
tenant_context.set(TenantContext(tenant_id=tenant_id))
|
||||
|
||||
|
||||
def get_tenant_context() -> TenantContext | None:
|
||||
return tenant_context.get()
|
||||
|
||||
|
||||
def get_tenant_id() -> str | None:
|
||||
ctx = get_tenant_context()
|
||||
return ctx.tenant_id if ctx else None
|
||||
|
||||
|
||||
def clear_tenant_context() -> None:
|
||||
tenant_context.set(None)
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Main FastAPI application for AI Service.
|
||||
[AC-AISVC-01] Entry point with middleware and exception handlers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api import chat_router, health_router
|
||||
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
|
||||
from app.api.admin.kb_optimized import router as kb_optimized_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import close_db, init_db
|
||||
from app.core.exceptions import (
|
||||
AIServiceException,
|
||||
ErrorCode,
|
||||
ErrorResponse,
|
||||
ai_service_exception_handler,
|
||||
generic_exception_handler,
|
||||
http_exception_handler,
|
||||
)
|
||||
from app.core.middleware import TenantContextMiddleware
|
||||
from app.core.qdrant_client import close_qdrant_client
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager.
|
||||
Handles startup and shutdown of database and external connections.
|
||||
"""
|
||||
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
|
||||
|
||||
try:
|
||||
await init_db()
|
||||
logger.info("[AC-AISVC-11] Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
|
||||
|
||||
yield
|
||||
|
||||
await close_db()
|
||||
await close_qdrant_client()
|
||||
logger.info(f"Shutting down {settings.app_name}")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="""
|
||||
Python AI Service for intelligent chat with RAG support.
|
||||
|
||||
## Features
|
||||
- Multi-tenant isolation via X-Tenant-Id header
|
||||
- SSE streaming support via Accept: text/event-stream
|
||||
- RAG-powered responses with confidence scoring
|
||||
|
||||
## Response Modes
|
||||
- **JSON**: Default response mode (Accept: application/json or no Accept header)
|
||||
- **SSE Streaming**: Set Accept: text/event-stream for streaming responses
|
||||
""",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(TenantContextMiddleware)
|
||||
|
||||
app.add_exception_handler(AIServiceException, ai_service_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(Exception, generic_exception_handler)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""
|
||||
[AC-AISVC-03] Handle request validation errors with structured response.
|
||||
"""
|
||||
logger.warning(f"[AC-AISVC-03] Request validation error: {exc.errors()}")
|
||||
error_response = ErrorResponse(
|
||||
code=ErrorCode.INVALID_REQUEST.value,
|
||||
message="Request validation failed",
|
||||
details=[{"loc": list(err["loc"]), "msg": err["msg"], "type": err["type"]} for err in exc.errors()],
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=error_response.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
app.include_router(health_router)
|
||||
app.include_router(chat_router)
|
||||
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(embedding_router)
|
||||
app.include_router(kb_router)
|
||||
app.include_router(kb_optimized_router)
|
||||
app.include_router(llm_router)
|
||||
app.include_router(rag_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(tenants_router)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Data models for AI Service.
|
||||
[AC-AISVC-02] Request/Response models aligned with OpenAPI contract.
|
||||
[AC-AISVC-13] Entity models for database persistence.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
WECHAT = "wechat"
|
||||
DOUYIN = "douyin"
|
||||
JD = "jd"
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role = Field(..., description="Message role: user or assistant")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str = Field(..., alias="sessionId", description="Session ID for conversation tracking")
|
||||
current_message: str = Field(..., alias="currentMessage", description="Current user message")
|
||||
channel_type: ChannelType = Field(..., alias="channelType", description="Channel type: wechat, douyin, jd")
|
||||
history: list[ChatMessage] | None = Field(default=None, description="Optional conversation history")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str = Field(..., description="AI generated reply content")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0.0 and 1.0")
|
||||
should_transfer: bool = Field(..., alias="shouldTransfer", description="Whether to suggest transfer to human agent")
|
||||
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Reason for transfer suggestion")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
MISSING_TENANT_ID = "MISSING_TENANT_ID"
|
||||
INVALID_TENANT_ID = "INVALID_TENANT_ID"
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
LLM_ERROR = "LLM_ERROR"
|
||||
RETRIEVAL_ERROR = "RETRIEVAL_ERROR"
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
code: str = Field(..., description="Error code")
|
||||
message: str = Field(..., description="Error message")
|
||||
details: list[dict[str, Any]] | None = Field(default=None, description="Detailed error information")
|
||||
|
||||
|
||||
class SSEEventType(str, Enum):
|
||||
MESSAGE = "message"
|
||||
FINAL = "final"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SSEMessageEvent(BaseModel):
|
||||
delta: str = Field(..., description="Incremental text content")
|
||||
|
||||
|
||||
class SSEFinalEvent(BaseModel):
|
||||
reply: str = Field(..., description="Complete AI reply")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
||||
should_transfer: bool = Field(..., alias="shouldTransfer", description="Transfer suggestion")
|
||||
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Transfer reason")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
class SSEErrorEvent(BaseModel):
|
||||
code: str = Field(..., description="Error code")
|
||||
message: str = Field(..., description="Error message")
|
||||
details: list[dict[str, Any]] | None = Field(default=None, description="Error details")
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
"""
|
||||
Memory layer entities for AI Service.
|
||||
[AC-AISVC-13] SQLModel entities for chat sessions and messages with tenant isolation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column, JSON
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class ChatSession(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-13] Chat session entity with tenant isolation.
|
||||
Primary key: (tenant_id, session_id) composite unique constraint.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_sessions"
|
||||
__table_args__ = (
|
||||
Index("ix_chat_sessions_tenant_session", "tenant_id", "session_id", unique=True),
|
||||
Index("ix_chat_sessions_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
session_id: str = Field(..., description="Session ID for conversation tracking")
|
||||
channel_type: str | None = Field(default=None, description="Channel type: wechat, douyin, jd")
|
||||
metadata_: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("metadata", JSON, nullable=True),
|
||||
description="Session metadata"
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class ChatMessage(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-13] Chat message entity with tenant isolation.
|
||||
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
__table_args__ = (
|
||||
Index("ix_chat_messages_tenant_session", "tenant_id", "session_id"),
|
||||
Index("ix_chat_messages_tenant_session_created", "tenant_id", "session_id", "created_at"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
|
||||
role: str = Field(..., description="Message role: user or assistant")
|
||||
content: str = Field(..., description="Message content")
|
||||
prompt_tokens: int | None = Field(default=None, description="Number of prompt tokens used")
|
||||
completion_tokens: int | None = Field(default=None, description="Number of completion tokens used")
|
||||
total_tokens: int | None = Field(default=None, description="Total tokens used")
|
||||
latency_ms: int | None = Field(default=None, description="Response latency in milliseconds")
|
||||
first_token_ms: int | None = Field(default=None, description="Time to first token in milliseconds (for streaming)")
|
||||
is_error: bool = Field(default=False, description="Whether this message is an error response")
|
||||
error_message: str | None = Field(default=None, description="Error message if any")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Message creation time")
|
||||
|
||||
|
||||
class ChatSessionCreate(SQLModel):
|
||||
"""Schema for creating a new chat session."""
|
||||
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
channel_type: str | None = None
|
||||
metadata_: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageCreate(SQLModel):
|
||||
"""Schema for creating a new chat message."""
|
||||
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class IndexJobStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class SessionStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
CLOSED = "closed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class Tenant(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-10] Tenant entity for storing tenant information.
|
||||
Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
|
||||
"""
|
||||
|
||||
__tablename__ = "tenants"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Full tenant ID (format: name@ash@year)", unique=True, index=True)
|
||||
name: str = Field(..., description="Tenant display name (first part of tenant_id)")
|
||||
year: str = Field(..., description="Year part from tenant_id")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
"""
|
||||
[AC-ASA-01] Knowledge base entity with tenant isolation.
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
__table_args__ = (
|
||||
Index("ix_knowledge_bases_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
name: str = Field(..., description="Knowledge base name")
|
||||
description: str | None = Field(default=None, description="Knowledge base description")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class Document(SQLModel, table=True):
|
||||
"""
|
||||
[AC-ASA-01, AC-ASA-08] Document entity with tenant isolation.
|
||||
"""
|
||||
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
Index("ix_documents_tenant_kb", "tenant_id", "kb_id"),
|
||||
Index("ix_documents_tenant_status", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
kb_id: str = Field(..., description="Knowledge base ID")
|
||||
file_name: str = Field(..., description="Original file name")
|
||||
file_path: str | None = Field(default=None, description="Storage path")
|
||||
file_size: int | None = Field(default=None, description="File size in bytes")
|
||||
file_type: str | None = Field(default=None, description="File MIME type")
|
||||
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
|
||||
error_msg: str | None = Field(default=None, description="Error message if failed")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class IndexJob(SQLModel, table=True):
|
||||
"""
|
||||
[AC-ASA-02] Index job entity for tracking document indexing progress.
|
||||
"""
|
||||
|
||||
__tablename__ = "index_jobs"
|
||||
__table_args__ = (
|
||||
Index("ix_index_jobs_tenant_doc", "tenant_id", "doc_id"),
|
||||
Index("ix_index_jobs_tenant_status", "tenant_id", "status"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
doc_id: uuid.UUID = Field(..., description="Document ID being indexed")
|
||||
status: str = Field(default=IndexJobStatus.PENDING.value, description="Job status")
|
||||
progress: int = Field(default=0, ge=0, le=100, description="Progress percentage")
|
||||
error_msg: str | None = Field(default=None, description="Error message if failed")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Job creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class KnowledgeBaseCreate(SQLModel):
|
||||
"""Schema for creating a new knowledge base."""
|
||||
|
||||
tenant_id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class DocumentCreate(SQLModel):
|
||||
"""Schema for creating a new document."""
|
||||
|
||||
tenant_id: str
|
||||
kb_id: str
|
||||
file_name: str
|
||||
file_path: str | None = None
|
||||
file_size: int | None = None
|
||||
file_type: str | None = None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Services module for AI Service.
|
||||
[AC-AISVC-13, AC-AISVC-16] Core services for memory and retrieval.
|
||||
"""
|
||||
|
||||
from app.services.memory import MemoryService
|
||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
||||
|
||||
__all__ = ["MemoryService", "OrchestratorService", "get_orchestrator_service"]
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
"""
|
||||
Confidence calculation for AI Service.
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Confidence scoring and transfer suggestion logic.
|
||||
|
||||
Design reference: design.md Section 4.3 - 检索不中兜底与置信度策略
|
||||
- Retrieval insufficiency detection
|
||||
- Confidence calculation based on retrieval scores
|
||||
- shouldTransfer logic with threshold T_low
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.retrieval.base import RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfidenceConfig:
|
||||
"""
|
||||
Configuration for confidence calculation.
|
||||
[AC-AISVC-17, AC-AISVC-18] Configurable thresholds.
|
||||
"""
|
||||
score_threshold: float = 0.7
|
||||
min_hits: int = 1
|
||||
confidence_low_threshold: float = 0.5
|
||||
confidence_high_threshold: float = 0.8
|
||||
insufficient_penalty: float = 0.3
|
||||
max_evidence_tokens: int = 2000
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfidenceResult:
|
||||
"""
|
||||
Result of confidence calculation.
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Contains confidence and transfer suggestion.
|
||||
"""
|
||||
confidence: float
|
||||
should_transfer: bool
|
||||
transfer_reason: str | None = None
|
||||
is_retrieval_insufficient: bool = False
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ConfidenceCalculator:
|
||||
"""
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculator for response confidence.
|
||||
|
||||
Design reference: design.md Section 4.3
|
||||
- MVP: confidence based on RAG retrieval scores
|
||||
- Insufficient retrieval triggers confidence downgrade
|
||||
- shouldTransfer when confidence < T_low
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfidenceConfig | None = None):
|
||||
settings = get_settings()
|
||||
self._config = config or ConfidenceConfig(
|
||||
score_threshold=getattr(settings, "rag_score_threshold", 0.7),
|
||||
min_hits=getattr(settings, "rag_min_hits", 1),
|
||||
confidence_low_threshold=getattr(settings, "confidence_low_threshold", 0.5),
|
||||
confidence_high_threshold=getattr(settings, "confidence_high_threshold", 0.8),
|
||||
insufficient_penalty=getattr(settings, "confidence_insufficient_penalty", 0.3),
|
||||
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
|
||||
)
|
||||
|
||||
def is_retrieval_insufficient(
|
||||
self,
|
||||
retrieval_result: RetrievalResult,
|
||||
evidence_tokens: int | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
[AC-AISVC-17] Determine if retrieval results are insufficient.
|
||||
|
||||
Conditions for insufficiency:
|
||||
1. hits.size < min_hits
|
||||
2. max(score) < score_threshold
|
||||
3. evidence tokens exceed limit (optional)
|
||||
|
||||
Args:
|
||||
retrieval_result: Result from retrieval operation
|
||||
evidence_tokens: Optional token count for evidence
|
||||
|
||||
Returns:
|
||||
Tuple of (is_insufficient, reason)
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
if retrieval_result.hit_count < self._config.min_hits:
|
||||
reasons.append(
|
||||
f"hit_count({retrieval_result.hit_count}) < min_hits({self._config.min_hits})"
|
||||
)
|
||||
|
||||
if retrieval_result.max_score < self._config.score_threshold:
|
||||
reasons.append(
|
||||
f"max_score({retrieval_result.max_score:.3f}) < threshold({self._config.score_threshold})"
|
||||
)
|
||||
|
||||
if evidence_tokens is not None and evidence_tokens > self._config.max_evidence_tokens:
|
||||
reasons.append(
|
||||
f"evidence_tokens({evidence_tokens}) > max({self._config.max_evidence_tokens})"
|
||||
)
|
||||
|
||||
is_insufficient = len(reasons) > 0
|
||||
reason = "; ".join(reasons) if reasons else "sufficient"
|
||||
|
||||
return is_insufficient, reason
|
||||
|
||||
def calculate_confidence(
|
||||
self,
|
||||
retrieval_result: RetrievalResult,
|
||||
evidence_tokens: int | None = None,
|
||||
additional_factors: dict[str, float] | None = None,
|
||||
) -> ConfidenceResult:
|
||||
"""
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence and transfer suggestion.
|
||||
|
||||
MVP Strategy:
|
||||
1. Base confidence from max retrieval score
|
||||
2. Adjust for hit count (more hits = higher confidence)
|
||||
3. Penalize if retrieval is insufficient
|
||||
4. Determine shouldTransfer based on T_low threshold
|
||||
|
||||
Args:
|
||||
retrieval_result: Result from retrieval operation
|
||||
evidence_tokens: Optional token count for evidence
|
||||
additional_factors: Optional additional confidence factors
|
||||
|
||||
Returns:
|
||||
ConfidenceResult with confidence and transfer suggestion
|
||||
"""
|
||||
is_insufficient, insufficiency_reason = self.is_retrieval_insufficient(
|
||||
retrieval_result, evidence_tokens
|
||||
)
|
||||
|
||||
base_confidence = retrieval_result.max_score
|
||||
|
||||
hit_count_factor = min(1.0, retrieval_result.hit_count / 5.0)
|
||||
confidence = base_confidence * 0.7 + hit_count_factor * 0.3
|
||||
|
||||
if is_insufficient:
|
||||
confidence -= self._config.insufficient_penalty
|
||||
logger.info(
|
||||
f"[AC-AISVC-17] Retrieval insufficient: {insufficiency_reason}, "
|
||||
f"applying penalty -{self._config.insufficient_penalty}"
|
||||
)
|
||||
|
||||
if additional_factors:
|
||||
for factor_name, factor_value in additional_factors.items():
|
||||
confidence += factor_value * 0.1
|
||||
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
|
||||
should_transfer = confidence < self._config.confidence_low_threshold
|
||||
transfer_reason = None
|
||||
|
||||
if should_transfer:
|
||||
if is_insufficient:
|
||||
transfer_reason = "检索结果不足,无法提供高置信度回答"
|
||||
else:
|
||||
transfer_reason = "置信度低于阈值,建议转人工"
|
||||
elif confidence < self._config.confidence_high_threshold and is_insufficient:
|
||||
transfer_reason = "检索结果有限,回答可能不够准确"
|
||||
|
||||
diagnostics = {
|
||||
"base_confidence": base_confidence,
|
||||
"hit_count": retrieval_result.hit_count,
|
||||
"max_score": retrieval_result.max_score,
|
||||
"is_insufficient": is_insufficient,
|
||||
"insufficiency_reason": insufficiency_reason if is_insufficient else None,
|
||||
"penalty_applied": self._config.insufficient_penalty if is_insufficient else 0.0,
|
||||
"threshold_low": self._config.confidence_low_threshold,
|
||||
"threshold_high": self._config.confidence_high_threshold,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
|
||||
f"{confidence:.3f}, should_transfer={should_transfer}, "
|
||||
f"insufficient={is_insufficient}"
|
||||
)
|
||||
|
||||
return ConfidenceResult(
|
||||
confidence=round(confidence, 3),
|
||||
should_transfer=should_transfer,
|
||||
transfer_reason=transfer_reason,
|
||||
is_retrieval_insufficient=is_insufficient,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
def calculate_confidence_no_retrieval(self) -> ConfidenceResult:
|
||||
"""
|
||||
[AC-AISVC-17] Calculate confidence when no retrieval was performed.
|
||||
|
||||
Returns a low confidence result suggesting transfer.
|
||||
"""
|
||||
return ConfidenceResult(
|
||||
confidence=0.3,
|
||||
should_transfer=True,
|
||||
transfer_reason="未进行知识库检索,建议转人工",
|
||||
is_retrieval_insufficient=True,
|
||||
diagnostics={
|
||||
"base_confidence": 0.0,
|
||||
"hit_count": 0,
|
||||
"max_score": 0.0,
|
||||
"is_insufficient": True,
|
||||
"insufficiency_reason": "no_retrieval",
|
||||
"penalty_applied": 0.0,
|
||||
"threshold_low": self._config.confidence_low_threshold,
|
||||
"threshold_high": self._config.confidence_high_threshold,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_confidence_calculator: ConfidenceCalculator | None = None
|
||||
|
||||
|
||||
def get_confidence_calculator() -> ConfidenceCalculator:
|
||||
"""Get or create confidence calculator instance."""
|
||||
global _confidence_calculator
|
||||
if _confidence_calculator is None:
|
||||
_confidence_calculator = ConfidenceCalculator()
|
||||
return _confidence_calculator
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
Context management utilities for AI Service.
|
||||
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
|
||||
|
||||
Design reference: design.md Section 7 - 上下文合并规则
|
||||
- H_local: Memory layer history (sorted by time)
|
||||
- H_ext: External history from Java request (in passed order)
|
||||
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
|
||||
- Truncation: Keep most recent N messages within token budget
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models import ChatMessage, Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergedContext:
|
||||
"""
|
||||
Result of context merging.
|
||||
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
|
||||
"""
|
||||
messages: list[dict[str, str]] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
local_count: int = 0
|
||||
external_count: int = 0
|
||||
duplicates_skipped: int = 0
|
||||
truncated_count: int = 0
|
||||
diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ContextMerger:
|
||||
"""
|
||||
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
|
||||
|
||||
Design reference: design.md Section 7
|
||||
- Deduplication based on message fingerprint
|
||||
- Priority: local history takes precedence
|
||||
- Token-based truncation using tiktoken
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_history_tokens: int | None = None,
|
||||
encoding_name: str = "cl100k_base",
|
||||
):
|
||||
settings = get_settings()
|
||||
self._max_history_tokens = max_history_tokens or 4096
|
||||
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def compute_fingerprint(self, role: str, content: str) -> str:
|
||||
"""
|
||||
Compute message fingerprint for deduplication.
|
||||
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the normalized message
|
||||
"""
|
||||
normalized_content = content.strip()
|
||||
fingerprint_input = f"{role}|{normalized_content}"
|
||||
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
|
||||
|
||||
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
|
||||
"""Convert ChatMessage or dict to standard dict format."""
|
||||
if isinstance(message, ChatMessage):
|
||||
return {"role": message.role.value, "content": message.content}
|
||||
return message
|
||||
|
||||
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Count total tokens in messages using tiktoken.
|
||||
[AC-AISVC-14] Token counting for history truncation.
|
||||
"""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
total += len(self._encoding.encode(msg.get("role", "")))
|
||||
total += len(self._encoding.encode(msg.get("content", "")))
|
||||
total += 4 # Approximate overhead for message structure
|
||||
return total
|
||||
|
||||
def merge_context(
|
||||
self,
|
||||
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
) -> MergedContext:
|
||||
"""
|
||||
Merge local and external history with deduplication.
|
||||
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
|
||||
|
||||
Design reference: design.md Section 7.2
|
||||
1. Build seen set from H_local
|
||||
2. Traverse H_ext, append if fingerprint not seen
|
||||
3. Local history takes priority
|
||||
|
||||
Args:
|
||||
local_history: History from Memory layer (H_local)
|
||||
external_history: History from Java request (H_ext)
|
||||
|
||||
Returns:
|
||||
MergedContext with merged messages and diagnostics
|
||||
"""
|
||||
result = MergedContext()
|
||||
seen_fingerprints: set[str] = set()
|
||||
merged_messages: list[dict[str, str]] = []
|
||||
diagnostics: list[dict[str, Any]] = []
|
||||
|
||||
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
|
||||
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
|
||||
|
||||
for msg in local_messages:
|
||||
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||
seen_fingerprints.add(fingerprint)
|
||||
merged_messages.append(msg)
|
||||
result.local_count += 1
|
||||
|
||||
for msg in external_messages:
|
||||
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||
if fingerprint not in seen_fingerprints:
|
||||
seen_fingerprints.add(fingerprint)
|
||||
merged_messages.append(msg)
|
||||
result.external_count += 1
|
||||
else:
|
||||
result.duplicates_skipped += 1
|
||||
diagnostics.append({
|
||||
"type": "duplicate_skipped",
|
||||
"role": msg["role"],
|
||||
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
|
||||
})
|
||||
|
||||
result.messages = merged_messages
|
||||
result.diagnostics = diagnostics
|
||||
result.total_tokens = self._count_tokens(merged_messages)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||||
f"local={result.local_count}, external={result.external_count}, "
|
||||
f"duplicates_skipped={result.duplicates_skipped}, "
|
||||
f"total_tokens={result.total_tokens}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def truncate_context(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[list[dict[str, str]], int]:
|
||||
"""
|
||||
Truncate context to fit within token budget.
|
||||
[AC-AISVC-14] Keep most recent N messages within budget.
|
||||
|
||||
Design reference: design.md Section 7.4
|
||||
- Budget = maxHistoryTokens (configurable)
|
||||
- Strategy: Keep most recent messages (from tail backward)
|
||||
|
||||
Args:
|
||||
messages: List of messages to truncate
|
||||
max_tokens: Maximum token budget (uses default if not provided)
|
||||
|
||||
Returns:
|
||||
Tuple of (truncated messages, truncated count)
|
||||
"""
|
||||
budget = max_tokens or self._max_history_tokens
|
||||
if not messages:
|
||||
return [], 0
|
||||
|
||||
total_tokens = self._count_tokens(messages)
|
||||
if total_tokens <= budget:
|
||||
return messages, 0
|
||||
|
||||
truncated_messages: list[dict[str, str]] = []
|
||||
current_tokens = 0
|
||||
truncated_count = 0
|
||||
|
||||
for msg in reversed(messages):
|
||||
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
|
||||
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
|
||||
msg_tokens += 4
|
||||
|
||||
if current_tokens + msg_tokens <= budget:
|
||||
truncated_messages.insert(0, msg)
|
||||
current_tokens += msg_tokens
|
||||
else:
|
||||
truncated_count += 1
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-14] Context truncated: "
|
||||
f"original={len(messages)}, truncated={len(truncated_messages)}, "
|
||||
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
|
||||
)
|
||||
|
||||
return truncated_messages, truncated_count
|
||||
|
||||
def merge_and_truncate(
|
||||
self,
|
||||
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||
max_tokens: int | None = None,
|
||||
) -> MergedContext:
|
||||
"""
|
||||
Merge and truncate context in one operation.
|
||||
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
|
||||
|
||||
Args:
|
||||
local_history: History from Memory layer (H_local)
|
||||
external_history: History from Java request (H_ext)
|
||||
max_tokens: Maximum token budget
|
||||
|
||||
Returns:
|
||||
MergedContext with final messages after merge and truncate
|
||||
"""
|
||||
merged = self.merge_context(local_history, external_history)
|
||||
|
||||
truncated_messages, truncated_count = self.truncate_context(
|
||||
merged.messages, max_tokens
|
||||
)
|
||||
|
||||
merged.messages = truncated_messages
|
||||
merged.truncated_count = truncated_count
|
||||
merged.total_tokens = self._count_tokens(truncated_messages)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
_context_merger: ContextMerger | None = None
|
||||
|
||||
|
||||
def get_context_merger() -> ContextMerger:
|
||||
"""Get or create context merger instance."""
|
||||
global _context_merger
|
||||
if _context_merger is None:
|
||||
_context_merger = ContextMerger()
|
||||
return _context_merger
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""
|
||||
Document parsing services package.
|
||||
[AC-AISVC-33] Provides document parsers for various formats.
|
||||
"""
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
PageText,
|
||||
ParseResult,
|
||||
UnsupportedFormatError,
|
||||
)
|
||||
from app.services.document.excel_parser import CSVParser, ExcelParser
|
||||
from app.services.document.factory import (
|
||||
DocumentParserFactory,
|
||||
get_supported_document_formats,
|
||||
parse_document,
|
||||
)
|
||||
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
|
||||
from app.services.document.text_parser import TextParser
|
||||
from app.services.document.word_parser import WordParser
|
||||
|
||||
__all__ = [
|
||||
"DocumentParseException",
|
||||
"DocumentParser",
|
||||
"PageText",
|
||||
"ParseResult",
|
||||
"UnsupportedFormatError",
|
||||
"DocumentParserFactory",
|
||||
"get_supported_document_formats",
|
||||
"parse_document",
|
||||
"PDFParser",
|
||||
"PDFPlumberParser",
|
||||
"WordParser",
|
||||
"ExcelParser",
|
||||
"CSVParser",
|
||||
"TextParser",
|
||||
]
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Base document parser interface.
|
||||
[AC-AISVC-33] Abstract interface for document parsers.
|
||||
|
||||
Design reference: progress.md Section 7.2 - DocumentParser interface
|
||||
- parse(file_path) -> str
|
||||
- get_supported_extensions() -> list[str]
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageText:
|
||||
"""
|
||||
Text content from a single page.
|
||||
"""
|
||||
page: int
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseResult:
|
||||
"""
|
||||
Result from document parsing.
|
||||
[AC-AISVC-33] Contains parsed text and metadata.
|
||||
"""
|
||||
text: str
|
||||
source_path: str
|
||||
file_size: int
|
||||
page_count: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
pages: list[PageText] = field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentParser(ABC):
|
||||
"""
|
||||
Abstract base class for document parsers.
|
||||
[AC-AISVC-33] Provides unified interface for different document formats.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a document and extract text content.
|
||||
[AC-AISVC-33] Returns parsed text content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file.
|
||||
|
||||
Returns:
|
||||
ParseResult with extracted text and metadata.
|
||||
|
||||
Raises:
|
||||
DocumentParseException: If parsing fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""
|
||||
Get list of supported file extensions.
|
||||
[AC-AISVC-37] Returns supported format list.
|
||||
|
||||
Returns:
|
||||
List of file extensions (e.g., [".pdf", ".txt"])
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_extension(self, extension: str) -> bool:
|
||||
"""
|
||||
Check if this parser supports a given file extension.
|
||||
[AC-AISVC-37] Validates file format support.
|
||||
|
||||
Args:
|
||||
extension: File extension to check.
|
||||
|
||||
Returns:
|
||||
True if extension is supported.
|
||||
"""
|
||||
normalized = extension.lower()
|
||||
if not normalized.startswith("."):
|
||||
normalized = f".{normalized}"
|
||||
return normalized in self.get_supported_extensions()
|
||||
|
||||
|
||||
class DocumentParseException(Exception):
|
||||
"""Exception raised when document parsing fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
file_path: str = "",
|
||||
parser: str = "",
|
||||
details: dict[str, Any] | None = None
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.parser = parser
|
||||
self.details = details or {}
|
||||
super().__init__(f"[{parser}] {message}" if parser else message)
|
||||
|
||||
|
||||
class UnsupportedFormatError(DocumentParseException):
|
||||
"""Exception raised when file format is not supported."""
|
||||
|
||||
def __init__(self, extension: str, supported: list[str]):
|
||||
super().__init__(
|
||||
f"Unsupported file format: {extension}. "
|
||||
f"Supported formats: {', '.join(supported)}",
|
||||
parser="format_checker"
|
||||
)
|
||||
self.extension = extension
|
||||
self.supported_formats = supported
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
"""
|
||||
Excel document parser implementation.
|
||||
[AC-AISVC-35] Excel (.xlsx) parsing using openpyxl.
|
||||
|
||||
Extracts text content from Excel spreadsheets and converts to JSON format
|
||||
to preserve structural relationships for better RAG retrieval.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelParser(DocumentParser):
|
||||
"""
|
||||
Parser for Excel documents.
|
||||
[AC-AISVC-35] Uses openpyxl for text extraction.
|
||||
Converts spreadsheet data to JSON format to preserve structure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
include_empty_cells: bool = False,
|
||||
max_rows_per_sheet: int = 10000,
|
||||
**kwargs: Any
|
||||
):
|
||||
self._include_empty_cells = include_empty_cells
|
||||
self._max_rows_per_sheet = max_rows_per_sheet
|
||||
self._extra_config = kwargs
|
||||
self._openpyxl = None
|
||||
|
||||
def _get_openpyxl(self):
|
||||
"""Lazy import of openpyxl."""
|
||||
if self._openpyxl is None:
|
||||
try:
|
||||
import openpyxl
|
||||
self._openpyxl = openpyxl
|
||||
except ImportError:
|
||||
raise DocumentParseException(
|
||||
"openpyxl not installed. Install with: pip install openpyxl",
|
||||
parser="excel"
|
||||
)
|
||||
return self._openpyxl
|
||||
|
||||
def _sheet_to_records(self, sheet, sheet_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert a worksheet to a list of record dictionaries.
|
||||
First row is treated as header (column names).
|
||||
"""
|
||||
records = []
|
||||
rows = list(sheet.iter_rows(max_row=self._max_rows_per_sheet, values_only=True))
|
||||
|
||||
if not rows:
|
||||
return records
|
||||
|
||||
headers = rows[0]
|
||||
header_list = [str(h) if h is not None else f"column_{i}" for i, h in enumerate(headers)]
|
||||
|
||||
for row in rows[1:]:
|
||||
record = {"_sheet": sheet_name}
|
||||
has_content = False
|
||||
|
||||
for i, value in enumerate(row):
|
||||
if i < len(header_list):
|
||||
key = header_list[i]
|
||||
else:
|
||||
key = f"column_{i}"
|
||||
|
||||
if value is not None:
|
||||
has_content = True
|
||||
if isinstance(value, (int, float, bool)):
|
||||
record[key] = value
|
||||
else:
|
||||
record[key] = str(value)
|
||||
elif self._include_empty_cells:
|
||||
record[key] = None
|
||||
|
||||
if has_content or self._include_empty_cells:
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse an Excel document and extract text content as JSON.
|
||||
[AC-AISVC-35] Converts spreadsheet data to JSON format.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="excel"
|
||||
)
|
||||
|
||||
if not self.supports_extension(path.suffix):
|
||||
raise DocumentParseException(
|
||||
f"Unsupported file extension: {path.suffix}",
|
||||
file_path=str(path),
|
||||
parser="excel"
|
||||
)
|
||||
|
||||
openpyxl = self._get_openpyxl()
|
||||
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(path, read_only=True, data_only=True)
|
||||
|
||||
all_records: list[dict[str, Any]] = []
|
||||
sheet_count = len(workbook.sheetnames)
|
||||
total_rows = 0
|
||||
|
||||
for sheet_name in workbook.sheetnames:
|
||||
sheet = workbook[sheet_name]
|
||||
records = self._sheet_to_records(sheet, sheet_name)
|
||||
all_records.extend(records)
|
||||
total_rows += len(records)
|
||||
|
||||
workbook.close()
|
||||
|
||||
json_str = json.dumps(all_records, ensure_ascii=False, indent=2)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
logger.info(
|
||||
f"Parsed Excel (JSON): {path.name}, sheets={sheet_count}, "
|
||||
f"rows={total_rows}, chars={len(json_str)}, size={file_size}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=json_str,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "xlsx",
|
||||
"output_format": "json",
|
||||
"sheet_count": sheet_count,
|
||||
"total_rows": total_rows,
|
||||
}
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse Excel document: {e}",
|
||||
file_path=str(path),
|
||||
parser="excel",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".xlsx", ".xls"]
|
||||
|
||||
|
||||
class CSVParser(DocumentParser):
|
||||
"""
|
||||
Parser for CSV files.
|
||||
[AC-AISVC-35] Uses Python's built-in csv module.
|
||||
Converts CSV data to JSON format to preserve structure.
|
||||
"""
|
||||
|
||||
def __init__(self, delimiter: str = ",", encoding: str = "utf-8", **kwargs: Any):
|
||||
self._delimiter = delimiter
|
||||
self._encoding = encoding
|
||||
self._extra_config = kwargs
|
||||
|
||||
def _parse_csv_to_records(self, path: Path, encoding: str) -> list[dict[str, Any]]:
|
||||
"""Parse CSV file and return list of record dictionaries."""
|
||||
import csv
|
||||
|
||||
records = []
|
||||
|
||||
with open(path, "r", encoding=encoding, newline="") as f:
|
||||
reader = csv.reader(f, delimiter=self._delimiter)
|
||||
rows = list(reader)
|
||||
|
||||
if not rows:
|
||||
return records
|
||||
|
||||
headers = rows[0]
|
||||
header_list = [str(h) if h else f"column_{i}" for i, h in enumerate(headers)]
|
||||
|
||||
for row in rows[1:]:
|
||||
record = {}
|
||||
has_content = False
|
||||
|
||||
for i, value in enumerate(row):
|
||||
if i < len(header_list):
|
||||
key = header_list[i]
|
||||
else:
|
||||
key = f"column_{i}"
|
||||
|
||||
if value:
|
||||
has_content = True
|
||||
record[key] = value
|
||||
|
||||
if has_content:
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a CSV file and extract text content as JSON.
|
||||
[AC-AISVC-35] Converts CSV data to JSON format.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="csv"
|
||||
)
|
||||
|
||||
try:
|
||||
records = self._parse_csv_to_records(path, self._encoding)
|
||||
row_count = len(records)
|
||||
used_encoding = self._encoding
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
records = self._parse_csv_to_records(path, "gbk")
|
||||
row_count = len(records)
|
||||
used_encoding = "gbk"
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse CSV with encoding fallback: {e}",
|
||||
file_path=str(path),
|
||||
parser="csv",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse CSV: {e}",
|
||||
file_path=str(path),
|
||||
parser="csv",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
json_str = json.dumps(records, ensure_ascii=False, indent=2)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
logger.info(
|
||||
f"Parsed CSV (JSON): {path.name}, rows={row_count}, "
|
||||
f"chars={len(json_str)}, size={file_size}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=json_str,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "csv",
|
||||
"output_format": "json",
|
||||
"row_count": row_count,
|
||||
"delimiter": self._delimiter,
|
||||
"encoding": used_encoding,
|
||||
}
|
||||
)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".csv"]
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
"""
|
||||
Document parser factory.
|
||||
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Factory for document parsers.
|
||||
|
||||
Design reference: progress.md Section 7.2 - DocumentParserFactory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParser,
|
||||
DocumentParseException,
|
||||
ParseResult,
|
||||
UnsupportedFormatError,
|
||||
)
|
||||
from app.services.document.excel_parser import CSVParser, ExcelParser
|
||||
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
|
||||
from app.services.document.text_parser import TextParser
|
||||
from app.services.document.word_parser import WordParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentParserFactory:
|
||||
"""
|
||||
Factory for creating document parsers.
|
||||
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Auto-selects parser based on file extension.
|
||||
"""
|
||||
|
||||
_parsers: dict[str, Type[DocumentParser]] = {}
|
||||
_extension_map: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def _initialize(cls) -> None:
|
||||
"""Initialize default parsers."""
|
||||
if cls._parsers:
|
||||
return
|
||||
|
||||
cls._parsers = {
|
||||
"pdf": PDFParser,
|
||||
"pdfplumber": PDFPlumberParser,
|
||||
"word": WordParser,
|
||||
"excel": ExcelParser,
|
||||
"csv": CSVParser,
|
||||
"text": TextParser,
|
||||
}
|
||||
|
||||
cls._extension_map = {
|
||||
".pdf": "pdf",
|
||||
".docx": "word",
|
||||
".xlsx": "excel",
|
||||
".xls": "excel",
|
||||
".csv": "csv",
|
||||
".txt": "text",
|
||||
".md": "text",
|
||||
".markdown": "text",
|
||||
".rst": "text",
|
||||
".log": "text",
|
||||
".json": "text",
|
||||
".xml": "text",
|
||||
".yaml": "text",
|
||||
".yml": "text",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_parser(
|
||||
cls,
|
||||
name: str,
|
||||
parser_class: Type[DocumentParser],
|
||||
extensions: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Register a new document parser.
|
||||
[AC-AISVC-33] Allows runtime registration of parsers.
|
||||
"""
|
||||
cls._initialize()
|
||||
cls._parsers[name] = parser_class
|
||||
for ext in extensions:
|
||||
cls._extension_map[ext.lower()] = name
|
||||
logger.info(f"Registered document parser: {name} for extensions: {extensions}")
|
||||
|
||||
@classmethod
|
||||
def get_supported_extensions(cls) -> list[str]:
|
||||
"""
|
||||
Get all supported file extensions.
|
||||
[AC-AISVC-37] Returns list of supported formats.
|
||||
"""
|
||||
cls._initialize()
|
||||
return list(cls._extension_map.keys())
|
||||
|
||||
@classmethod
|
||||
def get_parser_for_extension(cls, extension: str) -> DocumentParser:
|
||||
"""
|
||||
Get a parser instance for a file extension.
|
||||
[AC-AISVC-33] Creates appropriate parser based on extension.
|
||||
"""
|
||||
cls._initialize()
|
||||
|
||||
normalized = extension.lower()
|
||||
if not normalized.startswith("."):
|
||||
normalized = f".{normalized}"
|
||||
|
||||
if normalized not in cls._extension_map:
|
||||
raise UnsupportedFormatError(normalized, cls.get_supported_extensions())
|
||||
|
||||
parser_name = cls._extension_map[normalized]
|
||||
parser_class = cls._parsers[parser_name]
|
||||
|
||||
return parser_class()
|
||||
|
||||
@classmethod
|
||||
def parse_file(
|
||||
cls,
|
||||
file_path: str | Path,
|
||||
parser_name: str | None = None,
|
||||
parser_config: dict[str, Any] | None = None,
|
||||
) -> ParseResult:
|
||||
"""
|
||||
Parse a document file.
|
||||
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Main entry point for parsing.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
parser_name: Optional specific parser to use
|
||||
parser_config: Optional configuration for the parser
|
||||
|
||||
Returns:
|
||||
ParseResult with extracted text and metadata
|
||||
|
||||
Raises:
|
||||
UnsupportedFormatError: If file format is not supported
|
||||
DocumentParseException: If parsing fails
|
||||
"""
|
||||
cls._initialize()
|
||||
|
||||
path = Path(file_path)
|
||||
extension = path.suffix.lower()
|
||||
|
||||
if parser_name:
|
||||
if parser_name not in cls._parsers:
|
||||
raise DocumentParseException(
|
||||
f"Unknown parser: {parser_name}",
|
||||
file_path=str(path),
|
||||
parser="factory"
|
||||
)
|
||||
parser_class = cls._parsers[parser_name]
|
||||
parser = parser_class(**(parser_config or {}))
|
||||
else:
|
||||
parser = cls.get_parser_for_extension(extension)
|
||||
if parser_config:
|
||||
parser = type(parser)(**parser_config)
|
||||
|
||||
return parser.parse(path)
|
||||
|
||||
@classmethod
|
||||
def get_parser_info(cls) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get information about available parsers.
|
||||
[AC-AISVC-37] Returns parser metadata.
|
||||
"""
|
||||
cls._initialize()
|
||||
|
||||
info = []
|
||||
for name, parser_class in cls._parsers.items():
|
||||
temp_instance = parser_class.__new__(parser_class)
|
||||
extensions = temp_instance.get_supported_extensions()
|
||||
|
||||
display_names = {
|
||||
"pdf": "PDF 文档",
|
||||
"pdfplumber": "PDF 文档 (pdfplumber)",
|
||||
"word": "Word 文档",
|
||||
"excel": "Excel 电子表格",
|
||||
"csv": "CSV 文件",
|
||||
"text": "文本文件",
|
||||
}
|
||||
|
||||
descriptions = {
|
||||
"pdf": "使用 PyMuPDF 解析 PDF 文档,速度快",
|
||||
"pdfplumber": "使用 pdfplumber 解析 PDF 文档,表格提取效果更好",
|
||||
"word": "解析 Word 文档 (.docx),保留段落结构",
|
||||
"excel": "解析 Excel 电子表格,支持多工作表",
|
||||
"csv": "解析 CSV 文件,自动检测编码",
|
||||
"text": "解析纯文本文件,支持多种编码",
|
||||
}
|
||||
|
||||
info.append({
|
||||
"name": name,
|
||||
"display_name": display_names.get(name, name),
|
||||
"description": descriptions.get(name, ""),
|
||||
"extensions": extensions,
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def parse_document(
|
||||
file_path: str | Path,
|
||||
parser_name: str | None = None,
|
||||
parser_config: dict[str, Any] | None = None,
|
||||
) -> ParseResult:
|
||||
"""
|
||||
Convenience function for parsing documents.
|
||||
[AC-AISVC-33] Simple entry point for document parsing.
|
||||
"""
|
||||
return DocumentParserFactory.parse_file(file_path, parser_name, parser_config)
|
||||
|
||||
|
||||
def get_supported_document_formats() -> list[str]:
|
||||
"""
|
||||
Get list of supported document formats.
|
||||
[AC-AISVC-37] Returns supported format extensions.
|
||||
"""
|
||||
return DocumentParserFactory.get_supported_extensions()
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
PDF document parser implementation.
|
||||
[AC-AISVC-33] PDF parsing using PyMuPDF (fitz).
|
||||
|
||||
Extracts text content from PDF files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
PageText,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PDFParser(DocumentParser):
|
||||
"""
|
||||
Parser for PDF documents.
|
||||
[AC-AISVC-33] Uses PyMuPDF for text extraction.
|
||||
"""
|
||||
|
||||
def __init__(self, extract_images: bool = False, **kwargs: Any):
|
||||
self._extract_images = extract_images
|
||||
self._extra_config = kwargs
|
||||
self._fitz = None
|
||||
|
||||
def _get_fitz(self):
|
||||
"""Lazy import of PyMuPDF."""
|
||||
if self._fitz is None:
|
||||
try:
|
||||
import fitz
|
||||
self._fitz = fitz
|
||||
except ImportError:
|
||||
raise DocumentParseException(
|
||||
"PyMuPDF (fitz) not installed. Install with: pip install pymupdf",
|
||||
parser="pdf"
|
||||
)
|
||||
return self._fitz
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a PDF document and extract text content.
|
||||
[AC-AISVC-33] Extracts text from all pages.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="pdf"
|
||||
)
|
||||
|
||||
if not self.supports_extension(path.suffix):
|
||||
raise DocumentParseException(
|
||||
f"Unsupported file extension: {path.suffix}",
|
||||
file_path=str(path),
|
||||
parser="pdf"
|
||||
)
|
||||
|
||||
fitz = self._get_fitz()
|
||||
|
||||
try:
|
||||
doc = fitz.open(path)
|
||||
|
||||
pages: list[PageText] = []
|
||||
text_parts = []
|
||||
page_count = len(doc)
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = doc[page_num]
|
||||
text = page.get_text().strip()
|
||||
if text:
|
||||
pages.append(PageText(page=page_num + 1, text=text))
|
||||
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
||||
|
||||
doc.close()
|
||||
|
||||
full_text = "\n\n".join(text_parts)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
logger.info(
|
||||
f"Parsed PDF: {path.name}, pages={page_count}, "
|
||||
f"chars={len(full_text)}, size={file_size}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=full_text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
page_count=page_count,
|
||||
metadata={
|
||||
"format": "pdf",
|
||||
"page_count": page_count,
|
||||
},
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse PDF: {e}",
|
||||
file_path=str(path),
|
||||
parser="pdf",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".pdf"]
|
||||
|
||||
|
||||
class PDFPlumberParser(DocumentParser):
|
||||
"""
|
||||
Alternative PDF parser using pdfplumber.
|
||||
[AC-AISVC-33] Uses pdfplumber for text extraction.
|
||||
|
||||
pdfplumber is better for table extraction but slower than PyMuPDF.
|
||||
"""
|
||||
|
||||
def __init__(self, extract_tables: bool = True, **kwargs: Any):
|
||||
self._extract_tables = extract_tables
|
||||
self._extra_config = kwargs
|
||||
self._pdfplumber = None
|
||||
|
||||
def _get_pdfplumber(self):
|
||||
"""Lazy import of pdfplumber."""
|
||||
if self._pdfplumber is None:
|
||||
try:
|
||||
import pdfplumber
|
||||
self._pdfplumber = pdfplumber
|
||||
except ImportError:
|
||||
raise DocumentParseException(
|
||||
"pdfplumber not installed. Install with: pip install pdfplumber",
|
||||
parser="pdfplumber"
|
||||
)
|
||||
return self._pdfplumber
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a PDF document and extract text content.
|
||||
[AC-AISVC-33] Extracts text and optionally tables.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="pdfplumber"
|
||||
)
|
||||
|
||||
pdfplumber = self._get_pdfplumber()
|
||||
|
||||
try:
|
||||
pages: list[PageText] = []
|
||||
text_parts = []
|
||||
page_count = 0
|
||||
|
||||
with pdfplumber.open(path) as pdf:
|
||||
page_count = len(pdf.pages)
|
||||
|
||||
for page_num, page in enumerate(pdf.pages):
|
||||
text = page.extract_text() or ""
|
||||
|
||||
if self._extract_tables:
|
||||
tables = page.extract_tables()
|
||||
for table in tables:
|
||||
table_text = self._format_table(table)
|
||||
text += f"\n\n{table_text}"
|
||||
|
||||
text = text.strip()
|
||||
if text:
|
||||
pages.append(PageText(page=page_num + 1, text=text))
|
||||
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
||||
|
||||
full_text = "\n\n".join(text_parts)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
logger.info(
|
||||
f"Parsed PDF (pdfplumber): {path.name}, pages={page_count}, "
|
||||
f"chars={len(full_text)}, size={file_size}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=full_text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
page_count=page_count,
|
||||
metadata={
|
||||
"format": "pdf",
|
||||
"parser": "pdfplumber",
|
||||
"page_count": page_count,
|
||||
},
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse PDF: {e}",
|
||||
file_path=str(path),
|
||||
parser="pdfplumber",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _format_table(self, table: list[list[str | None]]) -> str:
|
||||
"""Format a table as text."""
|
||||
if not table:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
for row in table:
|
||||
cells = [str(cell) if cell else "" for cell in row]
|
||||
lines.append(" | ".join(cells))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".pdf"]
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Text file parser implementation.
|
||||
[AC-AISVC-33] Text file parsing for plain text and markdown.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENCODINGS_TO_TRY = ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]
|
||||
|
||||
|
||||
class TextParser(DocumentParser):
|
||||
"""
|
||||
Parser for plain text files.
|
||||
[AC-AISVC-33] Direct text extraction with multiple encoding support.
|
||||
"""
|
||||
|
||||
def __init__(self, encoding: str = "utf-8", **kwargs: Any):
|
||||
self._encoding = encoding
|
||||
self._extra_config = kwargs
|
||||
|
||||
def _try_encodings(self, path: Path) -> tuple[str, str]:
|
||||
"""
|
||||
Try multiple encodings to read the file.
|
||||
Returns: (text, encoding_used)
|
||||
"""
|
||||
for enc in ENCODINGS_TO_TRY:
|
||||
try:
|
||||
with open(path, "r", encoding=enc) as f:
|
||||
text = f.read()
|
||||
logger.info(f"Successfully parsed with encoding: {enc}")
|
||||
return text, enc
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
raise DocumentParseException(
|
||||
f"Failed to decode file with any known encoding",
|
||||
file_path=str(path),
|
||||
parser="text"
|
||||
)
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a text file and extract content.
|
||||
[AC-AISVC-33] Direct file reading.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="text"
|
||||
)
|
||||
|
||||
try:
|
||||
text, encoding_used = self._try_encodings(path)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
line_count = text.count("\n") + 1
|
||||
|
||||
logger.info(
|
||||
f"Parsed text: {path.name}, lines={line_count}, "
|
||||
f"chars={len(text)}, size={file_size}, encoding={encoding_used}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "text",
|
||||
"line_count": line_count,
|
||||
"encoding": encoding_used,
|
||||
}
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse text file: {e}",
|
||||
file_path=str(path),
|
||||
parser="text",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"]
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Word document parser implementation.
|
||||
[AC-AISVC-34] Word (.docx) parsing using python-docx.
|
||||
|
||||
Extracts text content from Word documents.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordParser(DocumentParser):
|
||||
"""
|
||||
Parser for Word documents.
|
||||
[AC-AISVC-34] Uses python-docx for text extraction.
|
||||
"""
|
||||
|
||||
def __init__(self, include_headers: bool = True, include_footers: bool = True, **kwargs: Any):
|
||||
self._include_headers = include_headers
|
||||
self._include_footers = include_footers
|
||||
self._extra_config = kwargs
|
||||
self._docx = None
|
||||
|
||||
def _get_docx(self):
|
||||
"""Lazy import of python-docx."""
|
||||
if self._docx is None:
|
||||
try:
|
||||
from docx import Document
|
||||
self._docx = Document
|
||||
except ImportError:
|
||||
raise DocumentParseException(
|
||||
"python-docx not installed. Install with: pip install python-docx",
|
||||
parser="word"
|
||||
)
|
||||
return self._docx
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a Word document and extract text content.
|
||||
[AC-AISVC-34] Extracts text while preserving paragraph structure.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="word"
|
||||
)
|
||||
|
||||
if not self.supports_extension(path.suffix):
|
||||
raise DocumentParseException(
|
||||
f"Unsupported file extension: {path.suffix}",
|
||||
file_path=str(path),
|
||||
parser="word"
|
||||
)
|
||||
|
||||
Document = self._get_docx()
|
||||
|
||||
try:
|
||||
doc = Document(path)
|
||||
|
||||
text_parts = []
|
||||
|
||||
if self._include_headers:
|
||||
for section in doc.sections:
|
||||
header = section.header
|
||||
if header and header.paragraphs:
|
||||
header_text = "\n".join(p.text for p in header.paragraphs if p.text.strip())
|
||||
if header_text:
|
||||
text_parts.append(f"[Header]\n{header_text}")
|
||||
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
style_name = para.style.name if para.style else ""
|
||||
if "Heading" in style_name:
|
||||
text_parts.append(f"\n## {para.text}")
|
||||
else:
|
||||
text_parts.append(para.text)
|
||||
|
||||
for table in doc.tables:
|
||||
table_text = self._format_table(table)
|
||||
if table_text.strip():
|
||||
text_parts.append(f"\n[Table]\n{table_text}")
|
||||
|
||||
if self._include_footers:
|
||||
for section in doc.sections:
|
||||
footer = section.footer
|
||||
if footer and footer.paragraphs:
|
||||
footer_text = "\n".join(p.text for p in footer.paragraphs if p.text.strip())
|
||||
if footer_text:
|
||||
text_parts.append(f"[Footer]\n{footer_text}")
|
||||
|
||||
full_text = "\n\n".join(text_parts)
|
||||
file_size = path.stat().st_size
|
||||
|
||||
paragraph_count = len(doc.paragraphs)
|
||||
table_count = len(doc.tables)
|
||||
|
||||
logger.info(
|
||||
f"Parsed Word: {path.name}, paragraphs={paragraph_count}, "
|
||||
f"tables={table_count}, chars={len(full_text)}, size={file_size}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=full_text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "docx",
|
||||
"paragraph_count": paragraph_count,
|
||||
"table_count": table_count,
|
||||
}
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse Word document: {e}",
|
||||
file_path=str(path),
|
||||
parser="word",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _format_table(self, table) -> str:
|
||||
"""Format a table as text."""
|
||||
lines = []
|
||||
for row in table.rows:
|
||||
cells = [cell.text.strip() for cell in row.cells]
|
||||
lines.append(" | ".join(cells))
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".docx"]
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""
|
||||
Embedding services package.
|
||||
[AC-AISVC-29] Provides pluggable embedding providers.
|
||||
"""
|
||||
|
||||
from app.services.embedding.base import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingException,
|
||||
EmbeddingProvider,
|
||||
EmbeddingResult,
|
||||
)
|
||||
from app.services.embedding.factory import (
|
||||
EmbeddingConfigManager,
|
||||
EmbeddingProviderFactory,
|
||||
get_embedding_config_manager,
|
||||
get_embedding_provider,
|
||||
)
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import (
|
||||
NomicEmbeddingProvider,
|
||||
NomicEmbeddingResult,
|
||||
EmbeddingTask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingConfig",
|
||||
"EmbeddingException",
|
||||
"EmbeddingProvider",
|
||||
"EmbeddingResult",
|
||||
"EmbeddingConfigManager",
|
||||
"EmbeddingProviderFactory",
|
||||
"get_embedding_config_manager",
|
||||
"get_embedding_provider",
|
||||
"OllamaEmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"NomicEmbeddingProvider",
|
||||
"NomicEmbeddingResult",
|
||||
"EmbeddingTask",
|
||||
]
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""
|
||||
Base embedding provider interface.
|
||||
[AC-AISVC-29] Abstract interface for embedding providers.
|
||||
|
||||
Design reference: progress.md Section 7.1 - EmbeddingProvider interface
|
||||
- embed(text) -> list[float]
|
||||
- embed_batch(texts) -> list[list[float]]
|
||||
- get_dimension() -> int
|
||||
- get_provider_name() -> str
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingConfig:
|
||||
"""
|
||||
Configuration for embedding provider.
|
||||
[AC-AISVC-31] Supports configurable embedding parameters.
|
||||
"""
|
||||
dimension: int = 768
|
||||
batch_size: int = 32
|
||||
timeout_seconds: int = 60
|
||||
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResult:
|
||||
"""
|
||||
Result from embedding generation.
|
||||
[AC-AISVC-29] Contains embedding vector and metadata.
|
||||
"""
|
||||
embedding: list[float]
|
||||
dimension: int
|
||||
model: str
|
||||
latency_ms: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""
|
||||
Abstract base class for embedding providers.
|
||||
[AC-AISVC-29] Provides unified interface for different embedding providers.
|
||||
|
||||
Design reference: progress.md Section 7.1 - Architecture
|
||||
- OllamaEmbeddingProvider / OpenAIEmbeddingProvider can be swapped
|
||||
- Factory pattern for dynamic loading
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for a single text.
|
||||
[AC-AISVC-29] Returns embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding vector.
|
||||
|
||||
Raises:
|
||||
EmbeddingException: If embedding generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
[AC-AISVC-29] Returns list of embedding vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
|
||||
Raises:
|
||||
EmbeddingException: If embedding generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimension of embedding vectors.
|
||||
[AC-AISVC-29] Returns vector dimension.
|
||||
|
||||
Returns:
|
||||
Integer dimension of embedding vectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> str:
|
||||
"""
|
||||
Get the name of this embedding provider.
|
||||
[AC-AISVC-29] Returns provider identifier.
|
||||
|
||||
Returns:
|
||||
String identifier for this provider.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the configuration schema for this provider.
|
||||
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
|
||||
|
||||
Returns:
|
||||
Dict describing configuration parameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the provider and release resources. Default no-op."""
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingException(Exception):
|
||||
"""Exception raised when embedding generation fails."""
|
||||
|
||||
def __init__(self, message: str, provider: str = "", details: dict[str, Any] | None = None):
|
||||
self.provider = provider
|
||||
self.details = details or {}
|
||||
super().__init__(f"[{provider}] {message}" if provider else message)
|
||||
|
|
@ -0,0 +1,305 @@
|
|||
"""
|
||||
Embedding provider factory and configuration manager.
|
||||
[AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading.
|
||||
|
||||
Design reference: progress.md Section 7.1 - Architecture
|
||||
- EmbeddingProviderFactory: creates providers based on config
|
||||
- EmbeddingConfigManager: manages configuration with hot-reload support
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
"""
|
||||
Factory for creating embedding providers.
|
||||
[AC-AISVC-30] Supports dynamic loading based on configuration.
|
||||
"""
|
||||
|
||||
_providers: dict[str, Type[EmbeddingProvider]] = {
|
||||
"ollama": OllamaEmbeddingProvider,
|
||||
"openai": OpenAIEmbeddingProvider,
|
||||
"nomic": NomicEmbeddingProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, name: str, provider_class: Type[EmbeddingProvider]) -> None:
|
||||
"""
|
||||
Register a new embedding provider.
|
||||
[AC-AISVC-30] Allows runtime registration of providers.
|
||||
"""
|
||||
cls._providers[name] = provider_class
|
||||
logger.info(f"Registered embedding provider: {name}")
|
||||
|
||||
@classmethod
|
||||
def get_available_providers(cls) -> list[str]:
|
||||
"""
|
||||
Get list of available provider names.
|
||||
[AC-AISVC-38] Returns registered provider identifiers.
|
||||
"""
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_info(cls, name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get provider information including config schema.
|
||||
[AC-AISVC-38] Returns provider metadata.
|
||||
"""
|
||||
if name not in cls._providers:
|
||||
raise EmbeddingException(
|
||||
f"Unknown provider: {name}",
|
||||
provider="factory"
|
||||
)
|
||||
|
||||
provider_class = cls._providers[name]
|
||||
temp_instance = provider_class.__new__(provider_class)
|
||||
|
||||
display_names = {
|
||||
"ollama": "Ollama 本地模型",
|
||||
"openai": "OpenAI Embedding",
|
||||
"nomic": "Nomic Embed (优化版)",
|
||||
}
|
||||
|
||||
descriptions = {
|
||||
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
|
||||
"openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型",
|
||||
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||||
}
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"display_name": display_names.get(name, name),
|
||||
"description": descriptions.get(name, ""),
|
||||
"config_schema": temp_instance.get_config_schema(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_provider(
|
||||
cls,
|
||||
name: str,
|
||||
config: dict[str, Any],
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Create an embedding provider instance.
|
||||
[AC-AISVC-30] Creates provider based on configuration.
|
||||
|
||||
Args:
|
||||
name: Provider identifier (e.g., "ollama", "openai")
|
||||
config: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
Configured EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
EmbeddingException: If provider is unknown or configuration is invalid
|
||||
"""
|
||||
if name not in cls._providers:
|
||||
raise EmbeddingException(
|
||||
f"Unknown embedding provider: {name}. "
|
||||
f"Available: {cls.get_available_providers()}",
|
||||
provider="factory"
|
||||
)
|
||||
|
||||
provider_class = cls._providers[name]
|
||||
|
||||
try:
|
||||
instance = provider_class(**config)
|
||||
logger.info(f"Created embedding provider: {name}")
|
||||
return instance
|
||||
except Exception as e:
|
||||
raise EmbeddingException(
|
||||
f"Failed to create provider '{name}': {e}",
|
||||
provider="factory",
|
||||
details={"config": config}
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfigManager:
|
||||
"""
|
||||
Manager for embedding configuration.
|
||||
[AC-AISVC-31] Supports hot-reload of configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
|
||||
self._provider_name = default_provider
|
||||
self._config = default_config or {
|
||||
"base_url": "http://localhost:11434",
|
||||
"model": "nomic-embed-text",
|
||||
"dimension": 768,
|
||||
}
|
||||
self._provider: EmbeddingProvider | None = None
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get current provider name."""
|
||||
return self._provider_name
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Get current configuration."""
|
||||
return self._config.copy()
|
||||
|
||||
def get_full_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get full configuration including provider name.
|
||||
[AC-AISVC-39] Returns complete configuration for API response.
|
||||
"""
|
||||
return {
|
||||
"provider": self._provider_name,
|
||||
"config": self._config.copy(),
|
||||
}
|
||||
|
||||
async def get_provider(self) -> EmbeddingProvider:
|
||||
"""
|
||||
Get or create the embedding provider.
|
||||
[AC-AISVC-29] Returns configured provider instance.
|
||||
"""
|
||||
if self._provider is None:
|
||||
self._provider = EmbeddingProviderFactory.create_provider(
|
||||
self._provider_name,
|
||||
self._config
|
||||
)
|
||||
return self._provider
|
||||
|
||||
async def update_config(
|
||||
self,
|
||||
provider: str,
|
||||
config: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Update embedding configuration.
|
||||
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload.
|
||||
|
||||
Args:
|
||||
provider: New provider name
|
||||
config: New provider configuration
|
||||
|
||||
Returns:
|
||||
True if update was successful
|
||||
|
||||
Raises:
|
||||
EmbeddingException: If configuration is invalid
|
||||
"""
|
||||
old_provider = self._provider_name
|
||||
old_config = self._config.copy()
|
||||
|
||||
try:
|
||||
new_provider_instance = EmbeddingProviderFactory.create_provider(
|
||||
provider,
|
||||
config
|
||||
)
|
||||
|
||||
if self._provider:
|
||||
await self._provider.close()
|
||||
|
||||
self._provider_name = provider
|
||||
self._config = config
|
||||
self._provider = new_provider_instance
|
||||
|
||||
logger.info(f"Updated embedding config: provider={provider}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._provider_name = old_provider
|
||||
self._config = old_config
|
||||
raise EmbeddingException(
|
||||
f"Failed to update config: {e}",
|
||||
provider="config_manager",
|
||||
details={"provider": provider, "config": config}
|
||||
)
|
||||
|
||||
async def test_connection(
|
||||
self,
|
||||
test_text: str = "这是一个测试文本",
|
||||
provider: str | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Test embedding connection.
|
||||
[AC-AISVC-41] Tests provider connectivity.
|
||||
|
||||
Args:
|
||||
test_text: Text to embed for testing
|
||||
provider: Provider to test (uses current if None)
|
||||
config: Config to test (uses current if None)
|
||||
|
||||
Returns:
|
||||
Dict with test results including success, dimension, latency
|
||||
"""
|
||||
import time
|
||||
|
||||
test_provider_name = provider or self._provider_name
|
||||
test_config = config or self._config
|
||||
|
||||
try:
|
||||
test_provider = EmbeddingProviderFactory.create_provider(
|
||||
test_provider_name,
|
||||
test_config
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
embedding = await test_provider.embed(test_text)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
await test_provider.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"dimension": len(embedding),
|
||||
"latency_ms": latency_ms,
|
||||
"message": f"连接成功,向量维度: {len(embedding)}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"dimension": 0,
|
||||
"latency_ms": 0,
|
||||
"error": str(e),
|
||||
"message": f"连接失败: {e}",
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the current provider."""
|
||||
if self._provider:
|
||||
await self._provider.close()
|
||||
self._provider = None
|
||||
|
||||
|
||||
_embedding_config_manager: EmbeddingConfigManager | None = None
|
||||
|
||||
|
||||
def get_embedding_config_manager() -> EmbeddingConfigManager:
|
||||
"""
|
||||
Get the global embedding config manager.
|
||||
[AC-AISVC-31] Singleton pattern for configuration management.
|
||||
"""
|
||||
global _embedding_config_manager
|
||||
if _embedding_config_manager is None:
|
||||
from app.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
_embedding_config_manager = EmbeddingConfigManager(
|
||||
default_provider="ollama",
|
||||
default_config={
|
||||
"base_url": settings.ollama_base_url,
|
||||
"model": settings.ollama_embedding_model,
|
||||
"dimension": settings.qdrant_vector_size,
|
||||
}
|
||||
)
|
||||
return _embedding_config_manager
|
||||
|
||||
|
||||
async def get_embedding_provider() -> EmbeddingProvider:
|
||||
"""
|
||||
Get the current embedding provider.
|
||||
[AC-AISVC-29] Convenience function for getting provider.
|
||||
"""
|
||||
manager = get_embedding_config_manager()
|
||||
return await manager.get_provider()
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
Nomic embedding provider with task prefixes and Matryoshka support.
|
||||
Implements RAG optimization spec:
|
||||
- Task prefixes: search_document: / search_query:
|
||||
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from app.services.embedding.base import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingException,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingTask(str, Enum):
|
||||
"""Task type for nomic-embed-text v1.5 model."""
|
||||
DOCUMENT = "search_document"
|
||||
QUERY = "search_query"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NomicEmbeddingResult:
|
||||
"""Result from Nomic embedding with multiple dimensions."""
|
||||
embedding_full: list[float]
|
||||
embedding_256: list[float]
|
||||
embedding_512: list[float]
|
||||
dimension: int
|
||||
model: str
|
||||
task: EmbeddingTask
|
||||
latency_ms: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class NomicEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Nomic-embed-text v1.5 embedding provider with task prefixes.
|
||||
|
||||
Key features:
|
||||
- Task prefixes: search_document: for documents, search_query: for queries
|
||||
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||
- Automatic normalization after truncation
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.1, 2.3
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "nomic"
|
||||
DOCUMENT_PREFIX = "search_document:"
|
||||
QUERY_PREFIX = "search_query:"
|
||||
FULL_DIMENSION = 768
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:11434",
|
||||
model: str = "nomic-embed-text",
|
||||
dimension: int = 768,
|
||||
timeout_seconds: int = 60,
|
||||
enable_matryoshka: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._dimension = dimension
|
||||
self._timeout = timeout_seconds
|
||||
self._enable_matryoshka = enable_matryoshka
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._extra_config = kwargs
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
def _add_prefix(self, text: str, task: EmbeddingTask) -> str:
|
||||
"""Add task prefix to text."""
|
||||
if task == EmbeddingTask.DOCUMENT:
|
||||
prefix = self.DOCUMENT_PREFIX
|
||||
else:
|
||||
prefix = self.QUERY_PREFIX
|
||||
|
||||
if text.startswith(prefix):
|
||||
return text
|
||||
return f"{prefix}{text}"
|
||||
|
||||
def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]:
|
||||
"""
|
||||
Truncate embedding to target dimension and normalize.
|
||||
Matryoshka representation learning allows dimension truncation.
|
||||
"""
|
||||
truncated = embedding[:target_dim]
|
||||
|
||||
arr = np.array(truncated, dtype=np.float32)
|
||||
norm = np.linalg.norm(arr)
|
||||
if norm > 0:
|
||||
arr = arr / norm
|
||||
|
||||
return arr.tolist()
|
||||
|
||||
async def embed_with_task(
|
||||
self,
|
||||
text: str,
|
||||
task: EmbeddingTask,
|
||||
) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding with specified task prefix.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
task: DOCUMENT for indexing, QUERY for retrieval
|
||||
|
||||
Returns:
|
||||
NomicEmbeddingResult with all dimension variants
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
prefixed_text = self._add_prefix(text, task)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
f"{self._base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self._model,
|
||||
"prompt": prefixed_text,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
if not embedding:
|
||||
raise EmbeddingException(
|
||||
"Empty embedding returned",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"text_length": len(text), "task": task.value}
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
||||
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
||||
|
||||
logger.debug(
|
||||
f"Generated Nomic embedding: task={task.value}, "
|
||||
f"dim={len(embedding)}, latency={latency_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return NomicEmbeddingResult(
|
||||
embedding_full=embedding,
|
||||
embedding_256=embedding_256,
|
||||
embedding_512=embedding_512,
|
||||
dimension=len(embedding),
|
||||
model=self._model,
|
||||
task=task,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama API error: {e.response.status_code}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"status_code": e.response.status_code, "response": e.response.text}
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama connection error: {e}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"base_url": self._base_url}
|
||||
)
|
||||
except EmbeddingException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise EmbeddingException(
|
||||
f"Embedding generation failed: {e}",
|
||||
provider=self.PROVIDER_NAME
|
||||
)
|
||||
|
||||
async def embed_document(self, text: str) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding for document (with search_document: prefix).
|
||||
Use this when indexing documents into vector store.
|
||||
"""
|
||||
return await self.embed_with_task(text, EmbeddingTask.DOCUMENT)
|
||||
|
||||
async def embed_query(self, text: str) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding for query (with search_query: prefix).
|
||||
Use this when searching/retrieving documents.
|
||||
"""
|
||||
return await self.embed_with_task(text, EmbeddingTask.QUERY)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for a single text.
|
||||
Default uses QUERY task for backward compatibility.
|
||||
"""
|
||||
result = await self.embed_query(text)
|
||||
return result.embedding_full
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
Uses QUERY task by default.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self.embed(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
async def embed_documents_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[NomicEmbeddingResult]:
|
||||
"""
|
||||
Generate embeddings for multiple documents (DOCUMENT task).
|
||||
Use this when batch indexing documents.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
result = await self.embed_document(text)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def embed_queries_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[NomicEmbeddingResult]:
|
||||
"""
|
||||
Generate embeddings for multiple queries (QUERY task).
|
||||
Use this when batch processing queries.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
result = await self.embed_query(text)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get the dimension of embedding vectors."""
|
||||
return self._dimension
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get the name of this embedding provider."""
|
||||
return self.PROVIDER_NAME
|
||||
|
||||
def get_config_schema(self) -> dict[str, Any]:
|
||||
"""Get the configuration schema for Nomic provider."""
|
||||
return {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
||||
"default": "nomic-embed-text",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"description": "向量维度(支持 256/512/768)",
|
||||
"default": 768,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
"enable_matryoshka": {
|
||||
"type": "boolean",
|
||||
"description": "启用 Matryoshka 维度截断",
|
||||
"default": True,
|
||||
},
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Ollama embedding service for generating text embeddings.
|
||||
Uses nomic-embed-text model via Ollama API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_embedding(text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for text using Ollama nomic-embed-text model.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{settings.ollama_base_url}/api/embeddings",
|
||||
json={
|
||||
"model": settings.ollama_embedding_model,
|
||||
"prompt": text,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
if not embedding:
|
||||
logger.warning(f"Empty embedding returned for text length={len(text)}")
|
||||
return [0.0] * settings.qdrant_vector_size
|
||||
|
||||
logger.debug(f"Generated embedding: dim={len(embedding)}")
|
||||
return embedding
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama API error: {e.response.status_code} - {e.response.text}")
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Ollama connection error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding generation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await get_embedding(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""
|
||||
Ollama embedding provider implementation.
|
||||
[AC-AISVC-29, AC-AISVC-30] Ollama-based embedding provider.
|
||||
|
||||
Uses Ollama API for generating text embeddings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.services.embedding.base import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingException,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Embedding provider using Ollama API.
|
||||
[AC-AISVC-29, AC-AISVC-30] Supports local embedding models via Ollama.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:11434",
|
||||
model: str = "nomic-embed-text",
|
||||
dimension: int = 768,
|
||||
timeout_seconds: int = 60,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._dimension = dimension
|
||||
self._timeout = timeout_seconds
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._extra_config = kwargs
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for a single text using Ollama API.
|
||||
[AC-AISVC-29] Returns embedding vector.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
f"{self._base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self._model,
|
||||
"prompt": text,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
if not embedding:
|
||||
raise EmbeddingException(
|
||||
"Empty embedding returned",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"text_length": len(text)}
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"Generated embedding via Ollama: dim={len(embedding)}, "
|
||||
f"latency={latency_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama API error: {e.response.status_code}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"status_code": e.response.status_code, "response": e.response.text}
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama connection error: {e}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"base_url": self._base_url}
|
||||
)
|
||||
except EmbeddingException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise EmbeddingException(
|
||||
f"Embedding generation failed: {e}",
|
||||
provider=self.PROVIDER_NAME
|
||||
)
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
[AC-AISVC-29] Sequential embedding generation.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self.embed(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get the dimension of embedding vectors."""
|
||||
return self._dimension
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get the name of this embedding provider."""
|
||||
return self.PROVIDER_NAME
|
||||
|
||||
def get_config_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the configuration schema for Ollama provider.
|
||||
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
|
||||
"""
|
||||
return {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "嵌入模型名称",
|
||||
"default": "nomic-embed-text",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"description": "向量维度",
|
||||
"default": 768,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
OpenAI embedding provider implementation.
|
||||
[AC-AISVC-29, AC-AISVC-30] OpenAI-based embedding provider.
|
||||
|
||||
Uses OpenAI API for generating text embeddings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.services.embedding.base import (
|
||||
EmbeddingException,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Embedding provider using OpenAI API.
|
||||
[AC-AISVC-29, AC-AISVC-30] Supports OpenAI embedding models.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
|
||||
MODEL_DIMENSIONS = {
|
||||
"text-embedding-ada-002": 1536,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = "https://api.openai.com/v1",
|
||||
dimension: int | None = None,
|
||||
timeout_seconds: int = 60,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout_seconds
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._extra_config = kwargs
|
||||
|
||||
if dimension:
|
||||
self._dimension = dimension
|
||||
elif model in self.MODEL_DIMENSIONS:
|
||||
self._dimension = self.MODEL_DIMENSIONS[model]
|
||||
else:
|
||||
self._dimension = 1536
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for a single text using OpenAI API.
|
||||
[AC-AISVC-29] Returns embedding vector.
|
||||
"""
|
||||
embeddings = await self.embed_batch([text])
|
||||
return embeddings[0]
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts using OpenAI API.
|
||||
[AC-AISVC-29] Supports batch embedding for efficiency.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
request_body: dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"input": texts,
|
||||
}
|
||||
if self._dimension and self._model.startswith("text-embedding-3"):
|
||||
request_body["dimensions"] = self._dimension
|
||||
|
||||
response = await client.post(
|
||||
f"{self._base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
embeddings = []
|
||||
for item in data.get("data", []):
|
||||
embedding = item.get("embedding", [])
|
||||
if not embedding:
|
||||
raise EmbeddingException(
|
||||
"Empty embedding returned",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"index": item.get("index", 0)}
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise EmbeddingException(
|
||||
f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
|
||||
provider=self.PROVIDER_NAME
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"Generated {len(embeddings)} embeddings via OpenAI: "
|
||||
f"dim={len(embeddings[0]) if embeddings else 0}, "
|
||||
f"latency={latency_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise EmbeddingException(
|
||||
f"OpenAI API error: {e.response.status_code}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"status_code": e.response.status_code, "response": e.response.text}
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise EmbeddingException(
|
||||
f"OpenAI connection error: {e}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"base_url": self._base_url}
|
||||
)
|
||||
except EmbeddingException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise EmbeddingException(
|
||||
f"Embedding generation failed: {e}",
|
||||
provider=self.PROVIDER_NAME
|
||||
)
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get the dimension of embedding vectors."""
|
||||
return self._dimension
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get the name of this embedding provider."""
|
||||
return self.PROVIDER_NAME
|
||||
|
||||
def get_config_schema(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the configuration schema for OpenAI provider.
|
||||
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
|
||||
"""
|
||||
return {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API 密钥",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "嵌入模型名称",
|
||||
"default": "text-embedding-3-small",
|
||||
"enum": list(self.MODEL_DIMENSIONS.keys()),
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API 地址(支持兼容接口)",
|
||||
"default": "https://api.openai.com/v1",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
|
||||
"default": 1536,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
|
@ -0,0 +1,294 @@
|
|||
"""
|
||||
Knowledge Base service for AI Service.
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col
|
||||
|
||||
from app.models.entities import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
IndexJob,
|
||||
IndexJobStatus,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KBService:
|
||||
"""
|
||||
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service.
|
||||
Handles document upload, indexing jobs, and document listing.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"):
|
||||
self._session = session
|
||||
self._upload_dir = upload_dir
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
async def get_or_create_kb(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
name: str = "Default KB",
|
||||
) -> KnowledgeBase:
|
||||
"""
|
||||
Get existing KB or create default one.
|
||||
"""
|
||||
if kb_id:
|
||||
try:
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.tenant_id == tenant_id,
|
||||
KnowledgeBase.id == uuid.UUID(kb_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
existing_kb = result.scalar_one_or_none()
|
||||
if existing_kb:
|
||||
return existing_kb
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.tenant_id == tenant_id,
|
||||
).limit(1)
|
||||
result = await self._session.execute(stmt)
|
||||
existing_kb = result.scalar_one_or_none()
|
||||
|
||||
if existing_kb:
|
||||
return existing_kb
|
||||
|
||||
new_kb = KnowledgeBase(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
)
|
||||
self._session.add(new_kb)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}")
|
||||
return new_kb
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str | None = None,
|
||||
) -> tuple[Document, IndexJob]:
|
||||
"""
|
||||
[AC-ASA-01] Upload document and create indexing job.
|
||||
"""
|
||||
doc_id = uuid.uuid4()
|
||||
file_path = os.path.join(self._upload_dir, f"{tenant_id}_{doc_id}_{file_name}")
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
document = Document(
|
||||
id=doc_id,
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
file_size=len(file_content),
|
||||
file_type=file_type,
|
||||
status=DocumentStatus.PENDING.value,
|
||||
)
|
||||
self._session.add(document)
|
||||
|
||||
job = IndexJob(
|
||||
tenant_id=tenant_id,
|
||||
doc_id=doc_id,
|
||||
status=IndexJobStatus.PENDING.value,
|
||||
progress=0,
|
||||
)
|
||||
self._session.add(job)
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, "
|
||||
f"file_name={file_name}, size={len(file_content)}"
|
||||
)
|
||||
|
||||
return document, job
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
tenant_id: str,
|
||||
kb_id: str | None = None,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[Sequence[Document], int]:
|
||||
"""
|
||||
[AC-ASA-08] List documents with filtering and pagination.
|
||||
"""
|
||||
stmt = select(Document).where(Document.tenant_id == tenant_id)
|
||||
|
||||
if kb_id:
|
||||
stmt = stmt.where(Document.kb_id == kb_id)
|
||||
if status:
|
||||
stmt = stmt.where(Document.status == status)
|
||||
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total_result = await self._session.execute(count_stmt)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
stmt = stmt.order_by(col(Document.created_at).desc())
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
documents = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-08] Listed documents: tenant={tenant_id}, "
|
||||
f"kb_id={kb_id}, status={status}, total={total}"
|
||||
)
|
||||
|
||||
return documents, total
|
||||
|
||||
async def get_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Get document by ID.
|
||||
"""
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.id == uuid.UUID(doc_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_index_job(
|
||||
self,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
) -> IndexJob | None:
|
||||
"""
|
||||
[AC-ASA-02] Get index job status.
|
||||
"""
|
||||
stmt = select(IndexJob).where(
|
||||
IndexJob.tenant_id == tenant_id,
|
||||
IndexJob.id == uuid.UUID(job_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
job = result.scalar_one_or_none()
|
||||
|
||||
if job:
|
||||
logger.info(
|
||||
f"[AC-ASA-02] Got job status: tenant={tenant_id}, "
|
||||
f"job_id={job_id}, status={job.status}, progress={job.progress}"
|
||||
)
|
||||
|
||||
return job
|
||||
|
||||
async def get_index_job_by_doc(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
) -> IndexJob | None:
|
||||
"""
|
||||
Get index job by document ID.
|
||||
"""
|
||||
stmt = select(IndexJob).where(
|
||||
IndexJob.tenant_id == tenant_id,
|
||||
IndexJob.doc_id == uuid.UUID(doc_id),
|
||||
).order_by(col(IndexJob.created_at).desc())
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_job_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
status: str,
|
||||
progress: int | None = None,
|
||||
error_msg: str | None = None,
|
||||
) -> IndexJob | None:
|
||||
"""
|
||||
Update index job status.
|
||||
"""
|
||||
stmt = select(IndexJob).where(
|
||||
IndexJob.tenant_id == tenant_id,
|
||||
IndexJob.id == uuid.UUID(job_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
job = result.scalar_one_or_none()
|
||||
|
||||
if job:
|
||||
job.status = status
|
||||
job.updated_at = datetime.utcnow()
|
||||
if progress is not None:
|
||||
job.progress = progress
|
||||
if error_msg is not None:
|
||||
job.error_msg = error_msg
|
||||
await self._session.flush()
|
||||
|
||||
if job.doc_id:
|
||||
doc_stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.id == job.doc_id,
|
||||
)
|
||||
doc_result = await self._session.execute(doc_stmt)
|
||||
doc = doc_result.scalar_one_or_none()
|
||||
if doc:
|
||||
doc.status = status
|
||||
doc.updated_at = datetime.utcnow()
|
||||
if error_msg:
|
||||
doc.error_msg = error_msg
|
||||
await self._session.flush()
|
||||
|
||||
return job
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete document and associated files.
|
||||
"""
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.id == uuid.UUID(doc_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if not document:
|
||||
return False
|
||||
|
||||
if document.file_path and os.path.exists(document.file_path):
|
||||
os.remove(document.file_path)
|
||||
|
||||
await self._session.delete(document)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}")
|
||||
return True
|
||||
|
||||
async def list_knowledge_bases(
|
||||
self,
|
||||
tenant_id: str,
|
||||
) -> Sequence[KnowledgeBase]:
|
||||
"""
|
||||
List all knowledge bases for a tenant.
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.tenant_id == tenant_id
|
||||
).order_by(col(KnowledgeBase.created_at).desc())
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""
|
||||
LLM Adapter module for AI Service.
|
||||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
|
||||
"""
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
__all__ = [
|
||||
"LLMClient",
|
||||
"LLMConfig",
|
||||
"LLMResponse",
|
||||
"LLMStreamChunk",
|
||||
"OpenAIClient",
|
||||
]
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Base LLM client interface.
|
||||
[AC-AISVC-02, AC-AISVC-06] Abstract interface for LLM providers.
|
||||
|
||||
Design reference: design.md Section 8.1 - LLMClient interface
|
||||
- generate(prompt, params) -> text
|
||||
- stream_generate(prompt, params) -> iterator[delta]
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""
|
||||
Configuration for LLM client.
|
||||
[AC-AISVC-02] Supports configurable model parameters.
|
||||
"""
|
||||
model: str = "gpt-4o-mini"
|
||||
max_tokens: int = 2048
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
timeout_seconds: int = 30
|
||||
max_retries: int = 3
|
||||
extra_params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""
|
||||
Response from LLM generation.
|
||||
[AC-AISVC-02] Contains generated content and metadata.
|
||||
"""
|
||||
content: str
|
||||
model: str
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
finish_reason: str = "stop"
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamChunk:
|
||||
"""
|
||||
Streaming chunk from LLM.
|
||||
[AC-AISVC-06, AC-AISVC-07] Incremental output for SSE streaming.
|
||||
"""
|
||||
delta: str
|
||||
model: str
|
||||
finish_reason: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""
|
||||
Abstract base class for LLM clients.
|
||||
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for different LLM providers.
|
||||
|
||||
Design reference: design.md Section 8.2 - Plugin points
|
||||
- OpenAICompatibleClient / LocalModelClient can be swapped
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-02] Returns complete response for ChatResponse.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
LLMStreamChunk with incremental content.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the client and release resources."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,421 @@
|
|||
"""
|
||||
LLM Provider Factory and Configuration Management.
|
||||
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management.
|
||||
|
||||
Design pattern: Factory pattern for pluggable LLM providers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.services.llm.base import LLMClient, LLMConfig
|
||||
from app.services.llm.openai_client import OpenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProviderInfo:
|
||||
"""Information about an LLM provider."""
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
config_schema: dict[str, Any]
|
||||
|
||||
|
||||
LLM_PROVIDERS: dict[str, LLMProviderInfo] = {
|
||||
"openai": LLMProviderInfo(
|
||||
name="openai",
|
||||
display_name="OpenAI",
|
||||
description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)",
|
||||
config_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"title": "API Key",
|
||||
"description": "API Key",
|
||||
"required": True,
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "API Base URL",
|
||||
"description": "API Base URL",
|
||||
"default": "https://api.openai.com/v1",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "模型名称",
|
||||
"default": "gpt-4o-mini",
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "最大输出 Token 数",
|
||||
"description": "最大输出 Token 数",
|
||||
"default": 2048,
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "温度参数",
|
||||
"description": "温度参数 (0-2)",
|
||||
"default": 0.7,
|
||||
"minimum": 0,
|
||||
"maximum": 2,
|
||||
},
|
||||
},
|
||||
"required": ["api_key"],
|
||||
},
|
||||
),
|
||||
"ollama": LLMProviderInfo(
|
||||
name="ollama",
|
||||
display_name="Ollama",
|
||||
description="Ollama 本地模型 (Llama, Qwen 等)",
|
||||
config_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "Ollama API 地址",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434/v1",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "模型名称",
|
||||
"default": "llama3.2",
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "最大输出 Token 数",
|
||||
"description": "最大输出 Token 数",
|
||||
"default": 2048,
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "温度参数",
|
||||
"description": "温度参数 (0-2)",
|
||||
"default": 0.7,
|
||||
"minimum": 0,
|
||||
"maximum": 2,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
"deepseek": LLMProviderInfo(
|
||||
name="deepseek",
|
||||
display_name="DeepSeek",
|
||||
description="DeepSeek 大模型 (deepseek-chat, deepseek-coder)",
|
||||
config_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"title": "API Key",
|
||||
"description": "DeepSeek API Key",
|
||||
"required": True,
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "API Base URL",
|
||||
"description": "API Base URL",
|
||||
"default": "https://api.deepseek.com/v1",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "模型名称",
|
||||
"description": "模型名称 (deepseek-chat, deepseek-coder)",
|
||||
"default": "deepseek-chat",
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "最大输出 Token 数",
|
||||
"description": "最大输出 Token 数",
|
||||
"default": 2048,
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "温度参数",
|
||||
"description": "温度参数 (0-2)",
|
||||
"default": 0.7,
|
||||
"minimum": 0,
|
||||
"maximum": 2,
|
||||
},
|
||||
},
|
||||
"required": ["api_key"],
|
||||
},
|
||||
),
|
||||
"azure": LLMProviderInfo(
|
||||
name="azure",
|
||||
display_name="Azure OpenAI",
|
||||
description="Azure OpenAI 服务",
|
||||
config_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"title": "API Key",
|
||||
"description": "API Key",
|
||||
"required": True,
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "Azure Endpoint",
|
||||
"description": "Azure Endpoint",
|
||||
"required": True,
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "部署名称",
|
||||
"description": "部署名称",
|
||||
"required": True,
|
||||
},
|
||||
"api_version": {
|
||||
"type": "string",
|
||||
"title": "API 版本",
|
||||
"description": "API 版本",
|
||||
"default": "2024-02-15-preview",
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "最大输出 Token 数",
|
||||
"description": "最大输出 Token 数",
|
||||
"default": 2048,
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "温度参数",
|
||||
"description": "温度参数 (0-2)",
|
||||
"default": 0.7,
|
||||
"minimum": 0,
|
||||
"maximum": 2,
|
||||
},
|
||||
},
|
||||
"required": ["api_key", "base_url", "model"],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class LLMProviderFactory:
|
||||
"""
|
||||
Factory for creating LLM clients.
|
||||
[AC-ASA-14, AC-ASA-15] Dynamic provider creation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_providers(cls) -> list[LLMProviderInfo]:
|
||||
"""Get all registered LLM providers."""
|
||||
return list(LLM_PROVIDERS.values())
|
||||
|
||||
@classmethod
|
||||
def get_provider_info(cls, name: str) -> LLMProviderInfo | None:
|
||||
"""Get provider info by name."""
|
||||
return LLM_PROVIDERS.get(name)
|
||||
|
||||
@classmethod
|
||||
def create_client(
|
||||
cls,
|
||||
provider: str,
|
||||
config: dict[str, Any],
|
||||
) -> LLMClient:
|
||||
"""
|
||||
Create an LLM client for the specified provider.
|
||||
[AC-ASA-15] Factory method for client creation.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, ollama, azure)
|
||||
config: Provider configuration
|
||||
|
||||
Returns:
|
||||
LLMClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider not in LLM_PROVIDERS:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
if provider in ("openai", "ollama", "azure", "deepseek"):
|
||||
return OpenAIClient(
|
||||
api_key=config.get("api_key"),
|
||||
base_url=config.get("base_url"),
|
||||
model=config.get("model"),
|
||||
default_config=LLMConfig(
|
||||
model=config.get("model", "gpt-4o-mini"),
|
||||
max_tokens=config.get("max_tokens", 2048),
|
||||
temperature=config.get("temperature", 0.7),
|
||||
),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
class LLMConfigManager:
|
||||
"""
|
||||
Manager for LLM configuration.
|
||||
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
self._current_provider: str = settings.llm_provider
|
||||
self._current_config: dict[str, Any] = {
|
||||
"api_key": settings.llm_api_key,
|
||||
"base_url": settings.llm_base_url,
|
||||
"model": settings.llm_model,
|
||||
"max_tokens": settings.llm_max_tokens,
|
||||
"temperature": settings.llm_temperature,
|
||||
}
|
||||
self._client: LLMClient | None = None
|
||||
|
||||
def get_current_config(self) -> dict[str, Any]:
|
||||
"""Get current LLM configuration."""
|
||||
return {
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}
|
||||
|
||||
async def update_config(
|
||||
self,
|
||||
provider: str,
|
||||
config: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Update LLM configuration.
|
||||
[AC-ASA-16] Hot-reload configuration.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
config: New configuration
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
"""
|
||||
if provider not in LLM_PROVIDERS:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
provider_info = LLM_PROVIDERS[provider]
|
||||
validated_config = self._validate_config(provider_info, config)
|
||||
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
self._current_provider = provider
|
||||
self._current_config = validated_config
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||
return True
|
||||
|
||||
def _validate_config(
|
||||
self,
|
||||
provider_info: LLMProviderInfo,
|
||||
config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate configuration against provider schema."""
|
||||
schema_props = provider_info.config_schema.get("properties", {})
|
||||
required_fields = provider_info.config_schema.get("required", [])
|
||||
|
||||
validated = {}
|
||||
for key, prop_schema in schema_props.items():
|
||||
if key in config:
|
||||
validated[key] = config[key]
|
||||
elif "default" in prop_schema:
|
||||
validated[key] = prop_schema["default"]
|
||||
elif key in required_fields:
|
||||
raise ValueError(f"Missing required config: {key}")
|
||||
return validated
|
||||
|
||||
def get_client(self) -> LLMClient:
|
||||
"""Get or create LLM client with current config."""
|
||||
if self._client is None:
|
||||
self._client = LLMProviderFactory.create_client(
|
||||
self._current_provider,
|
||||
self._current_config,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def test_connection(
|
||||
self,
|
||||
test_prompt: str = "你好,请简单介绍一下自己。",
|
||||
provider: str | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Test LLM connection.
|
||||
[AC-ASA-17, AC-ASA-18] Connection testing.
|
||||
|
||||
Args:
|
||||
test_prompt: Test prompt to send
|
||||
provider: Optional provider to test (uses current if not specified)
|
||||
config: Optional config to test (uses current if not specified)
|
||||
|
||||
Returns:
|
||||
Test result with success status, response, and metrics
|
||||
"""
|
||||
import time
|
||||
|
||||
test_provider = provider or self._current_provider
|
||||
test_config = config if config else self._current_config
|
||||
|
||||
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, config={test_config}")
|
||||
|
||||
if test_provider not in LLM_PROVIDERS:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported provider: {test_provider}",
|
||||
}
|
||||
|
||||
try:
|
||||
client = LLMProviderFactory.create_client(test_provider, test_config)
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.generate(
|
||||
messages=[{"role": "user", "content": test_prompt}],
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
await client.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"response": response.content,
|
||||
"latency_ms": round(latency_ms, 2),
|
||||
"prompt_tokens": response.usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": response.usage.get("completion_tokens", 0),
|
||||
"total_tokens": response.usage.get("total_tokens", 0),
|
||||
"model": response.model,
|
||||
"message": f"连接成功,模型: {response.model}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-18] LLM test failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"连接失败: {str(e)}",
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the current client."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
|
||||
_llm_config_manager: LLMConfigManager | None = None
|
||||
|
||||
|
||||
def get_llm_config_manager() -> LLMConfigManager:
|
||||
"""Get or create LLM config manager instance."""
|
||||
global _llm_config_manager
|
||||
if _llm_config_manager is None:
|
||||
_llm_config_manager = LLMConfigManager()
|
||||
return _llm_config_manager
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""
|
||||
OpenAI-compatible LLM client implementation.
|
||||
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
|
||||
|
||||
Design reference: design.md Section 8.1 - LLMClient interface
|
||||
- Uses langchain-openai or official SDK pattern
|
||||
- Supports generate and stream_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import AIServiceException, ErrorCode, ServiceUnavailableException, TimeoutException
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMException(AIServiceException):
|
||||
"""Exception raised when LLM operations fail."""
|
||||
|
||||
def __init__(self, message: str, details: list[dict] | None = None):
|
||||
super().__init__(
|
||||
code=ErrorCode.LLM_ERROR,
|
||||
message=message,
|
||||
status_code=503,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
"""
|
||||
OpenAI-compatible LLM client.
|
||||
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
|
||||
|
||||
Supports:
|
||||
- OpenAI API (official)
|
||||
- OpenAI-compatible endpoints (Azure, local models, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
default_config: LLMConfig | None = None,
|
||||
):
|
||||
settings = get_settings()
|
||||
self._api_key = api_key or settings.llm_api_key
|
||||
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
|
||||
self._model = model or settings.llm_model
|
||||
self._default_config = default_config or LLMConfig(
|
||||
model=self._model,
|
||||
max_tokens=settings.llm_max_tokens,
|
||||
temperature=settings.llm_temperature,
|
||||
timeout_seconds=settings.llm_timeout_seconds,
|
||||
max_retries=settings.llm_max_retries,
|
||||
)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout_seconds),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _build_request_body(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request body for OpenAI API."""
|
||||
body: dict[str, Any] = {
|
||||
"model": config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"top_p": config.top_p,
|
||||
"stream": stream,
|
||||
}
|
||||
body.update(config.extra_params)
|
||||
body.update(kwargs)
|
||||
return body
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.TimeoutException),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
)
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-02] Returns complete response for ChatResponse.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
LLMResponse with generated content and metadata.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
TimeoutException: If request times out.
|
||||
"""
|
||||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||
logger.info(f"[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-02] ======================================")
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self._base_url}/chat/completions",
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
|
||||
raise TimeoutException(message=f"LLM request timed out: {e}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
|
||||
error_detail = self._parse_error_response(e.response)
|
||||
raise LLMException(
|
||||
message=f"LLM API error: {error_detail}",
|
||||
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
|
||||
raise LLMException(message=f"Failed to parse LLM response: {e}")
|
||||
|
||||
try:
|
||||
choice = data["choices"][0]
|
||||
content = choice["message"]["content"]
|
||||
usage = data.get("usage", {})
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Generated response: "
|
||||
f"tokens={usage.get('total_tokens', 'N/A')}, "
|
||||
f"finish_reason={finish_reason}"
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", effective_config.model),
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
metadata={"raw_response": data},
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
|
||||
raise LLMException(
|
||||
message=f"Unexpected LLM response format: {e}",
|
||||
details=[{"response": str(data)}],
|
||||
)
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
config: LLMConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages with 'role' and 'content'.
|
||||
config: Optional LLM configuration overrides.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
|
||||
Yields:
|
||||
LLMStreamChunk with incremental content.
|
||||
|
||||
Raises:
|
||||
LLMException: If generation fails.
|
||||
TimeoutException: If request times out.
|
||||
"""
|
||||
effective_config = config or self._default_config
|
||||
client = self._get_client(effective_config.timeout_seconds)
|
||||
|
||||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||
logger.info(f"[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-06] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-06] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-06] ======================================")
|
||||
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self._base_url}/chat/completions",
|
||||
json=body,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
json_str = line[6:]
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
|
||||
if chunk:
|
||||
yield chunk
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
||||
continue
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
|
||||
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
|
||||
error_detail = self._parse_error_response(e.response)
|
||||
raise LLMException(
|
||||
message=f"LLM streaming API error: {error_detail}",
|
||||
details=[{"status_code": e.response.status_code, "response": error_detail}],
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Streaming generation completed")
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
model: str,
|
||||
) -> LLMStreamChunk | None:
|
||||
"""Parse a streaming chunk from OpenAI API."""
|
||||
try:
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
|
||||
if not content and not finish_reason:
|
||||
return None
|
||||
|
||||
return LLMStreamChunk(
|
||||
delta=content,
|
||||
model=data.get("model", model),
|
||||
finish_reason=finish_reason,
|
||||
metadata={"raw_chunk": data},
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
|
||||
return None
|
||||
|
||||
def _parse_error_response(self, response: httpx.Response) -> str:
|
||||
"""Parse error response from API."""
|
||||
try:
|
||||
data = response.json()
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
if isinstance(error, dict):
|
||||
return error.get("message", str(error))
|
||||
return str(error)
|
||||
return response.text
|
||||
except Exception:
|
||||
return response.text
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
_llm_client: OpenAIClient | None = None
|
||||
|
||||
|
||||
def get_llm_client() -> OpenAIClient:
|
||||
"""Get or create LLM client instance."""
|
||||
global _llm_client
|
||||
if _llm_client is None:
|
||||
_llm_client = OpenAIClient()
|
||||
return _llm_client
|
||||
|
||||
|
||||
async def close_llm_client() -> None:
|
||||
"""Close the global LLM client."""
|
||||
global _llm_client
|
||||
if _llm_client:
|
||||
await _llm_client.close()
|
||||
_llm_client = None
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""
|
||||
Memory service for AI Service.
|
||||
[AC-AISVC-13] Session-based memory management with tenant isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col
|
||||
|
||||
from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
[AC-AISVC-13] Memory service for session-based conversation history.
|
||||
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def get_or_create_session(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
channel_type: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
[AC-AISVC-13] Get existing session or create a new one.
|
||||
Ensures tenant isolation by querying with tenant_id.
|
||||
"""
|
||||
stmt = select(ChatSession).where(
|
||||
ChatSession.tenant_id == tenant_id,
|
||||
ChatSession.session_id == session_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
existing_session = result.scalar_one_or_none()
|
||||
|
||||
if existing_session:
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return existing_session
|
||||
|
||||
new_session = ChatSession(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
channel_type=channel_type,
|
||||
metadata_=metadata,
|
||||
)
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return new_session
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
limit: int | None = None,
|
||||
) -> Sequence[ChatMessage]:
|
||||
"""
|
||||
[AC-AISVC-13] Load conversation history for a session.
|
||||
All queries are filtered by tenant_id to ensure isolation.
|
||||
"""
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
.order_by(col(ChatMessage.created_at).asc())
|
||||
)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
[AC-AISVC-13] Append a message to the session history.
|
||||
Message is scoped by tenant_id for isolation.
|
||||
"""
|
||||
message = ChatMessage(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
|
||||
)
|
||||
return message
|
||||
|
||||
async def append_messages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
) -> list[ChatMessage]:
|
||||
"""
|
||||
[AC-AISVC-13] Append multiple messages to the session history.
|
||||
Used for batch insertion of conversation turns.
|
||||
"""
|
||||
chat_messages = []
|
||||
for msg in messages:
|
||||
message = ChatMessage(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
role=msg["role"],
|
||||
content=msg["content"],
|
||||
)
|
||||
self._session.add(message)
|
||||
chat_messages.append(message)
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return chat_messages
|
||||
|
||||
async def clear_history(self, tenant_id: str, session_id: str) -> int:
|
||||
"""
|
||||
[AC-AISVC-13] Clear all messages for a session.
|
||||
Only affects messages within the tenant's scope.
|
||||
"""
|
||||
stmt = select(ChatMessage).where(
|
||||
ChatMessage.tenant_id == tenant_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for message in messages:
|
||||
await self._session.delete(message)
|
||||
count += 1
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
|
||||
)
|
||||
return count
|
||||
|
|
@ -0,0 +1,689 @@
|
|||
"""
|
||||
Orchestrator service for AI Service.
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
|
||||
|
||||
Design reference: design.md Section 2.2 - 关键数据流
|
||||
1. Memory.load(tenantId, sessionId)
|
||||
2. merge_context(local_history, external_history)
|
||||
3. Retrieval.retrieve(query, tenantId, channelType, metadata)
|
||||
4. build_prompt(merged_history, retrieved_docs, currentMessage)
|
||||
5. LLM.generate(...) (non-streaming) or LLM.stream_generate(...) (streaming)
|
||||
6. compute_confidence(...)
|
||||
7. Memory.append(tenantId, sessionId, user/assistant messages)
|
||||
8. Return ChatResponse (or output via SSE)
|
||||
|
||||
RAG Optimization (rag-optimization/spec.md):
|
||||
- Two-stage retrieval with Matryoshka dimensions
|
||||
- RRF hybrid ranking
|
||||
- Optimized prompt engineering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from sse_starlette.sse import ServerSentEvent
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.prompts import SYSTEM_PROMPT, format_evidence_for_prompt
|
||||
from app.core.sse import (
|
||||
create_error_event,
|
||||
create_final_event,
|
||||
create_message_event,
|
||||
SSEStateMachine,
|
||||
)
|
||||
from app.models import ChatRequest, ChatResponse
|
||||
from app.services.confidence import ConfidenceCalculator, ConfidenceResult
|
||||
from app.services.context import ContextMerger, MergedContext
|
||||
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
|
||||
from app.services.memory import MemoryService
|
||||
from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
"""
|
||||
Configuration for OrchestratorService.
|
||||
[AC-AISVC-01] Centralized configuration for orchestration.
|
||||
"""
|
||||
max_history_tokens: int = 4000
|
||||
max_evidence_tokens: int = 2000
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
enable_rag: bool = True
|
||||
use_optimized_retriever: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationContext:
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline.
|
||||
Contains all intermediate results for diagnostics and response building.
|
||||
"""
|
||||
tenant_id: str
|
||||
session_id: str
|
||||
current_message: str
|
||||
channel_type: str
|
||||
request_metadata: dict[str, Any] | None = None
|
||||
|
||||
local_history: list[dict[str, str]] = field(default_factory=list)
|
||||
merged_context: MergedContext | None = None
|
||||
retrieval_result: RetrievalResult | None = None
|
||||
llm_response: LLMResponse | None = None
|
||||
confidence_result: ConfidenceResult | None = None
|
||||
|
||||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OrchestratorService:
|
||||
"""
|
||||
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
|
||||
Coordinates memory, retrieval, and LLM components.
|
||||
|
||||
SSE Event Flow (per design.md Section 6.2):
|
||||
- message* (0 or more) -> final (exactly 1) -> close
|
||||
- OR message* (0 or more) -> error (exactly 1) -> close
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: LLMClient | None = None,
|
||||
memory_service: MemoryService | None = None,
|
||||
retriever: BaseRetriever | None = None,
|
||||
context_merger: ContextMerger | None = None,
|
||||
confidence_calculator: ConfidenceCalculator | None = None,
|
||||
config: OrchestratorConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator with optional dependencies for DI.
|
||||
|
||||
Args:
|
||||
llm_client: LLM client for generation
|
||||
memory_service: Memory service for session history
|
||||
retriever: Retriever for RAG
|
||||
context_merger: Context merger for history deduplication
|
||||
confidence_calculator: Confidence calculator for response scoring
|
||||
config: Orchestrator configuration
|
||||
"""
|
||||
settings = get_settings()
|
||||
self._llm_client = llm_client
|
||||
self._memory_service = memory_service
|
||||
self._retriever = retriever
|
||||
self._context_merger = context_merger or ContextMerger(
|
||||
max_history_tokens=getattr(settings, "max_history_tokens", 4000)
|
||||
)
|
||||
self._confidence_calculator = confidence_calculator or ConfidenceCalculator()
|
||||
self._config = config or OrchestratorConfig(
|
||||
max_history_tokens=getattr(settings, "max_history_tokens", 4000),
|
||||
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
|
||||
enable_rag=True,
|
||||
)
|
||||
self._llm_config = LLMConfig(
|
||||
model=getattr(settings, "llm_model", "gpt-4o-mini"),
|
||||
max_tokens=getattr(settings, "llm_max_tokens", 2048),
|
||||
temperature=getattr(settings, "llm_temperature", 0.7),
|
||||
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
|
||||
max_retries=getattr(settings, "llm_max_retries", 3),
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
tenant_id: str,
|
||||
request: ChatRequest,
|
||||
) -> ChatResponse:
|
||||
"""
|
||||
Generate a non-streaming response.
|
||||
[AC-AISVC-01, AC-AISVC-02] Complete generation pipeline.
|
||||
|
||||
Pipeline (per design.md Section 2.2):
|
||||
1. Load local history from Memory
|
||||
2. Merge with external history (dedup + truncate)
|
||||
3. RAG retrieval (optional)
|
||||
4. Build prompt with context and evidence
|
||||
5. LLM generation
|
||||
6. Calculate confidence
|
||||
7. Save messages to Memory
|
||||
8. Return ChatResponse
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
|
||||
f"session={request.session_id}, channel_type={request.channel_type}, "
|
||||
f"current_message={request.current_message[:100]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
|
||||
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
|
||||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
|
||||
f"retriever={'configured' if self._retriever else 'NOT configured'}"
|
||||
)
|
||||
|
||||
ctx = GenerationContext(
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
current_message=request.current_message,
|
||||
channel_type=request.channel_type.value,
|
||||
request_metadata=request.metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._load_local_history(ctx)
|
||||
|
||||
await self._merge_context(ctx, request.history)
|
||||
|
||||
if self._config.enable_rag and self._retriever:
|
||||
await self._retrieve_evidence(ctx)
|
||||
|
||||
await self._generate_response(ctx)
|
||||
|
||||
self._calculate_confidence(ctx)
|
||||
|
||||
await self._save_messages(ctx)
|
||||
|
||||
return self._build_response(ctx)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-01] Generation failed: {e}")
|
||||
return ChatResponse(
|
||||
reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
|
||||
confidence=0.0,
|
||||
should_transfer=True,
|
||||
transfer_reason=f"服务异常: {str(e)}",
|
||||
metadata={"error": str(e), "diagnostics": ctx.diagnostics},
|
||||
)
|
||||
|
||||
async def _load_local_history(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-13] Load local history from Memory service.
|
||||
Step 1 of the generation pipeline.
|
||||
"""
|
||||
if not self._memory_service:
|
||||
logger.info("[AC-AISVC-13] No memory service configured, skipping history load")
|
||||
ctx.diagnostics["memory_enabled"] = False
|
||||
return
|
||||
|
||||
try:
|
||||
messages = await self._memory_service.load_history(
|
||||
tenant_id=ctx.tenant_id,
|
||||
session_id=ctx.session_id,
|
||||
)
|
||||
|
||||
ctx.local_history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
ctx.diagnostics["memory_enabled"] = True
|
||||
ctx.diagnostics["local_history_count"] = len(ctx.local_history)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory "
|
||||
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-13] Failed to load history: {e}")
|
||||
ctx.diagnostics["memory_error"] = str(e)
|
||||
|
||||
async def _merge_context(
|
||||
self,
|
||||
ctx: GenerationContext,
|
||||
external_history: list | None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-AISVC-14, AC-AISVC-15] Merge local and external history.
|
||||
Step 2 of the generation pipeline.
|
||||
|
||||
Design reference: design.md Section 7
|
||||
- Deduplication based on fingerprint
|
||||
- Truncation to fit token budget
|
||||
"""
|
||||
external_messages = None
|
||||
if external_history:
|
||||
external_messages = [
|
||||
{"role": msg.role.value, "content": msg.content}
|
||||
for msg in external_history
|
||||
]
|
||||
|
||||
ctx.merged_context = self._context_merger.merge_and_truncate(
|
||||
local_history=ctx.local_history,
|
||||
external_history=external_messages,
|
||||
max_tokens=self._config.max_history_tokens,
|
||||
)
|
||||
|
||||
ctx.diagnostics["merged_context"] = {
|
||||
"local_count": ctx.merged_context.local_count,
|
||||
"external_count": ctx.merged_context.external_count,
|
||||
"duplicates_skipped": ctx.merged_context.duplicates_skipped,
|
||||
"truncated_count": ctx.merged_context.truncated_count,
|
||||
"total_tokens": ctx.merged_context.total_tokens,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||||
f"local={ctx.merged_context.local_count}, "
|
||||
f"external={ctx.merged_context.external_count}, "
|
||||
f"tokens={ctx.merged_context.total_tokens}"
|
||||
)
|
||||
|
||||
async def _retrieve_evidence(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
||||
Step 3 of the generation pipeline.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
|
||||
)
|
||||
try:
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query=ctx.current_message,
|
||||
session_id=ctx.session_id,
|
||||
channel_type=ctx.channel_type,
|
||||
metadata=ctx.request_metadata,
|
||||
)
|
||||
|
||||
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
|
||||
|
||||
ctx.diagnostics["retrieval"] = {
|
||||
"hit_count": ctx.retrieval_result.hit_count,
|
||||
"max_score": ctx.retrieval_result.max_score,
|
||||
"is_empty": ctx.retrieval_result.is_empty,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
|
||||
f"hits={ctx.retrieval_result.hit_count}, "
|
||||
f"max_score={ctx.retrieval_result.max_score:.3f}, "
|
||||
f"is_empty={ctx.retrieval_result.is_empty}"
|
||||
)
|
||||
|
||||
if ctx.retrieval_result.hit_count > 0:
|
||||
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
|
||||
f"text_preview={hit.text[:100]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
|
||||
ctx.retrieval_result = RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e)},
|
||||
)
|
||||
ctx.diagnostics["retrieval_error"] = str(e)
|
||||
|
||||
async def _generate_response(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-02] Generate response using LLM.
|
||||
Step 4-5 of the generation pipeline.
|
||||
"""
|
||||
messages = self._build_llm_messages(ctx)
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
|
||||
f"has_retrieval_result={ctx.retrieval_result is not None}, "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
|
||||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
|
||||
)
|
||||
|
||||
if not self._llm_client:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-02] No LLM client configured, using fallback. "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
|
||||
)
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=self._fallback_response(ctx),
|
||||
model="fallback",
|
||||
usage={},
|
||||
finish_reason="fallback",
|
||||
)
|
||||
ctx.diagnostics["llm_mode"] = "fallback"
|
||||
ctx.diagnostics["fallback_reason"] = "no_llm_client"
|
||||
return
|
||||
|
||||
try:
|
||||
ctx.llm_response = await self._llm_client.generate(
|
||||
messages=messages,
|
||||
config=self._llm_config,
|
||||
)
|
||||
ctx.diagnostics["llm_mode"] = "live"
|
||||
ctx.diagnostics["llm_model"] = ctx.llm_response.model
|
||||
ctx.diagnostics["llm_usage"] = ctx.llm_response.usage
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] LLM response generated: "
|
||||
f"model={ctx.llm_response.model}, "
|
||||
f"tokens={ctx.llm_response.usage}, "
|
||||
f"content_preview={ctx.llm_response.content[:100]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-AISVC-02] LLM generation failed: {e}, "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
|
||||
exc_info=True
|
||||
)
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=self._fallback_response(ctx),
|
||||
model="fallback",
|
||||
usage={},
|
||||
finish_reason="error",
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
ctx.diagnostics["llm_error"] = str(e)
|
||||
ctx.diagnostics["llm_mode"] = "fallback"
|
||||
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
|
||||
|
||||
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
|
||||
"""
|
||||
[AC-AISVC-02] Build messages for LLM including system prompt and evidence.
|
||||
"""
|
||||
messages = []
|
||||
|
||||
system_content = self._config.system_prompt
|
||||
|
||||
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
|
||||
evidence_text = self._format_evidence(ctx.retrieval_result)
|
||||
system_content += f"\n\n知识库参考内容:\n{evidence_text}"
|
||||
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
if ctx.merged_context and ctx.merged_context.messages:
|
||||
messages.extend(ctx.merged_context.messages)
|
||||
|
||||
messages.append({"role": "user", "content": ctx.current_message})
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Built {len(messages)} messages for LLM: "
|
||||
f"system_len={len(system_content)}, history_count={len(ctx.merged_context.messages) if ctx.merged_context else 0}"
|
||||
)
|
||||
logger.debug(f"[AC-AISVC-02] System prompt preview: {system_content[:500]}...")
|
||||
|
||||
logger.info(f"[AC-AISVC-02] ========== ORCHESTRATOR FULL PROMPT ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-02] ==============================================")
|
||||
|
||||
return messages
|
||||
|
||||
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
|
||||
"""
|
||||
[AC-AISVC-17] Format retrieval hits as evidence text.
|
||||
Uses shared prompt configuration for consistency.
|
||||
"""
|
||||
return format_evidence_for_prompt(retrieval_result.hits, max_results=5, max_content_length=500)
|
||||
|
||||
def _fallback_response(self, ctx: GenerationContext) -> str:
|
||||
"""
|
||||
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
|
||||
"""
|
||||
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
|
||||
return (
|
||||
"根据知识库信息,我找到了一些相关内容,"
|
||||
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
|
||||
)
|
||||
return (
|
||||
"抱歉,我暂时无法处理您的请求。"
|
||||
"请稍后重试或联系人工客服获取帮助。"
|
||||
)
|
||||
|
||||
def _calculate_confidence(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
|
||||
Step 6 of the generation pipeline.
|
||||
"""
|
||||
if ctx.retrieval_result:
|
||||
evidence_tokens = 0
|
||||
if not ctx.retrieval_result.is_empty:
|
||||
evidence_tokens = sum(
|
||||
len(hit.text.split()) * 2
|
||||
for hit in ctx.retrieval_result.hits
|
||||
)
|
||||
|
||||
ctx.confidence_result = self._confidence_calculator.calculate_confidence(
|
||||
retrieval_result=ctx.retrieval_result,
|
||||
evidence_tokens=evidence_tokens,
|
||||
)
|
||||
else:
|
||||
ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval()
|
||||
|
||||
ctx.diagnostics["confidence"] = {
|
||||
"score": ctx.confidence_result.confidence,
|
||||
"should_transfer": ctx.confidence_result.should_transfer,
|
||||
"is_insufficient": ctx.confidence_result.is_retrieval_insufficient,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
|
||||
f"{ctx.confidence_result.confidence:.3f}, "
|
||||
f"should_transfer={ctx.confidence_result.should_transfer}"
|
||||
)
|
||||
|
||||
async def _save_messages(self, ctx: GenerationContext) -> None:
|
||||
"""
|
||||
[AC-AISVC-13] Save user and assistant messages to Memory.
|
||||
Step 7 of the generation pipeline.
|
||||
"""
|
||||
if not self._memory_service:
|
||||
logger.info("[AC-AISVC-13] No memory service configured, skipping save")
|
||||
return
|
||||
|
||||
try:
|
||||
await self._memory_service.get_or_create_session(
|
||||
tenant_id=ctx.tenant_id,
|
||||
session_id=ctx.session_id,
|
||||
channel_type=ctx.channel_type,
|
||||
metadata=ctx.request_metadata,
|
||||
)
|
||||
|
||||
messages_to_save = [
|
||||
{"role": "user", "content": ctx.current_message},
|
||||
]
|
||||
|
||||
if ctx.llm_response:
|
||||
messages_to_save.append({
|
||||
"role": "assistant",
|
||||
"content": ctx.llm_response.content,
|
||||
})
|
||||
|
||||
await self._memory_service.append_messages(
|
||||
tenant_id=ctx.tenant_id,
|
||||
session_id=ctx.session_id,
|
||||
messages=messages_to_save,
|
||||
)
|
||||
|
||||
ctx.diagnostics["messages_saved"] = len(messages_to_save)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-13] Saved {len(messages_to_save)} messages "
|
||||
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}")
|
||||
ctx.diagnostics["save_error"] = str(e)
|
||||
|
||||
def _build_response(self, ctx: GenerationContext) -> ChatResponse:
|
||||
"""
|
||||
[AC-AISVC-02] Build final ChatResponse from generation context.
|
||||
Step 8 of the generation pipeline.
|
||||
"""
|
||||
reply = ctx.llm_response.content if ctx.llm_response else self._fallback_response(ctx)
|
||||
|
||||
confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5
|
||||
should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True
|
||||
transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None
|
||||
|
||||
response_metadata = {
|
||||
"session_id": ctx.session_id,
|
||||
"channel_type": ctx.channel_type,
|
||||
"diagnostics": ctx.diagnostics,
|
||||
}
|
||||
|
||||
return ChatResponse(
|
||||
reply=reply,
|
||||
confidence=confidence,
|
||||
should_transfer=should_transfer,
|
||||
transfer_reason=transfer_reason,
|
||||
metadata=response_metadata,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
tenant_id: str,
|
||||
request: ChatRequest,
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
Generate a streaming response.
|
||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||||
|
||||
SSE Event Sequence (per design.md Section 6.2):
|
||||
1. message events (multiple) - each with incremental delta
|
||||
2. final event (exactly 1) - with complete response
|
||||
3. connection close
|
||||
|
||||
OR on error:
|
||||
1. message events (0 or more)
|
||||
2. error event (exactly 1)
|
||||
3. connection close
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
state_machine = SSEStateMachine()
|
||||
await state_machine.transition_to_streaming()
|
||||
|
||||
ctx = GenerationContext(
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
current_message=request.current_message,
|
||||
channel_type=request.channel_type.value,
|
||||
request_metadata=request.metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._load_local_history(ctx)
|
||||
await self._merge_context(ctx, request.history)
|
||||
|
||||
if self._config.enable_rag and self._retriever:
|
||||
await self._retrieve_evidence(ctx)
|
||||
|
||||
full_reply = ""
|
||||
|
||||
if self._llm_client:
|
||||
async for event in self._stream_from_llm(ctx, state_machine):
|
||||
if event.event == "message":
|
||||
full_reply += self._extract_delta_from_event(event)
|
||||
yield event
|
||||
else:
|
||||
async for event in self._stream_mock_response(ctx, state_machine):
|
||||
if event.event == "message":
|
||||
full_reply += self._extract_delta_from_event(event)
|
||||
yield event
|
||||
|
||||
if ctx.llm_response is None:
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=full_reply,
|
||||
model="streaming",
|
||||
usage={},
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
self._calculate_confidence(ctx)
|
||||
|
||||
await self._save_messages(ctx)
|
||||
|
||||
if await state_machine.transition_to_final():
|
||||
yield create_final_event(
|
||||
reply=full_reply,
|
||||
confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5,
|
||||
should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False,
|
||||
transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||||
if await state_machine.transition_to_error():
|
||||
yield create_error_event(
|
||||
code="GENERATION_ERROR",
|
||||
message=str(e),
|
||||
)
|
||||
finally:
|
||||
await state_machine.close()
|
||||
|
||||
async def _stream_from_llm(
|
||||
self,
|
||||
ctx: GenerationContext,
|
||||
state_machine: SSEStateMachine,
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
|
||||
"""
|
||||
messages = self._build_llm_messages(ctx)
|
||||
|
||||
async for chunk in self._llm_client.stream_generate(messages, self._llm_config):
|
||||
if not state_machine.can_send_message():
|
||||
break
|
||||
|
||||
if chunk.delta:
|
||||
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
|
||||
yield create_message_event(delta=chunk.delta)
|
||||
|
||||
if chunk.finish_reason:
|
||||
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
|
||||
break
|
||||
|
||||
async def _stream_mock_response(
|
||||
self,
|
||||
ctx: GenerationContext,
|
||||
state_machine: SSEStateMachine,
|
||||
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""
|
||||
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
|
||||
Simulates LLM-style incremental output.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"]
|
||||
|
||||
for part in reply_parts:
|
||||
if not state_machine.can_send_message():
|
||||
break
|
||||
|
||||
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
|
||||
yield create_message_event(delta=part)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
|
||||
"""Extract delta content from a message event."""
|
||||
import json
|
||||
try:
|
||||
if event.data:
|
||||
data = json.loads(event.data)
|
||||
return data.get("delta", "")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
_orchestrator_service: OrchestratorService | None = None
|
||||
|
||||
|
||||
def get_orchestrator_service() -> OrchestratorService:
|
||||
"""Get or create orchestrator service instance."""
|
||||
global _orchestrator_service
|
||||
if _orchestrator_service is None:
|
||||
_orchestrator_service = OrchestratorService()
|
||||
return _orchestrator_service
|
||||
|
||||
|
||||
def set_orchestrator_service(service: OrchestratorService) -> None:
|
||||
"""Set orchestrator service instance for testing."""
|
||||
global _orchestrator_service
|
||||
_orchestrator_service = service
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""
|
||||
Retrieval module for AI Service.
|
||||
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
||||
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
|
||||
"""
|
||||
|
||||
from app.services.retrieval.base import (
|
||||
BaseRetriever,
|
||||
RetrievalContext,
|
||||
RetrievalHit,
|
||||
RetrievalResult,
|
||||
)
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
||||
from app.services.retrieval.metadata import (
|
||||
ChunkMetadata,
|
||||
ChunkMetadataModel,
|
||||
MetadataFilter,
|
||||
KnowledgeChunk,
|
||||
RetrieveRequest,
|
||||
RetrieveResult,
|
||||
RetrievalStrategy,
|
||||
)
|
||||
from app.services.retrieval.optimized_retriever import (
|
||||
OptimizedRetriever,
|
||||
get_optimized_retriever,
|
||||
TwoStageResult,
|
||||
RRFCombiner,
|
||||
)
|
||||
from app.services.retrieval.indexer import (
|
||||
KnowledgeIndexer,
|
||||
get_knowledge_indexer,
|
||||
IndexingProgress,
|
||||
IndexingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseRetriever",
|
||||
"RetrievalContext",
|
||||
"RetrievalHit",
|
||||
"RetrievalResult",
|
||||
"VectorRetriever",
|
||||
"get_vector_retriever",
|
||||
"ChunkMetadata",
|
||||
"MetadataFilter",
|
||||
"KnowledgeChunk",
|
||||
"RetrieveRequest",
|
||||
"RetrieveResult",
|
||||
"RetrievalStrategy",
|
||||
"OptimizedRetriever",
|
||||
"get_optimized_retriever",
|
||||
"TwoStageResult",
|
||||
"RRFCombiner",
|
||||
"KnowledgeIndexer",
|
||||
"get_knowledge_indexer",
|
||||
"IndexingProgress",
|
||||
"IndexingResult",
|
||||
]
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Retrieval layer for AI Service.
|
||||
[AC-AISVC-16] Abstract base class for retrievers with plugin point support.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalContext:
|
||||
"""
|
||||
[AC-AISVC-16] Context for retrieval operations.
|
||||
Contains all necessary information for retrieval plugins.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
channel_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalHit:
|
||||
"""
|
||||
[AC-AISVC-16] Single retrieval result hit.
|
||||
Unified structure for all retriever types.
|
||||
"""
|
||||
|
||||
text: str
|
||||
score: float
|
||||
source: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16] Result from retrieval operation.
|
||||
Contains hits and optional diagnostics.
|
||||
"""
|
||||
|
||||
hits: list[RetrievalHit] = field(default_factory=list)
|
||||
diagnostics: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if no hits were found."""
|
||||
return len(self.hits) == 0
|
||||
|
||||
@property
|
||||
def max_score(self) -> float:
|
||||
"""Get the maximum score among hits."""
|
||||
if not self.hits:
|
||||
return 0.0
|
||||
return max(hit.score for hit in self.hits)
|
||||
|
||||
@property
|
||||
def hit_count(self) -> int:
|
||||
"""Get the number of hits."""
|
||||
return len(self.hits)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""
|
||||
[AC-AISVC-16] Abstract base class for retrievers.
|
||||
Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
[AC-AISVC-16] Retrieve relevant documents for the given context.
|
||||
|
||||
Args:
|
||||
ctx: Retrieval context containing tenant_id, query, and optional metadata.
|
||||
|
||||
Returns:
|
||||
RetrievalResult with hits and optional diagnostics.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if the retriever is healthy and ready to serve requests.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
Knowledge base indexing service with optimized embedding.
|
||||
Reference: rag-optimization/spec.md Section 5.1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||
from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingProgress:
|
||||
"""Progress tracking for indexing jobs."""
|
||||
total_chunks: int = 0
|
||||
processed_chunks: int = 0
|
||||
failed_chunks: int = 0
|
||||
current_document: str = ""
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@property
|
||||
def progress_percent(self) -> int:
|
||||
if self.total_chunks == 0:
|
||||
return 0
|
||||
return int((self.processed_chunks / self.total_chunks) * 100)
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self) -> float:
|
||||
return (datetime.utcnow() - self.started_at).total_seconds()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingResult:
|
||||
"""Result of an indexing operation."""
|
||||
success: bool
|
||||
total_chunks: int
|
||||
indexed_chunks: int
|
||||
failed_chunks: int
|
||||
elapsed_seconds: float
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class KnowledgeIndexer:
|
||||
"""
|
||||
Knowledge base indexer with optimized embedding.
|
||||
|
||||
Features:
|
||||
- Task prefixes (search_document:) for document embedding
|
||||
- Multi-dimensional vectors (256/512/768)
|
||||
- Metadata support
|
||||
- Batch processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 10,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._embedding_provider = embedding_provider
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._batch_size = batch_size
|
||||
self._progress: IndexingProgress | None = None
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]:
|
||||
"""
|
||||
Split text into chunks for indexing.
|
||||
Each line becomes a separate chunk for better retrieval granularity.
|
||||
|
||||
Args:
|
||||
text: Full text to chunk
|
||||
metadata: Metadata to attach to each chunk
|
||||
|
||||
Returns:
|
||||
List of KnowledgeChunk objects
|
||||
"""
|
||||
chunks = []
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
lines = text.split('\n')
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < 10:
|
||||
continue
|
||||
|
||||
chunk = KnowledgeChunk(
|
||||
chunk_id=f"{doc_id}_{i}",
|
||||
document_id=doc_id,
|
||||
content=line,
|
||||
metadata=metadata or ChunkMetadata(),
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_text_by_lines(
|
||||
self,
|
||||
text: str,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
min_line_length: int = 10,
|
||||
merge_short_lines: bool = False,
|
||||
) -> list[KnowledgeChunk]:
|
||||
"""
|
||||
Split text by lines, each line is a separate chunk.
|
||||
|
||||
Args:
|
||||
text: Full text to chunk
|
||||
metadata: Metadata to attach to each chunk
|
||||
min_line_length: Minimum line length to be indexed
|
||||
merge_short_lines: Whether to merge consecutive short lines
|
||||
|
||||
Returns:
|
||||
List of KnowledgeChunk objects
|
||||
"""
|
||||
chunks = []
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
lines = text.split('\n')
|
||||
|
||||
if merge_short_lines:
|
||||
merged_lines = []
|
||||
current_line = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
if current_line:
|
||||
merged_lines.append(current_line)
|
||||
current_line = ""
|
||||
continue
|
||||
|
||||
if current_line:
|
||||
current_line += " " + line
|
||||
else:
|
||||
current_line = line
|
||||
|
||||
if len(current_line) >= min_line_length * 2:
|
||||
merged_lines.append(current_line)
|
||||
current_line = ""
|
||||
|
||||
if current_line:
|
||||
merged_lines.append(current_line)
|
||||
|
||||
lines = merged_lines
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < min_line_length:
|
||||
continue
|
||||
|
||||
chunk = KnowledgeChunk(
|
||||
chunk_id=f"{doc_id}_{i}",
|
||||
document_id=doc_id,
|
||||
content=line,
|
||||
metadata=metadata or ChunkMetadata(),
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def index_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
document_id: str,
|
||||
text: str,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
) -> IndexingResult:
|
||||
"""
|
||||
Index a single document with optimized embedding.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
document_id: Document identifier
|
||||
text: Document text content
|
||||
metadata: Optional metadata for the document
|
||||
|
||||
Returns:
|
||||
IndexingResult with status and statistics
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
provider = await self._get_embedding_provider()
|
||||
|
||||
await client.ensure_collection_exists(tenant_id, use_multi_vector=True)
|
||||
|
||||
chunks = self.chunk_text(text, metadata)
|
||||
|
||||
self._progress = IndexingProgress(
|
||||
total_chunks=len(chunks),
|
||||
current_document=document_id,
|
||||
)
|
||||
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
embedding_result = await provider.embed_document(chunk.content)
|
||||
|
||||
chunk.embedding_full = embedding_result.embedding_full
|
||||
chunk.embedding_256 = embedding_result.embedding_256
|
||||
chunk.embedding_512 = embedding_result.embedding_512
|
||||
|
||||
point = {
|
||||
"id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant
|
||||
"vector": {
|
||||
"full": chunk.embedding_full,
|
||||
"dim_256": chunk.embedding_256,
|
||||
"dim_512": chunk.embedding_512,
|
||||
},
|
||||
"payload": {
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"document_id": document_id,
|
||||
"text": chunk.content,
|
||||
"metadata": chunk.metadata.to_dict(),
|
||||
"created_at": chunk.created_at.isoformat(),
|
||||
}
|
||||
}
|
||||
points.append(point)
|
||||
|
||||
self._progress.processed_chunks += 1
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}")
|
||||
self._progress.failed_chunks += 1
|
||||
|
||||
if points:
|
||||
await client.upsert_multi_vector(tenant_id, points)
|
||||
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Indexed document {document_id}: "
|
||||
f"{len(points)} chunks in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
return IndexingResult(
|
||||
success=True,
|
||||
total_chunks=len(chunks),
|
||||
indexed_chunks=len(points),
|
||||
failed_chunks=self._progress.failed_chunks,
|
||||
elapsed_seconds=elapsed,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}")
|
||||
|
||||
return IndexingResult(
|
||||
success=False,
|
||||
total_chunks=0,
|
||||
indexed_chunks=0,
|
||||
failed_chunks=0,
|
||||
elapsed_seconds=elapsed,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def index_documents_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> list[IndexingResult]:
|
||||
"""
|
||||
Index multiple documents in batch.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
documents: List of documents with format:
|
||||
{
|
||||
"document_id": str,
|
||||
"text": str,
|
||||
"metadata": ChunkMetadata (optional)
|
||||
}
|
||||
|
||||
Returns:
|
||||
List of IndexingResult for each document
|
||||
"""
|
||||
results = []
|
||||
|
||||
for doc in documents:
|
||||
result = await self.index_document(
|
||||
tenant_id=tenant_id,
|
||||
document_id=doc["document_id"],
|
||||
text=doc["text"],
|
||||
metadata=doc.get("metadata"),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def get_progress(self) -> IndexingProgress | None:
|
||||
"""Get current indexing progress."""
|
||||
return self._progress
|
||||
|
||||
|
||||
_knowledge_indexer: KnowledgeIndexer | None = None
|
||||
|
||||
|
||||
def get_knowledge_indexer() -> KnowledgeIndexer:
|
||||
"""Get or create KnowledgeIndexer instance."""
|
||||
global _knowledge_indexer
|
||||
if _knowledge_indexer is None:
|
||||
_knowledge_indexer = KnowledgeIndexer()
|
||||
return _knowledge_indexer
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Metadata models for RAG optimization.
|
||||
Implements structured metadata for knowledge chunks.
|
||||
Reference: rag-optimization/spec.md Section 3.2
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
"""Retrieval strategy options."""
|
||||
VECTOR_ONLY = "vector"
|
||||
BM25_ONLY = "bm25"
|
||||
HYBRID = "hybrid"
|
||||
TWO_STAGE = "two_stage"
|
||||
|
||||
|
||||
class ChunkMetadataModel(BaseModel):
|
||||
"""Pydantic model for API serialization."""
|
||||
category: str = ""
|
||||
subcategory: str = ""
|
||||
target_audience: list[str] = []
|
||||
source_doc: str = ""
|
||||
source_url: str = ""
|
||||
department: str = ""
|
||||
valid_from: str | None = None
|
||||
valid_until: str | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
"""
|
||||
Metadata for knowledge chunks.
|
||||
Reference: rag-optimization/spec.md Section 3.2.2
|
||||
"""
|
||||
category: str = ""
|
||||
subcategory: str = ""
|
||||
target_audience: list[str] = field(default_factory=list)
|
||||
source_doc: str = ""
|
||||
source_url: str = ""
|
||||
department: str = ""
|
||||
valid_from: date | None = None
|
||||
valid_until: date | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
"category": self.category,
|
||||
"subcategory": self.subcategory,
|
||||
"target_audience": self.target_audience,
|
||||
"source_doc": self.source_doc,
|
||||
"source_url": self.source_url,
|
||||
"department": self.department,
|
||||
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
|
||||
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
|
||||
"priority": self.priority,
|
||||
"keywords": self.keywords,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
category=data.get("category", ""),
|
||||
subcategory=data.get("subcategory", ""),
|
||||
target_audience=data.get("target_audience", []),
|
||||
source_doc=data.get("source_doc", ""),
|
||||
source_url=data.get("source_url", ""),
|
||||
department=data.get("department", ""),
|
||||
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
|
||||
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
|
||||
priority=data.get("priority", 5),
|
||||
keywords=data.get("keywords", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilter:
|
||||
"""
|
||||
Filter conditions for metadata-based retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
categories: list[str] | None = None
|
||||
target_audiences: list[str] | None = None
|
||||
departments: list[str] | None = None
|
||||
valid_only: bool = True
|
||||
min_priority: int | None = None
|
||||
keywords: list[str] | None = None
|
||||
|
||||
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
||||
"""Convert to Qdrant filter format."""
|
||||
conditions = []
|
||||
|
||||
if self.categories:
|
||||
conditions.append({
|
||||
"key": "metadata.category",
|
||||
"match": {"any": self.categories}
|
||||
})
|
||||
|
||||
if self.departments:
|
||||
conditions.append({
|
||||
"key": "metadata.department",
|
||||
"match": {"any": self.departments}
|
||||
})
|
||||
|
||||
if self.target_audiences:
|
||||
conditions.append({
|
||||
"key": "metadata.target_audience",
|
||||
"match": {"any": self.target_audiences}
|
||||
})
|
||||
|
||||
if self.valid_only:
|
||||
today = date.today().isoformat()
|
||||
conditions.append({
|
||||
"should": [
|
||||
{"key": "metadata.valid_until", "match": {"value": None}},
|
||||
{"key": "metadata.valid_until", "range": {"gte": today}}
|
||||
]
|
||||
})
|
||||
|
||||
if self.min_priority is not None:
|
||||
conditions.append({
|
||||
"key": "metadata.priority",
|
||||
"range": {"lte": self.min_priority}
|
||||
})
|
||||
|
||||
if not conditions:
|
||||
return None
|
||||
|
||||
if len(conditions) == 1:
|
||||
return {"must": conditions}
|
||||
|
||||
return {"must": conditions}
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeChunk:
|
||||
"""
|
||||
Knowledge chunk with multi-dimensional embeddings.
|
||||
Reference: rag-optimization/spec.md Section 3.2.1
|
||||
"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
content: str
|
||||
embedding_full: list[float] = field(default_factory=list)
|
||||
embedding_256: list[float] = field(default_factory=list)
|
||||
embedding_512: list[float] = field(default_factory=list)
|
||||
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
|
||||
"""Convert to Qdrant point format."""
|
||||
return {
|
||||
"id": point_id,
|
||||
"vector": {
|
||||
"full": self.embedding_full,
|
||||
"dim_256": self.embedding_256,
|
||||
"dim_512": self.embedding_512,
|
||||
},
|
||||
"payload": {
|
||||
"chunk_id": self.chunk_id,
|
||||
"document_id": self.document_id,
|
||||
"text": self.content,
|
||||
"metadata": self.metadata.to_dict(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveRequest:
|
||||
"""
|
||||
Request for knowledge retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
query: str
|
||||
query_with_prefix: str = ""
|
||||
top_k: int = 10
|
||||
filters: MetadataFilter | None = None
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.query_with_prefix:
|
||||
self.query_with_prefix = f"search_query:{self.query}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveResult:
|
||||
"""
|
||||
Result from knowledge retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
vector_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||
rank: int = 0
|
||||
|
|
@ -0,0 +1,509 @@
|
|||
"""
|
||||
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
|
||||
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||
from app.services.retrieval.base import (
|
||||
BaseRetriever,
|
||||
RetrievalContext,
|
||||
RetrievalHit,
|
||||
RetrievalResult,
|
||||
)
|
||||
from app.services.retrieval.metadata import (
|
||||
ChunkMetadata,
|
||||
MetadataFilter,
|
||||
RetrieveResult,
|
||||
RetrievalStrategy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TwoStageResult:
|
||||
"""Result from two-stage retrieval."""
|
||||
candidates: list[dict[str, Any]]
|
||||
final_results: list[RetrieveResult]
|
||||
stage1_latency_ms: float = 0.0
|
||||
stage2_latency_ms: float = 0.0
|
||||
|
||||
|
||||
class RRFCombiner:
|
||||
"""
|
||||
Reciprocal Rank Fusion for combining multiple retrieval results.
|
||||
Reference: rag-optimization/spec.md Section 2.5
|
||||
|
||||
Formula: score = Σ(1 / (k + rank_i))
|
||||
Default k = 60
|
||||
"""
|
||||
|
||||
def __init__(self, k: int = 60):
|
||||
self._k = k
|
||||
|
||||
def combine(
|
||||
self,
|
||||
vector_results: list[dict[str, Any]],
|
||||
bm25_results: list[dict[str, Any]],
|
||||
vector_weight: float = 0.7,
|
||||
bm25_weight: float = 0.3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Combine vector and BM25 results using RRF.
|
||||
|
||||
Args:
|
||||
vector_results: Results from vector search
|
||||
bm25_results: Results from BM25 search
|
||||
vector_weight: Weight for vector results
|
||||
bm25_weight: Weight for BM25 results
|
||||
|
||||
Returns:
|
||||
Combined and sorted results
|
||||
"""
|
||||
combined_scores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for rank, result in enumerate(vector_results):
|
||||
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||
rrf_score = vector_weight / (self._k + rank + 1)
|
||||
|
||||
if chunk_id not in combined_scores:
|
||||
combined_scores[chunk_id] = {
|
||||
"score": 0.0,
|
||||
"vector_score": result.get("score", 0.0),
|
||||
"bm25_score": 0.0,
|
||||
"vector_rank": rank,
|
||||
"bm25_rank": -1,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
}
|
||||
|
||||
combined_scores[chunk_id]["score"] += rrf_score
|
||||
|
||||
for rank, result in enumerate(bm25_results):
|
||||
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||
rrf_score = bm25_weight / (self._k + rank + 1)
|
||||
|
||||
if chunk_id not in combined_scores:
|
||||
combined_scores[chunk_id] = {
|
||||
"score": 0.0,
|
||||
"vector_score": 0.0,
|
||||
"bm25_score": result.get("score", 0.0),
|
||||
"vector_rank": -1,
|
||||
"bm25_rank": rank,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
}
|
||||
else:
|
||||
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
||||
combined_scores[chunk_id]["bm25_rank"] = rank
|
||||
|
||||
combined_scores[chunk_id]["score"] += rrf_score
|
||||
|
||||
sorted_results = sorted(
|
||||
combined_scores.values(),
|
||||
key=lambda x: x["score"],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
class OptimizedRetriever(BaseRetriever):
|
||||
"""
|
||||
Optimized retriever with:
|
||||
- Task prefixes (search_document/search_query)
|
||||
- Two-stage retrieval (256 dim -> 768 dim)
|
||||
- RRF hybrid ranking (vector + BM25)
|
||||
- Metadata filtering
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2, 3, 4
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||
top_k: int | None = None,
|
||||
score_threshold: float | None = None,
|
||||
min_hits: int | None = None,
|
||||
two_stage_enabled: bool | None = None,
|
||||
two_stage_expand_factor: int | None = None,
|
||||
hybrid_enabled: bool | None = None,
|
||||
rrf_k: int | None = None,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._embedding_provider = embedding_provider
|
||||
self._top_k = top_k or settings.rag_top_k
|
||||
self._score_threshold = score_threshold or settings.rag_score_threshold
|
||||
self._min_hits = min_hits or settings.rag_min_hits
|
||||
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
|
||||
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
|
||||
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
|
||||
self._rrf_k = rrf_k or settings.rag_rrf_k
|
||||
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
from app.services.embedding.factory import get_embedding_config_manager
|
||||
manager = get_embedding_config_manager()
|
||||
provider = await manager.get_provider()
|
||||
if isinstance(provider, NomicEmbeddingProvider):
|
||||
self._embedding_provider = provider
|
||||
else:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
Retrieve documents using optimized strategy.
|
||||
|
||||
Strategy selection:
|
||||
1. If two_stage_enabled: use two-stage retrieval
|
||||
2. If hybrid_enabled: use RRF hybrid ranking
|
||||
3. Otherwise: simple vector search
|
||||
"""
|
||||
logger.info(
|
||||
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
|
||||
)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
|
||||
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
|
||||
|
||||
embedding_result = await provider.embed_query(ctx.query)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
|
||||
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
||||
)
|
||||
|
||||
if self._two_stage_enabled:
|
||||
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
||||
results = await self._two_stage_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result,
|
||||
self._top_k,
|
||||
)
|
||||
elif self._hybrid_enabled:
|
||||
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
|
||||
results = await self._hybrid_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result,
|
||||
ctx.query,
|
||||
self._top_k,
|
||||
)
|
||||
else:
|
||||
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
|
||||
results = await self._vector_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result.embedding_full,
|
||||
self._top_k,
|
||||
)
|
||||
|
||||
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
|
||||
|
||||
retrieval_hits = [
|
||||
RetrievalHit(
|
||||
text=result.get("payload", {}).get("text", ""),
|
||||
score=result.get("score", 0.0),
|
||||
source="optimized_rag",
|
||||
metadata=result.get("payload", {}),
|
||||
)
|
||||
for result in results
|
||||
if result.get("score", 0.0) >= self._score_threshold
|
||||
]
|
||||
|
||||
filtered_count = len(results) - len(retrieval_hits)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
|
||||
)
|
||||
|
||||
is_insufficient = len(retrieval_hits) < self._min_hits
|
||||
|
||||
diagnostics = {
|
||||
"query_length": len(ctx.query),
|
||||
"top_k": self._top_k,
|
||||
"score_threshold": self._score_threshold,
|
||||
"two_stage_enabled": self._two_stage_enabled,
|
||||
"hybrid_enabled": self._hybrid_enabled,
|
||||
"total_hits": len(retrieval_hits),
|
||||
"is_insufficient": is_insufficient,
|
||||
"max_score": max((h.score for h in retrieval_hits), default=0.0),
|
||||
"raw_results_count": len(results),
|
||||
"filtered_below_threshold": filtered_count,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
|
||||
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
||||
)
|
||||
|
||||
if len(retrieval_hits) == 0:
|
||||
logger.warning(
|
||||
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
|
||||
f"raw_results={len(results)}, threshold={self._score_threshold}"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
hits=retrieval_hits,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
|
||||
return RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e), "is_insufficient": True},
|
||||
)
|
||||
|
||||
async def _two_stage_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding_result: NomicEmbeddingResult,
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Two-stage retrieval using Matryoshka dimensions.
|
||||
|
||||
Stage 1: Fast retrieval with 256-dim vectors
|
||||
Stage 2: Precise reranking with 768-dim vectors
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.4
|
||||
"""
|
||||
import time
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
stage1_start = time.perf_counter()
|
||||
candidates = await self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||
top_k * self._two_stage_expand_factor
|
||||
)
|
||||
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
||||
)
|
||||
|
||||
stage2_start = time.perf_counter()
|
||||
reranked = []
|
||||
for candidate in candidates:
|
||||
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
||||
if stored_full_embedding:
|
||||
import numpy as np
|
||||
similarity = self._cosine_similarity(
|
||||
embedding_result.embedding_full,
|
||||
stored_full_embedding
|
||||
)
|
||||
candidate["score"] = similarity
|
||||
candidate["stage"] = "reranked"
|
||||
reranked.append(candidate)
|
||||
|
||||
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
results = reranked[:top_k]
|
||||
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _hybrid_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding_result: NomicEmbeddingResult,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Hybrid retrieval using RRF to combine vector and BM25 results.
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.5
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
vector_task = self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_full, "full",
|
||||
top_k * 2
|
||||
)
|
||||
|
||||
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
|
||||
|
||||
vector_results, bm25_results = await asyncio.gather(
|
||||
vector_task, bm25_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(vector_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
||||
vector_results = []
|
||||
|
||||
if isinstance(bm25_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
||||
bm25_results = []
|
||||
|
||||
combined = self._rrf_combiner.combine(
|
||||
vector_results,
|
||||
bm25_results,
|
||||
vector_weight=settings.rag_vector_weight,
|
||||
bm25_weight=settings.rag_bm25_weight,
|
||||
)
|
||||
|
||||
return combined[:top_k]
|
||||
|
||||
async def _vector_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding: list[float],
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Simple vector retrieval."""
|
||||
client = await self._get_client()
|
||||
return await self._search_with_dimension(
|
||||
client, tenant_id, embedding, "full", top_k
|
||||
)
|
||||
|
||||
async def _search_with_dimension(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query_vector: list[float],
|
||||
vector_name: str,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using specified vector dimension."""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Searching collection={collection_name}, "
|
||||
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
|
||||
)
|
||||
|
||||
results = await qdrant.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
|
||||
)
|
||||
|
||||
if len(results) > 0:
|
||||
for i, r in enumerate(results[:3]):
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
|
||||
f"collection_name={client.get_collection_name(tenant_id)}",
|
||||
exc_info=True
|
||||
)
|
||||
return []
|
||||
|
||||
async def _bm25_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
|
||||
This is a simplified implementation; for production, use Elasticsearch.
|
||||
"""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=limit * 3,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
scored_results = []
|
||||
for point in results[0]:
|
||||
text = point.payload.get("text", "").lower()
|
||||
text_terms = set(re.findall(r'\w+', text))
|
||||
overlap = len(query_terms & text_terms)
|
||||
if overlap > 0:
|
||||
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||
scored_results.append({
|
||||
"id": str(point.id),
|
||||
"score": score,
|
||||
"payload": point.payload or {},
|
||||
})
|
||||
|
||||
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
import numpy as np
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if retriever is healthy."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
qdrant = await client.get_client()
|
||||
await qdrant.get_collections()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_optimized_retriever: OptimizedRetriever | None = None
|
||||
|
||||
|
||||
async def get_optimized_retriever() -> OptimizedRetriever:
|
||||
"""Get or create OptimizedRetriever instance."""
|
||||
global _optimized_retriever
|
||||
if _optimized_retriever is None:
|
||||
_optimized_retriever = OptimizedRetriever()
|
||||
return _optimized_retriever
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue