Merge pull request '[AC-AISVC-02, AC-AISVC-16] 多个需求合并' (#1) from feature/prompt-unification-and-logging into main

Reviewed-on: http://ashai.com.cn:3005/MerCry/ai-robot-core/pulls/1
This commit is contained in:
MerCry 2026-02-25 17:17:35 +00:00
commit 1e3fe808e8
136 changed files with 28460 additions and 1 deletions

6
.gitignore vendored
View File

@ -158,5 +158,9 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# Project specific
ai-service/uploads/
*.local

1
ai-service-admin/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
node_modules/

View File

@ -0,0 +1,13 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" href="/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>AI Service Admin</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

1935
ai-service-admin/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
{
"name": "ai-service-admin",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "vite",
"build": "vue-tsc --noEmit && vite build",
"preview": "vite preview"
},
"dependencies": {
"@element-plus/icons-vue": "^2.3.1",
"axios": "^1.6.7",
"element-plus": "^2.6.1",
"pinia": "^2.1.7",
"vue": "^3.4.21",
"vue-router": "^4.3.0"
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.0.4",
"typescript": "^5.2.2",
"vite": "^5.1.4",
"vue-tsc": "^1.8.27"
}
}

View File

@ -0,0 +1,279 @@
<template>
<div class="app-wrapper">
<header class="app-header">
<div class="header-left">
<div class="logo">
<div class="logo-icon">
<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 2L2 7L12 12L22 7L12 2Z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2 17L12 22L22 17" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2 12L12 17L22 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</div>
<span class="logo-text">AI Robot</span>
</div>
<nav class="main-nav">
<router-link to="/dashboard" class="nav-item" :class="{ active: isActive('/dashboard') }">
<el-icon><Odometer /></el-icon>
<span>控制台</span>
</router-link>
<router-link to="/kb" class="nav-item" :class="{ active: isActive('/kb') }">
<el-icon><FolderOpened /></el-icon>
<span>知识库</span>
</router-link>
<router-link to="/rag-lab" class="nav-item" :class="{ active: isActive('/rag-lab') }">
<el-icon><Cpu /></el-icon>
<span>RAG 实验室</span>
</router-link>
<router-link to="/monitoring" class="nav-item" :class="{ active: isActive('/monitoring') }">
<el-icon><Monitor /></el-icon>
<span>会话监控</span>
</router-link>
<div class="nav-divider"></div>
<router-link to="/admin/embedding" class="nav-item" :class="{ active: isActive('/admin/embedding') }">
<el-icon><Connection /></el-icon>
<span>嵌入模型</span>
</router-link>
<router-link to="/admin/llm" class="nav-item" :class="{ active: isActive('/admin/llm') }">
<el-icon><ChatDotSquare /></el-icon>
<span>LLM 配置</span>
</router-link>
</nav>
</div>
<div class="header-right">
<div class="tenant-selector">
<el-select
v-model="currentTenantId"
placeholder="选择租户"
size="default"
:loading="loading"
@change="handleTenantChange"
>
<el-option
v-for="tenant in tenantList"
:key="tenant.id"
:label="tenant.name"
:value="tenant.id"
/>
</el-select>
</div>
</div>
</header>
<main class="app-main">
<router-view />
</main>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRoute } from 'vue-router'
import { useTenantStore } from '@/stores/tenant'
import { getTenantList, type Tenant } from '@/api/tenant'
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare } from '@element-plus/icons-vue'
import { ElMessage } from 'element-plus'
const route = useRoute()
const tenantStore = useTenantStore()
const currentTenantId = ref(tenantStore.currentTenantId)
const tenantList = ref<Tenant[]>([])
const loading = ref(false)
const isActive = (path: string) => {
return route.path === path || route.path.startsWith(path + '/')
}
const handleTenantChange = (val: string) => {
tenantStore.setTenant(val)
//
window.location.reload()
}
// Validate tenant ID format: name@ash@year
const isValidTenantId = (tenantId: string): boolean => {
return /^[^@]+@ash@\d{4}$/.test(tenantId)
}
const fetchTenantList = async () => {
loading.value = true
try {
// ID
if (!isValidTenantId(currentTenantId.value)) {
console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value)
currentTenantId.value = 'default@ash@2026'
tenantStore.setTenant(currentTenantId.value)
}
const response = await getTenantList()
tenantList.value = response.tenants || []
//
if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) {
const firstTenant = tenantList.value[0]
currentTenantId.value = firstTenant.id
tenantStore.setTenant(firstTenant.id)
}
} catch (error) {
ElMessage.error('获取租户列表失败')
console.error('Failed to fetch tenant list:', error)
// 使
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)' }]
} finally {
loading.value = false
}
}
onMounted(() => {
fetchTenantList()
})
</script>
<style scoped>
.app-wrapper {
min-height: 100vh;
background-color: var(--bg-primary, #F8FAFC);
}
.app-header {
position: sticky;
top: 0;
z-index: 100;
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 24px;
height: 60px;
background-color: var(--bg-secondary, #FFFFFF);
border-bottom: 1px solid var(--border-color, #E2E8F0);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.04);
}
.header-left {
display: flex;
align-items: center;
gap: 32px;
}
.logo {
display: flex;
align-items: center;
gap: 10px;
}
.logo-icon {
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
color: var(--primary-color, #4F7CFF);
}
.logo-icon svg {
width: 28px;
height: 28px;
}
.logo-text {
font-size: 18px;
font-weight: 700;
color: var(--text-primary, #1E293B);
letter-spacing: -0.5px;
}
.main-nav {
display: flex;
align-items: center;
gap: 4px;
}
.nav-item {
display: flex;
align-items: center;
gap: 6px;
padding: 8px 14px;
font-size: 14px;
font-weight: 500;
color: var(--text-secondary, #64748B);
text-decoration: none;
border-radius: 8px;
transition: all 0.2s ease;
}
.nav-item:hover {
color: var(--primary-color, #4F7CFF);
background-color: var(--primary-lighter, #E8EEFF);
}
.nav-item.active {
color: var(--primary-color, #4F7CFF);
background-color: var(--primary-lighter, #E8EEFF);
}
.nav-item .el-icon {
font-size: 16px;
}
.nav-divider {
width: 1px;
height: 20px;
margin: 0 8px;
background-color: var(--border-color, #E2E8F0);
}
.header-right {
display: flex;
align-items: center;
gap: 16px;
}
.tenant-selector {
min-width: 140px;
}
.tenant-selector :deep(.el-input__wrapper) {
background-color: var(--bg-tertiary, #F1F5F9);
border-color: transparent;
}
.tenant-selector :deep(.el-input__wrapper:hover) {
background-color: var(--bg-hover, #E2E8F0);
}
.app-main {
min-height: calc(100vh - 60px);
}
@media (max-width: 1024px) {
.app-header {
padding: 0 16px;
}
.header-left {
gap: 16px;
}
.main-nav {
gap: 2px;
}
.nav-item {
padding: 8px 10px;
}
.nav-item span {
display: none;
}
.nav-divider {
display: none;
}
}
@media (max-width: 640px) {
.logo-text {
display: none;
}
}
</style>

View File

@ -0,0 +1,8 @@
import request from '@/utils/request'
export function getDashboardStats() {
return request({
url: '/admin/dashboard/stats',
method: 'get'
})
}

View File

@ -0,0 +1,88 @@
import request from '@/utils/request'
export interface EmbeddingProviderInfo {
name: string
display_name: string
description?: string
config_schema: Record<string, any>
}
export interface EmbeddingConfig {
provider: string
config: Record<string, any>
updated_at?: string
}
export interface EmbeddingConfigUpdate {
provider: string
config?: Record<string, any>
}
export interface EmbeddingTestResult {
success: boolean
dimension: number
latency_ms?: number
message?: string
error?: string
}
export interface EmbeddingTestRequest {
test_text?: string
config?: EmbeddingConfigUpdate
}
export interface DocumentFormat {
extension: string
name: string
description?: string
}
export interface EmbeddingProvidersResponse {
providers: EmbeddingProviderInfo[]
}
export interface EmbeddingConfigUpdateResponse {
success: boolean
message: string
}
export interface SupportedFormatsResponse {
formats: DocumentFormat[]
}
export function getProviders() {
return request({
url: '/embedding/providers',
method: 'get'
})
}
export function getConfig() {
return request({
url: '/embedding/config',
method: 'get'
})
}
export function saveConfig(data: EmbeddingConfigUpdate) {
return request({
url: '/embedding/config',
method: 'put',
data
})
}
export function testEmbedding(data: EmbeddingTestRequest): Promise<EmbeddingTestResult> {
return request({
url: '/embedding/test',
method: 'post',
data
})
}
export function getSupportedFormats() {
return request({
url: '/embedding/formats',
method: 'get'
})
}

View File

@ -0,0 +1,38 @@
import request from '@/utils/request'
export function listKnowledgeBases() {
return request({
url: '/admin/kb/knowledge-bases',
method: 'get'
})
}
export function listDocuments(params: any) {
return request({
url: '/admin/kb/documents',
method: 'get',
params
})
}
export function uploadDocument(data: FormData) {
return request({
url: '/admin/kb/documents',
method: 'post',
data
})
}
export function getIndexJob(jobId: string) {
return request({
url: `/admin/kb/index/jobs/${jobId}`,
method: 'get'
})
}
export function deleteDocument(docId: string) {
return request({
url: `/admin/kb/documents/${docId}`,
method: 'delete'
})
}

View File

@ -0,0 +1,50 @@
import request from '@/utils/request'
import type {
LLMProviderInfo,
LLMConfig,
LLMConfigUpdate,
LLMTestResult,
LLMTestRequest,
LLMProvidersResponse,
LLMConfigUpdateResponse
} from '@/types/llm'
export function getLLMProviders(): Promise<LLMProvidersResponse> {
return request({
url: '/admin/llm/providers',
method: 'get'
})
}
export function getLLMConfig(): Promise<LLMConfig> {
return request({
url: '/admin/llm/config',
method: 'get'
})
}
export function updateLLMConfig(data: LLMConfigUpdate): Promise<LLMConfigUpdateResponse> {
return request({
url: '/admin/llm/config',
method: 'put',
data
})
}
export function testLLM(data: LLMTestRequest): Promise<LLMTestResult> {
return request({
url: '/admin/llm/test',
method: 'post',
data
})
}
export type {
LLMProviderInfo,
LLMConfig,
LLMConfigUpdate,
LLMTestResult,
LLMTestRequest,
LLMProvidersResponse,
LLMConfigUpdateResponse
}

View File

@ -0,0 +1,16 @@
import request from '@/utils/request'
export function listSessions(params: any) {
return request({
url: '/admin/sessions',
method: 'get',
params
})
}
export function getSessionDetail(sessionId: string) {
return request({
url: `/admin/sessions/${sessionId}`,
method: 'get'
})
}

View File

@ -0,0 +1,135 @@
import request from '@/utils/request'
import { useTenantStore } from '@/stores/tenant'
export interface AIResponse {
content: string
prompt_tokens?: number
completion_tokens?: number
total_tokens?: number
latency_ms?: number
model?: string
}
export interface RetrievalResult {
content: string
score: number
source: string
metadata?: Record<string, any>
}
export interface RagExperimentRequest {
query: string
kb_ids?: string[]
top_k?: number
score_threshold?: number
llm_provider?: string
generate_response?: boolean
}
export interface RagExperimentResult {
query: string
retrieval_results?: RetrievalResult[]
final_prompt?: string
ai_response?: AIResponse
total_latency_ms?: number
}
export function runRagExperiment(data: RagExperimentRequest): Promise<RagExperimentResult> {
return request({
url: '/admin/rag/experiments/run',
method: 'post',
data
})
}
export function runRagExperimentStream(
data: RagExperimentRequest,
onMessage: (event: MessageEvent) => void,
onError?: (error: Event) => void,
onComplete?: () => void
): EventSource {
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const url = `${baseUrl}/admin/rag/experiments/stream`
const eventSource = new EventSource(url, {
withCredentials: true
})
eventSource.onmessage = onMessage
eventSource.onerror = (error) => {
eventSource.close()
onError?.(error)
}
return eventSource
}
export function createSSEConnection(
url: string,
body: RagExperimentRequest,
onMessage: (data: string) => void,
onError?: (error: Error) => void,
onComplete?: () => void
): () => void {
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
const fullUrl = `${baseUrl}${url}`
const tenantStore = useTenantStore()
const controller = new AbortController()
fetch(fullUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
'X-Tenant-Id': tenantStore.currentTenantId || '',
},
body: JSON.stringify(body),
signal: controller.signal
})
.then(async (response) => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`)
}
const reader = response.body?.getReader()
if (!reader) {
throw new Error('No response body')
}
const decoder = new TextDecoder()
let buffer = ''
while (true) {
const { done, value } = await reader.read()
if (done) {
onComplete?.()
break
}
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6)
if (data === '[DONE]') {
onComplete?.()
return
}
onMessage(data)
}
}
}
})
.catch((error) => {
if (error.name !== 'AbortError') {
onError?.(error)
}
})
return () => controller.abort()
}

View File

@ -0,0 +1,21 @@
import request from '@/utils/request'
export interface Tenant {
id: string
name: string
displayName: string
year: string
createdAt: string
}
export interface TenantListResponse {
tenants: Tenant[]
total: number
}
export function getTenantList() {
return request<TenantListResponse>({
url: '/admin/tenants',
method: 'get'
})
}

View File

@ -0,0 +1,43 @@
<template>
<el-form :model="model" v-bind="$attrs" ref="formRef">
<slot />
<el-form-item v-if="showFooter">
<el-button type="primary" @click="handleSubmit">提交</el-button>
<el-button @click="handleReset">重置</el-button>
</el-form-item>
</el-form>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import type { FormInstance } from 'element-plus'
const props = defineProps<{
model: any
showFooter?: boolean
}>()
const emit = defineEmits(['submit', 'reset'])
const formRef = ref<FormInstance>()
const handleSubmit = async () => {
if (!formRef.value) return
await formRef.value.validate((valid) => {
if (valid) {
emit('submit', props.model)
}
})
}
const handleReset = () => {
if (!formRef.value) return
formRef.value.resetFields()
emit('reset')
}
defineExpose({
validate: () => formRef.value?.validate(),
resetFields: () => formRef.value?.resetFields()
})
</script>

View File

@ -0,0 +1,55 @@
<template>
<el-table :data="data" v-bind="$attrs" style="width: 100%">
<slot />
</el-table>
<div v-if="total > 0" class="pagination-container">
<el-pagination
v-model:current-page="currentPage"
v-model:page-size="pageSize"
:total="total"
:page-sizes="[10, 20, 50, 100]"
layout="total, sizes, prev, pager, next, jumper"
@size-change="handleSizeChange"
@current-change="handleCurrentChange"
/>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
const props = defineProps<{
data: any[]
total: number
pageNum: number
pageSize: number
}>()
const emit = defineEmits(['update:pageNum', 'update:pageSize', 'pagination'])
const currentPage = computed({
get: () => props.pageNum,
set: (val) => emit('update:pageNum', val)
})
const pageSize = computed({
get: () => props.pageSize,
set: (val) => emit('update:pageSize', val)
})
const handleSizeChange = (val: number) => {
emit('pagination', { page: currentPage.value, limit: val })
}
const handleCurrentChange = (val: number) => {
emit('pagination', { page: val, limit: pageSize.value })
}
</script>
<style scoped>
.pagination-container {
padding: 32px 16px;
display: flex;
justify-content: flex-end;
}
</style>

View File

@ -0,0 +1,219 @@
<template>
<el-form
ref="formRef"
:model="formData"
:rules="formRules"
:label-width="labelWidth"
v-bind="$attrs"
>
<el-form-item
v-for="(field, key) in schemaProperties"
:key="key"
:label="field.title || key"
:prop="key"
>
<template #label>
<span>{{ field.title || key }}</span>
<el-tooltip v-if="field.description" :content="field.description" placement="top">
<el-icon class="ml-1 cursor-help"><QuestionFilled /></el-icon>
</el-tooltip>
</template>
<el-input
v-if="field.type === 'string'"
v-model="formData[key]"
:placeholder="`请输入${field.title || key}`"
clearable
:show-password="isPasswordField(key)"
/>
<el-input-number
v-else-if="field.type === 'integer' || field.type === 'number'"
v-model="formData[key]"
:placeholder="`请输入${field.title || key}`"
:min="field.minimum"
:max="field.maximum"
:step="field.type === 'number' ? 0.1 : 1"
:precision="field.type === 'number' ? 2 : 0"
controls-position="right"
class="w-full"
/>
<el-switch
v-else-if="field.type === 'boolean'"
v-model="formData[key]"
/>
<el-select
v-else-if="field.enum && field.enum.length > 0"
v-model="formData[key]"
:placeholder="`请选择${field.title || key}`"
clearable
class="w-full"
>
<el-option
v-for="option in field.enum"
:key="option"
:label="option"
:value="option"
/>
</el-select>
</el-form-item>
</el-form>
</template>
<script setup lang="ts">
import { ref, computed, watch, onMounted } from 'vue'
import { QuestionFilled } from '@element-plus/icons-vue'
import type { FormInstance, FormRules } from 'element-plus'
export interface SchemaProperty {
type: string
title?: string
description?: string
default?: any
enum?: string[]
minimum?: number
maximum?: number
required?: boolean
}
export interface ConfigSchema {
type?: string
properties?: Record<string, SchemaProperty>
required?: string[]
}
const props = defineProps<{
schema: ConfigSchema
modelValue: Record<string, any>
labelWidth?: string
}>()
const emit = defineEmits<{
(e: 'update:modelValue', value: Record<string, any>): void
}>()
const formRef = ref<FormInstance>()
const formData = ref<Record<string, any>>({})
const schemaProperties = computed(() => {
return props.schema?.properties || {}
})
const requiredFields = computed(() => {
const required = props.schema?.required || []
const propsRequired = Object.entries(schemaProperties.value)
.filter(([, field]) => field.required)
.map(([key]) => key)
return [...new Set([...required, ...propsRequired])]
})
const formRules = computed<FormRules>(() => {
const rules: FormRules = {}
Object.entries(schemaProperties.value).forEach(([key, field]) => {
const fieldRules: any[] = []
if (requiredFields.value.includes(key)) {
fieldRules.push({
required: true,
message: `${field.title || key}不能为空`,
trigger: ['blur', 'change']
})
}
if (field.type === 'string' && field.minimum !== undefined) {
fieldRules.push({
min: field.minimum,
message: `${field.title || key}长度不能小于${field.minimum}`,
trigger: ['blur']
})
}
if (field.type === 'string' && field.maximum !== undefined) {
fieldRules.push({
max: field.maximum,
message: `${field.title || key}长度不能大于${field.maximum}`,
trigger: ['blur']
})
}
if (rules[key]) {
rules[key] = fieldRules
} else if (fieldRules.length > 0) {
rules[key] = fieldRules
}
})
return rules
})
const isPasswordField = (key: string): boolean => {
const lowerKey = key.toLowerCase()
return lowerKey.includes('password') || lowerKey.includes('secret') || lowerKey.includes('key') || lowerKey.includes('token')
}
const initFormData = () => {
const data: Record<string, any> = {}
Object.entries(schemaProperties.value).forEach(([key, field]) => {
if (props.modelValue && props.modelValue[key] !== undefined) {
data[key] = props.modelValue[key]
} else if (field.default !== undefined) {
data[key] = field.default
} else {
switch (field.type) {
case 'string':
data[key] = ''
break
case 'integer':
case 'number':
data[key] = field.minimum ?? 0
break
case 'boolean':
data[key] = false
break
default:
data[key] = null
}
}
})
formData.value = data
}
watch(
() => props.modelValue,
() => {
initFormData()
},
{ deep: true }
)
watch(
() => props.schema,
() => {
initFormData()
},
{ deep: true }
)
watch(
formData,
(val) => {
emit('update:modelValue', val)
},
{ deep: true }
)
onMounted(() => {
initFormData()
})
defineExpose({
validate: () => formRef.value?.validate(),
resetFields: () => formRef.value?.resetFields(),
clearValidate: () => formRef.value?.clearValidate()
})
</script>
<style scoped>
.w-full {
width: 100%;
}
.ml-1 {
margin-left: 4px;
}
.cursor-help {
cursor: help;
}
</style>

View File

@ -0,0 +1,83 @@
<template>
<el-select
:model-value="modelValue"
:loading="loading"
:placeholder="placeholder"
:disabled="disabled"
:clearable="clearable"
:teleported="true"
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
@update:model-value="handleChange"
>
<el-option
v-for="provider in providers"
:key="provider.name"
:label="provider.display_name"
:value="provider.name"
>
<div class="provider-option">
<span class="provider-name">{{ provider.display_name }}</span>
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
</div>
</el-option>
</el-select>
</template>
<script setup lang="ts">
export interface ProviderInfo {
name: string
display_name: string
description?: string
config_schema: Record<string, any>
}
const props = withDefaults(
defineProps<{
modelValue?: string
providers: ProviderInfo[]
loading?: boolean
disabled?: boolean
clearable?: boolean
placeholder?: string
}>(),
{
modelValue: '',
loading: false,
disabled: false,
clearable: false,
placeholder: '请选择提供者'
}
)
const emit = defineEmits<{
'update:modelValue': [value: string]
change: [provider: ProviderInfo | undefined]
}>()
const handleChange = (value: string) => {
emit('update:modelValue', value)
const selectedProvider = props.providers.find((p) => p.name === value)
emit('change', selectedProvider)
}
</script>
<style scoped>
.provider-option {
display: flex;
flex-direction: column;
line-height: 1.5;
padding: 4px 0;
}
.provider-name {
font-weight: 500;
color: var(--text-primary);
}
.provider-desc {
font-size: 12px;
color: var(--text-secondary);
margin-top: 2px;
line-height: 1.4;
}
</style>

View File

@ -0,0 +1,523 @@
<template>
<el-card shadow="hover" class="test-panel">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><Connection /></el-icon>
</div>
<span class="header-title">{{ title }}</span>
</div>
<el-tag v-if="testResult" :type="testResult.success ? 'success' : 'danger'" size="small" effect="dark">
{{ testResult.success ? '连接成功' : '连接失败' }}
</el-tag>
</div>
</template>
<div class="test-content">
<div class="test-form-section">
<div class="section-label">
<el-icon><Edit /></el-icon>
<span>{{ inputLabel }}</span>
</div>
<el-input
v-model="testInputValue"
type="textarea"
:rows="3"
:placeholder="inputPlaceholder"
clearable
class="test-textarea"
/>
<el-button
type="primary"
size="large"
:loading="loading"
:disabled="!canTest"
class="test-button"
@click="handleTest"
>
<el-icon v-if="!loading"><Connection /></el-icon>
{{ loading ? '测试中...' : '测试连接' }}
</el-button>
</div>
<transition name="result-fade">
<div v-if="testResult" class="test-result">
<el-divider />
<div v-if="testResult.success" class="success-result">
<div class="result-header">
<div class="success-icon">
<el-icon><CircleCheck /></el-icon>
</div>
<span class="result-title">{{ testResult.message || '连接成功' }}</span>
</div>
<div class="success-details">
<div v-if="testResult.dimension !== undefined" class="detail-card">
<div class="detail-icon">
<el-icon><Grid /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">向量维度</span>
<span class="detail-value">{{ testResult.dimension }}</span>
</div>
</div>
<div v-if="testResult.response" class="detail-card response-card">
<div class="detail-icon">
<el-icon><ChatDotRound /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">模型响应</span>
<span class="detail-value response-text">{{ testResult.response }}</span>
</div>
</div>
<div v-if="testResult.latency_ms" class="detail-card">
<div class="detail-icon">
<el-icon><Timer /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">响应延迟</span>
<span class="detail-value">{{ testResult.latency_ms.toFixed(2) }} ms</span>
</div>
</div>
<template v-if="showTokenStats && testResult.total_tokens !== undefined">
<div class="detail-card">
<div class="detail-icon">
<el-icon><Document /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">输入 Token</span>
<span class="detail-value">{{ testResult.prompt_tokens || 0 }}</span>
</div>
</div>
<div class="detail-card">
<div class="detail-icon">
<el-icon><EditPen /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">输出 Token</span>
<span class="detail-value">{{ testResult.completion_tokens || 0 }}</span>
</div>
</div>
<div class="detail-card">
<div class="detail-icon">
<el-icon><DataAnalysis /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label"> Token</span>
<span class="detail-value">{{ testResult.total_tokens }}</span>
</div>
</div>
</template>
</div>
</div>
<div v-else class="error-result">
<div class="result-header">
<div class="error-icon">
<el-icon><CircleClose /></el-icon>
</div>
<span class="result-title error">连接失败</span>
</div>
<div class="error-message-box">
<p class="error-text">{{ testResult.error || '未知错误' }}</p>
</div>
<div class="troubleshooting">
<div class="troubleshoot-header">
<el-icon><Warning /></el-icon>
<span>排查建议</span>
</div>
<ul class="troubleshoot-list">
<li v-for="(tip, index) in troubleshootingTips" :key="index">
<el-icon class="list-icon"><Right /></el-icon>
{{ tip }}
</li>
</ul>
</div>
</div>
</div>
</transition>
</div>
</el-card>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { Connection, Edit, CircleCheck, CircleClose, Timer, Grid, Warning, Right, ChatDotRound, Document, EditPen, DataAnalysis } from '@element-plus/icons-vue'
export interface TestResult {
success: boolean
dimension?: number
latency_ms?: number
message?: string
error?: string
response?: string
prompt_tokens?: number
completion_tokens?: number
total_tokens?: number
}
const props = withDefaults(
defineProps<{
testFn: (input?: string) => Promise<TestResult>
canTest?: boolean
title?: string
inputLabel?: string
inputPlaceholder?: string
showTokenStats?: boolean
}>(),
{
canTest: true,
title: '连接测试',
inputLabel: '测试输入',
inputPlaceholder: '请输入测试内容(可选,默认使用系统预设内容)',
showTokenStats: false
}
)
const loading = ref(false)
const testResult = ref<TestResult | null>(null)
const testInputValue = ref('')
const troubleshootingTips = computed(() => {
const tips: string[] = []
const error = testResult.value?.error?.toLowerCase() || ''
if (error.includes('timeout') || error.includes('超时')) {
tips.push('检查网络连接是否正常')
tips.push('确认服务地址是否正确且可访问')
tips.push('尝试增加请求超时时间')
} else if (error.includes('auth') || error.includes('unauthorized') || error.includes('认证') || error.includes('api key')) {
tips.push('检查 API Key 是否正确')
tips.push('确认 API Key 是否已过期或被禁用')
tips.push('验证 API Key 是否具有足够的权限')
} else if (error.includes('connection') || error.includes('连接') || error.includes('refused')) {
tips.push('确认服务地址host/port配置正确')
tips.push('检查目标服务是否正在运行')
tips.push('验证防火墙是否允许访问')
} else if (error.includes('model') || error.includes('模型')) {
tips.push('确认模型名称是否正确')
tips.push('检查模型是否已部署或可用')
tips.push('验证模型配置参数是否符合要求')
} else {
tips.push('检查所有配置参数是否正确')
tips.push('确认服务是否正常运行')
tips.push('查看服务端日志获取详细错误信息')
}
return tips
})
const handleTest = async () => {
loading.value = true
testResult.value = null
try {
const input = testInputValue.value?.trim() || undefined
const result = await props.testFn(input)
testResult.value = result
} catch (error: any) {
testResult.value = {
success: false,
error: error?.message || '请求失败,请检查网络连接'
}
} finally {
loading.value = false
}
}
</script>
<style scoped>
.test-panel {
border-radius: 16px;
border: none;
background: rgba(255, 255, 255, 0.98);
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
}
.test-panel:hover {
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
transform: translateY(-4px);
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 10px;
color: #ffffff;
font-size: 20px;
}
.header-title {
font-size: 16px;
font-weight: 600;
color: #303133;
}
.test-content {
padding: 8px 0;
}
.test-form-section {
display: flex;
flex-direction: column;
gap: 16px;
}
.section-label {
display: flex;
align-items: center;
gap: 8px;
font-size: 14px;
font-weight: 600;
color: #606266;
}
.section-label .el-icon {
color: #667eea;
}
.test-textarea {
border-radius: 10px;
}
.test-textarea :deep(.el-textarea__inner) {
border-radius: 10px;
border: 1px solid #dcdfe6;
transition: all 0.3s ease;
}
.test-textarea :deep(.el-textarea__inner:focus) {
border-color: #667eea;
box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.2);
}
.test-button {
align-self: flex-start;
border-radius: 10px;
padding: 12px 24px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
transition: all 0.3s ease;
}
.test-button:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.test-button:disabled {
opacity: 0.6;
}
.test-result {
animation: fadeIn 0.4s ease-out;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.result-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 16px;
}
.success-icon,
.error-icon {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 50%;
font-size: 20px;
}
.success-icon {
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
color: #ffffff;
}
.error-icon {
background: linear-gradient(135deg, #f56c6c 0%, #f89898 100%);
color: #ffffff;
}
.result-title {
font-size: 16px;
font-weight: 600;
color: #67c23a;
}
.result-title.error {
color: #f56c6c;
}
.success-details {
display: flex;
gap: 16px;
flex-wrap: wrap;
}
.detail-card {
display: flex;
align-items: center;
gap: 12px;
padding: 14px 18px;
background: linear-gradient(135deg, #f0f9eb 0%, #e1f3d8 100%);
border-radius: 12px;
border: 1px solid #e1f3d8;
}
.response-card {
flex: 1;
min-width: 200px;
}
.detail-icon {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
border-radius: 10px;
color: #ffffff;
font-size: 18px;
}
.detail-info {
display: flex;
flex-direction: column;
}
.detail-label {
font-size: 12px;
color: #909399;
}
.detail-value {
font-size: 18px;
font-weight: 700;
color: #303133;
}
.response-text {
font-size: 14px;
font-weight: 500;
max-width: 300px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.error-result {
animation: shake 0.5s ease-out;
}
@keyframes shake {
0%, 100% { transform: translateX(0); }
25% { transform: translateX(-5px); }
75% { transform: translateX(5px); }
}
.error-message-box {
padding: 14px 16px;
background: linear-gradient(135deg, #fef0f0 0%, #fde2e2 100%);
border-radius: 10px;
border-left: 3px solid #f56c6c;
margin-bottom: 16px;
}
.error-text {
margin: 0;
color: #f56c6c;
font-size: 14px;
line-height: 1.6;
}
.troubleshooting {
padding: 16px;
background: linear-gradient(135deg, #fdf6ec 0%, #faecd8 100%);
border-radius: 12px;
border: 1px solid #faecd8;
}
.troubleshoot-header {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 12px;
font-weight: 600;
color: #e6a23c;
}
.troubleshoot-list {
margin: 0;
padding: 0;
list-style: none;
}
.troubleshoot-list li {
display: flex;
align-items: flex-start;
gap: 8px;
margin-bottom: 8px;
color: #606266;
font-size: 13px;
line-height: 1.6;
}
.list-icon {
margin-top: 4px;
color: #e6a23c;
font-size: 12px;
}
.result-fade-enter-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.result-fade-leave-active {
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
}
.result-fade-enter-from {
opacity: 0;
transform: translateY(20px);
}
.result-fade-leave-to {
opacity: 0;
transform: translateY(-10px);
}
</style>

View File

@ -0,0 +1,219 @@
<template>
<el-form
ref="formRef"
:model="formData"
:rules="formRules"
:label-width="labelWidth"
v-bind="$attrs"
>
<el-form-item
v-for="(field, key) in schemaProperties"
:key="key"
:label="field.title || key"
:prop="key"
>
<template #label>
<span>{{ field.title || key }}</span>
<el-tooltip v-if="field.description" :content="field.description" placement="top">
<el-icon class="ml-1 cursor-help"><QuestionFilled /></el-icon>
</el-tooltip>
</template>
<el-input
v-if="field.type === 'string'"
v-model="formData[key]"
:placeholder="`请输入${field.title || key}`"
clearable
:show-password="isPasswordField(key)"
/>
<el-input-number
v-else-if="field.type === 'integer' || field.type === 'number'"
v-model="formData[key]"
:placeholder="`请输入${field.title || key}`"
:min="field.minimum"
:max="field.maximum"
:step="field.type === 'number' ? 0.1 : 1"
:precision="field.type === 'number' ? 2 : 0"
controls-position="right"
class="w-full"
/>
<el-switch
v-else-if="field.type === 'boolean'"
v-model="formData[key]"
/>
<el-select
v-else-if="field.enum && field.enum.length > 0"
v-model="formData[key]"
:placeholder="`请选择${field.title || key}`"
clearable
class="w-full"
>
<el-option
v-for="option in field.enum"
:key="option"
:label="option"
:value="option"
/>
</el-select>
</el-form-item>
</el-form>
</template>
<script setup lang="ts">
import { ref, computed, watch, onMounted } from 'vue'
import { QuestionFilled } from '@element-plus/icons-vue'
import type { FormInstance, FormRules } from 'element-plus'
interface SchemaProperty {
type: string
title?: string
description?: string
default?: any
enum?: string[]
minimum?: number
maximum?: number
required?: boolean
}
interface ConfigSchema {
type?: string
properties?: Record<string, SchemaProperty>
required?: string[]
}
const props = defineProps<{
schema: ConfigSchema
modelValue: Record<string, any>
labelWidth?: string
}>()
const emit = defineEmits<{
(e: 'update:modelValue', value: Record<string, any>): void
}>()
const formRef = ref<FormInstance>()
const formData = ref<Record<string, any>>({})
const schemaProperties = computed(() => {
return props.schema?.properties || {}
})
const requiredFields = computed(() => {
const required = props.schema?.required || []
const propsRequired = Object.entries(schemaProperties.value)
.filter(([, field]) => field.required)
.map(([key]) => key)
return [...new Set([...required, ...propsRequired])]
})
const formRules = computed<FormRules>(() => {
const rules: FormRules = {}
Object.entries(schemaProperties.value).forEach(([key, field]) => {
const fieldRules: any[] = []
if (requiredFields.value.includes(key)) {
fieldRules.push({
required: true,
message: `${field.title || key}不能为空`,
trigger: ['blur', 'change']
})
}
if (field.type === 'string' && field.minimum !== undefined) {
fieldRules.push({
min: field.minimum,
message: `${field.title || key}长度不能小于${field.minimum}`,
trigger: ['blur']
})
}
if (field.type === 'string' && field.maximum !== undefined) {
fieldRules.push({
max: field.maximum,
message: `${field.title || key}长度不能大于${field.maximum}`,
trigger: ['blur']
})
}
if (rules[key]) {
rules[key] = fieldRules
} else if (fieldRules.length > 0) {
rules[key] = fieldRules
}
})
return rules
})
const isPasswordField = (key: string): boolean => {
const lowerKey = key.toLowerCase()
return lowerKey.includes('password') || lowerKey.includes('secret') || lowerKey.includes('key') || lowerKey.includes('token')
}
const initFormData = () => {
const data: Record<string, any> = {}
Object.entries(schemaProperties.value).forEach(([key, field]) => {
if (props.modelValue && props.modelValue[key] !== undefined) {
data[key] = props.modelValue[key]
} else if (field.default !== undefined) {
data[key] = field.default
} else {
switch (field.type) {
case 'string':
data[key] = ''
break
case 'integer':
case 'number':
data[key] = field.minimum ?? 0
break
case 'boolean':
data[key] = false
break
default:
data[key] = null
}
}
})
formData.value = data
}
watch(
() => props.modelValue,
() => {
initFormData()
},
{ deep: true }
)
watch(
() => props.schema,
() => {
initFormData()
},
{ deep: true }
)
watch(
formData,
(val) => {
emit('update:modelValue', val)
},
{ deep: true }
)
onMounted(() => {
initFormData()
})
defineExpose({
validate: () => formRef.value?.validate(),
resetFields: () => formRef.value?.resetFields(),
clearValidate: () => formRef.value?.clearValidate()
})
</script>
<style scoped>
.w-full {
width: 100%;
}
.ml-1 {
margin-left: 4px;
}
.cursor-help {
cursor: help;
}
</style>

View File

@ -0,0 +1,78 @@
<template>
<el-select
:model-value="modelValue"
:loading="loading"
:placeholder="placeholder"
:disabled="disabled"
:clearable="clearable"
:teleported="true"
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
@update:model-value="handleChange"
>
<el-option
v-for="provider in providers"
:key="provider.name"
:label="provider.display_name"
:value="provider.name"
>
<div class="provider-option">
<span class="provider-name">{{ provider.display_name }}</span>
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
</div>
</el-option>
</el-select>
</template>
<script setup lang="ts">
import type { EmbeddingProviderInfo } from '@/types/embedding'
const props = withDefaults(
defineProps<{
modelValue?: string
providers: EmbeddingProviderInfo[]
loading?: boolean
disabled?: boolean
clearable?: boolean
placeholder?: string
}>(),
{
modelValue: '',
loading: false,
disabled: false,
clearable: false,
placeholder: '请选择嵌入模型提供者'
}
)
const emit = defineEmits<{
'update:modelValue': [value: string]
change: [provider: EmbeddingProviderInfo | undefined]
}>()
const handleChange = (value: string) => {
emit('update:modelValue', value)
const selectedProvider = props.providers.find((p) => p.name === value)
emit('change', selectedProvider)
}
</script>
<style scoped>
.provider-option {
display: flex;
flex-direction: column;
line-height: 1.5;
padding: 4px 0;
}
.provider-name {
font-weight: 500;
color: var(--text-primary);
}
.provider-desc {
font-size: 12px;
color: var(--text-secondary);
margin-top: 2px;
line-height: 1.4;
}
</style>

View File

@ -0,0 +1,460 @@
<template>
<el-card shadow="hover" class="test-panel">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><Connection /></el-icon>
</div>
<span class="header-title">连接测试</span>
</div>
<el-tag v-if="testResult" :type="testResult.success ? 'success' : 'danger'" size="small" effect="dark">
{{ testResult.success ? '连接成功' : '连接失败' }}
</el-tag>
</div>
</template>
<div class="test-content">
<div class="test-form-section">
<div class="section-label">
<el-icon><Edit /></el-icon>
<span>测试文本</span>
</div>
<el-input
v-model="testForm.test_text"
type="textarea"
:rows="3"
placeholder="请输入测试文本(可选,默认使用系统预设文本)"
clearable
class="test-textarea"
/>
<el-button
type="primary"
size="large"
:loading="loading"
:disabled="!config?.provider"
class="test-button"
@click="handleTest"
>
<el-icon v-if="!loading"><Connection /></el-icon>
{{ loading ? '测试中...' : '测试连接' }}
</el-button>
</div>
<transition name="result-fade">
<div v-if="testResult" class="test-result">
<el-divider />
<div v-if="testResult.success" class="success-result">
<div class="result-header">
<div class="success-icon">
<el-icon><CircleCheck /></el-icon>
</div>
<span class="result-title">{{ testResult.message || '连接成功' }}</span>
</div>
<div class="success-details">
<div class="detail-card">
<div class="detail-icon">
<el-icon><Grid /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">向量维度</span>
<span class="detail-value">{{ testResult.dimension }}</span>
</div>
</div>
<div v-if="testResult.latency_ms" class="detail-card">
<div class="detail-icon">
<el-icon><Timer /></el-icon>
</div>
<div class="detail-info">
<span class="detail-label">响应延迟</span>
<span class="detail-value">{{ testResult.latency_ms.toFixed(2) }} ms</span>
</div>
</div>
</div>
</div>
<div v-else class="error-result">
<div class="result-header">
<div class="error-icon">
<el-icon><CircleClose /></el-icon>
</div>
<span class="result-title error">连接失败</span>
</div>
<div class="error-message-box">
<p class="error-text">{{ testResult.error || '未知错误' }}</p>
</div>
<div class="troubleshooting">
<div class="troubleshoot-header">
<el-icon><Warning /></el-icon>
<span>排查建议</span>
</div>
<ul class="troubleshoot-list">
<li v-for="(tip, index) in troubleshootingTips" :key="index">
<el-icon class="list-icon"><Right /></el-icon>
{{ tip }}
</li>
</ul>
</div>
</div>
</div>
</transition>
</div>
</el-card>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { Connection, Edit, CircleCheck, CircleClose, Timer, Grid, Warning, Right } from '@element-plus/icons-vue'
import { testEmbedding, type EmbeddingConfigUpdate, type EmbeddingTestResult } from '@/api/embedding'
const props = defineProps<{
config: EmbeddingConfigUpdate | null
}>()
const loading = ref(false)
const testResult = ref<EmbeddingTestResult | null>(null)
const testForm = ref({
test_text: ''
})
const troubleshootingTips = computed(() => {
const tips: string[] = []
const error = testResult.value?.error?.toLowerCase() || ''
if (error.includes('timeout') || error.includes('超时')) {
tips.push('检查网络连接是否正常')
tips.push('确认服务地址是否正确且可访问')
tips.push('尝试增加请求超时时间')
} else if (error.includes('auth') || error.includes('unauthorized') || error.includes('认证') || error.includes('api key')) {
tips.push('检查 API Key 是否正确')
tips.push('确认 API Key 是否已过期或被禁用')
tips.push('验证 API Key 是否具有足够的权限')
} else if (error.includes('connection') || error.includes('连接') || error.includes('refused')) {
tips.push('确认服务地址host/port配置正确')
tips.push('检查目标服务是否正在运行')
tips.push('验证防火墙是否允许访问')
} else if (error.includes('model') || error.includes('模型')) {
tips.push('确认模型名称是否正确')
tips.push('检查模型是否已部署或可用')
tips.push('验证模型配置参数是否符合要求')
} else {
tips.push('检查所有配置参数是否正确')
tips.push('确认服务是否正常运行')
tips.push('查看服务端日志获取详细错误信息')
}
return tips
})
const handleTest = async () => {
if (!props.config?.provider) {
return
}
loading.value = true
testResult.value = null
try {
const requestData: any = {
config: props.config
}
if (testForm.value.test_text?.trim()) {
requestData.test_text = testForm.value.test_text.trim()
}
const result = await testEmbedding(requestData)
testResult.value = result
} catch (error: any) {
testResult.value = {
success: false,
dimension: 0,
error: error?.message || '请求失败,请检查网络连接'
}
} finally {
loading.value = false
}
}
</script>
<style scoped>
.test-panel {
border-radius: 16px;
border: none;
background: rgba(255, 255, 255, 0.98);
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
}
.test-panel:hover {
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
transform: translateY(-4px);
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 10px;
color: #ffffff;
font-size: 20px;
}
.header-title {
font-size: 16px;
font-weight: 600;
color: #303133;
}
.test-content {
padding: 8px 0;
}
.test-form-section {
display: flex;
flex-direction: column;
gap: 16px;
}
.section-label {
display: flex;
align-items: center;
gap: 8px;
font-size: 14px;
font-weight: 600;
color: #606266;
}
.section-label .el-icon {
color: #667eea;
}
.test-textarea {
border-radius: 10px;
}
.test-textarea :deep(.el-textarea__inner) {
border-radius: 10px;
border: 1px solid #dcdfe6;
transition: all 0.3s ease;
}
.test-textarea :deep(.el-textarea__inner:focus) {
border-color: #667eea;
box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.2);
}
.test-button {
align-self: flex-start;
border-radius: 10px;
padding: 12px 24px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
transition: all 0.3s ease;
}
.test-button:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.test-button:disabled {
opacity: 0.6;
}
.test-result {
animation: fadeIn 0.4s ease-out;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.result-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 16px;
}
.success-icon,
.error-icon {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 50%;
font-size: 20px;
}
.success-icon {
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
color: #ffffff;
}
.error-icon {
background: linear-gradient(135deg, #f56c6c 0%, #f89898 100%);
color: #ffffff;
}
.result-title {
font-size: 16px;
font-weight: 600;
color: #67c23a;
}
.result-title.error {
color: #f56c6c;
}
.success-details {
display: flex;
gap: 16px;
flex-wrap: wrap;
}
.detail-card {
display: flex;
align-items: center;
gap: 12px;
padding: 14px 18px;
background: linear-gradient(135deg, #f0f9eb 0%, #e1f3d8 100%);
border-radius: 12px;
border: 1px solid #e1f3d8;
}
.detail-icon {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #67c23a 0%, #85ce61 100%);
border-radius: 10px;
color: #ffffff;
font-size: 18px;
}
.detail-info {
display: flex;
flex-direction: column;
}
.detail-label {
font-size: 12px;
color: #909399;
}
.detail-value {
font-size: 18px;
font-weight: 700;
color: #303133;
}
.error-result {
animation: shake 0.5s ease-out;
}
@keyframes shake {
0%, 100% { transform: translateX(0); }
25% { transform: translateX(-5px); }
75% { transform: translateX(5px); }
}
.error-message-box {
padding: 14px 16px;
background: linear-gradient(135deg, #fef0f0 0%, #fde2e2 100%);
border-radius: 10px;
border-left: 3px solid #f56c6c;
margin-bottom: 16px;
}
.error-text {
margin: 0;
color: #f56c6c;
font-size: 14px;
line-height: 1.6;
}
.troubleshooting {
padding: 16px;
background: linear-gradient(135deg, #fdf6ec 0%, #faecd8 100%);
border-radius: 12px;
border: 1px solid #faecd8;
}
.troubleshoot-header {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 12px;
font-weight: 600;
color: #e6a23c;
}
.troubleshoot-list {
margin: 0;
padding: 0;
list-style: none;
}
.troubleshoot-list li {
display: flex;
align-items: flex-start;
gap: 8px;
margin-bottom: 8px;
color: #606266;
font-size: 13px;
line-height: 1.6;
}
.list-icon {
margin-top: 4px;
color: #e6a23c;
font-size: 12px;
}
.result-fade-enter-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.result-fade-leave-active {
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
}
.result-fade-enter-from {
opacity: 0;
transform: translateY(20px);
}
.result-fade-leave-to {
opacity: 0;
transform: translateY(-10px);
}
</style>

View File

@ -0,0 +1,161 @@
<template>
<div class="supported-formats">
<div v-loading="loading" class="formats-content">
<transition-group name="tag-fade" tag="div" class="formats-grid">
<el-tooltip
v-for="format in formats"
:key="format.extension"
:content="format.description"
placement="top"
:disabled="!format.description"
effect="light"
>
<div class="format-item">
<div class="format-icon">
<span class="extension">{{ format.extension }}</span>
</div>
<div class="format-info">
<span class="format-name">{{ format.name }}</span>
</div>
</div>
</el-tooltip>
</transition-group>
<el-empty v-if="!loading && formats.length === 0" description="暂无支持的格式" :image-size="80">
<template #image>
<div class="empty-icon">
<el-icon><Document /></el-icon>
</div>
</template>
</el-empty>
</div>
</div>
</template>
<script setup lang="ts">
import { computed, onMounted } from 'vue'
import { Document } from '@element-plus/icons-vue'
import { useEmbeddingStore } from '@/stores/embedding'
const embeddingStore = useEmbeddingStore()
const formats = computed(() => embeddingStore.formats)
const loading = computed(() => embeddingStore.formatsLoading)
onMounted(() => {
if (formats.value.length === 0) {
embeddingStore.loadFormats()
}
})
</script>
<style scoped>
.supported-formats {
padding: 8px 0;
}
.formats-content {
min-height: 60px;
}
.formats-grid {
display: flex;
flex-wrap: wrap;
gap: 12px;
}
.format-item {
display: flex;
align-items: center;
gap: 10px;
padding: 10px 14px;
background: linear-gradient(135deg, #f8f9fc 0%, #eef1f5 100%);
border-radius: 10px;
border: 1px solid #e4e7ed;
cursor: default;
transition: all 0.3s ease;
}
.format-item:hover {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-color: transparent;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
}
.format-item:hover .extension,
.format-item:hover .format-name {
color: #ffffff;
}
.format-icon {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 8px;
}
.format-item:hover .format-icon {
background: rgba(255, 255, 255, 0.2);
}
.extension {
font-size: 11px;
font-weight: 700;
color: #ffffff;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.format-info {
display: flex;
flex-direction: column;
}
.format-name {
font-size: 13px;
font-weight: 600;
color: #303133;
transition: color 0.3s ease;
}
.empty-icon {
width: 80px;
height: 80px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #f5f7fa 0%, #e8ecf1 100%);
border-radius: 50%;
margin: 0 auto;
}
.empty-icon .el-icon {
font-size: 40px;
color: #c0c4cc;
}
.tag-fade-enter-active {
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
}
.tag-fade-leave-active {
transition: all 0.3s cubic-bezier(1, 0.5, 0.8, 1);
}
.tag-fade-enter-from {
opacity: 0;
transform: scale(0.8);
}
.tag-fade-leave-to {
opacity: 0;
transform: scale(0.8);
}
.tag-fade-move {
transition: transform 0.3s ease;
}
</style>

View File

@ -0,0 +1,351 @@
<template>
<el-card shadow="hover" class="ai-response-viewer">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><ChatDotRound /></el-icon>
</div>
<span class="header-title">AI 回复</span>
</div>
<el-tag v-if="response" type="success" size="small" effect="dark">
已生成
</el-tag>
</div>
</template>
<div class="response-content">
<div v-if="!response" class="placeholder-text">
<el-icon class="placeholder-icon"><Document /></el-icon>
<p>运行实验后将在此显示 AI 回复</p>
</div>
<template v-else>
<div class="markdown-content" v-html="renderedContent"></div>
<el-divider />
<div class="stats-section">
<div class="section-label">
<el-icon><DataAnalysis /></el-icon>
<span>统计信息</span>
</div>
<div class="stats-grid">
<div v-if="response.model" class="stat-card">
<div class="stat-icon model-icon">
<el-icon><Cpu /></el-icon>
</div>
<div class="stat-info">
<span class="stat-label">模型</span>
<span class="stat-value">{{ response.model }}</span>
</div>
</div>
<div v-if="response.latency_ms" class="stat-card">
<div class="stat-icon latency-icon">
<el-icon><Timer /></el-icon>
</div>
<div class="stat-info">
<span class="stat-label">响应耗时</span>
<span class="stat-value">{{ response.latency_ms.toFixed(2) }} ms</span>
</div>
</div>
<div v-if="response.prompt_tokens" class="stat-card">
<div class="stat-icon prompt-icon">
<el-icon><EditPen /></el-icon>
</div>
<div class="stat-info">
<span class="stat-label">Prompt Tokens</span>
<span class="stat-value">{{ response.prompt_tokens }}</span>
</div>
</div>
<div v-if="response.completion_tokens" class="stat-card">
<div class="stat-icon completion-icon">
<el-icon><DocumentCopy /></el-icon>
</div>
<div class="stat-info">
<span class="stat-label">Completion Tokens</span>
<span class="stat-value">{{ response.completion_tokens }}</span>
</div>
</div>
<div v-if="response.total_tokens" class="stat-card highlight">
<div class="stat-icon total-icon">
<el-icon><Coin /></el-icon>
</div>
<div class="stat-info">
<span class="stat-label">Total Tokens</span>
<span class="stat-value">{{ response.total_tokens }}</span>
</div>
</div>
</div>
</div>
</template>
</div>
</el-card>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { ChatDotRound, Document, DataAnalysis, Timer, EditPen, DocumentCopy, Coin, Cpu } from '@element-plus/icons-vue'
import type { AIResponse } from '@/api/rag'
const props = defineProps<{
response: AIResponse | null
}>()
const renderedContent = computed(() => {
if (!props.response?.content) return ''
return renderMarkdown(props.response.content)
})
const renderMarkdown = (text: string): string => {
let html = text
html = html.replace(/```(\w*)\n([\s\S]*?)```/g, '<pre><code class="language-$1">$2</code></pre>')
html = html.replace(/`([^`]+)`/g, '<code class="inline-code">$1</code>')
html = html.replace(/^### (.+)$/gm, '<h3>$1</h3>')
html = html.replace(/^## (.+)$/gm, '<h2>$1</h2>')
html = html.replace(/^# (.+)$/gm, '<h1>$1</h1>')
html = html.replace(/\*\*([^*]+)\*\*/g, '<strong>$1</strong>')
html = html.replace(/\*([^*]+)\*/g, '<em>$1</em>')
html = html.replace(/^\- (.+)$/gm, '<li>$1</li>')
html = html.replace(/^\d+\. (.+)$/gm, '<li>$1</li>')
html = html.replace(/\n\n/g, '</p><p>')
html = html.replace(/\n/g, '<br>')
return `<p>${html}</p>`
}
</script>
<style scoped>
.ai-response-viewer {
border-radius: 16px;
border: none;
background: rgba(255, 255, 255, 0.98);
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
}
.ai-response-viewer:hover {
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #36d1dc 0%, #5b86e5 100%);
border-radius: 10px;
color: #ffffff;
font-size: 20px;
}
.header-title {
font-size: 16px;
font-weight: 600;
color: #303133;
}
.placeholder-text {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 60px 20px;
color: #909399;
}
.placeholder-icon {
font-size: 48px;
margin-bottom: 16px;
opacity: 0.5;
}
.placeholder-text p {
margin: 0;
font-size: 14px;
}
.markdown-content {
padding: 16px;
background: #f8f9fa;
border-radius: 12px;
line-height: 1.8;
color: #303133;
max-height: 400px;
overflow-y: auto;
}
.markdown-content :deep(h1) {
font-size: 24px;
font-weight: 700;
margin: 16px 0 12px;
color: #303133;
}
.markdown-content :deep(h2) {
font-size: 20px;
font-weight: 600;
margin: 14px 0 10px;
color: #303133;
}
.markdown-content :deep(h3) {
font-size: 16px;
font-weight: 600;
margin: 12px 0 8px;
color: #303133;
}
.markdown-content :deep(pre) {
background: #1e1e1e;
border-radius: 8px;
padding: 16px;
overflow-x: auto;
margin: 12px 0;
}
.markdown-content :deep(code) {
font-family: 'Consolas', 'Monaco', monospace;
font-size: 13px;
}
.markdown-content :deep(pre code) {
color: #d4d4d4;
}
.markdown-content :deep(.inline-code) {
background: #e8e8e8;
padding: 2px 6px;
border-radius: 4px;
font-size: 13px;
color: #e83e8c;
}
.markdown-content :deep(strong) {
font-weight: 600;
color: #303133;
}
.markdown-content :deep(em) {
font-style: italic;
color: #606266;
}
.markdown-content :deep(li) {
margin: 4px 0;
padding-left: 8px;
}
.stats-section {
margin-top: 8px;
}
.section-label {
display: flex;
align-items: center;
gap: 8px;
font-size: 14px;
font-weight: 600;
color: #606266;
margin-bottom: 16px;
}
.section-label .el-icon {
color: #5b86e5;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 12px;
}
.stat-card {
display: flex;
align-items: center;
gap: 12px;
padding: 14px 16px;
background: linear-gradient(135deg, #f5f7fa 0%, #e8ecf1 100%);
border-radius: 12px;
border: 1px solid #e4e7ed;
transition: all 0.3s ease;
}
.stat-card:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
}
.stat-card.highlight {
background: linear-gradient(135deg, #e8f4fd 0%, #d4e9f7 100%);
border-color: #b8d9f0;
}
.stat-icon {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 8px;
color: #ffffff;
font-size: 16px;
}
.model-icon {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.latency-icon {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
}
.prompt-icon {
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
}
.completion-icon {
background: linear-gradient(135deg, #43e97b 0%, #38f9d7 100%);
}
.total-icon {
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
}
.stat-info {
display: flex;
flex-direction: column;
}
.stat-label {
font-size: 12px;
color: #909399;
}
.stat-value {
font-size: 16px;
font-weight: 600;
color: #303133;
}
</style>

View File

@ -0,0 +1,159 @@
<template>
<el-select
:model-value="displayValue"
:loading="loading"
:placeholder="computedPlaceholder"
:disabled="disabled"
:clearable="clearable"
:teleported="true"
:popper-class="popperClass"
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }, { name: 'computeStyles', options: { adaptive: false, gpuAcceleration: false } }] }"
@update:model-value="handleChange"
@clear="handleClear"
>
<el-option
v-for="provider in providers"
:key="provider.name"
:label="provider.display_name"
:value="provider.name"
>
<div class="provider-option">
<div class="provider-info">
<span class="provider-name">{{ provider.display_name }}</span>
<span v-if="provider.description" class="provider-desc">{{ provider.description }}</span>
</div>
<el-tag v-if="provider.name === currentProvider" type="success" size="small" effect="plain" class="current-tag">
当前配置
</el-tag>
<el-tag v-else-if="provider.name === modelValue" type="primary" size="small" effect="plain" class="selected-tag">
已选择
</el-tag>
</div>
</el-option>
</el-select>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import type { LLMProviderInfo } from '@/api/llm'
const popperClass = 'llm-selector-popper'
const props = withDefaults(
defineProps<{
modelValue?: string
providers: LLMProviderInfo[]
loading?: boolean
disabled?: boolean
clearable?: boolean
placeholder?: string
currentProvider?: string
}>(),
{
modelValue: '',
loading: false,
disabled: false,
clearable: false,
placeholder: '请选择 LLM 提供者',
currentProvider: ''
}
)
const emit = defineEmits<{
'update:modelValue': [value: string]
change: [provider: LLMProviderInfo | undefined]
}>()
const displayValue = computed(() => {
return props.modelValue || ''
})
const computedPlaceholder = computed(() => {
if (props.modelValue) {
return props.placeholder
}
if (props.currentProvider) {
const current = props.providers.find(p => p.name === props.currentProvider)
return `默认: ${current?.display_name || props.currentProvider}`
}
return props.placeholder
})
const handleChange = (value: string) => {
emit('update:modelValue', value)
const selectedProvider = props.providers.find((p) => p.name === value)
emit('change', selectedProvider)
}
const handleClear = () => {
emit('update:modelValue', '')
emit('change', undefined)
}
</script>
<style>
.llm-selector-popper {
min-width: 320px !important;
z-index: 9999 !important;
}
.llm-selector-popper .el-select-dropdown__wrap {
max-height: 400px;
}
.llm-selector-popper .el-select-dropdown__item {
height: auto;
padding: 8px 12px;
line-height: 1.5;
}
</style>
<style scoped>
.provider-option {
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
padding: 4px 0;
gap: 12px;
}
.provider-info {
display: flex;
flex-direction: column;
line-height: 1.5;
flex: 1;
min-width: 0;
overflow: hidden;
}
.provider-name {
font-weight: 500;
color: var(--text-primary);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.provider-desc {
font-size: 12px;
color: var(--text-secondary);
margin-top: 2px;
line-height: 1.4;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.current-tag {
flex-shrink: 0;
margin-left: 8px;
white-space: nowrap;
}
.selected-tag {
flex-shrink: 0;
margin-left: 8px;
white-space: nowrap;
}
</style>

View File

@ -0,0 +1,299 @@
<template>
<el-card shadow="hover" class="stream-output">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper" :class="{ streaming: isStreaming }">
<el-icon><Promotion /></el-icon>
</div>
<span class="header-title">流式输出</span>
</div>
<div class="header-actions">
<el-tag v-if="isStreaming" type="warning" size="small" effect="dark" class="pulse-tag">
<el-icon class="is-loading"><Loading /></el-icon>
生成中...
</el-tag>
<el-tag v-else-if="hasContent" type="success" size="small" effect="dark">
已完成
</el-tag>
<el-tag v-else type="info" size="small" effect="plain">
等待中
</el-tag>
</div>
</div>
</template>
<div class="stream-content">
<div v-if="!hasContent && !isStreaming" class="placeholder-text">
<el-icon class="placeholder-icon"><ChatLineSquare /></el-icon>
<p>启用流式输出后AI 回复将实时显示</p>
</div>
<div v-else class="output-area">
<div class="stream-text" v-html="renderedContent"></div>
<div v-if="isStreaming" class="typing-indicator">
<span class="dot"></span>
<span class="dot"></span>
<span class="dot"></span>
</div>
</div>
<div v-if="error" class="error-section">
<el-alert
:title="error"
type="error"
:closable="false"
show-icon
/>
</div>
</div>
</el-card>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { Promotion, Loading, ChatLineSquare } from '@element-plus/icons-vue'
const props = defineProps<{
content: string
isStreaming: boolean
error?: string | null
}>()
const hasContent = computed(() => props.content && props.content.length > 0)
const renderedContent = computed(() => {
if (!props.content) return ''
return renderMarkdown(props.content)
})
const renderMarkdown = (text: string): string => {
let html = text
html = html.replace(/```(\w*)\n([\s\S]*?)```/g, '<pre><code class="language-$1">$2</code></pre>')
html = html.replace(/`([^`]+)`/g, '<code class="inline-code">$1</code>')
html = html.replace(/^### (.+)$/gm, '<h3>$1</h3>')
html = html.replace(/^## (.+)$/gm, '<h2>$1</h2>')
html = html.replace(/^# (.+)$/gm, '<h1>$1</h1>')
html = html.replace(/\*\*([^*]+)\*\*/g, '<strong>$1</strong>')
html = html.replace(/\*([^*]+)\*/g, '<em>$1</em>')
html = html.replace(/^\- (.+)$/gm, '<li>$1</li>')
html = html.replace(/^\d+\. (.+)$/gm, '<li>$1</li>')
html = html.replace(/\n\n/g, '</p><p>')
html = html.replace(/\n/g, '<br>')
return `<p>${html}</p>`
}
</script>
<style scoped>
.stream-output {
border-radius: 16px;
border: none;
background: rgba(255, 255, 255, 0.98);
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
}
.stream-output:hover {
box-shadow: 0 12px 48px rgba(0, 0, 0, 0.15);
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
border-radius: 10px;
color: #ffffff;
font-size: 20px;
transition: all 0.3s ease;
}
.icon-wrapper.streaming {
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(17, 153, 142, 0.4);
}
50% {
transform: scale(1.05);
box-shadow: 0 0 0 10px rgba(17, 153, 142, 0);
}
}
.header-title {
font-size: 16px;
font-weight: 600;
color: #303133;
}
.pulse-tag {
animation: pulse-tag 1.5s ease-in-out infinite;
}
@keyframes pulse-tag {
0%, 100% {
opacity: 1;
}
50% {
opacity: 0.7;
}
}
.placeholder-text {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 60px 20px;
color: #909399;
}
.placeholder-icon {
font-size: 48px;
margin-bottom: 16px;
opacity: 0.5;
}
.placeholder-text p {
margin: 0;
font-size: 14px;
}
.output-area {
padding: 16px;
background: #f8f9fa;
border-radius: 12px;
min-height: 200px;
max-height: 400px;
overflow-y: auto;
}
.stream-text {
line-height: 1.8;
color: #303133;
}
.stream-text :deep(h1) {
font-size: 24px;
font-weight: 700;
margin: 16px 0 12px;
color: #303133;
}
.stream-text :deep(h2) {
font-size: 20px;
font-weight: 600;
margin: 14px 0 10px;
color: #303133;
}
.stream-text :deep(h3) {
font-size: 16px;
font-weight: 600;
margin: 12px 0 8px;
color: #303133;
}
.stream-text :deep(pre) {
background: #1e1e1e;
border-radius: 8px;
padding: 16px;
overflow-x: auto;
margin: 12px 0;
}
.stream-text :deep(code) {
font-family: 'Consolas', 'Monaco', monospace;
font-size: 13px;
}
.stream-text :deep(pre code) {
color: #d4d4d4;
}
.stream-text :deep(.inline-code) {
background: #e8e8e8;
padding: 2px 6px;
border-radius: 4px;
font-size: 13px;
color: #e83e8c;
}
.stream-text :deep(strong) {
font-weight: 600;
color: #303133;
}
.stream-text :deep(em) {
font-style: italic;
color: #606266;
}
.stream-text :deep(li) {
margin: 4px 0;
padding-left: 8px;
}
.typing-indicator {
display: flex;
gap: 4px;
margin-top: 12px;
padding: 8px 12px;
background: rgba(17, 153, 142, 0.1);
border-radius: 8px;
width: fit-content;
}
.typing-indicator .dot {
width: 8px;
height: 8px;
background: #11998e;
border-radius: 50%;
animation: typing 1.4s infinite ease-in-out both;
}
.typing-indicator .dot:nth-child(1) {
animation-delay: -0.32s;
}
.typing-indicator .dot:nth-child(2) {
animation-delay: -0.16s;
}
@keyframes typing {
0%, 80%, 100% {
transform: scale(0.6);
opacity: 0.5;
}
40% {
transform: scale(1);
opacity: 1;
}
}
.error-section {
margin-top: 16px;
}
</style>

View File

@ -0,0 +1,16 @@
import { createApp } from 'vue'
import { createPinia } from 'pinia'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import './styles/main.css'
import App from './App.vue'
import router from './router'
const app = createApp(App)
const pinia = createPinia()
app.use(pinia)
app.use(router)
app.use(ElementPlus)
app.mount('#app')

View File

@ -0,0 +1,51 @@
import { createRouter, createWebHistory, RouteRecordRaw } from 'vue-router'
const routes: Array<RouteRecordRaw> = [
{
path: '/',
redirect: '/dashboard'
},
{
path: '/dashboard',
name: 'Dashboard',
component: () => import('@/views/dashboard/index.vue'),
meta: { title: '控制台' }
},
{
path: '/kb',
name: 'KBManagement',
component: () => import('@/views/kb/index.vue'),
meta: { title: '知识库管理' }
},
{
path: '/rag-lab',
name: 'RagLab',
component: () => import('@/views/rag-lab/index.vue'),
meta: { title: 'RAG 实验室' }
},
{
path: '/monitoring',
name: 'Monitoring',
component: () => import('@/views/monitoring/index.vue'),
meta: { title: '会话监控' }
},
{
path: '/admin/embedding',
name: 'EmbeddingConfig',
component: () => import('@/views/admin/embedding/index.vue'),
meta: { title: '嵌入模型配置' }
},
{
path: '/admin/llm',
name: 'LLMConfig',
component: () => import('@/views/admin/llm/index.vue'),
meta: { title: 'LLM 模型配置' }
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
export default router

View File

@ -0,0 +1,164 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import {
getProviders,
getConfig,
saveConfig,
testEmbedding,
getSupportedFormats,
type EmbeddingProviderInfo,
type EmbeddingConfig,
type EmbeddingConfigUpdate,
type EmbeddingTestResult,
type DocumentFormat
} from '@/api/embedding'
export const useEmbeddingStore = defineStore('embedding', () => {
const providers = ref<EmbeddingProviderInfo[]>([])
const currentConfig = ref<EmbeddingConfig>({
provider: '',
config: {}
})
const formats = ref<DocumentFormat[]>([])
const loading = ref(false)
const providersLoading = ref(false)
const formatsLoading = ref(false)
const testResult = ref<EmbeddingTestResult | null>(null)
const testLoading = ref(false)
const currentProvider = computed(() => {
return providers.value.find(p => p.name === currentConfig.value.provider)
})
const configSchema = computed(() => {
return currentProvider.value?.config_schema || { properties: {} }
})
const loadProviders = async () => {
providersLoading.value = true
try {
const res: any = await getProviders()
providers.value = res?.providers || res?.data?.providers || []
} catch (error) {
console.error('Failed to load providers:', error)
throw error
} finally {
providersLoading.value = false
}
}
const loadConfig = async () => {
loading.value = true
try {
const res: any = await getConfig()
const config = res?.data || res
if (config) {
currentConfig.value = {
provider: config.provider || '',
config: config.config || {},
updated_at: config.updated_at
}
}
} catch (error) {
console.error('Failed to load config:', error)
throw error
} finally {
loading.value = false
}
}
const saveCurrentConfig = async () => {
loading.value = true
try {
const updateData: EmbeddingConfigUpdate = {
provider: currentConfig.value.provider,
config: currentConfig.value.config
}
await saveConfig(updateData)
} catch (error) {
console.error('Failed to save config:', error)
throw error
} finally {
loading.value = false
}
}
const runTest = async (testText?: string) => {
testLoading.value = true
testResult.value = null
try {
const result = await testEmbedding({
test_text: testText,
config: {
provider: currentConfig.value.provider,
config: currentConfig.value.config
}
})
testResult.value = result
} catch (error: any) {
testResult.value = {
success: false,
dimension: 0,
error: error?.message || '连接测试失败'
}
} finally {
testLoading.value = false
}
}
const loadFormats = async () => {
formatsLoading.value = true
try {
const res: any = await getSupportedFormats()
formats.value = res?.formats || res?.data?.formats || []
} catch (error) {
console.error('Failed to load formats:', error)
throw error
} finally {
formatsLoading.value = false
}
}
const setProvider = (providerName: string) => {
currentConfig.value.provider = providerName
const provider = providers.value.find(p => p.name === providerName)
if (provider?.config_schema?.properties) {
const newConfig: Record<string, any> = {}
Object.entries(provider.config_schema.properties).forEach(([key, field]: [string, any]) => {
newConfig[key] = field.default !== undefined ? field.default : ''
})
currentConfig.value.config = newConfig
} else {
currentConfig.value.config = {}
}
}
const updateConfigValue = (key: string, value: any) => {
currentConfig.value.config[key] = value
}
const clearTestResult = () => {
testResult.value = null
}
return {
providers,
currentConfig,
formats,
loading,
providersLoading,
formatsLoading,
testResult,
testLoading,
currentProvider,
configSchema,
loadProviders,
loadConfig,
saveCurrentConfig,
runTest,
loadFormats,
setProvider,
updateConfigValue,
clearTestResult
}
})

View File

@ -0,0 +1,161 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import {
getLLMProviders,
getLLMConfig,
updateLLMConfig,
testLLM,
type LLMProviderInfo,
type LLMConfig,
type LLMConfigUpdate,
type LLMTestResult
} from '@/api/llm'
export const useLLMStore = defineStore('llm', () => {
const providers = ref<LLMProviderInfo[]>([])
const currentConfig = ref<LLMConfig>({
provider: '',
config: {}
})
const loading = ref(false)
const providersLoading = ref(false)
const testResult = ref<LLMTestResult | null>(null)
const testLoading = ref(false)
const currentProvider = computed(() => {
return providers.value.find(p => p.name === currentConfig.value.provider)
})
const configSchema = computed(() => {
return currentProvider.value?.config_schema || { properties: {} }
})
const loadProviders = async () => {
providersLoading.value = true
try {
const res: any = await getLLMProviders()
providers.value = res?.providers || res?.data?.providers || []
} catch (error) {
console.error('Failed to load LLM providers:', error)
throw error
} finally {
providersLoading.value = false
}
}
const loadConfig = async () => {
loading.value = true
try {
const res: any = await getLLMConfig()
const config = res?.data || res
if (config) {
currentConfig.value = {
provider: config.provider || '',
config: config.config || {},
updated_at: config.updated_at
}
}
} catch (error) {
console.error('Failed to load LLM config:', error)
throw error
} finally {
loading.value = false
}
}
const saveCurrentConfig = async () => {
loading.value = true
try {
const updateData: LLMConfigUpdate = {
provider: currentConfig.value.provider,
config: currentConfig.value.config
}
await updateLLMConfig(updateData)
} catch (error) {
console.error('Failed to save LLM config:', error)
throw error
} finally {
loading.value = false
}
}
const runTest = async (testPrompt?: string): Promise<LLMTestResult> => {
testLoading.value = true
testResult.value = null
try {
const result = await testLLM({
test_prompt: testPrompt,
provider: currentConfig.value.provider,
config: currentConfig.value.config
})
testResult.value = result
return result
} catch (error: any) {
const errorResult: LLMTestResult = {
success: false,
error: error?.message || '连接测试失败'
}
testResult.value = errorResult
return errorResult
} finally {
testLoading.value = false
}
}
const setProvider = (providerName: string) => {
currentConfig.value.provider = providerName
const provider = providers.value.find(p => p.name === providerName)
if (provider?.config_schema?.properties) {
const newConfig: Record<string, any> = {}
Object.entries(provider.config_schema.properties).forEach(([key, field]: [string, any]) => {
if (field.default !== undefined) {
newConfig[key] = field.default
} else {
switch (field.type) {
case 'string':
newConfig[key] = ''
break
case 'integer':
case 'number':
newConfig[key] = field.minimum ?? 0
break
case 'boolean':
newConfig[key] = false
break
default:
newConfig[key] = null
}
}
})
currentConfig.value.config = newConfig
} else {
currentConfig.value.config = {}
}
}
const updateConfigValue = (key: string, value: any) => {
currentConfig.value.config[key] = value
}
const clearTestResult = () => {
testResult.value = null
}
return {
providers,
currentConfig,
loading,
providersLoading,
testResult,
testLoading,
currentProvider,
configSchema,
loadProviders,
loadConfig,
saveCurrentConfig,
runTest,
setProvider,
updateConfigValue,
clearTestResult
}
})

View File

@ -0,0 +1,126 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import {
runRagExperiment,
createSSEConnection,
type AIResponse,
type RetrievalResult,
type RagExperimentRequest,
type RagExperimentResult
} from '@/api/rag'
export const useRagStore = defineStore('rag', () => {
const retrievalResults = ref<RetrievalResult[]>([])
const finalPrompt = ref('')
const aiResponse = ref<AIResponse | null>(null)
const totalLatencyMs = ref<number>(0)
const loading = ref(false)
const streaming = ref(false)
const streamContent = ref('')
const streamError = ref<string | null>(null)
const hasResults = computed(() => retrievalResults.value.length > 0 || aiResponse.value !== null)
const abortStream = ref<(() => void) | null>(null)
const runExperiment = async (params: RagExperimentRequest) => {
loading.value = true
streamError.value = null
try {
const result: RagExperimentResult = await runRagExperiment(params)
retrievalResults.value = result.retrieval_results || []
finalPrompt.value = result.final_prompt || ''
aiResponse.value = result.ai_response || null
totalLatencyMs.value = result.total_latency_ms || 0
return result
} catch (error: any) {
streamError.value = error?.message || '实验运行失败'
throw error
} finally {
loading.value = false
}
}
const startStream = (params: RagExperimentRequest) => {
streaming.value = true
streamContent.value = ''
streamError.value = null
aiResponse.value = null
abortStream.value = createSSEConnection(
'/admin/rag/experiments/stream',
params,
(data: string) => {
try {
const parsed = JSON.parse(data)
if (parsed.type === 'content') {
streamContent.value += parsed.content || ''
} else if (parsed.type === 'retrieval') {
retrievalResults.value = parsed.results || []
} else if (parsed.type === 'prompt') {
finalPrompt.value = parsed.prompt || ''
} else if (parsed.type === 'complete') {
aiResponse.value = {
content: streamContent.value,
prompt_tokens: parsed.prompt_tokens,
completion_tokens: parsed.completion_tokens,
total_tokens: parsed.total_tokens,
latency_ms: parsed.latency_ms,
model: parsed.model
}
totalLatencyMs.value = parsed.total_latency_ms || 0
} else if (parsed.type === 'error') {
streamError.value = parsed.message || '流式输出错误'
}
} catch {
streamContent.value += data
}
},
(error: Error) => {
streaming.value = false
streamError.value = error.message
},
() => {
streaming.value = false
}
)
}
const stopStream = () => {
if (abortStream.value) {
abortStream.value()
abortStream.value = null
}
streaming.value = false
}
const clearResults = () => {
retrievalResults.value = []
finalPrompt.value = ''
aiResponse.value = null
totalLatencyMs.value = 0
streamContent.value = ''
streamError.value = null
}
return {
retrievalResults,
finalPrompt,
aiResponse,
totalLatencyMs,
loading,
streaming,
streamContent,
streamError,
hasResults,
runExperiment,
startStream,
stopStream,
clearResults
}
})

View File

@ -0,0 +1,41 @@
import { defineStore } from 'pinia'
import { ref, watch } from 'vue'
export const useRagLabStore = defineStore('ragLab', () => {
const query = ref(localStorage.getItem('ragLab_query') || '')
const kbIds = ref<string[]>(JSON.parse(localStorage.getItem('ragLab_kbIds') || '[]'))
const llmProvider = ref(localStorage.getItem('ragLab_llmProvider') || '')
const topK = ref(parseInt(localStorage.getItem('ragLab_topK') || '3', 10))
const scoreThreshold = ref(parseFloat(localStorage.getItem('ragLab_scoreThreshold') || '0.5'))
const generateResponse = ref(localStorage.getItem('ragLab_generateResponse') !== 'false')
const streamOutput = ref(localStorage.getItem('ragLab_streamOutput') === 'true')
watch(query, (val) => localStorage.setItem('ragLab_query', val))
watch(kbIds, (val) => localStorage.setItem('ragLab_kbIds', JSON.stringify(val)), { deep: true })
watch(llmProvider, (val) => localStorage.setItem('ragLab_llmProvider', val))
watch(topK, (val) => localStorage.setItem('ragLab_topK', String(val)))
watch(scoreThreshold, (val) => localStorage.setItem('ragLab_scoreThreshold', String(val)))
watch(generateResponse, (val) => localStorage.setItem('ragLab_generateResponse', String(val)))
watch(streamOutput, (val) => localStorage.setItem('ragLab_streamOutput', String(val)))
const clearParams = () => {
query.value = ''
kbIds.value = []
llmProvider.value = ''
topK.value = 3
scoreThreshold.value = 0.5
generateResponse.value = true
streamOutput.value = false
}
return {
query,
kbIds,
llmProvider,
topK,
scoreThreshold,
generateResponse,
streamOutput,
clearParams
}
})

View File

@ -0,0 +1,16 @@
import { defineStore } from 'pinia'
// Default tenant ID format: name@ash@year
const DEFAULT_TENANT_ID = 'default@ash@2026'
export const useTenantStore = defineStore('tenant', {
state: () => ({
currentTenantId: localStorage.getItem('X-Tenant-Id') || DEFAULT_TENANT_ID
}),
actions: {
setTenant(id: string) {
this.currentTenantId = id
localStorage.setItem('X-Tenant-Id', id)
}
}
})

View File

@ -0,0 +1,486 @@
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=DM+Sans:wght@400;500;600;700&display=swap');
:root {
--primary-color: #4F7CFF;
--primary-light: #6B91FF;
--primary-lighter: #E8EEFF;
--primary-dark: #3A5FD9;
--secondary-color: #6366F1;
--secondary-light: #818CF8;
--accent-color: #10B981;
--accent-light: #34D399;
--warning-color: #F59E0B;
--danger-color: #EF4444;
--success-color: #10B981;
--info-color: #3B82F6;
--bg-primary: #F8FAFC;
--bg-secondary: #FFFFFF;
--bg-tertiary: #F1F5F9;
--bg-hover: #F1F5F9;
--text-primary: #1E293B;
--text-secondary: #64748B;
--text-tertiary: #94A3B8;
--text-placeholder: #CBD5E1;
--border-color: #E2E8F0;
--border-light: #F1F5F9;
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
--shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -2px rgba(0, 0, 0, 0.05);
--shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.05), 0 4px 6px -4px rgba(0, 0, 0, 0.05);
--shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.05), 0 8px 10px -6px rgba(0, 0, 0, 0.05);
--radius-sm: 6px;
--radius-md: 10px;
--radius-lg: 14px;
--radius-xl: 20px;
--transition-fast: 0.15s ease;
--transition-normal: 0.25s ease;
--transition-slow: 0.35s ease;
--font-sans: 'DM Sans', 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
--font-mono: 'JetBrains Mono', 'Fira Code', 'SF Mono', Consolas, monospace;
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
html, body {
font-family: var(--font-sans);
font-size: 14px;
line-height: 1.6;
color: var(--text-primary);
background-color: var(--bg-primary);
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
.el-menu {
font-family: var(--font-sans) !important;
border-bottom: 1px solid var(--border-color) !important;
background-color: var(--bg-secondary) !important;
}
.el-menu-item {
font-weight: 500 !important;
transition: all var(--transition-fast) !important;
}
.el-menu-item:hover {
background-color: var(--bg-hover) !important;
}
.el-menu-item.is-active {
color: var(--primary-color) !important;
border-bottom-color: var(--primary-color) !important;
background-color: var(--primary-lighter) !important;
}
.el-card {
border-radius: var(--radius-lg) !important;
border: 1px solid var(--border-color) !important;
box-shadow: var(--shadow-sm) !important;
transition: all var(--transition-normal) !important;
background-color: var(--bg-secondary) !important;
}
.el-card:hover {
box-shadow: var(--shadow-md) !important;
}
.el-card__header {
border-bottom: 1px solid var(--border-light) !important;
padding: 16px 20px !important;
font-weight: 600 !important;
color: var(--text-primary) !important;
}
.el-card__body {
padding: 20px !important;
}
.el-button--primary {
background-color: var(--primary-color) !important;
border-color: var(--primary-color) !important;
font-weight: 500 !important;
transition: all var(--transition-fast) !important;
}
.el-button--primary:hover {
background-color: var(--primary-light) !important;
border-color: var(--primary-light) !important;
transform: translateY(-1px);
box-shadow: var(--shadow-md) !important;
}
.el-button--default {
font-weight: 500 !important;
border-color: var(--border-color) !important;
transition: all var(--transition-fast) !important;
}
.el-button--default:hover {
border-color: var(--primary-color) !important;
color: var(--primary-color) !important;
background-color: var(--primary-lighter) !important;
}
.el-input__wrapper {
border-radius: var(--radius-md) !important;
box-shadow: none !important;
border: 1px solid var(--border-color) !important;
transition: all var(--transition-fast) !important;
}
.el-input__wrapper:hover {
border-color: var(--primary-light) !important;
}
.el-input__wrapper.is-focus {
border-color: var(--primary-color) !important;
box-shadow: 0 0 0 3px var(--primary-lighter) !important;
}
.el-select {
--el-select-border-color-hover: var(--primary-light) !important;
}
.el-select .el-input__wrapper {
border-radius: var(--radius-md) !important;
}
.el-select-dropdown {
border-radius: var(--radius-lg) !important;
border: 1px solid var(--border-color) !important;
box-shadow: var(--shadow-xl) !important;
margin-top: 8px !important;
}
.el-select-dropdown__wrap {
max-height: 320px !important;
}
.el-select-dropdown__item {
padding: 10px 16px !important;
line-height: 1.5 !important;
transition: all var(--transition-fast) !important;
}
.el-select-dropdown__item.hover,
.el-select-dropdown__item:hover {
background-color: var(--primary-lighter) !important;
color: var(--primary-color) !important;
}
.el-select-dropdown__item.is-selected {
background-color: var(--primary-lighter) !important;
color: var(--primary-color) !important;
font-weight: 600 !important;
}
.el-tag {
border-radius: var(--radius-sm) !important;
font-weight: 500 !important;
border: none !important;
}
.el-tag--success {
background-color: #D1FAE5 !important;
color: #059669 !important;
}
.el-tag--warning {
background-color: #FEF3C7 !important;
color: #D97706 !important;
}
.el-tag--danger {
background-color: #FEE2E2 !important;
color: #DC2626 !important;
}
.el-tag--info {
background-color: #E0E7FF !important;
color: #4F46E5 !important;
}
.el-table {
border-radius: var(--radius-lg) !important;
overflow: hidden !important;
}
.el-table th.el-table__cell {
background-color: var(--bg-tertiary) !important;
font-weight: 600 !important;
color: var(--text-secondary) !important;
}
.el-table td.el-table__cell {
border-bottom: 1px solid var(--border-light) !important;
}
.el-table--striped .el-table__body tr.el-table__row--striped td.el-table__cell {
background-color: var(--bg-tertiary) !important;
}
.el-table__row:hover > td.el-table__cell {
background-color: var(--primary-lighter) !important;
}
.el-tabs__item {
font-weight: 500 !important;
transition: all var(--transition-fast) !important;
}
.el-tabs__item.is-active {
color: var(--primary-color) !important;
}
.el-tabs__active-bar {
background-color: var(--primary-color) !important;
}
.el-dialog {
border-radius: var(--radius-xl) !important;
overflow: hidden !important;
}
.el-dialog__header {
padding: 20px 24px !important;
border-bottom: 1px solid var(--border-light) !important;
}
.el-dialog__title {
font-weight: 600 !important;
font-size: 18px !important;
}
.el-dialog__body {
padding: 24px !important;
}
.el-form-item__label {
font-weight: 500 !important;
color: var(--text-secondary) !important;
}
.el-slider__bar {
background-color: var(--primary-color) !important;
}
.el-slider__button {
border-color: var(--primary-color) !important;
}
.el-switch.is-checked .el-switch__core {
background-color: var(--primary-color) !important;
border-color: var(--primary-color) !important;
}
.el-alert {
border-radius: var(--radius-md) !important;
border: none !important;
}
.el-alert--info {
background-color: #EFF6FF !important;
}
.el-alert--success {
background-color: #ECFDF5 !important;
}
.el-alert--warning {
background-color: #FFFBEB !important;
}
.el-alert--error {
background-color: #FEF2F2 !important;
}
.el-divider {
border-color: var(--border-light) !important;
}
.el-empty__description {
color: var(--text-tertiary) !important;
}
.el-loading-mask {
border-radius: var(--radius-lg) !important;
}
.el-descriptions {
border-radius: var(--radius-md) !important;
overflow: hidden !important;
}
.el-descriptions__label {
background-color: var(--bg-tertiary) !important;
font-weight: 500 !important;
}
.fade-in-up {
animation: fadeInUp 0.5s ease-out forwards;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.slide-in-left {
animation: slideInLeft 0.4s ease-out forwards;
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-20px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
.scale-in {
animation: scaleIn 0.3s ease-out forwards;
}
@keyframes scaleIn {
from {
opacity: 0;
transform: scale(0.95);
}
to {
opacity: 1;
transform: scale(1);
}
}
.page-container {
padding: 24px;
min-height: calc(100vh - 60px);
background-color: var(--bg-primary);
}
.page-header {
margin-bottom: 24px;
}
.page-title {
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
margin: 0 0 8px 0;
letter-spacing: -0.5px;
}
.page-desc {
font-size: 14px;
color: var(--text-secondary);
margin: 0;
line-height: 1.6;
}
.card-icon {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
border-radius: var(--radius-md);
font-size: 18px;
}
.card-icon.primary {
background-color: var(--primary-lighter);
color: var(--primary-color);
}
.card-icon.success {
background-color: #D1FAE5;
color: #059669;
}
.card-icon.warning {
background-color: #FEF3C7;
color: #D97706;
}
.card-icon.info {
background-color: #E0E7FF;
color: #4F46E5;
}
.stat-card {
position: relative;
overflow: hidden;
}
.stat-card::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
height: 3px;
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
opacity: 0;
transition: opacity var(--transition-fast);
}
.stat-card:hover::before {
opacity: 1;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: var(--bg-tertiary);
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: var(--text-tertiary);
border-radius: 4px;
transition: background var(--transition-fast);
}
::-webkit-scrollbar-thumb:hover {
background: var(--text-secondary);
}
::selection {
background-color: var(--primary-lighter);
color: var(--primary-dark);
}
@media (max-width: 768px) {
.page-container {
padding: 16px;
}
.page-title {
font-size: 20px;
}
}

View File

@ -0,0 +1,49 @@
export interface EmbeddingProviderInfo {
name: string
display_name: string
description?: string
config_schema: Record<string, any>
}
export interface EmbeddingConfig {
provider: string
config: Record<string, any>
updated_at?: string
}
export interface EmbeddingConfigUpdate {
provider: string
config?: Record<string, any>
}
export interface EmbeddingTestResult {
success: boolean
dimension: number
latency_ms?: number
message?: string
error?: string
}
export interface DocumentFormat {
extension: string
name: string
description?: string
}
export interface EmbeddingProvidersResponse {
providers: EmbeddingProviderInfo[]
}
export interface EmbeddingConfigUpdateResponse {
success: boolean
message: string
}
export interface SupportedFormatsResponse {
formats: DocumentFormat[]
}
export interface EmbeddingTestRequest {
test_text?: string
config?: EmbeddingConfigUpdate
}

View File

@ -0,0 +1,43 @@
export interface LLMProviderInfo {
name: string
display_name: string
description?: string
config_schema: Record<string, any>
}
export interface LLMConfig {
provider: string
config: Record<string, any>
updated_at?: string
}
export interface LLMConfigUpdate {
provider: string
config?: Record<string, any>
}
export interface LLMTestResult {
success: boolean
response?: string
latency_ms?: number
prompt_tokens?: number
completion_tokens?: number
total_tokens?: number
message?: string
error?: string
}
export interface LLMTestRequest {
test_prompt?: string
provider?: string
config?: Record<string, any>
}
export interface LLMProvidersResponse {
providers: LLMProviderInfo[]
}
export interface LLMConfigUpdateResponse {
success: boolean
message: string
}

View File

@ -0,0 +1,72 @@
import axios from 'axios'
import { ElMessage, ElMessageBox } from 'element-plus'
import { useTenantStore } from '@/stores/tenant'
// 创建 axios 实例
const service = axios.create({
baseURL: import.meta.env.VITE_APP_BASE_API || '/api',
timeout: 60000
})
// 请求拦截器
service.interceptors.request.use(
(config) => {
const tenantStore = useTenantStore()
if (tenantStore.currentTenantId) {
config.headers['X-Tenant-Id'] = tenantStore.currentTenantId
}
// TODO: 如果有 token 也可以在这里注入 Authorization
return config
},
(error) => {
console.log(error)
return Promise.reject(error)
}
)
// 响应拦截器
service.interceptors.response.use(
(response) => {
const res = response.data
// 这里可以根据后端的 code 进行统一处理
return res
},
(error) => {
console.log('err' + error)
let { message, response } = error
if (response) {
const status = response.status
if (status === 401) {
ElMessageBox.confirm('登录状态已过期,您可以继续留在该页面,或者重新登录', '系统提示', {
confirmButtonText: '重新登录',
cancelButtonText: '取消',
type: 'warning'
}).then(() => {
// TODO: 跳转到登录页或执行退出逻辑
location.href = '/login'
})
} else if (status === 403) {
ElMessage({
message: '当前操作无权限',
type: 'error',
duration: 5 * 1000
})
} else {
ElMessage({
message: message || '后端接口未知异常',
type: 'error',
duration: 5 * 1000
})
}
} else {
ElMessage({
message: '网络连接异常',
type: 'error',
duration: 5 * 1000
})
}
return Promise.reject(error)
}
)
export default service

View File

@ -0,0 +1,483 @@
<template>
<div class="embedding-config-page">
<div class="page-header">
<div class="header-content">
<div class="title-section">
<h1 class="page-title">嵌入模型配置</h1>
<p class="page-desc">配置和管理系统使用的嵌入模型支持多种提供者切换配置修改后需保存才能生效</p>
</div>
<div class="header-actions" v-if="currentConfig.updated_at">
<div class="update-info">
<el-icon class="update-icon"><Clock /></el-icon>
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
</div>
</div>
</div>
</div>
<el-row :gutter="24" v-loading="pageLoading" element-loading-text="加载中...">
<el-col :xs="24" :sm="24" :md="12" :lg="12">
<el-card shadow="hover" class="config-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><Setting /></el-icon>
</div>
<span class="header-title">模型配置</span>
</div>
</div>
</template>
<div class="card-content">
<div class="provider-select-section">
<div class="section-label">
<el-icon><Connection /></el-icon>
<span>选择提供者</span>
</div>
<EmbeddingProviderSelect
v-model="currentConfig.provider"
:providers="providers"
:loading="providersLoading"
placeholder="请选择嵌入模型提供者"
@change="handleProviderChange"
/>
<transition name="fade">
<div v-if="currentProvider" class="provider-info">
<el-icon class="info-icon"><InfoFilled /></el-icon>
<span class="info-text">{{ currentProvider.description }}</span>
</div>
</transition>
</div>
<el-divider />
<transition name="slide-fade" mode="out-in">
<div v-if="currentConfig.provider" key="form" class="config-form-section">
<EmbeddingConfigForm
ref="configFormRef"
:schema="configSchema"
v-model="currentConfig.config"
label-width="140px"
/>
</div>
<el-empty v-else key="empty" description="请先选择一个嵌入模型提供者" :image-size="120">
<template #image>
<div class="empty-icon">
<el-icon><Box /></el-icon>
</div>
</template>
</el-empty>
</transition>
</div>
<template #footer>
<div class="card-footer">
<el-button size="large" @click="handleReset">
<el-icon><RefreshLeft /></el-icon>
重置
</el-button>
<el-button type="primary" size="large" :loading="saving" @click="handleSave">
<el-icon><Check /></el-icon>
保存配置
</el-button>
</div>
</template>
</el-card>
</el-col>
<el-col :xs="24" :sm="24" :md="12" :lg="12">
<div class="right-column">
<EmbeddingTestPanel
:config="{ provider: currentConfig.provider, config: currentConfig.config }"
/>
<el-card shadow="hover" class="formats-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><Document /></el-icon>
</div>
<span class="header-title">支持的文档格式</span>
</div>
</div>
</template>
<SupportedFormats />
</el-card>
</div>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Setting, Connection, InfoFilled, Box, RefreshLeft, Check, Clock, Document } from '@element-plus/icons-vue'
import { useEmbeddingStore } from '@/stores/embedding'
import EmbeddingProviderSelect from '@/components/embedding/EmbeddingProviderSelect.vue'
import EmbeddingConfigForm from '@/components/embedding/EmbeddingConfigForm.vue'
import EmbeddingTestPanel from '@/components/embedding/EmbeddingTestPanel.vue'
import SupportedFormats from '@/components/embedding/SupportedFormats.vue'
const embeddingStore = useEmbeddingStore()
const configFormRef = ref<InstanceType<typeof EmbeddingConfigForm>>()
const saving = ref(false)
const pageLoading = ref(false)
const providers = computed(() => embeddingStore.providers)
const currentConfig = computed(() => embeddingStore.currentConfig)
const currentProvider = computed(() => embeddingStore.currentProvider)
const configSchema = computed(() => embeddingStore.configSchema)
const providersLoading = computed(() => embeddingStore.providersLoading)
const formatDate = (dateStr: string) => {
if (!dateStr) return ''
const date = new Date(dateStr)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit'
})
}
const handleProviderChange = (provider: any) => {
if (provider) {
embeddingStore.setProvider(provider.name)
}
}
const handleSave = async () => {
if (!currentConfig.value.provider) {
ElMessage.warning('请先选择嵌入模型提供者')
return
}
try {
const valid = await configFormRef.value?.validate()
if (!valid) {
return
}
} catch (error) {
ElMessage.warning('请检查配置表单中的必填项')
return
}
saving.value = true
try {
await embeddingStore.saveCurrentConfig()
ElMessage.success('配置保存成功')
} catch (error) {
ElMessage.error('配置保存失败')
} finally {
saving.value = false
}
}
const handleReset = async () => {
try {
await ElMessageBox.confirm('确定要重置配置吗?将恢复为当前保存的配置。', '确认重置', {
confirmButtonText: '确定',
cancelButtonText: '取消',
type: 'warning'
})
await embeddingStore.loadConfig()
ElMessage.success('配置已重置')
} catch (error) {
}
}
const initPage = async () => {
pageLoading.value = true
try {
await Promise.all([
embeddingStore.loadProviders(),
embeddingStore.loadConfig(),
embeddingStore.loadFormats()
])
} catch (error) {
ElMessage.error('初始化页面失败')
} finally {
pageLoading.value = false
}
}
onMounted(() => {
initPage()
})
</script>
<style scoped>
.embedding-config-page {
padding: 24px;
min-height: calc(100vh - 60px);
}
.page-header {
margin-bottom: 24px;
animation: slideDown 0.4s ease-out;
}
@keyframes slideDown {
from {
opacity: 0;
transform: translateY(-16px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.header-content {
display: flex;
justify-content: space-between;
align-items: flex-start;
gap: 20px;
flex-wrap: wrap;
}
.title-section {
flex: 1;
min-width: 300px;
}
.page-title {
margin: 0 0 8px 0;
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
letter-spacing: -0.5px;
}
.page-desc {
margin: 0;
font-size: 14px;
color: var(--text-secondary);
line-height: 1.6;
}
.header-actions {
display: flex;
align-items: center;
}
.update-info {
display: flex;
align-items: center;
gap: 6px;
padding: 8px 14px;
background-color: var(--bg-tertiary);
border-radius: 8px;
font-size: 13px;
color: var(--text-secondary);
}
.update-icon {
font-size: 14px;
color: var(--text-tertiary);
}
.config-card {
animation: fadeInUp 0.5s ease-out;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--primary-lighter);
border-radius: 10px;
color: var(--primary-color);
font-size: 18px;
}
.header-title {
font-size: 15px;
font-weight: 600;
color: var(--text-primary);
}
.card-content {
padding: 8px 0;
}
.provider-select-section {
margin-bottom: 16px;
}
.section-label {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 12px;
font-size: 13px;
font-weight: 600;
color: var(--text-secondary);
}
.section-label .el-icon {
color: var(--primary-color);
}
.provider-info {
display: flex;
align-items: flex-start;
gap: 10px;
margin-top: 14px;
padding: 14px 16px;
background-color: var(--bg-tertiary);
border-radius: 10px;
font-size: 13px;
color: var(--text-secondary);
line-height: 1.6;
border-left: 3px solid var(--primary-color);
}
.info-icon {
margin-top: 2px;
color: var(--primary-color);
font-size: 16px;
}
.info-text {
flex: 1;
}
.config-form-section {
max-height: 400px;
overflow-y: auto;
padding-right: 8px;
}
.config-form-section::-webkit-scrollbar {
width: 6px;
}
.config-form-section::-webkit-scrollbar-track {
background: var(--bg-tertiary);
border-radius: 3px;
}
.config-form-section::-webkit-scrollbar-thumb {
background: var(--text-tertiary);
border-radius: 3px;
}
.config-form-section::-webkit-scrollbar-thumb:hover {
background: var(--text-secondary);
}
.card-footer {
display: flex;
justify-content: flex-end;
gap: 12px;
padding-top: 8px;
}
.right-column {
display: flex;
flex-direction: column;
gap: 24px;
}
.formats-card {
animation: fadeInUp 0.6s ease-out;
}
.empty-icon {
width: 100px;
height: 100px;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--bg-tertiary);
border-radius: 50%;
margin: 0 auto;
}
.empty-icon .el-icon {
font-size: 48px;
color: var(--text-tertiary);
}
.fade-enter-active,
.fade-leave-active {
transition: opacity 0.25s ease;
}
.fade-enter-from,
.fade-leave-to {
opacity: 0;
}
.slide-fade-enter-active {
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
}
.slide-fade-leave-active {
transition: all 0.2s cubic-bezier(1, 0.5, 0.8, 1);
}
.slide-fade-enter-from {
opacity: 0;
transform: translateX(-16px);
}
.slide-fade-leave-to {
opacity: 0;
transform: translateX(16px);
}
@media (max-width: 768px) {
.embedding-config-page {
padding: 16px;
}
.page-title {
font-size: 20px;
}
.header-content {
flex-direction: column;
}
.title-section {
min-width: 100%;
}
.config-form-section {
max-height: 300px;
}
}
</style>

View File

@ -0,0 +1,470 @@
<template>
<div class="llm-config-page">
<div class="page-header">
<div class="header-content">
<div class="title-section">
<h1 class="page-title">LLM 模型配置</h1>
<p class="page-desc">配置和管理系统使用的大语言模型支持多种提供者切换配置修改后需保存才能生效</p>
</div>
<div class="header-actions" v-if="currentConfig.updated_at">
<div class="update-info">
<el-icon class="update-icon"><Clock /></el-icon>
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
</div>
</div>
</div>
</div>
<el-row :gutter="24" v-loading="pageLoading" element-loading-text="加载中...">
<el-col :xs="24" :sm="24" :md="12" :lg="12">
<el-card shadow="hover" class="config-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper llm-icon">
<el-icon><Cpu /></el-icon>
</div>
<span class="header-title">模型配置</span>
</div>
</div>
</template>
<div class="card-content">
<div class="provider-select-section">
<div class="section-label">
<el-icon><Connection /></el-icon>
<span>选择提供者</span>
</div>
<ProviderSelect
v-model="currentConfig.provider"
:providers="providers"
:loading="providersLoading"
placeholder="请选择 LLM 提供者"
@change="handleProviderChange"
/>
<transition name="fade">
<div v-if="currentProvider" class="provider-info">
<el-icon class="info-icon"><InfoFilled /></el-icon>
<span class="info-text">{{ currentProvider.description }}</span>
</div>
</transition>
</div>
<el-divider />
<transition name="slide-fade" mode="out-in">
<div v-if="currentConfig.provider" key="form" class="config-form-section">
<ConfigForm
ref="configFormRef"
:schema="configSchema"
v-model="currentConfig.config"
label-width="140px"
/>
</div>
<el-empty v-else key="empty" description="请先选择一个 LLM 提供者" :image-size="120">
<template #image>
<div class="empty-icon">
<el-icon><Box /></el-icon>
</div>
</template>
</el-empty>
</transition>
</div>
<template #footer>
<div class="card-footer">
<el-button size="large" @click="handleReset">
<el-icon><RefreshLeft /></el-icon>
重置
</el-button>
<el-button type="primary" size="large" :loading="saving" @click="handleSave">
<el-icon><Check /></el-icon>
保存配置
</el-button>
</div>
</template>
</el-card>
</el-col>
<el-col :xs="24" :sm="24" :md="12" :lg="12">
<TestPanel
:test-fn="handleTest"
:can-test="!!currentConfig.provider"
title="LLM 连接测试"
input-label="测试提示词"
input-placeholder="请输入测试提示词(可选,默认使用系统预设提示词)"
:show-token-stats="true"
/>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Clock } from '@element-plus/icons-vue'
import { useLLMStore } from '@/stores/llm'
import ProviderSelect from '@/components/common/ProviderSelect.vue'
import ConfigForm from '@/components/common/ConfigForm.vue'
import TestPanel from '@/components/common/TestPanel.vue'
import type { TestResult } from '@/components/common/TestPanel.vue'
const llmStore = useLLMStore()
const configFormRef = ref<InstanceType<typeof ConfigForm>>()
const saving = ref(false)
const pageLoading = ref(false)
const providers = computed(() => llmStore.providers)
const currentConfig = computed(() => llmStore.currentConfig)
const currentProvider = computed(() => llmStore.currentProvider)
const configSchema = computed(() => llmStore.configSchema)
const providersLoading = computed(() => llmStore.providersLoading)
const formatDate = (dateStr: string) => {
if (!dateStr) return ''
const date = new Date(dateStr)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit'
})
}
const handleProviderChange = (provider: any) => {
if (provider?.name) {
llmStore.setProvider(provider.name)
}
}
const handleSave = async () => {
if (!currentConfig.value.provider) {
ElMessage.warning('请先选择 LLM 提供者')
return
}
try {
const valid = await configFormRef.value?.validate()
if (!valid) {
return
}
} catch (error) {
ElMessage.warning('请检查配置表单中的必填项')
return
}
saving.value = true
try {
await llmStore.saveCurrentConfig()
ElMessage.success('配置保存成功')
} catch (error) {
ElMessage.error('配置保存失败')
} finally {
saving.value = false
}
}
const handleReset = async () => {
try {
await ElMessageBox.confirm('确定要重置配置吗?将恢复为当前保存的配置。', '确认重置', {
confirmButtonText: '确定',
cancelButtonText: '取消',
type: 'warning'
})
await llmStore.loadConfig()
ElMessage.success('配置已重置')
} catch (error) {
}
}
const handleTest = async (input?: string): Promise<TestResult> => {
return await llmStore.runTest(input)
}
const initPage = async () => {
pageLoading.value = true
try {
await Promise.all([
llmStore.loadProviders(),
llmStore.loadConfig()
])
} catch (error) {
ElMessage.error('初始化页面失败')
} finally {
pageLoading.value = false
}
}
onMounted(() => {
initPage()
})
</script>
<style scoped>
.llm-config-page {
padding: 24px;
min-height: calc(100vh - 60px);
}
.page-header {
margin-bottom: 24px;
animation: slideDown 0.4s ease-out;
}
@keyframes slideDown {
from {
opacity: 0;
transform: translateY(-16px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.header-content {
display: flex;
justify-content: space-between;
align-items: flex-start;
gap: 20px;
flex-wrap: wrap;
}
.title-section {
flex: 1;
min-width: 300px;
}
.page-title {
margin: 0 0 8px 0;
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
letter-spacing: -0.5px;
}
.page-desc {
margin: 0;
font-size: 14px;
color: var(--text-secondary);
line-height: 1.6;
}
.header-actions {
display: flex;
align-items: center;
}
.update-info {
display: flex;
align-items: center;
gap: 6px;
padding: 8px 14px;
background-color: var(--bg-tertiary);
border-radius: 8px;
font-size: 13px;
color: var(--text-secondary);
}
.update-icon {
font-size: 14px;
color: var(--text-tertiary);
}
.config-card {
animation: fadeInUp 0.5s ease-out;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--primary-lighter);
border-radius: 10px;
color: var(--primary-color);
font-size: 18px;
}
.icon-wrapper.llm-icon {
background: linear-gradient(135deg, #E0F2FE 0%, #BAE6FD 100%);
color: #0284C7;
}
.header-title {
font-size: 15px;
font-weight: 600;
color: var(--text-primary);
}
.card-content {
padding: 8px 0;
}
.provider-select-section {
margin-bottom: 16px;
}
.section-label {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 12px;
font-size: 13px;
font-weight: 600;
color: var(--text-secondary);
}
.section-label .el-icon {
color: var(--primary-color);
}
.provider-info {
display: flex;
align-items: flex-start;
gap: 10px;
margin-top: 14px;
padding: 14px 16px;
background-color: var(--bg-tertiary);
border-radius: 10px;
font-size: 13px;
color: var(--text-secondary);
line-height: 1.6;
border-left: 3px solid var(--primary-color);
}
.info-icon {
margin-top: 2px;
color: var(--primary-color);
font-size: 16px;
}
.info-text {
flex: 1;
}
.config-form-section {
max-height: 400px;
overflow-y: auto;
padding-right: 8px;
}
.config-form-section::-webkit-scrollbar {
width: 6px;
}
.config-form-section::-webkit-scrollbar-track {
background: var(--bg-tertiary);
border-radius: 3px;
}
.config-form-section::-webkit-scrollbar-thumb {
background: var(--text-tertiary);
border-radius: 3px;
}
.config-form-section::-webkit-scrollbar-thumb:hover {
background: var(--text-secondary);
}
.card-footer {
display: flex;
justify-content: flex-end;
gap: 12px;
padding-top: 8px;
}
.empty-icon {
width: 100px;
height: 100px;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--bg-tertiary);
border-radius: 50%;
margin: 0 auto;
}
.empty-icon .el-icon {
font-size: 48px;
color: var(--text-tertiary);
}
.fade-enter-active,
.fade-leave-active {
transition: opacity 0.25s ease;
}
.fade-enter-from,
.fade-leave-to {
opacity: 0;
}
.slide-fade-enter-active {
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
}
.slide-fade-leave-active {
transition: all 0.2s cubic-bezier(1, 0.5, 0.8, 1);
}
.slide-fade-enter-from {
opacity: 0;
transform: translateX(-16px);
}
.slide-fade-leave-to {
opacity: 0;
transform: translateX(16px);
}
@media (max-width: 768px) {
.llm-config-page {
padding: 16px;
}
.page-title {
font-size: 20px;
}
.header-content {
flex-direction: column;
}
.title-section {
min-width: 100%;
}
.config-form-section {
max-height: 300px;
}
}
</style>

View File

@ -0,0 +1,720 @@
<template>
<div class="dashboard-page">
<div class="page-header">
<h1 class="page-title">控制台</h1>
<p class="page-desc">系统概览与数据统计</p>
</div>
<el-row :gutter="20" v-loading="loading">
<el-col :xs="12" :sm="12" :md="6" :lg="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-content">
<div class="stat-icon primary">
<el-icon><FolderOpened /></el-icon>
</div>
<div class="stat-info">
<span class="stat-value">{{ stats.knowledgeBases }}</span>
<span class="stat-label">知识库总数</span>
</div>
</div>
</el-card>
</el-col>
<el-col :xs="12" :sm="12" :md="6" :lg="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-content">
<div class="stat-icon success">
<el-icon><Document /></el-icon>
</div>
<div class="stat-info">
<span class="stat-value">{{ stats.totalDocuments }}</span>
<span class="stat-label">文档总数</span>
</div>
</div>
</el-card>
</el-col>
<el-col :xs="12" :sm="12" :md="6" :lg="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-content">
<div class="stat-icon warning">
<el-icon><ChatDotSquare /></el-icon>
</div>
<div class="stat-info">
<span class="stat-value">{{ stats.totalMessages.toLocaleString() }}</span>
<span class="stat-label">总消息数</span>
</div>
</div>
</el-card>
</el-col>
<el-col :xs="12" :sm="12" :md="6" :lg="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-content">
<div class="stat-icon info">
<el-icon><Monitor /></el-icon>
</div>
<div class="stat-info">
<span class="stat-value">{{ stats.totalSessions }}</span>
<span class="stat-label">会话总数</span>
</div>
</div>
</el-card>
</el-col>
</el-row>
<el-row :gutter="20" style="margin-top: 20px;">
<el-col :xs="24" :sm="24" :md="8" :lg="8">
<el-card shadow="hover" class="metric-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper primary">
<el-icon><Cpu /></el-icon>
</div>
<span class="header-title">Token 消耗统计</span>
</div>
<el-tag type="primary" size="small" effect="plain">实时</el-tag>
</div>
</template>
<div class="metric-content">
<div class="metric-main">
<span class="metric-value primary">{{ formatNumber(stats.totalTokens) }}</span>
<span class="metric-label">总消耗</span>
</div>
<div class="metric-divider"></div>
<div class="metric-detail">
<div class="detail-item">
<span class="detail-label">输入</span>
<span class="detail-value">{{ formatNumber(stats.promptTokens) }}</span>
</div>
<div class="detail-item">
<span class="detail-label">输出</span>
<span class="detail-value">{{ formatNumber(stats.completionTokens) }}</span>
</div>
</div>
</div>
</el-card>
</el-col>
<el-col :xs="24" :sm="24" :md="8" :lg="8">
<el-card shadow="hover" class="metric-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper success">
<el-icon><Timer /></el-icon>
</div>
<span class="header-title">响应时间统计</span>
</div>
<el-tag :type="stats.slowRequestsCount > 0 ? 'warning' : 'success'" size="small" effect="plain">
{{ stats.slowRequestsCount }} 次超时
</el-tag>
</div>
</template>
<div class="metric-content">
<div class="latency-grid">
<div class="latency-item">
<span class="metric-value success">{{ formatLatency(stats.avgLatencyMs) }}</span>
<span class="metric-label">平均耗时</span>
</div>
<div class="latency-item">
<span class="metric-value">{{ formatLatency(stats.lastLatencyMs) }}</span>
<span class="metric-label">上次耗时</span>
</div>
</div>
<div class="metric-divider"></div>
<div class="latency-stats">
<div class="stat-row">
<div class="stat-col">
<span class="stat-label">P95</span>
<span class="stat-value">{{ formatLatency(stats.p95LatencyMs) }}</span>
</div>
<div class="stat-col">
<span class="stat-label">P99</span>
<span class="stat-value">{{ formatLatency(stats.p99LatencyMs) }}</span>
</div>
</div>
<div class="stat-row">
<div class="stat-col">
<span class="stat-label">最小</span>
<span class="stat-value">{{ formatLatency(stats.minLatencyMs) }}</span>
</div>
<div class="stat-col">
<span class="stat-label">最大</span>
<span class="stat-value">{{ formatLatency(stats.maxLatencyMs) }}</span>
</div>
</div>
</div>
</div>
</el-card>
</el-col>
<el-col :xs="24" :sm="24" :md="8" :lg="8">
<el-card shadow="hover" class="metric-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper warning">
<el-icon><DataLine /></el-icon>
</div>
<span class="header-title">请求统计</span>
</div>
<el-tag type="info" size="small" effect="plain">阈值 {{ stats.latencyThresholdMs }}ms</el-tag>
</div>
</template>
<div class="metric-content">
<div class="requests-grid">
<div class="request-item">
<span class="metric-value primary">{{ stats.aiRequestsCount }}</span>
<span class="metric-label">AI 请求</span>
</div>
<div class="request-item">
<span class="metric-value" :class="{ danger: stats.errorRequestsCount > 0 }">{{ stats.errorRequestsCount }}</span>
<span class="metric-label">错误</span>
</div>
<div class="request-item">
<span class="metric-value" :class="{ warning: stats.slowRequestsCount > 0 }">{{ stats.slowRequestsCount }}</span>
<span class="metric-label">超时</span>
</div>
</div>
<div class="metric-divider"></div>
<div class="rate-stats">
<div class="rate-item">
<span class="rate-label">错误率</span>
<span class="rate-value" :class="{ danger: parseFloat(errorRate) > 5 }">{{ errorRate }}%</span>
</div>
<div class="rate-item">
<span class="rate-label">超时率</span>
<span class="rate-value" :class="{ warning: parseFloat(timeoutRate) > 10 }">{{ timeoutRate }}%</span>
</div>
</div>
</div>
</el-card>
</el-col>
</el-row>
<el-row :gutter="20" style="margin-top: 20px;">
<el-col :span="24">
<el-card shadow="hover">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper info">
<el-icon><InfoFilled /></el-icon>
</div>
<span class="header-title">使用说明</span>
</div>
</div>
</template>
<div class="help-content">
<div class="help-item">
<div class="help-icon primary">
<el-icon><FolderOpened /></el-icon>
</div>
<div class="help-text">
<h4>知识库管理</h4>
<p>上传文档并建立向量索引支持 PDFWordTXT 等格式</p>
</div>
</div>
<div class="help-item">
<div class="help-icon success">
<el-icon><Cpu /></el-icon>
</div>
<div class="help-text">
<h4>RAG 实验室</h4>
<p>测试检索增强生成效果查看检索结果和 AI 响应</p>
</div>
</div>
<div class="help-item">
<div class="help-icon warning">
<el-icon><Connection /></el-icon>
</div>
<div class="help-text">
<h4>嵌入模型配置</h4>
<p>配置文本嵌入模型用于文档向量化</p>
</div>
</div>
<div class="help-item">
<div class="help-icon info">
<el-icon><ChatDotSquare /></el-icon>
</div>
<div class="help-text">
<h4>LLM 模型配置</h4>
<p>配置大语言模型支持 OpenAIDeepSeekOllama </p>
</div>
</div>
</div>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, onMounted } from 'vue'
import { FolderOpened, Document, ChatDotSquare, Monitor, Cpu, InfoFilled, Connection, Timer, DataLine } from '@element-plus/icons-vue'
import { getDashboardStats } from '@/api/dashboard'
const loading = ref(false)
const stats = reactive({
knowledgeBases: 0,
totalDocuments: 0,
totalMessages: 0,
totalSessions: 0,
totalTokens: 0,
promptTokens: 0,
completionTokens: 0,
aiRequestsCount: 0,
avgLatencyMs: 0,
lastLatencyMs: 0,
slowRequestsCount: 0,
errorRequestsCount: 0,
p95LatencyMs: 0,
p99LatencyMs: 0,
minLatencyMs: 0,
maxLatencyMs: 0,
latencyThresholdMs: 5000
})
const errorRate = computed(() => {
if (stats.aiRequestsCount === 0) return '0.00'
return ((stats.errorRequestsCount / stats.aiRequestsCount) * 100).toFixed(2)
})
const timeoutRate = computed(() => {
if (stats.aiRequestsCount === 0) return '0.00'
return ((stats.slowRequestsCount / stats.aiRequestsCount) * 100).toFixed(2)
})
const formatNumber = (num: number) => {
if (num >= 1000000) {
return (num / 1000000).toFixed(2) + 'M'
} else if (num >= 1000) {
return (num / 1000).toFixed(1) + 'K'
}
return num.toString()
}
const formatLatency = (ms: number | null | undefined) => {
if (ms === null || ms === undefined) return '-'
if (ms >= 1000) {
return (ms / 1000).toFixed(2) + 's'
}
return ms.toFixed(0) + 'ms'
}
const fetchStats = async () => {
loading.value = true
try {
const res: any = await getDashboardStats()
stats.knowledgeBases = res.knowledgeBases || 0
stats.totalDocuments = res.totalDocuments || 0
stats.totalMessages = res.totalMessages || 0
stats.totalSessions = res.totalSessions || 0
stats.totalTokens = res.totalTokens || 0
stats.promptTokens = res.promptTokens || 0
stats.completionTokens = res.completionTokens || 0
stats.aiRequestsCount = res.aiRequestsCount || 0
stats.avgLatencyMs = res.avgLatencyMs || 0
stats.lastLatencyMs = res.lastLatencyMs || 0
stats.slowRequestsCount = res.slowRequestsCount || 0
stats.errorRequestsCount = res.errorRequestsCount || 0
stats.p95LatencyMs = res.p95LatencyMs || 0
stats.p99LatencyMs = res.p99LatencyMs || 0
stats.minLatencyMs = res.minLatencyMs || 0
stats.maxLatencyMs = res.maxLatencyMs || 0
stats.latencyThresholdMs = res.latencyThresholdMs || 5000
} catch (error) {
console.error('Failed to fetch dashboard stats:', error)
} finally {
loading.value = false
}
}
onMounted(() => {
fetchStats()
})
</script>
<style scoped>
.dashboard-page {
padding: 24px;
min-height: calc(100vh - 60px);
}
.page-header {
margin-bottom: 24px;
}
.page-title {
margin: 0 0 8px 0;
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
letter-spacing: -0.5px;
}
.page-desc {
margin: 0;
font-size: 14px;
color: var(--text-secondary);
line-height: 1.6;
}
.stat-card {
animation: fadeInUp 0.5s ease-out;
}
.stat-card:nth-child(1) { animation-delay: 0s; }
.stat-card:nth-child(2) { animation-delay: 0.1s; }
.stat-card:nth-child(3) { animation-delay: 0.2s; }
.stat-card:nth-child(4) { animation-delay: 0.3s; }
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.stat-content {
display: flex;
align-items: center;
gap: 16px;
}
.stat-icon {
width: 48px;
height: 48px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 12px;
font-size: 22px;
}
.stat-icon.primary {
background-color: var(--primary-lighter);
color: var(--primary-color);
}
.stat-icon.success {
background-color: #D1FAE5;
color: #059669;
}
.stat-icon.warning {
background-color: #FEF3C7;
color: #D97706;
}
.stat-icon.info {
background-color: #E0E7FF;
color: #4F46E5;
}
.stat-info {
display: flex;
flex-direction: column;
}
.stat-value {
font-size: 28px;
font-weight: 700;
color: var(--text-primary);
line-height: 1.2;
}
.stat-label {
font-size: 13px;
color: var(--text-secondary);
margin-top: 4px;
}
.metric-card {
animation: fadeInUp 0.6s ease-out;
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 10px;
font-size: 18px;
}
.icon-wrapper.primary {
background-color: var(--primary-lighter);
color: var(--primary-color);
}
.icon-wrapper.success {
background-color: #D1FAE5;
color: #059669;
}
.icon-wrapper.warning {
background-color: #FEF3C7;
color: #D97706;
}
.icon-wrapper.info {
background-color: #E0E7FF;
color: #4F46E5;
}
.header-title {
font-size: 15px;
font-weight: 600;
color: var(--text-primary);
}
.metric-content {
display: flex;
flex-direction: column;
gap: 16px;
}
.metric-main {
display: flex;
flex-direction: column;
align-items: center;
padding: 8px 0;
}
.metric-value {
font-size: 32px;
font-weight: 700;
color: var(--text-primary);
line-height: 1.2;
}
.metric-value.primary {
color: var(--primary-color);
}
.metric-value.success {
color: #059669;
}
.metric-value.warning {
color: #D97706;
}
.metric-value.danger {
color: #DC2626;
}
.metric-label {
font-size: 13px;
color: var(--text-secondary);
margin-top: 4px;
}
.metric-divider {
height: 1px;
background-color: var(--border-light);
}
.metric-detail {
display: flex;
justify-content: space-around;
}
.detail-item {
display: flex;
flex-direction: column;
align-items: center;
gap: 4px;
}
.detail-label {
font-size: 12px;
color: var(--text-tertiary);
}
.detail-value {
font-size: 16px;
font-weight: 600;
color: var(--text-primary);
}
.latency-grid {
display: flex;
justify-content: space-around;
padding: 8px 0;
}
.latency-item {
display: flex;
flex-direction: column;
align-items: center;
}
.latency-stats {
display: flex;
flex-direction: column;
gap: 8px;
}
.stat-row {
display: flex;
justify-content: space-around;
}
.stat-col {
display: flex;
flex-direction: column;
align-items: center;
gap: 2px;
}
.stat-label {
font-size: 12px;
color: var(--text-tertiary);
}
.requests-grid {
display: flex;
justify-content: space-around;
padding: 8px 0;
}
.request-item {
display: flex;
flex-direction: column;
align-items: center;
}
.rate-stats {
display: flex;
justify-content: space-around;
}
.rate-item {
display: flex;
flex-direction: column;
align-items: center;
gap: 4px;
}
.rate-label {
font-size: 12px;
color: var(--text-tertiary);
}
.rate-value {
font-size: 18px;
font-weight: 600;
color: var(--text-primary);
}
.rate-value.warning {
color: #D97706;
}
.rate-value.danger {
color: #DC2626;
}
.help-content {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 20px;
}
.help-item {
display: flex;
gap: 14px;
padding: 16px;
background-color: var(--bg-tertiary);
border-radius: 12px;
transition: all 0.2s ease;
}
.help-item:hover {
background-color: var(--primary-lighter);
}
.help-icon {
width: 40px;
height: 40px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 10px;
font-size: 18px;
flex-shrink: 0;
}
.help-icon.primary {
background-color: var(--primary-lighter);
color: var(--primary-color);
}
.help-icon.success {
background-color: #D1FAE5;
color: #059669;
}
.help-icon.warning {
background-color: #FEF3C7;
color: #D97706;
}
.help-icon.info {
background-color: #E0E7FF;
color: #4F46E5;
}
.help-text h4 {
margin: 0 0 4px 0;
font-size: 14px;
font-weight: 600;
color: var(--text-primary);
}
.help-text p {
margin: 0;
font-size: 13px;
color: var(--text-secondary);
line-height: 1.5;
}
@media (max-width: 768px) {
.dashboard-page {
padding: 16px;
}
.page-title {
font-size: 20px;
}
.stat-value {
font-size: 24px;
}
.help-content {
grid-template-columns: 1fr;
}
.metric-value {
font-size: 28px;
}
}
</style>

View File

@ -0,0 +1,379 @@
<template>
<div class="kb-page">
<div class="page-header">
<div class="header-content">
<div class="title-section">
<h1 class="page-title">知识库管理</h1>
<p class="page-desc">上传文档并建立向量索引支持多种文档格式</p>
</div>
<div class="header-actions">
<el-button type="primary" @click="handleUploadClick">
<el-icon><Upload /></el-icon>
上传文档
</el-button>
</div>
</div>
</div>
<el-card shadow="hover" class="table-card">
<el-table v-loading="loading" :data="tableData" style="width: 100%">
<el-table-column prop="name" label="文件名" min-width="200">
<template #default="scope">
<div class="file-name">
<el-icon class="file-icon"><Document /></el-icon>
<span>{{ scope.row.name }}</span>
</div>
</template>
</el-table-column>
<el-table-column prop="status" label="状态" width="120">
<template #default="scope">
<el-tag :type="getStatusType(scope.row.status)" size="small">
{{ getStatusText(scope.row.status) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="jobId" label="任务ID" width="180">
<template #default="scope">
<span class="job-id">{{ scope.row.jobId || '-' }}</span>
</template>
</el-table-column>
<el-table-column prop="createTime" label="上传时间" width="180" />
<el-table-column label="操作" width="160" fixed="right">
<template #default="scope">
<el-button link type="primary" @click="handleViewJob(scope.row)">
<el-icon><View /></el-icon>
详情
</el-button>
<el-button link type="danger" @click="handleDelete(scope.row)">
<el-icon><Delete /></el-icon>
删除
</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<el-dialog v-model="jobDialogVisible" title="索引任务详情" width="500px" class="job-dialog">
<el-descriptions :column="1" border v-if="currentJob">
<el-descriptions-item label="任务ID">
<span class="job-id">{{ currentJob.jobId }}</span>
</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag :type="getStatusType(currentJob.status)" size="small">
{{ getStatusText(currentJob.status) }}
</el-tag>
</el-descriptions-item>
<el-descriptions-item label="进度">
<div class="progress-wrapper">
<el-progress :percentage="currentJob.progress" :status="getProgressStatus(currentJob.status)" />
</div>
</el-descriptions-item>
<el-descriptions-item label="错误信息" v-if="currentJob.errorMsg">
<el-alert type="error" :closable="false">{{ currentJob.errorMsg }}</el-alert>
</el-descriptions-item>
</el-descriptions>
<template #footer>
<el-button @click="jobDialogVisible = false">关闭</el-button>
<el-button
v-if="currentJob?.status === 'pending' || currentJob?.status === 'processing'"
type="primary"
@click="refreshJobStatus"
>
刷新状态
</el-button>
</template>
</el-dialog>
<input ref="fileInput" type="file" style="display: none" @change="handleFileChange" />
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Upload, Document, View, Delete } from '@element-plus/icons-vue'
import { uploadDocument, listDocuments, getIndexJob, deleteDocument } from '@/api/kb'
interface DocumentItem {
docId: string
name: string
status: string
jobId: string
createTime: string
}
const tableData = ref<DocumentItem[]>([])
const loading = ref(false)
const jobDialogVisible = ref(false)
const currentJob = ref<any>(null)
const pollingJobs = ref<Set<string>>(new Set())
let pollingInterval: number | null = null
const getStatusType = (status: string) => {
const typeMap: Record<string, string> = {
completed: 'success',
processing: 'warning',
pending: 'info',
failed: 'danger'
}
return typeMap[status] || 'info'
}
const getStatusText = (status: string) => {
const textMap: Record<string, string> = {
completed: '已完成',
processing: '处理中',
pending: '等待中',
failed: '失败'
}
return textMap[status] || status
}
const getProgressStatus = (status: string) => {
if (status === 'completed') return 'success'
if (status === 'failed') return 'exception'
return undefined
}
const fetchDocuments = async () => {
try {
const res = await listDocuments({})
tableData.value = res.data.map((doc: any) => ({
docId: doc.docId,
name: doc.fileName,
status: doc.status,
jobId: doc.jobId,
createTime: new Date(doc.createdAt).toLocaleString('zh-CN')
}))
} catch (error) {
console.error('Failed to fetch documents:', error)
}
}
const fetchJobStatus = async (jobId: string) => {
try {
const res = await getIndexJob(jobId)
return res
} catch (error) {
console.error('Failed to fetch job status:', error)
return null
}
}
const handleViewJob = async (row: DocumentItem) => {
if (!row.jobId) {
ElMessage.warning('该文档没有任务ID')
return
}
currentJob.value = await fetchJobStatus(row.jobId)
jobDialogVisible.value = true
}
const handleDelete = async (row: DocumentItem) => {
try {
await ElMessageBox.confirm(`确定要删除文档 "${row.name}" 吗?`, '确认删除', {
confirmButtonText: '确定',
cancelButtonText: '取消',
type: 'warning'
})
await deleteDocument(row.docId)
ElMessage.success('文档删除成功')
fetchDocuments()
} catch (error: any) {
if (error !== 'cancel') {
ElMessage.error(error.message || '删除失败')
}
}
}
const refreshJobStatus = async () => {
if (currentJob.value?.jobId) {
currentJob.value = await fetchJobStatus(currentJob.value.jobId)
}
}
const startPolling = (jobId: string) => {
pollingJobs.value.add(jobId)
if (!pollingInterval) {
pollingInterval = window.setInterval(async () => {
for (const jobId of pollingJobs.value) {
const job = await fetchJobStatus(jobId)
if (job) {
if (job.status === 'completed') {
pollingJobs.value.delete(jobId)
ElMessage.success('文档索引任务已完成')
fetchDocuments()
} else if (job.status === 'failed') {
pollingJobs.value.delete(jobId)
ElMessage.error('文档索引任务失败')
ElMessage.warning(`错误: ${job.errorMsg}`)
}
}
}
if (pollingJobs.value.size === 0 && pollingInterval) {
clearInterval(pollingInterval)
pollingInterval = null
}
}, 3000)
}
}
onMounted(() => {
fetchDocuments()
})
onUnmounted(() => {
if (pollingInterval) {
clearInterval(pollingInterval)
}
})
const fileInput = ref<HTMLInputElement | null>(null)
const handleUploadClick = () => {
fileInput.value?.click()
}
const handleFileChange = async (event: Event) => {
const target = event.target as HTMLInputElement
const file = target.files?.[0]
if (!file) return
const formData = new FormData()
formData.append('file', file)
formData.append('kb_id', 'kb_default')
try {
loading.value = true
const res = await uploadDocument(formData)
ElMessage.success(`文档上传成功任务ID: ${res.jobId}`)
console.log('Upload response:', res)
const newDoc: DocumentItem = {
name: file.name,
status: res.status || 'pending',
jobId: res.jobId,
createTime: new Date().toLocaleString('zh-CN')
}
tableData.value.unshift(newDoc)
startPolling(res.jobId)
} catch (error) {
ElMessage.error('文档上传失败')
console.error('Upload error:', error)
} finally {
loading.value = false
target.value = ''
}
}
</script>
<style scoped>
.kb-page {
padding: 24px;
min-height: calc(100vh - 60px);
}
.page-header {
margin-bottom: 24px;
}
.header-content {
display: flex;
justify-content: space-between;
align-items: flex-start;
gap: 20px;
flex-wrap: wrap;
}
.title-section {
flex: 1;
min-width: 200px;
}
.page-title {
margin: 0 0 8px 0;
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
letter-spacing: -0.5px;
}
.page-desc {
margin: 0;
font-size: 14px;
color: var(--text-secondary);
line-height: 1.6;
}
.header-actions {
display: flex;
align-items: center;
}
.table-card {
animation: fadeInUp 0.5s ease-out;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.file-name {
display: flex;
align-items: center;
gap: 10px;
}
.file-icon {
color: var(--primary-color);
font-size: 18px;
}
.job-id {
font-family: var(--font-mono);
font-size: 12px;
color: var(--text-secondary);
background-color: var(--bg-tertiary);
padding: 2px 8px;
border-radius: 4px;
}
.progress-wrapper {
width: 100%;
}
.job-dialog :deep(.el-dialog__header) {
border-bottom: 1px solid var(--border-light);
}
.job-dialog :deep(.el-dialog__body) {
padding: 24px;
}
@media (max-width: 768px) {
.kb-page {
padding: 16px;
}
.page-title {
font-size: 20px;
}
.header-content {
flex-direction: column;
}
.title-section {
min-width: 100%;
}
}
</style>

View File

@ -0,0 +1,191 @@
<template>
<div class="monitoring-container">
<el-card shadow="never">
<template #header>
<div class="card-header">
<span class="title">会话监控 [AC-ASA-09]</span>
<div class="header-ops">
<el-form :inline="true" :model="queryParams" class="search-form">
<el-form-item label="状态">
<el-select v-model="queryParams.status" placeholder="会话状态" clearable style="width: 120px">
<el-option label="活跃" value="active" />
<el-option label="已关闭" value="closed" />
<el-option label="已过期" value="expired" />
</el-select>
</el-form-item>
<el-form-item>
<el-button type="primary" @click="handleQuery">查询</el-button>
<el-button @click="resetQuery">重置</el-button>
</el-form-item>
</el-form>
</div>
</div>
</template>
<base-table
:data="tableData"
:total="total"
v-model:page-num="queryParams.page"
v-model:page-size="queryParams.pageSize"
@pagination="getList"
v-loading="loading"
>
<el-table-column prop="sessionId" label="会话 ID" width="280" show-overflow-tooltip />
<el-table-column prop="tenantId" label="租户 ID" width="280" show-overflow-tooltip />
<el-table-column prop="messageCount" label="消息数" width="100" align="center" />
<el-table-column prop="status" label="状态" width="100" align="center">
<template #default="scope">
<el-tag :type="statusMap[scope.row.status]?.type" size="small">
{{ statusMap[scope.row.status]?.label || scope.row.status }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="channelType" label="渠道" width="100" />
<el-table-column prop="startTime" label="开始时间" width="180" />
<el-table-column label="操作" fixed="right" width="120" align="center">
<template #default="scope">
<el-button link type="primary" @click="handleTrace(scope.row)">全链路追踪</el-button>
</template>
</el-table-column>
</base-table>
</el-card>
<el-drawer
v-model="drawerVisible"
title="会话全链路追踪详情"
size="50%"
destroy-on-close
>
<div v-loading="detailLoading" class="detail-container">
<el-empty v-if="!sessionDetail && !detailLoading" description="暂无追踪详情" />
<div v-else>
<el-descriptions :column="1" border class="session-info">
<el-descriptions-item label="会话ID">{{ sessionDetail?.sessionId }}</el-descriptions-item>
<el-descriptions-item label="消息数">{{ sessionDetail?.messages?.length || 0 }}</el-descriptions-item>
</el-descriptions>
<el-divider content-position="left">消息记录</el-divider>
<el-timeline>
<el-timeline-item
v-for="(msg, index) in sessionDetail?.messages"
:key="index"
:timestamp="msg.timestamp"
placement="top"
:type="msg.role === 'user' ? 'primary' : 'success'"
>
<el-card shadow="never" class="msg-card">
<div class="msg-header">
<span class="role-tag" :class="msg.role">{{ msg.role === 'user' ? 'USER' : 'ASSISTANT' }}</span>
</div>
<div class="msg-content">{{ msg.content }}</div>
</el-card>
</el-timeline-item>
</el-timeline>
<el-divider content-position="left" v-if="sessionDetail?.trace && (sessionDetail.trace.retrieval?.length || sessionDetail.trace.tools?.length)">
追踪信息
</el-divider>
<el-collapse v-if="sessionDetail?.trace">
<el-collapse-item v-if="sessionDetail.trace.retrieval?.length" title="检索追踪 (Retrieval)" name="retrieval">
<div v-for="(hit, hIdx) in sessionDetail.trace.retrieval" :key="hIdx" class="hit-item">
<div class="hit-meta">
<el-tag size="small" type="success">Score: {{ hit.score }}</el-tag>
<span class="hit-source" v-if="hit.source">来源: {{ hit.source }}</span>
</div>
<div class="hit-text">{{ hit.content }}</div>
</div>
</el-collapse-item>
<el-collapse-item v-if="sessionDetail.trace.tools?.length" title="工具调用 (Tool Calls)" name="tools">
<pre class="code-block"><code>{{ JSON.stringify(sessionDetail.trace.tools, null, 2) }}</code></pre>
</el-collapse-item>
</el-collapse>
</div>
</div>
</el-drawer>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import BaseTable from '@/components/BaseTable.vue'
import { listSessions, getSessionDetail } from '@/api/monitoring'
const statusMap: Record<string, { label: string, type: string }> = {
active: { label: '活跃', type: 'success' },
closed: { label: '已关闭', type: 'info' },
expired: { label: '已过期', type: 'warning' }
}
const loading = ref(false)
const tableData = ref([])
const total = ref(0)
const queryParams = reactive({
page: 1,
pageSize: 10,
status: ''
})
const drawerVisible = ref(false)
const detailLoading = ref(false)
const sessionDetail = ref<any>(null)
const getList = async () => {
loading.value = true
try {
const res: any = await listSessions(queryParams)
tableData.value = res.data || []
total.value = res.pagination?.total || 0
} catch (error) {
console.error('Failed to fetch sessions:', error)
} finally {
loading.value = false
}
}
const handleQuery = () => {
queryParams.page = 1
getList()
}
const resetQuery = () => {
queryParams.status = ''
handleQuery()
}
const handleTrace = async (row: any) => {
drawerVisible.value = true
detailLoading.value = true
try {
sessionDetail.value = await getSessionDetail(row.sessionId)
} catch (error) {
console.error('Failed to fetch session detail:', error)
} finally {
detailLoading.value = false
}
}
onMounted(() => {
getList()
})
</script>
<style scoped>
.monitoring-container { padding: 20px; }
.card-header { display: flex; justify-content: space-between; align-items: center; }
.title { font-size: 16px; font-weight: bold; }
.detail-container { padding: 10px 20px; }
.session-info { margin-bottom: 20px; }
.msg-card { border-radius: 8px; margin-bottom: 10px; }
.msg-header { margin-bottom: 8px; }
.role-tag { font-size: 11px; font-weight: bold; padding: 2px 6px; border-radius: 4px; }
.role-tag.user { background-color: #ecf5ff; color: #409eff; }
.role-tag.assistant { background-color: #f0f9eb; color: #67c23a; }
.msg-content { font-size: 14px; line-height: 1.6; color: #333; white-space: pre-wrap; }
.hit-item { padding: 10px; background-color: #f8f9fa; border-radius: 4px; margin-bottom: 8px; }
.hit-meta { display: flex; justify-content: space-between; align-items: center; margin-bottom: 6px; }
.hit-source { font-size: 11px; color: #999; }
.hit-text { font-size: 12px; color: #666; line-height: 1.5; }
.code-block { background-color: #fafafa; border: 1px solid #eaeaea; padding: 8px; border-radius: 4px; font-size: 12px; overflow-x: auto; margin: 0; }
</style>

View File

@ -0,0 +1,545 @@
<template>
<div class="rag-lab-page">
<div class="page-header">
<h1 class="page-title">RAG 实验室</h1>
<p class="page-desc">测试检索增强生成效果查看检索结果和 AI 响应</p>
</div>
<el-row :gutter="24">
<el-col :xs="24" :sm="24" :md="10" :lg="10">
<el-card shadow="hover" class="input-card">
<template #header>
<div class="card-header">
<div class="header-left">
<div class="icon-wrapper">
<el-icon><Edit /></el-icon>
</div>
<span class="header-title">调试输入</span>
</div>
</div>
</template>
<el-form label-position="top">
<el-form-item label="查询 Query">
<el-input
v-model="query"
type="textarea"
:rows="4"
placeholder="输入测试问题..."
/>
</el-form-item>
<el-form-item label="知识库范围">
<el-select
v-model="kbIds"
multiple
placeholder="请选择知识库"
style="width: 100%"
:loading="kbLoading"
:teleported="true"
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }] }"
>
<el-option
v-for="kb in knowledgeBases"
:key="kb.id"
:label="`${kb.name} (${kb.documentCount}个文档)`"
:value="kb.id"
/>
</el-select>
</el-form-item>
<el-form-item label="LLM 模型">
<LLMSelector
v-model="llmProvider"
:providers="llmProviders"
:loading="llmLoading"
:current-provider="currentLLMProvider"
placeholder="使用默认配置"
clearable
@change="handleLLMChange"
/>
</el-form-item>
<el-form-item label="参数配置">
<div class="param-item">
<span class="label">Top-K</span>
<el-input-number v-model="topK" :min="1" :max="10" />
</div>
<div class="param-item">
<span class="label">Score Threshold</span>
<el-slider
v-model="scoreThreshold"
:min="0"
:max="1"
:step="0.1"
show-input
/>
</div>
<div class="param-item">
<span class="label">生成 AI 回复</span>
<el-switch v-model="generateResponse" />
</div>
<div class="param-item" v-if="generateResponse">
<span class="label">流式输出</span>
<el-switch v-model="streamOutput" />
</div>
</el-form-item>
<el-button
type="primary"
block
@click="handleRun"
:loading="loading || streaming"
>
{{ streaming ? '生成中...' : '运行实验' }}
</el-button>
<el-button
v-if="streaming"
type="danger"
block
@click="handleStopStream"
style="margin-top: 10px;"
>
停止生成
</el-button>
</el-form>
</el-card>
</el-col>
<el-col :xs="24" :sm="24" :md="14" :lg="14">
<el-tabs v-model="activeTab" type="border-card" class="result-tabs">
<el-tab-pane label="召回片段" name="retrieval">
<div v-if="retrievalResults.length === 0" class="placeholder-text">
暂无实验数据
</div>
<div v-else class="result-list">
<el-card
v-for="(item, index) in retrievalResults"
:key="index"
class="result-card"
shadow="never"
>
<div class="result-header">
<el-tag size="small" type="primary">Score: {{ item.score.toFixed(4) }}</el-tag>
<span class="source">来源: {{ item.source }}</span>
</div>
<div class="result-content">{{ item.content }}</div>
</el-card>
</div>
</el-tab-pane>
<el-tab-pane label="最终 Prompt" name="prompt">
<div v-if="!finalPrompt" class="placeholder-text">
等待实验运行...
</div>
<div v-else class="prompt-view">
<pre><code>{{ finalPrompt }}</code></pre>
</div>
</el-tab-pane>
<el-tab-pane label="AI 回复" name="ai-response" v-if="generateResponse">
<StreamOutput
v-if="streamOutput"
:content="streamContent"
:is-streaming="streaming"
:error="streamError"
/>
<AIResponseViewer
v-else
:response="aiResponse"
/>
</el-tab-pane>
<el-tab-pane label="诊断信息" name="diagnostics">
<div v-if="!diagnostics" class="placeholder-text">
等待实验运行...
</div>
<div v-else class="diagnostics-view">
<pre><code>{{ JSON.stringify(diagnostics, null, 2) }}</code></pre>
</div>
</el-tab-pane>
</el-tabs>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { ElMessage } from 'element-plus'
import { Edit } from '@element-plus/icons-vue'
import { runRagExperiment, createSSEConnection, type AIResponse, type RetrievalResult } from '@/api/rag'
import { getLLMProviders, getLLMConfig, type LLMProviderInfo } from '@/api/llm'
import { listKnowledgeBases } from '@/api/kb'
import { useRagLabStore } from '@/stores/ragLab'
import { storeToRefs } from 'pinia'
import AIResponseViewer from '@/components/rag/AIResponseViewer.vue'
import StreamOutput from '@/components/rag/StreamOutput.vue'
import LLMSelector from '@/components/rag/LLMSelector.vue'
interface KnowledgeBase {
id: string
name: string
documentCount: number
}
const ragLabStore = useRagLabStore()
const {
query,
kbIds,
llmProvider,
topK,
scoreThreshold,
generateResponse,
streamOutput
} = storeToRefs(ragLabStore)
const loading = ref(false)
const kbLoading = ref(false)
const llmLoading = ref(false)
const streaming = ref(false)
const activeTab = ref('retrieval')
const knowledgeBases = ref<KnowledgeBase[]>([])
const llmProviders = ref<LLMProviderInfo[]>([])
const currentLLMProvider = ref('')
const retrievalResults = ref<RetrievalResult[]>([])
const finalPrompt = ref('')
const aiResponse = ref<AIResponse | null>(null)
const diagnostics = ref<any>(null)
const streamContent = ref('')
const streamError = ref<string | null>(null)
const totalLatencyMs = ref(0)
let abortStream: (() => void) | null = null
const fetchKnowledgeBases = async () => {
kbLoading.value = true
try {
const res: any = await listKnowledgeBases()
knowledgeBases.value = res.data || []
} catch (error) {
console.error('Failed to fetch knowledge bases:', error)
} finally {
kbLoading.value = false
}
}
const fetchLLMProviders = async () => {
llmLoading.value = true
try {
const [providersRes, configRes]: [any, any] = await Promise.all([
getLLMProviders(),
getLLMConfig()
])
llmProviders.value = providersRes?.providers || []
currentLLMProvider.value = configRes?.provider || ''
} catch (error) {
console.error('Failed to fetch LLM providers:', error)
} finally {
llmLoading.value = false
}
}
const handleLLMChange = (provider: LLMProviderInfo | undefined) => {
llmProvider.value = provider?.name || ''
}
const handleRun = async () => {
if (!query.value.trim()) {
ElMessage.warning('请输入查询 Query')
return
}
clearResults()
if (streamOutput.value && generateResponse.value) {
await runStreamExperiment()
} else {
await runNormalExperiment()
}
}
const runNormalExperiment = async () => {
loading.value = true
try {
const res: any = await runRagExperiment({
query: query.value,
kb_ids: kbIds.value,
top_k: topK.value,
score_threshold: scoreThreshold.value,
llm_provider: llmProvider.value || undefined,
generate_response: generateResponse.value
})
retrievalResults.value = res.retrieval_results || res.retrievalResults || []
finalPrompt.value = res.final_prompt || res.finalPrompt || ''
aiResponse.value = res.ai_response || res.aiResponse || null
diagnostics.value = res.diagnostics || null
totalLatencyMs.value = res.total_latency_ms || res.totalLatencyMs || 0
if (generateResponse.value) {
activeTab.value = 'ai-response'
} else {
activeTab.value = 'retrieval'
}
ElMessage.success('实验运行成功')
} catch (err: any) {
console.error(err)
ElMessage.error(err?.message || '实验运行失败')
} finally {
loading.value = false
}
}
const runStreamExperiment = async () => {
streaming.value = true
streamContent.value = ''
streamError.value = null
activeTab.value = 'ai-response'
abortStream = createSSEConnection(
'/admin/rag/experiments/stream',
{
query: query.value,
kb_ids: kbIds.value,
top_k: topK.value,
score_threshold: scoreThreshold.value,
llm_provider: llmProvider.value || undefined,
generate_response: true
},
(data: string) => {
try {
const parsed = JSON.parse(data)
if (parsed.type === 'content') {
streamContent.value += parsed.content || ''
} else if (parsed.type === 'retrieval') {
retrievalResults.value = parsed.results || []
} else if (parsed.type === 'prompt') {
finalPrompt.value = parsed.prompt || ''
} else if (parsed.type === 'complete') {
aiResponse.value = {
content: streamContent.value,
prompt_tokens: parsed.prompt_tokens,
completion_tokens: parsed.completion_tokens,
total_tokens: parsed.total_tokens,
latency_ms: parsed.latency_ms,
model: parsed.model
}
totalLatencyMs.value = parsed.total_latency_ms || 0
streaming.value = false
ElMessage.success('生成完成')
} else if (parsed.type === 'error') {
streamError.value = parsed.message || '流式输出错误'
streaming.value = false
ElMessage.error(streamError.value)
}
} catch {
streamContent.value += data
}
},
(error: Error) => {
streaming.value = false
streamError.value = error.message
ElMessage.error(error.message)
},
() => {
streaming.value = false
}
)
}
const handleStopStream = () => {
if (abortStream) {
abortStream()
abortStream = null
}
streaming.value = false
ElMessage.info('已停止生成')
}
const clearResults = () => {
retrievalResults.value = []
finalPrompt.value = ''
aiResponse.value = null
diagnostics.value = null
streamContent.value = ''
streamError.value = null
totalLatencyMs.value = 0
}
onMounted(() => {
fetchKnowledgeBases()
fetchLLMProviders()
})
</script>
<style scoped>
.rag-lab-page {
padding: 24px;
min-height: calc(100vh - 60px);
}
.page-header {
margin-bottom: 24px;
}
.page-title {
margin: 0 0 8px 0;
font-size: 24px;
font-weight: 700;
color: var(--text-primary);
letter-spacing: -0.5px;
}
.page-desc {
margin: 0;
font-size: 14px;
color: var(--text-secondary);
line-height: 1.6;
}
.input-card {
animation: fadeInUp 0.5s ease-out;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0;
}
.header-left {
display: flex;
align-items: center;
gap: 12px;
}
.icon-wrapper {
width: 36px;
height: 36px;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--primary-lighter);
border-radius: 10px;
color: var(--primary-color);
font-size: 18px;
}
.header-title {
font-size: 15px;
font-weight: 600;
color: var(--text-primary);
}
.param-item {
display: flex;
align-items: center;
margin-bottom: 16px;
gap: 16px;
}
.param-item .label {
width: 140px;
font-size: 13px;
font-weight: 500;
color: var(--text-secondary);
flex-shrink: 0;
}
.param-item :deep(.el-slider) {
flex: 1;
}
.result-tabs {
animation: fadeInUp 0.6s ease-out;
}
.result-tabs :deep(.el-tabs__header) {
border-radius: 12px 12px 0 0;
}
.placeholder-text {
color: var(--text-tertiary);
text-align: center;
padding: 60px 20px;
font-size: 14px;
}
.result-list {
max-height: 600px;
overflow-y: auto;
padding-right: 8px;
}
.result-card {
margin-bottom: 16px;
border: 1px solid var(--border-color);
}
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
}
.source {
font-size: 12px;
color: var(--text-tertiary);
}
.result-content {
font-size: 14px;
line-height: 1.7;
color: var(--text-primary);
}
.prompt-view,
.diagnostics-view {
background-color: var(--bg-tertiary);
padding: 16px;
border-radius: 10px;
max-height: 600px;
overflow-y: auto;
}
.prompt-view pre,
.diagnostics-view pre {
margin: 0;
white-space: pre-wrap;
word-wrap: break-word;
font-family: var(--font-mono);
font-size: 13px;
line-height: 1.6;
color: var(--text-primary);
}
@media (max-width: 768px) {
.rag-lab-page {
padding: 16px;
}
.page-title {
font-size: 20px;
}
.param-item {
flex-direction: column;
align-items: flex-start;
gap: 8px;
}
.param-item .label {
width: 100%;
}
}
</style>

View File

@ -0,0 +1,22 @@
{
"compilerOptions": {
"target": "ESNext",
"useDefineForClassFields": true,
"module": "ESNext",
"moduleResolution": "Node",
"strict": true,
"jsx": "preserve",
"resolveJsonModule": true,
"isolatedModules": true,
"esModuleInterop": true,
"lib": ["ESNext", "DOM"],
"skipLibCheck": true,
"noEmit": true,
"baseUrl": ".",
"paths": {
"@/*": ["src/*"]
}
},
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"],
"references": [{ "path": "./tsconfig.node.json" }]
}

View File

@ -0,0 +1,10 @@
{
"compilerOptions": {
"composite": true,
"skipLibCheck": true,
"module": "ESNext",
"moduleResolution": "Node",
"allowSyntheticDefaultImports": true
},
"include": ["vite.config.ts"]
}

View File

@ -0,0 +1,23 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import path from 'path'
// https://vitejs.dev/config/
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
},
},
server: {
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
rewrite: (path) => path.replace(/^\/api/, ''),
},
},
},
})

74
ai-service/README.md Normal file
View File

@ -0,0 +1,74 @@
# AI Service
Python AI Service for intelligent chat with RAG support.
## Features
- Multi-tenant isolation via X-Tenant-Id header
- SSE streaming support via Accept: text/event-stream
- RAG-powered responses with confidence scoring
## Prerequisites
- PostgreSQL 12+
- Qdrant vector database
- Python 3.10+
## Installation
```bash
pip install -e ".[dev]"
```
## Database Initialization
### Option 1: Using Python script (Recommended)
```bash
# Create database and tables
python scripts/init_db.py --create-db
# Or just create tables (database must exist)
python scripts/init_db.py
```
### Option 2: Using SQL script
```bash
# Connect to PostgreSQL and run
psql -U postgres -f scripts/init_db.sql
```
## Configuration
Create a `.env` file in the project root:
```env
AI_SERVICE_DATABASE_URL=postgresql+asyncpg://postgres:password@localhost:5432/ai_service
AI_SERVICE_QDRANT_URL=http://localhost:6333
AI_SERVICE_LLM_API_KEY=your-api-key
AI_SERVICE_LLM_BASE_URL=https://api.openai.com/v1
AI_SERVICE_LLM_MODEL=gpt-4o-mini
AI_SERVICE_DEBUG=true
```
## Running
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000
```
## API Endpoints
### Chat API
- `POST /ai/chat` - Generate AI reply (supports SSE streaming)
- `GET /ai/health` - Health check
### Admin API
- `GET /admin/kb/documents` - List documents
- `POST /admin/kb/documents` - Upload document
- `GET /admin/kb/index/jobs/{jobId}` - Get indexing job status
- `DELETE /admin/kb/documents/{docId}` - Delete document
- `POST /admin/rag/experiments/run` - Run RAG experiment
- `GET /admin/sessions` - List chat sessions
- `GET /admin/sessions/{sessionId}` - Get session details

View File

@ -0,0 +1,4 @@
"""
AI Service - Python AI Middle Platform
[AC-AISVC-01] FastAPI-based AI chat service with multi-tenant support.
"""

View File

@ -0,0 +1,8 @@
"""
API module for AI Service.
"""
from app.api.chat import router as chat_router
from app.api.health import router as health_router
__all__ = ["chat_router", "health_router"]

View File

@ -0,0 +1,14 @@
"""
Admin API routes for AI Service management.
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08] Admin management endpoints.
"""
from app.api.admin.dashboard import router as dashboard_router
from app.api.admin.embedding import router as embedding_router
from app.api.admin.kb import router as kb_router
from app.api.admin.llm import router as llm_router
from app.api.admin.rag import router as rag_router
from app.api.admin.sessions import router as sessions_router
from app.api.admin.tenants import router as tenants_router
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]

View File

@ -0,0 +1,202 @@
"""
Dashboard statistics endpoints.
Provides overview statistics for the admin dashboard.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import ChatMessage, ChatSession, Document, KnowledgeBase
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/dashboard", tags=["Dashboard"])
LATENCY_THRESHOLD_MS = 5000
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"/stats",
operation_id="getDashboardStats",
summary="Get dashboard statistics",
description="Get overview statistics for the admin dashboard.",
responses={
200: {"description": "Dashboard statistics"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_dashboard_stats(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
latency_threshold: int = Query(default=LATENCY_THRESHOLD_MS, description="Latency threshold in ms"),
) -> JSONResponse:
"""
Get dashboard statistics including knowledge bases, messages, and activity.
"""
logger.info(f"Getting dashboard stats: tenant={tenant_id}")
kb_count_stmt = select(func.count()).select_from(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id
)
kb_result = await session.execute(kb_count_stmt)
kb_count = kb_result.scalar() or 0
msg_count_stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id
)
msg_result = await session.execute(msg_count_stmt)
msg_count = msg_result.scalar() or 0
doc_count_stmt = select(func.count()).select_from(Document).where(
Document.tenant_id == tenant_id
)
doc_result = await session.execute(doc_count_stmt)
doc_count = doc_result.scalar() or 0
session_count_stmt = select(func.count()).select_from(ChatSession).where(
ChatSession.tenant_id == tenant_id
)
session_result = await session.execute(session_count_stmt)
session_count = session_result.scalar() or 0
total_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.total_tokens), 0)).select_from(
ChatMessage
).where(ChatMessage.tenant_id == tenant_id)
total_tokens_result = await session.execute(total_tokens_stmt)
total_tokens = total_tokens_result.scalar() or 0
prompt_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.prompt_tokens), 0)).select_from(
ChatMessage
).where(ChatMessage.tenant_id == tenant_id)
prompt_tokens_result = await session.execute(prompt_tokens_stmt)
prompt_tokens = prompt_tokens_result.scalar() or 0
completion_tokens_stmt = select(func.coalesce(func.sum(ChatMessage.completion_tokens), 0)).select_from(
ChatMessage
).where(ChatMessage.tenant_id == tenant_id)
completion_tokens_result = await session.execute(completion_tokens_stmt)
completion_tokens = completion_tokens_result.scalar() or 0
ai_requests_stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant"
)
ai_requests_result = await session.execute(ai_requests_stmt)
ai_requests_count = ai_requests_result.scalar() or 0
avg_latency_stmt = select(func.coalesce(func.avg(ChatMessage.latency_ms), 0)).select_from(
ChatMessage
).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None)
)
avg_latency_result = await session.execute(avg_latency_stmt)
avg_latency_ms = float(avg_latency_result.scalar() or 0)
last_request_stmt = select(ChatMessage.latency_ms, ChatMessage.created_at).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant"
).order_by(desc(ChatMessage.created_at)).limit(1)
last_request_result = await session.execute(last_request_stmt)
last_request_row = last_request_result.fetchone()
last_latency_ms = last_request_row[0] if last_request_row else None
last_request_time = last_request_row[1].isoformat() if last_request_row and last_request_row[1] else None
slow_requests_stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None),
ChatMessage.latency_ms >= latency_threshold
)
slow_requests_result = await session.execute(slow_requests_stmt)
slow_requests_count = slow_requests_result.scalar() or 0
error_requests_stmt = select(func.count()).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.is_error == True
)
error_requests_result = await session.execute(error_requests_stmt)
error_requests_count = error_requests_result.scalar() or 0
p95_latency_stmt = select(func.coalesce(
func.percentile_cont(0.95).within_group(ChatMessage.latency_ms), 0
)).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None)
)
p95_latency_result = await session.execute(p95_latency_stmt)
p95_latency_ms = float(p95_latency_result.scalar() or 0)
p99_latency_stmt = select(func.coalesce(
func.percentile_cont(0.99).within_group(ChatMessage.latency_ms), 0
)).select_from(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None)
)
p99_latency_result = await session.execute(p99_latency_stmt)
p99_latency_ms = float(p99_latency_result.scalar() or 0)
min_latency_stmt = select(func.coalesce(func.min(ChatMessage.latency_ms), 0)).select_from(
ChatMessage
).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None)
)
min_latency_result = await session.execute(min_latency_stmt)
min_latency_ms = float(min_latency_result.scalar() or 0)
max_latency_stmt = select(func.coalesce(func.max(ChatMessage.latency_ms), 0)).select_from(
ChatMessage
).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
ChatMessage.latency_ms.isnot(None)
)
max_latency_result = await session.execute(max_latency_stmt)
max_latency_ms = float(max_latency_result.scalar() or 0)
return JSONResponse(
content={
"knowledgeBases": kb_count,
"totalMessages": msg_count,
"totalDocuments": doc_count,
"totalSessions": session_count,
"totalTokens": total_tokens,
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"aiRequestsCount": ai_requests_count,
"avgLatencyMs": round(avg_latency_ms, 2),
"lastLatencyMs": last_latency_ms,
"lastRequestTime": last_request_time,
"slowRequestsCount": slow_requests_count,
"errorRequestsCount": error_requests_count,
"p95LatencyMs": round(p95_latency_ms, 2),
"p99LatencyMs": round(p99_latency_ms, 2),
"minLatencyMs": round(min_latency_ms, 2),
"maxLatencyMs": round(max_latency_ms, 2),
"latencyThresholdMs": latency_threshold,
}
)

View File

@ -0,0 +1,132 @@
"""
Embedding management API endpoints.
[AC-AISVC-38, AC-AISVC-39, AC-AISVC-40, AC-AISVC-41] Embedding model management.
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException
from app.core.exceptions import InvalidRequestException
from app.services.embedding import (
EmbeddingException,
EmbeddingProviderFactory,
get_embedding_config_manager,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/embedding", tags=["Embedding Management"])
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
return x_tenant_id
@router.get("/providers")
async def list_embedding_providers(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get available embedding providers.
[AC-AISVC-38] Returns all registered providers with config schemas.
"""
providers = []
for name in EmbeddingProviderFactory.get_available_providers():
info = EmbeddingProviderFactory.get_provider_info(name)
providers.append(info)
return {"providers": providers}
@router.get("/config")
async def get_embedding_config(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get current embedding configuration.
[AC-AISVC-39] Returns current provider and config.
"""
manager = get_embedding_config_manager()
return manager.get_full_config()
@router.put("/config")
async def update_embedding_config(
request: dict[str, Any],
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Update embedding configuration.
[AC-AISVC-40, AC-AISVC-31] Updates config with hot-reload support.
"""
provider = request.get("provider")
config = request.get("config", {})
if not provider:
raise InvalidRequestException("provider is required")
if provider not in EmbeddingProviderFactory.get_available_providers():
raise InvalidRequestException(
f"Unknown provider: {provider}. "
f"Available: {EmbeddingProviderFactory.get_available_providers()}"
)
manager = get_embedding_config_manager()
try:
await manager.update_config(provider, config)
return {
"success": True,
"message": f"Configuration updated to use {provider}",
}
except EmbeddingException as e:
raise InvalidRequestException(str(e))
@router.post("/test")
async def test_embedding(
request: dict[str, Any] | None = None,
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Test embedding connection.
[AC-AISVC-41] Tests provider connectivity and returns dimension info.
"""
request = request or {}
test_text = request.get("test_text", "这是一个测试文本")
config = request.get("config")
provider = request.get("provider")
manager = get_embedding_config_manager()
result = await manager.test_connection(
test_text=test_text,
provider=provider,
config=config,
)
return result
@router.get("/formats")
async def get_supported_document_formats(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get supported document formats for embedding.
Returns list of supported file extensions.
"""
from app.services.document import get_supported_document_formats, DocumentParserFactory
formats = get_supported_document_formats()
parser_info = DocumentParserFactory.get_parser_info()
return {
"formats": formats,
"parsers": parser_info,
}

View File

@ -0,0 +1,593 @@
"""
Knowledge Base management endpoints.
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Document upload, list, and index job status.
"""
import logging
import os
import uuid
from dataclasses import dataclass
from typing import Annotated, Optional
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import DocumentStatus, IndexJob, IndexJobStatus
from app.services.kb import KBService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
@dataclass
class TextChunk:
"""Text chunk with metadata."""
text: str
start_token: int
end_token: int
page: int | None = None
source: str | None = None
def chunk_text_by_lines(
text: str,
min_line_length: int = 10,
source: str | None = None,
) -> list[TextChunk]:
"""
按行分块每行作为一个独立的检索单元
Args:
text: 要分块的文本
min_line_length: 最小行长度低于此长度的行会被跳过
source: 来源文件路径可选
Returns:
分块列表每个块对应一行文本
"""
lines = text.split('\n')
chunks: list[TextChunk] = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) < min_line_length:
continue
chunks.append(TextChunk(
text=line,
start_token=i,
end_token=i + 1,
page=None,
source=source,
))
return chunks
def chunk_text_with_tiktoken(
text: str,
chunk_size: int = 512,
overlap: int = 100,
page: int | None = None,
source: str | None = None,
) -> list[TextChunk]:
"""
使用 tiktoken token 数分块支持重叠分块
Args:
text: 要分块的文本
chunk_size: 每个块的最大 token
overlap: 块之间的重叠 token
page: 页码可选
source: 来源文件路径可选
Returns:
分块列表每个块包含文本及起始/结束位置
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
chunks: list[TextChunk] = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = encoding.decode(chunk_tokens)
chunks.append(TextChunk(
text=chunk_text,
start_token=start,
end_token=end,
page=page,
source=source,
))
if end == len(tokens):
break
start += chunk_size - overlap
return chunks
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"/knowledge-bases",
operation_id="listKnowledgeBases",
summary="Query knowledge base list",
description="Get list of knowledge bases for the current tenant.",
responses={
200: {"description": "Knowledge base list"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_knowledge_bases(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
) -> JSONResponse:
"""
List all knowledge bases for the current tenant.
"""
logger.info(f"Listing knowledge bases: tenant={tenant_id}")
kb_service = KBService(session)
knowledge_bases = await kb_service.list_knowledge_bases(tenant_id)
kb_ids = [str(kb.id) for kb in knowledge_bases]
doc_counts = {}
if kb_ids:
from sqlalchemy import func
from app.models.entities import Document
count_stmt = (
select(Document.kb_id, func.count(Document.id).label("count"))
.where(Document.tenant_id == tenant_id, Document.kb_id.in_(kb_ids))
.group_by(Document.kb_id)
)
count_result = await session.execute(count_stmt)
for row in count_result:
doc_counts[row.kb_id] = row.count
data = []
for kb in knowledge_bases:
kb_id_str = str(kb.id)
data.append({
"id": kb_id_str,
"name": kb.name,
"documentCount": doc_counts.get(kb_id_str, 0),
"createdAt": kb.created_at.isoformat() + "Z",
})
return JSONResponse(content={"data": data})
@router.get(
"/documents",
operation_id="listDocuments",
summary="Query document list",
description="[AC-ASA-08] Get list of documents with pagination and filtering.",
responses={
200: {"description": "Document list with pagination"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: Annotated[Optional[str], Query()] = None,
status: Annotated[Optional[str], Query()] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> JSONResponse:
"""
[AC-ASA-08] List documents with filtering and pagination.
"""
logger.info(
f"[AC-ASA-08] Listing documents: tenant={tenant_id}, kb_id={kb_id}, "
f"status={status}, page={page}, page_size={page_size}"
)
kb_service = KBService(session)
documents, total = await kb_service.list_documents(
tenant_id=tenant_id,
kb_id=kb_id,
status=status,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
data = []
for doc in documents:
job_stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id == doc.id,
).order_by(IndexJob.created_at.desc())
job_result = await session.execute(job_stmt)
latest_job = job_result.scalar_one_or_none()
data.append({
"docId": str(doc.id),
"kbId": doc.kb_id,
"fileName": doc.file_name,
"status": doc.status,
"jobId": str(latest_job.id) if latest_job else None,
"createdAt": doc.created_at.isoformat() + "Z",
"updatedAt": doc.updated_at.isoformat() + "Z",
})
return JSONResponse(
content={
"data": data,
"pagination": {
"page": page,
"pageSize": page_size,
"total": total,
"totalPages": total_pages,
},
}
)
@router.post(
"/documents",
operation_id="uploadDocument",
summary="Upload/import document",
description="[AC-ASA-01] Upload document and trigger indexing job.",
responses={
202: {"description": "Accepted - async indexing job started"},
400: {"description": "Bad Request - unsupported format"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def upload_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
) -> JSONResponse:
"""
[AC-ASA-01] Upload document and create indexing job.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35, AC-AISVC-37] Support multiple document formats.
"""
from app.services.document import get_supported_document_formats, UnsupportedFormatError
from pathlib import Path
logger.info(
f"[AC-ASA-01] Uploading document: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
file_ext = Path(file.filename or "").suffix.lower()
supported_formats = get_supported_document_formats()
if file_ext and file_ext not in supported_formats:
return JSONResponse(
status_code=400,
content={
"code": "UNSUPPORTED_FORMAT",
"message": f"Unsupported file format: {file_ext}",
"details": {
"supported_formats": supported_formats,
},
},
)
kb_service = KBService(session)
kb = await kb_service.get_or_create_kb(tenant_id, kb_id)
file_content = await file.read()
document, job = await kb_service.upload_document(
tenant_id=tenant_id,
kb_id=str(kb.id),
file_name=file.filename or "unknown",
file_content=file_content,
file_type=file.content_type,
)
await session.commit()
background_tasks.add_task(
_index_document, tenant_id, str(job.id), str(document.id), file_content, file.filename
)
return JSONResponse(
status_code=202,
content={
"jobId": str(job.id),
"docId": str(document.id),
"status": job.status,
},
)
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes, filename: str | None = None):
"""
Background indexing task.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Uses document parsing and pluggable embedding.
"""
from app.core.database import async_session_maker
from app.services.kb import KBService
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException, PageText
from qdrant_client.models import PointStruct
import asyncio
import tempfile
from pathlib import Path
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
await asyncio.sleep(1)
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info(f"[INDEX] Treating as text file, trying multiple encodings")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = content.decode(encoding)
logger.info(f"[INDEX] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = content.decode("utf-8", errors="replace")
logger.warning(f"[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
else:
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
logger.info(f"[INDEX] Temp file created: {tmp_path}")
try:
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[INDEX] Parsed document SUCCESS: {filename}, "
f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
f"metadata={parse_result.metadata}"
)
if len(text) < 100:
logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
except UnsupportedFormatError as e:
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
text = content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
text = content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
text = content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[INDEX] Temp file cleaned up")
logger.info(f"[INDEX] Final text length: {len(text)} chars")
if len(text) < 50:
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info(f"[INDEX] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
else:
logger.info(f"[INDEX] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_collection_exists(tenant_id)
points = []
total_chunks = len(all_chunks)
for i, chunk in enumerate(all_chunks):
embedding = await embedding_provider.embed(chunk.text)
payload = {
"text": chunk.text,
"source": doc_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
await qdrant.upsert_vectors(tenant_id, points)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[INDEX] COMPLETED: tenant={tenant_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
)
except Exception as e:
import traceback
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
@router.get(
"/index/jobs/{job_id}",
operation_id="getIndexJob",
summary="Query index job status",
description="[AC-ASA-02] Get indexing job status and progress.",
responses={
200: {"description": "Job status details"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_index_job(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
job_id: str,
) -> JSONResponse:
"""
[AC-ASA-02] Get indexing job status with progress.
"""
logger.info(
f"[AC-ASA-02] Getting job status: tenant={tenant_id}, job_id={job_id}"
)
kb_service = KBService(session)
job = await kb_service.get_index_job(tenant_id, job_id)
if not job:
return JSONResponse(
status_code=404,
content={
"code": "JOB_NOT_FOUND",
"message": f"Job {job_id} not found",
},
)
return JSONResponse(
content={
"jobId": str(job.id),
"docId": str(job.doc_id),
"status": job.status,
"progress": job.progress,
"errorMsg": job.error_msg,
}
)
@router.delete(
"/documents/{doc_id}",
operation_id="deleteDocument",
summary="Delete document",
description="[AC-ASA-08] Delete a document and its associated files.",
responses={
200: {"description": "Document deleted"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def delete_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
) -> JSONResponse:
"""
[AC-ASA-08] Delete a document.
"""
logger.info(
f"[AC-ASA-08] Deleting document: tenant={tenant_id}, doc_id={doc_id}"
)
kb_service = KBService(session)
deleted = await kb_service.delete_document(tenant_id, doc_id)
if not deleted:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
return JSONResponse(
content={
"success": True,
"message": "Document deleted",
}
)

View File

@ -0,0 +1,330 @@
"""
Knowledge base management API with RAG optimization features.
Reference: rag-optimization/spec.md Section 4.2
"""
import logging
from datetime import date
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.services.retrieval import (
ChunkMetadata,
ChunkMetadataModel,
IndexingProgress,
IndexingResult,
KnowledgeIndexer,
MetadataFilter,
RetrievalStrategy,
get_knowledge_indexer,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"])
class IndexDocumentRequest(BaseModel):
"""Request to index a document."""
tenant_id: str = Field(..., description="Tenant ID")
document_id: str = Field(..., description="Document ID")
text: str = Field(..., description="Document text content")
metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata")
class IndexDocumentResponse(BaseModel):
"""Response from document indexing."""
success: bool
total_chunks: int
indexed_chunks: int
failed_chunks: int
elapsed_seconds: float
error_message: str | None = None
class IndexingProgressResponse(BaseModel):
"""Response with current indexing progress."""
total_chunks: int
processed_chunks: int
failed_chunks: int
progress_percent: int
elapsed_seconds: float
current_document: str
class MetadataFilterRequest(BaseModel):
"""Request for metadata filtering."""
categories: list[str] | None = None
target_audiences: list[str] | None = None
departments: list[str] | None = None
valid_only: bool = True
min_priority: int | None = None
keywords: list[str] | None = None
class RetrieveRequest(BaseModel):
"""Request for knowledge retrieval."""
tenant_id: str = Field(..., description="Tenant ID")
query: str = Field(..., description="Search query")
top_k: int = Field(default=10, ge=1, le=50, description="Number of results")
filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters")
strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy")
class RetrieveResponse(BaseModel):
"""Response from knowledge retrieval."""
hits: list[dict[str, Any]]
total_hits: int
max_score: float
is_insufficient: bool
diagnostics: dict[str, Any]
class MetadataOptionsResponse(BaseModel):
"""Response with available metadata options."""
categories: list[str]
departments: list[str]
target_audiences: list[str]
priorities: list[int]
@router.post("/index", response_model=IndexDocumentResponse)
async def index_document(
request: IndexDocumentRequest,
session: AsyncSession = Depends(get_session),
):
"""
Index a document with optimized embedding.
Features:
- Task prefixes (search_document:) for document embedding
- Multi-dimensional vectors (256/512/768)
- Metadata support
"""
try:
index = get_knowledge_indexer()
chunk_metadata = None
if request.metadata:
chunk_metadata = ChunkMetadata(
category=request.metadata.category,
subcategory=request.metadata.subcategory,
target_audience=request.metadata.target_audience,
source_doc=request.metadata.source_doc,
source_url=request.metadata.source_url,
department=request.metadata.department,
priority=request.metadata.priority,
keywords=request.metadata.keywords,
)
result = await index.index_document(
tenant_id=request.tenant_id,
document_id=request.document_id,
text=request.text,
metadata=chunk_metadata,
)
return IndexDocumentResponse(
success=result.success,
total_chunks=result.total_chunks,
indexed_chunks=result.indexed_chunks,
failed_chunks=result.failed_chunks,
elapsed_seconds=result.elapsed_seconds,
error_message=result.error_message,
)
except Exception as e:
logger.error(f"[KB-API] Failed to index document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"索引失败: {str(e)}"
)
@router.get("/index/progress", response_model=IndexingProgressResponse | None)
async def get_indexing_progress():
"""Get current indexing progress."""
try:
index = get_knowledge_indexer()
progress = index.get_progress()
if progress is None:
return None
return IndexingProgressResponse(
total_chunks=progress.total_chunks,
processed_chunks=progress.processed_chunks,
failed_chunks=progress.failed_chunks,
progress_percent=progress.progress_percent,
elapsed_seconds=progress.elapsed_seconds,
current_document=progress.current_document,
)
except Exception as e:
logger.error(f"[KB-API] Failed to get progress: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取进度失败: {str(e)}"
)
@router.post("/retrieve", response_model=RetrieveResponse)
async def retrieve_knowledge(request: RetrieveRequest):
"""
Retrieve knowledge using optimized RAG.
Strategies:
- vector: Simple vector search
- bm25: BM25 keyword search
- hybrid: RRF combination of vector + BM25 (default)
- two_stage: Two-stage retrieval with Matryoshka dimensions
"""
try:
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
retriever = await get_optimized_retriever()
metadata_filter = None
if request.filters:
filter_dict = request.filters.model_dump(exclude_none=True)
metadata_filter = MetadataFilter(**filter_dict)
ctx = RetrievalContext(
tenant_id=request.tenant_id,
query=request.query,
)
if metadata_filter:
ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()}
result = await retriever.retrieve(ctx)
return RetrieveResponse(
hits=[
{
"text": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
],
total_hits=result.hit_count,
max_score=result.max_score,
is_insufficient=result.diagnostics.get("is_insufficient", False),
diagnostics=result.diagnostics or {},
)
except Exception as e:
logger.error(f"[KB-API] Failed to retrieve: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"检索失败: {str(e)}"
)
@router.get("/metadata/options", response_model=MetadataOptionsResponse)
async def get_metadata_options():
"""
Get available metadata options for filtering.
These would typically be loaded from a database.
"""
try:
return MetadataOptionsResponse(
categories=[
"课程咨询",
"考试政策",
"学籍管理",
"奖助学金",
"宿舍管理",
"校园服务",
"就业指导",
"其他",
],
departments=[
"教务处",
"学生处",
"财务处",
"后勤处",
"就业指导中心",
"图书馆",
"信息中心",
],
target_audiences=[
"本科生",
"研究生",
"留学生",
"新生",
"毕业生",
"教职工",
],
priorities=list(range(1, 11)),
)
except Exception as e:
logger.error(f"[KB-API] Failed to get metadata options: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取选项失败: {str(e)}"
)
@router.post("/reindex")
async def reindex_all(
tenant_id: str,
session: AsyncSession = Depends(get_session),
):
"""
Reindex all documents for a tenant with optimized embedding.
This would typically read from the documents table and reindex.
"""
try:
from app.models.entities import Document, DocumentStatus
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.status == DocumentStatus.COMPLETED.value,
)
result = await session.execute(stmt)
documents = result.scalars().all()
index = get_knowledge_indexer()
total_indexed = 0
total_failed = 0
for doc in documents:
if doc.file_path:
import os
if os.path.exists(doc.file_path):
with open(doc.file_path, 'r', encoding='utf-8') as f:
text = f.read()
result = await index.index_document(
tenant_id=tenant_id,
document_id=str(doc.id),
text=text,
)
total_indexed += result.indexed_chunks
total_failed += result.failed_chunks
return {
"success": True,
"total_documents": len(documents),
"total_indexed": total_indexed,
"total_failed": total_failed,
}
except Exception as e:
logger.error(f"[KB-API] Failed to reindex: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重新索引失败: {str(e)}"
)

View File

@ -0,0 +1,152 @@
"""
LLM Configuration Management API.
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management endpoints.
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException
from app.services.llm.factory import (
LLMConfigManager,
LLMProviderFactory,
get_llm_config_manager,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/llm", tags=["LLM Management"])
def get_tenant_id(x_tenant_id: str = Header(..., alias="X-Tenant-Id")) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-Id header is required")
return x_tenant_id
@router.get("/providers")
async def list_providers(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
List all available LLM providers.
[AC-ASA-15] Returns provider list with configuration schemas.
"""
logger.info(f"[AC-ASA-15] Listing LLM providers for tenant={tenant_id}")
providers = LLMProviderFactory.get_providers()
return {
"providers": [
{
"name": p.name,
"display_name": p.display_name,
"description": p.description,
"config_schema": p.config_schema,
}
for p in providers
],
}
@router.get("/config")
async def get_config(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Get current LLM configuration.
[AC-ASA-14] Returns current provider and config.
"""
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}")
manager = get_llm_config_manager()
config = manager.get_current_config()
masked_config = _mask_secrets(config.get("config", {}))
return {
"provider": config["provider"],
"config": masked_config,
}
@router.put("/config")
async def update_config(
body: dict[str, Any],
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Update LLM configuration.
[AC-ASA-16] Updates provider and config with validation.
"""
provider = body.get("provider")
config = body.get("config", {})
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}")
if not provider:
return {
"success": False,
"message": "Provider is required",
}
try:
manager = get_llm_config_manager()
await manager.update_config(provider, config)
return {
"success": True,
"message": f"LLM configuration updated to {provider}",
}
except ValueError as e:
logger.error(f"[AC-ASA-16] Invalid LLM config: {e}")
return {
"success": False,
"message": str(e),
}
@router.post("/test")
async def test_connection(
body: dict[str, Any] | None = None,
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Tests connection and returns response.
"""
body = body or {}
test_prompt = body.get("test_prompt", "你好,请简单介绍一下自己。")
provider = body.get("provider")
config = body.get("config")
logger.info(
f"[AC-ASA-17] Testing LLM connection for tenant={tenant_id}, "
f"provider={provider or 'current'}"
)
manager = get_llm_config_manager()
result = await manager.test_connection(
test_prompt=test_prompt,
provider=provider,
config=config,
)
return result
def _mask_secrets(config: dict[str, Any]) -> dict[str, Any]:
"""Mask secret fields in config for display."""
masked = {}
for key, value in config.items():
if key in ("api_key", "password", "secret"):
if value:
masked[key] = f"{str(value)[:4]}***"
else:
masked[key] = ""
else:
masked[key] = value
return masked

View File

@ -0,0 +1,330 @@
"""
RAG Lab endpoints for debugging and experimentation.
[AC-ASA-05, AC-ASA-19, AC-ASA-20, AC-ASA-21, AC-ASA-22] RAG experiment with AI output.
"""
import json
import logging
import time
from typing import Annotated, Any, List
from fastapi import APIRouter, Depends, Body
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from app.core.config import get_settings
from app.core.exceptions import MissingTenantIdException
from app.core.prompts import format_evidence_for_prompt, build_user_prompt_with_evidence
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.services.retrieval.vector_retriever import get_vector_retriever
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
from app.services.llm.factory import get_llm_config_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
class RAGExperimentRequest(BaseModel):
query: str = Field(..., description="Query text for retrieval")
kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search")
top_k: int = Field(default=5, description="Number of results to retrieve")
score_threshold: float = Field(default=0.5, description="Minimum similarity score")
generate_response: bool = Field(default=True, description="Whether to generate AI response")
llm_provider: str | None = Field(default=None, description="Specific LLM provider to use")
class AIResponse(BaseModel):
content: str
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
latency_ms: float = 0
model: str = ""
class RAGExperimentResult(BaseModel):
query: str
retrieval_results: List[dict] = []
final_prompt: str = ""
ai_response: AIResponse | None = None
total_latency_ms: float = 0
diagnostics: dict[str, Any] = {}
@router.post(
"/experiments/run",
operation_id="runRagExperiment",
summary="Run RAG debugging experiment with AI output",
description="[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Trigger RAG experiment with retrieval, prompt generation, and AI response.",
responses={
200: {"description": "Experiment results with retrieval, prompt, and AI response"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def run_rag_experiment(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
request: RAGExperimentRequest = Body(...),
) -> JSONResponse:
"""
[AC-ASA-05, AC-ASA-19, AC-ASA-21, AC-ASA-22] Run RAG experiment and return retrieval results with AI response.
"""
start_time = time.time()
logger.info(
f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, "
f"query={request.query[:50]}..., kb_ids={request.kb_ids}, "
f"generate_response={request.generate_response}"
)
settings = get_settings()
top_k = request.top_k or settings.rag_top_k
threshold = request.score_threshold or settings.rag_score_threshold
try:
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=request.query,
session_id="rag_experiment",
channel_type="admin",
metadata={"kb_ids": request.kb_ids},
)
result = await retriever.retrieve(retrieval_ctx)
retrieval_results = [
{
"content": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
]
final_prompt = _build_final_prompt(request.query, retrieval_results)
logger.info(
f"[AC-ASA-05] RAG retrieval complete: hits={len(retrieval_results)}, "
f"max_score={result.max_score:.3f}"
)
ai_response = None
if request.generate_response:
ai_response = await _generate_ai_response(
final_prompt,
provider=request.llm_provider,
)
total_latency_ms = (time.time() - start_time) * 1000
return JSONResponse(
content={
"query": request.query,
"retrieval_results": retrieval_results,
"final_prompt": final_prompt,
"ai_response": ai_response.model_dump() if ai_response else None,
"total_latency_ms": round(total_latency_ms, 2),
"diagnostics": result.diagnostics,
}
)
except Exception as e:
logger.error(f"[AC-ASA-05] RAG experiment failed: {e}")
fallback_results = _get_fallback_results(request.query)
fallback_prompt = _build_final_prompt(request.query, fallback_results)
ai_response = None
if request.generate_response:
ai_response = await _generate_ai_response(
fallback_prompt,
provider=request.llm_provider,
)
total_latency_ms = (time.time() - start_time) * 1000
return JSONResponse(
content={
"query": request.query,
"retrieval_results": fallback_results,
"final_prompt": fallback_prompt,
"ai_response": ai_response.model_dump() if ai_response else None,
"total_latency_ms": round(total_latency_ms, 2),
"diagnostics": {
"error": str(e),
"fallback": True,
},
}
)
@router.post(
"/experiments/stream",
operation_id="runRagExperimentStream",
summary="Run RAG experiment with streaming AI output",
description="[AC-ASA-20] Trigger RAG experiment with SSE streaming for AI response.",
responses={
200: {"description": "SSE stream with retrieval results and AI response"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def run_rag_experiment_stream(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
request: RAGExperimentRequest = Body(...),
) -> StreamingResponse:
"""
[AC-ASA-20] Run RAG experiment with SSE streaming for AI response.
"""
logger.info(
f"[AC-ASA-20] Running RAG experiment stream: tenant={tenant_id}, "
f"query={request.query[:50]}..."
)
settings = get_settings()
top_k = request.top_k or settings.rag_top_k
threshold = request.score_threshold or settings.rag_score_threshold
async def event_generator():
try:
# Use optimized retriever with RAG enhancements
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=request.query,
session_id="rag_experiment_stream",
channel_type="admin",
metadata={"kb_ids": request.kb_ids},
)
result = await retriever.retrieve(retrieval_ctx)
retrieval_results = [
{
"content": hit.text,
"score": hit.score,
"source": hit.source,
"metadata": hit.metadata,
}
for hit in result.hits
]
final_prompt = _build_final_prompt(request.query, retrieval_results)
logger.info(f"[AC-ASA-20] ========== RAG LAB STREAM FULL PROMPT ==========")
logger.info(f"[AC-ASA-20] Prompt length: {len(final_prompt)}")
logger.info(f"[AC-ASA-20] Prompt content:\n{final_prompt}")
logger.info(f"[AC-ASA-20] ==============================================")
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
if request.generate_response:
manager = get_llm_config_manager()
client = manager.get_client()
full_content = ""
async for chunk in client.stream_generate(
messages=[{"role": "user", "content": final_prompt}],
):
if chunk.delta:
full_content += chunk.delta
yield f"event: message\ndata: {json.dumps({'delta': chunk.delta})}\n\n"
yield f"event: final\ndata: {json.dumps({'content': full_content, 'finish_reason': 'stop'})}\n\n"
else:
yield f"event: final\ndata: {json.dumps({'content': '', 'finish_reason': 'skipped'})}\n\n"
except Exception as e:
logger.error(f"[AC-ASA-20] RAG experiment stream failed: {e}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
async def _generate_ai_response(
prompt: str,
provider: str | None = None,
) -> AIResponse | None:
"""
[AC-ASA-19, AC-ASA-21] Generate AI response from prompt.
"""
import time
logger.info(f"[AC-ASA-19] ========== RAG LAB FULL PROMPT ==========")
logger.info(f"[AC-ASA-19] Prompt length: {len(prompt)}")
logger.info(f"[AC-ASA-19] Prompt content:\n{prompt}")
logger.info(f"[AC-ASA-19] ==========================================")
try:
manager = get_llm_config_manager()
client = manager.get_client()
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": prompt}],
)
latency_ms = (time.time() - start_time) * 1000
return AIResponse(
content=response.content,
prompt_tokens=response.usage.get("prompt_tokens", 0),
completion_tokens=response.usage.get("completion_tokens", 0),
total_tokens=response.usage.get("total_tokens", 0),
latency_ms=round(latency_ms, 2),
model=response.model,
)
except Exception as e:
logger.error(f"[AC-ASA-19] AI response generation failed: {e}")
return AIResponse(
content=f"AI 响应生成失败: {str(e)}",
latency_ms=0,
)
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
"""
Build the final prompt from query and retrieval results.
Uses shared prompt configuration for consistency with orchestrator.
"""
evidence_text = format_evidence_for_prompt(retrieval_results, max_results=5, max_content_length=500)
return build_user_prompt_with_evidence(query, evidence_text)
def _get_fallback_results(query: str) -> list[dict]:
"""
Provide fallback results when retrieval fails.
"""
return [
{
"content": "检索服务暂时不可用,这是模拟结果。",
"score": 0.5,
"source": "fallback",
}
]

View File

@ -0,0 +1,293 @@
"""
Session monitoring and management endpoints.
[AC-ASA-07, AC-ASA-09] Session list and detail monitoring.
"""
import logging
from typing import Annotated, Optional, Sequence
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import ChatSession, ChatMessage, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/sessions", tags=["Session Monitoring"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"",
operation_id="listSessions",
summary="Query session list",
description="[AC-ASA-09] Get list of sessions with pagination and filtering.",
responses={
200: {"description": "Session list with pagination"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_sessions(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
status: Annotated[Optional[str], Query()] = None,
start_time: Annotated[Optional[str], Query(alias="startTime")] = None,
end_time: Annotated[Optional[str], Query(alias="endTime")] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> JSONResponse:
"""
[AC-ASA-09] List sessions with filtering and pagination.
"""
logger.info(
f"[AC-ASA-09] Listing sessions: tenant={tenant_id}, status={status}, "
f"start_time={start_time}, end_time={end_time}, page={page}, page_size={page_size}"
)
stmt = select(ChatSession).where(ChatSession.tenant_id == tenant_id)
if status:
stmt = stmt.where(ChatSession.metadata_["status"].as_string() == status)
if start_time:
try:
start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
stmt = stmt.where(ChatSession.created_at >= start_dt)
except ValueError:
pass
if end_time:
try:
end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
stmt = stmt.where(ChatSession.created_at <= end_dt)
except ValueError:
pass
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = await session.execute(count_stmt)
total = total_result.scalar() or 0
stmt = stmt.order_by(col(ChatSession.created_at).desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt)
sessions = result.scalars().all()
session_ids = [s.session_id for s in sessions]
if session_ids:
msg_count_stmt = (
select(
ChatMessage.session_id,
func.count(ChatMessage.id).label("count")
)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id.in_(session_ids)
)
.group_by(ChatMessage.session_id)
)
msg_count_result = await session.execute(msg_count_stmt)
msg_counts = {row.session_id: row.count for row in msg_count_result}
else:
msg_counts = {}
data = []
for s in sessions:
session_status = SessionStatus.ACTIVE.value
if s.metadata_ and "status" in s.metadata_:
session_status = s.metadata_["status"]
end_time_val = None
if s.metadata_ and "endTime" in s.metadata_:
end_time_val = s.metadata_["endTime"]
data.append({
"sessionId": s.session_id,
"tenantId": tenant_id,
"status": session_status,
"startTime": s.created_at.isoformat() + "Z",
"endTime": end_time_val,
"messageCount": msg_counts.get(s.session_id, 0),
"channelType": s.channel_type,
})
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
return JSONResponse(
content={
"data": data,
"pagination": {
"page": page,
"pageSize": page_size,
"total": total,
"totalPages": total_pages,
},
}
)
@router.get(
"/{session_id}",
operation_id="getSessionDetail",
summary="Get session details",
description="[AC-ASA-07] Get full session details with messages and trace.",
responses={
200: {"description": "Full session details with messages and trace"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_session_detail(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
session_id: str,
) -> JSONResponse:
"""
[AC-ASA-07] Get session detail with messages and trace information.
"""
logger.info(
f"[AC-ASA-07] Getting session detail: tenant={tenant_id}, session_id={session_id}"
)
session_stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
session_result = await session.execute(session_stmt)
chat_session = session_result.scalar_one_or_none()
if not chat_session:
return JSONResponse(
status_code=404,
content={
"code": "SESSION_NOT_FOUND",
"message": f"Session {session_id} not found",
},
)
messages_stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).asc())
)
messages_result = await session.execute(messages_stmt)
messages = messages_result.scalars().all()
messages_data = []
for msg in messages:
msg_data = {
"role": msg.role,
"content": msg.content,
"timestamp": msg.created_at.isoformat() + "Z",
}
messages_data.append(msg_data)
trace = _build_trace_info(messages)
return JSONResponse(
content={
"sessionId": session_id,
"messages": messages_data,
"trace": trace,
"metadata": chat_session.metadata_ or {},
}
)
def _build_trace_info(messages: Sequence[ChatMessage]) -> dict:
"""
Build trace information from messages.
This extracts retrieval and tool call information from message metadata.
"""
trace = {
"retrieval": [],
"tools": [],
"errors": [],
}
for msg in messages:
if msg.role == "assistant":
pass
return trace
@router.put(
"/{session_id}/status",
operation_id="updateSessionStatus",
summary="Update session status",
description="[AC-ASA-09] Update session status (active, closed, expired).",
responses={
200: {"description": "Session status updated"},
404: {"description": "Session not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_session_status(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
db_session: Annotated[AsyncSession, Depends(get_session)],
session_id: str,
status: str = Query(..., description="New status: active, closed, expired"),
) -> JSONResponse:
"""
[AC-ASA-09] Update session status.
"""
logger.info(
f"[AC-ASA-09] Updating session status: tenant={tenant_id}, "
f"session_id={session_id}, status={status}"
)
stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
result = await db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
return JSONResponse(
status_code=404,
content={
"code": "SESSION_NOT_FOUND",
"message": f"Session {session_id} not found",
},
)
metadata = chat_session.metadata_ or {}
metadata["status"] = status
if status == SessionStatus.CLOSED.value or status == SessionStatus.EXPIRED.value:
metadata["endTime"] = datetime.utcnow().isoformat() + "Z"
chat_session.metadata_ = metadata
chat_session.updated_at = datetime.utcnow()
await db_session.flush()
return JSONResponse(
content={
"success": True,
"sessionId": session_id,
"status": status,
}
)

View File

@ -0,0 +1,78 @@
"""
Tenant management endpoints.
Provides tenant list and management functionality.
"""
import logging
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.middleware import parse_tenant_id
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import Tenant
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/tenants", tags=["Tenants"])
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"",
operation_id="listTenants",
summary="List all tenants",
description="Get a list of all tenants from the system.",
responses={
200: {"description": "List of tenants"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_tenants(
session: Annotated[AsyncSession, Depends(get_session)],
) -> JSONResponse:
"""
Get a list of all tenants from the tenants table.
Returns tenant ID and display name (first part of tenant_id).
"""
logger.info("Getting all tenants")
# Get all tenants from tenants table
stmt = select(Tenant).order_by(Tenant.created_at.desc())
result = await session.execute(stmt)
tenants = result.scalars().all()
# Format tenant list with display name
tenant_list = []
for tenant in tenants:
name, year = parse_tenant_id(tenant.tenant_id)
tenant_list.append({
"id": tenant.tenant_id,
"name": f"{name} ({year})",
"displayName": name,
"year": year,
"createdAt": tenant.created_at.isoformat() if tenant.created_at else None,
})
logger.info(f"Found {len(tenant_list)} tenants: {[t['id'] for t in tenant_list]}")
return JSONResponse(
content={
"tenants": tenant_list,
"total": len(tenant_list)
}
)

191
ai-service/app/api/chat.py Normal file
View File

@ -0,0 +1,191 @@
"""
Chat endpoint for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-08, AC-AISVC-09] Main chat endpoint with streaming/non-streaming modes.
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Header, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.middleware import get_response_mode, is_sse_request
from app.core.sse import SSEStateMachine, create_error_event
from app.core.tenant import get_tenant_id
from app.models import ChatRequest, ChatResponse, ErrorResponse
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["AI Chat"])
async def get_orchestrator_service_with_memory(
session: Annotated[AsyncSession, Depends(get_session)]
) -> OrchestratorService:
"""
[AC-AISVC-13] Create orchestrator service with memory service and LLM client.
Ensures each request has a fresh MemoryService with database session.
"""
from app.services.llm.factory import get_llm_config_manager
from app.services.retrieval.optimized_retriever import get_optimized_retriever
memory_service = MemoryService(session)
llm_config_manager = get_llm_config_manager()
llm_client = llm_config_manager.get_client()
retriever = await get_optimized_retriever()
return OrchestratorService(
llm_client=llm_client,
memory_service=memory_service,
retriever=retriever,
)
@router.post(
"/ai/chat",
operation_id="generateReply",
summary="Generate AI reply",
description="""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Generate AI reply based on user message.
Response mode is determined by Accept header:
- Accept: text/event-stream -> SSE streaming response
- Other -> JSON response
""",
responses={
200: {
"description": "Success - JSON or SSE stream",
"content": {
"application/json": {"schema": {"$ref": "#/components/schemas/ChatResponse"}},
"text/event-stream": {"schema": {"type": "string"}},
},
},
400: {"description": "Invalid request", "model": ErrorResponse},
500: {"description": "Internal error", "model": ErrorResponse},
503: {"description": "Service unavailable", "model": ErrorResponse},
},
)
async def generate_reply(
request: Request,
chat_request: ChatRequest,
accept: Annotated[str | None, Header()] = None,
orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory),
) -> Any:
"""
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
Based on Accept header:
- text/event-stream: Returns SSE stream with message/final/error events
- Other: Returns JSON ChatResponse
"""
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
logger.info(
f"[AC-AISVC-06] Processing chat request: tenant={tenant_id}, "
f"session={chat_request.session_id}, mode={get_response_mode(request)}"
)
if is_sse_request(request):
return await _handle_streaming_request(tenant_id, chat_request, orchestrator)
else:
return await _handle_json_request(tenant_id, chat_request, orchestrator)
async def _handle_json_request(
tenant_id: str,
chat_request: ChatRequest,
orchestrator: OrchestratorService,
) -> JSONResponse:
"""
[AC-AISVC-02] Handle non-streaming JSON request.
Returns ChatResponse with reply, confidence, shouldTransfer.
"""
logger.info(f"[AC-AISVC-02] Processing JSON request for tenant={tenant_id}")
try:
response = await orchestrator.generate(tenant_id, chat_request)
return JSONResponse(
content=response.model_dump(exclude_none=True, by_alias=True),
)
except Exception as e:
logger.error(f"[AC-AISVC-04] Error generating response: {e}")
from app.core.exceptions import AIServiceException, ErrorCode
if isinstance(e, AIServiceException):
raise e
from app.core.exceptions import AIServiceException
raise AIServiceException(
code=ErrorCode.INTERNAL_ERROR,
message=str(e),
)
async def _handle_streaming_request(
tenant_id: str,
chat_request: ChatRequest,
orchestrator: OrchestratorService,
) -> EventSourceResponse:
"""
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request.
SSE Event Sequence (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
State machine ensures:
- No events after final/error
- Only one final OR one error event
- Proper connection close
"""
logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}")
state_machine = SSEStateMachine()
async def event_generator():
"""
[AC-AISVC-08, AC-AISVC-09] Event generator with state machine enforcement.
Ensures proper event sequence and error handling.
"""
await state_machine.transition_to_streaming()
try:
async for event in orchestrator.generate_stream(tenant_id, chat_request):
if not state_machine.can_send_message():
logger.warning("[AC-AISVC-08] Received event after state machine closed, ignoring")
break
if event.event == "final":
if await state_machine.transition_to_final():
logger.info("[AC-AISVC-08] Sending final event and closing stream")
yield event
break
elif event.event == "error":
if await state_machine.transition_to_error():
logger.info("[AC-AISVC-09] Sending error event and closing stream")
yield event
break
elif event.event == "message":
yield event
except Exception as e:
logger.error(f"[AC-AISVC-09] Streaming error: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
finally:
await state_machine.close()
logger.debug("[AC-AISVC-08] SSE connection closed")
return EventSourceResponse(event_generator(), ping=15)

View File

@ -0,0 +1,30 @@
"""
Health check endpoint.
[AC-AISVC-20] Health check for service monitoring.
"""
from fastapi import APIRouter, status
from fastapi.responses import JSONResponse
router = APIRouter(tags=["Health"])
@router.get(
"/ai/health",
operation_id="healthCheck",
summary="Health check",
description="[AC-AISVC-20] Check if AI service is healthy",
responses={
200: {"description": "Service is healthy"},
503: {"description": "Service is unhealthy"},
},
)
async def health_check() -> JSONResponse:
"""
[AC-AISVC-20] Health check endpoint.
Returns 200 with status if healthy, 503 if not.
"""
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"status": "healthy"},
)

View File

@ -0,0 +1,19 @@
"""
Core module - Configuration, dependencies, and utilities.
[AC-AISVC-01, AC-AISVC-10, AC-AISVC-11] Core infrastructure components.
"""
from app.core.config import Settings, get_settings
from app.core.database import async_session_maker, get_session, init_db, close_db
from app.core.qdrant_client import QdrantClient, get_qdrant_client
__all__ = [
"Settings",
"get_settings",
"async_session_maker",
"get_session",
"init_db",
"close_db",
"QdrantClient",
"get_qdrant_client",
]

View File

@ -0,0 +1,66 @@
"""
Configuration management for AI Service.
[AC-AISVC-01] Centralized configuration with environment variable support.
"""
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="AI_SERVICE_", env_file=".env", extra="ignore")
app_name: str = "AI Service"
app_version: str = "0.1.0"
debug: bool = False
host: str = "0.0.0.0"
port: int = 8080
request_timeout_seconds: int = 20
sse_ping_interval_seconds: int = 15
log_level: str = "INFO"
llm_provider: str = "openai"
llm_api_key: str = ""
llm_base_url: str = "https://api.openai.com/v1"
llm_model: str = "gpt-4o-mini"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
llm_timeout_seconds: int = 30
llm_max_retries: int = 3
database_url: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/ai_service"
database_pool_size: int = 10
database_max_overflow: int = 20
qdrant_url: str = "http://localhost:6333"
qdrant_collection_prefix: str = "kb_"
qdrant_vector_size: int = 768
ollama_base_url: str = "http://localhost:11434"
ollama_embedding_model: str = "nomic-embed-text"
rag_top_k: int = 5
rag_score_threshold: float = 0.01
rag_min_hits: int = 1
rag_max_evidence_tokens: int = 2000
rag_two_stage_enabled: bool = True
rag_two_stage_expand_factor: int = 10
rag_hybrid_enabled: bool = True
rag_rrf_k: int = 60
rag_vector_weight: float = 0.7
rag_bm25_weight: float = 0.3
confidence_low_threshold: float = 0.5
confidence_high_threshold: float = 0.8
confidence_insufficient_penalty: float = 0.3
max_history_tokens: int = 4000
@lru_cache
def get_settings() -> Settings:
return Settings()

View File

@ -0,0 +1,67 @@
"""
Database client for AI Service.
[AC-AISVC-11] PostgreSQL database with SQLModel for multi-tenant data isolation.
"""
import logging
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from sqlmodel import SQLModel
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
engine = create_async_engine(
settings.database_url,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
echo=settings.debug,
pool_pre_ping=True,
)
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def init_db() -> None:
"""
[AC-AISVC-11] Initialize database tables.
Creates all tables defined in SQLModel metadata.
"""
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("[AC-AISVC-11] Database tables initialized")
async def close_db() -> None:
"""
Close database connections.
"""
await engine.dispose()
logger.info("Database connections closed")
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
[AC-AISVC-11] Dependency injection for database session.
Ensures proper session lifecycle management.
"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

View File

@ -0,0 +1,99 @@
"""
Exception handling for AI Service.
[AC-AISVC-03, AC-AISVC-04, AC-AISVC-05] Structured error responses.
"""
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from app.models import ErrorCode, ErrorResponse
class AIServiceException(Exception):
def __init__(
self,
code: ErrorCode,
message: str,
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
details: list[dict] | None = None,
):
self.code = code
self.message = message
self.status_code = status_code
self.details = details
super().__init__(message)
class MissingTenantIdException(AIServiceException):
def __init__(self, message: str = "Missing required header: X-Tenant-Id"):
super().__init__(
code=ErrorCode.MISSING_TENANT_ID,
message=message,
status_code=status.HTTP_400_BAD_REQUEST,
)
class InvalidRequestException(AIServiceException):
def __init__(self, message: str, details: list[dict] | None = None):
super().__init__(
code=ErrorCode.INVALID_REQUEST,
message=message,
status_code=status.HTTP_400_BAD_REQUEST,
details=details,
)
class ServiceUnavailableException(AIServiceException):
def __init__(self, message: str = "Service temporarily unavailable"):
super().__init__(
code=ErrorCode.SERVICE_UNAVAILABLE,
message=message,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
class TimeoutException(AIServiceException):
def __init__(self, message: str = "Request timeout"):
super().__init__(
code=ErrorCode.TIMEOUT,
message=message,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
async def ai_service_exception_handler(request: Request, exc: AIServiceException) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=exc.code.value,
message=exc.message,
details=exc.details,
).model_dump(exclude_none=True),
)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
if exc.status_code == status.HTTP_400_BAD_REQUEST:
code = ErrorCode.INVALID_REQUEST
elif exc.status_code == status.HTTP_503_SERVICE_UNAVAILABLE:
code = ErrorCode.SERVICE_UNAVAILABLE
else:
code = ErrorCode.INTERNAL_ERROR
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=code.value,
message=exc.detail or "An error occurred",
).model_dump(exclude_none=True),
)
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=ErrorResponse(
code=ErrorCode.INTERNAL_ERROR.value,
message="An unexpected error occurred",
).model_dump(exclude_none=True),
)

View File

@ -0,0 +1,150 @@
"""
Middleware for AI Service.
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
"""
import logging
import re
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
from app.core.tenant import clear_tenant_context, set_tenant_context
logger = logging.getLogger(__name__)
TENANT_ID_HEADER = "X-Tenant-Id"
ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream"
# Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
def validate_tenant_id_format(tenant_id: str) -> bool:
"""
[AC-AISVC-10] Validate tenant ID format: name@ash@year
Examples: szmp@ash@2026, abc123@ash@2025
"""
return bool(TENANT_ID_PATTERN.match(tenant_id))
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
"""
[AC-AISVC-10] Parse tenant ID into name and year.
Returns: (name, year)
"""
parts = tenant_id.split('@')
return parts[0], parts[2]
class TenantContextMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
Injects tenant context into request state for downstream processing.
Validates tenant ID format and auto-creates tenant if not exists.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context()
if request.url.path == "/ai/health":
return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER)
if not tenant_id or not tenant_id.strip():
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.MISSING_TENANT_ID.value,
message="Missing required header: X-Tenant-Id",
).model_dump(exclude_none=True),
)
tenant_id = tenant_id.strip()
# Validate tenant ID format
if not validate_tenant_id_format(tenant_id):
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.INVALID_TENANT_ID.value,
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
).model_dump(exclude_none=True),
)
# Auto-create tenant if not exists (for admin endpoints)
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
try:
await self._ensure_tenant_exists(request, tenant_id)
except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
# Continue processing even if tenant creation fails
set_tenant_context(tenant_id)
request.state.tenant_id = tenant_id
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}")
try:
response = await call_next(request)
finally:
clear_tenant_context()
return response
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
"""
[AC-AISVC-10] Ensure tenant exists in database, create if not.
"""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import async_session_maker
from app.models.entities import Tenant
name, year = parse_tenant_id(tenant_id)
async with async_session_maker() as session:
# Check if tenant exists
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
result = await session.execute(stmt)
existing_tenant = result.scalar_one_or_none()
if existing_tenant:
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
return
# Create new tenant
new_tenant = Tenant(
tenant_id=tenant_id,
name=name,
year=year,
)
session.add(new_tenant)
await session.commit()
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
def is_sse_request(request: Request) -> bool:
"""
[AC-AISVC-06] Check if the request expects SSE streaming response.
Based on Accept header: text/event-stream indicates SSE mode.
"""
accept_header = request.headers.get(ACCEPT_HEADER, "")
return SSE_CONTENT_TYPE in accept_header
def get_response_mode(request: Request) -> str:
"""
[AC-AISVC-06] Determine response mode based on Accept header.
Returns 'streaming' for SSE, 'json' for regular JSON response.
"""
return "streaming" if is_sse_request(request) else "json"

View File

@ -0,0 +1,114 @@
"""
Shared prompt templates for AI Service.
Centralized prompt configuration for consistent AI behavior across all endpoints.
"""
SYSTEM_PROMPT = """你是一名经验丰富的客服专员,名字叫"小N"。请遵循以下准则:
像朋友一样聊天使用自然口语化的中文语气亲切有礼貌就像在和一位老朋友聊天多使用"咱们""您看""是不是"这样拉近距离的表达
避免生硬罗列介绍多个信息时不要用"第一、第二"或分点列举而是用流畅的段落把它们串联起来比如"我们这边有几类课程挺受欢迎的,一个是……另外还有……您对哪方面更感兴趣呀?"
主动引导对话不要一次性把知识库里的所有内容都倒出来可以先简单概括然后问问客户具体想了解什么比如"您是想提升职场英语,还是有留学方面的打算呢?"
委婉处理未知如果知识库里没有答案别直接说"不知道"可以说"您稍等,我帮您确认一下""这个问题我需要问问同事,马上给您回复"
隐藏机器身份绝不要提及"作为AI""根据算法"等词也避免使用过于官方的术语就像普通人一样用"""我们"
多用语气词和情感表达适当加入""""""""等语气词根据情况使用表情符号😊表达理解和关心比如"我明白您的意思""别着急,我们一起看看"
保持简洁但有温度回答一般控制在3-5句话如果需要详细说明可以分成小段落但每段都要口语化读起来不累
隐私与安全提醒如果客户问到敏感信息如密码转账要温和地引导至人工渠道"为了您的信息安全建议您拨打官方电话400-xxx-xxxx咨询会更稳妥哦。"""
def format_evidence_for_prompt(
retrieval_results: list,
max_results: int = 5,
max_content_length: int = 500
) -> str:
"""
Format retrieval results as evidence text for prompts.
Args:
retrieval_results: List of retrieval hits. Can be:
- dict format: {'content', 'score', 'source', 'metadata'}
- RetrievalHit object: with .text, .score, .source, .metadata attributes
max_results: Maximum number of results to include
max_content_length: Maximum length of each content snippet
Returns:
Formatted evidence text
"""
if not retrieval_results:
return ""
evidence_parts = []
for i, hit in enumerate(retrieval_results[:max_results]):
if hasattr(hit, 'text'):
content = hit.text
score = hit.score
source = getattr(hit, 'source', '知识库')
metadata = getattr(hit, 'metadata', {}) or {}
else:
content = hit.get('content', '')
score = hit.get('score', 0)
source = hit.get('source', '知识库')
metadata = hit.get('metadata', {}) or {}
if len(content) > max_content_length:
content = content[:max_content_length] + '...'
nested_meta = metadata.get('metadata', {})
source_doc = nested_meta.get('source_doc', source) if nested_meta else source
category = nested_meta.get('category', '') if nested_meta else ''
department = nested_meta.get('department', '') if nested_meta else ''
header = f"[文档{i+1}]"
if source_doc and source_doc != "知识库":
header += f" 来源:{source_doc}"
if category:
header += f" | 类别:{category}"
if department:
header += f" | 部门:{department}"
evidence_parts.append(f"{header}\n相关度:{score:.2f}\n内容:{content}")
return "\n\n".join(evidence_parts)
def build_system_prompt_with_evidence(evidence_text: str) -> str:
"""
Build system prompt with knowledge base evidence.
Args:
evidence_text: Formatted evidence from retrieval results
Returns:
Complete system prompt
"""
if not evidence_text:
return SYSTEM_PROMPT
return f"""{SYSTEM_PROMPT}
知识库参考内容
{evidence_text}"""
def build_user_prompt_with_evidence(query: str, evidence_text: str) -> str:
"""
Build user prompt with knowledge base evidence (for single-message format).
Args:
query: User's question
evidence_text: Formatted evidence from retrieval results
Returns:
Complete user prompt
"""
if not evidence_text:
return f"""用户问题:{query}
未找到相关检索结果请基于通用知识回答用户问题"""
return f"""【系统指令】
{SYSTEM_PROMPT}
知识库内容
{evidence_text}
用户问题
{query}"""

View File

@ -0,0 +1,314 @@
"""
Qdrant client for AI Service.
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
Supports multi-dimensional vectors for Matryoshka representation learning.
"""
import logging
from typing import Any
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class QdrantClient:
"""
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
Collection naming: kb_{tenantId} for tenant isolation.
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
"""
def __init__(self):
self._client: AsyncQdrantClient | None = None
self._collection_prefix = settings.qdrant_collection_prefix
self._vector_size = settings.qdrant_vector_size
async def get_client(self) -> AsyncQdrantClient:
"""Get or create Qdrant client instance."""
if self._client is None:
self._client = AsyncQdrantClient(url=settings.qdrant_url)
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
return self._client
async def close(self) -> None:
"""Close Qdrant client connection."""
if self._client:
await self._client.close()
self._client = None
logger.info("Qdrant client connection closed")
def get_collection_name(self, tenant_id: str) -> str:
"""
[AC-AISVC-10] Get collection name for a tenant.
Naming convention: kb_{tenantId}
Replaces @ with _ to ensure valid collection names.
"""
safe_tenant_id = tenant_id.replace('@', '_')
return f"{self._collection_prefix}{safe_tenant_id}"
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
"""
[AC-AISVC-10] Ensure collection exists for tenant.
Supports multi-dimensional vectors for Matryoshka retrieval.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
collections = await client.get_collections()
exists = any(c.name == collection_name for c in collections.collections)
if not exists:
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
return False
async def upsert_vectors(
self,
tenant_id: str,
points: list[PointStruct],
) -> bool:
"""
[AC-AISVC-10] Upsert vectors into tenant's collection.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.upsert(
collection_name=collection_name,
points=points,
)
logger.info(
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
return False
async def upsert_multi_vector(
self,
tenant_id: str,
points: list[dict[str, Any]],
) -> bool:
"""
Upsert points with multi-dimensional vectors.
Args:
tenant_id: Tenant identifier
points: List of points with format:
{
"id": str | int,
"vector": {
"full": [768 floats],
"dim_256": [256 floats],
"dim_512": [512 floats],
},
"payload": dict
}
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
qdrant_points = []
for p in points:
point = PointStruct(
id=p["id"],
vector=p["vector"],
payload=p.get("payload", {}),
)
qdrant_points.append(point)
await client.upsert(
collection_name=collection_name,
points=qdrant_points,
)
logger.info(
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}"
)
return True
except Exception as e:
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
return False
async def search(
self,
tenant_id: str,
query_vector: list[float],
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collection.
Returns results with score >= score_threshold if specified.
Searches both old format (with @) and new format (with _) for backward compatibility.
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
limit: Maximum number of results
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup.
"""
client = await self.get_client()
logger.info(
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
)
collection_names = [self.get_collection_name(tenant_id)]
if '@' in tenant_id:
old_format = f"{self._collection_prefix}{tenant_id}"
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
collection_names = [new_format, old_format]
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
all_hits = []
for collection_name in collection_names:
try:
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
try:
results = await client.search(
collection_name=collection_name,
query_vector=(vector_name, query_vector),
limit=limit,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
logger.info(
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
f"trying without vector name (single-vector mode)"
)
results = await client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
)
else:
raise
logger.info(
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
)
hits = [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
if score_threshold is None or result.score >= score_threshold
]
all_hits.extend(hits)
if hits:
logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
)
for i, h in enumerate(hits[:3]):
logger.debug(
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
)
else:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
)
except Exception as e:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
)
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
)
if len(all_hits) == 0:
logger.warning(
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
f"collections_tried={collection_names}, limit={limit}"
)
return all_hits
async def delete_collection(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Delete tenant's collection.
Used when tenant is removed.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
return False
_qdrant_client: QdrantClient | None = None
async def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client instance."""
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = QdrantClient()
return _qdrant_client
async def close_qdrant_client() -> None:
"""Close Qdrant client connection."""
global _qdrant_client
if _qdrant_client:
await _qdrant_client.close()
_qdrant_client = None

173
ai-service/app/core/sse.py Normal file
View File

@ -0,0 +1,173 @@
"""
SSE utilities for AI Service.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine.
"""
import asyncio
import json
import logging
from enum import Enum
from typing import Any, AsyncGenerator
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from app.core.config import get_settings
from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent
logger = logging.getLogger(__name__)
class SSEState(str, Enum):
INIT = "INIT"
STREAMING = "STREAMING"
FINAL_SENT = "FINAL_SENT"
ERROR_SENT = "ERROR_SENT"
CLOSED = "CLOSED"
class SSEStateMachine:
"""
[AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence.
State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED
"""
def __init__(self):
self._state = SSEState.INIT
self._lock = asyncio.Lock()
@property
def state(self) -> SSEState:
return self._state
async def transition_to_streaming(self) -> bool:
async with self._lock:
if self._state == SSEState.INIT:
self._state = SSEState.STREAMING
logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING")
return True
return False
async def transition_to_final(self) -> bool:
async with self._lock:
if self._state == SSEState.STREAMING:
self._state = SSEState.FINAL_SENT
logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT")
return True
return False
async def transition_to_error(self) -> bool:
async with self._lock:
if self._state in (SSEState.INIT, SSEState.STREAMING):
self._state = SSEState.ERROR_SENT
logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT")
return True
return False
async def close(self) -> None:
async with self._lock:
self._state = SSEState.CLOSED
logger.debug("SSE state transition: -> CLOSED")
def can_send_message(self) -> bool:
return self._state == SSEState.STREAMING
def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent:
"""Format data as SSE event."""
return ServerSentEvent(
event=event_type.value,
data=json.dumps(data, ensure_ascii=False),
)
def create_message_event(delta: str) -> ServerSentEvent:
"""[AC-AISVC-07] Create a message event with incremental content."""
event_data = SSEMessageEvent(delta=delta)
return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump())
def create_final_event(
reply: str,
confidence: float,
should_transfer: bool,
transfer_reason: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-08] Create a final event with complete response."""
event_data = SSEFinalEvent(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=metadata,
)
return format_sse_event(
SSEEventType.FINAL,
event_data.model_dump(exclude_none=True, by_alias=True)
)
def create_error_event(
code: str,
message: str,
details: list[dict[str, Any]] | None = None,
) -> ServerSentEvent:
"""[AC-AISVC-09] Create an error event."""
event_data = SSEErrorEvent(
code=code,
message=message,
details=details,
)
return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True))
async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]:
"""
[AC-AISVC-06] Generate ping comments for SSE keep-alive.
Sends ': ping' as comment lines (not events) to keep connection alive.
"""
while True:
await asyncio.sleep(interval_seconds)
yield ": ping\n\n"
class SSEResponseBuilder:
"""
Builder for SSE response with proper event sequencing and ping keep-alive.
"""
def __init__(self):
self._state_machine = SSEStateMachine()
self._settings = get_settings()
async def build_response(
self,
content_generator: AsyncGenerator[ServerSentEvent, None],
) -> EventSourceResponse:
"""
Build SSE response with ping keep-alive mechanism.
[AC-AISVC-06] Implements ping keep-alive to prevent connection timeout.
"""
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
await self._state_machine.transition_to_streaming()
try:
async for event in content_generator:
if self._state_machine.can_send_message():
yield event
else:
break
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}")
if await self._state_machine.transition_to_error():
yield create_error_event(
code="STREAMING_ERROR",
message=str(e),
)
finally:
await self._state_machine.close()
return EventSourceResponse(
event_generator(),
ping=self._settings.sse_ping_interval_seconds,
)

View File

@ -0,0 +1,31 @@
"""
Tenant context management.
[AC-AISVC-10, AC-AISVC-12] Multi-tenant isolation via X-Tenant-Id header.
"""
from contextvars import ContextVar
from dataclasses import dataclass
tenant_context: ContextVar["TenantContext | None"] = ContextVar("tenant_context", default=None)
@dataclass
class TenantContext:
tenant_id: str
def set_tenant_context(tenant_id: str) -> None:
tenant_context.set(TenantContext(tenant_id=tenant_id))
def get_tenant_context() -> TenantContext | None:
return tenant_context.get()
def get_tenant_id() -> str | None:
ctx = get_tenant_context()
return ctx.tenant_id if ctx else None
def clear_tenant_context() -> None:
tenant_context.set(None)

134
ai-service/app/main.py Normal file
View File

@ -0,0 +1,134 @@
"""
Main FastAPI application for AI Service.
[AC-AISVC-01] Entry point with middleware and exception handlers.
"""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api import chat_router, health_router
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
from app.api.admin.kb_optimized import router as kb_optimized_router
from app.core.config import get_settings
from app.core.database import close_db, init_db
from app.core.exceptions import (
AIServiceException,
ErrorCode,
ErrorResponse,
ai_service_exception_handler,
generic_exception_handler,
http_exception_handler,
)
from app.core.middleware import TenantContextMiddleware
from app.core.qdrant_client import close_qdrant_client
settings = get_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
[AC-AISVC-01, AC-AISVC-11] Application lifespan manager.
Handles startup and shutdown of database and external connections.
"""
logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}")
try:
await init_db()
logger.info("[AC-AISVC-11] Database initialized successfully")
except Exception as e:
logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}")
yield
await close_db()
await close_qdrant_client()
logger.info(f"Shutting down {settings.app_name}")
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="""
Python AI Service for intelligent chat with RAG support.
## Features
- Multi-tenant isolation via X-Tenant-Id header
- SSE streaming support via Accept: text/event-stream
- RAG-powered responses with confidence scoring
## Response Modes
- **JSON**: Default response mode (Accept: application/json or no Accept header)
- **SSE Streaming**: Set Accept: text/event-stream for streaming responses
""",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(TenantContextMiddleware)
app.add_exception_handler(AIServiceException, ai_service_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
[AC-AISVC-03] Handle request validation errors with structured response.
"""
logger.warning(f"[AC-AISVC-03] Request validation error: {exc.errors()}")
error_response = ErrorResponse(
code=ErrorCode.INVALID_REQUEST.value,
message="Request validation failed",
details=[{"loc": list(err["loc"]), "msg": err["msg"], "type": err["type"]} for err in exc.errors()],
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=error_response.model_dump(exclude_none=True),
)
app.include_router(health_router)
app.include_router(chat_router)
app.include_router(dashboard_router)
app.include_router(embedding_router)
app.include_router(kb_router)
app.include_router(kb_optimized_router)
app.include_router(llm_router)
app.include_router(rag_router)
app.include_router(sessions_router)
app.include_router(tenants_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)

View File

@ -0,0 +1,89 @@
"""
Data models for AI Service.
[AC-AISVC-02] Request/Response models aligned with OpenAPI contract.
[AC-AISVC-13] Entity models for database persistence.
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class ChannelType(str, Enum):
WECHAT = "wechat"
DOUYIN = "douyin"
JD = "jd"
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
class ChatMessage(BaseModel):
role: Role = Field(..., description="Message role: user or assistant")
content: str = Field(..., description="Message content")
class ChatRequest(BaseModel):
session_id: str = Field(..., alias="sessionId", description="Session ID for conversation tracking")
current_message: str = Field(..., alias="currentMessage", description="Current user message")
channel_type: ChannelType = Field(..., alias="channelType", description="Channel type: wechat, douyin, jd")
history: list[ChatMessage] | None = Field(default=None, description="Optional conversation history")
metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata")
model_config = {"populate_by_name": True}
class ChatResponse(BaseModel):
reply: str = Field(..., description="AI generated reply content")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0.0 and 1.0")
should_transfer: bool = Field(..., alias="shouldTransfer", description="Whether to suggest transfer to human agent")
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Reason for transfer suggestion")
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
model_config = {"populate_by_name": True}
class ErrorCode(str, Enum):
INVALID_REQUEST = "INVALID_REQUEST"
MISSING_TENANT_ID = "MISSING_TENANT_ID"
INVALID_TENANT_ID = "INVALID_TENANT_ID"
INTERNAL_ERROR = "INTERNAL_ERROR"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
TIMEOUT = "TIMEOUT"
LLM_ERROR = "LLM_ERROR"
RETRIEVAL_ERROR = "RETRIEVAL_ERROR"
class ErrorResponse(BaseModel):
code: str = Field(..., description="Error code")
message: str = Field(..., description="Error message")
details: list[dict[str, Any]] | None = Field(default=None, description="Detailed error information")
class SSEEventType(str, Enum):
MESSAGE = "message"
FINAL = "final"
ERROR = "error"
class SSEMessageEvent(BaseModel):
delta: str = Field(..., description="Incremental text content")
class SSEFinalEvent(BaseModel):
reply: str = Field(..., description="Complete AI reply")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
should_transfer: bool = Field(..., alias="shouldTransfer", description="Transfer suggestion")
transfer_reason: str | None = Field(default=None, alias="transferReason", description="Transfer reason")
metadata: dict[str, Any] | None = Field(default=None, description="Response metadata")
model_config = {"populate_by_name": True}
class SSEErrorEvent(BaseModel):
code: str = Field(..., description="Error code")
message: str = Field(..., description="Error message")
details: list[dict[str, Any]] | None = Field(default=None, description="Error details")

View File

@ -0,0 +1,200 @@
"""
Memory layer entities for AI Service.
[AC-AISVC-13] SQLModel entities for chat sessions and messages with tenant isolation.
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any
from sqlalchemy import Column, JSON
from sqlmodel import Field, Index, SQLModel
class ChatSession(SQLModel, table=True):
"""
[AC-AISVC-13] Chat session entity with tenant isolation.
Primary key: (tenant_id, session_id) composite unique constraint.
"""
__tablename__ = "chat_sessions"
__table_args__ = (
Index("ix_chat_sessions_tenant_session", "tenant_id", "session_id", unique=True),
Index("ix_chat_sessions_tenant_id", "tenant_id"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
session_id: str = Field(..., description="Session ID for conversation tracking")
channel_type: str | None = Field(default=None, description="Channel type: wechat, douyin, jd")
metadata_: dict[str, Any] | None = Field(
default=None,
sa_column=Column("metadata", JSON, nullable=True),
description="Session metadata"
)
created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class ChatMessage(SQLModel, table=True):
"""
[AC-AISVC-13] Chat message entity with tenant isolation.
Messages are scoped by (tenant_id, session_id) for multi-tenant security.
"""
__tablename__ = "chat_messages"
__table_args__ = (
Index("ix_chat_messages_tenant_session", "tenant_id", "session_id"),
Index("ix_chat_messages_tenant_session_created", "tenant_id", "session_id", "created_at"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
session_id: str = Field(..., description="Session ID for conversation tracking", index=True)
role: str = Field(..., description="Message role: user or assistant")
content: str = Field(..., description="Message content")
prompt_tokens: int | None = Field(default=None, description="Number of prompt tokens used")
completion_tokens: int | None = Field(default=None, description="Number of completion tokens used")
total_tokens: int | None = Field(default=None, description="Total tokens used")
latency_ms: int | None = Field(default=None, description="Response latency in milliseconds")
first_token_ms: int | None = Field(default=None, description="Time to first token in milliseconds (for streaming)")
is_error: bool = Field(default=False, description="Whether this message is an error response")
error_message: str | None = Field(default=None, description="Error message if any")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Message creation time")
class ChatSessionCreate(SQLModel):
"""Schema for creating a new chat session."""
tenant_id: str
session_id: str
channel_type: str | None = None
metadata_: dict[str, Any] | None = None
class ChatMessageCreate(SQLModel):
"""Schema for creating a new chat message."""
tenant_id: str
session_id: str
role: str
content: str
class DocumentStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class IndexJobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class SessionStatus(str, Enum):
ACTIVE = "active"
CLOSED = "closed"
EXPIRED = "expired"
class Tenant(SQLModel, table=True):
"""
[AC-AISVC-10] Tenant entity for storing tenant information.
Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
"""
__tablename__ = "tenants"
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Full tenant ID (format: name@ash@year)", unique=True, index=True)
name: str = Field(..., description="Tenant display name (first part of tenant_id)")
year: str = Field(..., description="Year part from tenant_id")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class KnowledgeBase(SQLModel, table=True):
"""
[AC-ASA-01] Knowledge base entity with tenant isolation.
"""
__tablename__ = "knowledge_bases"
__table_args__ = (
Index("ix_knowledge_bases_tenant_id", "tenant_id"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
name: str = Field(..., description="Knowledge base name")
description: str | None = Field(default=None, description="Knowledge base description")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class Document(SQLModel, table=True):
"""
[AC-ASA-01, AC-ASA-08] Document entity with tenant isolation.
"""
__tablename__ = "documents"
__table_args__ = (
Index("ix_documents_tenant_kb", "tenant_id", "kb_id"),
Index("ix_documents_tenant_status", "tenant_id", "status"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
kb_id: str = Field(..., description="Knowledge base ID")
file_name: str = Field(..., description="Original file name")
file_path: str | None = Field(default=None, description="Storage path")
file_size: int | None = Field(default=None, description="File size in bytes")
file_type: str | None = Field(default=None, description="File MIME type")
status: str = Field(default=DocumentStatus.PENDING.value, description="Document status")
error_msg: str | None = Field(default=None, description="Error message if failed")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class IndexJob(SQLModel, table=True):
"""
[AC-ASA-02] Index job entity for tracking document indexing progress.
"""
__tablename__ = "index_jobs"
__table_args__ = (
Index("ix_index_jobs_tenant_doc", "tenant_id", "doc_id"),
Index("ix_index_jobs_tenant_status", "tenant_id", "status"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
doc_id: uuid.UUID = Field(..., description="Document ID being indexed")
status: str = Field(default=IndexJobStatus.PENDING.value, description="Job status")
progress: int = Field(default=0, ge=0, le=100, description="Progress percentage")
error_msg: str | None = Field(default=None, description="Error message if failed")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Job creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class KnowledgeBaseCreate(SQLModel):
"""Schema for creating a new knowledge base."""
tenant_id: str
name: str
description: str | None = None
class DocumentCreate(SQLModel):
"""Schema for creating a new document."""
tenant_id: str
kb_id: str
file_name: str
file_path: str | None = None
file_size: int | None = None
file_type: str | None = None

View File

@ -0,0 +1,9 @@
"""
Services module for AI Service.
[AC-AISVC-13, AC-AISVC-16] Core services for memory and retrieval.
"""
from app.services.memory import MemoryService
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
__all__ = ["MemoryService", "OrchestratorService", "get_orchestrator_service"]

View File

@ -0,0 +1,224 @@
"""
Confidence calculation for AI Service.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Confidence scoring and transfer suggestion logic.
Design reference: design.md Section 4.3 - 检索不中兜底与置信度策略
- Retrieval insufficiency detection
- Confidence calculation based on retrieval scores
- shouldTransfer logic with threshold T_low
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from app.core.config import get_settings
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ConfidenceConfig:
"""
Configuration for confidence calculation.
[AC-AISVC-17, AC-AISVC-18] Configurable thresholds.
"""
score_threshold: float = 0.7
min_hits: int = 1
confidence_low_threshold: float = 0.5
confidence_high_threshold: float = 0.8
insufficient_penalty: float = 0.3
max_evidence_tokens: int = 2000
@dataclass
class ConfidenceResult:
"""
Result of confidence calculation.
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Contains confidence and transfer suggestion.
"""
confidence: float
should_transfer: bool
transfer_reason: str | None = None
is_retrieval_insufficient: bool = False
diagnostics: dict[str, Any] = field(default_factory=dict)
class ConfidenceCalculator:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculator for response confidence.
Design reference: design.md Section 4.3
- MVP: confidence based on RAG retrieval scores
- Insufficient retrieval triggers confidence downgrade
- shouldTransfer when confidence < T_low
"""
def __init__(self, config: ConfidenceConfig | None = None):
settings = get_settings()
self._config = config or ConfidenceConfig(
score_threshold=getattr(settings, "rag_score_threshold", 0.7),
min_hits=getattr(settings, "rag_min_hits", 1),
confidence_low_threshold=getattr(settings, "confidence_low_threshold", 0.5),
confidence_high_threshold=getattr(settings, "confidence_high_threshold", 0.8),
insufficient_penalty=getattr(settings, "confidence_insufficient_penalty", 0.3),
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
)
def is_retrieval_insufficient(
self,
retrieval_result: RetrievalResult,
evidence_tokens: int | None = None,
) -> tuple[bool, str]:
"""
[AC-AISVC-17] Determine if retrieval results are insufficient.
Conditions for insufficiency:
1. hits.size < min_hits
2. max(score) < score_threshold
3. evidence tokens exceed limit (optional)
Args:
retrieval_result: Result from retrieval operation
evidence_tokens: Optional token count for evidence
Returns:
Tuple of (is_insufficient, reason)
"""
reasons = []
if retrieval_result.hit_count < self._config.min_hits:
reasons.append(
f"hit_count({retrieval_result.hit_count}) < min_hits({self._config.min_hits})"
)
if retrieval_result.max_score < self._config.score_threshold:
reasons.append(
f"max_score({retrieval_result.max_score:.3f}) < threshold({self._config.score_threshold})"
)
if evidence_tokens is not None and evidence_tokens > self._config.max_evidence_tokens:
reasons.append(
f"evidence_tokens({evidence_tokens}) > max({self._config.max_evidence_tokens})"
)
is_insufficient = len(reasons) > 0
reason = "; ".join(reasons) if reasons else "sufficient"
return is_insufficient, reason
def calculate_confidence(
self,
retrieval_result: RetrievalResult,
evidence_tokens: int | None = None,
additional_factors: dict[str, float] | None = None,
) -> ConfidenceResult:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence and transfer suggestion.
MVP Strategy:
1. Base confidence from max retrieval score
2. Adjust for hit count (more hits = higher confidence)
3. Penalize if retrieval is insufficient
4. Determine shouldTransfer based on T_low threshold
Args:
retrieval_result: Result from retrieval operation
evidence_tokens: Optional token count for evidence
additional_factors: Optional additional confidence factors
Returns:
ConfidenceResult with confidence and transfer suggestion
"""
is_insufficient, insufficiency_reason = self.is_retrieval_insufficient(
retrieval_result, evidence_tokens
)
base_confidence = retrieval_result.max_score
hit_count_factor = min(1.0, retrieval_result.hit_count / 5.0)
confidence = base_confidence * 0.7 + hit_count_factor * 0.3
if is_insufficient:
confidence -= self._config.insufficient_penalty
logger.info(
f"[AC-AISVC-17] Retrieval insufficient: {insufficiency_reason}, "
f"applying penalty -{self._config.insufficient_penalty}"
)
if additional_factors:
for factor_name, factor_value in additional_factors.items():
confidence += factor_value * 0.1
confidence = max(0.0, min(1.0, confidence))
should_transfer = confidence < self._config.confidence_low_threshold
transfer_reason = None
if should_transfer:
if is_insufficient:
transfer_reason = "检索结果不足,无法提供高置信度回答"
else:
transfer_reason = "置信度低于阈值,建议转人工"
elif confidence < self._config.confidence_high_threshold and is_insufficient:
transfer_reason = "检索结果有限,回答可能不够准确"
diagnostics = {
"base_confidence": base_confidence,
"hit_count": retrieval_result.hit_count,
"max_score": retrieval_result.max_score,
"is_insufficient": is_insufficient,
"insufficiency_reason": insufficiency_reason if is_insufficient else None,
"penalty_applied": self._config.insufficient_penalty if is_insufficient else 0.0,
"threshold_low": self._config.confidence_low_threshold,
"threshold_high": self._config.confidence_high_threshold,
}
logger.info(
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
f"{confidence:.3f}, should_transfer={should_transfer}, "
f"insufficient={is_insufficient}"
)
return ConfidenceResult(
confidence=round(confidence, 3),
should_transfer=should_transfer,
transfer_reason=transfer_reason,
is_retrieval_insufficient=is_insufficient,
diagnostics=diagnostics,
)
def calculate_confidence_no_retrieval(self) -> ConfidenceResult:
"""
[AC-AISVC-17] Calculate confidence when no retrieval was performed.
Returns a low confidence result suggesting transfer.
"""
return ConfidenceResult(
confidence=0.3,
should_transfer=True,
transfer_reason="未进行知识库检索,建议转人工",
is_retrieval_insufficient=True,
diagnostics={
"base_confidence": 0.0,
"hit_count": 0,
"max_score": 0.0,
"is_insufficient": True,
"insufficiency_reason": "no_retrieval",
"penalty_applied": 0.0,
"threshold_low": self._config.confidence_low_threshold,
"threshold_high": self._config.confidence_high_threshold,
},
)
_confidence_calculator: ConfidenceCalculator | None = None
def get_confidence_calculator() -> ConfidenceCalculator:
"""Get or create confidence calculator instance."""
global _confidence_calculator
if _confidence_calculator is None:
_confidence_calculator = ConfidenceCalculator()
return _confidence_calculator

View File

@ -0,0 +1,245 @@
"""
Context management utilities for AI Service.
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
Design reference: design.md Section 7 - 上下文合并规则
- H_local: Memory layer history (sorted by time)
- H_ext: External history from Java request (in passed order)
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
- Truncation: Keep most recent N messages within token budget
"""
import hashlib
import logging
from dataclasses import dataclass, field
from typing import Any
import tiktoken
from app.core.config import get_settings
from app.models import ChatMessage, Role
logger = logging.getLogger(__name__)
@dataclass
class MergedContext:
"""
Result of context merging.
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
"""
messages: list[dict[str, str]] = field(default_factory=list)
total_tokens: int = 0
local_count: int = 0
external_count: int = 0
duplicates_skipped: int = 0
truncated_count: int = 0
diagnostics: list[dict[str, Any]] = field(default_factory=list)
class ContextMerger:
"""
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
Design reference: design.md Section 7
- Deduplication based on message fingerprint
- Priority: local history takes precedence
- Token-based truncation using tiktoken
"""
def __init__(
self,
max_history_tokens: int | None = None,
encoding_name: str = "cl100k_base",
):
settings = get_settings()
self._max_history_tokens = max_history_tokens or 4096
self._encoding = tiktoken.get_encoding(encoding_name)
def compute_fingerprint(self, role: str, content: str) -> str:
"""
Compute message fingerprint for deduplication.
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
Args:
role: Message role (user/assistant)
content: Message content
Returns:
SHA256 hash of the normalized message
"""
normalized_content = content.strip()
fingerprint_input = f"{role}|{normalized_content}"
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
"""Convert ChatMessage or dict to standard dict format."""
if isinstance(message, ChatMessage):
return {"role": message.role.value, "content": message.content}
return message
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
"""
Count total tokens in messages using tiktoken.
[AC-AISVC-14] Token counting for history truncation.
"""
total = 0
for msg in messages:
total += len(self._encoding.encode(msg.get("role", "")))
total += len(self._encoding.encode(msg.get("content", "")))
total += 4 # Approximate overhead for message structure
return total
def merge_context(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
) -> MergedContext:
"""
Merge local and external history with deduplication.
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
Design reference: design.md Section 7.2
1. Build seen set from H_local
2. Traverse H_ext, append if fingerprint not seen
3. Local history takes priority
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
Returns:
MergedContext with merged messages and diagnostics
"""
result = MergedContext()
seen_fingerprints: set[str] = set()
merged_messages: list[dict[str, str]] = []
diagnostics: list[dict[str, Any]] = []
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
for msg in local_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.local_count += 1
for msg in external_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
if fingerprint not in seen_fingerprints:
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.external_count += 1
else:
result.duplicates_skipped += 1
diagnostics.append({
"type": "duplicate_skipped",
"role": msg["role"],
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
})
result.messages = merged_messages
result.diagnostics = diagnostics
result.total_tokens = self._count_tokens(merged_messages)
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={result.local_count}, external={result.external_count}, "
f"duplicates_skipped={result.duplicates_skipped}, "
f"total_tokens={result.total_tokens}"
)
return result
def truncate_context(
self,
messages: list[dict[str, str]],
max_tokens: int | None = None,
) -> tuple[list[dict[str, str]], int]:
"""
Truncate context to fit within token budget.
[AC-AISVC-14] Keep most recent N messages within budget.
Design reference: design.md Section 7.4
- Budget = maxHistoryTokens (configurable)
- Strategy: Keep most recent messages (from tail backward)
Args:
messages: List of messages to truncate
max_tokens: Maximum token budget (uses default if not provided)
Returns:
Tuple of (truncated messages, truncated count)
"""
budget = max_tokens or self._max_history_tokens
if not messages:
return [], 0
total_tokens = self._count_tokens(messages)
if total_tokens <= budget:
return messages, 0
truncated_messages: list[dict[str, str]] = []
current_tokens = 0
truncated_count = 0
for msg in reversed(messages):
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
msg_tokens += 4
if current_tokens + msg_tokens <= budget:
truncated_messages.insert(0, msg)
current_tokens += msg_tokens
else:
truncated_count += 1
logger.info(
f"[AC-AISVC-14] Context truncated: "
f"original={len(messages)}, truncated={len(truncated_messages)}, "
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
)
return truncated_messages, truncated_count
def merge_and_truncate(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
max_tokens: int | None = None,
) -> MergedContext:
"""
Merge and truncate context in one operation.
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
max_tokens: Maximum token budget
Returns:
MergedContext with final messages after merge and truncate
"""
merged = self.merge_context(local_history, external_history)
truncated_messages, truncated_count = self.truncate_context(
merged.messages, max_tokens
)
merged.messages = truncated_messages
merged.truncated_count = truncated_count
merged.total_tokens = self._count_tokens(truncated_messages)
return merged
_context_merger: ContextMerger | None = None
def get_context_merger() -> ContextMerger:
"""Get or create context merger instance."""
global _context_merger
if _context_merger is None:
_context_merger = ContextMerger()
return _context_merger

View File

@ -0,0 +1,38 @@
"""
Document parsing services package.
[AC-AISVC-33] Provides document parsers for various formats.
"""
from app.services.document.base import (
DocumentParseException,
DocumentParser,
PageText,
ParseResult,
UnsupportedFormatError,
)
from app.services.document.excel_parser import CSVParser, ExcelParser
from app.services.document.factory import (
DocumentParserFactory,
get_supported_document_formats,
parse_document,
)
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
from app.services.document.text_parser import TextParser
from app.services.document.word_parser import WordParser
__all__ = [
"DocumentParseException",
"DocumentParser",
"PageText",
"ParseResult",
"UnsupportedFormatError",
"DocumentParserFactory",
"get_supported_document_formats",
"parse_document",
"PDFParser",
"PDFPlumberParser",
"WordParser",
"ExcelParser",
"CSVParser",
"TextParser",
]

View File

@ -0,0 +1,116 @@
"""
Base document parser interface.
[AC-AISVC-33] Abstract interface for document parsers.
Design reference: progress.md Section 7.2 - DocumentParser interface
- parse(file_path) -> str
- get_supported_extensions() -> list[str]
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class PageText:
"""
Text content from a single page.
"""
page: int
text: str
@dataclass
class ParseResult:
"""
Result from document parsing.
[AC-AISVC-33] Contains parsed text and metadata.
"""
text: str
source_path: str
file_size: int
page_count: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
pages: list[PageText] = field(default_factory=list)
class DocumentParser(ABC):
"""
Abstract base class for document parsers.
[AC-AISVC-33] Provides unified interface for different document formats.
"""
@abstractmethod
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a document and extract text content.
[AC-AISVC-33] Returns parsed text content.
Args:
file_path: Path to the document file.
Returns:
ParseResult with extracted text and metadata.
Raises:
DocumentParseException: If parsing fails.
"""
pass
@abstractmethod
def get_supported_extensions(self) -> list[str]:
"""
Get list of supported file extensions.
[AC-AISVC-37] Returns supported format list.
Returns:
List of file extensions (e.g., [".pdf", ".txt"])
"""
pass
def supports_extension(self, extension: str) -> bool:
"""
Check if this parser supports a given file extension.
[AC-AISVC-37] Validates file format support.
Args:
extension: File extension to check.
Returns:
True if extension is supported.
"""
normalized = extension.lower()
if not normalized.startswith("."):
normalized = f".{normalized}"
return normalized in self.get_supported_extensions()
class DocumentParseException(Exception):
"""Exception raised when document parsing fails."""
def __init__(
self,
message: str,
file_path: str = "",
parser: str = "",
details: dict[str, Any] | None = None
):
self.file_path = file_path
self.parser = parser
self.details = details or {}
super().__init__(f"[{parser}] {message}" if parser else message)
class UnsupportedFormatError(DocumentParseException):
"""Exception raised when file format is not supported."""
def __init__(self, extension: str, supported: list[str]):
super().__init__(
f"Unsupported file format: {extension}. "
f"Supported formats: {', '.join(supported)}",
parser="format_checker"
)
self.extension = extension
self.supported_formats = supported

View File

@ -0,0 +1,273 @@
"""
Excel document parser implementation.
[AC-AISVC-35] Excel (.xlsx) parsing using openpyxl.
Extracts text content from Excel spreadsheets and converts to JSON format
to preserve structural relationships for better RAG retrieval.
"""
import json
import logging
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
ParseResult,
)
logger = logging.getLogger(__name__)
class ExcelParser(DocumentParser):
"""
Parser for Excel documents.
[AC-AISVC-35] Uses openpyxl for text extraction.
Converts spreadsheet data to JSON format to preserve structure.
"""
def __init__(
self,
include_empty_cells: bool = False,
max_rows_per_sheet: int = 10000,
**kwargs: Any
):
self._include_empty_cells = include_empty_cells
self._max_rows_per_sheet = max_rows_per_sheet
self._extra_config = kwargs
self._openpyxl = None
def _get_openpyxl(self):
"""Lazy import of openpyxl."""
if self._openpyxl is None:
try:
import openpyxl
self._openpyxl = openpyxl
except ImportError:
raise DocumentParseException(
"openpyxl not installed. Install with: pip install openpyxl",
parser="excel"
)
return self._openpyxl
def _sheet_to_records(self, sheet, sheet_name: str) -> list[dict[str, Any]]:
"""
Convert a worksheet to a list of record dictionaries.
First row is treated as header (column names).
"""
records = []
rows = list(sheet.iter_rows(max_row=self._max_rows_per_sheet, values_only=True))
if not rows:
return records
headers = rows[0]
header_list = [str(h) if h is not None else f"column_{i}" for i, h in enumerate(headers)]
for row in rows[1:]:
record = {"_sheet": sheet_name}
has_content = False
for i, value in enumerate(row):
if i < len(header_list):
key = header_list[i]
else:
key = f"column_{i}"
if value is not None:
has_content = True
if isinstance(value, (int, float, bool)):
record[key] = value
else:
record[key] = str(value)
elif self._include_empty_cells:
record[key] = None
if has_content or self._include_empty_cells:
records.append(record)
return records
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse an Excel document and extract text content as JSON.
[AC-AISVC-35] Converts spreadsheet data to JSON format.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="excel"
)
if not self.supports_extension(path.suffix):
raise DocumentParseException(
f"Unsupported file extension: {path.suffix}",
file_path=str(path),
parser="excel"
)
openpyxl = self._get_openpyxl()
try:
workbook = openpyxl.load_workbook(path, read_only=True, data_only=True)
all_records: list[dict[str, Any]] = []
sheet_count = len(workbook.sheetnames)
total_rows = 0
for sheet_name in workbook.sheetnames:
sheet = workbook[sheet_name]
records = self._sheet_to_records(sheet, sheet_name)
all_records.extend(records)
total_rows += len(records)
workbook.close()
json_str = json.dumps(all_records, ensure_ascii=False, indent=2)
file_size = path.stat().st_size
logger.info(
f"Parsed Excel (JSON): {path.name}, sheets={sheet_count}, "
f"rows={total_rows}, chars={len(json_str)}, size={file_size}"
)
return ParseResult(
text=json_str,
source_path=str(path),
file_size=file_size,
metadata={
"format": "xlsx",
"output_format": "json",
"sheet_count": sheet_count,
"total_rows": total_rows,
}
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse Excel document: {e}",
file_path=str(path),
parser="excel",
details={"error": str(e)}
)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".xlsx", ".xls"]
class CSVParser(DocumentParser):
"""
Parser for CSV files.
[AC-AISVC-35] Uses Python's built-in csv module.
Converts CSV data to JSON format to preserve structure.
"""
def __init__(self, delimiter: str = ",", encoding: str = "utf-8", **kwargs: Any):
self._delimiter = delimiter
self._encoding = encoding
self._extra_config = kwargs
def _parse_csv_to_records(self, path: Path, encoding: str) -> list[dict[str, Any]]:
"""Parse CSV file and return list of record dictionaries."""
import csv
records = []
with open(path, "r", encoding=encoding, newline="") as f:
reader = csv.reader(f, delimiter=self._delimiter)
rows = list(reader)
if not rows:
return records
headers = rows[0]
header_list = [str(h) if h else f"column_{i}" for i, h in enumerate(headers)]
for row in rows[1:]:
record = {}
has_content = False
for i, value in enumerate(row):
if i < len(header_list):
key = header_list[i]
else:
key = f"column_{i}"
if value:
has_content = True
record[key] = value
if has_content:
records.append(record)
return records
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a CSV file and extract text content as JSON.
[AC-AISVC-35] Converts CSV data to JSON format.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="csv"
)
try:
records = self._parse_csv_to_records(path, self._encoding)
row_count = len(records)
used_encoding = self._encoding
except UnicodeDecodeError:
try:
records = self._parse_csv_to_records(path, "gbk")
row_count = len(records)
used_encoding = "gbk"
except Exception as e:
raise DocumentParseException(
f"Failed to parse CSV with encoding fallback: {e}",
file_path=str(path),
parser="csv",
details={"error": str(e)}
)
except Exception as e:
raise DocumentParseException(
f"Failed to parse CSV: {e}",
file_path=str(path),
parser="csv",
details={"error": str(e)}
)
json_str = json.dumps(records, ensure_ascii=False, indent=2)
file_size = path.stat().st_size
logger.info(
f"Parsed CSV (JSON): {path.name}, rows={row_count}, "
f"chars={len(json_str)}, size={file_size}"
)
return ParseResult(
text=json_str,
source_path=str(path),
file_size=file_size,
metadata={
"format": "csv",
"output_format": "json",
"row_count": row_count,
"delimiter": self._delimiter,
"encoding": used_encoding,
}
)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".csv"]

View File

@ -0,0 +1,215 @@
"""
Document parser factory.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Factory for document parsers.
Design reference: progress.md Section 7.2 - DocumentParserFactory
"""
import logging
from pathlib import Path
from typing import Any, Type
from app.services.document.base import (
DocumentParser,
DocumentParseException,
ParseResult,
UnsupportedFormatError,
)
from app.services.document.excel_parser import CSVParser, ExcelParser
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
from app.services.document.text_parser import TextParser
from app.services.document.word_parser import WordParser
logger = logging.getLogger(__name__)
class DocumentParserFactory:
"""
Factory for creating document parsers.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Auto-selects parser based on file extension.
"""
_parsers: dict[str, Type[DocumentParser]] = {}
_extension_map: dict[str, str] = {}
@classmethod
def _initialize(cls) -> None:
"""Initialize default parsers."""
if cls._parsers:
return
cls._parsers = {
"pdf": PDFParser,
"pdfplumber": PDFPlumberParser,
"word": WordParser,
"excel": ExcelParser,
"csv": CSVParser,
"text": TextParser,
}
cls._extension_map = {
".pdf": "pdf",
".docx": "word",
".xlsx": "excel",
".xls": "excel",
".csv": "csv",
".txt": "text",
".md": "text",
".markdown": "text",
".rst": "text",
".log": "text",
".json": "text",
".xml": "text",
".yaml": "text",
".yml": "text",
}
@classmethod
def register_parser(
cls,
name: str,
parser_class: Type[DocumentParser],
extensions: list[str],
) -> None:
"""
Register a new document parser.
[AC-AISVC-33] Allows runtime registration of parsers.
"""
cls._initialize()
cls._parsers[name] = parser_class
for ext in extensions:
cls._extension_map[ext.lower()] = name
logger.info(f"Registered document parser: {name} for extensions: {extensions}")
@classmethod
def get_supported_extensions(cls) -> list[str]:
"""
Get all supported file extensions.
[AC-AISVC-37] Returns list of supported formats.
"""
cls._initialize()
return list(cls._extension_map.keys())
@classmethod
def get_parser_for_extension(cls, extension: str) -> DocumentParser:
"""
Get a parser instance for a file extension.
[AC-AISVC-33] Creates appropriate parser based on extension.
"""
cls._initialize()
normalized = extension.lower()
if not normalized.startswith("."):
normalized = f".{normalized}"
if normalized not in cls._extension_map:
raise UnsupportedFormatError(normalized, cls.get_supported_extensions())
parser_name = cls._extension_map[normalized]
parser_class = cls._parsers[parser_name]
return parser_class()
@classmethod
def parse_file(
cls,
file_path: str | Path,
parser_name: str | None = None,
parser_config: dict[str, Any] | None = None,
) -> ParseResult:
"""
Parse a document file.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Main entry point for parsing.
Args:
file_path: Path to the document file
parser_name: Optional specific parser to use
parser_config: Optional configuration for the parser
Returns:
ParseResult with extracted text and metadata
Raises:
UnsupportedFormatError: If file format is not supported
DocumentParseException: If parsing fails
"""
cls._initialize()
path = Path(file_path)
extension = path.suffix.lower()
if parser_name:
if parser_name not in cls._parsers:
raise DocumentParseException(
f"Unknown parser: {parser_name}",
file_path=str(path),
parser="factory"
)
parser_class = cls._parsers[parser_name]
parser = parser_class(**(parser_config or {}))
else:
parser = cls.get_parser_for_extension(extension)
if parser_config:
parser = type(parser)(**parser_config)
return parser.parse(path)
@classmethod
def get_parser_info(cls) -> list[dict[str, Any]]:
"""
Get information about available parsers.
[AC-AISVC-37] Returns parser metadata.
"""
cls._initialize()
info = []
for name, parser_class in cls._parsers.items():
temp_instance = parser_class.__new__(parser_class)
extensions = temp_instance.get_supported_extensions()
display_names = {
"pdf": "PDF 文档",
"pdfplumber": "PDF 文档 (pdfplumber)",
"word": "Word 文档",
"excel": "Excel 电子表格",
"csv": "CSV 文件",
"text": "文本文件",
}
descriptions = {
"pdf": "使用 PyMuPDF 解析 PDF 文档,速度快",
"pdfplumber": "使用 pdfplumber 解析 PDF 文档,表格提取效果更好",
"word": "解析 Word 文档 (.docx),保留段落结构",
"excel": "解析 Excel 电子表格,支持多工作表",
"csv": "解析 CSV 文件,自动检测编码",
"text": "解析纯文本文件,支持多种编码",
}
info.append({
"name": name,
"display_name": display_names.get(name, name),
"description": descriptions.get(name, ""),
"extensions": extensions,
})
return info
def parse_document(
file_path: str | Path,
parser_name: str | None = None,
parser_config: dict[str, Any] | None = None,
) -> ParseResult:
"""
Convenience function for parsing documents.
[AC-AISVC-33] Simple entry point for document parsing.
"""
return DocumentParserFactory.parse_file(file_path, parser_name, parser_config)
def get_supported_document_formats() -> list[str]:
"""
Get list of supported document formats.
[AC-AISVC-37] Returns supported format extensions.
"""
return DocumentParserFactory.get_supported_extensions()

View File

@ -0,0 +1,229 @@
"""
PDF document parser implementation.
[AC-AISVC-33] PDF parsing using PyMuPDF (fitz).
Extracts text content from PDF files.
"""
import logging
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
PageText,
ParseResult,
)
logger = logging.getLogger(__name__)
class PDFParser(DocumentParser):
"""
Parser for PDF documents.
[AC-AISVC-33] Uses PyMuPDF for text extraction.
"""
def __init__(self, extract_images: bool = False, **kwargs: Any):
self._extract_images = extract_images
self._extra_config = kwargs
self._fitz = None
def _get_fitz(self):
"""Lazy import of PyMuPDF."""
if self._fitz is None:
try:
import fitz
self._fitz = fitz
except ImportError:
raise DocumentParseException(
"PyMuPDF (fitz) not installed. Install with: pip install pymupdf",
parser="pdf"
)
return self._fitz
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a PDF document and extract text content.
[AC-AISVC-33] Extracts text from all pages.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="pdf"
)
if not self.supports_extension(path.suffix):
raise DocumentParseException(
f"Unsupported file extension: {path.suffix}",
file_path=str(path),
parser="pdf"
)
fitz = self._get_fitz()
try:
doc = fitz.open(path)
pages: list[PageText] = []
text_parts = []
page_count = len(doc)
for page_num in range(page_count):
page = doc[page_num]
text = page.get_text().strip()
if text:
pages.append(PageText(page=page_num + 1, text=text))
text_parts.append(f"[Page {page_num + 1}]\n{text}")
doc.close()
full_text = "\n\n".join(text_parts)
file_size = path.stat().st_size
logger.info(
f"Parsed PDF: {path.name}, pages={page_count}, "
f"chars={len(full_text)}, size={file_size}"
)
return ParseResult(
text=full_text,
source_path=str(path),
file_size=file_size,
page_count=page_count,
metadata={
"format": "pdf",
"page_count": page_count,
},
pages=pages,
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse PDF: {e}",
file_path=str(path),
parser="pdf",
details={"error": str(e)}
)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".pdf"]
class PDFPlumberParser(DocumentParser):
"""
Alternative PDF parser using pdfplumber.
[AC-AISVC-33] Uses pdfplumber for text extraction.
pdfplumber is better for table extraction but slower than PyMuPDF.
"""
def __init__(self, extract_tables: bool = True, **kwargs: Any):
self._extract_tables = extract_tables
self._extra_config = kwargs
self._pdfplumber = None
def _get_pdfplumber(self):
"""Lazy import of pdfplumber."""
if self._pdfplumber is None:
try:
import pdfplumber
self._pdfplumber = pdfplumber
except ImportError:
raise DocumentParseException(
"pdfplumber not installed. Install with: pip install pdfplumber",
parser="pdfplumber"
)
return self._pdfplumber
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a PDF document and extract text content.
[AC-AISVC-33] Extracts text and optionally tables.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="pdfplumber"
)
pdfplumber = self._get_pdfplumber()
try:
pages: list[PageText] = []
text_parts = []
page_count = 0
with pdfplumber.open(path) as pdf:
page_count = len(pdf.pages)
for page_num, page in enumerate(pdf.pages):
text = page.extract_text() or ""
if self._extract_tables:
tables = page.extract_tables()
for table in tables:
table_text = self._format_table(table)
text += f"\n\n{table_text}"
text = text.strip()
if text:
pages.append(PageText(page=page_num + 1, text=text))
text_parts.append(f"[Page {page_num + 1}]\n{text}")
full_text = "\n\n".join(text_parts)
file_size = path.stat().st_size
logger.info(
f"Parsed PDF (pdfplumber): {path.name}, pages={page_count}, "
f"chars={len(full_text)}, size={file_size}"
)
return ParseResult(
text=full_text,
source_path=str(path),
file_size=file_size,
page_count=page_count,
metadata={
"format": "pdf",
"parser": "pdfplumber",
"page_count": page_count,
},
pages=pages,
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse PDF: {e}",
file_path=str(path),
parser="pdfplumber",
details={"error": str(e)}
)
def _format_table(self, table: list[list[str | None]]) -> str:
"""Format a table as text."""
if not table:
return ""
lines = []
for row in table:
cells = [str(cell) if cell else "" for cell in row]
lines.append(" | ".join(cells))
return "\n".join(lines)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".pdf"]

View File

@ -0,0 +1,99 @@
"""
Text file parser implementation.
[AC-AISVC-33] Text file parsing for plain text and markdown.
"""
import logging
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
ParseResult,
)
logger = logging.getLogger(__name__)
ENCODINGS_TO_TRY = ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]
class TextParser(DocumentParser):
"""
Parser for plain text files.
[AC-AISVC-33] Direct text extraction with multiple encoding support.
"""
def __init__(self, encoding: str = "utf-8", **kwargs: Any):
self._encoding = encoding
self._extra_config = kwargs
def _try_encodings(self, path: Path) -> tuple[str, str]:
"""
Try multiple encodings to read the file.
Returns: (text, encoding_used)
"""
for enc in ENCODINGS_TO_TRY:
try:
with open(path, "r", encoding=enc) as f:
text = f.read()
logger.info(f"Successfully parsed with encoding: {enc}")
return text, enc
except (UnicodeDecodeError, LookupError):
continue
raise DocumentParseException(
f"Failed to decode file with any known encoding",
file_path=str(path),
parser="text"
)
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a text file and extract content.
[AC-AISVC-33] Direct file reading.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="text"
)
try:
text, encoding_used = self._try_encodings(path)
file_size = path.stat().st_size
line_count = text.count("\n") + 1
logger.info(
f"Parsed text: {path.name}, lines={line_count}, "
f"chars={len(text)}, size={file_size}, encoding={encoding_used}"
)
return ParseResult(
text=text,
source_path=str(path),
file_size=file_size,
metadata={
"format": "text",
"line_count": line_count,
"encoding": encoding_used,
}
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse text file: {e}",
file_path=str(path),
parser="text",
details={"error": str(e)}
)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"]

View File

@ -0,0 +1,145 @@
"""
Word document parser implementation.
[AC-AISVC-34] Word (.docx) parsing using python-docx.
Extracts text content from Word documents.
"""
import logging
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
ParseResult,
)
logger = logging.getLogger(__name__)
class WordParser(DocumentParser):
"""
Parser for Word documents.
[AC-AISVC-34] Uses python-docx for text extraction.
"""
def __init__(self, include_headers: bool = True, include_footers: bool = True, **kwargs: Any):
self._include_headers = include_headers
self._include_footers = include_footers
self._extra_config = kwargs
self._docx = None
def _get_docx(self):
"""Lazy import of python-docx."""
if self._docx is None:
try:
from docx import Document
self._docx = Document
except ImportError:
raise DocumentParseException(
"python-docx not installed. Install with: pip install python-docx",
parser="word"
)
return self._docx
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a Word document and extract text content.
[AC-AISVC-34] Extracts text while preserving paragraph structure.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="word"
)
if not self.supports_extension(path.suffix):
raise DocumentParseException(
f"Unsupported file extension: {path.suffix}",
file_path=str(path),
parser="word"
)
Document = self._get_docx()
try:
doc = Document(path)
text_parts = []
if self._include_headers:
for section in doc.sections:
header = section.header
if header and header.paragraphs:
header_text = "\n".join(p.text for p in header.paragraphs if p.text.strip())
if header_text:
text_parts.append(f"[Header]\n{header_text}")
for para in doc.paragraphs:
if para.text.strip():
style_name = para.style.name if para.style else ""
if "Heading" in style_name:
text_parts.append(f"\n## {para.text}")
else:
text_parts.append(para.text)
for table in doc.tables:
table_text = self._format_table(table)
if table_text.strip():
text_parts.append(f"\n[Table]\n{table_text}")
if self._include_footers:
for section in doc.sections:
footer = section.footer
if footer and footer.paragraphs:
footer_text = "\n".join(p.text for p in footer.paragraphs if p.text.strip())
if footer_text:
text_parts.append(f"[Footer]\n{footer_text}")
full_text = "\n\n".join(text_parts)
file_size = path.stat().st_size
paragraph_count = len(doc.paragraphs)
table_count = len(doc.tables)
logger.info(
f"Parsed Word: {path.name}, paragraphs={paragraph_count}, "
f"tables={table_count}, chars={len(full_text)}, size={file_size}"
)
return ParseResult(
text=full_text,
source_path=str(path),
file_size=file_size,
metadata={
"format": "docx",
"paragraph_count": paragraph_count,
"table_count": table_count,
}
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse Word document: {e}",
file_path=str(path),
parser="word",
details={"error": str(e)}
)
def _format_table(self, table) -> str:
"""Format a table as text."""
lines = []
for row in table.rows:
cells = [cell.text.strip() for cell in row.cells]
lines.append(" | ".join(cells))
return "\n".join(lines)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".docx"]

View File

@ -0,0 +1,40 @@
"""
Embedding services package.
[AC-AISVC-29] Provides pluggable embedding providers.
"""
from app.services.embedding.base import (
EmbeddingConfig,
EmbeddingException,
EmbeddingProvider,
EmbeddingResult,
)
from app.services.embedding.factory import (
EmbeddingConfigManager,
EmbeddingProviderFactory,
get_embedding_config_manager,
get_embedding_provider,
)
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
from app.services.embedding.nomic_provider import (
NomicEmbeddingProvider,
NomicEmbeddingResult,
EmbeddingTask,
)
__all__ = [
"EmbeddingConfig",
"EmbeddingException",
"EmbeddingProvider",
"EmbeddingResult",
"EmbeddingConfigManager",
"EmbeddingProviderFactory",
"get_embedding_config_manager",
"get_embedding_provider",
"OllamaEmbeddingProvider",
"OpenAIEmbeddingProvider",
"NomicEmbeddingProvider",
"NomicEmbeddingResult",
"EmbeddingTask",
]

View File

@ -0,0 +1,130 @@
"""
Base embedding provider interface.
[AC-AISVC-29] Abstract interface for embedding providers.
Design reference: progress.md Section 7.1 - EmbeddingProvider interface
- embed(text) -> list[float]
- embed_batch(texts) -> list[list[float]]
- get_dimension() -> int
- get_provider_name() -> str
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass
class EmbeddingConfig:
"""
Configuration for embedding provider.
[AC-AISVC-31] Supports configurable embedding parameters.
"""
dimension: int = 768
batch_size: int = 32
timeout_seconds: int = 60
extra_params: dict[str, Any] = field(default_factory=dict)
@dataclass
class EmbeddingResult:
"""
Result from embedding generation.
[AC-AISVC-29] Contains embedding vector and metadata.
"""
embedding: list[float]
dimension: int
model: str
latency_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class EmbeddingProvider(ABC):
"""
Abstract base class for embedding providers.
[AC-AISVC-29] Provides unified interface for different embedding providers.
Design reference: progress.md Section 7.1 - Architecture
- OllamaEmbeddingProvider / OpenAIEmbeddingProvider can be swapped
- Factory pattern for dynamic loading
"""
@abstractmethod
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text.
[AC-AISVC-29] Returns embedding vector.
Args:
text: Input text to embed.
Returns:
List of floats representing the embedding vector.
Raises:
EmbeddingException: If embedding generation fails.
"""
pass
@abstractmethod
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
[AC-AISVC-29] Returns list of embedding vectors.
Args:
texts: List of input texts to embed.
Returns:
List of embedding vectors.
Raises:
EmbeddingException: If embedding generation fails.
"""
pass
@abstractmethod
def get_dimension(self) -> int:
"""
Get the dimension of embedding vectors.
[AC-AISVC-29] Returns vector dimension.
Returns:
Integer dimension of embedding vectors.
"""
pass
@abstractmethod
def get_provider_name(self) -> str:
"""
Get the name of this embedding provider.
[AC-AISVC-29] Returns provider identifier.
Returns:
String identifier for this provider.
"""
pass
@abstractmethod
def get_config_schema(self) -> dict[str, Any]:
"""
Get the configuration schema for this provider.
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
Returns:
Dict describing configuration parameters.
"""
pass
async def close(self) -> None:
"""Close the provider and release resources. Default no-op."""
pass
class EmbeddingException(Exception):
"""Exception raised when embedding generation fails."""
def __init__(self, message: str, provider: str = "", details: dict[str, Any] | None = None):
self.provider = provider
self.details = details or {}
super().__init__(f"[{provider}] {message}" if provider else message)

View File

@ -0,0 +1,305 @@
"""
Embedding provider factory and configuration manager.
[AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading.
Design reference: progress.md Section 7.1 - Architecture
- EmbeddingProviderFactory: creates providers based on config
- EmbeddingConfigManager: manages configuration with hot-reload support
"""
import logging
from typing import Any, Type
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
logger = logging.getLogger(__name__)
class EmbeddingProviderFactory:
"""
Factory for creating embedding providers.
[AC-AISVC-30] Supports dynamic loading based on configuration.
"""
_providers: dict[str, Type[EmbeddingProvider]] = {
"ollama": OllamaEmbeddingProvider,
"openai": OpenAIEmbeddingProvider,
"nomic": NomicEmbeddingProvider,
}
@classmethod
def register_provider(cls, name: str, provider_class: Type[EmbeddingProvider]) -> None:
"""
Register a new embedding provider.
[AC-AISVC-30] Allows runtime registration of providers.
"""
cls._providers[name] = provider_class
logger.info(f"Registered embedding provider: {name}")
@classmethod
def get_available_providers(cls) -> list[str]:
"""
Get list of available provider names.
[AC-AISVC-38] Returns registered provider identifiers.
"""
return list(cls._providers.keys())
@classmethod
def get_provider_info(cls, name: str) -> dict[str, Any]:
"""
Get provider information including config schema.
[AC-AISVC-38] Returns provider metadata.
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown provider: {name}",
provider="factory"
)
provider_class = cls._providers[name]
temp_instance = provider_class.__new__(provider_class)
display_names = {
"ollama": "Ollama 本地模型",
"openai": "OpenAI Embedding",
"nomic": "Nomic Embed (优化版)",
}
descriptions = {
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
"openai": "使用 OpenAI 官方 Embedding API支持 text-embedding-3 系列模型",
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断专为RAG优化",
}
return {
"name": name,
"display_name": display_names.get(name, name),
"description": descriptions.get(name, ""),
"config_schema": temp_instance.get_config_schema(),
}
@classmethod
def create_provider(
cls,
name: str,
config: dict[str, Any],
) -> EmbeddingProvider:
"""
Create an embedding provider instance.
[AC-AISVC-30] Creates provider based on configuration.
Args:
name: Provider identifier (e.g., "ollama", "openai")
config: Provider-specific configuration
Returns:
Configured EmbeddingProvider instance
Raises:
EmbeddingException: If provider is unknown or configuration is invalid
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown embedding provider: {name}. "
f"Available: {cls.get_available_providers()}",
provider="factory"
)
provider_class = cls._providers[name]
try:
instance = provider_class(**config)
logger.info(f"Created embedding provider: {name}")
return instance
except Exception as e:
raise EmbeddingException(
f"Failed to create provider '{name}': {e}",
provider="factory",
details={"config": config}
)
class EmbeddingConfigManager:
"""
Manager for embedding configuration.
[AC-AISVC-31] Supports hot-reload of configuration.
"""
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
self._provider_name = default_provider
self._config = default_config or {
"base_url": "http://localhost:11434",
"model": "nomic-embed-text",
"dimension": 768,
}
self._provider: EmbeddingProvider | None = None
def get_provider_name(self) -> str:
"""Get current provider name."""
return self._provider_name
def get_config(self) -> dict[str, Any]:
"""Get current configuration."""
return self._config.copy()
def get_full_config(self) -> dict[str, Any]:
"""
Get full configuration including provider name.
[AC-AISVC-39] Returns complete configuration for API response.
"""
return {
"provider": self._provider_name,
"config": self._config.copy(),
}
async def get_provider(self) -> EmbeddingProvider:
"""
Get or create the embedding provider.
[AC-AISVC-29] Returns configured provider instance.
"""
if self._provider is None:
self._provider = EmbeddingProviderFactory.create_provider(
self._provider_name,
self._config
)
return self._provider
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update embedding configuration.
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload.
Args:
provider: New provider name
config: New provider configuration
Returns:
True if update was successful
Raises:
EmbeddingException: If configuration is invalid
"""
old_provider = self._provider_name
old_config = self._config.copy()
try:
new_provider_instance = EmbeddingProviderFactory.create_provider(
provider,
config
)
if self._provider:
await self._provider.close()
self._provider_name = provider
self._config = config
self._provider = new_provider_instance
logger.info(f"Updated embedding config: provider={provider}")
return True
except Exception as e:
self._provider_name = old_provider
self._config = old_config
raise EmbeddingException(
f"Failed to update config: {e}",
provider="config_manager",
details={"provider": provider, "config": config}
)
async def test_connection(
self,
test_text: str = "这是一个测试文本",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test embedding connection.
[AC-AISVC-41] Tests provider connectivity.
Args:
test_text: Text to embed for testing
provider: Provider to test (uses current if None)
config: Config to test (uses current if None)
Returns:
Dict with test results including success, dimension, latency
"""
import time
test_provider_name = provider or self._provider_name
test_config = config or self._config
try:
test_provider = EmbeddingProviderFactory.create_provider(
test_provider_name,
test_config
)
start_time = time.perf_counter()
embedding = await test_provider.embed(test_text)
latency_ms = (time.perf_counter() - start_time) * 1000
await test_provider.close()
return {
"success": True,
"dimension": len(embedding),
"latency_ms": latency_ms,
"message": f"连接成功,向量维度: {len(embedding)}",
}
except Exception as e:
return {
"success": False,
"dimension": 0,
"latency_ms": 0,
"error": str(e),
"message": f"连接失败: {e}",
}
async def close(self) -> None:
"""Close the current provider."""
if self._provider:
await self._provider.close()
self._provider = None
_embedding_config_manager: EmbeddingConfigManager | None = None
def get_embedding_config_manager() -> EmbeddingConfigManager:
"""
Get the global embedding config manager.
[AC-AISVC-31] Singleton pattern for configuration management.
"""
global _embedding_config_manager
if _embedding_config_manager is None:
from app.core.config import get_settings
settings = get_settings()
_embedding_config_manager = EmbeddingConfigManager(
default_provider="ollama",
default_config={
"base_url": settings.ollama_base_url,
"model": settings.ollama_embedding_model,
"dimension": settings.qdrant_vector_size,
}
)
return _embedding_config_manager
async def get_embedding_provider() -> EmbeddingProvider:
"""
Get the current embedding provider.
[AC-AISVC-29] Convenience function for getting provider.
"""
manager = get_embedding_config_manager()
return await manager.get_provider()

View File

@ -0,0 +1,291 @@
"""
Nomic embedding provider with task prefixes and Matryoshka support.
Implements RAG optimization spec:
- Task prefixes: search_document: / search_query:
- Matryoshka dimension truncation: 256/512/768 dimensions
"""
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import httpx
import numpy as np
from app.services.embedding.base import (
EmbeddingConfig,
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class EmbeddingTask(str, Enum):
"""Task type for nomic-embed-text v1.5 model."""
DOCUMENT = "search_document"
QUERY = "search_query"
@dataclass
class NomicEmbeddingResult:
"""Result from Nomic embedding with multiple dimensions."""
embedding_full: list[float]
embedding_256: list[float]
embedding_512: list[float]
dimension: int
model: str
task: EmbeddingTask
latency_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class NomicEmbeddingProvider(EmbeddingProvider):
"""
Nomic-embed-text v1.5 embedding provider with task prefixes.
Key features:
- Task prefixes: search_document: for documents, search_query: for queries
- Matryoshka dimension truncation: 256/512/768 dimensions
- Automatic normalization after truncation
Reference: rag-optimization/spec.md Section 2.1, 2.3
"""
PROVIDER_NAME = "nomic"
DOCUMENT_PREFIX = "search_document:"
QUERY_PREFIX = "search_query:"
FULL_DIMENSION = 768
def __init__(
self,
base_url: str = "http://localhost:11434",
model: str = "nomic-embed-text",
dimension: int = 768,
timeout_seconds: int = 60,
enable_matryoshka: bool = True,
**kwargs: Any,
):
self._base_url = base_url.rstrip("/")
self._model = model
self._dimension = dimension
self._timeout = timeout_seconds
self._enable_matryoshka = enable_matryoshka
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
def _add_prefix(self, text: str, task: EmbeddingTask) -> str:
"""Add task prefix to text."""
if task == EmbeddingTask.DOCUMENT:
prefix = self.DOCUMENT_PREFIX
else:
prefix = self.QUERY_PREFIX
if text.startswith(prefix):
return text
return f"{prefix}{text}"
def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]:
"""
Truncate embedding to target dimension and normalize.
Matryoshka representation learning allows dimension truncation.
"""
truncated = embedding[:target_dim]
arr = np.array(truncated, dtype=np.float32)
norm = np.linalg.norm(arr)
if norm > 0:
arr = arr / norm
return arr.tolist()
async def embed_with_task(
self,
text: str,
task: EmbeddingTask,
) -> NomicEmbeddingResult:
"""
Generate embedding with specified task prefix.
Args:
text: Input text to embed
task: DOCUMENT for indexing, QUERY for retrieval
Returns:
NomicEmbeddingResult with all dimension variants
"""
start_time = time.perf_counter()
prefixed_text = self._add_prefix(text, task)
try:
client = await self._get_client()
response = await client.post(
f"{self._base_url}/api/embeddings",
json={
"model": self._model,
"prompt": prefixed_text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"text_length": len(text), "task": task.value}
)
latency_ms = (time.perf_counter() - start_time) * 1000
embedding_256 = self._truncate_and_normalize(embedding, 256)
embedding_512 = self._truncate_and_normalize(embedding, 512)
logger.debug(
f"Generated Nomic embedding: task={task.value}, "
f"dim={len(embedding)}, latency={latency_ms:.2f}ms"
)
return NomicEmbeddingResult(
embedding_full=embedding,
embedding_256=embedding_256,
embedding_512=embedding_512,
dimension=len(embedding),
model=self._model,
task=task,
latency_ms=latency_ms,
)
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"Ollama API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"Ollama connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
async def embed_document(self, text: str) -> NomicEmbeddingResult:
"""
Generate embedding for document (with search_document: prefix).
Use this when indexing documents into vector store.
"""
return await self.embed_with_task(text, EmbeddingTask.DOCUMENT)
async def embed_query(self, text: str) -> NomicEmbeddingResult:
"""
Generate embedding for query (with search_query: prefix).
Use this when searching/retrieving documents.
"""
return await self.embed_with_task(text, EmbeddingTask.QUERY)
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text.
Default uses QUERY task for backward compatibility.
"""
result = await self.embed_query(text)
return result.embedding_full
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
Uses QUERY task by default.
"""
embeddings = []
for text in texts:
embedding = await self.embed(text)
embeddings.append(embedding)
return embeddings
async def embed_documents_batch(
self,
texts: list[str],
) -> list[NomicEmbeddingResult]:
"""
Generate embeddings for multiple documents (DOCUMENT task).
Use this when batch indexing documents.
"""
results = []
for text in texts:
result = await self.embed_document(text)
results.append(result)
return results
async def embed_queries_batch(
self,
texts: list[str],
) -> list[NomicEmbeddingResult]:
"""
Generate embeddings for multiple queries (QUERY task).
Use this when batch processing queries.
"""
results = []
for text in texts:
result = await self.embed_query(text)
results.append(result)
return results
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""Get the configuration schema for Nomic provider."""
return {
"base_url": {
"type": "string",
"description": "Ollama API 地址",
"default": "http://localhost:11434",
},
"model": {
"type": "string",
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5",
"default": "nomic-embed-text",
},
"dimension": {
"type": "integer",
"description": "向量维度(支持 256/512/768",
"default": 768,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
"enable_matryoshka": {
"type": "boolean",
"description": "启用 Matryoshka 维度截断",
"default": True,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None

View File

@ -0,0 +1,58 @@
"""
Ollama embedding service for generating text embeddings.
Uses nomic-embed-text model via Ollama API.
"""
import logging
import httpx
from app.core.config import get_settings
logger = logging.getLogger(__name__)
async def get_embedding(text: str) -> list[float]:
"""
Generate embedding vector for text using Ollama nomic-embed-text model.
"""
settings = get_settings()
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={
"model": settings.ollama_embedding_model,
"prompt": text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
logger.warning(f"Empty embedding returned for text length={len(text)}")
return [0.0] * settings.qdrant_vector_size
logger.debug(f"Generated embedding: dim={len(embedding)}")
return embedding
except httpx.HTTPStatusError as e:
logger.error(f"Ollama API error: {e.response.status_code} - {e.response.text}")
raise
except httpx.RequestError as e:
logger.error(f"Ollama connection error: {e}")
raise
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
async def get_embeddings_batch(texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
"""
embeddings = []
for text in texts:
embedding = await get_embedding(text)
embeddings.append(embedding)
return embeddings

View File

@ -0,0 +1,157 @@
"""
Ollama embedding provider implementation.
[AC-AISVC-29, AC-AISVC-30] Ollama-based embedding provider.
Uses Ollama API for generating text embeddings.
"""
import logging
import time
from typing import Any
import httpx
from app.services.embedding.base import (
EmbeddingConfig,
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class OllamaEmbeddingProvider(EmbeddingProvider):
"""
Embedding provider using Ollama API.
[AC-AISVC-29, AC-AISVC-30] Supports local embedding models via Ollama.
"""
PROVIDER_NAME = "ollama"
def __init__(
self,
base_url: str = "http://localhost:11434",
model: str = "nomic-embed-text",
dimension: int = 768,
timeout_seconds: int = 60,
**kwargs: Any,
):
self._base_url = base_url.rstrip("/")
self._model = model
self._dimension = dimension
self._timeout = timeout_seconds
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text using Ollama API.
[AC-AISVC-29] Returns embedding vector.
"""
start_time = time.perf_counter()
try:
client = await self._get_client()
response = await client.post(
f"{self._base_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
}
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"text_length": len(text)}
)
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"Generated embedding via Ollama: dim={len(embedding)}, "
f"latency={latency_ms:.2f}ms"
)
return embedding
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"Ollama API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"Ollama connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts.
[AC-AISVC-29] Sequential embedding generation.
"""
embeddings = []
for text in texts:
embedding = await self.embed(text)
embeddings.append(embedding)
return embeddings
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""
Get the configuration schema for Ollama provider.
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
"""
return {
"base_url": {
"type": "string",
"description": "Ollama API 地址",
"default": "http://localhost:11434",
},
"model": {
"type": "string",
"description": "嵌入模型名称",
"default": "nomic-embed-text",
},
"dimension": {
"type": "integer",
"description": "向量维度",
"default": 768,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None

View File

@ -0,0 +1,193 @@
"""
OpenAI embedding provider implementation.
[AC-AISVC-29, AC-AISVC-30] OpenAI-based embedding provider.
Uses OpenAI API for generating text embeddings.
"""
import logging
import time
from typing import Any
import httpx
from app.services.embedding.base import (
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""
Embedding provider using OpenAI API.
[AC-AISVC-29, AC-AISVC-30] Supports OpenAI embedding models.
"""
PROVIDER_NAME = "openai"
MODEL_DIMENSIONS = {
"text-embedding-ada-002": 1536,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
}
def __init__(
self,
api_key: str,
model: str = "text-embedding-3-small",
base_url: str = "https://api.openai.com/v1",
dimension: int | None = None,
timeout_seconds: int = 60,
**kwargs: Any,
):
self._api_key = api_key
self._model = model
self._base_url = base_url.rstrip("/")
self._timeout = timeout_seconds
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
if dimension:
self._dimension = dimension
elif model in self.MODEL_DIMENSIONS:
self._dimension = self.MODEL_DIMENSIONS[model]
else:
self._dimension = 1536
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text using OpenAI API.
[AC-AISVC-29] Returns embedding vector.
"""
embeddings = await self.embed_batch([text])
return embeddings[0]
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts using OpenAI API.
[AC-AISVC-29] Supports batch embedding for efficiency.
"""
start_time = time.perf_counter()
try:
client = await self._get_client()
request_body: dict[str, Any] = {
"model": self._model,
"input": texts,
}
if self._dimension and self._model.startswith("text-embedding-3"):
request_body["dimensions"] = self._dimension
response = await client.post(
f"{self._base_url}/embeddings",
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
json=request_body,
)
response.raise_for_status()
data = response.json()
embeddings = []
for item in data.get("data", []):
embedding = item.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"index": item.get("index", 0)}
)
embeddings.append(embedding)
if len(embeddings) != len(texts):
raise EmbeddingException(
f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
provider=self.PROVIDER_NAME
)
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"Generated {len(embeddings)} embeddings via OpenAI: "
f"dim={len(embeddings[0]) if embeddings else 0}, "
f"latency={latency_ms:.2f}ms"
)
return embeddings
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"OpenAI API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"OpenAI connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""
Get the configuration schema for OpenAI provider.
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
"""
return {
"api_key": {
"type": "string",
"description": "OpenAI API 密钥",
"required": True,
"secret": True,
},
"model": {
"type": "string",
"description": "嵌入模型名称",
"default": "text-embedding-3-small",
"enum": list(self.MODEL_DIMENSIONS.keys()),
},
"base_url": {
"type": "string",
"description": "OpenAI API 地址(支持兼容接口)",
"default": "https://api.openai.com/v1",
},
"dimension": {
"type": "integer",
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
"default": 1536,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None

View File

@ -0,0 +1,294 @@
"""
Knowledge Base service for AI Service.
[AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing.
"""
import logging
import os
import uuid
from datetime import datetime
from typing import Sequence
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import (
Document,
DocumentStatus,
IndexJob,
IndexJobStatus,
KnowledgeBase,
)
logger = logging.getLogger(__name__)
class KBService:
"""
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service.
Handles document upload, indexing jobs, and document listing.
"""
def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"):
self._session = session
self._upload_dir = upload_dir
os.makedirs(upload_dir, exist_ok=True)
async def get_or_create_kb(
self,
tenant_id: str,
kb_id: str | None = None,
name: str = "Default KB",
) -> KnowledgeBase:
"""
Get existing KB or create default one.
"""
if kb_id:
try:
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
KnowledgeBase.id == uuid.UUID(kb_id),
)
result = await self._session.execute(stmt)
existing_kb = result.scalar_one_or_none()
if existing_kb:
return existing_kb
except ValueError:
pass
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
).limit(1)
result = await self._session.execute(stmt)
existing_kb = result.scalar_one_or_none()
if existing_kb:
return existing_kb
new_kb = KnowledgeBase(
tenant_id=tenant_id,
name=name,
)
self._session.add(new_kb)
await self._session.flush()
logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}")
return new_kb
async def upload_document(
self,
tenant_id: str,
kb_id: str,
file_name: str,
file_content: bytes,
file_type: str | None = None,
) -> tuple[Document, IndexJob]:
"""
[AC-ASA-01] Upload document and create indexing job.
"""
doc_id = uuid.uuid4()
file_path = os.path.join(self._upload_dir, f"{tenant_id}_{doc_id}_{file_name}")
with open(file_path, "wb") as f:
f.write(file_content)
document = Document(
id=doc_id,
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file_name,
file_path=file_path,
file_size=len(file_content),
file_type=file_type,
status=DocumentStatus.PENDING.value,
)
self._session.add(document)
job = IndexJob(
tenant_id=tenant_id,
doc_id=doc_id,
status=IndexJobStatus.PENDING.value,
progress=0,
)
self._session.add(job)
await self._session.flush()
logger.info(
f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, "
f"file_name={file_name}, size={len(file_content)}"
)
return document, job
async def list_documents(
self,
tenant_id: str,
kb_id: str | None = None,
status: str | None = None,
page: int = 1,
page_size: int = 20,
) -> tuple[Sequence[Document], int]:
"""
[AC-ASA-08] List documents with filtering and pagination.
"""
stmt = select(Document).where(Document.tenant_id == tenant_id)
if kb_id:
stmt = stmt.where(Document.kb_id == kb_id)
if status:
stmt = stmt.where(Document.status == status)
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = await self._session.execute(count_stmt)
total = total_result.scalar() or 0
stmt = stmt.order_by(col(Document.created_at).desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await self._session.execute(stmt)
documents = result.scalars().all()
logger.info(
f"[AC-ASA-08] Listed documents: tenant={tenant_id}, "
f"kb_id={kb_id}, status={status}, total={total}"
)
return documents, total
async def get_document(
self,
tenant_id: str,
doc_id: str,
) -> Document | None:
"""
Get document by ID.
"""
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == uuid.UUID(doc_id),
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_index_job(
self,
tenant_id: str,
job_id: str,
) -> IndexJob | None:
"""
[AC-ASA-02] Get index job status.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.id == uuid.UUID(job_id),
)
result = await self._session.execute(stmt)
job = result.scalar_one_or_none()
if job:
logger.info(
f"[AC-ASA-02] Got job status: tenant={tenant_id}, "
f"job_id={job_id}, status={job.status}, progress={job.progress}"
)
return job
async def get_index_job_by_doc(
self,
tenant_id: str,
doc_id: str,
) -> IndexJob | None:
"""
Get index job by document ID.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id == uuid.UUID(doc_id),
).order_by(col(IndexJob.created_at).desc())
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_job_status(
self,
tenant_id: str,
job_id: str,
status: str,
progress: int | None = None,
error_msg: str | None = None,
) -> IndexJob | None:
"""
Update index job status.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.id == uuid.UUID(job_id),
)
result = await self._session.execute(stmt)
job = result.scalar_one_or_none()
if job:
job.status = status
job.updated_at = datetime.utcnow()
if progress is not None:
job.progress = progress
if error_msg is not None:
job.error_msg = error_msg
await self._session.flush()
if job.doc_id:
doc_stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == job.doc_id,
)
doc_result = await self._session.execute(doc_stmt)
doc = doc_result.scalar_one_or_none()
if doc:
doc.status = status
doc.updated_at = datetime.utcnow()
if error_msg:
doc.error_msg = error_msg
await self._session.flush()
return job
async def delete_document(
self,
tenant_id: str,
doc_id: str,
) -> bool:
"""
Delete document and associated files.
"""
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == uuid.UUID(doc_id),
)
result = await self._session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
return False
if document.file_path and os.path.exists(document.file_path):
os.remove(document.file_path)
await self._session.delete(document)
await self._session.flush()
logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}")
return True
async def list_knowledge_bases(
self,
tenant_id: str,
) -> Sequence[KnowledgeBase]:
"""
List all knowledge bases for a tenant.
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id
).order_by(col(KnowledgeBase.created_at).desc())
result = await self._session.execute(stmt)
return result.scalars().all()

View File

@ -0,0 +1,15 @@
"""
LLM Adapter module for AI Service.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers.
"""
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.openai_client import OpenAIClient
__all__ = [
"LLMClient",
"LLMConfig",
"LLMResponse",
"LLMStreamChunk",
"OpenAIClient",
]

View File

@ -0,0 +1,115 @@
"""
Base LLM client interface.
[AC-AISVC-02, AC-AISVC-06] Abstract interface for LLM providers.
Design reference: design.md Section 8.1 - LLMClient interface
- generate(prompt, params) -> text
- stream_generate(prompt, params) -> iterator[delta]
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator
@dataclass
class LLMConfig:
"""
Configuration for LLM client.
[AC-AISVC-02] Supports configurable model parameters.
"""
model: str = "gpt-4o-mini"
max_tokens: int = 2048
temperature: float = 0.7
top_p: float = 1.0
timeout_seconds: int = 30
max_retries: int = 3
extra_params: dict[str, Any] = field(default_factory=dict)
@dataclass
class LLMResponse:
"""
Response from LLM generation.
[AC-AISVC-02] Contains generated content and metadata.
"""
content: str
model: str
usage: dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop"
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class LLMStreamChunk:
"""
Streaming chunk from LLM.
[AC-AISVC-06, AC-AISVC-07] Incremental output for SSE streaming.
"""
delta: str
model: str
finish_reason: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class LLMClient(ABC):
"""
Abstract base class for LLM clients.
[AC-AISVC-02, AC-AISVC-06] Provides unified interface for different LLM providers.
Design reference: design.md Section 8.2 - Plugin points
- OpenAICompatibleClient / LocalModelClient can be swapped
"""
@abstractmethod
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns complete response for ChatResponse.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
Raises:
LLMException: If generation fails.
"""
pass
@abstractmethod
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Yields:
LLMStreamChunk with incremental content.
Raises:
LLMException: If generation fails.
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the client and release resources."""
pass

View File

@ -0,0 +1,421 @@
"""
LLM Provider Factory and Configuration Management.
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management.
Design pattern: Factory pattern for pluggable LLM providers.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__)
@dataclass
class LLMProviderInfo:
"""Information about an LLM provider."""
name: str
display_name: str
description: str
config_schema: dict[str, Any]
LLM_PROVIDERS: dict[str, LLMProviderInfo] = {
"openai": LLMProviderInfo(
name="openai",
display_name="OpenAI",
description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.openai.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "gpt-4o-mini",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key"],
},
),
"ollama": LLMProviderInfo(
name="ollama",
display_name="Ollama",
description="Ollama 本地模型 (Llama, Qwen 等)",
config_schema={
"type": "object",
"properties": {
"base_url": {
"type": "string",
"title": "Ollama API 地址",
"description": "Ollama API 地址",
"default": "http://localhost:11434/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "llama3.2",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": [],
},
),
"deepseek": LLMProviderInfo(
name="deepseek",
display_name="DeepSeek",
description="DeepSeek 大模型 (deepseek-chat, deepseek-coder)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "DeepSeek API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.deepseek.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称 (deepseek-chat, deepseek-coder)",
"default": "deepseek-chat",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key"],
},
),
"azure": LLMProviderInfo(
name="azure",
display_name="Azure OpenAI",
description="Azure OpenAI 服务",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "Azure Endpoint",
"description": "Azure Endpoint",
"required": True,
},
"model": {
"type": "string",
"title": "部署名称",
"description": "部署名称",
"required": True,
},
"api_version": {
"type": "string",
"title": "API 版本",
"description": "API 版本",
"default": "2024-02-15-preview",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key", "base_url", "model"],
},
),
}
class LLMProviderFactory:
"""
Factory for creating LLM clients.
[AC-ASA-14, AC-ASA-15] Dynamic provider creation.
"""
@classmethod
def get_providers(cls) -> list[LLMProviderInfo]:
"""Get all registered LLM providers."""
return list(LLM_PROVIDERS.values())
@classmethod
def get_provider_info(cls, name: str) -> LLMProviderInfo | None:
"""Get provider info by name."""
return LLM_PROVIDERS.get(name)
@classmethod
def create_client(
cls,
provider: str,
config: dict[str, Any],
) -> LLMClient:
"""
Create an LLM client for the specified provider.
[AC-ASA-15] Factory method for client creation.
Args:
provider: Provider name (openai, ollama, azure)
config: Provider configuration
Returns:
LLMClient instance
Raises:
ValueError: If provider is not supported
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
if provider in ("openai", "ollama", "azure", "deepseek"):
return OpenAIClient(
api_key=config.get("api_key"),
base_url=config.get("base_url"),
model=config.get("model"),
default_config=LLMConfig(
model=config.get("model", "gpt-4o-mini"),
max_tokens=config.get("max_tokens", 2048),
temperature=config.get("temperature", 0.7),
),
)
raise ValueError(f"Unsupported LLM provider: {provider}")
class LLMConfigManager:
"""
Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload.
"""
def __init__(self):
from app.core.config import get_settings
settings = get_settings()
self._current_provider: str = settings.llm_provider
self._current_config: dict[str, Any] = {
"api_key": settings.llm_api_key,
"base_url": settings.llm_base_url,
"model": settings.llm_model,
"max_tokens": settings.llm_max_tokens,
"temperature": settings.llm_temperature,
}
self._client: LLMClient | None = None
def get_current_config(self) -> dict[str, Any]:
"""Get current LLM configuration."""
return {
"provider": self._current_provider,
"config": self._current_config,
}
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update LLM configuration.
[AC-ASA-16] Hot-reload configuration.
Args:
provider: Provider name
config: New configuration
Returns:
True if update successful
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._client:
await self._client.close()
self._client = None
self._current_provider = provider
self._current_config = validated_config
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
return True
def _validate_config(
self,
provider_info: LLMProviderInfo,
config: dict[str, Any],
) -> dict[str, Any]:
"""Validate configuration against provider schema."""
schema_props = provider_info.config_schema.get("properties", {})
required_fields = provider_info.config_schema.get("required", [])
validated = {}
for key, prop_schema in schema_props.items():
if key in config:
validated[key] = config[key]
elif "default" in prop_schema:
validated[key] = prop_schema["default"]
elif key in required_fields:
raise ValueError(f"Missing required config: {key}")
return validated
def get_client(self) -> LLMClient:
"""Get or create LLM client with current config."""
if self._client is None:
self._client = LLMProviderFactory.create_client(
self._current_provider,
self._current_config,
)
return self._client
async def test_connection(
self,
test_prompt: str = "你好,请简单介绍一下自己。",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Connection testing.
Args:
test_prompt: Test prompt to send
provider: Optional provider to test (uses current if not specified)
config: Optional config to test (uses current if not specified)
Returns:
Test result with success status, response, and metrics
"""
import time
test_provider = provider or self._current_provider
test_config = config if config else self._current_config
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, config={test_config}")
if test_provider not in LLM_PROVIDERS:
return {
"success": False,
"error": f"Unsupported provider: {test_provider}",
}
try:
client = LLMProviderFactory.create_client(test_provider, test_config)
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": test_prompt}],
)
latency_ms = (time.time() - start_time) * 1000
await client.close()
return {
"success": True,
"response": response.content,
"latency_ms": round(latency_ms, 2),
"prompt_tokens": response.usage.get("prompt_tokens", 0),
"completion_tokens": response.usage.get("completion_tokens", 0),
"total_tokens": response.usage.get("total_tokens", 0),
"model": response.model,
"message": f"连接成功,模型: {response.model}",
}
except Exception as e:
logger.error(f"[AC-ASA-18] LLM test failed: {e}")
return {
"success": False,
"error": str(e),
"message": f"连接失败: {str(e)}",
}
async def close(self) -> None:
"""Close the current client."""
if self._client:
await self._client.close()
self._client = None
_llm_config_manager: LLMConfigManager | None = None
def get_llm_config_manager() -> LLMConfigManager:
"""Get or create LLM config manager instance."""
global _llm_config_manager
if _llm_config_manager is None:
_llm_config_manager = LLMConfigManager()
return _llm_config_manager

View File

@ -0,0 +1,333 @@
"""
OpenAI-compatible LLM client implementation.
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
Design reference: design.md Section 8.1 - LLMClient interface
- Uses langchain-openai or official SDK pattern
- Supports generate and stream_generate
"""
import json
import logging
from typing import Any, AsyncGenerator
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import get_settings
from app.core.exceptions import AIServiceException, ErrorCode, ServiceUnavailableException, TimeoutException
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
logger = logging.getLogger(__name__)
class LLMException(AIServiceException):
"""Exception raised when LLM operations fail."""
def __init__(self, message: str, details: list[dict] | None = None):
super().__init__(
code=ErrorCode.LLM_ERROR,
message=message,
status_code=503,
details=details,
)
class OpenAIClient(LLMClient):
"""
OpenAI-compatible LLM client.
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
Supports:
- OpenAI API (official)
- OpenAI-compatible endpoints (Azure, local models, etc.)
"""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
default_config: LLMConfig | None = None,
):
settings = get_settings()
self._api_key = api_key or settings.llm_api_key
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
self._model = model or settings.llm_model
self._default_config = default_config or LLMConfig(
model=self._model,
max_tokens=settings.llm_max_tokens,
temperature=settings.llm_temperature,
timeout_seconds=settings.llm_timeout_seconds,
max_retries=settings.llm_max_retries,
)
self._client: httpx.AsyncClient | None = None
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout_seconds),
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
return self._client
def _build_request_body(
self,
messages: list[dict[str, str]],
config: LLMConfig,
stream: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request body for OpenAI API."""
body: dict[str, Any] = {
"model": config.model,
"messages": messages,
"max_tokens": config.max_tokens,
"temperature": config.temperature,
"top_p": config.top_p,
"stream": stream,
}
body.update(config.extra_params)
body.update(kwargs)
return body
@retry(
retry=retry_if_exception_type(httpx.TimeoutException),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
)
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns complete response for ChatResponse.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content and metadata.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
logger.info(f"[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
logger.info(f"[AC-AISVC-02] ======================================")
try:
response = await client.post(
f"{self._base_url}/chat/completions",
json=body,
)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
raise TimeoutException(message=f"LLM request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
except json.JSONDecodeError as e:
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
raise LLMException(message=f"Failed to parse LLM response: {e}")
try:
choice = data["choices"][0]
content = choice["message"]["content"]
usage = data.get("usage", {})
finish_reason = choice.get("finish_reason", "stop")
logger.info(
f"[AC-AISVC-02] Generated response: "
f"tokens={usage.get('total_tokens', 'N/A')}, "
f"finish_reason={finish_reason}"
)
return LLMResponse(
content=content,
model=data.get("model", effective_config.model),
usage=usage,
finish_reason=finish_reason,
metadata={"raw_response": data},
)
except (KeyError, IndexError) as e:
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
raise LLMException(
message=f"Unexpected LLM response format: {e}",
details=[{"response": str(data)}],
)
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
**kwargs: Additional provider-specific parameters.
Yields:
LLMStreamChunk with incremental content.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
logger.info(f"[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
logger.info(f"[AC-AISVC-06] [{i}] role={role}, content_length={len(content)}")
logger.info(f"[AC-AISVC-06] [{i}] content:\n{content}")
logger.info(f"[AC-AISVC-06] ======================================")
try:
async with client.stream(
"POST",
f"{self._base_url}/chat/completions",
json=body,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
json_str = line[6:]
try:
chunk_data = json.loads(json_str)
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
if chunk:
yield chunk
except json.JSONDecodeError as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
continue
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM streaming API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
logger.info(f"[AC-AISVC-06] Streaming generation completed")
def _parse_stream_chunk(
self,
data: dict[str, Any],
model: str,
) -> LLMStreamChunk | None:
"""Parse a streaming chunk from OpenAI API."""
try:
choices = data.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
content = delta.get("content", "")
finish_reason = choices[0].get("finish_reason")
if not content and not finish_reason:
return None
return LLMStreamChunk(
delta=content,
model=data.get("model", model),
finish_reason=finish_reason,
metadata={"raw_chunk": data},
)
except (KeyError, IndexError) as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
return None
def _parse_error_response(self, response: httpx.Response) -> str:
"""Parse error response from API."""
try:
data = response.json()
if "error" in data:
error = data["error"]
if isinstance(error, dict):
return error.get("message", str(error))
return str(error)
return response.text
except Exception:
return response.text
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
_llm_client: OpenAIClient | None = None
def get_llm_client() -> OpenAIClient:
"""Get or create LLM client instance."""
global _llm_client
if _llm_client is None:
_llm_client = OpenAIClient()
return _llm_client
async def close_llm_client() -> None:
"""Close the global LLM client."""
global _llm_client
if _llm_client:
await _llm_client.close()
_llm_client = None

View File

@ -0,0 +1,170 @@
"""
Memory service for AI Service.
[AC-AISVC-13] Session-based memory management with tenant isolation.
"""
import logging
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate
logger = logging.getLogger(__name__)
class MemoryService:
"""
[AC-AISVC-13] Memory service for session-based conversation history.
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def get_or_create_session(
self,
tenant_id: str,
session_id: str,
channel_type: str | None = None,
metadata: dict | None = None,
) -> ChatSession:
"""
[AC-AISVC-13] Get existing session or create a new one.
Ensures tenant isolation by querying with tenant_id.
"""
stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
result = await self._session.execute(stmt)
existing_session = result.scalar_one_or_none()
if existing_session:
logger.info(
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
)
return existing_session
new_session = ChatSession(
tenant_id=tenant_id,
session_id=session_id,
channel_type=channel_type,
metadata_=metadata,
)
self._session.add(new_session)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
)
return new_session
async def load_history(
self,
tenant_id: str,
session_id: str,
limit: int | None = None,
) -> Sequence[ChatMessage]:
"""
[AC-AISVC-13] Load conversation history for a session.
All queries are filtered by tenant_id to ensure isolation.
"""
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).asc())
)
if limit:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
messages = result.scalars().all()
logger.info(
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
)
return messages
async def append_message(
self,
tenant_id: str,
session_id: str,
role: str,
content: str,
) -> ChatMessage:
"""
[AC-AISVC-13] Append a message to the session history.
Message is scoped by tenant_id for isolation.
"""
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=role,
content=content,
)
self._session.add(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
)
return message
async def append_messages(
self,
tenant_id: str,
session_id: str,
messages: list[dict[str, str]],
) -> list[ChatMessage]:
"""
[AC-AISVC-13] Append multiple messages to the session history.
Used for batch insertion of conversation turns.
"""
chat_messages = []
for msg in messages:
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=msg["role"],
content=msg["content"],
)
self._session.add(message)
chat_messages.append(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
)
return chat_messages
async def clear_history(self, tenant_id: str, session_id: str) -> int:
"""
[AC-AISVC-13] Clear all messages for a session.
Only affects messages within the tenant's scope.
"""
stmt = select(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
count = 0
for message in messages:
await self._session.delete(message)
count += 1
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
)
return count

View File

@ -0,0 +1,689 @@
"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
Design reference: design.md Section 2.2 - 关键数据流
1. Memory.load(tenantId, sessionId)
2. merge_context(local_history, external_history)
3. Retrieval.retrieve(query, tenantId, channelType, metadata)
4. build_prompt(merged_history, retrieved_docs, currentMessage)
5. LLM.generate(...) (non-streaming) or LLM.stream_generate(...) (streaming)
6. compute_confidence(...)
7. Memory.append(tenantId, sessionId, user/assistant messages)
8. Return ChatResponse (or output via SSE)
RAG Optimization (rag-optimization/spec.md):
- Two-stage retrieval with Matryoshka dimensions
- RRF hybrid ranking
- Optimized prompt engineering
"""
import logging
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator
from sse_starlette.sse import ServerSentEvent
from app.core.config import get_settings
from app.core.prompts import SYSTEM_PROMPT, format_evidence_for_prompt
from app.core.sse import (
create_error_event,
create_final_event,
create_message_event,
SSEStateMachine,
)
from app.models import ChatRequest, ChatResponse
from app.services.confidence import ConfidenceCalculator, ConfidenceResult
from app.services.context import ContextMerger, MergedContext
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
from app.services.memory import MemoryService
from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class OrchestratorConfig:
"""
Configuration for OrchestratorService.
[AC-AISVC-01] Centralized configuration for orchestration.
"""
max_history_tokens: int = 4000
max_evidence_tokens: int = 2000
system_prompt: str = SYSTEM_PROMPT
enable_rag: bool = True
use_optimized_retriever: bool = True
@dataclass
class GenerationContext:
"""
[AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline.
Contains all intermediate results for diagnostics and response building.
"""
tenant_id: str
session_id: str
current_message: str
channel_type: str
request_metadata: dict[str, Any] | None = None
local_history: list[dict[str, str]] = field(default_factory=list)
merged_context: MergedContext | None = None
retrieval_result: RetrievalResult | None = None
llm_response: LLMResponse | None = None
confidence_result: ConfidenceResult | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
Coordinates memory, retrieval, and LLM components.
SSE Event Flow (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
"""
def __init__(
self,
llm_client: LLMClient | None = None,
memory_service: MemoryService | None = None,
retriever: BaseRetriever | None = None,
context_merger: ContextMerger | None = None,
confidence_calculator: ConfidenceCalculator | None = None,
config: OrchestratorConfig | None = None,
):
"""
Initialize orchestrator with optional dependencies for DI.
Args:
llm_client: LLM client for generation
memory_service: Memory service for session history
retriever: Retriever for RAG
context_merger: Context merger for history deduplication
confidence_calculator: Confidence calculator for response scoring
config: Orchestrator configuration
"""
settings = get_settings()
self._llm_client = llm_client
self._memory_service = memory_service
self._retriever = retriever
self._context_merger = context_merger or ContextMerger(
max_history_tokens=getattr(settings, "max_history_tokens", 4000)
)
self._confidence_calculator = confidence_calculator or ConfidenceCalculator()
self._config = config or OrchestratorConfig(
max_history_tokens=getattr(settings, "max_history_tokens", 4000),
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
enable_rag=True,
)
self._llm_config = LLMConfig(
model=getattr(settings, "llm_model", "gpt-4o-mini"),
max_tokens=getattr(settings, "llm_max_tokens", 2048),
temperature=getattr(settings, "llm_temperature", 0.7),
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
max_retries=getattr(settings, "llm_max_retries", 3),
)
async def generate(
self,
tenant_id: str,
request: ChatRequest,
) -> ChatResponse:
"""
Generate a non-streaming response.
[AC-AISVC-01, AC-AISVC-02] Complete generation pipeline.
Pipeline (per design.md Section 2.2):
1. Load local history from Memory
2. Merge with external history (dedup + truncate)
3. RAG retrieval (optional)
4. Build prompt with context and evidence
5. LLM generation
6. Calculate confidence
7. Save messages to Memory
8. Return ChatResponse
"""
logger.info(
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
f"session={request.session_id}, channel_type={request.channel_type}, "
f"current_message={request.current_message[:100]}..."
)
logger.info(
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
f"retriever={'configured' if self._retriever else 'NOT configured'}"
)
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
await self._generate_response(ctx)
self._calculate_confidence(ctx)
await self._save_messages(ctx)
return self._build_response(ctx)
except Exception as e:
logger.error(f"[AC-AISVC-01] Generation failed: {e}")
return ChatResponse(
reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
confidence=0.0,
should_transfer=True,
transfer_reason=f"服务异常: {str(e)}",
metadata={"error": str(e), "diagnostics": ctx.diagnostics},
)
async def _load_local_history(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Load local history from Memory service.
Step 1 of the generation pipeline.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping history load")
ctx.diagnostics["memory_enabled"] = False
return
try:
messages = await self._memory_service.load_history(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
)
ctx.local_history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
ctx.diagnostics["memory_enabled"] = True
ctx.diagnostics["local_history_count"] = len(ctx.local_history)
logger.info(
f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to load history: {e}")
ctx.diagnostics["memory_error"] = str(e)
async def _merge_context(
self,
ctx: GenerationContext,
external_history: list | None,
) -> None:
"""
[AC-AISVC-14, AC-AISVC-15] Merge local and external history.
Step 2 of the generation pipeline.
Design reference: design.md Section 7
- Deduplication based on fingerprint
- Truncation to fit token budget
"""
external_messages = None
if external_history:
external_messages = [
{"role": msg.role.value, "content": msg.content}
for msg in external_history
]
ctx.merged_context = self._context_merger.merge_and_truncate(
local_history=ctx.local_history,
external_history=external_messages,
max_tokens=self._config.max_history_tokens,
)
ctx.diagnostics["merged_context"] = {
"local_count": ctx.merged_context.local_count,
"external_count": ctx.merged_context.external_count,
"duplicates_skipped": ctx.merged_context.duplicates_skipped,
"truncated_count": ctx.merged_context.truncated_count,
"total_tokens": ctx.merged_context.total_tokens,
}
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={ctx.merged_context.local_count}, "
f"external={ctx.merged_context.external_count}, "
f"tokens={ctx.merged_context.total_tokens}"
)
async def _retrieve_evidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
Step 3 of the generation pipeline.
"""
logger.info(
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
)
try:
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.current_message,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
ctx.diagnostics["retrieval"] = {
"hit_count": ctx.retrieval_result.hit_count,
"max_score": ctx.retrieval_result.max_score,
"is_empty": ctx.retrieval_result.is_empty,
}
logger.info(
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
f"hits={ctx.retrieval_result.hit_count}, "
f"max_score={ctx.retrieval_result.max_score:.3f}, "
f"is_empty={ctx.retrieval_result.is_empty}"
)
if ctx.retrieval_result.hit_count > 0:
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
logger.info(
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
f"text_preview={hit.text[:100]}..."
)
except Exception as e:
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
ctx.retrieval_result = RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
)
ctx.diagnostics["retrieval_error"] = str(e)
async def _generate_response(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-02] Generate response using LLM.
Step 4-5 of the generation pipeline.
"""
messages = self._build_llm_messages(ctx)
logger.info(
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
f"has_retrieval_result={ctx.retrieval_result is not None}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
)
if not self._llm_client:
logger.warning(
f"[AC-AISVC-02] No LLM client configured, using fallback. "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="fallback",
)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = "no_llm_client"
return
try:
ctx.llm_response = await self._llm_client.generate(
messages=messages,
config=self._llm_config,
)
ctx.diagnostics["llm_mode"] = "live"
ctx.diagnostics["llm_model"] = ctx.llm_response.model
ctx.diagnostics["llm_usage"] = ctx.llm_response.usage
logger.info(
f"[AC-AISVC-02] LLM response generated: "
f"model={ctx.llm_response.model}, "
f"tokens={ctx.llm_response.usage}, "
f"content_preview={ctx.llm_response.content[:100]}..."
)
except Exception as e:
logger.error(
f"[AC-AISVC-02] LLM generation failed: {e}, "
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
exc_info=True
)
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="error",
metadata={"error": str(e)},
)
ctx.diagnostics["llm_error"] = str(e)
ctx.diagnostics["llm_mode"] = "fallback"
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
"""
[AC-AISVC-02] Build messages for LLM including system prompt and evidence.
"""
messages = []
system_content = self._config.system_prompt
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
evidence_text = self._format_evidence(ctx.retrieval_result)
system_content += f"\n\n知识库参考内容:\n{evidence_text}"
messages.append({"role": "system", "content": system_content})
if ctx.merged_context and ctx.merged_context.messages:
messages.extend(ctx.merged_context.messages)
messages.append({"role": "user", "content": ctx.current_message})
logger.info(
f"[AC-AISVC-02] Built {len(messages)} messages for LLM: "
f"system_len={len(system_content)}, history_count={len(ctx.merged_context.messages) if ctx.merged_context else 0}"
)
logger.debug(f"[AC-AISVC-02] System prompt preview: {system_content[:500]}...")
logger.info(f"[AC-AISVC-02] ========== ORCHESTRATOR FULL PROMPT ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
logger.info(f"[AC-AISVC-02] ==============================================")
return messages
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
"""
[AC-AISVC-17] Format retrieval hits as evidence text.
Uses shared prompt configuration for consistency.
"""
return format_evidence_for_prompt(retrieval_result.hits, max_results=5, max_content_length=500)
def _fallback_response(self, ctx: GenerationContext) -> str:
"""
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
"""
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
return (
"根据知识库信息,我找到了一些相关内容,"
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
)
return (
"抱歉,我暂时无法处理您的请求。"
"请稍后重试或联系人工客服获取帮助。"
)
def _calculate_confidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
Step 6 of the generation pipeline.
"""
if ctx.retrieval_result:
evidence_tokens = 0
if not ctx.retrieval_result.is_empty:
evidence_tokens = sum(
len(hit.text.split()) * 2
for hit in ctx.retrieval_result.hits
)
ctx.confidence_result = self._confidence_calculator.calculate_confidence(
retrieval_result=ctx.retrieval_result,
evidence_tokens=evidence_tokens,
)
else:
ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval()
ctx.diagnostics["confidence"] = {
"score": ctx.confidence_result.confidence,
"should_transfer": ctx.confidence_result.should_transfer,
"is_insufficient": ctx.confidence_result.is_retrieval_insufficient,
}
logger.info(
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
f"{ctx.confidence_result.confidence:.3f}, "
f"should_transfer={ctx.confidence_result.should_transfer}"
)
async def _save_messages(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Save user and assistant messages to Memory.
Step 7 of the generation pipeline.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping save")
return
try:
await self._memory_service.get_or_create_session(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
messages_to_save = [
{"role": "user", "content": ctx.current_message},
]
if ctx.llm_response:
messages_to_save.append({
"role": "assistant",
"content": ctx.llm_response.content,
})
await self._memory_service.append_messages(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
messages=messages_to_save,
)
ctx.diagnostics["messages_saved"] = len(messages_to_save)
logger.info(
f"[AC-AISVC-13] Saved {len(messages_to_save)} messages "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}")
ctx.diagnostics["save_error"] = str(e)
def _build_response(self, ctx: GenerationContext) -> ChatResponse:
"""
[AC-AISVC-02] Build final ChatResponse from generation context.
Step 8 of the generation pipeline.
"""
reply = ctx.llm_response.content if ctx.llm_response else self._fallback_response(ctx)
confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5
should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True
transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None
response_metadata = {
"session_id": ctx.session_id,
"channel_type": ctx.channel_type,
"diagnostics": ctx.diagnostics,
}
return ChatResponse(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=response_metadata,
)
async def generate_stream(
self,
tenant_id: str,
request: ChatRequest,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
SSE Event Sequence (per design.md Section 6.2):
1. message events (multiple) - each with incremental delta
2. final event (exactly 1) - with complete response
3. connection close
OR on error:
1. message events (0 or more)
2. error event (exactly 1)
3. connection close
"""
logger.info(
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
full_reply = ""
if self._llm_client:
async for event in self._stream_from_llm(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
else:
async for event in self._stream_mock_response(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
if ctx.llm_response is None:
ctx.llm_response = LLMResponse(
content=full_reply,
model="streaming",
usage={},
finish_reason="stop",
)
self._calculate_confidence(ctx)
await self._save_messages(ctx)
if await state_machine.transition_to_final():
yield create_final_event(
reply=full_reply,
confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5,
should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False,
transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None,
)
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
)
finally:
await state_machine.close()
async def _stream_from_llm(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
"""
messages = self._build_llm_messages(ctx)
async for chunk in self._llm_client.stream_generate(messages, self._llm_config):
if not state_machine.can_send_message():
break
if chunk.delta:
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
yield create_message_event(delta=chunk.delta)
if chunk.finish_reason:
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
break
async def _stream_mock_response(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
Simulates LLM-style incremental output.
"""
import asyncio
reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"]
for part in reply_parts:
if not state_machine.can_send_message():
break
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
yield create_message_event(delta=part)
await asyncio.sleep(0.05)
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
"""Extract delta content from a message event."""
import json
try:
if event.data:
data = json.loads(event.data)
return data.get("delta", "")
except (json.JSONDecodeError, TypeError):
pass
return ""
_orchestrator_service: OrchestratorService | None = None
def get_orchestrator_service() -> OrchestratorService:
"""Get or create orchestrator service instance."""
global _orchestrator_service
if _orchestrator_service is None:
_orchestrator_service = OrchestratorService()
return _orchestrator_service
def set_orchestrator_service(service: OrchestratorService) -> None:
"""Set orchestrator service instance for testing."""
global _orchestrator_service
_orchestrator_service = service

View File

@ -0,0 +1,57 @@
"""
Retrieval module for AI Service.
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
"""
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
from app.services.retrieval.metadata import (
ChunkMetadata,
ChunkMetadataModel,
MetadataFilter,
KnowledgeChunk,
RetrieveRequest,
RetrieveResult,
RetrievalStrategy,
)
from app.services.retrieval.optimized_retriever import (
OptimizedRetriever,
get_optimized_retriever,
TwoStageResult,
RRFCombiner,
)
from app.services.retrieval.indexer import (
KnowledgeIndexer,
get_knowledge_indexer,
IndexingProgress,
IndexingResult,
)
__all__ = [
"BaseRetriever",
"RetrievalContext",
"RetrievalHit",
"RetrievalResult",
"VectorRetriever",
"get_vector_retriever",
"ChunkMetadata",
"MetadataFilter",
"KnowledgeChunk",
"RetrieveRequest",
"RetrieveResult",
"RetrievalStrategy",
"OptimizedRetriever",
"get_optimized_retriever",
"TwoStageResult",
"RRFCombiner",
"KnowledgeIndexer",
"get_knowledge_indexer",
"IndexingProgress",
"IndexingResult",
]

View File

@ -0,0 +1,96 @@
"""
Retrieval layer for AI Service.
[AC-AISVC-16] Abstract base class for retrievers with plugin point support.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class RetrievalContext:
"""
[AC-AISVC-16] Context for retrieval operations.
Contains all necessary information for retrieval plugins.
"""
tenant_id: str
query: str
session_id: str | None = None
channel_type: str | None = None
metadata: dict[str, Any] | None = None
@dataclass
class RetrievalHit:
"""
[AC-AISVC-16] Single retrieval result hit.
Unified structure for all retriever types.
"""
text: str
score: float
source: str
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class RetrievalResult:
"""
[AC-AISVC-16] Result from retrieval operation.
Contains hits and optional diagnostics.
"""
hits: list[RetrievalHit] = field(default_factory=list)
diagnostics: dict[str, Any] | None = None
@property
def is_empty(self) -> bool:
"""Check if no hits were found."""
return len(self.hits) == 0
@property
def max_score(self) -> float:
"""Get the maximum score among hits."""
if not self.hits:
return 0.0
return max(hit.score for hit in self.hits)
@property
def hit_count(self) -> int:
"""Get the number of hits."""
return len(self.hits)
class BaseRetriever(ABC):
"""
[AC-AISVC-16] Abstract base class for retrievers.
Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid).
"""
@abstractmethod
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
[AC-AISVC-16] Retrieve relevant documents for the given context.
Args:
ctx: Retrieval context containing tenant_id, query, and optional metadata.
Returns:
RetrievalResult with hits and optional diagnostics.
"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""
Check if the retriever is healthy and ready to serve requests.
Returns:
True if healthy, False otherwise.
"""
pass

View File

@ -0,0 +1,339 @@
"""
Knowledge base indexing service with optimized embedding.
Reference: rag-optimization/spec.md Section 5.1
"""
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class IndexingProgress:
"""Progress tracking for indexing jobs."""
total_chunks: int = 0
processed_chunks: int = 0
failed_chunks: int = 0
current_document: str = ""
started_at: datetime = field(default_factory=datetime.utcnow)
@property
def progress_percent(self) -> int:
if self.total_chunks == 0:
return 0
return int((self.processed_chunks / self.total_chunks) * 100)
@property
def elapsed_seconds(self) -> float:
return (datetime.utcnow() - self.started_at).total_seconds()
@dataclass
class IndexingResult:
"""Result of an indexing operation."""
success: bool
total_chunks: int
indexed_chunks: int
failed_chunks: int
elapsed_seconds: float
error_message: str | None = None
class KnowledgeIndexer:
"""
Knowledge base indexer with optimized embedding.
Features:
- Task prefixes (search_document:) for document embedding
- Multi-dimensional vectors (256/512/768)
- Metadata support
- Batch processing
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 10,
):
self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._batch_size = batch_size
self._progress: IndexingProgress | None = None
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]:
"""
Split text into chunks for indexing.
Each line becomes a separate chunk for better retrieval granularity.
Args:
text: Full text to chunk
metadata: Metadata to attach to each chunk
Returns:
List of KnowledgeChunk objects
"""
chunks = []
doc_id = str(uuid.uuid4())
lines = text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if len(line) < 10:
continue
chunk = KnowledgeChunk(
chunk_id=f"{doc_id}_{i}",
document_id=doc_id,
content=line,
metadata=metadata or ChunkMetadata(),
)
chunks.append(chunk)
return chunks
def chunk_text_by_lines(
self,
text: str,
metadata: ChunkMetadata | None = None,
min_line_length: int = 10,
merge_short_lines: bool = False,
) -> list[KnowledgeChunk]:
"""
Split text by lines, each line is a separate chunk.
Args:
text: Full text to chunk
metadata: Metadata to attach to each chunk
min_line_length: Minimum line length to be indexed
merge_short_lines: Whether to merge consecutive short lines
Returns:
List of KnowledgeChunk objects
"""
chunks = []
doc_id = str(uuid.uuid4())
lines = text.split('\n')
if merge_short_lines:
merged_lines = []
current_line = ""
for line in lines:
line = line.strip()
if not line:
if current_line:
merged_lines.append(current_line)
current_line = ""
continue
if current_line:
current_line += " " + line
else:
current_line = line
if len(current_line) >= min_line_length * 2:
merged_lines.append(current_line)
current_line = ""
if current_line:
merged_lines.append(current_line)
lines = merged_lines
for i, line in enumerate(lines):
line = line.strip()
if len(line) < min_line_length:
continue
chunk = KnowledgeChunk(
chunk_id=f"{doc_id}_{i}",
document_id=doc_id,
content=line,
metadata=metadata or ChunkMetadata(),
)
chunks.append(chunk)
return chunks
async def index_document(
self,
tenant_id: str,
document_id: str,
text: str,
metadata: ChunkMetadata | None = None,
) -> IndexingResult:
"""
Index a single document with optimized embedding.
Args:
tenant_id: Tenant identifier
document_id: Document identifier
text: Document text content
metadata: Optional metadata for the document
Returns:
IndexingResult with status and statistics
"""
start_time = datetime.utcnow()
try:
client = await self._get_client()
provider = await self._get_embedding_provider()
await client.ensure_collection_exists(tenant_id, use_multi_vector=True)
chunks = self.chunk_text(text, metadata)
self._progress = IndexingProgress(
total_chunks=len(chunks),
current_document=document_id,
)
points = []
for i, chunk in enumerate(chunks):
try:
embedding_result = await provider.embed_document(chunk.content)
chunk.embedding_full = embedding_result.embedding_full
chunk.embedding_256 = embedding_result.embedding_256
chunk.embedding_512 = embedding_result.embedding_512
point = {
"id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant
"vector": {
"full": chunk.embedding_full,
"dim_256": chunk.embedding_256,
"dim_512": chunk.embedding_512,
},
"payload": {
"chunk_id": chunk.chunk_id,
"document_id": document_id,
"text": chunk.content,
"metadata": chunk.metadata.to_dict(),
"created_at": chunk.created_at.isoformat(),
}
}
points.append(point)
self._progress.processed_chunks += 1
logger.debug(
f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}"
)
except Exception as e:
logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}")
self._progress.failed_chunks += 1
if points:
await client.upsert_multi_vector(tenant_id, points)
elapsed = (datetime.utcnow() - start_time).total_seconds()
logger.info(
f"[RAG-OPT] Indexed document {document_id}: "
f"{len(points)} chunks in {elapsed:.2f}s"
)
return IndexingResult(
success=True,
total_chunks=len(chunks),
indexed_chunks=len(points),
failed_chunks=self._progress.failed_chunks,
elapsed_seconds=elapsed,
)
except Exception as e:
elapsed = (datetime.utcnow() - start_time).total_seconds()
logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}")
return IndexingResult(
success=False,
total_chunks=0,
indexed_chunks=0,
failed_chunks=0,
elapsed_seconds=elapsed,
error_message=str(e),
)
async def index_documents_batch(
self,
tenant_id: str,
documents: list[dict[str, Any]],
) -> list[IndexingResult]:
"""
Index multiple documents in batch.
Args:
tenant_id: Tenant identifier
documents: List of documents with format:
{
"document_id": str,
"text": str,
"metadata": ChunkMetadata (optional)
}
Returns:
List of IndexingResult for each document
"""
results = []
for doc in documents:
result = await self.index_document(
tenant_id=tenant_id,
document_id=doc["document_id"],
text=doc["text"],
metadata=doc.get("metadata"),
)
results.append(result)
return results
def get_progress(self) -> IndexingProgress | None:
"""Get current indexing progress."""
return self._progress
_knowledge_indexer: KnowledgeIndexer | None = None
def get_knowledge_indexer() -> KnowledgeIndexer:
"""Get or create KnowledgeIndexer instance."""
global _knowledge_indexer
if _knowledge_indexer is None:
_knowledge_indexer = KnowledgeIndexer()
return _knowledge_indexer

View File

@ -0,0 +1,210 @@
"""
Metadata models for RAG optimization.
Implements structured metadata for knowledge chunks.
Reference: rag-optimization/spec.md Section 3.2
"""
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
class RetrievalStrategy(str, Enum):
"""Retrieval strategy options."""
VECTOR_ONLY = "vector"
BM25_ONLY = "bm25"
HYBRID = "hybrid"
TWO_STAGE = "two_stage"
class ChunkMetadataModel(BaseModel):
"""Pydantic model for API serialization."""
category: str = ""
subcategory: str = ""
target_audience: list[str] = []
source_doc: str = ""
source_url: str = ""
department: str = ""
valid_from: str | None = None
valid_until: str | None = None
priority: int = 5
keywords: list[str] = []
@dataclass
class ChunkMetadata:
"""
Metadata for knowledge chunks.
Reference: rag-optimization/spec.md Section 3.2.2
"""
category: str = ""
subcategory: str = ""
target_audience: list[str] = field(default_factory=list)
source_doc: str = ""
source_url: str = ""
department: str = ""
valid_from: date | None = None
valid_until: date | None = None
priority: int = 5
keywords: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for storage."""
return {
"category": self.category,
"subcategory": self.subcategory,
"target_audience": self.target_audience,
"source_doc": self.source_doc,
"source_url": self.source_url,
"department": self.department,
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
"priority": self.priority,
"keywords": self.keywords,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
"""Create from dictionary."""
return cls(
category=data.get("category", ""),
subcategory=data.get("subcategory", ""),
target_audience=data.get("target_audience", []),
source_doc=data.get("source_doc", ""),
source_url=data.get("source_url", ""),
department=data.get("department", ""),
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
priority=data.get("priority", 5),
keywords=data.get("keywords", []),
)
@dataclass
class MetadataFilter:
"""
Filter conditions for metadata-based retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
categories: list[str] | None = None
target_audiences: list[str] | None = None
departments: list[str] | None = None
valid_only: bool = True
min_priority: int | None = None
keywords: list[str] | None = None
def to_qdrant_filter(self) -> dict[str, Any] | None:
"""Convert to Qdrant filter format."""
conditions = []
if self.categories:
conditions.append({
"key": "metadata.category",
"match": {"any": self.categories}
})
if self.departments:
conditions.append({
"key": "metadata.department",
"match": {"any": self.departments}
})
if self.target_audiences:
conditions.append({
"key": "metadata.target_audience",
"match": {"any": self.target_audiences}
})
if self.valid_only:
today = date.today().isoformat()
conditions.append({
"should": [
{"key": "metadata.valid_until", "match": {"value": None}},
{"key": "metadata.valid_until", "range": {"gte": today}}
]
})
if self.min_priority is not None:
conditions.append({
"key": "metadata.priority",
"range": {"lte": self.min_priority}
})
if not conditions:
return None
if len(conditions) == 1:
return {"must": conditions}
return {"must": conditions}
@dataclass
class KnowledgeChunk:
"""
Knowledge chunk with multi-dimensional embeddings.
Reference: rag-optimization/spec.md Section 3.2.1
"""
chunk_id: str
document_id: str
content: str
embedding_full: list[float] = field(default_factory=list)
embedding_256: list[float] = field(default_factory=list)
embedding_512: list[float] = field(default_factory=list)
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
"""Convert to Qdrant point format."""
return {
"id": point_id,
"vector": {
"full": self.embedding_full,
"dim_256": self.embedding_256,
"dim_512": self.embedding_512,
},
"payload": {
"chunk_id": self.chunk_id,
"document_id": self.document_id,
"text": self.content,
"metadata": self.metadata.to_dict(),
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
}
@dataclass
class RetrieveRequest:
"""
Request for knowledge retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
query: str
query_with_prefix: str = ""
top_k: int = 10
filters: MetadataFilter | None = None
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
def __post_init__(self):
if not self.query_with_prefix:
self.query_with_prefix = f"search_query:{self.query}"
@dataclass
class RetrieveResult:
"""
Result from knowledge retrieval.
Reference: rag-optimization/spec.md Section 4.1
"""
chunk_id: str
content: str
score: float
vector_score: float = 0.0
bm25_score: float = 0.0
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
rank: int = 0

View File

@ -0,0 +1,509 @@
"""
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
"""
import asyncio
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalHit,
RetrievalResult,
)
from app.services.retrieval.metadata import (
ChunkMetadata,
MetadataFilter,
RetrieveResult,
RetrievalStrategy,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class TwoStageResult:
"""Result from two-stage retrieval."""
candidates: list[dict[str, Any]]
final_results: list[RetrieveResult]
stage1_latency_ms: float = 0.0
stage2_latency_ms: float = 0.0
class RRFCombiner:
"""
Reciprocal Rank Fusion for combining multiple retrieval results.
Reference: rag-optimization/spec.md Section 2.5
Formula: score = Σ(1 / (k + rank_i))
Default k = 60
"""
def __init__(self, k: int = 60):
self._k = k
def combine(
self,
vector_results: list[dict[str, Any]],
bm25_results: list[dict[str, Any]],
vector_weight: float = 0.7,
bm25_weight: float = 0.3,
) -> list[dict[str, Any]]:
"""
Combine vector and BM25 results using RRF.
Args:
vector_results: Results from vector search
bm25_results: Results from BM25 search
vector_weight: Weight for vector results
bm25_weight: Weight for BM25 results
Returns:
Combined and sorted results
"""
combined_scores: dict[str, dict[str, Any]] = {}
for rank, result in enumerate(vector_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = vector_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": result.get("score", 0.0),
"bm25_score": 0.0,
"vector_rank": rank,
"bm25_rank": -1,
"payload": result.get("payload", {}),
"id": chunk_id,
}
combined_scores[chunk_id]["score"] += rrf_score
for rank, result in enumerate(bm25_results):
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
rrf_score = bm25_weight / (self._k + rank + 1)
if chunk_id not in combined_scores:
combined_scores[chunk_id] = {
"score": 0.0,
"vector_score": 0.0,
"bm25_score": result.get("score", 0.0),
"vector_rank": -1,
"bm25_rank": rank,
"payload": result.get("payload", {}),
"id": chunk_id,
}
else:
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
combined_scores[chunk_id]["bm25_rank"] = rank
combined_scores[chunk_id]["score"] += rrf_score
sorted_results = sorted(
combined_scores.values(),
key=lambda x: x["score"],
reverse=True
)
return sorted_results
class OptimizedRetriever(BaseRetriever):
"""
Optimized retriever with:
- Task prefixes (search_document/search_query)
- Two-stage retrieval (256 dim -> 768 dim)
- RRF hybrid ranking (vector + BM25)
- Metadata filtering
Reference: rag-optimization/spec.md Section 2, 3, 4
"""
def __init__(
self,
qdrant_client: QdrantClient | None = None,
embedding_provider: NomicEmbeddingProvider | None = None,
top_k: int | None = None,
score_threshold: float | None = None,
min_hits: int | None = None,
two_stage_enabled: bool | None = None,
two_stage_expand_factor: int | None = None,
hybrid_enabled: bool | None = None,
rrf_k: int | None = None,
):
self._qdrant_client = qdrant_client
self._embedding_provider = embedding_provider
self._top_k = top_k or settings.rag_top_k
self._score_threshold = score_threshold or settings.rag_score_threshold
self._min_hits = min_hits or settings.rag_min_hits
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
self._rrf_k = rrf_k or settings.rag_rrf_k
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager()
provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider):
self._embedding_provider = provider
else:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
"""
Retrieve documents using optimized strategy.
Strategy selection:
1. If two_stage_enabled: use two-stage retrieval
2. If hybrid_enabled: use RRF hybrid ranking
3. Otherwise: simple vector search
"""
logger.info(
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
)
logger.info(
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
)
try:
provider = await self._get_embedding_provider()
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
embedding_result = await provider.embed_query(ctx.query)
logger.info(
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
)
if self._two_stage_enabled:
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
results = await self._two_stage_retrieve(
ctx.tenant_id,
embedding_result,
self._top_k,
)
elif self._hybrid_enabled:
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
results = await self._hybrid_retrieve(
ctx.tenant_id,
embedding_result,
ctx.query,
self._top_k,
)
else:
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
results = await self._vector_retrieve(
ctx.tenant_id,
embedding_result.embedding_full,
self._top_k,
)
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
retrieval_hits = [
RetrievalHit(
text=result.get("payload", {}).get("text", ""),
score=result.get("score", 0.0),
source="optimized_rag",
metadata=result.get("payload", {}),
)
for result in results
if result.get("score", 0.0) >= self._score_threshold
]
filtered_count = len(results) - len(retrieval_hits)
if filtered_count > 0:
logger.info(
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
)
is_insufficient = len(retrieval_hits) < self._min_hits
diagnostics = {
"query_length": len(ctx.query),
"top_k": self._top_k,
"score_threshold": self._score_threshold,
"two_stage_enabled": self._two_stage_enabled,
"hybrid_enabled": self._hybrid_enabled,
"total_hits": len(retrieval_hits),
"is_insufficient": is_insufficient,
"max_score": max((h.score for h in retrieval_hits), default=0.0),
"raw_results_count": len(results),
"filtered_below_threshold": filtered_count,
}
logger.info(
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
)
if len(retrieval_hits) == 0:
logger.warning(
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
f"raw_results={len(results)}, threshold={self._score_threshold}"
)
return RetrievalResult(
hits=retrieval_hits,
diagnostics=diagnostics,
)
except Exception as e:
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
return RetrievalResult(
hits=[],
diagnostics={"error": str(e), "is_insufficient": True},
)
async def _two_stage_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
top_k: int,
) -> list[dict[str, Any]]:
"""
Two-stage retrieval using Matryoshka dimensions.
Stage 1: Fast retrieval with 256-dim vectors
Stage 2: Precise reranking with 768-dim vectors
Reference: rag-optimization/spec.md Section 2.4
"""
import time
client = await self._get_client()
stage1_start = time.perf_counter()
candidates = await self._search_with_dimension(
client, tenant_id, embedding_result.embedding_256, "dim_256",
top_k * self._two_stage_expand_factor
)
stage1_latency = (time.perf_counter() - stage1_start) * 1000
logger.debug(
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
)
stage2_start = time.perf_counter()
reranked = []
for candidate in candidates:
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
if stored_full_embedding:
import numpy as np
similarity = self._cosine_similarity(
embedding_result.embedding_full,
stored_full_embedding
)
candidate["score"] = similarity
candidate["stage"] = "reranked"
reranked.append(candidate)
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
results = reranked[:top_k]
stage2_latency = (time.perf_counter() - stage2_start) * 1000
logger.debug(
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
)
return results
async def _hybrid_retrieve(
self,
tenant_id: str,
embedding_result: NomicEmbeddingResult,
query: str,
top_k: int,
) -> list[dict[str, Any]]:
"""
Hybrid retrieval using RRF to combine vector and BM25 results.
Reference: rag-optimization/spec.md Section 2.5
"""
client = await self._get_client()
vector_task = self._search_with_dimension(
client, tenant_id, embedding_result.embedding_full, "full",
top_k * 2
)
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
vector_results, bm25_results = await asyncio.gather(
vector_task, bm25_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
vector_results = []
if isinstance(bm25_results, Exception):
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
bm25_results = []
combined = self._rrf_combiner.combine(
vector_results,
bm25_results,
vector_weight=settings.rag_vector_weight,
bm25_weight=settings.rag_bm25_weight,
)
return combined[:top_k]
async def _vector_retrieve(
self,
tenant_id: str,
embedding: list[float],
top_k: int,
) -> list[dict[str, Any]]:
"""Simple vector retrieval."""
client = await self._get_client()
return await self._search_with_dimension(
client, tenant_id, embedding, "full", top_k
)
async def _search_with_dimension(
self,
client: QdrantClient,
tenant_id: str,
query_vector: list[float],
vector_name: str,
limit: int,
) -> list[dict[str, Any]]:
"""Search using specified vector dimension."""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
logger.info(
f"[RAG-OPT] Searching collection={collection_name}, "
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
)
results = await qdrant.search(
collection_name=collection_name,
query_vector=(vector_name, query_vector),
limit=limit,
)
logger.info(
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
)
if len(results) > 0:
for i, r in enumerate(results[:3]):
logger.debug(
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
)
return [
{
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
}
for result in results
]
except Exception as e:
logger.error(
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
f"collection_name={client.get_collection_name(tenant_id)}",
exc_info=True
)
return []
async def _bm25_search(
self,
client: QdrantClient,
tenant_id: str,
query: str,
limit: int,
) -> list[dict[str, Any]]:
"""
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
This is a simplified implementation; for production, use Elasticsearch.
"""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
query_terms = set(re.findall(r'\w+', query.lower()))
results = await qdrant.scroll(
collection_name=collection_name,
limit=limit * 3,
with_payload=True,
)
scored_results = []
for point in results[0]:
text = point.payload.get("text", "").lower()
text_terms = set(re.findall(r'\w+', text))
overlap = len(query_terms & text_terms)
if overlap > 0:
score = overlap / (len(query_terms) + len(text_terms) - overlap)
scored_results.append({
"id": str(point.id),
"score": score,
"payload": point.payload or {},
})
scored_results.sort(key=lambda x: x["score"], reverse=True)
return scored_results[:limit]
except Exception as e:
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
return []
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
import numpy as np
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def health_check(self) -> bool:
"""Check if retriever is healthy."""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[RAG-OPT] Health check failed: {e}")
return False
_optimized_retriever: OptimizedRetriever | None = None
async def get_optimized_retriever() -> OptimizedRetriever:
"""Get or create OptimizedRetriever instance."""
global _optimized_retriever
if _optimized_retriever is None:
_optimized_retriever = OptimizedRetriever()
return _optimized_retriever

Some files were not shown because too many files have changed in this diff Show More