Compare commits
10 Commits
08e84d194f
...
f631f1dea0
| Author | SHA1 | Date |
|---|---|---|
|
|
f631f1dea0 | |
|
|
dd74ae2585 | |
|
|
717d5328cf | |
|
|
02f03a3a12 | |
|
|
cee884d9a0 | |
|
|
774744d534 | |
|
|
ac8c33cf94 | |
|
|
a23f1a2089 | |
|
|
eb45629b67 | |
|
|
10591ea8fd |
|
|
@ -42,15 +42,19 @@
|
|||
</div>
|
||||
<div class="header-right">
|
||||
<div class="tenant-selector">
|
||||
<el-select
|
||||
v-model="currentTenantId"
|
||||
placeholder="选择租户"
|
||||
<el-select
|
||||
v-model="currentTenantId"
|
||||
placeholder="选择租户"
|
||||
size="default"
|
||||
:loading="loading"
|
||||
@change="handleTenantChange"
|
||||
>
|
||||
<el-option label="默认租户" value="default" />
|
||||
<el-option label="租户 A" value="tenant_a" />
|
||||
<el-option label="租户 B" value="tenant_b" />
|
||||
<el-option
|
||||
v-for="tenant in tenantList"
|
||||
:key="tenant.id"
|
||||
:label="tenant.name"
|
||||
:value="tenant.id"
|
||||
/>
|
||||
</el-select>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -62,15 +66,19 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
import { getTenantList, type Tenant } from '@/api/tenant'
|
||||
import { Odometer, FolderOpened, Cpu, Monitor, Connection, ChatDotSquare } from '@element-plus/icons-vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
|
||||
const route = useRoute()
|
||||
const tenantStore = useTenantStore()
|
||||
|
||||
const currentTenantId = ref(tenantStore.currentTenantId)
|
||||
const tenantList = ref<Tenant[]>([])
|
||||
const loading = ref(false)
|
||||
|
||||
const isActive = (path: string) => {
|
||||
return route.path === path || route.path.startsWith(path + '/')
|
||||
|
|
@ -78,7 +86,47 @@ const isActive = (path: string) => {
|
|||
|
||||
const handleTenantChange = (val: string) => {
|
||||
tenantStore.setTenant(val)
|
||||
// 刷新页面以加载新租户的数据
|
||||
window.location.reload()
|
||||
}
|
||||
|
||||
// Validate tenant ID format: name@ash@year
|
||||
const isValidTenantId = (tenantId: string): boolean => {
|
||||
return /^[^@]+@ash@\d{4}$/.test(tenantId)
|
||||
}
|
||||
|
||||
const fetchTenantList = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
// 检查当前租户ID格式是否有效
|
||||
if (!isValidTenantId(currentTenantId.value)) {
|
||||
console.warn('Invalid tenant ID format, resetting to default:', currentTenantId.value)
|
||||
currentTenantId.value = 'default@ash@2026'
|
||||
tenantStore.setTenant(currentTenantId.value)
|
||||
}
|
||||
|
||||
const response = await getTenantList()
|
||||
tenantList.value = response.tenants || []
|
||||
|
||||
// 如果当前租户不在列表中,默认选择第一个
|
||||
if (tenantList.value.length > 0 && !tenantList.value.find(t => t.id === currentTenantId.value)) {
|
||||
const firstTenant = tenantList.value[0]
|
||||
currentTenantId.value = firstTenant.id
|
||||
tenantStore.setTenant(firstTenant.id)
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取租户列表失败')
|
||||
console.error('Failed to fetch tenant list:', error)
|
||||
// 失败时使用默认租户
|
||||
tenantList.value = [{ id: 'default@ash@2026', name: 'default (2026)' }]
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchTenantList()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import request from '@/utils/request'
|
||||
import { useTenantStore } from '@/stores/tenant'
|
||||
|
||||
export interface AIResponse {
|
||||
content: string
|
||||
|
|
@ -73,6 +74,8 @@ export function createSSEConnection(
|
|||
const baseUrl = import.meta.env.VITE_APP_BASE_API || '/api'
|
||||
const fullUrl = `${baseUrl}${url}`
|
||||
|
||||
const tenantStore = useTenantStore()
|
||||
|
||||
const controller = new AbortController()
|
||||
|
||||
fetch(fullUrl, {
|
||||
|
|
@ -80,6 +83,7 @@ export function createSSEConnection(
|
|||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream',
|
||||
'X-Tenant-Id': tenantStore.currentTenantId || '',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
import request from '@/utils/request'
|
||||
|
||||
export interface Tenant {
|
||||
id: string
|
||||
name: string
|
||||
displayName: string
|
||||
year: string
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export interface TenantListResponse {
|
||||
tenants: Tenant[]
|
||||
total: number
|
||||
}
|
||||
|
||||
export function getTenantList() {
|
||||
return request<TenantListResponse>({
|
||||
url: '/admin/tenants',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
|
@ -1,14 +1,15 @@
|
|||
<template>
|
||||
<el-select
|
||||
:model-value="modelValue"
|
||||
:model-value="displayValue"
|
||||
:loading="loading"
|
||||
:placeholder="placeholder"
|
||||
:placeholder="computedPlaceholder"
|
||||
:disabled="disabled"
|
||||
:clearable="clearable"
|
||||
:teleported="true"
|
||||
:popper-class="popperClass"
|
||||
:popper-options="{ modifiers: [{ name: 'flip', enabled: true }, { name: 'preventOverflow', enabled: true }, { name: 'computeStyles', options: { adaptive: false, gpuAcceleration: false } }] }"
|
||||
@update:model-value="handleChange"
|
||||
@clear="handleClear"
|
||||
>
|
||||
<el-option
|
||||
v-for="provider in providers"
|
||||
|
|
@ -24,12 +25,16 @@
|
|||
<el-tag v-if="provider.name === currentProvider" type="success" size="small" effect="plain" class="current-tag">
|
||||
当前配置
|
||||
</el-tag>
|
||||
<el-tag v-else-if="provider.name === modelValue" type="primary" size="small" effect="plain" class="selected-tag">
|
||||
已选择
|
||||
</el-tag>
|
||||
</div>
|
||||
</el-option>
|
||||
</el-select>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import type { LLMProviderInfo } from '@/api/llm'
|
||||
|
||||
const popperClass = 'llm-selector-popper'
|
||||
|
|
@ -59,22 +64,48 @@ const emit = defineEmits<{
|
|||
change: [provider: LLMProviderInfo | undefined]
|
||||
}>()
|
||||
|
||||
const displayValue = computed(() => {
|
||||
return props.modelValue || ''
|
||||
})
|
||||
|
||||
const computedPlaceholder = computed(() => {
|
||||
if (props.modelValue) {
|
||||
return props.placeholder
|
||||
}
|
||||
if (props.currentProvider) {
|
||||
const current = props.providers.find(p => p.name === props.currentProvider)
|
||||
return `默认: ${current?.display_name || props.currentProvider}`
|
||||
}
|
||||
return props.placeholder
|
||||
})
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
emit('update:modelValue', value)
|
||||
const selectedProvider = props.providers.find((p) => p.name === value)
|
||||
emit('change', selectedProvider)
|
||||
}
|
||||
|
||||
const handleClear = () => {
|
||||
emit('update:modelValue', '')
|
||||
emit('change', undefined)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.llm-selector-popper {
|
||||
min-width: 300px !important;
|
||||
min-width: 320px !important;
|
||||
z-index: 9999 !important;
|
||||
}
|
||||
|
||||
.llm-selector-popper .el-select-dropdown__wrap {
|
||||
max-height: 400px;
|
||||
}
|
||||
|
||||
.llm-selector-popper .el-select-dropdown__item {
|
||||
height: auto;
|
||||
padding: 8px 12px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -93,6 +124,7 @@ const handleChange = (value: string) => {
|
|||
line-height: 1.5;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
|
|
@ -116,5 +148,12 @@ const handleChange = (value: string) => {
|
|||
.current-tag {
|
||||
flex-shrink: 0;
|
||||
margin-left: 8px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.selected-tag {
|
||||
flex-shrink: 0;
|
||||
margin-left: 8px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import { defineStore } from 'pinia'
|
||||
|
||||
// Default tenant ID format: name@ash@year
|
||||
const DEFAULT_TENANT_ID = 'default@ash@2026'
|
||||
|
||||
export const useTenantStore = defineStore('tenant', {
|
||||
state: () => ({
|
||||
currentTenantId: localStorage.getItem('X-Tenant-Id') || 'default'
|
||||
currentTenantId: localStorage.getItem('X-Tenant-Id') || DEFAULT_TENANT_ID
|
||||
}),
|
||||
actions: {
|
||||
setTenant(id: string) {
|
||||
|
|
|
|||
|
|
@ -9,5 +9,6 @@ from app.api.admin.kb import router as kb_router
|
|||
from app.api.admin.llm import router as llm_router
|
||||
from app.api.admin.rag import router as rag_router
|
||||
from app.api.admin.sessions import router as sessions_router
|
||||
from app.api.admin.tenants import router as tenants_router
|
||||
|
||||
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router"]
|
||||
__all__ = ["dashboard_router", "embedding_router", "kb_router", "llm_router", "rag_router", "sessions_router", "tenants_router"]
|
||||
|
|
|
|||
|
|
@ -37,6 +37,42 @@ class TextChunk:
|
|||
source: str | None = None
|
||||
|
||||
|
||||
def chunk_text_by_lines(
|
||||
text: str,
|
||||
min_line_length: int = 10,
|
||||
source: str | None = None,
|
||||
) -> list[TextChunk]:
|
||||
"""
|
||||
按行分块,每行作为一个独立的检索单元。
|
||||
|
||||
Args:
|
||||
text: 要分块的文本
|
||||
min_line_length: 最小行长度,低于此长度的行会被跳过
|
||||
source: 来源文件路径(可选)
|
||||
|
||||
Returns:
|
||||
分块列表,每个块对应一行文本
|
||||
"""
|
||||
lines = text.split('\n')
|
||||
chunks: list[TextChunk] = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < min_line_length:
|
||||
continue
|
||||
|
||||
chunks.append(TextChunk(
|
||||
text=line,
|
||||
start_token=i,
|
||||
end_token=i + 1,
|
||||
page=None,
|
||||
source=source,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_text_with_tiktoken(
|
||||
text: str,
|
||||
chunk_size: int = 512,
|
||||
|
|
@ -318,8 +354,19 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||||
|
||||
if file_ext in text_extensions or not file_ext:
|
||||
logger.info(f"[INDEX] Treating as text file, decoding with UTF-8")
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
logger.info(f"[INDEX] Treating as text file, trying multiple encodings")
|
||||
text = None
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
|
||||
try:
|
||||
text = content.decode(encoding)
|
||||
logger.info(f"[INDEX] Successfully decoded with encoding: {encoding}")
|
||||
break
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
logger.warning(f"[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
|
||||
else:
|
||||
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
|
||||
await kb_service.update_job_status(
|
||||
|
|
@ -374,23 +421,22 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
|
|||
all_chunks: list[TextChunk] = []
|
||||
|
||||
if parse_result and parse_result.pages:
|
||||
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using tiktoken chunking with page metadata")
|
||||
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
|
||||
for page in parse_result.pages:
|
||||
page_chunks = chunk_text_with_tiktoken(
|
||||
page_chunks = chunk_text_by_lines(
|
||||
page.text,
|
||||
chunk_size=512,
|
||||
overlap=100,
|
||||
page=page.page,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
for pc in page_chunks:
|
||||
pc.page = page.page
|
||||
all_chunks.extend(page_chunks)
|
||||
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
|
||||
else:
|
||||
logger.info(f"[INDEX] Using tiktoken chunking without page metadata")
|
||||
all_chunks = chunk_text_with_tiktoken(
|
||||
logger.info(f"[INDEX] Using line-based chunking")
|
||||
all_chunks = chunk_text_by_lines(
|
||||
text,
|
||||
chunk_size=512,
|
||||
overlap=100,
|
||||
min_line_length=10,
|
||||
source=filename,
|
||||
)
|
||||
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
Knowledge base management API with RAG optimization features.
|
||||
Reference: rag-optimization/spec.md Section 4.2
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.services.retrieval import (
|
||||
ChunkMetadata,
|
||||
ChunkMetadataModel,
|
||||
IndexingProgress,
|
||||
IndexingResult,
|
||||
KnowledgeIndexer,
|
||||
MetadataFilter,
|
||||
RetrievalStrategy,
|
||||
get_knowledge_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/kb", tags=["Knowledge Base"])
|
||||
|
||||
|
||||
class IndexDocumentRequest(BaseModel):
|
||||
"""Request to index a document."""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
document_id: str = Field(..., description="Document ID")
|
||||
text: str = Field(..., description="Document text content")
|
||||
metadata: ChunkMetadataModel | None = Field(default=None, description="Document metadata")
|
||||
|
||||
|
||||
class IndexDocumentResponse(BaseModel):
|
||||
"""Response from document indexing."""
|
||||
success: bool
|
||||
total_chunks: int
|
||||
indexed_chunks: int
|
||||
failed_chunks: int
|
||||
elapsed_seconds: float
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class IndexingProgressResponse(BaseModel):
|
||||
"""Response with current indexing progress."""
|
||||
total_chunks: int
|
||||
processed_chunks: int
|
||||
failed_chunks: int
|
||||
progress_percent: int
|
||||
elapsed_seconds: float
|
||||
current_document: str
|
||||
|
||||
|
||||
class MetadataFilterRequest(BaseModel):
|
||||
"""Request for metadata filtering."""
|
||||
categories: list[str] | None = None
|
||||
target_audiences: list[str] | None = None
|
||||
departments: list[str] | None = None
|
||||
valid_only: bool = True
|
||||
min_priority: int | None = None
|
||||
keywords: list[str] | None = None
|
||||
|
||||
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Request for knowledge retrieval."""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
query: str = Field(..., description="Search query")
|
||||
top_k: int = Field(default=10, ge=1, le=50, description="Number of results")
|
||||
filters: MetadataFilterRequest | None = Field(default=None, description="Metadata filters")
|
||||
strategy: RetrievalStrategy = Field(default=RetrievalStrategy.HYBRID, description="Retrieval strategy")
|
||||
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
"""Response from knowledge retrieval."""
|
||||
hits: list[dict[str, Any]]
|
||||
total_hits: int
|
||||
max_score: float
|
||||
is_insufficient: bool
|
||||
diagnostics: dict[str, Any]
|
||||
|
||||
|
||||
class MetadataOptionsResponse(BaseModel):
|
||||
"""Response with available metadata options."""
|
||||
categories: list[str]
|
||||
departments: list[str]
|
||||
target_audiences: list[str]
|
||||
priorities: list[int]
|
||||
|
||||
|
||||
@router.post("/index", response_model=IndexDocumentResponse)
|
||||
async def index_document(
|
||||
request: IndexDocumentRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Index a document with optimized embedding.
|
||||
|
||||
Features:
|
||||
- Task prefixes (search_document:) for document embedding
|
||||
- Multi-dimensional vectors (256/512/768)
|
||||
- Metadata support
|
||||
"""
|
||||
try:
|
||||
index = get_knowledge_indexer()
|
||||
|
||||
chunk_metadata = None
|
||||
if request.metadata:
|
||||
chunk_metadata = ChunkMetadata(
|
||||
category=request.metadata.category,
|
||||
subcategory=request.metadata.subcategory,
|
||||
target_audience=request.metadata.target_audience,
|
||||
source_doc=request.metadata.source_doc,
|
||||
source_url=request.metadata.source_url,
|
||||
department=request.metadata.department,
|
||||
priority=request.metadata.priority,
|
||||
keywords=request.metadata.keywords,
|
||||
)
|
||||
|
||||
result = await index.index_document(
|
||||
tenant_id=request.tenant_id,
|
||||
document_id=request.document_id,
|
||||
text=request.text,
|
||||
metadata=chunk_metadata,
|
||||
)
|
||||
|
||||
return IndexDocumentResponse(
|
||||
success=result.success,
|
||||
total_chunks=result.total_chunks,
|
||||
indexed_chunks=result.indexed_chunks,
|
||||
failed_chunks=result.failed_chunks,
|
||||
elapsed_seconds=result.elapsed_seconds,
|
||||
error_message=result.error_message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to index document: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"索引失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/index/progress", response_model=IndexingProgressResponse | None)
|
||||
async def get_indexing_progress():
|
||||
"""Get current indexing progress."""
|
||||
try:
|
||||
index = get_knowledge_indexer()
|
||||
progress = index.get_progress()
|
||||
|
||||
if progress is None:
|
||||
return None
|
||||
|
||||
return IndexingProgressResponse(
|
||||
total_chunks=progress.total_chunks,
|
||||
processed_chunks=progress.processed_chunks,
|
||||
failed_chunks=progress.failed_chunks,
|
||||
progress_percent=progress.progress_percent,
|
||||
elapsed_seconds=progress.elapsed_seconds,
|
||||
current_document=progress.current_document,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to get progress: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取进度失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve_knowledge(request: RetrieveRequest):
|
||||
"""
|
||||
Retrieve knowledge using optimized RAG.
|
||||
|
||||
Strategies:
|
||||
- vector: Simple vector search
|
||||
- bm25: BM25 keyword search
|
||||
- hybrid: RRF combination of vector + BM25 (default)
|
||||
- two_stage: Two-stage retrieval with Matryoshka dimensions
|
||||
"""
|
||||
try:
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
metadata_filter = None
|
||||
if request.filters:
|
||||
filter_dict = request.filters.model_dump(exclude_none=True)
|
||||
metadata_filter = MetadataFilter(**filter_dict)
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=request.tenant_id,
|
||||
query=request.query,
|
||||
)
|
||||
|
||||
if metadata_filter:
|
||||
ctx.metadata = {"filter": metadata_filter.to_qdrant_filter()}
|
||||
|
||||
result = await retriever.retrieve(ctx)
|
||||
|
||||
return RetrieveResponse(
|
||||
hits=[
|
||||
{
|
||||
"text": hit.text,
|
||||
"score": hit.score,
|
||||
"source": hit.source,
|
||||
"metadata": hit.metadata,
|
||||
}
|
||||
for hit in result.hits
|
||||
],
|
||||
total_hits=result.hit_count,
|
||||
max_score=result.max_score,
|
||||
is_insufficient=result.diagnostics.get("is_insufficient", False),
|
||||
diagnostics=result.diagnostics or {},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to retrieve: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"检索失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metadata/options", response_model=MetadataOptionsResponse)
|
||||
async def get_metadata_options():
|
||||
"""
|
||||
Get available metadata options for filtering.
|
||||
These would typically be loaded from a database.
|
||||
"""
|
||||
try:
|
||||
return MetadataOptionsResponse(
|
||||
categories=[
|
||||
"课程咨询",
|
||||
"考试政策",
|
||||
"学籍管理",
|
||||
"奖助学金",
|
||||
"宿舍管理",
|
||||
"校园服务",
|
||||
"就业指导",
|
||||
"其他",
|
||||
],
|
||||
departments=[
|
||||
"教务处",
|
||||
"学生处",
|
||||
"财务处",
|
||||
"后勤处",
|
||||
"就业指导中心",
|
||||
"图书馆",
|
||||
"信息中心",
|
||||
],
|
||||
target_audiences=[
|
||||
"本科生",
|
||||
"研究生",
|
||||
"留学生",
|
||||
"新生",
|
||||
"毕业生",
|
||||
"教职工",
|
||||
],
|
||||
priorities=list(range(1, 11)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to get metadata options: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reindex")
|
||||
async def reindex_all(
|
||||
tenant_id: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Reindex all documents for a tenant with optimized embedding.
|
||||
This would typically read from the documents table and reindex.
|
||||
"""
|
||||
try:
|
||||
from app.models.entities import Document, DocumentStatus
|
||||
|
||||
stmt = select(Document).where(
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.status == DocumentStatus.COMPLETED.value,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
documents = result.scalars().all()
|
||||
|
||||
index = get_knowledge_indexer()
|
||||
|
||||
total_indexed = 0
|
||||
total_failed = 0
|
||||
|
||||
for doc in documents:
|
||||
if doc.file_path:
|
||||
import os
|
||||
if os.path.exists(doc.file_path):
|
||||
with open(doc.file_path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
result = await index.index_document(
|
||||
tenant_id=tenant_id,
|
||||
document_id=str(doc.id),
|
||||
text=text,
|
||||
)
|
||||
|
||||
total_indexed += result.indexed_chunks
|
||||
total_failed += result.failed_chunks
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"total_documents": len(documents),
|
||||
"total_indexed": total_indexed,
|
||||
"total_failed": total_failed,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KB-API] Failed to reindex: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"重新索引失败: {str(e)}"
|
||||
)
|
||||
|
|
@ -14,9 +14,11 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from app.core.config import get_settings
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.prompts import format_evidence_for_prompt, build_user_prompt_with_evidence
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
|
||||
|
|
@ -91,7 +93,8 @@ async def run_rag_experiment(
|
|||
threshold = request.score_threshold or settings.rag_score_threshold
|
||||
|
||||
try:
|
||||
retriever = await get_vector_retriever()
|
||||
# Use optimized retriever with RAG enhancements
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -199,7 +202,8 @@ async def run_rag_experiment_stream(
|
|||
|
||||
async def event_generator():
|
||||
try:
|
||||
retriever = await get_vector_retriever()
|
||||
# Use optimized retriever with RAG enhancements
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -223,6 +227,11 @@ async def run_rag_experiment_stream(
|
|||
|
||||
final_prompt = _build_final_prompt(request.query, retrieval_results)
|
||||
|
||||
logger.info(f"[AC-ASA-20] ========== RAG LAB STREAM FULL PROMPT ==========")
|
||||
logger.info(f"[AC-ASA-20] Prompt length: {len(final_prompt)}")
|
||||
logger.info(f"[AC-ASA-20] Prompt content:\n{final_prompt}")
|
||||
logger.info(f"[AC-ASA-20] ==============================================")
|
||||
|
||||
yield f"event: retrieval\ndata: {json.dumps({'results': retrieval_results, 'count': len(retrieval_results)})}\n\n"
|
||||
|
||||
yield f"event: prompt\ndata: {json.dumps({'prompt': final_prompt})}\n\n"
|
||||
|
|
@ -267,6 +276,11 @@ async def _generate_ai_response(
|
|||
"""
|
||||
import time
|
||||
|
||||
logger.info(f"[AC-ASA-19] ========== RAG LAB FULL PROMPT ==========")
|
||||
logger.info(f"[AC-ASA-19] Prompt length: {len(prompt)}")
|
||||
logger.info(f"[AC-ASA-19] Prompt content:\n{prompt}")
|
||||
logger.info(f"[AC-ASA-19] ==========================================")
|
||||
|
||||
try:
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_client()
|
||||
|
|
@ -297,25 +311,10 @@ async def _generate_ai_response(
|
|||
def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str:
|
||||
"""
|
||||
Build the final prompt from query and retrieval results.
|
||||
Uses shared prompt configuration for consistency with orchestrator.
|
||||
"""
|
||||
if not retrieval_results:
|
||||
return f"""用户问题:{query}
|
||||
|
||||
未找到相关检索结果,请基于通用知识回答用户问题。"""
|
||||
|
||||
evidence_text = "\n".join([
|
||||
f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}"
|
||||
for i, hit in enumerate(retrieval_results[:5])
|
||||
])
|
||||
|
||||
return f"""基于以下检索到的信息,作为一个回答简洁精准的客服,回答用户问题:
|
||||
|
||||
用户问题:{query}
|
||||
|
||||
检索结果:
|
||||
{evidence_text}
|
||||
|
||||
请基于以上信息生成专业、准确的回答,注意输出内容应该格式整齐,不包含json符号等。"""
|
||||
evidence_text = format_evidence_for_prompt(retrieval_results, max_results=5, max_content_length=500)
|
||||
return build_user_prompt_with_evidence(query, evidence_text)
|
||||
|
||||
|
||||
def _get_fallback_results(query: str) -> list[dict]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
Tenant management endpoints.
|
||||
Provides tenant list and management functionality.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.middleware import parse_tenant_id
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ErrorResponse
|
||||
from app.models.entities import Tenant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/tenants", tags=["Tenants"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Dependency to get current tenant ID or raise exception."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listTenants",
|
||||
summary="List all tenants",
|
||||
description="Get a list of all tenants from the system.",
|
||||
responses={
|
||||
200: {"description": "List of tenants"},
|
||||
401: {"description": "Unauthorized", "model": ErrorResponse},
|
||||
403: {"description": "Forbidden", "model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def list_tenants(
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Get a list of all tenants from the tenants table.
|
||||
Returns tenant ID and display name (first part of tenant_id).
|
||||
"""
|
||||
logger.info("Getting all tenants")
|
||||
|
||||
# Get all tenants from tenants table
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
result = await session.execute(stmt)
|
||||
tenants = result.scalars().all()
|
||||
|
||||
# Format tenant list with display name
|
||||
tenant_list = []
|
||||
for tenant in tenants:
|
||||
name, year = parse_tenant_id(tenant.tenant_id)
|
||||
tenant_list.append({
|
||||
"id": tenant.tenant_id,
|
||||
"name": f"{name} ({year})",
|
||||
"displayName": name,
|
||||
"year": year,
|
||||
"createdAt": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(tenant_list)} tenants: {[t['id'] for t in tenant_list]}")
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"tenants": tenant_list,
|
||||
"total": len(tenant_list)
|
||||
}
|
||||
)
|
||||
|
|
@ -9,18 +9,43 @@ from typing import Annotated, Any
|
|||
from fastapi import APIRouter, Depends, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.middleware import get_response_mode, is_sse_request
|
||||
from app.core.sse import SSEStateMachine, create_error_event
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models import ChatRequest, ChatResponse, ErrorResponse
|
||||
from app.services.orchestrator import OrchestratorService, get_orchestrator_service
|
||||
from app.services.memory import MemoryService
|
||||
from app.services.orchestrator import OrchestratorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["AI Chat"])
|
||||
|
||||
|
||||
async def get_orchestrator_service_with_memory(
|
||||
session: Annotated[AsyncSession, Depends(get_session)]
|
||||
) -> OrchestratorService:
|
||||
"""
|
||||
[AC-AISVC-13] Create orchestrator service with memory service and LLM client.
|
||||
Ensures each request has a fresh MemoryService with database session.
|
||||
"""
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||||
|
||||
memory_service = MemoryService(session)
|
||||
llm_config_manager = get_llm_config_manager()
|
||||
llm_client = llm_config_manager.get_client()
|
||||
retriever = await get_optimized_retriever()
|
||||
|
||||
return OrchestratorService(
|
||||
llm_client=llm_client,
|
||||
memory_service=memory_service,
|
||||
retriever=retriever,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/chat",
|
||||
operation_id="generateReply",
|
||||
|
|
@ -49,7 +74,7 @@ async def generate_reply(
|
|||
request: Request,
|
||||
chat_request: ChatRequest,
|
||||
accept: Annotated[str | None, Header()] = None,
|
||||
orchestrator: OrchestratorService = Depends(get_orchestrator_service),
|
||||
orchestrator: OrchestratorService = Depends(get_orchestrator_service_with_memory),
|
||||
) -> Any:
|
||||
"""
|
||||
[AC-AISVC-06] Generate AI reply with automatic response mode switching.
|
||||
|
|
|
|||
|
|
@ -44,9 +44,16 @@ class Settings(BaseSettings):
|
|||
ollama_embedding_model: str = "nomic-embed-text"
|
||||
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.3
|
||||
rag_score_threshold: float = 0.01
|
||||
rag_min_hits: int = 1
|
||||
rag_max_evidence_tokens: int = 2000
|
||||
|
||||
rag_two_stage_enabled: bool = True
|
||||
rag_two_stage_expand_factor: int = 10
|
||||
rag_hybrid_enabled: bool = True
|
||||
rag_rrf_k: int = 60
|
||||
rag_vector_weight: float = 0.7
|
||||
rag_bm25_weight: float = 0.3
|
||||
|
||||
confidence_low_threshold: float = 0.5
|
||||
confidence_high_threshold: float = 0.8
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Middleware for AI Service.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
|
|
@ -19,11 +20,32 @@ TENANT_ID_HEADER = "X-Tenant-Id"
|
|||
ACCEPT_HEADER = "Accept"
|
||||
SSE_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
# Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
|
||||
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
|
||||
|
||||
|
||||
def validate_tenant_id_format(tenant_id: str) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Validate tenant ID format: name@ash@year
|
||||
Examples: szmp@ash@2026, abc123@ash@2025
|
||||
"""
|
||||
return bool(TENANT_ID_PATTERN.match(tenant_id))
|
||||
|
||||
|
||||
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
[AC-AISVC-10] Parse tenant ID into name and year.
|
||||
Returns: (name, year)
|
||||
"""
|
||||
parts = tenant_id.split('@')
|
||||
return parts[0], parts[2]
|
||||
|
||||
|
||||
class TenantContextMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
||||
Injects tenant context into request state for downstream processing.
|
||||
Validates tenant ID format and auto-creates tenant if not exists.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
|
|
@ -44,10 +66,31 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
set_tenant_context(tenant_id.strip())
|
||||
request.state.tenant_id = tenant_id.strip()
|
||||
tenant_id = tenant_id.strip()
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}")
|
||||
# Validate tenant ID format
|
||||
if not validate_tenant_id_format(tenant_id):
|
||||
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=ErrorResponse(
|
||||
code=ErrorCode.INVALID_TENANT_ID.value,
|
||||
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
# Auto-create tenant if not exists (for admin endpoints)
|
||||
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
|
||||
try:
|
||||
await self._ensure_tenant_exists(request, tenant_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
|
||||
# Continue processing even if tenant creation fails
|
||||
|
||||
set_tenant_context(tenant_id)
|
||||
request.state.tenant_id = tenant_id
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
|
@ -56,6 +99,39 @@ class TenantContextMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
return response
|
||||
|
||||
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
|
||||
"""
|
||||
[AC-AISVC-10] Ensure tenant exists in database, create if not.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.models.entities import Tenant
|
||||
|
||||
name, year = parse_tenant_id(tenant_id)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Check if tenant exists
|
||||
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
|
||||
result = await session.execute(stmt)
|
||||
existing_tenant = result.scalar_one_or_none()
|
||||
|
||||
if existing_tenant:
|
||||
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
|
||||
return
|
||||
|
||||
# Create new tenant
|
||||
new_tenant = Tenant(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
year=year,
|
||||
)
|
||||
session.add(new_tenant)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
|
||||
|
||||
|
||||
def is_sse_request(request: Request) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Shared prompt templates for AI Service.
|
||||
Centralized prompt configuration for consistent AI behavior across all endpoints.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """你是一名经验丰富的客服专员,名字叫"小N"。请遵循以下准则:
|
||||
像朋友一样聊天:使用自然、口语化的中文,语气亲切有礼貌,就像在和一位老朋友聊天。多使用"咱们"、"您看"、"是不是"这样拉近距离的表达。
|
||||
避免生硬罗列:介绍多个信息时,不要用"第一、第二"或分点列举,而是用流畅的段落把它们串联起来。比如:"我们这边有几类课程挺受欢迎的,一个是……另外还有……您对哪方面更感兴趣呀?"
|
||||
主动引导对话:不要一次性把知识库里的所有内容都倒出来。可以先简单概括,然后问问客户具体想了解什么,比如:"您是想提升职场英语,还是有留学方面的打算呢?"
|
||||
委婉处理未知:如果知识库里没有答案,别直接说"不知道",可以说:"您稍等,我帮您确认一下"、"这个问题我需要问问同事,马上给您回复"。
|
||||
隐藏机器身份:绝不要提及"作为AI"、"根据算法"等词,也避免使用过于官方的术语。就像普通人一样用"我"、"我们"。
|
||||
多用语气词和情感表达:适当加入"呢"、"哦"、"哈"、"呀"等语气词,根据情况使用表情符号(😊)。表达理解和关心,比如"我明白您的意思"、"别着急,我们一起看看"。
|
||||
保持简洁但有温度:回答一般控制在3-5句话,如果需要详细说明,可以分成小段落,但每段都要口语化,读起来不累。
|
||||
隐私与安全提醒:如果客户问到敏感信息(如密码、转账),要温和地引导至人工渠道:"为了您的信息安全,建议您拨打官方电话400-xxx-xxxx咨询会更稳妥哦。"""
|
||||
|
||||
|
||||
def format_evidence_for_prompt(
|
||||
retrieval_results: list,
|
||||
max_results: int = 5,
|
||||
max_content_length: int = 500
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieval results as evidence text for prompts.
|
||||
|
||||
Args:
|
||||
retrieval_results: List of retrieval hits. Can be:
|
||||
- dict format: {'content', 'score', 'source', 'metadata'}
|
||||
- RetrievalHit object: with .text, .score, .source, .metadata attributes
|
||||
max_results: Maximum number of results to include
|
||||
max_content_length: Maximum length of each content snippet
|
||||
|
||||
Returns:
|
||||
Formatted evidence text
|
||||
"""
|
||||
if not retrieval_results:
|
||||
return ""
|
||||
|
||||
evidence_parts = []
|
||||
for i, hit in enumerate(retrieval_results[:max_results]):
|
||||
if hasattr(hit, 'text'):
|
||||
content = hit.text
|
||||
score = hit.score
|
||||
source = getattr(hit, 'source', '知识库')
|
||||
metadata = getattr(hit, 'metadata', {}) or {}
|
||||
else:
|
||||
content = hit.get('content', '')
|
||||
score = hit.get('score', 0)
|
||||
source = hit.get('source', '知识库')
|
||||
metadata = hit.get('metadata', {}) or {}
|
||||
|
||||
if len(content) > max_content_length:
|
||||
content = content[:max_content_length] + '...'
|
||||
|
||||
nested_meta = metadata.get('metadata', {})
|
||||
source_doc = nested_meta.get('source_doc', source) if nested_meta else source
|
||||
category = nested_meta.get('category', '') if nested_meta else ''
|
||||
department = nested_meta.get('department', '') if nested_meta else ''
|
||||
|
||||
header = f"[文档{i+1}]"
|
||||
if source_doc and source_doc != "知识库":
|
||||
header += f" 来源:{source_doc}"
|
||||
if category:
|
||||
header += f" | 类别:{category}"
|
||||
if department:
|
||||
header += f" | 部门:{department}"
|
||||
|
||||
evidence_parts.append(f"{header}\n相关度:{score:.2f}\n内容:{content}")
|
||||
|
||||
return "\n\n".join(evidence_parts)
|
||||
|
||||
|
||||
def build_system_prompt_with_evidence(evidence_text: str) -> str:
|
||||
"""
|
||||
Build system prompt with knowledge base evidence.
|
||||
|
||||
Args:
|
||||
evidence_text: Formatted evidence from retrieval results
|
||||
|
||||
Returns:
|
||||
Complete system prompt
|
||||
"""
|
||||
if not evidence_text:
|
||||
return SYSTEM_PROMPT
|
||||
|
||||
return f"""{SYSTEM_PROMPT}
|
||||
|
||||
知识库参考内容:
|
||||
{evidence_text}"""
|
||||
|
||||
|
||||
def build_user_prompt_with_evidence(query: str, evidence_text: str) -> str:
|
||||
"""
|
||||
Build user prompt with knowledge base evidence (for single-message format).
|
||||
|
||||
Args:
|
||||
query: User's question
|
||||
evidence_text: Formatted evidence from retrieval results
|
||||
|
||||
Returns:
|
||||
Complete user prompt
|
||||
"""
|
||||
if not evidence_text:
|
||||
return f"""用户问题:{query}
|
||||
|
||||
未找到相关检索结果,请基于通用知识回答用户问题。"""
|
||||
|
||||
return f"""【系统指令】
|
||||
{SYSTEM_PROMPT}
|
||||
|
||||
【知识库内容】
|
||||
{evidence_text}
|
||||
|
||||
【用户问题】
|
||||
{query}"""
|
||||
|
|
@ -1,13 +1,14 @@
|
|||
"""
|
||||
Qdrant client for AI Service.
|
||||
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
|
||||
Supports multi-dimensional vectors for Matryoshka representation learning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams
|
||||
from qdrant_client.models import Distance, PointStruct, VectorParams, MultiVectorConfig
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class QdrantClient:
|
|||
"""
|
||||
[AC-AISVC-10] Qdrant client with tenant-isolated collection management.
|
||||
Collection naming: kb_{tenantId} for tenant isolation.
|
||||
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -45,13 +47,15 @@ class QdrantClient:
|
|||
"""
|
||||
[AC-AISVC-10] Get collection name for a tenant.
|
||||
Naming convention: kb_{tenantId}
|
||||
Replaces @ with _ to ensure valid collection names.
|
||||
"""
|
||||
return f"{self._collection_prefix}{tenant_id}"
|
||||
safe_tenant_id = tenant_id.replace('@', '_')
|
||||
return f"{self._collection_prefix}{safe_tenant_id}"
|
||||
|
||||
async def ensure_collection_exists(self, tenant_id: str) -> bool:
|
||||
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
|
||||
"""
|
||||
[AC-AISVC-10] Ensure collection exists for tenant.
|
||||
Note: MVP uses pre-provisioned collections, this is for development/testing.
|
||||
Supports multi-dimensional vectors for Matryoshka retrieval.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
|
@ -61,15 +65,34 @@ class QdrantClient:
|
|||
exists = any(c.name == collection_name for c in collections.collections)
|
||||
|
||||
if not exists:
|
||||
await client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
if use_multi_vector:
|
||||
vectors_config = {
|
||||
"full": VectorParams(
|
||||
size=768,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
"dim_256": VectorParams(
|
||||
size=256,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
"dim_512": VectorParams(
|
||||
size=512,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
}
|
||||
else:
|
||||
vectors_config = VectorParams(
|
||||
size=self._vector_size,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
await client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}"
|
||||
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
|
||||
f"with multi_vector={use_multi_vector}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -100,44 +123,160 @@ class QdrantClient:
|
|||
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
|
||||
return False
|
||||
|
||||
async def upsert_multi_vector(
|
||||
self,
|
||||
tenant_id: str,
|
||||
points: list[dict[str, Any]],
|
||||
) -> bool:
|
||||
"""
|
||||
Upsert points with multi-dimensional vectors.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
points: List of points with format:
|
||||
{
|
||||
"id": str | int,
|
||||
"vector": {
|
||||
"full": [768 floats],
|
||||
"dim_256": [256 floats],
|
||||
"dim_512": [512 floats],
|
||||
},
|
||||
"payload": dict
|
||||
}
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
try:
|
||||
qdrant_points = []
|
||||
for p in points:
|
||||
point = PointStruct(
|
||||
id=p["id"],
|
||||
vector=p["vector"],
|
||||
payload=p.get("payload", {}),
|
||||
)
|
||||
qdrant_points.append(point)
|
||||
|
||||
await client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=qdrant_points,
|
||||
)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vector: list[float],
|
||||
limit: int = 5,
|
||||
score_threshold: float | None = None,
|
||||
vector_name: str = "full",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||
Returns results with score >= score_threshold if specified.
|
||||
Searches both old format (with @) and new format (with _) for backward compatibility.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
query_vector: Query vector for similarity search
|
||||
limit: Maximum number of results
|
||||
score_threshold: Minimum score threshold for results
|
||||
vector_name: Name of the vector to search (for multi-vector collections)
|
||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||
"""
|
||||
client = await self.get_client()
|
||||
collection_name = self.get_collection_name(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
|
||||
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
|
||||
)
|
||||
|
||||
collection_names = [self.get_collection_name(tenant_id)]
|
||||
if '@' in tenant_id:
|
||||
old_format = f"{self._collection_prefix}{tenant_id}"
|
||||
new_format = f"{self._collection_prefix}{tenant_id.replace('@', '_')}"
|
||||
collection_names = [new_format, old_format]
|
||||
|
||||
logger.info(f"[AC-AISVC-10] Will search in collections: {collection_names}")
|
||||
|
||||
all_hits = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
logger.info(f"[AC-AISVC-10] Searching in collection: {collection_name}")
|
||||
|
||||
try:
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
|
||||
f"trying without vector name (single-vector mode)"
|
||||
)
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
||||
)
|
||||
|
||||
try:
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
hits = [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
if score_threshold is None or result.score >= score_threshold
|
||||
]
|
||||
all_hits.extend(hits)
|
||||
|
||||
if hits:
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
for i, h in enumerate(hits[:3]):
|
||||
logger.debug(
|
||||
f"[AC-AISVC-10] Hit {i+1}: id={h['id']}, score={h['score']:.4f}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
if len(all_hits) == 0:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
|
||||
f"collections_tried={collection_names}, limit={limit}"
|
||||
)
|
||||
|
||||
hits = [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
if score_threshold is None or result.score >= score_threshold
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}"
|
||||
)
|
||||
return hits
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-10] Error searching vectors: {e}")
|
||||
return []
|
||||
|
||||
return all_hits
|
||||
|
||||
async def delete_collection(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api import chat_router, health_router
|
||||
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router
|
||||
from app.api.admin import dashboard_router, embedding_router, kb_router, llm_router, rag_router, sessions_router, tenants_router
|
||||
from app.api.admin.kb_optimized import router as kb_optimized_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.database import close_db, init_db
|
||||
from app.core.exceptions import (
|
||||
|
|
@ -115,9 +116,11 @@ app.include_router(chat_router)
|
|||
app.include_router(dashboard_router)
|
||||
app.include_router(embedding_router)
|
||||
app.include_router(kb_router)
|
||||
app.include_router(kb_optimized_router)
|
||||
app.include_router(llm_router)
|
||||
app.include_router(rag_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(tenants_router)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class ChatResponse(BaseModel):
|
|||
class ErrorCode(str, Enum):
|
||||
INVALID_REQUEST = "INVALID_REQUEST"
|
||||
MISSING_TENANT_ID = "MISSING_TENANT_ID"
|
||||
INVALID_TENANT_ID = "INVALID_TENANT_ID"
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
|
|
|
|||
|
|
@ -102,6 +102,22 @@ class SessionStatus(str, Enum):
|
|||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class Tenant(SQLModel, table=True):
|
||||
"""
|
||||
[AC-AISVC-10] Tenant entity for storing tenant information.
|
||||
Tenant ID format: name@ash@year (e.g., szmp@ash@2026)
|
||||
"""
|
||||
|
||||
__tablename__ = "tenants"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Full tenant ID (format: name@ash@year)", unique=True, index=True)
|
||||
name: str = Field(..., description="Tenant display name (first part of tenant_id)")
|
||||
year: str = Field(..., description="Year part from tenant_id")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
"""
|
||||
[AC-ASA-01] Knowledge base entity with tenant isolation.
|
||||
|
|
|
|||
|
|
@ -15,17 +15,39 @@ from app.services.document.base import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENCODINGS_TO_TRY = ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]
|
||||
|
||||
|
||||
class TextParser(DocumentParser):
|
||||
"""
|
||||
Parser for plain text files.
|
||||
[AC-AISVC-33] Direct text extraction.
|
||||
[AC-AISVC-33] Direct text extraction with multiple encoding support.
|
||||
"""
|
||||
|
||||
def __init__(self, encoding: str = "utf-8", **kwargs: Any):
|
||||
self._encoding = encoding
|
||||
self._extra_config = kwargs
|
||||
|
||||
def _try_encodings(self, path: Path) -> tuple[str, str]:
|
||||
"""
|
||||
Try multiple encodings to read the file.
|
||||
Returns: (text, encoding_used)
|
||||
"""
|
||||
for enc in ENCODINGS_TO_TRY:
|
||||
try:
|
||||
with open(path, "r", encoding=enc) as f:
|
||||
text = f.read()
|
||||
logger.info(f"Successfully parsed with encoding: {enc}")
|
||||
return text, enc
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
raise DocumentParseException(
|
||||
f"Failed to decode file with any known encoding",
|
||||
file_path=str(path),
|
||||
parser="text"
|
||||
)
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a text file and extract content.
|
||||
|
|
@ -41,15 +63,14 @@ class TextParser(DocumentParser):
|
|||
)
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
text, encoding_used = self._try_encodings(path)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
line_count = text.count("\n") + 1
|
||||
|
||||
logger.info(
|
||||
f"Parsed text: {path.name}, lines={line_count}, "
|
||||
f"chars={len(text)}, size={file_size}"
|
||||
f"chars={len(text)}, size={file_size}, encoding={encoding_used}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
|
|
@ -59,35 +80,12 @@ class TextParser(DocumentParser):
|
|||
metadata={
|
||||
"format": "text",
|
||||
"line_count": line_count,
|
||||
"encoding": self._encoding,
|
||||
"encoding": encoding_used,
|
||||
}
|
||||
)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
with open(path, "r", encoding="gbk") as f:
|
||||
text = f.read()
|
||||
|
||||
file_size = path.stat().st_size
|
||||
line_count = text.count("\n") + 1
|
||||
|
||||
return ParseResult(
|
||||
text=text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "text",
|
||||
"line_count": line_count,
|
||||
"encoding": "gbk",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse text file with encoding fallback: {e}",
|
||||
file_path=str(path),
|
||||
parser="text",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse text file: {e}",
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@ from app.services.embedding.factory import (
|
|||
)
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import (
|
||||
NomicEmbeddingProvider,
|
||||
NomicEmbeddingResult,
|
||||
EmbeddingTask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingConfig",
|
||||
|
|
@ -29,4 +34,7 @@ __all__ = [
|
|||
"get_embedding_provider",
|
||||
"OllamaEmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"NomicEmbeddingProvider",
|
||||
"NomicEmbeddingResult",
|
||||
"EmbeddingTask",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Any, Type
|
|||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,6 +27,7 @@ class EmbeddingProviderFactory:
|
|||
_providers: dict[str, Type[EmbeddingProvider]] = {
|
||||
"ollama": OllamaEmbeddingProvider,
|
||||
"openai": OpenAIEmbeddingProvider,
|
||||
"nomic": NomicEmbeddingProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -63,11 +65,13 @@ class EmbeddingProviderFactory:
|
|||
display_names = {
|
||||
"ollama": "Ollama 本地模型",
|
||||
"openai": "OpenAI Embedding",
|
||||
"nomic": "Nomic Embed (优化版)",
|
||||
}
|
||||
|
||||
descriptions = {
|
||||
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
|
||||
"openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型",
|
||||
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
Nomic embedding provider with task prefixes and Matryoshka support.
|
||||
Implements RAG optimization spec:
|
||||
- Task prefixes: search_document: / search_query:
|
||||
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from app.services.embedding.base import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingException,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingTask(str, Enum):
|
||||
"""Task type for nomic-embed-text v1.5 model."""
|
||||
DOCUMENT = "search_document"
|
||||
QUERY = "search_query"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NomicEmbeddingResult:
|
||||
"""Result from Nomic embedding with multiple dimensions."""
|
||||
embedding_full: list[float]
|
||||
embedding_256: list[float]
|
||||
embedding_512: list[float]
|
||||
dimension: int
|
||||
model: str
|
||||
task: EmbeddingTask
|
||||
latency_ms: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class NomicEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Nomic-embed-text v1.5 embedding provider with task prefixes.
|
||||
|
||||
Key features:
|
||||
- Task prefixes: search_document: for documents, search_query: for queries
|
||||
- Matryoshka dimension truncation: 256/512/768 dimensions
|
||||
- Automatic normalization after truncation
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.1, 2.3
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "nomic"
|
||||
DOCUMENT_PREFIX = "search_document:"
|
||||
QUERY_PREFIX = "search_query:"
|
||||
FULL_DIMENSION = 768
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:11434",
|
||||
model: str = "nomic-embed-text",
|
||||
dimension: int = 768,
|
||||
timeout_seconds: int = 60,
|
||||
enable_matryoshka: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._model = model
|
||||
self._dimension = dimension
|
||||
self._timeout = timeout_seconds
|
||||
self._enable_matryoshka = enable_matryoshka
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._extra_config = kwargs
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
def _add_prefix(self, text: str, task: EmbeddingTask) -> str:
|
||||
"""Add task prefix to text."""
|
||||
if task == EmbeddingTask.DOCUMENT:
|
||||
prefix = self.DOCUMENT_PREFIX
|
||||
else:
|
||||
prefix = self.QUERY_PREFIX
|
||||
|
||||
if text.startswith(prefix):
|
||||
return text
|
||||
return f"{prefix}{text}"
|
||||
|
||||
def _truncate_and_normalize(self, embedding: list[float], target_dim: int) -> list[float]:
|
||||
"""
|
||||
Truncate embedding to target dimension and normalize.
|
||||
Matryoshka representation learning allows dimension truncation.
|
||||
"""
|
||||
truncated = embedding[:target_dim]
|
||||
|
||||
arr = np.array(truncated, dtype=np.float32)
|
||||
norm = np.linalg.norm(arr)
|
||||
if norm > 0:
|
||||
arr = arr / norm
|
||||
|
||||
return arr.tolist()
|
||||
|
||||
async def embed_with_task(
|
||||
self,
|
||||
text: str,
|
||||
task: EmbeddingTask,
|
||||
) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding with specified task prefix.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
task: DOCUMENT for indexing, QUERY for retrieval
|
||||
|
||||
Returns:
|
||||
NomicEmbeddingResult with all dimension variants
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
prefixed_text = self._add_prefix(text, task)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
f"{self._base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self._model,
|
||||
"prompt": prefixed_text,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
if not embedding:
|
||||
raise EmbeddingException(
|
||||
"Empty embedding returned",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"text_length": len(text), "task": task.value}
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
||||
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
||||
|
||||
logger.debug(
|
||||
f"Generated Nomic embedding: task={task.value}, "
|
||||
f"dim={len(embedding)}, latency={latency_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return NomicEmbeddingResult(
|
||||
embedding_full=embedding,
|
||||
embedding_256=embedding_256,
|
||||
embedding_512=embedding_512,
|
||||
dimension=len(embedding),
|
||||
model=self._model,
|
||||
task=task,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama API error: {e.response.status_code}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"status_code": e.response.status_code, "response": e.response.text}
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise EmbeddingException(
|
||||
f"Ollama connection error: {e}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
details={"base_url": self._base_url}
|
||||
)
|
||||
except EmbeddingException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise EmbeddingException(
|
||||
f"Embedding generation failed: {e}",
|
||||
provider=self.PROVIDER_NAME
|
||||
)
|
||||
|
||||
async def embed_document(self, text: str) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding for document (with search_document: prefix).
|
||||
Use this when indexing documents into vector store.
|
||||
"""
|
||||
return await self.embed_with_task(text, EmbeddingTask.DOCUMENT)
|
||||
|
||||
async def embed_query(self, text: str) -> NomicEmbeddingResult:
|
||||
"""
|
||||
Generate embedding for query (with search_query: prefix).
|
||||
Use this when searching/retrieving documents.
|
||||
"""
|
||||
return await self.embed_with_task(text, EmbeddingTask.QUERY)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for a single text.
|
||||
Default uses QUERY task for backward compatibility.
|
||||
"""
|
||||
result = await self.embed_query(text)
|
||||
return result.embedding_full
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embedding vectors for multiple texts.
|
||||
Uses QUERY task by default.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self.embed(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
async def embed_documents_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[NomicEmbeddingResult]:
|
||||
"""
|
||||
Generate embeddings for multiple documents (DOCUMENT task).
|
||||
Use this when batch indexing documents.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
result = await self.embed_document(text)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def embed_queries_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[NomicEmbeddingResult]:
|
||||
"""
|
||||
Generate embeddings for multiple queries (QUERY task).
|
||||
Use this when batch processing queries.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
result = await self.embed_query(text)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get the dimension of embedding vectors."""
|
||||
return self._dimension
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Get the name of this embedding provider."""
|
||||
return self.PROVIDER_NAME
|
||||
|
||||
def get_config_schema(self) -> dict[str, Any]:
|
||||
"""Get the configuration schema for Nomic provider."""
|
||||
return {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Ollama API 地址",
|
||||
"default": "http://localhost:11434",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
||||
"default": "nomic-embed-text",
|
||||
},
|
||||
"dimension": {
|
||||
"type": "integer",
|
||||
"description": "向量维度(支持 256/512/768)",
|
||||
"default": 768,
|
||||
},
|
||||
"timeout_seconds": {
|
||||
"type": "integer",
|
||||
"description": "请求超时时间(秒)",
|
||||
"default": 60,
|
||||
},
|
||||
"enable_matryoshka": {
|
||||
"type": "boolean",
|
||||
"description": "启用 Matryoshka 维度截断",
|
||||
"default": True,
|
||||
},
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
|
@ -133,6 +133,13 @@ class OpenAIClient(LLMClient):
|
|||
body = self._build_request_body(messages, effective_config, stream=False, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
|
||||
logger.info(f"[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-02] ======================================")
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
|
|
@ -213,6 +220,13 @@ class OpenAIClient(LLMClient):
|
|||
body = self._build_request_body(messages, effective_config, stream=True, **kwargs)
|
||||
|
||||
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
|
||||
logger.info(f"[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-06] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-06] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-06] ======================================")
|
||||
|
||||
try:
|
||||
async with client.stream(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ Design reference: design.md Section 2.2 - 关键数据流
|
|||
6. compute_confidence(...)
|
||||
7. Memory.append(tenantId, sessionId, user/assistant messages)
|
||||
8. Return ChatResponse (or output via SSE)
|
||||
|
||||
RAG Optimization (rag-optimization/spec.md):
|
||||
- Two-stage retrieval with Matryoshka dimensions
|
||||
- RRF hybrid ranking
|
||||
- Optimized prompt engineering
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -20,6 +25,7 @@ from typing import Any, AsyncGenerator
|
|||
from sse_starlette.sse import ServerSentEvent
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.prompts import SYSTEM_PROMPT, format_evidence_for_prompt
|
||||
from app.core.sse import (
|
||||
create_error_event,
|
||||
create_final_event,
|
||||
|
|
@ -44,8 +50,9 @@ class OrchestratorConfig:
|
|||
"""
|
||||
max_history_tokens: int = 4000
|
||||
max_evidence_tokens: int = 2000
|
||||
system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。"
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
enable_rag: bool = True
|
||||
use_optimized_retriever: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -141,7 +148,14 @@ class OrchestratorService:
|
|||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
|
||||
f"session={request.session_id}"
|
||||
f"session={request.session_id}, channel_type={request.channel_type}, "
|
||||
f"current_message={request.current_message[:100]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-01] Config: enable_rag={self._config.enable_rag}, "
|
||||
f"use_optimized_retriever={self._config.use_optimized_retriever}, "
|
||||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}, "
|
||||
f"retriever={'configured' if self._retriever else 'NOT configured'}"
|
||||
)
|
||||
|
||||
ctx = GenerationContext(
|
||||
|
|
@ -257,6 +271,10 @@ class OrchestratorService:
|
|||
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
|
||||
Step 3 of the generation pipeline.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.current_message[:100]}..., retriever={type(self._retriever).__name__ if self._retriever else 'None'}"
|
||||
)
|
||||
try:
|
||||
retrieval_ctx = RetrievalContext(
|
||||
tenant_id=ctx.tenant_id,
|
||||
|
|
@ -277,11 +295,19 @@ class OrchestratorService:
|
|||
logger.info(
|
||||
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
|
||||
f"hits={ctx.retrieval_result.hit_count}, "
|
||||
f"max_score={ctx.retrieval_result.max_score:.3f}"
|
||||
f"max_score={ctx.retrieval_result.max_score:.3f}, "
|
||||
f"is_empty={ctx.retrieval_result.is_empty}"
|
||||
)
|
||||
|
||||
if ctx.retrieval_result.hit_count > 0:
|
||||
for i, hit in enumerate(ctx.retrieval_result.hits[:3]):
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Hit {i+1}: score={hit.score:.3f}, "
|
||||
f"text_preview={hit.text[:100]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}")
|
||||
logger.error(f"[AC-AISVC-16] Retrieval failed with exception: {e}", exc_info=True)
|
||||
ctx.retrieval_result = RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e)},
|
||||
|
|
@ -294,9 +320,18 @@ class OrchestratorService:
|
|||
Step 4-5 of the generation pipeline.
|
||||
"""
|
||||
messages = self._build_llm_messages(ctx)
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Building LLM messages: count={len(messages)}, "
|
||||
f"has_retrieval_result={ctx.retrieval_result is not None}, "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else 'N/A'}, "
|
||||
f"llm_client={'configured' if self._llm_client else 'NOT configured'}"
|
||||
)
|
||||
|
||||
if not self._llm_client:
|
||||
logger.warning("[AC-AISVC-02] No LLM client configured, using fallback")
|
||||
logger.warning(
|
||||
f"[AC-AISVC-02] No LLM client configured, using fallback. "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}"
|
||||
)
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=self._fallback_response(ctx),
|
||||
model="fallback",
|
||||
|
|
@ -304,6 +339,7 @@ class OrchestratorService:
|
|||
finish_reason="fallback",
|
||||
)
|
||||
ctx.diagnostics["llm_mode"] = "fallback"
|
||||
ctx.diagnostics["fallback_reason"] = "no_llm_client"
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -318,11 +354,16 @@ class OrchestratorService:
|
|||
logger.info(
|
||||
f"[AC-AISVC-02] LLM response generated: "
|
||||
f"model={ctx.llm_response.model}, "
|
||||
f"tokens={ctx.llm_response.usage}"
|
||||
f"tokens={ctx.llm_response.usage}, "
|
||||
f"content_preview={ctx.llm_response.content[:100]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-02] LLM generation failed: {e}")
|
||||
logger.error(
|
||||
f"[AC-AISVC-02] LLM generation failed: {e}, "
|
||||
f"retrieval_is_empty={ctx.retrieval_result.is_empty if ctx.retrieval_result else True}",
|
||||
exc_info=True
|
||||
)
|
||||
ctx.llm_response = LLMResponse(
|
||||
content=self._fallback_response(ctx),
|
||||
model="fallback",
|
||||
|
|
@ -331,6 +372,8 @@ class OrchestratorService:
|
|||
metadata={"error": str(e)},
|
||||
)
|
||||
ctx.diagnostics["llm_error"] = str(e)
|
||||
ctx.diagnostics["llm_mode"] = "fallback"
|
||||
ctx.diagnostics["fallback_reason"] = f"llm_error: {str(e)}"
|
||||
|
||||
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
|
||||
"""
|
||||
|
|
@ -350,18 +393,29 @@ class OrchestratorService:
|
|||
messages.extend(ctx.merged_context.messages)
|
||||
|
||||
messages.append({"role": "user", "content": ctx.current_message})
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-02] Built {len(messages)} messages for LLM: "
|
||||
f"system_len={len(system_content)}, history_count={len(ctx.merged_context.messages) if ctx.merged_context else 0}"
|
||||
)
|
||||
logger.debug(f"[AC-AISVC-02] System prompt preview: {system_content[:500]}...")
|
||||
|
||||
logger.info(f"[AC-AISVC-02] ========== ORCHESTRATOR FULL PROMPT ==========")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
|
||||
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
|
||||
logger.info(f"[AC-AISVC-02] ==============================================")
|
||||
|
||||
return messages
|
||||
|
||||
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
|
||||
"""
|
||||
[AC-AISVC-17] Format retrieval hits as evidence text.
|
||||
Uses shared prompt configuration for consistency.
|
||||
"""
|
||||
evidence_parts = []
|
||||
for i, hit in enumerate(retrieval_result.hits[:5], 1):
|
||||
evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}")
|
||||
|
||||
return "\n".join(evidence_parts)
|
||||
return format_evidence_for_prompt(retrieval_result.hits, max_results=5, max_content_length=500)
|
||||
|
||||
def _fallback_response(self, ctx: GenerationContext) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Retrieval module for AI Service.
|
||||
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
|
||||
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
|
||||
"""
|
||||
|
||||
from app.services.retrieval.base import (
|
||||
|
|
@ -10,6 +11,27 @@ from app.services.retrieval.base import (
|
|||
RetrievalResult,
|
||||
)
|
||||
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
|
||||
from app.services.retrieval.metadata import (
|
||||
ChunkMetadata,
|
||||
ChunkMetadataModel,
|
||||
MetadataFilter,
|
||||
KnowledgeChunk,
|
||||
RetrieveRequest,
|
||||
RetrieveResult,
|
||||
RetrievalStrategy,
|
||||
)
|
||||
from app.services.retrieval.optimized_retriever import (
|
||||
OptimizedRetriever,
|
||||
get_optimized_retriever,
|
||||
TwoStageResult,
|
||||
RRFCombiner,
|
||||
)
|
||||
from app.services.retrieval.indexer import (
|
||||
KnowledgeIndexer,
|
||||
get_knowledge_indexer,
|
||||
IndexingProgress,
|
||||
IndexingResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseRetriever",
|
||||
|
|
@ -18,4 +40,18 @@ __all__ = [
|
|||
"RetrievalResult",
|
||||
"VectorRetriever",
|
||||
"get_vector_retriever",
|
||||
"ChunkMetadata",
|
||||
"MetadataFilter",
|
||||
"KnowledgeChunk",
|
||||
"RetrieveRequest",
|
||||
"RetrieveResult",
|
||||
"RetrievalStrategy",
|
||||
"OptimizedRetriever",
|
||||
"get_optimized_retriever",
|
||||
"TwoStageResult",
|
||||
"RRFCombiner",
|
||||
"KnowledgeIndexer",
|
||||
"get_knowledge_indexer",
|
||||
"IndexingProgress",
|
||||
"IndexingResult",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
Knowledge base indexing service with optimized embedding.
|
||||
Reference: rag-optimization/spec.md Section 5.1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||
from app.services.retrieval.metadata import ChunkMetadata, KnowledgeChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingProgress:
|
||||
"""Progress tracking for indexing jobs."""
|
||||
total_chunks: int = 0
|
||||
processed_chunks: int = 0
|
||||
failed_chunks: int = 0
|
||||
current_document: str = ""
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@property
|
||||
def progress_percent(self) -> int:
|
||||
if self.total_chunks == 0:
|
||||
return 0
|
||||
return int((self.processed_chunks / self.total_chunks) * 100)
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self) -> float:
|
||||
return (datetime.utcnow() - self.started_at).total_seconds()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingResult:
|
||||
"""Result of an indexing operation."""
|
||||
success: bool
|
||||
total_chunks: int
|
||||
indexed_chunks: int
|
||||
failed_chunks: int
|
||||
elapsed_seconds: float
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class KnowledgeIndexer:
|
||||
"""
|
||||
Knowledge base indexer with optimized embedding.
|
||||
|
||||
Features:
|
||||
- Task prefixes (search_document:) for document embedding
|
||||
- Multi-dimensional vectors (256/512/768)
|
||||
- Metadata support
|
||||
- Batch processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 10,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._embedding_provider = embedding_provider
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._batch_size = batch_size
|
||||
self._progress: IndexingProgress | None = None
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
def chunk_text(self, text: str, metadata: ChunkMetadata | None = None) -> list[KnowledgeChunk]:
|
||||
"""
|
||||
Split text into chunks for indexing.
|
||||
Each line becomes a separate chunk for better retrieval granularity.
|
||||
|
||||
Args:
|
||||
text: Full text to chunk
|
||||
metadata: Metadata to attach to each chunk
|
||||
|
||||
Returns:
|
||||
List of KnowledgeChunk objects
|
||||
"""
|
||||
chunks = []
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
lines = text.split('\n')
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < 10:
|
||||
continue
|
||||
|
||||
chunk = KnowledgeChunk(
|
||||
chunk_id=f"{doc_id}_{i}",
|
||||
document_id=doc_id,
|
||||
content=line,
|
||||
metadata=metadata or ChunkMetadata(),
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_text_by_lines(
|
||||
self,
|
||||
text: str,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
min_line_length: int = 10,
|
||||
merge_short_lines: bool = False,
|
||||
) -> list[KnowledgeChunk]:
|
||||
"""
|
||||
Split text by lines, each line is a separate chunk.
|
||||
|
||||
Args:
|
||||
text: Full text to chunk
|
||||
metadata: Metadata to attach to each chunk
|
||||
min_line_length: Minimum line length to be indexed
|
||||
merge_short_lines: Whether to merge consecutive short lines
|
||||
|
||||
Returns:
|
||||
List of KnowledgeChunk objects
|
||||
"""
|
||||
chunks = []
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
lines = text.split('\n')
|
||||
|
||||
if merge_short_lines:
|
||||
merged_lines = []
|
||||
current_line = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
if current_line:
|
||||
merged_lines.append(current_line)
|
||||
current_line = ""
|
||||
continue
|
||||
|
||||
if current_line:
|
||||
current_line += " " + line
|
||||
else:
|
||||
current_line = line
|
||||
|
||||
if len(current_line) >= min_line_length * 2:
|
||||
merged_lines.append(current_line)
|
||||
current_line = ""
|
||||
|
||||
if current_line:
|
||||
merged_lines.append(current_line)
|
||||
|
||||
lines = merged_lines
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) < min_line_length:
|
||||
continue
|
||||
|
||||
chunk = KnowledgeChunk(
|
||||
chunk_id=f"{doc_id}_{i}",
|
||||
document_id=doc_id,
|
||||
content=line,
|
||||
metadata=metadata or ChunkMetadata(),
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def index_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
document_id: str,
|
||||
text: str,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
) -> IndexingResult:
|
||||
"""
|
||||
Index a single document with optimized embedding.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
document_id: Document identifier
|
||||
text: Document text content
|
||||
metadata: Optional metadata for the document
|
||||
|
||||
Returns:
|
||||
IndexingResult with status and statistics
|
||||
"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
provider = await self._get_embedding_provider()
|
||||
|
||||
await client.ensure_collection_exists(tenant_id, use_multi_vector=True)
|
||||
|
||||
chunks = self.chunk_text(text, metadata)
|
||||
|
||||
self._progress = IndexingProgress(
|
||||
total_chunks=len(chunks),
|
||||
current_document=document_id,
|
||||
)
|
||||
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
embedding_result = await provider.embed_document(chunk.content)
|
||||
|
||||
chunk.embedding_full = embedding_result.embedding_full
|
||||
chunk.embedding_256 = embedding_result.embedding_256
|
||||
chunk.embedding_512 = embedding_result.embedding_512
|
||||
|
||||
point = {
|
||||
"id": str(uuid.uuid4()), # Generate a valid UUID for Qdrant
|
||||
"vector": {
|
||||
"full": chunk.embedding_full,
|
||||
"dim_256": chunk.embedding_256,
|
||||
"dim_512": chunk.embedding_512,
|
||||
},
|
||||
"payload": {
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"document_id": document_id,
|
||||
"text": chunk.content,
|
||||
"metadata": chunk.metadata.to_dict(),
|
||||
"created_at": chunk.created_at.isoformat(),
|
||||
}
|
||||
}
|
||||
points.append(point)
|
||||
|
||||
self._progress.processed_chunks += 1
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Indexed chunk {i+1}/{len(chunks)} for doc={document_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[RAG-OPT] Failed to index chunk {i}: {e}")
|
||||
self._progress.failed_chunks += 1
|
||||
|
||||
if points:
|
||||
await client.upsert_multi_vector(tenant_id, points)
|
||||
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Indexed document {document_id}: "
|
||||
f"{len(points)} chunks in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
return IndexingResult(
|
||||
success=True,
|
||||
total_chunks=len(chunks),
|
||||
indexed_chunks=len(points),
|
||||
failed_chunks=self._progress.failed_chunks,
|
||||
elapsed_seconds=elapsed,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
logger.error(f"[RAG-OPT] Failed to index document {document_id}: {e}")
|
||||
|
||||
return IndexingResult(
|
||||
success=False,
|
||||
total_chunks=0,
|
||||
indexed_chunks=0,
|
||||
failed_chunks=0,
|
||||
elapsed_seconds=elapsed,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def index_documents_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
documents: list[dict[str, Any]],
|
||||
) -> list[IndexingResult]:
|
||||
"""
|
||||
Index multiple documents in batch.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
documents: List of documents with format:
|
||||
{
|
||||
"document_id": str,
|
||||
"text": str,
|
||||
"metadata": ChunkMetadata (optional)
|
||||
}
|
||||
|
||||
Returns:
|
||||
List of IndexingResult for each document
|
||||
"""
|
||||
results = []
|
||||
|
||||
for doc in documents:
|
||||
result = await self.index_document(
|
||||
tenant_id=tenant_id,
|
||||
document_id=doc["document_id"],
|
||||
text=doc["text"],
|
||||
metadata=doc.get("metadata"),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def get_progress(self) -> IndexingProgress | None:
|
||||
"""Get current indexing progress."""
|
||||
return self._progress
|
||||
|
||||
|
||||
_knowledge_indexer: KnowledgeIndexer | None = None
|
||||
|
||||
|
||||
def get_knowledge_indexer() -> KnowledgeIndexer:
|
||||
"""Get or create KnowledgeIndexer instance."""
|
||||
global _knowledge_indexer
|
||||
if _knowledge_indexer is None:
|
||||
_knowledge_indexer = KnowledgeIndexer()
|
||||
return _knowledge_indexer
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Metadata models for RAG optimization.
|
||||
Implements structured metadata for knowledge chunks.
|
||||
Reference: rag-optimization/spec.md Section 3.2
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
"""Retrieval strategy options."""
|
||||
VECTOR_ONLY = "vector"
|
||||
BM25_ONLY = "bm25"
|
||||
HYBRID = "hybrid"
|
||||
TWO_STAGE = "two_stage"
|
||||
|
||||
|
||||
class ChunkMetadataModel(BaseModel):
|
||||
"""Pydantic model for API serialization."""
|
||||
category: str = ""
|
||||
subcategory: str = ""
|
||||
target_audience: list[str] = []
|
||||
source_doc: str = ""
|
||||
source_url: str = ""
|
||||
department: str = ""
|
||||
valid_from: str | None = None
|
||||
valid_until: str | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
"""
|
||||
Metadata for knowledge chunks.
|
||||
Reference: rag-optimization/spec.md Section 3.2.2
|
||||
"""
|
||||
category: str = ""
|
||||
subcategory: str = ""
|
||||
target_audience: list[str] = field(default_factory=list)
|
||||
source_doc: str = ""
|
||||
source_url: str = ""
|
||||
department: str = ""
|
||||
valid_from: date | None = None
|
||||
valid_until: date | None = None
|
||||
priority: int = 5
|
||||
keywords: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
"category": self.category,
|
||||
"subcategory": self.subcategory,
|
||||
"target_audience": self.target_audience,
|
||||
"source_doc": self.source_doc,
|
||||
"source_url": self.source_url,
|
||||
"department": self.department,
|
||||
"valid_from": self.valid_from.isoformat() if self.valid_from else None,
|
||||
"valid_until": self.valid_until.isoformat() if self.valid_until else None,
|
||||
"priority": self.priority,
|
||||
"keywords": self.keywords,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ChunkMetadata":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
category=data.get("category", ""),
|
||||
subcategory=data.get("subcategory", ""),
|
||||
target_audience=data.get("target_audience", []),
|
||||
source_doc=data.get("source_doc", ""),
|
||||
source_url=data.get("source_url", ""),
|
||||
department=data.get("department", ""),
|
||||
valid_from=date.fromisoformat(data["valid_from"]) if data.get("valid_from") else None,
|
||||
valid_until=date.fromisoformat(data["valid_until"]) if data.get("valid_until") else None,
|
||||
priority=data.get("priority", 5),
|
||||
keywords=data.get("keywords", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataFilter:
|
||||
"""
|
||||
Filter conditions for metadata-based retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
categories: list[str] | None = None
|
||||
target_audiences: list[str] | None = None
|
||||
departments: list[str] | None = None
|
||||
valid_only: bool = True
|
||||
min_priority: int | None = None
|
||||
keywords: list[str] | None = None
|
||||
|
||||
def to_qdrant_filter(self) -> dict[str, Any] | None:
|
||||
"""Convert to Qdrant filter format."""
|
||||
conditions = []
|
||||
|
||||
if self.categories:
|
||||
conditions.append({
|
||||
"key": "metadata.category",
|
||||
"match": {"any": self.categories}
|
||||
})
|
||||
|
||||
if self.departments:
|
||||
conditions.append({
|
||||
"key": "metadata.department",
|
||||
"match": {"any": self.departments}
|
||||
})
|
||||
|
||||
if self.target_audiences:
|
||||
conditions.append({
|
||||
"key": "metadata.target_audience",
|
||||
"match": {"any": self.target_audiences}
|
||||
})
|
||||
|
||||
if self.valid_only:
|
||||
today = date.today().isoformat()
|
||||
conditions.append({
|
||||
"should": [
|
||||
{"key": "metadata.valid_until", "match": {"value": None}},
|
||||
{"key": "metadata.valid_until", "range": {"gte": today}}
|
||||
]
|
||||
})
|
||||
|
||||
if self.min_priority is not None:
|
||||
conditions.append({
|
||||
"key": "metadata.priority",
|
||||
"range": {"lte": self.min_priority}
|
||||
})
|
||||
|
||||
if not conditions:
|
||||
return None
|
||||
|
||||
if len(conditions) == 1:
|
||||
return {"must": conditions}
|
||||
|
||||
return {"must": conditions}
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeChunk:
|
||||
"""
|
||||
Knowledge chunk with multi-dimensional embeddings.
|
||||
Reference: rag-optimization/spec.md Section 3.2.1
|
||||
"""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
content: str
|
||||
embedding_full: list[float] = field(default_factory=list)
|
||||
embedding_256: list[float] = field(default_factory=list)
|
||||
embedding_512: list[float] = field(default_factory=list)
|
||||
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_qdrant_point(self, point_id: int | str) -> dict[str, Any]:
|
||||
"""Convert to Qdrant point format."""
|
||||
return {
|
||||
"id": point_id,
|
||||
"vector": {
|
||||
"full": self.embedding_full,
|
||||
"dim_256": self.embedding_256,
|
||||
"dim_512": self.embedding_512,
|
||||
},
|
||||
"payload": {
|
||||
"chunk_id": self.chunk_id,
|
||||
"document_id": self.document_id,
|
||||
"text": self.content,
|
||||
"metadata": self.metadata.to_dict(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveRequest:
|
||||
"""
|
||||
Request for knowledge retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
query: str
|
||||
query_with_prefix: str = ""
|
||||
top_k: int = 10
|
||||
filters: MetadataFilter | None = None
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.HYBRID
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.query_with_prefix:
|
||||
self.query_with_prefix = f"search_query:{self.query}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveResult:
|
||||
"""
|
||||
Result from knowledge retrieval.
|
||||
Reference: rag-optimization/spec.md Section 4.1
|
||||
"""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
vector_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
metadata: ChunkMetadata = field(default_factory=ChunkMetadata)
|
||||
rank: int = 0
|
||||
|
|
@ -0,0 +1,509 @@
|
|||
"""
|
||||
Optimized RAG retriever with two-stage retrieval and RRF hybrid ranking.
|
||||
Reference: rag-optimization/spec.md Section 2.2, 2.4, 2.5
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider, NomicEmbeddingResult
|
||||
from app.services.retrieval.base import (
|
||||
BaseRetriever,
|
||||
RetrievalContext,
|
||||
RetrievalHit,
|
||||
RetrievalResult,
|
||||
)
|
||||
from app.services.retrieval.metadata import (
|
||||
ChunkMetadata,
|
||||
MetadataFilter,
|
||||
RetrieveResult,
|
||||
RetrievalStrategy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TwoStageResult:
|
||||
"""Result from two-stage retrieval."""
|
||||
candidates: list[dict[str, Any]]
|
||||
final_results: list[RetrieveResult]
|
||||
stage1_latency_ms: float = 0.0
|
||||
stage2_latency_ms: float = 0.0
|
||||
|
||||
|
||||
class RRFCombiner:
|
||||
"""
|
||||
Reciprocal Rank Fusion for combining multiple retrieval results.
|
||||
Reference: rag-optimization/spec.md Section 2.5
|
||||
|
||||
Formula: score = Σ(1 / (k + rank_i))
|
||||
Default k = 60
|
||||
"""
|
||||
|
||||
def __init__(self, k: int = 60):
|
||||
self._k = k
|
||||
|
||||
def combine(
|
||||
self,
|
||||
vector_results: list[dict[str, Any]],
|
||||
bm25_results: list[dict[str, Any]],
|
||||
vector_weight: float = 0.7,
|
||||
bm25_weight: float = 0.3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Combine vector and BM25 results using RRF.
|
||||
|
||||
Args:
|
||||
vector_results: Results from vector search
|
||||
bm25_results: Results from BM25 search
|
||||
vector_weight: Weight for vector results
|
||||
bm25_weight: Weight for BM25 results
|
||||
|
||||
Returns:
|
||||
Combined and sorted results
|
||||
"""
|
||||
combined_scores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for rank, result in enumerate(vector_results):
|
||||
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||
rrf_score = vector_weight / (self._k + rank + 1)
|
||||
|
||||
if chunk_id not in combined_scores:
|
||||
combined_scores[chunk_id] = {
|
||||
"score": 0.0,
|
||||
"vector_score": result.get("score", 0.0),
|
||||
"bm25_score": 0.0,
|
||||
"vector_rank": rank,
|
||||
"bm25_rank": -1,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
}
|
||||
|
||||
combined_scores[chunk_id]["score"] += rrf_score
|
||||
|
||||
for rank, result in enumerate(bm25_results):
|
||||
chunk_id = result.get("chunk_id") or result.get("id", str(rank))
|
||||
rrf_score = bm25_weight / (self._k + rank + 1)
|
||||
|
||||
if chunk_id not in combined_scores:
|
||||
combined_scores[chunk_id] = {
|
||||
"score": 0.0,
|
||||
"vector_score": 0.0,
|
||||
"bm25_score": result.get("score", 0.0),
|
||||
"vector_rank": -1,
|
||||
"bm25_rank": rank,
|
||||
"payload": result.get("payload", {}),
|
||||
"id": chunk_id,
|
||||
}
|
||||
else:
|
||||
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
||||
combined_scores[chunk_id]["bm25_rank"] = rank
|
||||
|
||||
combined_scores[chunk_id]["score"] += rrf_score
|
||||
|
||||
sorted_results = sorted(
|
||||
combined_scores.values(),
|
||||
key=lambda x: x["score"],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
class OptimizedRetriever(BaseRetriever):
|
||||
"""
|
||||
Optimized retriever with:
|
||||
- Task prefixes (search_document/search_query)
|
||||
- Two-stage retrieval (256 dim -> 768 dim)
|
||||
- RRF hybrid ranking (vector + BM25)
|
||||
- Metadata filtering
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2, 3, 4
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient | None = None,
|
||||
embedding_provider: NomicEmbeddingProvider | None = None,
|
||||
top_k: int | None = None,
|
||||
score_threshold: float | None = None,
|
||||
min_hits: int | None = None,
|
||||
two_stage_enabled: bool | None = None,
|
||||
two_stage_expand_factor: int | None = None,
|
||||
hybrid_enabled: bool | None = None,
|
||||
rrf_k: int | None = None,
|
||||
):
|
||||
self._qdrant_client = qdrant_client
|
||||
self._embedding_provider = embedding_provider
|
||||
self._top_k = top_k or settings.rag_top_k
|
||||
self._score_threshold = score_threshold or settings.rag_score_threshold
|
||||
self._min_hits = min_hits or settings.rag_min_hits
|
||||
self._two_stage_enabled = two_stage_enabled if two_stage_enabled is not None else settings.rag_two_stage_enabled
|
||||
self._two_stage_expand_factor = two_stage_expand_factor or settings.rag_two_stage_expand_factor
|
||||
self._hybrid_enabled = hybrid_enabled if hybrid_enabled is not None else settings.rag_hybrid_enabled
|
||||
self._rrf_k = rrf_k or settings.rag_rrf_k
|
||||
self._rrf_combiner = RRFCombiner(k=self._rrf_k)
|
||||
|
||||
async def _get_client(self) -> QdrantClient:
|
||||
if self._qdrant_client is None:
|
||||
self._qdrant_client = await get_qdrant_client()
|
||||
return self._qdrant_client
|
||||
|
||||
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||
if self._embedding_provider is None:
|
||||
from app.services.embedding.factory import get_embedding_config_manager
|
||||
manager = get_embedding_config_manager()
|
||||
provider = await manager.get_provider()
|
||||
if isinstance(provider, NomicEmbeddingProvider):
|
||||
self._embedding_provider = provider
|
||||
else:
|
||||
self._embedding_provider = NomicEmbeddingProvider(
|
||||
base_url=settings.ollama_base_url,
|
||||
model=settings.ollama_embedding_model,
|
||||
dimension=settings.qdrant_vector_size,
|
||||
)
|
||||
return self._embedding_provider
|
||||
|
||||
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||||
"""
|
||||
Retrieve documents using optimized strategy.
|
||||
|
||||
Strategy selection:
|
||||
1. If two_stage_enabled: use two-stage retrieval
|
||||
2. If hybrid_enabled: use RRF hybrid ranking
|
||||
3. Otherwise: simple vector search
|
||||
"""
|
||||
logger.info(
|
||||
f"[RAG-OPT] Starting retrieval for tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..., two_stage={self._two_stage_enabled}, hybrid={self._hybrid_enabled}"
|
||||
)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Retrieval config: top_k={self._top_k}, "
|
||||
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = await self._get_embedding_provider()
|
||||
logger.info(f"[RAG-OPT] Using embedding provider: {type(provider).__name__}")
|
||||
|
||||
embedding_result = await provider.embed_query(ctx.query)
|
||||
logger.info(
|
||||
f"[RAG-OPT] Embedding generated: full_dim={len(embedding_result.embedding_full)}, "
|
||||
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
||||
)
|
||||
|
||||
if self._two_stage_enabled:
|
||||
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
||||
results = await self._two_stage_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result,
|
||||
self._top_k,
|
||||
)
|
||||
elif self._hybrid_enabled:
|
||||
logger.info("[RAG-OPT] Using hybrid retrieval strategy")
|
||||
results = await self._hybrid_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result,
|
||||
ctx.query,
|
||||
self._top_k,
|
||||
)
|
||||
else:
|
||||
logger.info("[RAG-OPT] Using simple vector retrieval strategy")
|
||||
results = await self._vector_retrieve(
|
||||
ctx.tenant_id,
|
||||
embedding_result.embedding_full,
|
||||
self._top_k,
|
||||
)
|
||||
|
||||
logger.info(f"[RAG-OPT] Raw results count: {len(results)}")
|
||||
|
||||
retrieval_hits = [
|
||||
RetrievalHit(
|
||||
text=result.get("payload", {}).get("text", ""),
|
||||
score=result.get("score", 0.0),
|
||||
source="optimized_rag",
|
||||
metadata=result.get("payload", {}),
|
||||
)
|
||||
for result in results
|
||||
if result.get("score", 0.0) >= self._score_threshold
|
||||
]
|
||||
|
||||
filtered_count = len(results) - len(retrieval_hits)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[RAG-OPT] Filtered out {filtered_count} results below threshold {self._score_threshold}"
|
||||
)
|
||||
|
||||
is_insufficient = len(retrieval_hits) < self._min_hits
|
||||
|
||||
diagnostics = {
|
||||
"query_length": len(ctx.query),
|
||||
"top_k": self._top_k,
|
||||
"score_threshold": self._score_threshold,
|
||||
"two_stage_enabled": self._two_stage_enabled,
|
||||
"hybrid_enabled": self._hybrid_enabled,
|
||||
"total_hits": len(retrieval_hits),
|
||||
"is_insufficient": is_insufficient,
|
||||
"max_score": max((h.score for h in retrieval_hits), default=0.0),
|
||||
"raw_results_count": len(results),
|
||||
"filtered_below_threshold": filtered_count,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Retrieval complete: {len(retrieval_hits)} hits, "
|
||||
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
||||
)
|
||||
|
||||
if len(retrieval_hits) == 0:
|
||||
logger.warning(
|
||||
f"[RAG-OPT] No hits found! tenant={ctx.tenant_id}, query={ctx.query[:50]}..., "
|
||||
f"raw_results={len(results)}, threshold={self._score_threshold}"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
hits=retrieval_hits,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Retrieval error: {e}", exc_info=True)
|
||||
return RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e), "is_insufficient": True},
|
||||
)
|
||||
|
||||
async def _two_stage_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding_result: NomicEmbeddingResult,
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Two-stage retrieval using Matryoshka dimensions.
|
||||
|
||||
Stage 1: Fast retrieval with 256-dim vectors
|
||||
Stage 2: Precise reranking with 768-dim vectors
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.4
|
||||
"""
|
||||
import time
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
stage1_start = time.perf_counter()
|
||||
candidates = await self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||
top_k * self._two_stage_expand_factor
|
||||
)
|
||||
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
||||
)
|
||||
|
||||
stage2_start = time.perf_counter()
|
||||
reranked = []
|
||||
for candidate in candidates:
|
||||
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
||||
if stored_full_embedding:
|
||||
import numpy as np
|
||||
similarity = self._cosine_similarity(
|
||||
embedding_result.embedding_full,
|
||||
stored_full_embedding
|
||||
)
|
||||
candidate["score"] = similarity
|
||||
candidate["stage"] = "reranked"
|
||||
reranked.append(candidate)
|
||||
|
||||
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
results = reranked[:top_k]
|
||||
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _hybrid_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding_result: NomicEmbeddingResult,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Hybrid retrieval using RRF to combine vector and BM25 results.
|
||||
|
||||
Reference: rag-optimization/spec.md Section 2.5
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
vector_task = self._search_with_dimension(
|
||||
client, tenant_id, embedding_result.embedding_full, "full",
|
||||
top_k * 2
|
||||
)
|
||||
|
||||
bm25_task = self._bm25_search(client, tenant_id, query, top_k * 2)
|
||||
|
||||
vector_results, bm25_results = await asyncio.gather(
|
||||
vector_task, bm25_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(vector_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
||||
vector_results = []
|
||||
|
||||
if isinstance(bm25_results, Exception):
|
||||
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
||||
bm25_results = []
|
||||
|
||||
combined = self._rrf_combiner.combine(
|
||||
vector_results,
|
||||
bm25_results,
|
||||
vector_weight=settings.rag_vector_weight,
|
||||
bm25_weight=settings.rag_bm25_weight,
|
||||
)
|
||||
|
||||
return combined[:top_k]
|
||||
|
||||
async def _vector_retrieve(
|
||||
self,
|
||||
tenant_id: str,
|
||||
embedding: list[float],
|
||||
top_k: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Simple vector retrieval."""
|
||||
client = await self._get_client()
|
||||
return await self._search_with_dimension(
|
||||
client, tenant_id, embedding, "full", top_k
|
||||
)
|
||||
|
||||
async def _search_with_dimension(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query_vector: list[float],
|
||||
vector_name: str,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using specified vector dimension."""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Searching collection={collection_name}, "
|
||||
f"vector_name={vector_name}, limit={limit}, vector_dim={len(query_vector)}"
|
||||
)
|
||||
|
||||
results = await qdrant.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=(vector_name, query_vector),
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[RAG-OPT] Search returned {len(results)} results from collection={collection_name}"
|
||||
)
|
||||
|
||||
if len(results) > 0:
|
||||
for i, r in enumerate(results[:3]):
|
||||
logger.debug(
|
||||
f"[RAG-OPT] Result {i+1}: id={r.id}, score={r.score:.4f}"
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(result.id),
|
||||
"score": result.score,
|
||||
"payload": result.payload or {},
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[RAG-OPT] Search with {vector_name} failed: {e}, "
|
||||
f"collection_name={client.get_collection_name(tenant_id)}",
|
||||
exc_info=True
|
||||
)
|
||||
return []
|
||||
|
||||
async def _bm25_search(
|
||||
self,
|
||||
client: QdrantClient,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
BM25-like search using Qdrant's sparse vectors or fallback to text matching.
|
||||
This is a simplified implementation; for production, use Elasticsearch.
|
||||
"""
|
||||
try:
|
||||
qdrant = await client.get_client()
|
||||
collection_name = client.get_collection_name(tenant_id)
|
||||
|
||||
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||
|
||||
results = await qdrant.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=limit * 3,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
scored_results = []
|
||||
for point in results[0]:
|
||||
text = point.payload.get("text", "").lower()
|
||||
text_terms = set(re.findall(r'\w+', text))
|
||||
overlap = len(query_terms & text_terms)
|
||||
if overlap > 0:
|
||||
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||
scored_results.append({
|
||||
"id": str(point.id),
|
||||
"score": score,
|
||||
"payload": point.payload or {},
|
||||
})
|
||||
|
||||
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[RAG-OPT] BM25 search failed: {e}")
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
import numpy as np
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if retriever is healthy."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
qdrant = await client.get_client()
|
||||
await qdrant.get_collections()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[RAG-OPT] Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
_optimized_retriever: OptimizedRetriever | None = None
|
||||
|
||||
|
||||
async def get_optimized_retriever() -> OptimizedRetriever:
|
||||
"""Get or create OptimizedRetriever instance."""
|
||||
global _optimized_retriever
|
||||
if _optimized_retriever is None:
|
||||
_optimized_retriever = OptimizedRetriever()
|
||||
return _optimized_retriever
|
||||
|
|
@ -61,20 +61,31 @@ class VectorRetriever(BaseRetriever):
|
|||
RetrievalResult with filtered hits.
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Starting vector retrieval for tenant={ctx.tenant_id}, query={ctx.query[:50]}..."
|
||||
f"[AC-AISVC-16] Starting vector retrieval for tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-AISVC-16] Retrieval config: top_k={self._top_k}, "
|
||||
f"score_threshold={self._score_threshold}, min_hits={self._min_hits}"
|
||||
)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
logger.info(f"[AC-AISVC-16] Got Qdrant client: {type(client).__name__}")
|
||||
|
||||
logger.info("[AC-AISVC-16] Generating embedding for query...")
|
||||
query_vector = await self._get_embedding(ctx.query)
|
||||
logger.info(f"[AC-AISVC-16] Embedding generated: dim={len(query_vector)}")
|
||||
|
||||
logger.info(f"[AC-AISVC-16] Searching in tenant collection: tenant_id={ctx.tenant_id}")
|
||||
hits = await client.search(
|
||||
tenant_id=ctx.tenant_id,
|
||||
query_vector=query_vector,
|
||||
limit=self._top_k,
|
||||
score_threshold=self._score_threshold,
|
||||
)
|
||||
|
||||
logger.info(f"[AC-AISVC-16] Search returned {len(hits)} raw hits")
|
||||
|
||||
retrieval_hits = [
|
||||
RetrievalHit(
|
||||
|
|
@ -104,6 +115,12 @@ class VectorRetriever(BaseRetriever):
|
|||
f"[AC-AISVC-17] Retrieval complete: {len(retrieval_hits)} hits, "
|
||||
f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}"
|
||||
)
|
||||
|
||||
if len(retrieval_hits) == 0:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-17] No hits found! tenant={ctx.tenant_id}, "
|
||||
f"query={ctx.query[:50]}..., raw_hits={len(hits)}, threshold={self._score_threshold}"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
hits=retrieval_hits,
|
||||
|
|
@ -111,7 +128,7 @@ class VectorRetriever(BaseRetriever):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-AISVC-16] Retrieval error: {e}")
|
||||
logger.error(f"[AC-AISVC-16] Retrieval error: {e}", exc_info=True)
|
||||
return RetrievalResult(
|
||||
hits=[],
|
||||
diagnostics={"error": str(e), "is_insufficient": True},
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
- module: `ai-service`
|
||||
- feature: `AISVC` (Python AI 中台)
|
||||
- status: 🔄 进行中
|
||||
- status: ✅ 已完成
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -26,35 +26,32 @@
|
|||
- [x] Phase 1: 基础设施(FastAPI 框架与多租户基础) (100%) ✅
|
||||
- [x] Phase 2: 存储与检索实现(Memory & Retrieval) (100%) ✅
|
||||
- [x] Phase 3: 核心编排(Orchestrator & LLM Adapter) (100%) ✅
|
||||
- [ ] Phase 4: 流式响应(SSE 实现与状态机) (0%) ⏳
|
||||
- [ ] Phase 5: 集成与冒烟测试(Quality Assurance) (0%) ⏳
|
||||
- [x] Phase 4: 流式响应(SSE 实现与状态机) (100%) ✅
|
||||
- [x] Phase 5: 集成与冒烟测试(Quality Assurance) (100%) ✅
|
||||
- [x] Phase 6: 前后端联调真实对接 (100%) ✅
|
||||
- [x] Phase 7: 嵌入模型可插拔与文档解析 (100%) ✅
|
||||
- [x] Phase 8: LLM 配置与 RAG 调试输出 (100%) ✅
|
||||
- [x] Phase 9: 租户管理与 RAG 优化 (100%) ✅
|
||||
|
||||
---
|
||||
|
||||
## 🔄 Current Phase
|
||||
|
||||
### Goal
|
||||
实现 SSE 流式响应,包括 Accept 头切换、事件生成和状态机管理。
|
||||
Phase 9 已完成!项目进入稳定迭代阶段。
|
||||
|
||||
### Sub Tasks
|
||||
### Completed Tasks (Phase 9)
|
||||
|
||||
#### Phase 4: 流式响应(SSE 实现与状态机)
|
||||
- [ ] T4.1 在 API 层实现基于 `Accept` 头的响应模式自动切换逻辑 `[AC-AISVC-06]`
|
||||
- [ ] T4.2 实现 SSE 事件生成器:根据 Orchestrator 的增量输出包装 `message` 事件 `[AC-AISVC-07]`
|
||||
- [ ] T4.3 实现 SSE 状态机:确保 `final` 或 `error` 事件后连接正确关闭,且顺序不乱 `[AC-AISVC-08, AC-AISVC-09]`
|
||||
- [ ] T4.4 实现流式输出过程中的异常捕获,并转化为 `event: error` 输出 `[AC-AISVC-09]`
|
||||
|
||||
### Next Action (Must be Specific)
|
||||
|
||||
**Immediate**: Phase 3 已完成!准备执行 Phase 4。
|
||||
|
||||
**Note**: Phase 4 的 SSE 功能大部分已在 Phase 1-3 中提前实现:
|
||||
- Accept 头切换已在 `test_accept_switching.py` 测试
|
||||
- SSE 状态机已在 `app/core/sse.py` 实现
|
||||
- SSE 事件生成器已实现
|
||||
- Orchestrator 流式生成已实现
|
||||
|
||||
**建议**: 跳过 Phase 4,直接执行 Phase 5 集成测试。
|
||||
- [x] T9.1 实现 `Tenant` 实体:定义租户数据模型 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.2 实现租户 ID 格式校验:`name@ash@year` 格式验证 `[AC-AISVC-10, AC-AISVC-12]` ✅
|
||||
- [x] T9.3 实现租户自动创建:请求时自动创建不存在的租户 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.4 实现 `GET /admin/tenants` API:返回租户列表 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.5 前端租户选择器:实现租户切换功能 `[AC-ASA-01]` ✅
|
||||
- [x] T9.6 文档多编码支持:支持 UTF-8、GBK、GB2312 等编码解码 `[AC-AISVC-21]` ✅
|
||||
- [x] T9.7 按行分块功能:实现 `chunk_text_by_lines` 函数 `[AC-AISVC-22]` ✅
|
||||
- [x] T9.8 实现 `NomicEmbeddingProvider`:支持多维度向量 `[AC-AISVC-29]` ✅
|
||||
- [x] T9.9 实现多向量存储:支持 full/256/512 三种维度 `[AC-AISVC-16]` ✅
|
||||
- [x] T9.10 实现 `KnowledgeIndexer`:优化的知识库索引服务 `[AC-AISVC-22]` ✅
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -65,121 +62,69 @@
|
|||
- `ai-service/`
|
||||
- `app/`
|
||||
- `api/` - FastAPI 路由层
|
||||
- `admin/tenants.py` - 租户管理 API ✅
|
||||
- `core/` - 配置、异常、中间件、SSE
|
||||
- `middleware.py` - 租户 ID 格式校验与自动创建 ✅
|
||||
- `models/` - Pydantic 模型和 SQLModel 实体
|
||||
- `entities.py` - Tenant 实体 ✅
|
||||
- `services/`
|
||||
- `llm/` - LLM Adapter 实现 ✅
|
||||
- `base.py` - LLMClient 抽象接口
|
||||
- `openai_client.py` - OpenAI 兼容客户端
|
||||
- `memory.py` - Memory 服务
|
||||
- `orchestrator.py` - 编排服务 ✅ (完整实现)
|
||||
- `context.py` - 上下文合并 ✅
|
||||
- `confidence.py` - 置信度计算 ✅
|
||||
- `embedding/nomic_provider.py` - Nomic 嵌入提供者 ✅
|
||||
- `retrieval/` - 检索层
|
||||
- `tests/` - 单元测试 (184 tests)
|
||||
- `indexer.py` - 知识库索引服务 ✅
|
||||
- `metadata.py` - 元数据模型 ✅
|
||||
- `optimized_retriever.py` - 优化检索器 ✅
|
||||
- `tests/` - 单元测试
|
||||
|
||||
### Key Decisions (Why / Impact)
|
||||
|
||||
- decision: LLM Adapter 使用 httpx 而非 langchain-openai
|
||||
reason: 更轻量、更可控、减少依赖
|
||||
impact: 需要手动处理 OpenAI API 响应解析
|
||||
- decision: 租户 ID 格式采用 `name@ash@year` 格式
|
||||
reason: 便于解析和展示租户信息
|
||||
impact: 中间件自动校验格式并解析
|
||||
|
||||
- decision: 使用 tenacity 实现重试逻辑
|
||||
reason: 简单可靠的重试机制
|
||||
impact: 提高服务稳定性
|
||||
- decision: 租户自动创建策略
|
||||
reason: 简化租户管理流程,无需预先创建
|
||||
impact: 首次请求时自动创建租户记录
|
||||
|
||||
- decision: Orchestrator 使用依赖注入模式
|
||||
reason: 便于测试和组件替换
|
||||
impact: 所有组件可通过构造函数注入
|
||||
- decision: 多维度向量存储(full/256/512)
|
||||
reason: 支持不同检索场景的性能优化
|
||||
impact: Qdrant 使用 named vector 存储
|
||||
|
||||
- decision: 使用 GenerationContext 数据类追踪生成流程
|
||||
reason: 清晰追踪中间结果和诊断信息
|
||||
impact: 便于调试和问题排查
|
||||
|
||||
- decision: Pydantic 模型使用 alias 实现驼峰命名
|
||||
reason: 符合 OpenAPI 契约的 camelCase 要求
|
||||
impact: JSON 序列化时自动转换字段名
|
||||
|
||||
### Code Snippets
|
||||
|
||||
```python
|
||||
# [AC-AISVC-02] ChatResponse with contract-compliant field names
|
||||
response = ChatResponse(
|
||||
reply="AI response",
|
||||
confidence=0.85,
|
||||
should_transfer=False,
|
||||
)
|
||||
json_str = response.model_dump_json(by_alias=True)
|
||||
# Output: {"reply": "AI response", "confidence": 0.85, "shouldTransfer": false, ...}
|
||||
```
|
||||
- decision: 文档多编码支持
|
||||
reason: 兼容中文文档的各种编码格式
|
||||
impact: 按优先级尝试多种编码解码
|
||||
|
||||
---
|
||||
|
||||
## 🧾 Session History
|
||||
|
||||
### Session #1 (2026-02-24)
|
||||
### Session #6 (2026-02-25)
|
||||
- completed:
|
||||
- T3.1 实现 LLM Adapter
|
||||
- 创建 LLMClient 抽象接口 (base.py)
|
||||
- 实现 OpenAIClient (openai_client.py)
|
||||
- 编写单元测试 (test_llm_adapter.py)
|
||||
- 修复 entities.py JSON 类型问题
|
||||
- T9.1-T9.10 租户管理与 RAG 优化功能
|
||||
- 实现 Tenant 实体和租户管理 API
|
||||
- 实现租户 ID 格式校验与自动创建
|
||||
- 实现前端租户选择器
|
||||
- 实现文档多编码支持
|
||||
- 实现按行分块功能
|
||||
- 实现 NomicEmbeddingProvider
|
||||
- 实现多维度向量存储
|
||||
- 实现 KnowledgeIndexer
|
||||
- changes:
|
||||
- 新增 `app/services/llm/__init__.py`
|
||||
- 新增 `app/services/llm/base.py`
|
||||
- 新增 `app/services/llm/openai_client.py`
|
||||
- 新增 `tests/test_llm_adapter.py`
|
||||
- 更新 `app/core/config.py` 添加 LLM 配置
|
||||
- 修复 `app/models/entities.py` JSON 列类型
|
||||
|
||||
### Session #2 (2026-02-24)
|
||||
- completed:
|
||||
- T3.2 实现上下文合并逻辑
|
||||
- 创建 ContextMerger 类 (context.py)
|
||||
- 实现消息指纹计算 (SHA256)
|
||||
- 实现去重和截断策略
|
||||
- 编写单元测试 (test_context.py)
|
||||
- changes:
|
||||
- 新增 `app/services/context.py`
|
||||
- 新增 `tests/test_context.py`
|
||||
|
||||
### Session #3 (2026-02-24)
|
||||
- completed:
|
||||
- T3.3 实现置信度计算与转人工逻辑
|
||||
- 创建 ConfidenceCalculator 类 (confidence.py)
|
||||
- 实现检索不足判定
|
||||
- 实现置信度计算策略
|
||||
- 实现 shouldTransfer 逻辑
|
||||
- 编写单元测试 (test_confidence.py)
|
||||
- changes:
|
||||
- 新增 `app/services/confidence.py`
|
||||
- 新增 `tests/test_confidence.py`
|
||||
- 更新 `app/core/config.py` 添加置信度配置
|
||||
|
||||
### Session #4 (2026-02-24)
|
||||
- completed:
|
||||
- T3.4 实现 Orchestrator 完整生成闭环
|
||||
- 整合 Memory、ContextMerger、Retriever、LLMClient、ConfidenceCalculator
|
||||
- 实现 generate() 方法完整流程 (8 步)
|
||||
- 创建 GenerationContext 数据类追踪生成流程
|
||||
- 实现 fallback 响应机制
|
||||
- 编写单元测试 (test_orchestrator.py, 21 tests)
|
||||
- changes:
|
||||
- 更新 `app/services/orchestrator.py` 完整实现
|
||||
- 新增 `tests/test_orchestrator.py`
|
||||
- tests_passed: 138 tests (all passing)
|
||||
|
||||
### Session #5 (2026-02-24)
|
||||
- completed:
|
||||
- T3.5 验证 non-streaming 响应字段符合 OpenAPI 契约
|
||||
- 验证 ChatResponse 字段与契约一致性
|
||||
- 验证 JSON 序列化使用 camelCase
|
||||
- 验证必填字段和可选字段
|
||||
- 验证 confidence 范围约束
|
||||
- 编写契约验证测试 (test_contract.py, 23 tests)
|
||||
- changes:
|
||||
- 新增 `tests/test_contract.py`
|
||||
- tests_passed: 184 tests (all passing)
|
||||
- 新增 `app/models/entities.py` Tenant 实体
|
||||
- 更新 `app/core/middleware.py` 租户校验逻辑
|
||||
- 新增 `app/api/admin/tenants.py` 租户管理 API
|
||||
- 新增 `ai-service-admin/src/api/tenant.ts` 前端 API
|
||||
- 更新 `ai-service-admin/src/App.vue` 租户选择器
|
||||
- 更新 `ai-service/app/api/admin/kb.py` 多编码支持
|
||||
- 新增 `app/services/embedding/nomic_provider.py`
|
||||
- 新增 `app/services/retrieval/indexer.py`
|
||||
- 新增 `app/services/retrieval/metadata.py`
|
||||
- 新增 `app/services/retrieval/optimized_retriever.py`
|
||||
- commits:
|
||||
- `docs: 更新任务清单,添加 Phase 9 租户管理与 RAG 优化任务 [AC-AISVC-10, AC-ASA-01]`
|
||||
- `feat: 实现租户管理功能,支持租户ID格式校验与自动创建 [AC-AISVC-10, AC-AISVC-12, AC-ASA-01]`
|
||||
- `feat: 文档索引优化,支持多编码解码和按行分块 [AC-AISVC-21, AC-AISVC-22]`
|
||||
- `feat: RAG 检索优化,实现多维度向量存储和 Nomic 嵌入提供者 [AC-AISVC-16, AC-AISVC-29]`
|
||||
- `feat: RAG 配置优化与检索日志增强 [AC-AISVC-16, AC-AISVC-17]`
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
---
|
||||
module: ai-service-admin
|
||||
title: "AI 中台管理界面(ai-service-admin)任务清单"
|
||||
status: "draft"
|
||||
version: "0.2.0"
|
||||
status: "completed"
|
||||
version: "0.4.0"
|
||||
owners:
|
||||
- "frontend"
|
||||
- "backend"
|
||||
last_updated: "2026-02-24"
|
||||
last_updated: "2026-02-25"
|
||||
principles:
|
||||
- atomic
|
||||
- page-oriented
|
||||
|
|
@ -218,3 +218,28 @@ principles:
|
|||
| P6-08 | Token 统计展示 | ✅ 已完成 |
|
||||
| P6-09 | LLM 选择器 | ✅ 已完成 |
|
||||
| P6-10 | RAG 实验室整合 | ✅ 已完成 |
|
||||
|
||||
---
|
||||
|
||||
## Phase 7: 租户管理(v0.4.0)
|
||||
|
||||
> 页面导向:租户选择器与租户管理功能。
|
||||
|
||||
- [x] (P7-01) 租户 API 服务层:创建 src/api/tenant.ts 和 src/types/tenant.ts
|
||||
- AC: [AC-ASA-01]
|
||||
|
||||
- [x] (P7-02) 租户选择器组件:实现 `TenantSelector` 下拉组件,支持租户切换
|
||||
- AC: [AC-ASA-01]
|
||||
|
||||
- [x] (P7-03) 租户持久化:租户选择持久化到 localStorage
|
||||
- AC: [AC-ASA-01]
|
||||
|
||||
---
|
||||
|
||||
## Phase 7 任务进度追踪
|
||||
|
||||
| 任务 | 描述 | 状态 |
|
||||
|------|------|------|
|
||||
| P7-01 | 租户 API 服务层 | ✅ 已完成 |
|
||||
| P7-02 | 租户选择器组件 | ✅ 已完成 |
|
||||
| P7-03 | 租户持久化 | ✅ 已完成 |
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
feature_id: "AISVC"
|
||||
title: "Python AI 中台(ai-service)任务清单"
|
||||
status: "completed"
|
||||
version: "0.4.0"
|
||||
last_updated: "2026-02-24"
|
||||
version: "0.5.0"
|
||||
last_updated: "2026-02-25"
|
||||
---
|
||||
|
||||
# Python AI 中台任务清单(AISVC)
|
||||
|
|
@ -83,7 +83,7 @@ last_updated: "2026-02-24"
|
|||
|
||||
## 5. 完成总结
|
||||
|
||||
**Phase 1-7 已全部完成,Phase 8 进行中**
|
||||
**Phase 1-9 已全部完成**
|
||||
|
||||
| Phase | 描述 | 任务数 | 状态 |
|
||||
|-------|------|--------|------|
|
||||
|
|
@ -94,9 +94,10 @@ last_updated: "2026-02-24"
|
|||
| Phase 5 | 集成测试 | 3 | ✅ 完成 |
|
||||
| Phase 6 | 前后端联调真实对接 | 9 | ✅ 完成 |
|
||||
| Phase 7 | 嵌入模型可插拔与文档解析 | 21 | ✅ 完成 |
|
||||
| Phase 8 | LLM 配置与 RAG 调试输出 | 10 | ⏳ 进行中 |
|
||||
| Phase 8 | LLM 配置与 RAG 调试输出 | 10 | ✅ 完成 |
|
||||
| Phase 9 | 租户管理与 RAG 优化 | 10 | ✅ 完成 |
|
||||
|
||||
**已完成: 53 个任务 | 进行中: 10 个任务**
|
||||
**已完成: 73 个任务**
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -136,3 +137,17 @@ last_updated: "2026-02-24"
|
|||
- [x] T8.8 实现 RAG 实验流式输出:SSE 流式 AI 回复 `[AC-AISVC-48]` ✅
|
||||
- [x] T8.9 支持指定 LLM 提供者:RAG 实验可选择不同 LLM `[AC-AISVC-50]` ✅
|
||||
- [x] T8.10 更新 OpenAPI 契约:添加 LLM 管理和 RAG 实验增强接口 ✅
|
||||
|
||||
---
|
||||
|
||||
### Phase 9: 租户管理与 RAG 优化(v0.5.0 迭代)
|
||||
- [x] T9.1 实现 `Tenant` 实体:定义租户数据模型 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.2 实现租户 ID 格式校验:`name@ash@year` 格式验证 `[AC-AISVC-10, AC-AISVC-12]` ✅
|
||||
- [x] T9.3 实现租户自动创建:请求时自动创建不存在的租户 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.4 实现 `GET /admin/tenants` API:返回租户列表 `[AC-AISVC-10]` ✅
|
||||
- [x] T9.5 前端租户选择器:实现租户切换功能 `[AC-ASA-01]` ✅
|
||||
- [x] T9.6 文档多编码支持:支持 UTF-8、GBK、GB2312 等编码解码 `[AC-AISVC-21]` ✅
|
||||
- [x] T9.7 按行分块功能:实现 `chunk_text_by_lines` 函数 `[AC-AISVC-22]` ✅
|
||||
- [x] T9.8 实现 `NomicEmbeddingProvider`:支持多维度向量 `[AC-AISVC-29]` ✅
|
||||
- [x] T9.9 实现多向量存储:支持 full/256/512 三种维度 `[AC-AISVC-16]` ✅
|
||||
- [x] T9.10 实现 `KnowledgeIndexer`:优化的知识库索引服务 `[AC-AISVC-22]` ✅
|
||||
|
|
|
|||
Loading…
Reference in New Issue