Compare commits
13 Commits
1b71c29ddb
...
a0044c4c42
| Author | SHA1 | Date |
|---|---|---|
|
|
a0044c4c42 | |
|
|
a61fb72d2b | |
|
|
60e16d65c9 | |
|
|
a6276522c8 | |
|
|
1490235b8f | |
|
|
9196247578 | |
|
|
b3343f9e52 | |
|
|
6fec2a755a | |
|
|
e45396e1e4 | |
|
|
e9de808969 | |
|
|
4de2a2aece | |
|
|
b3680bda8a | |
|
|
7134ec3c5e |
|
|
@ -6,7 +6,9 @@ import type {
|
|||
LLMTestResult,
|
||||
LLMTestRequest,
|
||||
LLMProvidersResponse,
|
||||
LLMConfigUpdateResponse
|
||||
LLMUsageTypesResponse,
|
||||
LLMConfigUpdateResponse,
|
||||
LLMAllConfigs
|
||||
} from '@/types/llm'
|
||||
|
||||
export function getLLMProviders(): Promise<LLMProvidersResponse> {
|
||||
|
|
@ -16,10 +18,22 @@ export function getLLMProviders(): Promise<LLMProvidersResponse> {
|
|||
})
|
||||
}
|
||||
|
||||
export function getLLMConfig(): Promise<LLMConfig> {
|
||||
export function getLLMUsageTypes(): Promise<LLMUsageTypesResponse> {
|
||||
return request({
|
||||
url: '/admin/llm/usage-types',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
export function getLLMConfig(usageType?: string): Promise<LLMConfig | LLMAllConfigs> {
|
||||
const params: Record<string, string> = {}
|
||||
if (usageType) {
|
||||
params.usage_type = usageType
|
||||
}
|
||||
return request({
|
||||
url: '/admin/llm/config',
|
||||
method: 'get'
|
||||
method: 'get',
|
||||
params
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -46,5 +60,7 @@ export type {
|
|||
LLMTestResult,
|
||||
LLMTestRequest,
|
||||
LLMProvidersResponse,
|
||||
LLMConfigUpdateResponse
|
||||
LLMUsageTypesResponse,
|
||||
LLMConfigUpdateResponse,
|
||||
LLMAllConfigs
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,26 +2,40 @@ import { defineStore } from 'pinia'
|
|||
import { ref, computed } from 'vue'
|
||||
import {
|
||||
getLLMProviders,
|
||||
getLLMUsageTypes,
|
||||
getLLMConfig,
|
||||
updateLLMConfig,
|
||||
testLLM,
|
||||
type LLMProviderInfo,
|
||||
type LLMConfig,
|
||||
type LLMConfigUpdate,
|
||||
type LLMTestResult
|
||||
type LLMTestResult,
|
||||
type LLMUsageType,
|
||||
type LLMAllConfigs
|
||||
} from '@/api/llm'
|
||||
|
||||
export const useLLMStore = defineStore('llm', () => {
|
||||
const providers = ref<LLMProviderInfo[]>([])
|
||||
const currentConfig = ref<LLMConfig>({
|
||||
provider: '',
|
||||
config: {}
|
||||
const usageTypes = ref<LLMUsageType[]>([])
|
||||
const allConfigs = ref<LLMAllConfigs>({
|
||||
chat: { provider: '', config: {} },
|
||||
kb_processing: { provider: '', config: {} }
|
||||
})
|
||||
const currentUsageType = ref<string>('chat')
|
||||
const loading = ref(false)
|
||||
const providersLoading = ref(false)
|
||||
const testResult = ref<LLMTestResult | null>(null)
|
||||
const testLoading = ref(false)
|
||||
|
||||
const currentConfig = computed(() => {
|
||||
const config = allConfigs.value[currentUsageType.value as keyof LLMAllConfigs]
|
||||
return {
|
||||
provider: config?.provider || '',
|
||||
config: config?.config || {},
|
||||
usage_type: currentUsageType.value
|
||||
}
|
||||
})
|
||||
|
||||
const currentProvider = computed(() => {
|
||||
return providers.value.find(p => p.name === currentConfig.value.provider)
|
||||
})
|
||||
|
|
@ -43,16 +57,29 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
}
|
||||
}
|
||||
|
||||
const loadUsageTypes = async () => {
|
||||
try {
|
||||
const res: any = await getLLMUsageTypes()
|
||||
usageTypes.value = res?.usage_types || res?.data?.usage_types || []
|
||||
} catch (error) {
|
||||
console.error('Failed to load LLM usage types:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
const loadConfig = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res: any = await getLLMConfig()
|
||||
const config = res?.data || res
|
||||
if (config) {
|
||||
currentConfig.value = {
|
||||
provider: config.provider || '',
|
||||
config: config.config || {},
|
||||
updated_at: config.updated_at
|
||||
const configs = res?.data || res
|
||||
if (configs) {
|
||||
if (configs.chat && configs.kb_processing) {
|
||||
allConfigs.value = configs
|
||||
} else {
|
||||
allConfigs.value = {
|
||||
chat: { provider: configs.provider || '', config: configs.config || {} },
|
||||
kb_processing: { provider: configs.provider || '', config: configs.config || {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
|
@ -68,7 +95,8 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
try {
|
||||
const updateData: LLMConfigUpdate = {
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
config: currentConfig.value.config,
|
||||
usage_type: currentUsageType.value
|
||||
}
|
||||
await updateLLMConfig(updateData)
|
||||
} catch (error) {
|
||||
|
|
@ -86,7 +114,8 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
const result = await testLLM({
|
||||
test_prompt: testPrompt,
|
||||
provider: currentConfig.value.provider,
|
||||
config: currentConfig.value.config
|
||||
config: currentConfig.value.config,
|
||||
usage_type: currentUsageType.value
|
||||
})
|
||||
testResult.value = result
|
||||
return result
|
||||
|
|
@ -103,7 +132,8 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
}
|
||||
|
||||
const setProvider = (providerName: string) => {
|
||||
currentConfig.value.provider = providerName
|
||||
const usageTypeKey = currentUsageType.value as keyof LLMAllConfigs
|
||||
allConfigs.value[usageTypeKey].provider = providerName
|
||||
const provider = providers.value.find(p => p.name === providerName)
|
||||
if (provider?.config_schema?.properties) {
|
||||
const newConfig: Record<string, any> = {}
|
||||
|
|
@ -127,14 +157,19 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
}
|
||||
}
|
||||
})
|
||||
currentConfig.value.config = newConfig
|
||||
allConfigs.value[usageTypeKey].config = newConfig
|
||||
} else {
|
||||
currentConfig.value.config = {}
|
||||
allConfigs.value[usageTypeKey].config = {}
|
||||
}
|
||||
}
|
||||
|
||||
const updateConfigValue = (key: string, value: any) => {
|
||||
currentConfig.value.config[key] = value
|
||||
const usageTypeKey = currentUsageType.value as keyof LLMAllConfigs
|
||||
allConfigs.value[usageTypeKey].config[key] = value
|
||||
}
|
||||
|
||||
const setCurrentUsageType = (usageType: string) => {
|
||||
currentUsageType.value = usageType
|
||||
}
|
||||
|
||||
const clearTestResult = () => {
|
||||
|
|
@ -143,6 +178,9 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
|
||||
return {
|
||||
providers,
|
||||
usageTypes,
|
||||
allConfigs,
|
||||
currentUsageType,
|
||||
currentConfig,
|
||||
loading,
|
||||
providersLoading,
|
||||
|
|
@ -151,11 +189,13 @@ export const useLLMStore = defineStore('llm', () => {
|
|||
currentProvider,
|
||||
configSchema,
|
||||
loadProviders,
|
||||
loadUsageTypes,
|
||||
loadConfig,
|
||||
saveCurrentConfig,
|
||||
runTest,
|
||||
setProvider,
|
||||
updateConfigValue,
|
||||
setCurrentUsageType,
|
||||
clearTestResult
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -5,15 +5,32 @@ export interface LLMProviderInfo {
|
|||
config_schema: Record<string, any>
|
||||
}
|
||||
|
||||
export interface LLMUsageType {
|
||||
name: string
|
||||
display_name: string
|
||||
description: string
|
||||
}
|
||||
|
||||
export interface LLMConfig {
|
||||
provider: string
|
||||
config: Record<string, any>
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
export interface LLMConfigByUsage {
|
||||
provider: string
|
||||
config: Record<string, any>
|
||||
}
|
||||
|
||||
export interface LLMAllConfigs {
|
||||
chat: LLMConfigByUsage
|
||||
kb_processing: LLMConfigByUsage
|
||||
}
|
||||
|
||||
export interface LLMConfigUpdate {
|
||||
provider: string
|
||||
config?: Record<string, any>
|
||||
usage_type?: string
|
||||
}
|
||||
|
||||
export interface LLMTestResult {
|
||||
|
|
@ -31,12 +48,17 @@ export interface LLMTestRequest {
|
|||
test_prompt?: string
|
||||
provider?: string
|
||||
config?: Record<string, any>
|
||||
usage_type?: string
|
||||
}
|
||||
|
||||
export interface LLMProvidersResponse {
|
||||
providers: LLMProviderInfo[]
|
||||
}
|
||||
|
||||
export interface LLMUsageTypesResponse {
|
||||
usage_types: LLMUsageType[]
|
||||
}
|
||||
|
||||
export interface LLMConfigUpdateResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
|
|
|
|||
|
|
@ -4,13 +4,7 @@
|
|||
<div class="header-content">
|
||||
<div class="title-section">
|
||||
<h1 class="page-title">LLM 模型配置</h1>
|
||||
<p class="page-desc">配置和管理系统使用的大语言模型,支持多种提供者切换。配置修改后需保存才能生效。</p>
|
||||
</div>
|
||||
<div class="header-actions" v-if="currentConfig.updated_at">
|
||||
<div class="update-info">
|
||||
<el-icon class="update-icon"><Clock /></el-icon>
|
||||
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
|
||||
</div>
|
||||
<p class="page-desc">配置和管理系统使用的大语言模型,支持多种提供者切换。可以为不同用途配置不同的模型。</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -30,6 +24,26 @@
|
|||
</template>
|
||||
|
||||
<div class="card-content">
|
||||
<div class="usage-type-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Setting /></el-icon>
|
||||
<span>选择用途类型</span>
|
||||
</div>
|
||||
<el-radio-group v-model="currentUsageType" class="usage-type-radio" @change="handleUsageTypeChange">
|
||||
<el-radio-button
|
||||
v-for="ut in usageTypes"
|
||||
:key="ut.name"
|
||||
:value="ut.name"
|
||||
>
|
||||
<el-tooltip :content="ut.description" placement="top">
|
||||
<span>{{ ut.display_name }}</span>
|
||||
</el-tooltip>
|
||||
</el-radio-button>
|
||||
</el-radio-group>
|
||||
</div>
|
||||
|
||||
<el-divider />
|
||||
|
||||
<div class="provider-select-section">
|
||||
<div class="section-label">
|
||||
<el-icon><Connection /></el-icon>
|
||||
|
|
@ -103,7 +117,7 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Clock } from '@element-plus/icons-vue'
|
||||
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Setting } from '@element-plus/icons-vue'
|
||||
import { useLLMStore } from '@/stores/llm'
|
||||
import ProviderSelect from '@/components/common/ProviderSelect.vue'
|
||||
import ConfigForm from '@/components/common/ConfigForm.vue'
|
||||
|
|
@ -117,21 +131,18 @@ const saving = ref(false)
|
|||
const pageLoading = ref(false)
|
||||
|
||||
const providers = computed(() => llmStore.providers)
|
||||
const usageTypes = computed(() => llmStore.usageTypes)
|
||||
const currentConfig = computed(() => llmStore.currentConfig)
|
||||
const currentProvider = computed(() => llmStore.currentProvider)
|
||||
const configSchema = computed(() => llmStore.configSchema)
|
||||
const providersLoading = computed(() => llmStore.providersLoading)
|
||||
const currentUsageType = computed({
|
||||
get: () => llmStore.currentUsageType,
|
||||
set: (val) => llmStore.setCurrentUsageType(val)
|
||||
})
|
||||
|
||||
const formatDate = (dateStr: string) => {
|
||||
if (!dateStr) return ''
|
||||
const date = new Date(dateStr)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
})
|
||||
const handleUsageTypeChange = (usageType: string) => {
|
||||
llmStore.setCurrentUsageType(usageType)
|
||||
}
|
||||
|
||||
const handleProviderChange = (provider: any) => {
|
||||
|
|
@ -189,6 +200,7 @@ const initPage = async () => {
|
|||
try {
|
||||
await Promise.all([
|
||||
llmStore.loadProviders(),
|
||||
llmStore.loadUsageTypes(),
|
||||
llmStore.loadConfig()
|
||||
])
|
||||
} catch (error) {
|
||||
|
|
@ -253,27 +265,6 @@ onMounted(() => {
|
|||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.update-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 14px;
|
||||
background-color: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.update-icon {
|
||||
font-size: 14px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.config-card {
|
||||
animation: fadeInUp 0.5s ease-out;
|
||||
}
|
||||
|
|
@ -329,8 +320,18 @@ onMounted(() => {
|
|||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.provider-select-section {
|
||||
margin-bottom: 16px;
|
||||
.usage-type-section {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.usage-type-radio {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.usage-type-radio :deep(.el-radio-button__inner) {
|
||||
padding: 10px 20px;
|
||||
}
|
||||
|
||||
.section-label {
|
||||
|
|
@ -347,6 +348,10 @@ onMounted(() => {
|
|||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
.provider-select-section {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import json
|
|||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Optional
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
|
|
@ -38,6 +39,20 @@ from app.services.metadata_field_definition_service import MetadataFieldDefiniti
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = get_settings()
|
||||
kb_vector_logger = logging.getLogger("kb_vector_payload")
|
||||
if settings.kb_vector_log_enabled and not kb_vector_logger.handlers:
|
||||
handler = RotatingFileHandler(
|
||||
filename=settings.kb_vector_log_path,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
)
|
||||
handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
|
||||
kb_vector_logger.addHandler(handler)
|
||||
kb_vector_logger.setLevel(logging.INFO)
|
||||
kb_vector_logger.propagate = False
|
||||
|
||||
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
|
||||
|
||||
|
||||
|
|
@ -661,6 +676,7 @@ async def _index_document(
|
|||
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
|
||||
|
||||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"}
|
||||
|
||||
if file_ext in text_extensions or not file_ext:
|
||||
logger.info("[INDEX] Treating as text file, trying multiple encodings")
|
||||
|
|
@ -676,6 +692,44 @@ async def _index_document(
|
|||
if text is None:
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
logger.warning("[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
|
||||
elif file_ext in image_extensions:
|
||||
logger.info("[INDEX] Image file detected, will parse with multimodal LLM")
|
||||
await kb_service.update_job_status(
|
||||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
logger.info(f"[INDEX] Temp file created: {tmp_path}")
|
||||
|
||||
try:
|
||||
from app.services.document.image_parser import ImageParser
|
||||
logger.info(f"[INDEX] Starting image parsing for {file_ext}...")
|
||||
image_parser = ImageParser()
|
||||
image_result = await image_parser.parse_with_chunks(tmp_path)
|
||||
text = image_result.raw_text
|
||||
parse_result = type('ParseResult', (), {
|
||||
'text': text,
|
||||
'metadata': image_result.metadata,
|
||||
'pages': None,
|
||||
'image_chunks': image_result.chunks,
|
||||
'image_summary': image_result.image_summary,
|
||||
})()
|
||||
logger.info(
|
||||
f"[INDEX] Parsed image SUCCESS: {filename}, "
|
||||
f"chars={len(text)}, chunks={len(image_result.chunks)}, "
|
||||
f"summary={image_result.image_summary[:50] if image_result.image_summary else 'N/A'}..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[INDEX] Image parsing error: {type(e).__name__}: {e}")
|
||||
text = ""
|
||||
parse_result = None
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
logger.info("[INDEX] Temp file cleaned up")
|
||||
else:
|
||||
logger.info("[INDEX] Binary file detected, will parse with document parser")
|
||||
await kb_service.update_job_status(
|
||||
|
|
@ -723,13 +777,64 @@ async def _index_document(
|
|||
)
|
||||
await session.commit()
|
||||
|
||||
from app.services.metadata_auto_inference_service import MetadataAutoInferenceService
|
||||
|
||||
inference_service = MetadataAutoInferenceService(session)
|
||||
|
||||
image_base64_for_inference = None
|
||||
mime_type_for_inference = None
|
||||
if file_ext in image_extensions:
|
||||
import base64
|
||||
image_base64_for_inference = base64.b64encode(content).decode("utf-8")
|
||||
mime_type_map = {
|
||||
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
||||
".png": "image/png", ".gif": "image/gif",
|
||||
".webp": "image/webp", ".bmp": "image/bmp",
|
||||
".tiff": "image/tiff", ".tif": "image/tiff",
|
||||
}
|
||||
mime_type_for_inference = mime_type_map.get(file_ext, "image/jpeg")
|
||||
|
||||
logger.info("[INDEX] Starting metadata auto-inference...")
|
||||
inference_result = await inference_service.infer_metadata(
|
||||
tenant_id=tenant_id,
|
||||
content=text or "",
|
||||
scope="kb_document",
|
||||
existing_metadata=metadata,
|
||||
image_base64=image_base64_for_inference,
|
||||
mime_type=mime_type_for_inference,
|
||||
)
|
||||
|
||||
if inference_result.success:
|
||||
metadata = inference_result.inferred_metadata
|
||||
logger.info(
|
||||
f"[INDEX] Metadata inference SUCCESS: "
|
||||
f"inferred_fields={list(inference_result.inferred_metadata.keys())}, "
|
||||
f"confidence_scores={inference_result.confidence_scores}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[INDEX] Metadata inference FAILED: {inference_result.error_message}, "
|
||||
f"using existing metadata"
|
||||
)
|
||||
|
||||
logger.info("[INDEX] Getting embedding provider...")
|
||||
embedding_provider = await get_embedding_provider()
|
||||
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
|
||||
|
||||
all_chunks: list[TextChunk] = []
|
||||
|
||||
if parse_result and parse_result.pages:
|
||||
if parse_result and hasattr(parse_result, 'image_chunks') and parse_result.image_chunks:
|
||||
logger.info(f"[INDEX] Image with {len(parse_result.image_chunks)} intelligent chunks from LLM")
|
||||
for img_chunk in parse_result.image_chunks:
|
||||
all_chunks.append(TextChunk(
|
||||
text=img_chunk.content,
|
||||
start_token=img_chunk.chunk_index,
|
||||
end_token=img_chunk.chunk_index + 1,
|
||||
page=None,
|
||||
source=filename,
|
||||
))
|
||||
logger.info(f"[INDEX] Total chunks from image: {len(all_chunks)}")
|
||||
elif parse_result and parse_result.pages:
|
||||
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
|
||||
for page in parse_result.pages:
|
||||
page_chunks = chunk_text_by_lines(
|
||||
|
|
@ -807,6 +912,35 @@ async def _index_document(
|
|||
await session.commit()
|
||||
|
||||
if points:
|
||||
if settings.kb_vector_log_enabled:
|
||||
vector_payloads = []
|
||||
for point in points:
|
||||
if use_multi_vector:
|
||||
payload = {
|
||||
"id": point.get("id"),
|
||||
"vector": point.get("vector"),
|
||||
"payload": point.get("payload"),
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"id": point.id,
|
||||
"vector": point.vector,
|
||||
"payload": point.payload,
|
||||
}
|
||||
vector_payloads.append(payload)
|
||||
|
||||
kb_vector_logger.info(json.dumps({
|
||||
"tenant_id": tenant_id,
|
||||
"kb_id": kb_id,
|
||||
"doc_id": doc_id,
|
||||
"job_id": job_id,
|
||||
"filename": filename,
|
||||
"file_ext": file_ext,
|
||||
"is_image": file_ext in image_extensions,
|
||||
"metadata": doc_metadata,
|
||||
"vectors": vector_payloads,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant for kb_id={kb_id}...")
|
||||
if use_multi_vector:
|
||||
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from fastapi import APIRouter, Depends, Header, HTTPException
|
|||
|
||||
from app.services.llm.factory import (
|
||||
LLMProviderFactory,
|
||||
LLMUsageType,
|
||||
LLM_USAGE_DISPLAY_NAMES,
|
||||
LLM_USAGE_DESCRIPTIONS,
|
||||
get_llm_config_manager,
|
||||
)
|
||||
|
||||
|
|
@ -49,25 +52,63 @@ async def list_providers(
|
|||
}
|
||||
|
||||
|
||||
@router.get("/usage-types")
|
||||
async def list_usage_types(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List all available LLM usage types.
|
||||
"""
|
||||
logger.info(f"Listing LLM usage types for tenant={tenant_id}")
|
||||
|
||||
return {
|
||||
"usage_types": [
|
||||
{
|
||||
"name": ut.value,
|
||||
"display_name": LLM_USAGE_DISPLAY_NAMES[ut],
|
||||
"description": LLM_USAGE_DESCRIPTIONS[ut],
|
||||
}
|
||||
for ut in LLMUsageType
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
usage_type: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get current LLM configuration.
|
||||
[AC-ASA-14] Returns current provider and config.
|
||||
If usage_type is specified, returns config for that usage type.
|
||||
Otherwise, returns all configs.
|
||||
"""
|
||||
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}")
|
||||
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}, usage_type={usage_type}")
|
||||
|
||||
manager = get_llm_config_manager()
|
||||
config = manager.get_current_config()
|
||||
|
||||
if usage_type:
|
||||
try:
|
||||
ut = LLMUsageType(usage_type)
|
||||
config = manager.get_current_config(ut)
|
||||
masked_config = _mask_secrets(config.get("config", {}))
|
||||
|
||||
return {
|
||||
"usage_type": config["usage_type"],
|
||||
"provider": config["provider"],
|
||||
"config": masked_config,
|
||||
}
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid usage_type: {usage_type}")
|
||||
|
||||
all_configs = manager.get_current_config()
|
||||
result = {}
|
||||
for ut_key, config in all_configs.items():
|
||||
result[ut_key] = {
|
||||
"provider": config["provider"],
|
||||
"config": _mask_secrets(config.get("config", {})),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
|
|
@ -78,11 +119,25 @@ async def update_config(
|
|||
"""
|
||||
Update LLM configuration.
|
||||
[AC-ASA-16] Updates provider and config with validation.
|
||||
|
||||
Request body format:
|
||||
- For specific usage type:
|
||||
{
|
||||
"usage_type": "chat" | "kb_processing",
|
||||
"provider": "openai",
|
||||
"config": {...}
|
||||
}
|
||||
- For all usage types (backward compatible):
|
||||
{
|
||||
"provider": "openai",
|
||||
"config": {...}
|
||||
}
|
||||
"""
|
||||
provider = body.get("provider")
|
||||
config = body.get("config", {})
|
||||
usage_type_str = body.get("usage_type")
|
||||
|
||||
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}")
|
||||
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}, usage_type={usage_type_str}")
|
||||
|
||||
if not provider:
|
||||
return {
|
||||
|
|
@ -92,6 +147,18 @@ async def update_config(
|
|||
|
||||
try:
|
||||
manager = get_llm_config_manager()
|
||||
|
||||
if usage_type_str:
|
||||
try:
|
||||
usage_type = LLMUsageType(usage_type_str)
|
||||
await manager.update_usage_config(usage_type, provider, config)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"LLM configuration updated for {usage_type_str} to {provider}",
|
||||
}
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid usage_type: {usage_type_str}")
|
||||
else:
|
||||
await manager.update_config(provider, config)
|
||||
|
||||
return {
|
||||
|
|
@ -115,23 +182,44 @@ async def test_connection(
|
|||
"""
|
||||
Test LLM connection.
|
||||
[AC-ASA-17, AC-ASA-18] Tests connection and returns response.
|
||||
|
||||
Request body format:
|
||||
{
|
||||
"test_prompt": "optional test prompt",
|
||||
"provider": "optional provider to test",
|
||||
"config": "optional config to test",
|
||||
"usage_type": "optional usage type to test"
|
||||
}
|
||||
"""
|
||||
body = body or {}
|
||||
|
||||
test_prompt = body.get("test_prompt", "你好,请简单介绍一下自己。")
|
||||
provider = body.get("provider")
|
||||
config = body.get("config")
|
||||
usage_type_str = body.get("usage_type")
|
||||
|
||||
logger.info(
|
||||
f"[AC-ASA-17] Testing LLM connection for tenant={tenant_id}, "
|
||||
f"provider={provider or 'current'}"
|
||||
f"provider={provider or 'current'}, usage_type={usage_type_str or 'default'}"
|
||||
)
|
||||
|
||||
manager = get_llm_config_manager()
|
||||
|
||||
usage_type = None
|
||||
if usage_type_str:
|
||||
try:
|
||||
usage_type = LLMUsageType(usage_type_str)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid usage_type: {usage_type_str}",
|
||||
}
|
||||
|
||||
result = await manager.test_connection(
|
||||
test_prompt=test_prompt,
|
||||
provider=provider,
|
||||
config=config,
|
||||
usage_type=usage_type,
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
Retrieval Strategy API Endpoints.
|
||||
[AC-AISVC-RES-01~15] 策略管理 API。
|
||||
|
||||
Endpoints:
|
||||
- GET /strategy/retrieval/current - 获取当前策略状态
|
||||
- POST /strategy/retrieval/switch - 切换策略
|
||||
- POST /strategy/retrieval/validate - 验证策略配置
|
||||
- POST /strategy/retrieval/rollback - 回退策略
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.retrieval.strategy.config import (
|
||||
GrayscaleConfig,
|
||||
ModeRouterConfig,
|
||||
PipelineConfig,
|
||||
RetrievalStrategyConfig,
|
||||
RerankerConfig,
|
||||
RuntimeMode,
|
||||
StrategyType,
|
||||
)
|
||||
from app.services.retrieval.strategy.rollback_manager import (
|
||||
RollbackManager,
|
||||
RollbackResult,
|
||||
RollbackTrigger,
|
||||
get_rollback_manager,
|
||||
)
|
||||
from app.services.retrieval.strategy.strategy_router import (
|
||||
StrategyRouter,
|
||||
get_strategy_router,
|
||||
set_strategy_router,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
|
||||
|
||||
|
||||
class GrayscaleConfigSchema(BaseModel):
|
||||
"""灰度配置 Schema。"""
|
||||
enabled: bool = False
|
||||
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
|
||||
allowlist: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RerankerConfigSchema(BaseModel):
|
||||
"""重排器配置 Schema。"""
|
||||
enabled: bool = False
|
||||
model: str = "cross-encoder"
|
||||
top_k_after_rerank: int = 5
|
||||
min_score_threshold: float = 0.3
|
||||
|
||||
|
||||
class ModeRouterConfigSchema(BaseModel):
|
||||
"""模式路由配置 Schema。"""
|
||||
runtime_mode: str = "direct"
|
||||
react_trigger_confidence_threshold: float = 0.6
|
||||
react_trigger_complexity_score: float = 0.5
|
||||
react_max_steps: int = 5
|
||||
direct_fallback_on_low_confidence: bool = True
|
||||
|
||||
|
||||
class PipelineConfigSchema(BaseModel):
|
||||
"""Pipeline 配置 Schema。"""
|
||||
top_k: int = 5
|
||||
score_threshold: float = 0.01
|
||||
min_hits: int = 1
|
||||
two_stage_enabled: bool = True
|
||||
|
||||
|
||||
class StrategyStatusResponse(BaseModel):
|
||||
"""策略状态响应。"""
|
||||
active_strategy: str
|
||||
grayscale: GrayscaleConfigSchema
|
||||
pipeline: PipelineConfigSchema
|
||||
reranker: RerankerConfigSchema
|
||||
mode_router: ModeRouterConfigSchema
|
||||
performance_thresholds: dict[str, float] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SwitchRequest(BaseModel):
|
||||
"""策略切换请求。"""
|
||||
active_strategy: str | None = None
|
||||
grayscale: GrayscaleConfigSchema | None = None
|
||||
mode_router: ModeRouterConfigSchema | None = None
|
||||
reranker: RerankerConfigSchema | None = None
|
||||
|
||||
|
||||
class SwitchResponse(BaseModel):
|
||||
"""策略切换响应。"""
|
||||
success: bool
|
||||
previous_strategy: str
|
||||
current_strategy: str
|
||||
message: str
|
||||
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
"""策略验证请求。"""
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""策略验证响应。"""
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
config_summary: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RollbackResponse(BaseModel):
|
||||
"""策略回退响应。"""
|
||||
success: bool
|
||||
previous_strategy: str
|
||||
current_strategy: str
|
||||
trigger: str
|
||||
reason: str
|
||||
audit_log_id: str | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/current",
|
||||
response_model=StrategyStatusResponse,
|
||||
summary="获取当前策略状态",
|
||||
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
|
||||
)
|
||||
async def get_current_strategy() -> StrategyStatusResponse:
|
||||
"""
|
||||
【AC-AISVC-RES-01】 获取当前策略状态。
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
config = strategy_router.get_config()
|
||||
|
||||
return StrategyStatusResponse(
|
||||
active_strategy=config.active_strategy.value,
|
||||
grayscale=GrayscaleConfigSchema(
|
||||
enabled=config.grayscale.enabled,
|
||||
percentage=config.grayscale.percentage,
|
||||
allowlist=config.grayscale.allowlist,
|
||||
),
|
||||
pipeline=PipelineConfigSchema(
|
||||
top_k=config.pipeline.top_k,
|
||||
score_threshold=config.pipeline.score_threshold,
|
||||
min_hits=config.pipeline.min_hits,
|
||||
two_stage_enabled=config.pipeline.two_stage_enabled,
|
||||
),
|
||||
reranker=RerankerConfigSchema(
|
||||
enabled=config.reranker.enabled,
|
||||
model=config.reranker.model,
|
||||
top_k_after_rerank=config.reranker.top_k_after_rerank,
|
||||
min_score_threshold=config.reranker.min_score_threshold,
|
||||
),
|
||||
mode_router=ModeRouterConfigSchema(
|
||||
runtime_mode=config.mode_router.runtime_mode.value,
|
||||
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
|
||||
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
|
||||
react_max_steps=config.mode_router.react_max_steps,
|
||||
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
|
||||
),
|
||||
performance_thresholds=config.performance_thresholds,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/switch",
|
||||
response_model=SwitchResponse,
|
||||
summary="切换策略",
|
||||
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
|
||||
)
|
||||
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
|
||||
"""
|
||||
【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换策略。
|
||||
|
||||
支持灰度发布配置(percentage/allowlist)。
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
current_config = strategy_router.get_config()
|
||||
previous_strategy = current_config.active_strategy.value
|
||||
|
||||
try:
|
||||
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
|
||||
)
|
||||
|
||||
new_config = RetrievalStrategyConfig(
|
||||
active_strategy=new_active_strategy,
|
||||
grayscale=GrayscaleConfig(
|
||||
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
|
||||
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
|
||||
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
|
||||
),
|
||||
mode_router=ModeRouterConfig(
|
||||
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
|
||||
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
|
||||
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
|
||||
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
|
||||
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
|
||||
) if request.mode_router else current_config.mode_router,
|
||||
reranker=RerankerConfig(
|
||||
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
|
||||
model=request.reranker.model if request.reranker else current_config.reranker.model,
|
||||
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
|
||||
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
|
||||
) if request.reranker else current_config.reranker,
|
||||
pipeline=current_config.pipeline,
|
||||
metadata_inference=current_config.metadata_inference,
|
||||
performance_thresholds=current_config.performance_thresholds,
|
||||
)
|
||||
|
||||
strategy_router.update_config(new_config)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
|
||||
)
|
||||
|
||||
return SwitchResponse(
|
||||
success=True,
|
||||
previous_strategy=previous_strategy,
|
||||
current_strategy=new_config.active_strategy.value,
|
||||
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/validate",
|
||||
response_model=ValidationResponse,
|
||||
summary="验证策略配置",
|
||||
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
|
||||
)
|
||||
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
|
||||
"""
|
||||
【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置。
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
config = request.config
|
||||
|
||||
if "active_strategy" in config:
|
||||
if config["active_strategy"] not in ["default", "enhanced"]:
|
||||
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
|
||||
|
||||
if "grayscale" in config:
|
||||
grayscale = config["grayscale"]
|
||||
if grayscale.get("percentage", 1.0) is not None:
|
||||
if not (0 <= grayscale["percentage"] <= 100):
|
||||
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
|
||||
|
||||
if "mode_router" in config:
|
||||
mode_router = config["mode_router"]
|
||||
if "runtime_mode" in mode_router:
|
||||
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
|
||||
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
|
||||
|
||||
if "reranker" in config:
|
||||
reranker = config["reranker"]
|
||||
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
|
||||
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
|
||||
|
||||
if "performance_thresholds" in config:
|
||||
thresholds = config["performance_thresholds"]
|
||||
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
|
||||
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
|
||||
|
||||
return ValidationResponse(
|
||||
valid=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
config_summary={
|
||||
"active_strategy": config.get("active_strategy", "default"),
|
||||
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
|
||||
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
|
||||
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
|
||||
"performance_thresholds": config.get("performance_thresholds", {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rollback",
|
||||
response_model=RollbackResponse,
|
||||
summary="回退策略",
|
||||
description="【AC-AISVC-RES-07】 回退到默认策略。",
|
||||
)
|
||||
async def rollback_strategy(
|
||||
trigger: str = "manual",
|
||||
reason: str = "",
|
||||
) -> RollbackResponse:
|
||||
"""
|
||||
【AC-AISVC-RES-07】 回退策略。
|
||||
|
||||
支持手动触发和自动触发(性能退化、异常)。
|
||||
"""
|
||||
strategy_router = get_strategy_router()
|
||||
current_config = strategy_router.get_config()
|
||||
previous_strategy = current_config.active_strategy
|
||||
|
||||
if previous_strategy == StrategyType.DEFAULT:
|
||||
return RollbackResponse(
|
||||
success=False,
|
||||
previous_strategy=previous_strategy.value,
|
||||
current_strategy=previous_strategy.value,
|
||||
trigger=trigger,
|
||||
reason="Already on default strategy",
|
||||
audit_log_id=None,
|
||||
)
|
||||
|
||||
new_config = RetrievalStrategyConfig(
|
||||
active_strategy=StrategyType.DEFAULT,
|
||||
grayscale=current_config.grayscale,
|
||||
mode_router=current_config.mode_router,
|
||||
reranker=current_config.reranker,
|
||||
pipeline=current_config.pipeline,
|
||||
metadata_inference=current_config.metadata_inference,
|
||||
performance_thresholds=current_config.performance_thresholds,
|
||||
)
|
||||
strategy_router.update_config(new_config)
|
||||
|
||||
rollback_manager = get_rollback_manager()
|
||||
rollback_manager.update_config(new_config)
|
||||
|
||||
audit_log = rollback_manager.record_audit(
|
||||
action="rollback",
|
||||
details={
|
||||
"from_strategy": previous_strategy.value,
|
||||
"to_strategy": StrategyType.DEFAULT.value,
|
||||
"trigger": trigger,
|
||||
"reason": reason or "Manual rollback",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
|
||||
)
|
||||
|
||||
return RollbackResponse(
|
||||
success=True,
|
||||
previous_strategy=previous_strategy.value,
|
||||
current_strategy=StrategyType.DEFAULT.value,
|
||||
trigger=trigger,
|
||||
reason=reason or "Manual rollback",
|
||||
audit_log_id=str(audit_log.timestamp) if audit_log else None,
|
||||
)
|
||||
|
|
@ -580,6 +580,35 @@ async def respond_dialogue(
|
|||
f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}"
|
||||
)
|
||||
|
||||
if dialogue_request.user_id:
|
||||
try:
|
||||
from app.services.mid.memory_adapter import MemoryAdapter
|
||||
from app.services.mid.memory_summary_generator import MemorySummaryGenerator
|
||||
|
||||
memory_adapter = MemoryAdapter(session=session)
|
||||
summary_generator = MemorySummaryGenerator()
|
||||
|
||||
history_messages = [
|
||||
{"role": h.role, "content": h.content}
|
||||
for h in (dialogue_request.history or [])
|
||||
]
|
||||
assistant_reply = "\n".join(s.text for s in final_segments)
|
||||
update_messages = history_messages + [
|
||||
{"role": "user", "content": dialogue_request.user_message},
|
||||
{"role": "assistant", "content": assistant_reply},
|
||||
]
|
||||
|
||||
await memory_adapter.update_with_summary_generation(
|
||||
user_id=dialogue_request.user_id,
|
||||
session_id=dialogue_request.session_id,
|
||||
messages=update_messages,
|
||||
tenant_id=tenant_id,
|
||||
summary_generator=summary_generator,
|
||||
recent_turns=8,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-14] Memory update trigger failed: {e}")
|
||||
|
||||
return DialogueResponse(
|
||||
segments=final_segments,
|
||||
trace=final_trace,
|
||||
|
|
@ -1429,10 +1458,38 @@ async def _execute_agent_mode(
|
|||
|
||||
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
|
||||
|
||||
# 合并 tool_calls:优先使用 KB 工具内部的 trace(包含注入后的参数)
|
||||
final_tool_calls = list(react_ctx.tool_calls) if react_ctx.tool_calls else []
|
||||
logger.info(
|
||||
f"[TRACE-MERGE] Before merge: final_tool_calls count={len(final_tool_calls)}, "
|
||||
f"kb_dynamic_result exists={kb_dynamic_result is not None}, "
|
||||
f"kb_dynamic_result.tool_trace exists={kb_dynamic_result.tool_trace if kb_dynamic_result else None}"
|
||||
)
|
||||
if kb_dynamic_result and kb_dynamic_result.tool_trace:
|
||||
kb_trace = kb_dynamic_result.tool_trace
|
||||
logger.info(
|
||||
f"[TRACE-MERGE] KB trace arguments: {kb_trace.arguments}"
|
||||
)
|
||||
for i, tc in enumerate(final_tool_calls):
|
||||
logger.info(
|
||||
f"[TRACE-MERGE] Checking tool_call[{i}]: tool_name={tc.tool_name}"
|
||||
)
|
||||
if tc.tool_name == "kb_search_dynamic":
|
||||
logger.info(
|
||||
f"[TRACE-MERGE] Replacing trace at index {i}: old_args={tc.arguments}, new_args={kb_trace.arguments}"
|
||||
)
|
||||
final_tool_calls[i] = kb_trace
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[TRACE-MERGE] Skipped merge: kb_dynamic_result={kb_dynamic_result is not None}, "
|
||||
f"tool_trace={kb_dynamic_result.tool_trace if kb_dynamic_result else 'N/A'}"
|
||||
)
|
||||
|
||||
trace_logger.update_trace(
|
||||
request_id=request_id,
|
||||
react_iterations=react_ctx.iteration,
|
||||
tool_calls=react_ctx.tool_calls,
|
||||
tool_calls=final_tool_calls,
|
||||
)
|
||||
|
||||
segments = _text_to_segments(final_answer)
|
||||
|
|
@ -1444,8 +1501,8 @@ async def _execute_agent_mode(
|
|||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
react_iterations=react_ctx.iteration,
|
||||
tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None,
|
||||
tool_calls=react_ctx.tool_calls,
|
||||
tools_used=[tc.tool_name for tc in final_tool_calls] if final_tool_calls else None,
|
||||
tool_calls=final_tool_calls,
|
||||
timeout_profile=timeout_governor.profile,
|
||||
kb_tool_called=True,
|
||||
kb_hit=kb_success and len(kb_hits) > 0,
|
||||
|
|
|
|||
|
|
@ -176,14 +176,38 @@ async def share_chat_page(
|
|||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>对话分享</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600&display=swap" rel="stylesheet">
|
||||
<style>
|
||||
:root {{
|
||||
--bg-primary: #f8fafc;
|
||||
--bg-secondary: #ffffff;
|
||||
--bg-tertiary: #f1f5f9;
|
||||
--text-primary: #0f172a;
|
||||
--text-secondary: #475569;
|
||||
--text-muted: #94a3b8;
|
||||
--accent: #6366f1;
|
||||
--accent-hover: #4f46e5;
|
||||
--accent-light: #eef2ff;
|
||||
--border: #e2e8f0;
|
||||
--shadow-sm: 0 1px 2px rgba(0,0,0,0.04);
|
||||
--shadow-md: 0 4px 12px rgba(0,0,0,0.06);
|
||||
--shadow-lg: 0 8px 32px rgba(0,0,0,0.08);
|
||||
--radius-sm: 8px;
|
||||
--radius-md: 12px;
|
||||
--radius-lg: 16px;
|
||||
--radius-xl: 24px;
|
||||
}}
|
||||
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'PingFang SC', 'Microsoft YaHei', sans-serif;
|
||||
background: #f8f9fa;
|
||||
font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'PingFang SC', 'Microsoft YaHei', sans-serif;
|
||||
background: var(--bg-primary);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
color: var(--text-primary);
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.welcome-screen {{
|
||||
flex: 1;
|
||||
|
|
@ -191,105 +215,221 @@ async def share_chat_page(
|
|||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 40px 20px;
|
||||
padding: 48px 24px;
|
||||
background: linear-gradient(180deg, var(--bg-secondary) 0%, var(--bg-primary) 100%);
|
||||
}}
|
||||
.welcome-screen.hidden {{ display: none; }}
|
||||
.welcome-title {{
|
||||
font-size: 28px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
margin-bottom: 32px;
|
||||
letter-spacing: -0.02em;
|
||||
}}
|
||||
.welcome-input-wrapper {{
|
||||
width: 100%;
|
||||
max-width: 680px;
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
padding: 16px 20px;
|
||||
box-shadow: 0 2px 12px rgba(0,0,0,0.08);
|
||||
max-width: 600px;
|
||||
background: var(--bg-secondary);
|
||||
border-radius: var(--radius-xl);
|
||||
padding: 8px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--border);
|
||||
display: flex;
|
||||
align-items: flex-end;
|
||||
gap: 8px;
|
||||
transition: box-shadow 0.2s ease, border-color 0.2s ease;
|
||||
}}
|
||||
.welcome-input-wrapper:focus-within {{
|
||||
box-shadow: var(--shadow-lg), 0 0 0 3px var(--accent-light);
|
||||
border-color: var(--accent);
|
||||
}}
|
||||
.welcome-textarea, .input-textarea {{
|
||||
width: 100%;
|
||||
min-height: 56px;
|
||||
border: 1px solid #e5e5e5;
|
||||
border-radius: 12px;
|
||||
padding: 12px 16px;
|
||||
flex: 1;
|
||||
min-height: 48px;
|
||||
max-height: 200px;
|
||||
border: none;
|
||||
border-radius: var(--radius-lg);
|
||||
padding: 14px 16px;
|
||||
resize: none;
|
||||
outline: none;
|
||||
font-size: 15px;
|
||||
line-height: 1.6;
|
||||
line-height: 1.5;
|
||||
font-family: inherit;
|
||||
background: #fafafa;
|
||||
transition: all 0.2s;
|
||||
background: transparent;
|
||||
color: var(--text-primary);
|
||||
}}
|
||||
.welcome-textarea:focus, .input-textarea:focus {{
|
||||
border-color: #1677ff;
|
||||
background: white;
|
||||
.welcome-textarea::placeholder, .input-textarea::placeholder {{
|
||||
color: var(--text-muted);
|
||||
}}
|
||||
.chat-screen {{ flex: 1; display: none; flex-direction: column; }}
|
||||
.chat-screen {{ flex: 1; display: none; flex-direction: column; background: var(--bg-primary); }}
|
||||
.chat-screen.active {{ display: flex; }}
|
||||
.chat-list {{
|
||||
flex: 1;
|
||||
padding: 20px;
|
||||
max-width: 800px;
|
||||
padding: 24px;
|
||||
max-width: 720px;
|
||||
width: 100%;
|
||||
margin: 0 auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
gap: 20px;
|
||||
overflow-y: auto;
|
||||
}}
|
||||
.bubble {{ display: flex; gap: 10px; align-items: flex-start; }}
|
||||
.bubble {{ display: flex; gap: 12px; align-items: flex-start; animation: fadeIn 0.3s ease; }}
|
||||
@keyframes fadeIn {{
|
||||
from {{ opacity: 0; transform: translateY(8px); }}
|
||||
to {{ opacity: 1; transform: translateY(0); }}
|
||||
}}
|
||||
.bubble.user {{ flex-direction: row-reverse; }}
|
||||
.avatar {{
|
||||
width: 32px; height: 32px; border-radius: 50%;
|
||||
background: white; display: flex; align-items: center; justify-content: center;
|
||||
width: 36px; height: 36px; border-radius: 50%;
|
||||
background: var(--bg-tertiary);
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
flex-shrink: 0;
|
||||
overflow: hidden;
|
||||
}}
|
||||
.bubble-content {{ max-width: 75%; }}
|
||||
.avatar svg {{
|
||||
width: 22px;
|
||||
height: 22px;
|
||||
}}
|
||||
.bubble.user .avatar {{
|
||||
background: var(--accent-light);
|
||||
}}
|
||||
.bubble.user .avatar svg path {{
|
||||
fill: var(--accent);
|
||||
}}
|
||||
.bubble.bot .avatar svg path {{
|
||||
fill: var(--accent);
|
||||
}}
|
||||
.bubble-content {{ max-width: 80%; min-width: 0; }}
|
||||
.bubble-text {{
|
||||
padding: 12px 16px; border-radius: 16px; white-space: pre-wrap; word-break: break-word;
|
||||
font-size: 14px; line-height: 1.6;
|
||||
padding: 14px 18px;
|
||||
border-radius: var(--radius-lg);
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
font-size: 14px;
|
||||
line-height: 1.65;
|
||||
}}
|
||||
.bubble.user .bubble-text {{
|
||||
background: var(--accent);
|
||||
color: white;
|
||||
border-bottom-right-radius: var(--radius-sm);
|
||||
}}
|
||||
.bubble.bot .bubble-text {{
|
||||
background: var(--bg-secondary);
|
||||
color: var(--text-primary);
|
||||
border-bottom-left-radius: var(--radius-sm);
|
||||
box-shadow: var(--shadow-sm);
|
||||
}}
|
||||
.bubble.error .bubble-text {{
|
||||
background: #fef2f2;
|
||||
color: #dc2626;
|
||||
border: 1px solid #fecaca;
|
||||
}}
|
||||
.bubble.user .bubble-text {{ background: #1677ff; color: white; }}
|
||||
.bubble.bot .bubble-text {{ background: white; color: #333; }}
|
||||
.bubble.error .bubble-text {{ background: #fff2f0; color: #ff4d4f; border: 1px solid #ffccc7; }}
|
||||
.thought-block {{
|
||||
background: #f5f5f5;
|
||||
color: #888;
|
||||
padding: 12px 16px;
|
||||
border-radius: 12px;
|
||||
background: var(--bg-tertiary);
|
||||
color: var(--text-secondary);
|
||||
padding: 14px 18px;
|
||||
border-radius: var(--radius-md);
|
||||
margin-bottom: 12px;
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
border-left: 3px solid #ddd;
|
||||
line-height: 1.65;
|
||||
border-left: 3px solid var(--text-muted);
|
||||
}}
|
||||
.thought-label {{
|
||||
font-weight: 600;
|
||||
color: #999;
|
||||
margin-bottom: 6px;
|
||||
font-size: 12px;
|
||||
color: var(--text-muted);
|
||||
margin-bottom: 8px;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}}
|
||||
.final-answer-block {{
|
||||
background: white;
|
||||
color: #333;
|
||||
padding: 12px 16px;
|
||||
border-radius: 12px;
|
||||
background: var(--bg-secondary);
|
||||
color: var(--text-primary);
|
||||
padding: 14px 18px;
|
||||
border-radius: var(--radius-md);
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
line-height: 1.65;
|
||||
}}
|
||||
.final-answer-label {{
|
||||
font-weight: 600;
|
||||
color: #1677ff;
|
||||
margin-bottom: 6px;
|
||||
font-size: 12px;
|
||||
color: var(--accent);
|
||||
margin-bottom: 8px;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}}
|
||||
.input-area {{
|
||||
background: var(--bg-secondary);
|
||||
padding: 16px 24px 24px;
|
||||
border-top: 1px solid var(--border);
|
||||
}}
|
||||
.input-wrapper {{
|
||||
max-width: 720px;
|
||||
margin: 0 auto;
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
align-items: flex-end;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: var(--radius-lg);
|
||||
padding: 6px 6px 6px 16px;
|
||||
border: 1px solid var(--border);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}}
|
||||
.input-wrapper:focus-within {{
|
||||
border-color: var(--accent);
|
||||
box-shadow: 0 0 0 3px var(--accent-light);
|
||||
}}
|
||||
.input-textarea {{
|
||||
background: transparent;
|
||||
padding: 10px 0;
|
||||
}}
|
||||
.input-area {{ background: white; padding: 16px 20px 20px; border-top: 1px solid #eee; }}
|
||||
.input-wrapper {{ max-width: 800px; margin: 0 auto; display: flex; gap: 12px; align-items: flex-end; }}
|
||||
.send-btn, .welcome-send {{
|
||||
width: 40px; height: 40px; border-radius: 50%; border: none; cursor: pointer;
|
||||
background: #1677ff; color: white;
|
||||
width: 44px; height: 44px;
|
||||
border-radius: 50%;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
background: var(--accent);
|
||||
color: white;
|
||||
font-size: 18px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: background 0.2s ease, transform 0.15s ease;
|
||||
flex-shrink: 0;
|
||||
}}
|
||||
.send-btn:hover, .welcome-send:hover {{
|
||||
background: var(--accent-hover);
|
||||
transform: scale(1.05);
|
||||
}}
|
||||
.send-btn:active, .welcome-send:active {{
|
||||
transform: scale(0.95);
|
||||
}}
|
||||
.send-btn:disabled, .welcome-send:disabled {{
|
||||
background: var(--text-muted);
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}}
|
||||
.status {{
|
||||
text-align: center;
|
||||
padding: 12px;
|
||||
font-size: 13px;
|
||||
color: var(--text-muted);
|
||||
font-weight: 500;
|
||||
}}
|
||||
.status.error {{ color: #dc2626; }}
|
||||
@media (max-width: 640px) {{
|
||||
.welcome-screen {{ padding: 32px 16px; }}
|
||||
.welcome-title {{ font-size: 22px; margin-bottom: 24px; }}
|
||||
.chat-list {{ padding: 16px; gap: 16px; }}
|
||||
.bubble-content {{ max-width: 85%; }}
|
||||
.input-area {{ padding: 12px 16px 20px; }}
|
||||
}}
|
||||
.status {{ text-align: center; padding: 8px; font-size: 12px; color: #999; }}
|
||||
.status.error {{ color: #ff4d4f; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="welcome-screen" id="welcomeScreen">
|
||||
<h1>今天有什么可以帮到你?</h1>
|
||||
<h1 class="welcome-title">今天有什么可以帮到你?</h1>
|
||||
<div class="welcome-input-wrapper">
|
||||
<textarea class="welcome-textarea" id="welcomeInput" placeholder="输入消息,按 Enter 发送" rows="1"></textarea>
|
||||
<button class="welcome-send" id="welcomeSendBtn">➤</button>
|
||||
|
|
@ -367,7 +507,12 @@ function formatBotMessage(text) {{
|
|||
function addMessage(role, text) {{
|
||||
const div = document.createElement('div');
|
||||
div.className = 'bubble ' + role;
|
||||
const avatar = role === 'user' ? '👤' : (role === 'bot' ? '🤖' : '⚠️');
|
||||
|
||||
const userSvg = '<svg viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg"><path d="M573.9 516.2L512 640l-61.9-123.8C232 546.4 64 733.6 64 960h896c0-226.4-168-413.6-386.1-443.8zM480 384h64c17.7 0 32.1 14.4 32.1 32.1 0 17.7-14.4 32.1-32.1 32.1h-64c-11.9 0-22.3-6.5-27.8-16.1H356c34.9 48.5 91.7 80 156 80 106 0 192-86 192-192s-86-192-192-192-192 86-192 192c0 28.5 6.2 55.6 17.4 80h114.8c5.5-9.6 15.9-16.1 27.8-16.1z"/><path d="M272 432.1h84c-4.2-5.9-8.1-12-11.7-18.4-2.3-4.1-4.4-8.3-6.4-12.5-0.2-0.4-0.4-0.7-0.5-1.1H288c-8.8 0-16-7.2-16-16v-48.4c0-64.1 25-124.3 70.3-169.6S447.9 95.8 512 95.8s124.3 25 169.7 70.3c38.3 38.3 62.1 87.2 68.5 140.2-8.4 4-14.2 12.5-14.2 22.4v78.6c0 13.7 11.1 24.8 24.8 24.8h14.6c13.7 0 24.8-11.1 24.8-24.8v-78.6c0-11.3-7.6-20.9-18-23.8-6.9-60.9-33.9-117.4-78-161.3C652.9 92.1 584.6 63.9 512 63.9s-140.9 28.3-192.3 79.6C268.3 194.8 240 263.1 240 335.7v64.4c0 17.7 14.3 32 32 32z"/></svg>';
|
||||
const botSvg = '<svg viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg"><path d="M894.1 355.6h-1.7C853 177.6 687.6 51.4 498.1 54.9S148.2 190.5 115.9 369.7c-35.2 5.6-61.1 36-61.1 71.7v143.4c0.9 40.4 34.3 72.5 74.7 71.7 21.7-0.3 42.2-10 56-26.7 33.6 84.5 99.9 152 183.8 187 1.1-2 2.3-3.9 3.7-5.7 0.9-1.5 2.4-2.6 4.1-3 1.3 0 2.5 0.5 3.6 1.2a318.46 318.46 0 0 1-105.3-187.1c-5.1-44.4 24.1-85.4 67.6-95.2 64.3-11.7 128.1-24.7 192.4-35.9 37.9-5.3 70.4-29.8 85.7-64.9 6.8-15.9 11-32.8 12.5-50 0.5-3.1 2.9-5.6 5.9-6.2 3.1-0.7 6.4 0.5 8.2 3l1.7-1.1c25.4 35.9 74.7 114.4 82.7 197.2 8.2 94.8 3.7 160-71.4 226.5-1.1 1.1-1.7 2.6-1.7 4.1 0.1 2 1.1 3.8 2.8 4.8h4.8l3.2-1.8c75.6-40.4 132.8-108.2 159.9-189.5 11.4 16.1 28.5 27.1 47.8 30.8C846 783.9 716.9 871.6 557.2 884.9c-12-28.6-42.5-44.8-72.9-38.6-33.6 5.4-56.6 37-51.2 70.6 4.4 27.6 26.8 48.8 54.5 51.6 30.6 4.6 60.3-13 70.8-42.2 184.9-14.5 333.2-120.8 364.2-286.9 27.8-10.8 46.3-37.4 46.6-67.2V428.7c-0.1-19.5-8.1-38.2-22.3-51.6-14.5-13.8-33.8-21.4-53.8-21.3l1-0.2zM825.9 397c-71.1-176.9-272.1-262.7-449-191.7-86.8 34.9-155.7 103.4-191 190-2.5-2.8-5.2-5.4-8-7.9 25.3-154.6 163.8-268.6 326.8-269.2s302.3 112.6 328.7 267c-2.9 3.8-5.4 7.7-7.5 11.8z"/></svg>';
|
||||
const errorSvg = '<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>';
|
||||
|
||||
const avatar = role === 'user' ? userSvg : (role === 'bot' ? botSvg : errorSvg);
|
||||
|
||||
let contentHtml;
|
||||
if (role === 'bot') {{
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ class Settings(BaseSettings):
|
|||
|
||||
log_level: str = "INFO"
|
||||
|
||||
kb_vector_log_enabled: bool = False
|
||||
kb_vector_log_path: str = "logs/kb_vector_payload.log"
|
||||
|
||||
llm_provider: str = "openai"
|
||||
llm_api_key: str = ""
|
||||
llm_base_url: str = "https://api.openai.com/v1"
|
||||
|
|
|
|||
|
|
@ -492,18 +492,46 @@ class QdrantClient:
|
|||
构建 Qdrant 过滤条件。
|
||||
|
||||
Args:
|
||||
metadata_filter: 元数据过滤条件,如 {"grade": "三年级", "subject": "语文"}
|
||||
metadata_filter: 元数据过滤条件,支持两种格式:
|
||||
- 简单值格式: {"grade": "三年级", "subject": "语文"}
|
||||
- 操作符格式: {"grade": {"$eq": "三年级"}, "kb_scene": {"$eq": "open_consult"}}
|
||||
|
||||
Returns:
|
||||
Qdrant Filter 对象
|
||||
"""
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, MatchAny
|
||||
|
||||
must_conditions = []
|
||||
|
||||
for key, value in metadata_filter.items():
|
||||
# 支持嵌套 metadata 字段,如 metadata.grade
|
||||
field_path = f"metadata.{key}"
|
||||
|
||||
if isinstance(value, dict):
|
||||
op = list(value.keys())[0] if value else None
|
||||
actual_value = value.get(op) if op else None
|
||||
|
||||
if op == "$eq" and actual_value is not None:
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=actual_value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
elif op == "$in" and isinstance(actual_value, list):
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchAny(any=actual_value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-AISVC-16] Unsupported filter operator: {op}, using as direct value"
|
||||
)
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
)
|
||||
must_conditions.append(condition)
|
||||
else:
|
||||
condition = FieldCondition(
|
||||
key=field_path,
|
||||
match=MatchValue(value=value),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ Main FastAPI application for AI Service.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
|
|
@ -52,15 +54,51 @@ from app.core.qdrant_client import close_qdrant_client
|
|||
|
||||
settings = get_settings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||
def setup_logging():
|
||||
"""
|
||||
配置滚动日志文件。
|
||||
- 日志文件存储在 logs/ 目录
|
||||
- 单文件最大 2MB,超过则切分
|
||||
- 保留最近 7 天的日志(约 70 个备份文件)
|
||||
- 同时输出到控制台
|
||||
"""
|
||||
log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, settings.log_level.upper()))
|
||||
|
||||
root_logger.handlers.clear()
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(getattr(logging, settings.log_level.upper()))
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
log_file = os.path.join(log_dir, "ai-service.log")
|
||||
file_handler = RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=2 * 1024 * 1024,
|
||||
backupCount=70,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setLevel(getattr(logging, settings.log_level.upper()))
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||
|
||||
return log_dir
|
||||
|
||||
|
||||
setup_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -116,6 +116,44 @@ class ChatMessageCreate(SQLModel):
|
|||
content: str
|
||||
|
||||
|
||||
class UserMemory(SQLModel, table=True):
|
||||
"""
|
||||
[AC-IDMP-14] 用户级记忆存储(滚动摘要)
|
||||
支持多租户隔离,存储最新 summary + facts/preferences/open_issues
|
||||
"""
|
||||
|
||||
__tablename__ = "user_memories"
|
||||
__table_args__ = (
|
||||
Index("ix_user_memories_tenant_user", "tenant_id", "user_id"),
|
||||
Index("ix_user_memories_tenant_user_updated", "tenant_id", "user_id", "updated_at"),
|
||||
)
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
|
||||
user_id: str = Field(..., description="User ID for memory storage", index=True)
|
||||
summary: str | None = Field(default=None, description="Rolling summary for user")
|
||||
facts: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("facts", JSON, nullable=True),
|
||||
description="Extracted stable facts list",
|
||||
)
|
||||
preferences: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("preferences", JSON, nullable=True),
|
||||
description="User preferences as structured JSON",
|
||||
)
|
||||
open_issues: list[str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column("open_issues", JSON, nullable=True),
|
||||
description="Open issues list",
|
||||
)
|
||||
summary_version: int = Field(default=1, description="Summary version / update round")
|
||||
last_turn_id: str | None = Field(default=None, description="Last turn identifier (optional)")
|
||||
expires_at: datetime | None = Field(default=None, description="Expiration time (optional)")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
|
||||
|
||||
class SharedSession(SQLModel, table=True):
|
||||
"""
|
||||
[AC-IDMP-SHARE] Shared session entity for dialogue sharing.
|
||||
|
|
|
|||
|
|
@ -16,6 +16,16 @@ from app.services.document.factory import (
|
|||
get_supported_document_formats,
|
||||
parse_document,
|
||||
)
|
||||
from app.services.document.image_parser import ImageParser
|
||||
from app.services.document.markdown_chunker import (
|
||||
MarkdownChunk,
|
||||
MarkdownChunker,
|
||||
MarkdownElement,
|
||||
MarkdownElementType,
|
||||
MarkdownParser as MarkdownStructureParser,
|
||||
chunk_markdown,
|
||||
)
|
||||
from app.services.document.markdown_parser import MarkdownParser
|
||||
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
|
||||
from app.services.document.text_parser import TextParser
|
||||
from app.services.document.word_parser import WordParser
|
||||
|
|
@ -35,4 +45,12 @@ __all__ = [
|
|||
"ExcelParser",
|
||||
"CSVParser",
|
||||
"TextParser",
|
||||
"MarkdownParser",
|
||||
"MarkdownChunker",
|
||||
"MarkdownChunk",
|
||||
"MarkdownElement",
|
||||
"MarkdownElementType",
|
||||
"MarkdownStructureParser",
|
||||
"chunk_markdown",
|
||||
"ImageParser",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from app.services.document.base import (
|
|||
UnsupportedFormatError,
|
||||
)
|
||||
from app.services.document.excel_parser import CSVParser, ExcelParser
|
||||
from app.services.document.image_parser import ImageParser
|
||||
from app.services.document.markdown_parser import MarkdownParser
|
||||
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
|
||||
from app.services.document.text_parser import TextParser
|
||||
from app.services.document.word_parser import WordParser
|
||||
|
|
@ -45,6 +47,8 @@ class DocumentParserFactory:
|
|||
"excel": ExcelParser,
|
||||
"csv": CSVParser,
|
||||
"text": TextParser,
|
||||
"markdown": MarkdownParser,
|
||||
"image": ImageParser,
|
||||
}
|
||||
|
||||
cls._extension_map = {
|
||||
|
|
@ -54,14 +58,22 @@ class DocumentParserFactory:
|
|||
".xls": "excel",
|
||||
".csv": "csv",
|
||||
".txt": "text",
|
||||
".md": "text",
|
||||
".markdown": "text",
|
||||
".md": "markdown",
|
||||
".markdown": "markdown",
|
||||
".rst": "text",
|
||||
".log": "text",
|
||||
".json": "text",
|
||||
".xml": "text",
|
||||
".yaml": "text",
|
||||
".yml": "text",
|
||||
".jpg": "image",
|
||||
".jpeg": "image",
|
||||
".png": "image",
|
||||
".gif": "image",
|
||||
".webp": "image",
|
||||
".bmp": "image",
|
||||
".tiff": "image",
|
||||
".tif": "image",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -174,6 +186,8 @@ class DocumentParserFactory:
|
|||
"excel": "Excel 电子表格",
|
||||
"csv": "CSV 文件",
|
||||
"text": "文本文件",
|
||||
"markdown": "Markdown 文档",
|
||||
"image": "图片文件",
|
||||
}
|
||||
|
||||
descriptions = {
|
||||
|
|
@ -183,6 +197,8 @@ class DocumentParserFactory:
|
|||
"excel": "解析 Excel 电子表格,支持多工作表",
|
||||
"csv": "解析 CSV 文件,自动检测编码",
|
||||
"text": "解析纯文本文件,支持多种编码",
|
||||
"markdown": "智能解析 Markdown 文档,保留结构(标题、代码块、表格、列表)",
|
||||
"image": "使用多模态 LLM 解析图片,提取文字和关键信息",
|
||||
}
|
||||
|
||||
info.append({
|
||||
|
|
|
|||
|
|
@ -0,0 +1,490 @@
|
|||
"""
|
||||
Image parser using multimodal LLM.
|
||||
Supports parsing images into structured text content for knowledge base indexing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
PageText,
|
||||
ParseResult,
|
||||
)
|
||||
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMAGE_SYSTEM_PROMPT = """你是一个专业的图像内容分析助手。你的任务是分析图片内容,并将其智能拆分为适合知识库检索的独立数据块。
|
||||
|
||||
## 分析要求
|
||||
1. 仔细分析图片内容,识别其中的文字、图表、数据等信息
|
||||
2. 根据内容的逻辑结构,智能判断如何拆分为独立的知识条目
|
||||
3. 每个条目应该是独立、完整、可检索的知识单元
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下 JSON 格式输出,不要添加任何其他内容:
|
||||
|
||||
```json
|
||||
{
|
||||
"image_summary": "图片整体概述(一句话描述图片主题)",
|
||||
"total_chunks": <分块总数>,
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "该分块的完整内容文字",
|
||||
"chunk_type": "text|table|list|diagram|chart|mixed",
|
||||
"keywords": ["关键词1", "关键词2"]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 分块策略
|
||||
- **单一内容**: 如果图片只有一段完整的文字/信息,可以只输出1个分块
|
||||
- **多段落内容**: 按段落或逻辑单元拆分,每个段落作为独立分块
|
||||
- **表格数据**: 将表格内容转换为结构化文字,作为一个分块
|
||||
- **图表数据**: 描述图表内容和数据,作为一个分块
|
||||
- **列表内容**: 每个列表项可作为独立分块,或合并为相关的一组
|
||||
- **混合内容**: 根据内容类型分别处理,确保每个分块主题明确
|
||||
|
||||
## 注意事项
|
||||
1. 每个分块的 content 必须是完整、可独立理解的文字
|
||||
2. chunk_type 用于标识内容类型,便于后续处理
|
||||
3. keywords 提取该分块的核心关键词,便于检索
|
||||
4. 确保输出的 JSON 格式正确,可以被解析"""
|
||||
|
||||
IMAGE_USER_PROMPT = "请分析这张图片,按照要求的 JSON 格式输出分块结果。"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageChunk:
|
||||
"""智能分块结果"""
|
||||
chunk_index: int
|
||||
content: str
|
||||
chunk_type: str = "text"
|
||||
keywords: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageParseResult:
|
||||
"""图片解析结果(包含智能分块)"""
|
||||
image_summary: str
|
||||
chunks: list[ImageChunk]
|
||||
raw_text: str
|
||||
source_path: str
|
||||
file_size: int
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ImageParser(DocumentParser):
|
||||
"""
|
||||
Image parser using multimodal LLM.
|
||||
|
||||
Supports common image formats and extracts text content using
|
||||
vision-capable LLM models (GPT-4V, GPT-4o, etc.).
|
||||
|
||||
Features:
|
||||
- Intelligent chunking based on content structure
|
||||
- Structured output with keywords and chunk types
|
||||
- Support for various content types (text, table, chart, etc.)
|
||||
"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = [
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
timeout_seconds: int = 120,
|
||||
):
|
||||
self._model = model
|
||||
self._max_tokens = max_tokens
|
||||
self._timeout_seconds = timeout_seconds
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse an image file and extract text content using multimodal LLM.
|
||||
|
||||
Note: This method is synchronous but internally uses async operations.
|
||||
For async contexts, use parse_async() instead.
|
||||
|
||||
Args:
|
||||
file_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
ParseResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
DocumentParseException: If parsing fails.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"Image file not found: {file_path}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
extension = path.suffix.lower()
|
||||
|
||||
if extension not in self.SUPPORTED_EXTENSIONS:
|
||||
raise DocumentParseException(
|
||||
f"Unsupported image format: {extension}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
|
||||
)
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
mime_type = self._get_mime_type(extension)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run,
|
||||
self._analyze_image_async(image_base64, mime_type)
|
||||
)
|
||||
result = future.result()
|
||||
except RuntimeError:
|
||||
result = asyncio.run(self._analyze_image_async(image_base64, mime_type))
|
||||
|
||||
logger.info(
|
||||
f"[IMAGE-PARSER] Successfully parsed image: {path.name}, "
|
||||
f"size={file_size}, chunks={len(result.chunks)}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=result.raw_text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
page_count=1,
|
||||
metadata={
|
||||
"format": extension,
|
||||
"parser": "image",
|
||||
"mime_type": mime_type,
|
||||
"image_summary": result.image_summary,
|
||||
"chunk_count": len(result.chunks),
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": c.chunk_index,
|
||||
"content": c.content,
|
||||
"chunk_type": c.chunk_type,
|
||||
"keywords": c.keywords,
|
||||
}
|
||||
for c in result.chunks
|
||||
],
|
||||
},
|
||||
pages=[PageText(page=1, text=result.raw_text)],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[IMAGE-PARSER] Failed to parse image {path}: {e}")
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse image: {str(e)}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def parse_async(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Async version of parse method for use in async contexts.
|
||||
|
||||
Args:
|
||||
file_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
ParseResult with extracted text content.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"Image file not found: {file_path}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
extension = path.suffix.lower()
|
||||
|
||||
if extension not in self.SUPPORTED_EXTENSIONS:
|
||||
raise DocumentParseException(
|
||||
f"Unsupported image format: {extension}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
|
||||
)
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
mime_type = self._get_mime_type(extension)
|
||||
|
||||
result = await self._analyze_image_async(image_base64, mime_type)
|
||||
|
||||
logger.info(
|
||||
f"[IMAGE-PARSER] Successfully parsed image (async): {path.name}, "
|
||||
f"size={file_size}, chunks={len(result.chunks)}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=result.raw_text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
page_count=1,
|
||||
metadata={
|
||||
"format": extension,
|
||||
"parser": "image",
|
||||
"mime_type": mime_type,
|
||||
"image_summary": result.image_summary,
|
||||
"chunk_count": len(result.chunks),
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": c.chunk_index,
|
||||
"content": c.content,
|
||||
"chunk_type": c.chunk_type,
|
||||
"keywords": c.keywords,
|
||||
}
|
||||
for c in result.chunks
|
||||
],
|
||||
},
|
||||
pages=[PageText(page=1, text=result.raw_text)],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[IMAGE-PARSER] Failed to parse image {path}: {e}")
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse image: {str(e)}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def parse_with_chunks(self, file_path: str | Path) -> ImageParseResult:
|
||||
"""
|
||||
Parse image and return structured result with intelligent chunks.
|
||||
|
||||
Args:
|
||||
file_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
ImageParseResult with intelligent chunks.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"Image file not found: {file_path}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
extension = path.suffix.lower()
|
||||
|
||||
if extension not in self.SUPPORTED_EXTENSIONS:
|
||||
raise DocumentParseException(
|
||||
f"Unsupported image format: {extension}",
|
||||
file_path=str(path),
|
||||
parser="image",
|
||||
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
|
||||
)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
mime_type = self._get_mime_type(extension)
|
||||
|
||||
result = await self._analyze_image_async(image_base64, mime_type)
|
||||
result.source_path = str(path)
|
||||
result.file_size = file_size
|
||||
result.metadata = {
|
||||
"format": extension,
|
||||
"parser": "image",
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def _analyze_image_async(self, image_base64: str, mime_type: str) -> ImageParseResult:
|
||||
"""
|
||||
Analyze image using multimodal LLM and return structured chunks.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image data.
|
||||
mime_type: MIME type of the image.
|
||||
|
||||
Returns:
|
||||
ImageParseResult with intelligent chunks.
|
||||
"""
|
||||
try:
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_kb_processing_client()
|
||||
|
||||
config = manager.kb_processing_config
|
||||
model = self._model or config.get("model", "gpt-4o-mini")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": IMAGE_SYSTEM_PROMPT,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": IMAGE_USER_PROMPT,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{image_base64}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
from app.services.llm.base import LLMConfig
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=0.3,
|
||||
timeout_seconds=self._timeout_seconds,
|
||||
)
|
||||
|
||||
response = await client.generate(messages=messages, config=llm_config)
|
||||
|
||||
if not response.content:
|
||||
raise DocumentParseException(
|
||||
"LLM returned empty response for image analysis",
|
||||
parser="image",
|
||||
)
|
||||
|
||||
return self._parse_llm_response(response.content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[IMAGE-PARSER] LLM analysis failed: {e}")
|
||||
raise
|
||||
|
||||
def _parse_llm_response(self, response_content: str) -> ImageParseResult:
|
||||
"""
|
||||
Parse LLM response into structured ImageParseResult.
|
||||
|
||||
Args:
|
||||
response_content: Raw LLM response content.
|
||||
|
||||
Returns:
|
||||
ImageParseResult with parsed chunks.
|
||||
"""
|
||||
try:
|
||||
json_str = self._extract_json(response_content)
|
||||
data = json.loads(json_str)
|
||||
|
||||
image_summary = data.get("image_summary", "")
|
||||
chunks_data = data.get("chunks", [])
|
||||
|
||||
chunks = []
|
||||
for chunk_data in chunks_data:
|
||||
chunk = ImageChunk(
|
||||
chunk_index=chunk_data.get("chunk_index", len(chunks)),
|
||||
content=chunk_data.get("content", ""),
|
||||
chunk_type=chunk_data.get("chunk_type", "text"),
|
||||
keywords=chunk_data.get("keywords", []),
|
||||
)
|
||||
if chunk.content.strip():
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
chunks.append(ImageChunk(
|
||||
chunk_index=0,
|
||||
content=response_content,
|
||||
chunk_type="text",
|
||||
keywords=[],
|
||||
))
|
||||
|
||||
raw_text = "\n\n".join([c.content for c in chunks])
|
||||
|
||||
return ImageParseResult(
|
||||
image_summary=image_summary,
|
||||
chunks=chunks,
|
||||
raw_text=raw_text,
|
||||
source_path="",
|
||||
file_size=0,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[IMAGE-PARSER] Failed to parse JSON response: {e}, using fallback")
|
||||
return ImageParseResult(
|
||||
image_summary="图片内容",
|
||||
chunks=[ImageChunk(
|
||||
chunk_index=0,
|
||||
content=response_content,
|
||||
chunk_type="text",
|
||||
keywords=[],
|
||||
)],
|
||||
raw_text=response_content,
|
||||
source_path="",
|
||||
file_size=0,
|
||||
)
|
||||
|
||||
def _extract_json(self, content: str) -> str:
|
||||
"""
|
||||
Extract JSON from LLM response content.
|
||||
|
||||
Args:
|
||||
content: Raw response content that may contain JSON.
|
||||
|
||||
Returns:
|
||||
Extracted JSON string.
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
if content.startswith("{") and content.endswith("}"):
|
||||
return content
|
||||
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
return content[json_start:json_end + 1]
|
||||
|
||||
return content
|
||||
|
||||
def _get_mime_type(self, extension: str) -> str:
|
||||
"""Get MIME type for image extension."""
|
||||
mime_types = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".bmp": "image/bmp",
|
||||
".tiff": "image/tiff",
|
||||
".tif": "image/tiff",
|
||||
}
|
||||
return mime_types.get(extension.lower(), "image/jpeg")
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get list of supported image extensions."""
|
||||
return ImageParser.SUPPORTED_EXTENSIONS
|
||||
|
|
@ -0,0 +1,771 @@
|
|||
"""
|
||||
Markdown intelligent chunker with structure-aware splitting.
|
||||
Supports headers, code blocks, tables, lists, and preserves context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownElementType(Enum):
|
||||
"""Types of Markdown elements."""
|
||||
HEADER = "header"
|
||||
PARAGRAPH = "paragraph"
|
||||
CODE_BLOCK = "code_block"
|
||||
INLINE_CODE = "inline_code"
|
||||
TABLE = "table"
|
||||
LIST = "list"
|
||||
BLOCKQUOTE = "blockquote"
|
||||
HORIZONTAL_RULE = "horizontal_rule"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownElement:
|
||||
"""Represents a parsed Markdown element."""
|
||||
type: MarkdownElementType
|
||||
content: str
|
||||
level: int = 0
|
||||
language: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
line_start: int = 0
|
||||
line_end: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"content": self.content,
|
||||
"level": self.level,
|
||||
"language": self.language,
|
||||
"metadata": self.metadata,
|
||||
"line_start": self.line_start,
|
||||
"line_end": self.line_end,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownChunk:
|
||||
"""Represents a chunk of Markdown content with context."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
element_type: MarkdownElementType
|
||||
header_context: list[str]
|
||||
level: int = 0
|
||||
language: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"chunk_id": self.chunk_id,
|
||||
"content": self.content,
|
||||
"element_type": self.element_type.value,
|
||||
"header_context": self.header_context,
|
||||
"level": self.level,
|
||||
"language": self.language,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""
|
||||
Parser for Markdown documents.
|
||||
Extracts structured elements from Markdown text.
|
||||
"""
|
||||
|
||||
HEADER_PATTERN = re.compile(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', re.MULTILINE)
|
||||
CODE_BLOCK_PATTERN = re.compile(r'^```(\w*)\n(.*?)^```', re.MULTILINE | re.DOTALL)
|
||||
TABLE_PATTERN = re.compile(r'^(\|.+\|)\n(\|[-:\s|]+\|)\n((?:\|.+\|\n?)+)', re.MULTILINE)
|
||||
LIST_PATTERN = re.compile(r'^([ \t]*[-*+]|\d+\.)\s+(.+)$', re.MULTILINE)
|
||||
BLOCKQUOTE_PATTERN = re.compile(r'^>\s*(.+)$', re.MULTILINE)
|
||||
HR_PATTERN = re.compile(r'^[-*_]{3,}\s*$', re.MULTILINE)
|
||||
IMAGE_PATTERN = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)')
|
||||
LINK_PATTERN = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
|
||||
INLINE_CODE_PATTERN = re.compile(r'`([^`]+)`')
|
||||
|
||||
def parse(self, text: str) -> list[MarkdownElement]:
|
||||
"""
|
||||
Parse Markdown text into structured elements.
|
||||
|
||||
Args:
|
||||
text: Raw Markdown text
|
||||
|
||||
Returns:
|
||||
List of MarkdownElement objects
|
||||
"""
|
||||
elements = []
|
||||
lines = text.split('\n')
|
||||
current_pos = 0
|
||||
|
||||
code_block_ranges = self._extract_code_blocks(text, lines, elements)
|
||||
table_ranges = self._extract_tables(text, lines, elements)
|
||||
protected_ranges = code_block_ranges + table_ranges
|
||||
|
||||
self._extract_headers(lines, elements, protected_ranges)
|
||||
self._extract_lists(lines, elements, protected_ranges)
|
||||
self._extract_blockquotes(lines, elements, protected_ranges)
|
||||
self._extract_horizontal_rules(lines, elements, protected_ranges)
|
||||
|
||||
self._fill_paragraphs(lines, elements, protected_ranges)
|
||||
|
||||
elements.sort(key=lambda e: e.line_start)
|
||||
|
||||
return elements
|
||||
|
||||
def _extract_code_blocks(
|
||||
self,
|
||||
text: str,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Extract code blocks with language info."""
|
||||
ranges = []
|
||||
in_code_block = False
|
||||
code_start = 0
|
||||
language = ""
|
||||
code_content = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
if not in_code_block:
|
||||
in_code_block = True
|
||||
code_start = i
|
||||
language = line.strip()[3:].strip()
|
||||
code_content = []
|
||||
else:
|
||||
in_code_block = False
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.CODE_BLOCK,
|
||||
content='\n'.join(code_content),
|
||||
language=language,
|
||||
line_start=code_start,
|
||||
line_end=i,
|
||||
metadata={"language": language},
|
||||
))
|
||||
ranges.append((code_start, i))
|
||||
|
||||
elif in_code_block:
|
||||
code_content.append(line)
|
||||
|
||||
return ranges
|
||||
|
||||
def _extract_tables(
|
||||
self,
|
||||
text: str,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Extract Markdown tables."""
|
||||
ranges = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
if '|' in line and i + 1 < len(lines):
|
||||
next_line = lines[i + 1]
|
||||
if '|' in next_line and re.match(r'^[\|\-\:\s]+$', next_line.strip()):
|
||||
table_lines = [line, next_line]
|
||||
j = i + 2
|
||||
|
||||
while j < len(lines) and '|' in lines[j]:
|
||||
table_lines.append(lines[j])
|
||||
j += 1
|
||||
|
||||
table_content = '\n'.join(table_lines)
|
||||
headers = [h.strip() for h in line.split('|') if h.strip()]
|
||||
row_count = len(table_lines) - 2
|
||||
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.TABLE,
|
||||
content=table_content,
|
||||
line_start=i,
|
||||
line_end=j - 1,
|
||||
metadata={
|
||||
"headers": headers,
|
||||
"row_count": row_count,
|
||||
},
|
||||
))
|
||||
ranges.append((i, j - 1))
|
||||
i = j
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
return ranges
|
||||
|
||||
def _is_in_protected_range(self, line_num: int, ranges: list[tuple[int, int]]) -> bool:
|
||||
"""Check if a line is within a protected range."""
|
||||
for start, end in ranges:
|
||||
if start <= line_num <= end:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_headers(
|
||||
self,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
protected_ranges: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Extract headers with level info."""
|
||||
for i, line in enumerate(lines):
|
||||
if self._is_in_protected_range(i, protected_ranges):
|
||||
continue
|
||||
|
||||
match = self.HEADER_PATTERN.match(line)
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2).strip()
|
||||
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.HEADER,
|
||||
content=title,
|
||||
level=level,
|
||||
line_start=i,
|
||||
line_end=i,
|
||||
metadata={"level": level},
|
||||
))
|
||||
|
||||
def _extract_lists(
|
||||
self,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
protected_ranges: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Extract list items."""
|
||||
in_list = False
|
||||
list_start = 0
|
||||
list_items = []
|
||||
list_indent = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if self._is_in_protected_range(i, protected_ranges):
|
||||
if in_list:
|
||||
self._save_list(elements, list_start, i - 1, list_items)
|
||||
in_list = False
|
||||
list_items = []
|
||||
continue
|
||||
|
||||
match = self.LIST_PATTERN.match(line)
|
||||
if match:
|
||||
indent = len(line) - len(line.lstrip())
|
||||
item_content = match.group(2)
|
||||
|
||||
if not in_list:
|
||||
in_list = True
|
||||
list_start = i
|
||||
list_indent = indent
|
||||
list_items = [(indent, item_content)]
|
||||
else:
|
||||
list_items.append((indent, item_content))
|
||||
else:
|
||||
if in_list:
|
||||
if line.strip() == '':
|
||||
continue
|
||||
else:
|
||||
self._save_list(elements, list_start, i - 1, list_items)
|
||||
in_list = False
|
||||
list_items = []
|
||||
|
||||
if in_list:
|
||||
self._save_list(elements, list_start, len(lines) - 1, list_items)
|
||||
|
||||
def _save_list(
|
||||
self,
|
||||
elements: list[MarkdownElement],
|
||||
start: int,
|
||||
end: int,
|
||||
items: list[tuple[int, str]],
|
||||
) -> None:
|
||||
"""Save a list element."""
|
||||
if not items:
|
||||
return
|
||||
|
||||
content = '\n'.join([item[1] for item in items])
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.LIST,
|
||||
content=content,
|
||||
line_start=start,
|
||||
line_end=end,
|
||||
metadata={
|
||||
"item_count": len(items),
|
||||
"is_ordered": False,
|
||||
},
|
||||
))
|
||||
|
||||
def _extract_blockquotes(
|
||||
self,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
protected_ranges: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Extract blockquotes."""
|
||||
in_quote = False
|
||||
quote_start = 0
|
||||
quote_lines = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if self._is_in_protected_range(i, protected_ranges):
|
||||
if in_quote:
|
||||
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
|
||||
in_quote = False
|
||||
quote_lines = []
|
||||
continue
|
||||
|
||||
match = self.BLOCKQUOTE_PATTERN.match(line)
|
||||
if match:
|
||||
if not in_quote:
|
||||
in_quote = True
|
||||
quote_start = i
|
||||
quote_lines.append(match.group(1))
|
||||
else:
|
||||
if in_quote:
|
||||
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
|
||||
in_quote = False
|
||||
quote_lines = []
|
||||
|
||||
if in_quote:
|
||||
self._save_blockquote(elements, quote_start, len(lines) - 1, quote_lines)
|
||||
|
||||
def _save_blockquote(
|
||||
self,
|
||||
elements: list[MarkdownElement],
|
||||
start: int,
|
||||
end: int,
|
||||
lines: list[str],
|
||||
) -> None:
|
||||
"""Save a blockquote element."""
|
||||
if not lines:
|
||||
return
|
||||
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.BLOCKQUOTE,
|
||||
content='\n'.join(lines),
|
||||
line_start=start,
|
||||
line_end=end,
|
||||
))
|
||||
|
||||
def _extract_horizontal_rules(
|
||||
self,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
protected_ranges: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Extract horizontal rules."""
|
||||
for i, line in enumerate(lines):
|
||||
if self._is_in_protected_range(i, protected_ranges):
|
||||
continue
|
||||
|
||||
if self.HR_PATTERN.match(line):
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.HORIZONTAL_RULE,
|
||||
content=line,
|
||||
line_start=i,
|
||||
line_end=i,
|
||||
))
|
||||
|
||||
def _fill_paragraphs(
|
||||
self,
|
||||
lines: list[str],
|
||||
elements: list[MarkdownElement],
|
||||
protected_ranges: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Fill in paragraphs for remaining content."""
|
||||
occupied = set()
|
||||
for start, end in protected_ranges:
|
||||
for i in range(start, end + 1):
|
||||
occupied.add(i)
|
||||
|
||||
for elem in elements:
|
||||
for i in range(elem.line_start, elem.line_end + 1):
|
||||
occupied.add(i)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
if i in occupied:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if lines[i].strip() == '':
|
||||
i += 1
|
||||
continue
|
||||
|
||||
para_start = i
|
||||
para_lines = []
|
||||
|
||||
while i < len(lines) and i not in occupied and lines[i].strip() != '':
|
||||
para_lines.append(lines[i])
|
||||
occupied.add(i)
|
||||
i += 1
|
||||
|
||||
if para_lines:
|
||||
elements.append(MarkdownElement(
|
||||
type=MarkdownElementType.PARAGRAPH,
|
||||
content='\n'.join(para_lines),
|
||||
line_start=para_start,
|
||||
line_end=i - 1,
|
||||
))
|
||||
|
||||
|
||||
class MarkdownChunker:
|
||||
"""
|
||||
Intelligent chunker for Markdown documents.
|
||||
|
||||
Features:
|
||||
- Structure-aware splitting (headers, code blocks, tables, lists)
|
||||
- Context preservation (header hierarchy)
|
||||
- Configurable chunk size and overlap
|
||||
- Metadata extraction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_chunk_size: int = 1000,
|
||||
min_chunk_size: int = 100,
|
||||
chunk_overlap: int = 50,
|
||||
preserve_code_blocks: bool = True,
|
||||
preserve_tables: bool = True,
|
||||
preserve_lists: bool = True,
|
||||
include_header_context: bool = True,
|
||||
):
|
||||
self._max_chunk_size = max_chunk_size
|
||||
self._min_chunk_size = min_chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._preserve_code_blocks = preserve_code_blocks
|
||||
self._preserve_tables = preserve_tables
|
||||
self._preserve_lists = preserve_lists
|
||||
self._include_header_context = include_header_context
|
||||
self._parser = MarkdownParser()
|
||||
|
||||
def chunk(self, text: str, doc_id: str = "") -> list[MarkdownChunk]:
|
||||
"""
|
||||
Chunk Markdown text into structured segments.
|
||||
|
||||
Args:
|
||||
text: Raw Markdown text
|
||||
doc_id: Optional document ID for chunk IDs
|
||||
|
||||
Returns:
|
||||
List of MarkdownChunk objects
|
||||
"""
|
||||
elements = self._parser.parse(text)
|
||||
chunks = []
|
||||
header_stack: list[str] = []
|
||||
chunk_index = 0
|
||||
|
||||
for elem in elements:
|
||||
if elem.type == MarkdownElementType.HEADER:
|
||||
level = elem.level
|
||||
while len(header_stack) >= level:
|
||||
if header_stack:
|
||||
header_stack.pop()
|
||||
header_stack.append(elem.content)
|
||||
continue
|
||||
|
||||
if elem.type == MarkdownElementType.HORIZONTAL_RULE:
|
||||
continue
|
||||
|
||||
chunk_content = self._format_element_content(elem)
|
||||
if not chunk_content:
|
||||
continue
|
||||
|
||||
chunk_id = f"{doc_id}_chunk_{chunk_index}" if doc_id else f"chunk_{chunk_index}"
|
||||
|
||||
header_context = []
|
||||
if self._include_header_context:
|
||||
header_context = header_stack.copy()
|
||||
|
||||
if len(chunk_content) > self._max_chunk_size:
|
||||
sub_chunks = self._split_large_element(
|
||||
elem,
|
||||
chunk_id,
|
||||
header_context,
|
||||
chunk_index,
|
||||
)
|
||||
chunks.extend(sub_chunks)
|
||||
chunk_index += len(sub_chunks)
|
||||
else:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=chunk_id,
|
||||
content=chunk_content,
|
||||
element_type=elem.type,
|
||||
header_context=header_context,
|
||||
level=elem.level,
|
||||
language=elem.language,
|
||||
metadata=elem.metadata,
|
||||
))
|
||||
chunk_index += 1
|
||||
|
||||
return chunks
|
||||
|
||||
def _format_element_content(self, elem: MarkdownElement) -> str:
|
||||
"""Format element content based on type."""
|
||||
if elem.type == MarkdownElementType.CODE_BLOCK:
|
||||
lang = elem.language or ""
|
||||
return f"```{lang}\n{elem.content}\n```"
|
||||
|
||||
elif elem.type == MarkdownElementType.TABLE:
|
||||
return elem.content
|
||||
|
||||
elif elem.type == MarkdownElementType.LIST:
|
||||
return elem.content
|
||||
|
||||
elif elem.type == MarkdownElementType.BLOCKQUOTE:
|
||||
lines = elem.content.split('\n')
|
||||
return '\n'.join([f"> {line}" for line in lines])
|
||||
|
||||
elif elem.type == MarkdownElementType.PARAGRAPH:
|
||||
return elem.content
|
||||
|
||||
return elem.content
|
||||
|
||||
def _split_large_element(
|
||||
self,
|
||||
elem: MarkdownElement,
|
||||
base_id: str,
|
||||
header_context: list[str],
|
||||
start_index: int,
|
||||
) -> list[MarkdownChunk]:
|
||||
"""Split a large element into smaller chunks."""
|
||||
chunks = []
|
||||
|
||||
if elem.type == MarkdownElementType.CODE_BLOCK:
|
||||
chunks = self._split_code_block(elem, base_id, header_context, start_index)
|
||||
elif elem.type == MarkdownElementType.TABLE:
|
||||
chunks = self._split_table(elem, base_id, header_context, start_index)
|
||||
elif elem.type == MarkdownElementType.LIST:
|
||||
chunks = self._split_list(elem, base_id, header_context, start_index)
|
||||
else:
|
||||
chunks = self._split_text(elem, base_id, header_context, start_index)
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_code_block(
|
||||
self,
|
||||
elem: MarkdownElement,
|
||||
base_id: str,
|
||||
header_context: list[str],
|
||||
start_index: int,
|
||||
) -> list[MarkdownChunk]:
|
||||
"""Split code block while preserving language marker."""
|
||||
chunks = []
|
||||
lines = elem.content.split('\n')
|
||||
current_lines = []
|
||||
current_size = 0
|
||||
sub_index = 0
|
||||
|
||||
for line in lines:
|
||||
if current_size + len(line) + 1 > self._max_chunk_size and current_lines:
|
||||
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content=chunk_content,
|
||||
element_type=MarkdownElementType.CODE_BLOCK,
|
||||
header_context=header_context,
|
||||
language=elem.language,
|
||||
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
|
||||
))
|
||||
sub_index += 1
|
||||
current_lines = []
|
||||
current_size = 0
|
||||
|
||||
current_lines.append(line)
|
||||
current_size += len(line) + 1
|
||||
|
||||
if current_lines:
|
||||
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content=chunk_content,
|
||||
element_type=MarkdownElementType.CODE_BLOCK,
|
||||
header_context=header_context,
|
||||
language=elem.language,
|
||||
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_table(
|
||||
self,
|
||||
elem: MarkdownElement,
|
||||
base_id: str,
|
||||
header_context: list[str],
|
||||
start_index: int,
|
||||
) -> list[MarkdownChunk]:
|
||||
"""Split table while preserving header row."""
|
||||
chunks = []
|
||||
lines = elem.content.split('\n')
|
||||
|
||||
if len(lines) < 2:
|
||||
return [MarkdownChunk(
|
||||
chunk_id=f"{base_id}_0",
|
||||
content=elem.content,
|
||||
element_type=MarkdownElementType.TABLE,
|
||||
header_context=header_context,
|
||||
metadata=elem.metadata,
|
||||
)]
|
||||
|
||||
header_line = lines[0]
|
||||
separator_line = lines[1]
|
||||
data_lines = lines[2:]
|
||||
|
||||
current_lines = [header_line, separator_line]
|
||||
current_size = len(header_line) + len(separator_line) + 2
|
||||
sub_index = 0
|
||||
|
||||
for line in data_lines:
|
||||
if current_size + len(line) + 1 > self._max_chunk_size and len(current_lines) > 2:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content='\n'.join(current_lines),
|
||||
element_type=MarkdownElementType.TABLE,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
|
||||
))
|
||||
sub_index += 1
|
||||
current_lines = [header_line, separator_line]
|
||||
current_size = len(header_line) + len(separator_line) + 2
|
||||
|
||||
current_lines.append(line)
|
||||
current_size += len(line) + 1
|
||||
|
||||
if len(current_lines) > 2:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content='\n'.join(current_lines),
|
||||
element_type=MarkdownElementType.TABLE,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_list(
|
||||
self,
|
||||
elem: MarkdownElement,
|
||||
base_id: str,
|
||||
header_context: list[str],
|
||||
start_index: int,
|
||||
) -> list[MarkdownChunk]:
|
||||
"""Split list into smaller chunks."""
|
||||
chunks = []
|
||||
items = elem.content.split('\n')
|
||||
current_items = []
|
||||
current_size = 0
|
||||
sub_index = 0
|
||||
|
||||
for item in items:
|
||||
if current_size + len(item) + 1 > self._max_chunk_size and current_items:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content='\n'.join(current_items),
|
||||
element_type=MarkdownElementType.LIST,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
|
||||
))
|
||||
sub_index += 1
|
||||
current_items = []
|
||||
current_size = 0
|
||||
|
||||
current_items.append(item)
|
||||
current_size += len(item) + 1
|
||||
|
||||
if current_items:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content='\n'.join(current_items),
|
||||
element_type=MarkdownElementType.LIST,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_text(
|
||||
self,
|
||||
elem: MarkdownElement,
|
||||
base_id: str,
|
||||
header_context: list[str],
|
||||
start_index: int,
|
||||
) -> list[MarkdownChunk]:
|
||||
"""Split text content by sentences or paragraphs."""
|
||||
chunks = []
|
||||
text = elem.content
|
||||
sub_index = 0
|
||||
|
||||
paragraphs = text.split('\n\n')
|
||||
|
||||
current_content = ""
|
||||
current_size = 0
|
||||
|
||||
for para in paragraphs:
|
||||
if current_size + len(para) + 2 > self._max_chunk_size and current_content:
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content=current_content.strip(),
|
||||
element_type=elem.type,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
|
||||
))
|
||||
sub_index += 1
|
||||
current_content = ""
|
||||
current_size = 0
|
||||
|
||||
current_content += para + "\n\n"
|
||||
current_size += len(para) + 2
|
||||
|
||||
if current_content.strip():
|
||||
chunks.append(MarkdownChunk(
|
||||
chunk_id=f"{base_id}_{sub_index}",
|
||||
content=current_content.strip(),
|
||||
element_type=elem.type,
|
||||
header_context=header_context,
|
||||
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_markdown(
|
||||
text: str,
|
||||
doc_id: str = "",
|
||||
max_chunk_size: int = 1000,
|
||||
min_chunk_size: int = 100,
|
||||
preserve_code_blocks: bool = True,
|
||||
preserve_tables: bool = True,
|
||||
preserve_lists: bool = True,
|
||||
include_header_context: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convenience function to chunk Markdown text.
|
||||
|
||||
Args:
|
||||
text: Raw Markdown text
|
||||
doc_id: Optional document ID
|
||||
max_chunk_size: Maximum chunk size in characters
|
||||
min_chunk_size: Minimum chunk size in characters
|
||||
preserve_code_blocks: Whether to preserve code blocks
|
||||
preserve_tables: Whether to preserve tables
|
||||
preserve_lists: Whether to preserve lists
|
||||
include_header_context: Whether to include header context
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries
|
||||
"""
|
||||
chunker = MarkdownChunker(
|
||||
max_chunk_size=max_chunk_size,
|
||||
min_chunk_size=min_chunk_size,
|
||||
preserve_code_blocks=preserve_code_blocks,
|
||||
preserve_tables=preserve_tables,
|
||||
preserve_lists=preserve_lists,
|
||||
include_header_context=include_header_context,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(text, doc_id)
|
||||
return [chunk.to_dict() for chunk in chunks]
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
Markdown parser with intelligent chunking.
|
||||
[AC-AISVC-33] Markdown file parsing with structure-aware chunking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.services.document.base import (
|
||||
DocumentParseException,
|
||||
DocumentParser,
|
||||
ParseResult,
|
||||
)
|
||||
from app.services.document.markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
MarkdownElementType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENCODINGS_TO_TRY = ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]
|
||||
|
||||
|
||||
class MarkdownParser(DocumentParser):
|
||||
"""
|
||||
Parser for Markdown files with intelligent chunking.
|
||||
[AC-AISVC-33] Structure-aware parsing for Markdown documents.
|
||||
|
||||
Features:
|
||||
- Header hierarchy extraction
|
||||
- Code block preservation
|
||||
- Table structure preservation
|
||||
- List grouping
|
||||
- Context-aware chunking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding: str = "utf-8",
|
||||
max_chunk_size: int = 1000,
|
||||
min_chunk_size: int = 100,
|
||||
preserve_code_blocks: bool = True,
|
||||
preserve_tables: bool = True,
|
||||
preserve_lists: bool = True,
|
||||
include_header_context: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._encoding = encoding
|
||||
self._max_chunk_size = max_chunk_size
|
||||
self._min_chunk_size = min_chunk_size
|
||||
self._preserve_code_blocks = preserve_code_blocks
|
||||
self._preserve_tables = preserve_tables
|
||||
self._preserve_lists = preserve_lists
|
||||
self._include_header_context = include_header_context
|
||||
self._extra_config = kwargs
|
||||
|
||||
self._chunker = MarkdownChunker(
|
||||
max_chunk_size=max_chunk_size,
|
||||
min_chunk_size=min_chunk_size,
|
||||
preserve_code_blocks=preserve_code_blocks,
|
||||
preserve_tables=preserve_tables,
|
||||
preserve_lists=preserve_lists,
|
||||
include_header_context=include_header_context,
|
||||
)
|
||||
|
||||
def _try_encodings(self, path: Path) -> tuple[str, str]:
|
||||
"""
|
||||
Try multiple encodings to read the file.
|
||||
Returns: (text, encoding_used)
|
||||
"""
|
||||
for enc in ENCODINGS_TO_TRY:
|
||||
try:
|
||||
with open(path, encoding=enc) as f:
|
||||
text = f.read()
|
||||
logger.info(f"Successfully parsed Markdown with encoding: {enc}")
|
||||
return text, enc
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
raise DocumentParseException(
|
||||
"Failed to decode Markdown file with any known encoding",
|
||||
file_path=str(path),
|
||||
parser="markdown"
|
||||
)
|
||||
|
||||
def parse(self, file_path: str | Path) -> ParseResult:
|
||||
"""
|
||||
Parse a Markdown file and extract structured content.
|
||||
[AC-AISVC-33] Structure-aware parsing.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise DocumentParseException(
|
||||
f"File not found: {path}",
|
||||
file_path=str(path),
|
||||
parser="markdown"
|
||||
)
|
||||
|
||||
try:
|
||||
text, encoding_used = self._try_encodings(path)
|
||||
|
||||
file_size = path.stat().st_size
|
||||
line_count = text.count("\n") + 1
|
||||
|
||||
chunks = self._chunker.chunk(text, doc_id=path.stem)
|
||||
|
||||
header_count = sum(
|
||||
1 for c in chunks
|
||||
if c.element_type == MarkdownElementType.HEADER
|
||||
)
|
||||
code_block_count = sum(
|
||||
1 for c in chunks
|
||||
if c.element_type == MarkdownElementType.CODE_BLOCK
|
||||
)
|
||||
table_count = sum(
|
||||
1 for c in chunks
|
||||
if c.element_type == MarkdownElementType.TABLE
|
||||
)
|
||||
list_count = sum(
|
||||
1 for c in chunks
|
||||
if c.element_type == MarkdownElementType.LIST
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Parsed Markdown: {path.name}, lines={line_count}, "
|
||||
f"chars={len(text)}, chunks={len(chunks)}, "
|
||||
f"headers={header_count}, code_blocks={code_block_count}, "
|
||||
f"tables={table_count}, lists={list_count}"
|
||||
)
|
||||
|
||||
return ParseResult(
|
||||
text=text,
|
||||
source_path=str(path),
|
||||
file_size=file_size,
|
||||
metadata={
|
||||
"format": "markdown",
|
||||
"line_count": line_count,
|
||||
"encoding": encoding_used,
|
||||
"chunk_count": len(chunks),
|
||||
"structure": {
|
||||
"headers": header_count,
|
||||
"code_blocks": code_block_count,
|
||||
"tables": table_count,
|
||||
"lists": list_count,
|
||||
},
|
||||
"chunks": [chunk.to_dict() for chunk in chunks],
|
||||
}
|
||||
)
|
||||
|
||||
except DocumentParseException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DocumentParseException(
|
||||
f"Failed to parse Markdown file: {e}",
|
||||
file_path=str(path),
|
||||
parser="markdown",
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def get_supported_extensions(self) -> list[str]:
|
||||
"""Get supported file extensions."""
|
||||
return [".md", ".markdown"]
|
||||
|
||||
def get_chunks(self, text: str, doc_id: str = "") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get structured chunks from Markdown text.
|
||||
|
||||
Args:
|
||||
text: Markdown text content
|
||||
doc_id: Optional document ID
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries
|
||||
"""
|
||||
chunks = self._chunker.chunk(text, doc_id)
|
||||
return [chunk.to_dict() for chunk in chunks]
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
KB document metadata inference using LLM.
|
||||
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.llm.base import LLMConfig
|
||||
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_MAX_CONTENT_CHARS = 2000
|
||||
|
||||
|
||||
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||
candidates: list[str] = []
|
||||
|
||||
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
|
||||
if code_block_match:
|
||||
candidates.append(code_block_match.group(1).strip())
|
||||
|
||||
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
|
||||
if fence_match:
|
||||
candidates.append(fence_match.group(1).strip())
|
||||
|
||||
brace_match = re.search(r"\{[\s\S]*\}", text)
|
||||
if brace_match:
|
||||
candidates.append(brace_match.group(0).strip())
|
||||
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(candidate)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
fixed = candidate.replace("'", '"')
|
||||
try:
|
||||
obj = json.loads(fixed)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _truncate_content(text: str) -> str:
|
||||
if len(text) <= _MAX_CONTENT_CHARS * 2:
|
||||
return text
|
||||
head = text[:_MAX_CONTENT_CHARS]
|
||||
tail = text[-_MAX_CONTENT_CHARS:]
|
||||
return f"{head}\n\n...\n\n{tail}"
|
||||
|
||||
|
||||
class KBMetadataInferenceService:
|
||||
"""Infer document metadata based on markdown content."""
|
||||
|
||||
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
|
||||
self._session = session
|
||||
self._max_tokens = max_tokens
|
||||
self._temperature = temperature
|
||||
|
||||
async def infer_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
content: str,
|
||||
filename: str | None = None,
|
||||
kb_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
field_def_service = MetadataFieldDefinitionService(self._session)
|
||||
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
|
||||
|
||||
if not field_defs:
|
||||
return {}
|
||||
|
||||
discovery_tool = MetadataDiscoveryTool(self._session)
|
||||
discovery_result = await discovery_tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
kb_id=kb_id,
|
||||
include_values=True,
|
||||
top_n=5,
|
||||
)
|
||||
common_values_map = {
|
||||
field.field_key: field.common_values
|
||||
for field in (discovery_result.fields if discovery_result.success else [])
|
||||
}
|
||||
|
||||
fields_payload = []
|
||||
for field_def in field_defs:
|
||||
fields_payload.append({
|
||||
"field_key": field_def.field_key,
|
||||
"label": field_def.label,
|
||||
"type": field_def.type,
|
||||
"required": field_def.required,
|
||||
"options": field_def.options or [],
|
||||
"common_values": common_values_map.get(field_def.field_key, []),
|
||||
})
|
||||
|
||||
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
|
||||
|
||||
要求:
|
||||
1) 只能使用提供的字段;
|
||||
2) 如果字段有 options 或 common_values,优先从中选择最匹配的值;
|
||||
3) 不确定的字段不要填写;
|
||||
4) 输出必须是严格 JSON 对象,只包含推断出的字段;
|
||||
5) 不要输出多余说明。
|
||||
|
||||
可用字段定义(JSON 数组):
|
||||
{json.dumps(fields_payload, ensure_ascii=False)}
|
||||
|
||||
文件名:{filename or "unknown"}
|
||||
|
||||
Markdown 内容:
|
||||
{_truncate_content(content)}
|
||||
""".strip()
|
||||
|
||||
try:
|
||||
llm_manager = get_llm_config_manager()
|
||||
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
response = await llm_client.generate(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是严格的 JSON 生成器。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
config=LLMConfig(
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=self._temperature,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
|
||||
return {}
|
||||
|
||||
if not response.content:
|
||||
return {}
|
||||
|
||||
inferred = _extract_json_object(response.content)
|
||||
if not inferred:
|
||||
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
|
||||
return {}
|
||||
|
||||
cleaned = self._clean_inferred_values(inferred, field_defs)
|
||||
if not cleaned:
|
||||
return {}
|
||||
|
||||
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
|
||||
tenant_id, cleaned, "kb_document"
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
|
||||
return {}
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_inferred_values(
|
||||
self,
|
||||
inferred: dict[str, Any],
|
||||
field_defs: list[Any],
|
||||
) -> dict[str, Any]:
|
||||
field_map = {f.field_key: f for f in field_defs}
|
||||
cleaned: dict[str, Any] = {}
|
||||
|
||||
for key, value in inferred.items():
|
||||
field_def = field_map.get(key)
|
||||
if not field_def:
|
||||
continue
|
||||
if value is None or value == "":
|
||||
continue
|
||||
|
||||
if field_def.type in {"enum", "array_enum"} and field_def.options:
|
||||
if field_def.type == "enum":
|
||||
if value not in field_def.options:
|
||||
continue
|
||||
else:
|
||||
if not isinstance(value, list):
|
||||
continue
|
||||
filtered = [v for v in value if v in field_def.options]
|
||||
if not filtered:
|
||||
continue
|
||||
value = filtered
|
||||
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
|
@ -117,7 +117,7 @@ class LLMClient(ABC):
|
|||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, Any]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
|
|
@ -145,7 +145,7 @@ class LLMClient(ABC):
|
|||
@abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, Any]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Design pattern: Factory pattern for pluggable LLM providers.
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -23,6 +24,23 @@ LLM_CONFIG_FILE = Path("config/llm_config.json")
|
|||
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
|
||||
|
||||
|
||||
class LLMUsageType(str, Enum):
|
||||
"""LLM usage type for different scenarios."""
|
||||
CHAT = "chat"
|
||||
KB_PROCESSING = "kb_processing"
|
||||
|
||||
|
||||
LLM_USAGE_DISPLAY_NAMES: dict[LLMUsageType, str] = {
|
||||
LLMUsageType.CHAT: "对话模型",
|
||||
LLMUsageType.KB_PROCESSING: "知识库处理模型",
|
||||
}
|
||||
|
||||
LLM_USAGE_DESCRIPTIONS: dict[LLMUsageType, str] = {
|
||||
LLMUsageType.CHAT: "用于 Agent 对话、问答等交互场景",
|
||||
LLMUsageType.KB_PROCESSING: "用于知识库文档上传、元数据推断、文档处理等场景",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProviderInfo:
|
||||
"""Information about an LLM provider."""
|
||||
|
|
@ -284,6 +302,7 @@ class LLMConfigManager:
|
|||
"""
|
||||
Manager for LLM configuration.
|
||||
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
|
||||
Supports multiple LLM usage types (chat, kb_processing).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -293,8 +312,7 @@ class LLMConfigManager:
|
|||
self._settings = settings
|
||||
self._redis_client: redis.Redis | None = None
|
||||
|
||||
self._current_provider: str = settings.llm_provider
|
||||
self._current_config: dict[str, Any] = {
|
||||
default_config = {
|
||||
"api_key": settings.llm_api_key,
|
||||
"base_url": settings.llm_base_url,
|
||||
"model": settings.llm_model,
|
||||
|
|
@ -303,11 +321,42 @@ class LLMConfigManager:
|
|||
"timeout_seconds": settings.llm_timeout_seconds,
|
||||
"max_retries": settings.llm_max_retries,
|
||||
}
|
||||
self._client: LLMClient | None = None
|
||||
|
||||
self._configs: dict[LLMUsageType, dict[str, Any]] = {
|
||||
LLMUsageType.CHAT: {
|
||||
"provider": settings.llm_provider,
|
||||
"config": default_config.copy(),
|
||||
},
|
||||
LLMUsageType.KB_PROCESSING: {
|
||||
"provider": settings.llm_provider,
|
||||
"config": default_config.copy(),
|
||||
},
|
||||
}
|
||||
|
||||
self._clients: dict[LLMUsageType, LLMClient | None] = {
|
||||
LLMUsageType.CHAT: None,
|
||||
LLMUsageType.KB_PROCESSING: None,
|
||||
}
|
||||
|
||||
self._load_from_redis()
|
||||
self._load_from_file()
|
||||
|
||||
@property
|
||||
def chat_provider(self) -> str:
|
||||
return self._configs[LLMUsageType.CHAT]["provider"]
|
||||
|
||||
@property
|
||||
def kb_processing_provider(self) -> str:
|
||||
return self._configs[LLMUsageType.KB_PROCESSING]["provider"]
|
||||
|
||||
@property
|
||||
def chat_config(self) -> dict[str, Any]:
|
||||
return self._configs[LLMUsageType.CHAT]["config"].copy()
|
||||
|
||||
@property
|
||||
def kb_processing_config(self) -> dict[str, Any]:
|
||||
return self._configs[LLMUsageType.KB_PROCESSING]["config"].copy()
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
|
|
@ -322,11 +371,19 @@ class LLMConfigManager:
|
|||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
|
||||
for usage_type in LLMUsageType:
|
||||
type_key = usage_type.value
|
||||
if type_key in saved:
|
||||
self._configs[usage_type] = {
|
||||
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
|
||||
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
|
||||
}
|
||||
elif "provider" in saved:
|
||||
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
|
||||
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
|
||||
|
||||
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from Redis")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
|
|
@ -341,50 +398,42 @@ class LLMConfigManager:
|
|||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
save_data = {
|
||||
usage_type.value: {
|
||||
"provider": config["provider"],
|
||||
"config": config["config"],
|
||||
}
|
||||
for usage_type, config in self._configs.items()
|
||||
}
|
||||
|
||||
self._redis_client.set(
|
||||
LLM_CONFIG_REDIS_KEY,
|
||||
json.dumps({
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}, ensure_ascii=False),
|
||||
json.dumps(save_data, ensure_ascii=False),
|
||||
)
|
||||
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
|
||||
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to Redis")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
|
||||
|
||||
def _load_from_redis(self) -> None:
|
||||
"""Load configuration from Redis if exists."""
|
||||
try:
|
||||
if not self._settings.redis_enabled:
|
||||
return
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
|
||||
if not saved_raw:
|
||||
return
|
||||
saved = json.loads(saved_raw)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
|
||||
|
||||
def _load_from_file(self) -> None:
|
||||
"""Load configuration from file if exists."""
|
||||
try:
|
||||
if LLM_CONFIG_FILE.exists():
|
||||
with open(LLM_CONFIG_FILE, encoding='utf-8') as f:
|
||||
saved = json.load(f)
|
||||
self._current_provider = saved.get("provider", self._current_provider)
|
||||
saved_config = saved.get("config", {})
|
||||
if saved_config:
|
||||
self._current_config.update(saved_config)
|
||||
logger.info(f"[AC-ASA-16] Loaded LLM config from file: provider={self._current_provider}")
|
||||
|
||||
for usage_type in LLMUsageType:
|
||||
type_key = usage_type.value
|
||||
if type_key in saved:
|
||||
self._configs[usage_type] = {
|
||||
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
|
||||
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
|
||||
}
|
||||
elif "provider" in saved:
|
||||
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
|
||||
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
|
||||
|
||||
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from file")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
|
||||
|
||||
|
|
@ -392,26 +441,48 @@ class LLMConfigManager:
|
|||
"""Save configuration to file."""
|
||||
try:
|
||||
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_data = {
|
||||
usage_type.value: {
|
||||
"provider": config["provider"],
|
||||
"config": config["config"],
|
||||
}
|
||||
for usage_type, config in self._configs.items()
|
||||
}
|
||||
|
||||
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config,
|
||||
}, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[AC-ASA-16] Saved LLM config to file: provider={self._current_provider}")
|
||||
json.dump(save_data, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to file")
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
|
||||
|
||||
def get_current_config(self) -> dict[str, Any]:
|
||||
"""Get current LLM configuration."""
|
||||
def get_current_config(self, usage_type: LLMUsageType | None = None) -> dict[str, Any]:
|
||||
"""Get current LLM configuration for specified usage type or all configs."""
|
||||
if usage_type:
|
||||
config = self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
|
||||
return {
|
||||
"provider": self._current_provider,
|
||||
"config": self._current_config.copy(),
|
||||
"usage_type": usage_type.value,
|
||||
"provider": config["provider"],
|
||||
"config": config["config"].copy(),
|
||||
}
|
||||
|
||||
return {
|
||||
usage_type.value: {
|
||||
"provider": config["provider"],
|
||||
"config": config["config"].copy(),
|
||||
}
|
||||
for usage_type, config in self._configs.items()
|
||||
}
|
||||
|
||||
def get_config_for_usage(self, usage_type: LLMUsageType) -> dict[str, Any]:
|
||||
"""Get configuration for a specific usage type."""
|
||||
return self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
|
||||
|
||||
async def update_config(
|
||||
self,
|
||||
provider: str,
|
||||
config: dict[str, Any],
|
||||
usage_type: LLMUsageType | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update LLM configuration.
|
||||
|
|
@ -420,6 +491,7 @@ class LLMConfigManager:
|
|||
Args:
|
||||
provider: Provider name
|
||||
config: New configuration
|
||||
usage_type: Usage type to update (None = update all)
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
|
|
@ -430,17 +502,46 @@ class LLMConfigManager:
|
|||
provider_info = LLM_PROVIDERS[provider]
|
||||
validated_config = self._validate_config(provider_info, config)
|
||||
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
target_usage_types = [usage_type] if usage_type else list(LLMUsageType)
|
||||
|
||||
self._current_provider = provider
|
||||
self._current_config = validated_config
|
||||
for ut in target_usage_types:
|
||||
if self._clients[ut]:
|
||||
await self._clients[ut].close()
|
||||
self._clients[ut] = None
|
||||
|
||||
self._configs[ut]["provider"] = provider
|
||||
self._configs[ut]["config"] = validated_config.copy()
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}, usage={usage_type or 'all'}")
|
||||
return True
|
||||
|
||||
async def update_usage_config(
|
||||
self,
|
||||
usage_type: LLMUsageType,
|
||||
provider: str,
|
||||
config: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Update configuration for a specific usage type."""
|
||||
if provider not in LLM_PROVIDERS:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
provider_info = LLM_PROVIDERS[provider]
|
||||
validated_config = self._validate_config(provider_info, config)
|
||||
|
||||
if self._clients[usage_type]:
|
||||
await self._clients[usage_type].close()
|
||||
self._clients[usage_type] = None
|
||||
|
||||
self._configs[usage_type]["provider"] = provider
|
||||
self._configs[usage_type]["config"] = validated_config
|
||||
|
||||
self._save_to_redis()
|
||||
self._save_to_file()
|
||||
|
||||
logger.info(f"[AC-ASA-16] LLM config updated: usage={usage_type.value}, provider={provider}")
|
||||
return True
|
||||
|
||||
def _validate_config(
|
||||
|
|
@ -462,20 +563,32 @@ class LLMConfigManager:
|
|||
raise ValueError(f"Missing required config: {key}")
|
||||
return validated
|
||||
|
||||
def get_client(self) -> LLMClient:
|
||||
"""Get or create LLM client with current config."""
|
||||
if self._client is None:
|
||||
self._client = LLMProviderFactory.create_client(
|
||||
self._current_provider,
|
||||
self._current_config,
|
||||
def get_client(self, usage_type: LLMUsageType | None = None) -> LLMClient:
|
||||
"""Get or create LLM client with config for specified usage type."""
|
||||
ut = usage_type or LLMUsageType.CHAT
|
||||
|
||||
if self._clients[ut] is None:
|
||||
config = self._configs[ut]
|
||||
self._clients[ut] = LLMProviderFactory.create_client(
|
||||
config["provider"],
|
||||
config["config"],
|
||||
)
|
||||
return self._client
|
||||
return self._clients[ut]
|
||||
|
||||
def get_chat_client(self) -> LLMClient:
|
||||
"""Get LLM client for chat/dialogue."""
|
||||
return self.get_client(LLMUsageType.CHAT)
|
||||
|
||||
def get_kb_processing_client(self) -> LLMClient:
|
||||
"""Get LLM client for KB processing."""
|
||||
return self.get_client(LLMUsageType.KB_PROCESSING)
|
||||
|
||||
async def test_connection(
|
||||
self,
|
||||
test_prompt: str = "你好,请简单介绍一下自己。",
|
||||
provider: str | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
usage_type: LLMUsageType | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Test LLM connection.
|
||||
|
|
@ -485,14 +598,20 @@ class LLMConfigManager:
|
|||
test_prompt: Test prompt to send
|
||||
provider: Optional provider to test (uses current if not specified)
|
||||
config: Optional config to test (uses current if not specified)
|
||||
usage_type: Usage type for config lookup
|
||||
|
||||
Returns:
|
||||
Test result with success status, response, and metrics
|
||||
"""
|
||||
import time
|
||||
|
||||
test_provider = provider or self._current_provider
|
||||
test_config = config if config else self._current_config
|
||||
if usage_type and not provider:
|
||||
usage_config = self._configs[usage_type]
|
||||
test_provider = usage_config["provider"]
|
||||
test_config = usage_config["config"]
|
||||
else:
|
||||
test_provider = provider or self._configs[LLMUsageType.CHAT]["provider"]
|
||||
test_config = config if config else self._configs[LLMUsageType.CHAT]["config"]
|
||||
|
||||
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
|
||||
|
||||
|
|
@ -533,10 +652,11 @@ class LLMConfigManager:
|
|||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the current client."""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
"""Close all clients."""
|
||||
for client in self._clients.values():
|
||||
if client:
|
||||
await client.close()
|
||||
self._clients = {ut: None for ut in LLMUsageType}
|
||||
|
||||
|
||||
_llm_config_manager: LLMConfigManager | None = None
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class OpenAIClient(LLMClient):
|
|||
|
||||
def _build_request_body(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, Any]],
|
||||
config: LLMConfig,
|
||||
stream: bool = False,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
|
|
@ -133,7 +133,7 @@ class OpenAIClient(LLMClient):
|
|||
)
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, Any]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
|
|
@ -255,7 +255,7 @@ class OpenAIClient(LLMClient):
|
|||
|
||||
async def stream_generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[dict[str, Any]],
|
||||
config: LLMConfig | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,496 @@
|
|||
"""
|
||||
Metadata Auto Inference Service.
|
||||
自动推断文档元数据的服务,支持图片和文本格式。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_field_definitions_cache: dict[str, list[Any]] = {}
|
||||
_cache_ttl_seconds = 300
|
||||
_last_cache_refresh: dict[str, float] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceFieldContext:
|
||||
"""推断字段的上下文信息"""
|
||||
field_key: str
|
||||
label: str
|
||||
type: str
|
||||
required: bool
|
||||
options: list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoInferenceResult:
|
||||
"""自动推断结果"""
|
||||
inferred_metadata: dict[str, Any]
|
||||
confidence_scores: dict[str, float]
|
||||
raw_response: str
|
||||
success: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
METADATA_INFERENCE_SYSTEM_PROMPT = """你是一个专业的文档元数据分析助手。你的任务是根据文档内容,自动推断并填写元数据字段。
|
||||
|
||||
## 输出要求
|
||||
请严格按照以下 JSON 格式输出,不要添加任何其他内容:
|
||||
|
||||
```json
|
||||
{
|
||||
"inferred_metadata": {
|
||||
"字段键名1": "推断的值1",
|
||||
"字段键名2": "推断的值2"
|
||||
},
|
||||
"confidence_scores": {
|
||||
"字段键名1": 0.95,
|
||||
"字段键名2": 0.80
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 推断规则
|
||||
1. **仔细分析文档内容**:根据文档的主题、关键词、上下文来推断元数据
|
||||
2. **遵循字段定义**:
|
||||
- 对于枚举类型(enum),必须从给定的选项中选择
|
||||
- 对于数组枚举类型(array_enum),可以选择多个选项
|
||||
- 对于数字类型(number),输出数字
|
||||
- 对于布尔类型(boolean),输出 true 或 false
|
||||
- 对于文本类型(text),输出字符串
|
||||
3. **置信度评分**:
|
||||
- 0.9-1.0: 非常确定
|
||||
- 0.7-0.9: 比较确定
|
||||
- 0.5-0.7: 有一定依据但不确定
|
||||
- 0.0-0.5: 猜测或无法确定
|
||||
4. **无法推断时**:如果无法从文档内容中合理推断某个字段,可以不填写该字段
|
||||
|
||||
## 注意事项
|
||||
- 必须严格按照字段定义的类型和选项填写
|
||||
- 不要编造不存在的选项值
|
||||
- 保持客观,基于文档内容推断"""
|
||||
|
||||
|
||||
class MetadataAutoInferenceService:
|
||||
"""
|
||||
元数据自动推断服务。
|
||||
|
||||
功能:
|
||||
1. 获取租户配置的元数据字段定义
|
||||
2. 使用 LLM 根据文档内容自动推断元数据
|
||||
3. 验证推断结果符合字段定义
|
||||
|
||||
使用场景:
|
||||
- 图片上传时自动推断元数据
|
||||
- Markdown/文本上传时自动推断元数据
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 1024,
|
||||
timeout_seconds: int = 60,
|
||||
):
|
||||
self._session = session
|
||||
self._model = model
|
||||
self._max_tokens = max_tokens
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._field_def_service = MetadataFieldDefinitionService(session)
|
||||
|
||||
async def infer_metadata(
|
||||
self,
|
||||
tenant_id: str,
|
||||
content: str,
|
||||
scope: str = "kb_document",
|
||||
existing_metadata: dict[str, Any] | None = None,
|
||||
image_base64: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> AutoInferenceResult:
|
||||
"""
|
||||
自动推断文档元数据。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
content: 文档文本内容
|
||||
scope: 元数据作用范围
|
||||
existing_metadata: 已有的元数据(用户手动填写的,会覆盖推断结果)
|
||||
image_base64: 图片的 base64 编码(如果是图片)
|
||||
mime_type: 图片的 MIME 类型
|
||||
|
||||
Returns:
|
||||
AutoInferenceResult 包含推断的元数据
|
||||
"""
|
||||
logger.info(
|
||||
f"[MetadataAutoInference] Starting inference: tenant={tenant_id}, "
|
||||
f"content_length={len(content)}, scope={scope}"
|
||||
)
|
||||
|
||||
field_definitions = await self._get_field_definitions_with_cache(tenant_id, scope)
|
||||
|
||||
if not field_definitions:
|
||||
logger.info(f"[MetadataAutoInference] No field definitions found for tenant={tenant_id}")
|
||||
return AutoInferenceResult(
|
||||
inferred_metadata=existing_metadata or {},
|
||||
confidence_scores={},
|
||||
raw_response="",
|
||||
success=True,
|
||||
error_message="No field definitions configured",
|
||||
)
|
||||
|
||||
field_contexts = self._build_field_contexts(field_definitions)
|
||||
|
||||
if not field_contexts:
|
||||
return AutoInferenceResult(
|
||||
inferred_metadata=existing_metadata or {},
|
||||
confidence_scores={},
|
||||
raw_response="",
|
||||
success=True,
|
||||
)
|
||||
|
||||
user_prompt = self._build_user_prompt(content, field_contexts, existing_metadata)
|
||||
|
||||
try:
|
||||
if image_base64 and mime_type:
|
||||
raw_response = await self._call_multimodal_llm(
|
||||
user_prompt, image_base64, mime_type
|
||||
)
|
||||
else:
|
||||
raw_response = await self._call_text_llm(user_prompt)
|
||||
|
||||
result = self._parse_llm_response(raw_response, field_contexts)
|
||||
|
||||
if existing_metadata:
|
||||
result.inferred_metadata.update(existing_metadata)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataAutoInference] Inference completed: "
|
||||
f"inferred_fields={list(result.inferred_metadata.keys())}, "
|
||||
f"avg_confidence={sum(result.confidence_scores.values()) / len(result.confidence_scores) if result.confidence_scores else 0:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataAutoInference] Inference failed: {e}")
|
||||
return AutoInferenceResult(
|
||||
inferred_metadata=existing_metadata or {},
|
||||
confidence_scores={},
|
||||
raw_response="",
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
def _build_field_contexts(
|
||||
self,
|
||||
field_definitions: list[Any],
|
||||
) -> list[InferenceFieldContext]:
|
||||
"""构建字段上下文列表"""
|
||||
contexts = []
|
||||
for f in field_definitions:
|
||||
ctx = InferenceFieldContext(
|
||||
field_key=f.field_key,
|
||||
label=f.label,
|
||||
type=f.type,
|
||||
required=f.required,
|
||||
options=f.options,
|
||||
description=getattr(f, 'description', None),
|
||||
)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
async def _get_field_definitions_with_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scope: str,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
获取字段定义(带缓存)。
|
||||
|
||||
优先级:
|
||||
1. Redis 缓存
|
||||
2. 本地内存缓存
|
||||
3. 数据库查询
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scope: 作用范围
|
||||
|
||||
Returns:
|
||||
字段定义列表
|
||||
"""
|
||||
import time
|
||||
cache_key = f"{tenant_id}:{scope}"
|
||||
|
||||
try:
|
||||
redis_cache = await get_metadata_cache_service()
|
||||
cached_fields = await redis_cache.get_fields(tenant_id)
|
||||
|
||||
if cached_fields:
|
||||
logger.info(f"[MetadataAutoInference] Redis cache hit for tenant={tenant_id}")
|
||||
return [self._dict_to_field_def(f) for f in cached_fields]
|
||||
except Exception as e:
|
||||
logger.warning(f"[MetadataAutoInference] Redis cache error: {e}")
|
||||
|
||||
current_time = time.time()
|
||||
if cache_key in _field_definitions_cache:
|
||||
last_refresh = _last_cache_refresh.get(cache_key, 0)
|
||||
if current_time - last_refresh < _cache_ttl_seconds:
|
||||
logger.info(f"[MetadataAutoInference] Local cache hit for tenant={tenant_id}")
|
||||
return _field_definitions_cache[cache_key]
|
||||
|
||||
logger.info(f"[MetadataAutoInference] Cache miss, querying database for tenant={tenant_id}")
|
||||
field_definitions = await self._field_def_service.get_active_field_definitions(
|
||||
tenant_id, scope
|
||||
)
|
||||
|
||||
_field_definitions_cache[cache_key] = field_definitions
|
||||
_last_cache_refresh[cache_key] = current_time
|
||||
|
||||
try:
|
||||
redis_cache = await get_metadata_cache_service()
|
||||
await redis_cache.set_fields(
|
||||
tenant_id,
|
||||
[self._field_def_to_dict(f) for f in field_definitions]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MetadataAutoInference] Failed to update Redis cache: {e}")
|
||||
|
||||
return field_definitions
|
||||
|
||||
def _field_def_to_dict(self, field_def: Any) -> dict[str, Any]:
|
||||
"""将字段定义转换为字典"""
|
||||
return {
|
||||
"field_key": field_def.field_key,
|
||||
"label": field_def.label,
|
||||
"type": field_def.type,
|
||||
"required": field_def.required,
|
||||
"options": field_def.options,
|
||||
}
|
||||
|
||||
def _dict_to_field_def(self, data: dict[str, Any]) -> Any:
|
||||
"""将字典转换为字段定义对象"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CachedFieldDefinition:
|
||||
field_key: str
|
||||
label: str
|
||||
type: str
|
||||
required: bool
|
||||
options: list[str] | None = None
|
||||
|
||||
return CachedFieldDefinition(
|
||||
field_key=data["field_key"],
|
||||
label=data["label"],
|
||||
type=data["type"],
|
||||
required=data["required"],
|
||||
options=data.get("options"),
|
||||
)
|
||||
|
||||
def _build_user_prompt(
|
||||
self,
|
||||
content: str,
|
||||
field_contexts: list[InferenceFieldContext],
|
||||
existing_metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""构建用户提示词"""
|
||||
field_descriptions = []
|
||||
for ctx in field_contexts:
|
||||
desc = f"- **{ctx.label}** ({ctx.field_key})"
|
||||
desc += f"\n - 类型: {ctx.type}"
|
||||
desc += f"\n - 必填: {'是' if ctx.required else '否'}"
|
||||
if ctx.options:
|
||||
desc += f"\n - 可选值: {', '.join(ctx.options)}"
|
||||
if existing_metadata and ctx.field_key in existing_metadata:
|
||||
desc += f"\n - 已有值: {existing_metadata[ctx.field_key]}"
|
||||
field_descriptions.append(desc)
|
||||
|
||||
fields_text = "\n".join(field_descriptions)
|
||||
|
||||
prompt = f"""请分析以下文档内容,并推断相应的元数据字段。
|
||||
|
||||
## 待推断的字段定义
|
||||
{fields_text}
|
||||
|
||||
## 文档内容
|
||||
{content[:4000]}
|
||||
|
||||
请根据文档内容推断上述字段的值,并输出 JSON 格式的结果。"""
|
||||
|
||||
return prompt
|
||||
|
||||
async def _call_text_llm(self, prompt: str) -> str:
|
||||
"""调用文本 LLM"""
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_kb_processing_client()
|
||||
|
||||
config = manager.kb_processing_config
|
||||
model = self._model or config.get("model", "gpt-4o-mini")
|
||||
|
||||
from app.services.llm.base import LLMConfig
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=0.3,
|
||||
timeout_seconds=self._timeout_seconds,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = await client.generate(messages=messages, config=llm_config)
|
||||
return response.content or ""
|
||||
|
||||
async def _call_multimodal_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
image_base64: str,
|
||||
mime_type: str,
|
||||
) -> str:
|
||||
"""调用多模态 LLM"""
|
||||
manager = get_llm_config_manager()
|
||||
client = manager.get_kb_processing_client()
|
||||
|
||||
config = manager.kb_processing_config
|
||||
model = self._model or config.get("model", "gpt-4o-mini")
|
||||
|
||||
from app.services.llm.base import LLMConfig
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=model,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=0.3,
|
||||
timeout_seconds=self._timeout_seconds,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{image_base64}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response = await client.generate(messages=messages, config=llm_config)
|
||||
return response.content or ""
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
field_contexts: list[InferenceFieldContext],
|
||||
) -> AutoInferenceResult:
|
||||
"""解析 LLM 响应"""
|
||||
try:
|
||||
json_str = self._extract_json(response)
|
||||
data = json.loads(json_str)
|
||||
|
||||
inferred_metadata = data.get("inferred_metadata", {})
|
||||
confidence_scores = data.get("confidence_scores", {})
|
||||
|
||||
field_map = {ctx.field_key: ctx for ctx in field_contexts}
|
||||
validated_metadata = {}
|
||||
validated_scores = {}
|
||||
|
||||
for field_key, value in inferred_metadata.items():
|
||||
if field_key not in field_map:
|
||||
continue
|
||||
|
||||
ctx = field_map[field_key]
|
||||
validated_value = self._validate_field_value(ctx, value)
|
||||
|
||||
if validated_value is not None:
|
||||
validated_metadata[field_key] = validated_value
|
||||
validated_scores[field_key] = confidence_scores.get(field_key, 0.5)
|
||||
|
||||
return AutoInferenceResult(
|
||||
inferred_metadata=validated_metadata,
|
||||
confidence_scores=validated_scores,
|
||||
raw_response=response,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[MetadataAutoInference] Failed to parse JSON: {e}")
|
||||
return AutoInferenceResult(
|
||||
inferred_metadata={},
|
||||
confidence_scores={},
|
||||
raw_response=response,
|
||||
success=False,
|
||||
error_message=f"JSON parse error: {e}",
|
||||
)
|
||||
|
||||
def _validate_field_value(
|
||||
self,
|
||||
ctx: InferenceFieldContext,
|
||||
value: Any,
|
||||
) -> Any:
|
||||
"""验证并转换字段值"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
from app.models.entities import MetadataFieldType
|
||||
|
||||
if ctx.type == MetadataFieldType.NUMBER.value:
|
||||
try:
|
||||
return float(value) if isinstance(value, str) else value
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
elif ctx.type == MetadataFieldType.BOOLEAN.value:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
return bool(value)
|
||||
|
||||
elif ctx.type == MetadataFieldType.ENUM.value:
|
||||
if ctx.options and value in ctx.options:
|
||||
return value
|
||||
return None
|
||||
|
||||
elif ctx.type == MetadataFieldType.ARRAY_ENUM.value:
|
||||
if not isinstance(value, list):
|
||||
value = [value] if value else []
|
||||
if ctx.options:
|
||||
return [v for v in value if v in ctx.options]
|
||||
return value
|
||||
|
||||
else:
|
||||
return str(value) if value is not None else None
|
||||
|
||||
def _extract_json(self, content: str) -> str:
|
||||
"""从响应中提取 JSON"""
|
||||
content = content.strip()
|
||||
|
||||
if content.startswith("{") and content.endswith("}"):
|
||||
return content
|
||||
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
return content[json_start:json_end + 1]
|
||||
|
||||
return content
|
||||
|
|
@ -14,6 +14,7 @@ from .metrics_collector import MetricsCollector, SessionMetrics, AggregatedMetri
|
|||
from .tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry
|
||||
from .tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder
|
||||
from .memory_adapter import MemoryAdapter, UserMemory
|
||||
from .memory_summary_generator import MemorySummaryGenerator
|
||||
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
|
||||
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
|
||||
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
|
||||
|
|
@ -55,6 +56,7 @@ __all__ = [
|
|||
"get_tool_call_recorder",
|
||||
"MemoryAdapter",
|
||||
"UserMemory",
|
||||
"MemorySummaryGenerator",
|
||||
"DefaultKbToolRunner",
|
||||
"KbToolResult",
|
||||
"KbToolConfig",
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from app.models.mid.schemas import (
|
|||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.mid.tool_converter import convert_tool_to_llm_format, convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
|
|
@ -482,27 +482,35 @@ class AgentOrchestrator:
|
|||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
- scene 参数会自动注入到 context.kb_scene,无需手动在 context 中设置 kb_scene
|
||||
- scene 参数应从元数据字段的 kb_scene 常见值中选择
|
||||
|
||||
**kb_scene 自动注入说明:**
|
||||
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
|
||||
- AI 只需在 context 中设置其他过滤字段(如 grade、subject)
|
||||
- 不要在 context 中重复设置 kb_scene,系统会自动处理
|
||||
|
||||
**示例流程:**
|
||||
1. 调用 `list_document_metadata_fields` 获取字段信息
|
||||
2. 根据返回结果,发现可用字段:grade(年级)、subject(学科)、kb_scene(场景)
|
||||
3. 分析用户问题"三年级语文怎么学",确定过滤条件:grade="三年级", subject="语文"
|
||||
4. 从 kb_scene 的常见值中选择合适的 scene(如"学习方案")
|
||||
5. 调用 `kb_search_dynamic`,传入构造好的 context 和 scene
|
||||
5. 调用 `kb_search_dynamic`,传入 scene="学习方案",context={"grade": "三年级", "subject": "语文"}
|
||||
6. 系统自动将 scene 注入到 context.kb_scene
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
- **不要**在 context 中手动设置 kb_scene,系统会自动从 scene 参数注入。
|
||||
"""
|
||||
|
||||
if not self._template_service or not self._tenant_id:
|
||||
return default_prompt
|
||||
|
||||
try:
|
||||
from app.core.database import get_session
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.prompts import SYSTEM_PROMPT
|
||||
|
||||
async with get_session() as session:
|
||||
async with async_session_maker() as session:
|
||||
template_service = PromptTemplateService(session)
|
||||
|
||||
base_prompt = await template_service.get_published_template(
|
||||
|
|
@ -511,6 +519,15 @@ class AgentOrchestrator:
|
|||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="agent_react",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
if base_prompt and base_prompt != SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] Using agent_react template for Function Calling mode")
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
|
|
@ -519,7 +536,7 @@ class AgentOrchestrator:
|
|||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc/agent_react/default, using default prompt")
|
||||
return default_prompt
|
||||
|
||||
agent_protocol = """
|
||||
|
|
@ -545,10 +562,15 @@ class AgentOrchestrator:
|
|||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
- scene 参数会自动注入到 context.kb_scene,无需手动在 context 中设置 kb_scene
|
||||
|
||||
**kb_scene 自动注入说明:**
|
||||
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
|
||||
- AI 只需在 context 中设置其他过滤字段(如 grade、subject)
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
- **不要**在 context 中手动设置 kb_scene,系统会自动从 scene 参数注入。
|
||||
"""
|
||||
|
||||
final_prompt = base_prompt + agent_protocol
|
||||
|
|
|
|||
|
|
@ -127,6 +127,8 @@ class KbSearchDynamicTool:
|
|||
"知识库动态检索工具。"
|
||||
"根据租户配置的元数据字段定义,动态构建检索过滤器。"
|
||||
"支持必填字段检测和可观测降级。"
|
||||
"重要:context 参数中应包含 kb_scene 字段用于场景过滤,"
|
||||
"系统会自动从外部请求的 scene 参数注入到 context.kb_scene。"
|
||||
)
|
||||
|
||||
def get_tool_schema(self) -> dict[str, Any]:
|
||||
|
|
@ -146,7 +148,7 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
"scene": {
|
||||
"type": "string",
|
||||
"description": "场景标识,如 'open_consult', 'intent_match'",
|
||||
"description": "场景标识(如 'open_consult', 'intent_match'),系统会自动将其注入到 context.kb_scene 作为过滤条件",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
|
|
@ -155,7 +157,7 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "上下文信息,包含动态过滤字段值",
|
||||
"description": "上下文信息,包含动态过滤字段值。重要字段:kb_scene(场景过滤,由系统自动从 scene 参数注入)、grade(年级)、subject(学科)等",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
|
@ -299,13 +301,14 @@ class KbSearchDynamicTool:
|
|||
[AC-MARH-05] 执行 KB 动态检索。
|
||||
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
|
||||
[Step-KB-Binding] 支持步骤级别的知识库约束
|
||||
[KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
scene: 场景标识(会自动注入到 context.kb_scene)
|
||||
top_k: 返回数量
|
||||
context: 上下文(包含动态过滤值,包括 scene)
|
||||
context: 上下文(包含动态过滤值)
|
||||
slot_state: 预聚合的槽位状态(可选,优先使用)
|
||||
step_kb_config: 步骤级别的知识库配置(可选)
|
||||
slot_policy: 槽位策略(flow_strict=流程严格模式,agent_relaxed=通用问答宽松模式)
|
||||
|
|
@ -326,6 +329,25 @@ class KbSearchDynamicTool:
|
|||
effective_context = dict(context) if context else {}
|
||||
effective_scene = effective_context.get("scene", scene)
|
||||
|
||||
logger.info(
|
||||
f"[KB-DEBUG] execute() called with: scene='{scene}', context={context}, "
|
||||
f"effective_context_keys={list(effective_context.keys())}"
|
||||
)
|
||||
|
||||
# [KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
|
||||
# 优先级:context.kb_scene > context.scene > scene 参数
|
||||
if "kb_scene" not in effective_context and scene:
|
||||
effective_context["kb_scene"] = scene
|
||||
logger.info(
|
||||
f"[KB-SCENE-INJECT] Injected scene='{scene}' into context.kb_scene, "
|
||||
f"effective_context now={effective_context}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[KB-SCENE-INJECT] Skipped injection: kb_scene in context={('kb_scene' in effective_context)}, "
|
||||
f"scene is empty={not scene}"
|
||||
)
|
||||
|
||||
# [Step-KB-Binding] 记录步骤知识库约束
|
||||
step_kb_binding_info: dict[str, Any] = {}
|
||||
if step_kb_config:
|
||||
|
|
@ -445,8 +467,8 @@ class KbSearchDynamicTool:
|
|||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit, "applied_filter": metadata_filter},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -482,7 +504,7 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
|
|
@ -509,7 +531,7 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
|
|
@ -905,7 +927,7 @@ def register_kb_search_dynamic_tool(
|
|||
|
||||
registry.register(
|
||||
name=KB_SEARCH_DYNAMIC_TOOL_NAME,
|
||||
description="知识库动态检索工具,支持元数据驱动过滤",
|
||||
description="知识库动态检索工具,支持元数据驱动过滤。系统会自动将 scene 参数注入到 context.kb_scene 进行场景过滤。",
|
||||
handler=handler,
|
||||
tool_type=RegistryToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
|
|
@ -922,9 +944,12 @@ def register_kb_search_dynamic_tool(
|
|||
"properties": {
|
||||
"query": {"type": "string", "description": "检索查询文本"},
|
||||
"tenant_id": {"type": "string", "description": "租户 ID"},
|
||||
"scene": {"type": "string", "description": "场景标识,如 open_consult"},
|
||||
"scene": {"type": "string", "description": "场景标识,系统自动注入到 context.kb_scene"},
|
||||
"top_k": {"type": "integer", "description": "返回条数"},
|
||||
"context": {"type": "object", "description": "上下文,用于动态过滤字段"}
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "过滤条件上下文。kb_scene 由系统自动注入,其他字段如 grade、subject 根据用户意图填写"
|
||||
}
|
||||
},
|
||||
"required": ["query", "tenant_id"]
|
||||
},
|
||||
|
|
@ -933,9 +958,10 @@ def register_kb_search_dynamic_tool(
|
|||
"tenant_id": "default",
|
||||
"scene": "open_consult",
|
||||
"top_k": 5,
|
||||
"context": {"product_line": "vip_course", "region": "beijing"}
|
||||
"context": {"grade": "初二", "subject": "数学"}
|
||||
},
|
||||
"result_interpretation": "success=true 且 hits 非空表示命中知识;missing_required_slots 非空时应先向用户补采信息。"
|
||||
"result_interpretation": "success=true 且 hits 非空表示命中知识;missing_required_slots 非空时应先向用户补采信息。",
|
||||
"kb_scene_injection": "系统会自动将 scene 参数值注入到 context.kb_scene 字段,用于知识库场景过滤。AI 无需手动在 context 中设置 kb_scene。"
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ Reference:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
|
@ -17,6 +19,7 @@ from typing import Any, Callable
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import UserMemory as UserMemoryEntity
|
||||
from app.models.mid.memory import (
|
||||
MemoryFact,
|
||||
MemoryProfile,
|
||||
|
|
@ -93,14 +96,6 @@ class MemoryAdapter:
|
|||
|
||||
在响应前执行,注入基础属性、事实记忆与偏好记忆。
|
||||
失败时返回空记忆,不阻断主链路。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
tenant_id: 租户ID(可选)
|
||||
|
||||
Returns:
|
||||
RecallResponse: 包含 profile/facts/preferences 的响应
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
|
|
@ -126,9 +121,6 @@ class MemoryAdapter:
|
|||
session_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> RecallResponse:
|
||||
"""
|
||||
内部召回实现
|
||||
"""
|
||||
profile = await self._recall_profile(user_id, tenant_id)
|
||||
facts = await self._recall_facts(user_id, tenant_id)
|
||||
preferences = await self._recall_preferences(user_id, tenant_id)
|
||||
|
|
@ -152,7 +144,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryProfile | None:
|
||||
"""召回用户基础属性"""
|
||||
return MemoryProfile(
|
||||
grade="初一",
|
||||
region="北京",
|
||||
|
|
@ -165,7 +156,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> list[MemoryFact]:
|
||||
"""召回用户事实记忆"""
|
||||
return [
|
||||
MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0),
|
||||
MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9),
|
||||
|
|
@ -177,7 +167,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryPreferences | None:
|
||||
"""召回用户偏好"""
|
||||
return MemoryPreferences(
|
||||
tone="friendly",
|
||||
focus_subjects=["数学", "物理"],
|
||||
|
|
@ -189,8 +178,16 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""召回最近会话摘要"""
|
||||
return "上次讨论了数学学习计划,用户对课程安排比较满意"
|
||||
if not tenant_id:
|
||||
return None
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record.summary if record else None
|
||||
|
||||
async def update(
|
||||
self,
|
||||
|
|
@ -200,22 +197,6 @@ class MemoryAdapter:
|
|||
summary: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 异步更新用户记忆
|
||||
|
||||
在对话完成后异步执行,不阻塞主响应。
|
||||
包含会话摘要的回写。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
messages: 本轮对话消息
|
||||
summary: 会话摘要(可选)
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功提交更新任务
|
||||
"""
|
||||
request = UpdateRequest(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
|
|
@ -242,9 +223,6 @@ class MemoryAdapter:
|
|||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
内部更新实现
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._do_update(request, tenant_id),
|
||||
|
|
@ -270,10 +248,18 @@ class MemoryAdapter:
|
|||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
执行实际的记忆更新
|
||||
"""
|
||||
if request.summary:
|
||||
summary_payload = self._parse_summary_payload(request.summary)
|
||||
if summary_payload:
|
||||
await self._save_summary(
|
||||
request.user_id,
|
||||
summary_payload.get("summary", ""),
|
||||
tenant_id,
|
||||
facts=summary_payload.get("facts"),
|
||||
preferences=summary_payload.get("preferences"),
|
||||
open_issues=summary_payload.get("open_issues"),
|
||||
)
|
||||
else:
|
||||
await self._save_summary(request.user_id, request.summary, tenant_id)
|
||||
|
||||
await self._extract_and_save_facts(
|
||||
|
|
@ -285,9 +271,41 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
summary: str,
|
||||
tenant_id: str | None,
|
||||
facts: list[str] | None = None,
|
||||
preferences: dict[str, Any] | None = None,
|
||||
open_issues: list[str] | None = None,
|
||||
) -> None:
|
||||
"""保存会话摘要"""
|
||||
pass
|
||||
if not tenant_id:
|
||||
logger.warning("[AC-IDMP-14] Missing tenant_id when saving summary")
|
||||
return
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if record:
|
||||
record.summary = summary
|
||||
record.facts = facts or record.facts
|
||||
record.preferences = preferences or record.preferences
|
||||
record.open_issues = open_issues or record.open_issues
|
||||
record.summary_version = (record.summary_version or 0) + 1
|
||||
record.updated_at = datetime.utcnow()
|
||||
else:
|
||||
record = UserMemoryEntity(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
summary=summary,
|
||||
facts=facts,
|
||||
preferences=preferences,
|
||||
open_issues=open_issues,
|
||||
summary_version=1,
|
||||
)
|
||||
self._session.add(record)
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
async def _extract_and_save_facts(
|
||||
self,
|
||||
|
|
@ -295,8 +313,25 @@ class MemoryAdapter:
|
|||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""从消息中提取并保存事实"""
|
||||
pass
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
for msg in messages:
|
||||
payload = msg.get("memory_payload") or msg.get("summary_payload")
|
||||
if not payload:
|
||||
continue
|
||||
parsed = self._parse_summary_payload(payload)
|
||||
if not parsed:
|
||||
continue
|
||||
await self._save_summary(
|
||||
user_id=user_id,
|
||||
summary=parsed.get("summary", ""),
|
||||
tenant_id=tenant_id,
|
||||
facts=parsed.get("facts"),
|
||||
preferences=parsed.get("preferences"),
|
||||
open_issues=parsed.get("open_issues"),
|
||||
)
|
||||
break
|
||||
|
||||
async def update_with_summary_generation(
|
||||
self,
|
||||
|
|
@ -305,41 +340,92 @@ class MemoryAdapter:
|
|||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None = None,
|
||||
summary_generator: Callable | None = None,
|
||||
recent_turns: int = 8,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 带摘要生成的记忆更新
|
||||
request = UpdateRequest(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=None,
|
||||
)
|
||||
|
||||
如果未提供摘要,会尝试生成摘要后回写
|
||||
"""
|
||||
task = asyncio.create_task(
|
||||
self._update_with_generation_internal(
|
||||
request,
|
||||
tenant_id,
|
||||
summary_generator,
|
||||
recent_turns,
|
||||
),
|
||||
name=f"memory_update_gen_{user_id}_{session_id}",
|
||||
)
|
||||
self._pending_updates.append(task)
|
||||
task.add_done_callback(lambda t: self._pending_updates.remove(t))
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory update (with summary) scheduled for user={user_id}, "
|
||||
f"session={session_id}, messages_count={len(messages)}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _update_with_generation_internal(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
summary_generator: Callable | None,
|
||||
recent_turns: int,
|
||||
) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._do_update_with_generation(
|
||||
request,
|
||||
tenant_id,
|
||||
summary_generator,
|
||||
recent_turns,
|
||||
),
|
||||
timeout=self._update_timeout_ms / 1000,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory updated (with summary) for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Memory update (with summary) timeout for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-IDMP-14] Memory update (with summary) failed for user={request.user_id}, "
|
||||
f"session={request.session_id}, error={e}"
|
||||
)
|
||||
|
||||
async def _do_update_with_generation(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
summary_generator: Callable | None,
|
||||
recent_turns: int,
|
||||
) -> None:
|
||||
summary = None
|
||||
if summary_generator:
|
||||
try:
|
||||
summary = await summary_generator(messages)
|
||||
old_summary = await self._load_latest_summary(request.user_id, tenant_id)
|
||||
recent_messages = self._trim_recent_messages(request.messages, recent_turns)
|
||||
summary = await self._call_summary_generator(
|
||||
summary_generator,
|
||||
recent_messages,
|
||||
old_summary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Summary generation failed: {e}"
|
||||
)
|
||||
|
||||
return await self.update(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=summary,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
request.summary = summary
|
||||
await self._do_update(request, tenant_id)
|
||||
|
||||
async def wait_pending_updates(self, timeout: float = 5.0) -> int:
|
||||
"""
|
||||
等待所有待处理的更新任务完成
|
||||
|
||||
用于优雅关闭时确保所有更新完成
|
||||
|
||||
Args:
|
||||
timeout: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
int: 完成的任务数
|
||||
"""
|
||||
if not self._pending_updates:
|
||||
return 0
|
||||
|
||||
|
|
@ -353,3 +439,62 @@ class MemoryAdapter:
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}")
|
||||
return 0
|
||||
|
||||
async def _load_latest_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
if not tenant_id:
|
||||
return None
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record.summary if record else None
|
||||
|
||||
def _trim_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
recent_turns: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
if recent_turns <= 0:
|
||||
return []
|
||||
return messages[-recent_turns:]
|
||||
|
||||
async def _call_summary_generator(
|
||||
self,
|
||||
summary_generator: Callable,
|
||||
recent_messages: list[dict[str, Any]],
|
||||
old_summary: str | None,
|
||||
) -> str | None:
|
||||
try:
|
||||
if len(inspect.signature(summary_generator).parameters) >= 2:
|
||||
return await summary_generator(recent_messages, old_summary)
|
||||
except Exception:
|
||||
return await summary_generator(recent_messages)
|
||||
|
||||
return await summary_generator(recent_messages)
|
||||
|
||||
def _parse_summary_payload(
|
||||
self,
|
||||
payload: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -334,24 +334,20 @@ class MemoryRecallTool:
|
|||
) -> str | None:
|
||||
"""召回最近会话摘要。"""
|
||||
try:
|
||||
from app.models.entities import MidAuditLog
|
||||
from sqlmodel import col
|
||||
from app.models.entities import UserMemory
|
||||
|
||||
stmt = (
|
||||
select(MidAuditLog)
|
||||
select(UserMemory)
|
||||
.where(
|
||||
MidAuditLog.tenant_id == tenant_id,
|
||||
UserMemory.tenant_id == tenant_id,
|
||||
UserMemory.user_id == user_id,
|
||||
)
|
||||
.order_by(col(MidAuditLog.created_at).desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
audit = result.scalar_one_or_none()
|
||||
memory = result.scalar_one_or_none()
|
||||
|
||||
if audit:
|
||||
return f"上次会话模式: {audit.mode}"
|
||||
|
||||
return None
|
||||
return memory.summary if memory else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Memory summary generator using LLM.
|
||||
[AC-IDMP-14] Generate rolling summary for memory update.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.services.llm.base import LLMConfig
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.mid.memory_summary_prompt import build_memory_summary_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemorySummaryGenerator:
|
||||
"""
|
||||
LLM-based memory summary generator.
|
||||
|
||||
Output expected to be a JSON object or structured text.
|
||||
"""
|
||||
|
||||
def __init__(self, max_tokens: int = 512, temperature: float = 0.2):
|
||||
self._max_tokens = max_tokens
|
||||
self._temperature = temperature
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
old_summary: str | None = None,
|
||||
) -> str | None:
|
||||
try:
|
||||
llm_manager = get_llm_config_manager()
|
||||
llm_client = llm_manager.get_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-14] Failed to get LLM client: {e}")
|
||||
return None
|
||||
|
||||
prompt = build_memory_summary_prompt(messages, old_summary)
|
||||
|
||||
try:
|
||||
response = await llm_client.generate(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个严格遵循 JSON 格式的摘要器。仅输出 JSON。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
config=LLMConfig(
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=self._temperature,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-14] Summary generation failed: {e}")
|
||||
return None
|
||||
|
||||
return response.content
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
"""
|
||||
Memory summary prompt builder.
|
||||
[AC-IDMP-14] Rolling summary prompt for memory update.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
SUMMARY_PROMPT_TEMPLATE = """
|
||||
你是一个记忆摘要生成器。你的目标是把对话中稳定、有长期价值的信息归纳为“可用于记忆召回”的摘要。
|
||||
要求:
|
||||
1) 必须保留:稳定事实、用户偏好、未解决问题;
|
||||
2) 不要写闲聊/情绪;
|
||||
3) 输出必须为严格 JSON 对象,只允许以下字段:summary, facts, preferences, open_issues;
|
||||
4) summary 为一段话(150-300字以内);facts/preferences/open_issues 为列表;
|
||||
5) 如果新内容没有变化,保留旧摘要并可轻微精简;
|
||||
6) 所有内容必须基于对话,不允许编造。
|
||||
|
||||
【旧摘要】
|
||||
{old_summary}
|
||||
|
||||
【最近对话】
|
||||
{recent_messages}
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_recent_messages_text(messages: list[dict[str, Any]]) -> str:
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
lines.append(f"{role}: {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_memory_summary_prompt(
|
||||
messages: list[dict[str, Any]],
|
||||
old_summary: str | None = None,
|
||||
) -> str:
|
||||
return SUMMARY_PROMPT_TEMPLATE.format(
|
||||
old_summary=old_summary or "无",
|
||||
recent_messages=build_recent_messages_text(messages),
|
||||
)
|
||||
|
|
@ -69,7 +69,7 @@ class RerankerConfig:
|
|||
@dataclass
|
||||
class ModeRouterConfig:
|
||||
"""模式路由配置。【AC-AISVC-RES-09~15】"""
|
||||
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
|
||||
runtime_mode: RuntimeMode = RuntimeMode.AUTO
|
||||
react_trigger_confidence_threshold: float = 0.6
|
||||
react_trigger_complexity_score: float = 0.5
|
||||
react_max_steps: int = 5
|
||||
|
|
|
|||
|
|
@ -0,0 +1,300 @@
|
|||
"""
|
||||
Strategy Audit Service for AI Service.
|
||||
[AC-AISVC-RES-07] Audit logging for strategy operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.retrieval_strategy import StrategyAuditLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEntry:
|
||||
"""
|
||||
Internal audit entry structure.
|
||||
"""
|
||||
|
||||
timestamp: str
|
||||
operation: str
|
||||
previous_strategy: str | None = None
|
||||
new_strategy: str | None = None
|
||||
previous_react_mode: str | None = None
|
||||
new_react_mode: str | None = None
|
||||
reason: str | None = None
|
||||
operator: str | None = None
|
||||
tenant_id: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class StrategyAuditService:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Audit service for strategy operations.
|
||||
|
||||
Features:
|
||||
- Structured audit logging
|
||||
- In-memory audit trail (configurable retention)
|
||||
- JSON output for log aggregation
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 1000):
|
||||
self._audit_log: deque[AuditEntry] = deque(maxlen=max_entries)
|
||||
self._max_entries = max_entries
|
||||
|
||||
def log(
|
||||
self,
|
||||
operation: str,
|
||||
previous_strategy: str | None = None,
|
||||
new_strategy: str | None = None,
|
||||
previous_react_mode: str | None = None,
|
||||
new_react_mode: str | None = None,
|
||||
reason: str | None = None,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Log a strategy operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type (switch, rollback, validate).
|
||||
previous_strategy: Previous strategy value.
|
||||
new_strategy: New strategy value.
|
||||
previous_react_mode: Previous react mode.
|
||||
new_react_mode: New react mode.
|
||||
reason: Reason for the operation.
|
||||
operator: Operator who performed the operation.
|
||||
tenant_id: Tenant ID if applicable.
|
||||
metadata: Additional metadata.
|
||||
"""
|
||||
entry = AuditEntry(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
operation=operation,
|
||||
previous_strategy=previous_strategy,
|
||||
new_strategy=new_strategy,
|
||||
previous_react_mode=previous_react_mode,
|
||||
new_react_mode=new_react_mode,
|
||||
reason=reason,
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._audit_log.append(entry)
|
||||
|
||||
log_data = {
|
||||
"audit_type": "strategy_operation",
|
||||
"timestamp": entry.timestamp,
|
||||
"operation": entry.operation,
|
||||
"previous_strategy": entry.previous_strategy,
|
||||
"new_strategy": entry.new_strategy,
|
||||
"previous_react_mode": entry.previous_react_mode,
|
||||
"new_react_mode": entry.new_react_mode,
|
||||
"reason": entry.reason,
|
||||
"operator": entry.operator,
|
||||
"tenant_id": entry.tenant_id,
|
||||
"metadata": entry.metadata,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Strategy audit: operation={operation}, "
|
||||
f"from={previous_strategy} -> to={new_strategy}, "
|
||||
f"operator={operator}, reason={reason}"
|
||||
)
|
||||
|
||||
audit_logger = logging.getLogger("audit.strategy")
|
||||
audit_logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
|
||||
def log_switch(
|
||||
self,
|
||||
previous_strategy: str,
|
||||
new_strategy: str,
|
||||
previous_react_mode: str | None = None,
|
||||
new_react_mode: str | None = None,
|
||||
reason: str | None = None,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
rollout_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log a strategy switch operation.
|
||||
|
||||
Args:
|
||||
previous_strategy: Previous strategy value.
|
||||
new_strategy: New strategy value.
|
||||
previous_react_mode: Previous react mode.
|
||||
new_react_mode: New react mode.
|
||||
reason: Reason for the switch.
|
||||
operator: Operator who performed the switch.
|
||||
tenant_id: Tenant ID if applicable.
|
||||
rollout_config: Rollout configuration.
|
||||
"""
|
||||
self.log(
|
||||
operation="switch",
|
||||
previous_strategy=previous_strategy,
|
||||
new_strategy=new_strategy,
|
||||
previous_react_mode=previous_react_mode,
|
||||
new_react_mode=new_react_mode,
|
||||
reason=reason,
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
metadata={"rollout_config": rollout_config} if rollout_config else None,
|
||||
)
|
||||
|
||||
def log_rollback(
|
||||
self,
|
||||
previous_strategy: str,
|
||||
new_strategy: str,
|
||||
previous_react_mode: str | None = None,
|
||||
new_react_mode: str | None = None,
|
||||
reason: str | None = None,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log a strategy rollback operation.
|
||||
|
||||
Args:
|
||||
previous_strategy: Previous strategy value.
|
||||
new_strategy: Strategy rolled back to.
|
||||
previous_react_mode: Previous react mode.
|
||||
new_react_mode: React mode rolled back to.
|
||||
reason: Reason for the rollback.
|
||||
operator: Operator who performed the rollback.
|
||||
tenant_id: Tenant ID if applicable.
|
||||
"""
|
||||
self.log(
|
||||
operation="rollback",
|
||||
previous_strategy=previous_strategy,
|
||||
new_strategy=new_strategy,
|
||||
previous_react_mode=previous_react_mode,
|
||||
new_react_mode=new_react_mode,
|
||||
reason=reason or "Manual rollback",
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def log_validation(
|
||||
self,
|
||||
strategy: str,
|
||||
react_mode: str | None = None,
|
||||
checks: list[str] | None = None,
|
||||
passed: bool = False,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log a strategy validation operation.
|
||||
|
||||
Args:
|
||||
strategy: Strategy being validated.
|
||||
react_mode: React mode being validated.
|
||||
checks: List of checks performed.
|
||||
passed: Whether validation passed.
|
||||
operator: Operator who performed the validation.
|
||||
tenant_id: Tenant ID if applicable.
|
||||
"""
|
||||
self.log(
|
||||
operation="validate",
|
||||
new_strategy=strategy,
|
||||
new_react_mode=react_mode,
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
metadata={
|
||||
"checks": checks,
|
||||
"passed": passed,
|
||||
},
|
||||
)
|
||||
|
||||
def get_audit_log(
|
||||
self,
|
||||
limit: int = 100,
|
||||
operation: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> list[StrategyAuditLog]:
|
||||
"""
|
||||
Get audit log entries.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return.
|
||||
operation: Filter by operation type.
|
||||
tenant_id: Filter by tenant ID.
|
||||
|
||||
Returns:
|
||||
List of StrategyAuditLog entries.
|
||||
"""
|
||||
entries = list(self._audit_log)
|
||||
|
||||
if operation:
|
||||
entries = [e for e in entries if e.operation == operation]
|
||||
|
||||
if tenant_id:
|
||||
entries = [e for e in entries if e.tenant_id == tenant_id]
|
||||
|
||||
entries = entries[-limit:]
|
||||
|
||||
return [
|
||||
StrategyAuditLog(
|
||||
timestamp=e.timestamp,
|
||||
operation=e.operation,
|
||||
previous_strategy=e.previous_strategy,
|
||||
new_strategy=e.new_strategy,
|
||||
previous_react_mode=e.previous_react_mode,
|
||||
new_react_mode=e.new_react_mode,
|
||||
reason=e.reason,
|
||||
operator=e.operator,
|
||||
tenant_id=e.tenant_id,
|
||||
metadata=e.metadata,
|
||||
)
|
||||
for e in entries
|
||||
]
|
||||
|
||||
def get_audit_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get audit log statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with audit statistics.
|
||||
"""
|
||||
entries = list(self._audit_log)
|
||||
|
||||
operation_counts: dict[str, int] = {}
|
||||
for entry in entries:
|
||||
operation_counts[entry.operation] = operation_counts.get(entry.operation, 0) + 1
|
||||
|
||||
return {
|
||||
"total_entries": len(entries),
|
||||
"max_entries": self._max_entries,
|
||||
"operation_counts": operation_counts,
|
||||
"oldest_entry": entries[0].timestamp if entries else None,
|
||||
"newest_entry": entries[-1].timestamp if entries else None,
|
||||
}
|
||||
|
||||
def clear_audit_log(self) -> int:
|
||||
"""
|
||||
Clear all audit log entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
count = len(self._audit_log)
|
||||
self._audit_log.clear()
|
||||
logger.info(f"[AC-AISVC-RES-07] Audit log cleared: {count} entries removed")
|
||||
return count
|
||||
|
||||
|
||||
_audit_service: StrategyAuditService | None = None
|
||||
|
||||
|
||||
def get_audit_service() -> StrategyAuditService:
|
||||
"""Get or create StrategyAuditService instance."""
|
||||
global _audit_service
|
||||
if _audit_service is None:
|
||||
_audit_service = StrategyAuditService()
|
||||
return _audit_service
|
||||
|
|
@ -0,0 +1,452 @@
|
|||
"""
|
||||
Strategy Metrics Service for AI Service.
|
||||
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics collection for strategy operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.retrieval_strategy import (
|
||||
ReactMode,
|
||||
StrategyMetrics,
|
||||
StrategyType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatencyTracker:
|
||||
"""
|
||||
Latency tracking for a single operation.
|
||||
"""
|
||||
|
||||
latencies: list[float] = field(default_factory=list)
|
||||
max_samples: int = 1000
|
||||
|
||||
def record(self, latency_ms: float) -> None:
|
||||
"""Record a latency sample."""
|
||||
if len(self.latencies) >= self.max_samples:
|
||||
self.latencies = self.latencies[-self.max_samples // 2 :]
|
||||
self.latencies.append(latency_ms)
|
||||
|
||||
def get_percentile(self, percentile: float) -> float:
|
||||
"""Get latency at given percentile."""
|
||||
if not self.latencies:
|
||||
return 0.0
|
||||
sorted_latencies = sorted(self.latencies)
|
||||
index = int(len(sorted_latencies) * percentile / 100)
|
||||
index = min(index, len(sorted_latencies) - 1)
|
||||
return sorted_latencies[index]
|
||||
|
||||
def get_avg(self) -> float:
|
||||
"""Get average latency."""
|
||||
if not self.latencies:
|
||||
return 0.0
|
||||
return sum(self.latencies) / len(self.latencies)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyMetricsData:
|
||||
"""
|
||||
Internal metrics data structure.
|
||||
"""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
latency_tracker: LatencyTracker = field(default_factory=LatencyTracker)
|
||||
direct_route_count: int = 0
|
||||
react_route_count: int = 0
|
||||
auto_route_count: int = 0
|
||||
fallback_count: int = 0
|
||||
last_updated: str | None = None
|
||||
|
||||
|
||||
class StrategyMetricsService:
|
||||
"""
|
||||
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics service for strategy operations.
|
||||
|
||||
Features:
|
||||
- Request counting by strategy and route mode
|
||||
- Latency tracking with percentiles
|
||||
- Fallback and error tracking
|
||||
- Metrics export for monitoring
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
|
||||
self._route_metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
|
||||
self._current_strategy: StrategyType = StrategyType.DEFAULT
|
||||
self._current_react_mode: ReactMode = ReactMode.NON_REACT
|
||||
|
||||
def set_current_strategy(
|
||||
self,
|
||||
strategy: StrategyType,
|
||||
react_mode: ReactMode,
|
||||
) -> None:
|
||||
"""
|
||||
Set current strategy for metrics attribution.
|
||||
|
||||
Args:
|
||||
strategy: Current active strategy.
|
||||
react_mode: Current react mode.
|
||||
"""
|
||||
self._current_strategy = strategy
|
||||
self._current_react_mode = react_mode
|
||||
|
||||
def record_request(
|
||||
self,
|
||||
latency_ms: float,
|
||||
success: bool = True,
|
||||
route_mode: str | None = None,
|
||||
fallback: bool = False,
|
||||
strategy: StrategyType | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-03, AC-AISVC-RES-08] Record a retrieval request.
|
||||
|
||||
Args:
|
||||
latency_ms: Request latency in milliseconds.
|
||||
success: Whether the request was successful.
|
||||
route_mode: Route mode used (direct, react, auto).
|
||||
fallback: Whether fallback to default occurred.
|
||||
strategy: Strategy used (defaults to current).
|
||||
"""
|
||||
effective_strategy = strategy or self._current_strategy
|
||||
key = effective_strategy.value
|
||||
|
||||
metrics = self._metrics[key]
|
||||
metrics.total_requests += 1
|
||||
|
||||
if success:
|
||||
metrics.successful_requests += 1
|
||||
else:
|
||||
metrics.failed_requests += 1
|
||||
|
||||
metrics.latency_tracker.record(latency_ms)
|
||||
metrics.last_updated = datetime.utcnow().isoformat()
|
||||
|
||||
if fallback:
|
||||
metrics.fallback_count += 1
|
||||
|
||||
if route_mode:
|
||||
self._record_route_metric(route_mode, latency_ms, success)
|
||||
|
||||
logger.debug(
|
||||
f"[AC-AISVC-RES-08] Request recorded: strategy={key}, "
|
||||
f"latency={latency_ms:.2f}ms, success={success}, route={route_mode}"
|
||||
)
|
||||
|
||||
def _record_route_metric(
|
||||
self,
|
||||
route_mode: str,
|
||||
latency_ms: float,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Record metrics for route mode.
|
||||
|
||||
Args:
|
||||
route_mode: Route mode (direct, react, auto).
|
||||
latency_ms: Request latency.
|
||||
success: Whether successful.
|
||||
"""
|
||||
metrics = self._route_metrics[route_mode]
|
||||
metrics.total_requests += 1
|
||||
|
||||
if success:
|
||||
metrics.successful_requests += 1
|
||||
else:
|
||||
metrics.failed_requests += 1
|
||||
|
||||
metrics.latency_tracker.record(latency_ms)
|
||||
metrics.last_updated = datetime.utcnow().isoformat()
|
||||
|
||||
if route_mode == "direct":
|
||||
self._metrics[self._current_strategy.value].direct_route_count += 1
|
||||
elif route_mode == "react":
|
||||
self._metrics[self._current_strategy.value].react_route_count += 1
|
||||
elif route_mode == "auto":
|
||||
self._metrics[self._current_strategy.value].auto_route_count += 1
|
||||
|
||||
def record_strategy_switch(
|
||||
self,
|
||||
from_strategy: str,
|
||||
to_strategy: str,
|
||||
) -> None:
|
||||
"""
|
||||
Record a strategy switch event.
|
||||
|
||||
Args:
|
||||
from_strategy: Previous strategy.
|
||||
to_strategy: New strategy.
|
||||
"""
|
||||
metrics_logger = logging.getLogger("metrics.strategy")
|
||||
metrics_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "strategy_switch",
|
||||
"from_strategy": from_strategy,
|
||||
"to_strategy": to_strategy,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-03] Strategy switch recorded: {from_strategy} -> {to_strategy}"
|
||||
)
|
||||
|
||||
def record_grayscale_request(
|
||||
self,
|
||||
tenant_id: str,
|
||||
strategy_used: str,
|
||||
in_grayscale: bool,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-AISVC-RES-03] Record a grayscale request.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID.
|
||||
strategy_used: Strategy used for the request.
|
||||
in_grayscale: Whether the request was in grayscale group.
|
||||
"""
|
||||
metrics_logger = logging.getLogger("metrics.grayscale")
|
||||
metrics_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "grayscale_request",
|
||||
"tenant_id": tenant_id,
|
||||
"strategy_used": strategy_used,
|
||||
"in_grayscale": in_grayscale,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
def get_metrics(self, strategy: StrategyType | None = None) -> StrategyMetrics:
|
||||
"""
|
||||
Get metrics for a specific strategy or current strategy.
|
||||
|
||||
Args:
|
||||
strategy: Strategy to get metrics for (defaults to current).
|
||||
|
||||
Returns:
|
||||
StrategyMetrics for the strategy.
|
||||
"""
|
||||
effective_strategy = strategy or self._current_strategy
|
||||
key = effective_strategy.value
|
||||
data = self._metrics[key]
|
||||
|
||||
return StrategyMetrics(
|
||||
strategy=effective_strategy,
|
||||
react_mode=self._current_react_mode,
|
||||
total_requests=data.total_requests,
|
||||
successful_requests=data.successful_requests,
|
||||
failed_requests=data.failed_requests,
|
||||
avg_latency_ms=round(data.latency_tracker.get_avg(), 2),
|
||||
p99_latency_ms=round(data.latency_tracker.get_percentile(99), 2),
|
||||
direct_route_count=data.direct_route_count,
|
||||
react_route_count=data.react_route_count,
|
||||
auto_route_count=data.auto_route_count,
|
||||
fallback_count=data.fallback_count,
|
||||
last_updated=data.last_updated,
|
||||
)
|
||||
|
||||
def get_all_metrics(self) -> dict[str, StrategyMetrics]:
|
||||
"""
|
||||
Get metrics for all strategies.
|
||||
|
||||
Returns:
|
||||
Dictionary of strategy name to metrics.
|
||||
"""
|
||||
return {
|
||||
strategy.value: self.get_metrics(StrategyType(strategy))
|
||||
for strategy in StrategyType
|
||||
}
|
||||
|
||||
def get_route_metrics(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get metrics by route mode.
|
||||
|
||||
Returns:
|
||||
Dictionary of route mode to metrics.
|
||||
"""
|
||||
result = {}
|
||||
for route_mode, data in self._route_metrics.items():
|
||||
result[route_mode] = {
|
||||
"total_requests": data.total_requests,
|
||||
"successful_requests": data.successful_requests,
|
||||
"failed_requests": data.failed_requests,
|
||||
"avg_latency_ms": round(data.latency_tracker.get_avg(), 2),
|
||||
"p99_latency_ms": round(data.latency_tracker.get_percentile(99), 2),
|
||||
"last_updated": data.last_updated,
|
||||
}
|
||||
return result
|
||||
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-RES-08] Get performance summary for monitoring.
|
||||
|
||||
Returns:
|
||||
Performance summary dictionary.
|
||||
"""
|
||||
all_metrics = self.get_all_metrics()
|
||||
|
||||
total_requests = sum(m.total_requests for m in all_metrics.values())
|
||||
total_success = sum(m.successful_requests for m in all_metrics.values())
|
||||
total_failed = sum(m.failed_requests for m in all_metrics.values())
|
||||
|
||||
avg_latencies = [
|
||||
m.avg_latency_ms for m in all_metrics.values() if m.avg_latency_ms > 0
|
||||
]
|
||||
overall_avg_latency = (
|
||||
sum(avg_latencies) / len(avg_latencies) if avg_latencies else 0.0
|
||||
)
|
||||
|
||||
p99_latencies = [
|
||||
m.p99_latency_ms for m in all_metrics.values() if m.p99_latency_ms > 0
|
||||
]
|
||||
overall_p99_latency = max(p99_latencies) if p99_latencies else 0.0
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": total_success,
|
||||
"failed_requests": total_failed,
|
||||
"success_rate": round(total_success / total_requests, 4) if total_requests > 0 else 0.0,
|
||||
"avg_latency_ms": round(overall_avg_latency, 2),
|
||||
"p99_latency_ms": round(overall_p99_latency, 2),
|
||||
"current_strategy": self._current_strategy.value,
|
||||
"current_react_mode": self._current_react_mode.value,
|
||||
"strategies": {
|
||||
name: {
|
||||
"total_requests": m.total_requests,
|
||||
"success_rate": round(
|
||||
m.successful_requests / m.total_requests, 4
|
||||
)
|
||||
if m.total_requests > 0
|
||||
else 0.0,
|
||||
"avg_latency_ms": m.avg_latency_ms,
|
||||
"p99_latency_ms": m.p99_latency_ms,
|
||||
}
|
||||
for name, m in all_metrics.items()
|
||||
},
|
||||
"routes": self.get_route_metrics(),
|
||||
}
|
||||
|
||||
def reset_metrics(self, strategy: StrategyType | None = None) -> None:
|
||||
"""
|
||||
Reset metrics for a strategy or all strategies.
|
||||
|
||||
Args:
|
||||
strategy: Strategy to reset (None for all).
|
||||
"""
|
||||
if strategy:
|
||||
self._metrics[strategy.value] = StrategyMetricsData()
|
||||
logger.info(f"[AC-AISVC-RES-08] Metrics reset for strategy: {strategy.value}")
|
||||
else:
|
||||
self._metrics.clear()
|
||||
self._route_metrics.clear()
|
||||
logger.info("[AC-AISVC-RES-08] All metrics reset")
|
||||
|
||||
def check_performance_threshold(
|
||||
self,
|
||||
strategy: StrategyType,
|
||||
max_latency_ms: float = 5000.0,
|
||||
max_error_rate: float = 0.1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
[AC-AISVC-RES-08] Check if performance is within acceptable thresholds.
|
||||
|
||||
Args:
|
||||
strategy: Strategy to check.
|
||||
max_latency_ms: Maximum acceptable average latency.
|
||||
max_error_rate: Maximum acceptable error rate (0-1).
|
||||
|
||||
Returns:
|
||||
Dictionary with check results.
|
||||
"""
|
||||
metrics = self.get_metrics(strategy)
|
||||
|
||||
latency_ok = metrics.avg_latency_ms <= max_latency_ms
|
||||
error_rate = (
|
||||
metrics.failed_requests / metrics.total_requests
|
||||
if metrics.total_requests > 0
|
||||
else 0.0
|
||||
)
|
||||
error_rate_ok = error_rate <= max_error_rate
|
||||
|
||||
return {
|
||||
"strategy": strategy.value,
|
||||
"latency_ok": latency_ok,
|
||||
"avg_latency_ms": metrics.avg_latency_ms,
|
||||
"max_latency_ms": max_latency_ms,
|
||||
"error_rate_ok": error_rate_ok,
|
||||
"error_rate": round(error_rate, 4),
|
||||
"max_error_rate": max_error_rate,
|
||||
"overall_ok": latency_ok and error_rate_ok,
|
||||
"recommendation": (
|
||||
"Performance within acceptable thresholds"
|
||||
if latency_ok and error_rate_ok
|
||||
else "Consider rollback or investigation"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class MetricsContext:
|
||||
"""
|
||||
Context manager for timing operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metrics_service: StrategyMetricsService,
|
||||
route_mode: str | None = None,
|
||||
strategy: StrategyType | None = None,
|
||||
):
|
||||
self._metrics_service = metrics_service
|
||||
self._route_mode = route_mode
|
||||
self._strategy = strategy
|
||||
self._start_time: float | None = None
|
||||
self._success = True
|
||||
|
||||
def __enter__(self) -> "MetricsContext":
|
||||
self._start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
if self._start_time is None:
|
||||
return
|
||||
|
||||
latency_ms = (time.time() - self._start_time) * 1000
|
||||
success = exc_type is None
|
||||
|
||||
self._metrics_service.record_request(
|
||||
latency_ms=latency_ms,
|
||||
success=success,
|
||||
route_mode=self._route_mode,
|
||||
strategy=self._strategy,
|
||||
)
|
||||
|
||||
def mark_failed(self) -> None:
|
||||
"""Mark the operation as failed."""
|
||||
self._success = False
|
||||
|
||||
|
||||
_metrics_service: StrategyMetricsService | None = None
|
||||
|
||||
|
||||
def get_metrics_service() -> StrategyMetricsService:
|
||||
"""Get or create StrategyMetricsService instance."""
|
||||
global _metrics_service
|
||||
if _metrics_service is None:
|
||||
_metrics_service = StrategyMetricsService()
|
||||
return _metrics_service
|
||||
|
|
@ -0,0 +1,484 @@
|
|||
"""
|
||||
Retrieval Strategy Service for AI Service.
|
||||
[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.retrieval_strategy import (
|
||||
ReactMode,
|
||||
RolloutConfig,
|
||||
RolloutMode,
|
||||
StrategyType,
|
||||
RetrievalStrategyStatus,
|
||||
RetrievalStrategySwitchRequest,
|
||||
RetrievalStrategySwitchResponse,
|
||||
RetrievalStrategyValidationRequest,
|
||||
RetrievalStrategyValidationResponse,
|
||||
RetrievalStrategyRollbackResponse,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyState:
|
||||
"""
|
||||
[AC-AISVC-RES-01] Internal state for retrieval strategy.
|
||||
"""
|
||||
|
||||
active_strategy: StrategyType = StrategyType.DEFAULT
|
||||
react_mode: ReactMode = ReactMode.NON_REACT
|
||||
rollout_mode: RolloutMode = RolloutMode.OFF
|
||||
rollout_percentage: float = 0.0
|
||||
rollout_allowlist: list[str] = field(default_factory=list)
|
||||
previous_strategy: StrategyType | None = None
|
||||
previous_react_mode: ReactMode | None = None
|
||||
switch_history: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class RetrievalStrategyService:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Service for managing retrieval strategies.
|
||||
|
||||
Features:
|
||||
- Strategy switching with grayscale support
|
||||
- Rollback to previous/default strategy
|
||||
- Validation of strategy configuration
|
||||
- Audit logging integration
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state = StrategyState()
|
||||
self._audit_callback: Any = None
|
||||
self._metrics_callback: Any = None
|
||||
|
||||
def set_audit_callback(self, callback: Any) -> None:
|
||||
"""Set callback for audit logging."""
|
||||
self._audit_callback = callback
|
||||
|
||||
def set_metrics_callback(self, callback: Any) -> None:
|
||||
"""Set callback for metrics recording."""
|
||||
self._metrics_callback = callback
|
||||
|
||||
def get_current_status(self) -> RetrievalStrategyStatus:
|
||||
"""
|
||||
[AC-AISVC-RES-01] Get current retrieval strategy status.
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyStatus with current configuration.
|
||||
"""
|
||||
rollout = RolloutConfig(
|
||||
mode=self._state.rollout_mode,
|
||||
percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None,
|
||||
allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None,
|
||||
)
|
||||
|
||||
status = RetrievalStrategyStatus(
|
||||
active_strategy=self._state.active_strategy,
|
||||
react_mode=self._state.react_mode,
|
||||
rollout=rollout,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, "
|
||||
f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}"
|
||||
)
|
||||
|
||||
return status
|
||||
|
||||
def switch_strategy(
|
||||
self,
|
||||
request: RetrievalStrategySwitchRequest,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> RetrievalStrategySwitchResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy.
|
||||
|
||||
Args:
|
||||
request: Switch request with target strategy and options.
|
||||
operator: Operator who initiated the switch.
|
||||
tenant_id: Tenant ID for audit.
|
||||
|
||||
Returns:
|
||||
RetrievalStrategySwitchResponse with previous and current status.
|
||||
"""
|
||||
previous_status = self.get_current_status()
|
||||
|
||||
self._state.previous_strategy = self._state.active_strategy
|
||||
self._state.previous_react_mode = self._state.react_mode
|
||||
|
||||
self._state.active_strategy = request.target_strategy
|
||||
|
||||
if request.react_mode:
|
||||
self._state.react_mode = request.react_mode
|
||||
|
||||
if request.rollout:
|
||||
self._state.rollout_mode = request.rollout.mode
|
||||
if request.rollout.mode == RolloutMode.PERCENTAGE:
|
||||
self._state.rollout_percentage = request.rollout.percentage or 0.0
|
||||
elif request.rollout.mode == RolloutMode.ALLOWLIST:
|
||||
self._state.rollout_allowlist = request.rollout.allowlist or []
|
||||
|
||||
switch_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"from_strategy": self._state.previous_strategy.value,
|
||||
"to_strategy": self._state.active_strategy.value,
|
||||
"react_mode": self._state.react_mode.value,
|
||||
"rollout_mode": self._state.rollout_mode.value,
|
||||
"reason": request.reason,
|
||||
"operator": operator,
|
||||
}
|
||||
self._state.switch_history.append(switch_record)
|
||||
|
||||
current_status = self.get_current_status()
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> "
|
||||
f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}"
|
||||
)
|
||||
|
||||
if self._audit_callback:
|
||||
self._audit_callback(
|
||||
operation="switch",
|
||||
previous_strategy=self._state.previous_strategy.value,
|
||||
new_strategy=self._state.active_strategy.value,
|
||||
previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None,
|
||||
new_react_mode=self._state.react_mode.value,
|
||||
reason=request.reason,
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if self._metrics_callback:
|
||||
self._metrics_callback("strategy_switch", {
|
||||
"from_strategy": self._state.previous_strategy.value,
|
||||
"to_strategy": self._state.active_strategy.value,
|
||||
})
|
||||
|
||||
return RetrievalStrategySwitchResponse(
|
||||
previous=previous_status,
|
||||
current=current_status,
|
||||
)
|
||||
|
||||
def validate_strategy(
|
||||
self,
|
||||
request: RetrievalStrategyValidationRequest,
|
||||
) -> RetrievalStrategyValidationResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration.
|
||||
|
||||
Args:
|
||||
request: Validation request with strategy and checks.
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyValidationResponse with check results.
|
||||
"""
|
||||
results: list[ValidationResult] = []
|
||||
default_checks = [
|
||||
"metadata_consistency",
|
||||
"embedding_prefix",
|
||||
"rrf_config",
|
||||
"performance_budget",
|
||||
]
|
||||
checks_to_run = request.checks if request.checks else default_checks
|
||||
|
||||
for check in checks_to_run:
|
||||
result = self._run_validation_check(check, request.strategy, request.react_mode)
|
||||
results.append(result)
|
||||
|
||||
all_passed = all(r.passed for r in results)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, "
|
||||
f"checks={len(results)}, passed={all_passed}"
|
||||
)
|
||||
|
||||
return RetrievalStrategyValidationResponse(
|
||||
passed=all_passed,
|
||||
results=results,
|
||||
)
|
||||
|
||||
def _run_validation_check(
|
||||
self,
|
||||
check: str,
|
||||
strategy: StrategyType,
|
||||
react_mode: ReactMode | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Run a single validation check.
|
||||
|
||||
Args:
|
||||
check: Check name.
|
||||
strategy: Strategy to validate.
|
||||
react_mode: ReAct mode to validate.
|
||||
|
||||
Returns:
|
||||
ValidationResult for the check.
|
||||
"""
|
||||
if check == "metadata_consistency":
|
||||
return self._check_metadata_consistency(strategy)
|
||||
elif check == "embedding_prefix":
|
||||
return self._check_embedding_prefix(strategy)
|
||||
elif check == "rrf_config":
|
||||
return self._check_rrf_config(strategy)
|
||||
elif check == "performance_budget":
|
||||
return self._check_performance_budget(strategy, react_mode)
|
||||
else:
|
||||
return ValidationResult(
|
||||
check=check,
|
||||
passed=False,
|
||||
message=f"Unknown check type: {check}",
|
||||
)
|
||||
|
||||
def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult:
|
||||
"""
|
||||
[AC-AISVC-RES-04] Check metadata consistency between strategies.
|
||||
"""
|
||||
try:
|
||||
passed = True
|
||||
message = "Metadata consistency check passed"
|
||||
|
||||
logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}")
|
||||
return ValidationResult(check="metadata_consistency", passed=passed, message=message)
|
||||
except Exception as e:
|
||||
return ValidationResult(check="metadata_consistency", passed=False, message=str(e))
|
||||
|
||||
def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult:
|
||||
"""
|
||||
Check embedding prefix configuration.
|
||||
"""
|
||||
try:
|
||||
passed = True
|
||||
message = "Embedding prefix configuration valid"
|
||||
|
||||
logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}")
|
||||
return ValidationResult(check="embedding_prefix", passed=passed, message=message)
|
||||
except Exception as e:
|
||||
return ValidationResult(check="embedding_prefix", passed=False, message=str(e))
|
||||
|
||||
def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult:
|
||||
"""
|
||||
[AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration.
|
||||
"""
|
||||
try:
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
if strategy == StrategyType.ENHANCED:
|
||||
if not settings.rag_hybrid_enabled:
|
||||
return ValidationResult(
|
||||
check="rrf_config",
|
||||
passed=False,
|
||||
message="Hybrid retrieval not enabled for enhanced strategy",
|
||||
)
|
||||
|
||||
if settings.rag_rrf_k <= 0:
|
||||
return ValidationResult(
|
||||
check="rrf_config",
|
||||
passed=False,
|
||||
message="RRF K parameter must be positive",
|
||||
)
|
||||
|
||||
return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid")
|
||||
except Exception as e:
|
||||
return ValidationResult(check="rrf_config", passed=False, message=str(e))
|
||||
|
||||
def _check_performance_budget(
|
||||
self,
|
||||
strategy: StrategyType,
|
||||
react_mode: ReactMode | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
[AC-AISVC-RES-08] Check performance budget constraints.
|
||||
"""
|
||||
try:
|
||||
max_latency_ms = 5000
|
||||
|
||||
if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT:
|
||||
max_latency_ms = 10000
|
||||
|
||||
message = f"Performance budget check passed (max_latency={max_latency_ms}ms)"
|
||||
|
||||
logger.debug(
|
||||
f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, "
|
||||
f"react_mode={react_mode}, max_latency={max_latency_ms}ms"
|
||||
)
|
||||
return ValidationResult(check="performance_budget", passed=True, message=message)
|
||||
except Exception as e:
|
||||
return ValidationResult(check="performance_budget", passed=False, message=str(e))
|
||||
|
||||
def rollback_strategy(
|
||||
self,
|
||||
operator: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> RetrievalStrategyRollbackResponse:
|
||||
"""
|
||||
[AC-AISVC-RES-07] Rollback to previous or default strategy.
|
||||
|
||||
Args:
|
||||
operator: Operator who initiated the rollback.
|
||||
tenant_id: Tenant ID for audit.
|
||||
|
||||
Returns:
|
||||
RetrievalStrategyRollbackResponse with current and rollback status.
|
||||
"""
|
||||
current_status = self.get_current_status()
|
||||
|
||||
rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT
|
||||
rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT
|
||||
|
||||
old_strategy = self._state.active_strategy
|
||||
old_react_mode = self._state.react_mode
|
||||
|
||||
self._state.active_strategy = rollback_to_strategy
|
||||
self._state.react_mode = rollback_to_react_mode
|
||||
self._state.rollout_mode = RolloutMode.OFF
|
||||
self._state.rollout_percentage = 0.0
|
||||
self._state.rollout_allowlist = []
|
||||
|
||||
rollback_status = self.get_current_status()
|
||||
|
||||
rollback_record = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"from_strategy": old_strategy.value,
|
||||
"to_strategy": rollback_to_strategy.value,
|
||||
"operator": operator,
|
||||
}
|
||||
self._state.switch_history.append(rollback_record)
|
||||
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> "
|
||||
f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}"
|
||||
)
|
||||
|
||||
if self._audit_callback:
|
||||
self._audit_callback(
|
||||
operation="rollback",
|
||||
previous_strategy=old_strategy.value,
|
||||
new_strategy=rollback_to_strategy.value,
|
||||
previous_react_mode=old_react_mode.value,
|
||||
new_react_mode=rollback_to_react_mode.value,
|
||||
reason="Manual rollback",
|
||||
operator=operator,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if self._metrics_callback:
|
||||
self._metrics_callback("strategy_rollback", {
|
||||
"from_strategy": old_strategy.value,
|
||||
"to_strategy": rollback_to_strategy.value,
|
||||
})
|
||||
|
||||
return RetrievalStrategyRollbackResponse(
|
||||
current=current_status,
|
||||
rollback_to=rollback_status,
|
||||
)
|
||||
|
||||
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
|
||||
"""
|
||||
[AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for allowlist check.
|
||||
|
||||
Returns:
|
||||
True if enhanced strategy should be used.
|
||||
"""
|
||||
if self._state.active_strategy == StrategyType.DEFAULT:
|
||||
return False
|
||||
|
||||
if self._state.rollout_mode == RolloutMode.OFF:
|
||||
return self._state.active_strategy == StrategyType.ENHANCED
|
||||
|
||||
if self._state.rollout_mode == RolloutMode.ALLOWLIST:
|
||||
if tenant_id and tenant_id in self._state.rollout_allowlist:
|
||||
return True
|
||||
return False
|
||||
|
||||
if self._state.rollout_mode == RolloutMode.PERCENTAGE:
|
||||
import random
|
||||
return random.random() * 100 < self._state.rollout_percentage
|
||||
|
||||
return False
|
||||
|
||||
def get_route_mode(
|
||||
self,
|
||||
query: str,
|
||||
confidence: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
[AC-AISVC-RES-09~15] Determine route mode based on query and confidence.
|
||||
|
||||
Args:
|
||||
query: User query.
|
||||
confidence: Confidence score from metadata inference.
|
||||
|
||||
Returns:
|
||||
Route mode: "direct", "react", or "auto".
|
||||
"""
|
||||
if self._state.react_mode == ReactMode.REACT:
|
||||
return "react"
|
||||
elif self._state.react_mode == ReactMode.NON_REACT:
|
||||
return "direct"
|
||||
else:
|
||||
return self._auto_route(query, confidence)
|
||||
|
||||
def _auto_route(self, query: str, confidence: float | None = None) -> str:
|
||||
"""
|
||||
[AC-AISVC-RES-11~14] Auto route based on query complexity and confidence.
|
||||
"""
|
||||
query_length = len(query)
|
||||
has_multiple_conditions = "和" in query or "或" in query or "以及" in query
|
||||
|
||||
low_confidence_threshold = 0.5
|
||||
short_query_threshold = 20
|
||||
|
||||
if confidence is not None and confidence < low_confidence_threshold:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}"
|
||||
)
|
||||
return "react"
|
||||
|
||||
if has_multiple_conditions:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected"
|
||||
)
|
||||
return "react"
|
||||
|
||||
if query_length < short_query_threshold and confidence and confidence > 0.7:
|
||||
logger.info(
|
||||
f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence"
|
||||
)
|
||||
return "direct"
|
||||
|
||||
return "direct"
|
||||
|
||||
def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get recent switch history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of records to return.
|
||||
|
||||
Returns:
|
||||
List of switch records.
|
||||
"""
|
||||
return self._state.switch_history[-limit:]
|
||||
|
||||
|
||||
_strategy_service: RetrievalStrategyService | None = None
|
||||
|
||||
|
||||
def get_strategy_service() -> RetrievalStrategyService:
|
||||
"""Get or create RetrievalStrategyService instance."""
|
||||
global _strategy_service
|
||||
if _strategy_service is None:
|
||||
_strategy_service = RetrievalStrategyService()
|
||||
return _strategy_service
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Database Migration: User Memories Table.
|
||||
[AC-IDMP-14] 用户级记忆滚动摘要表
|
||||
|
||||
创建时间: 2025-03-08
|
||||
变更说明:
|
||||
- 新增 user_memories 表用于存储滚动摘要与事实/偏好/未解决问题
|
||||
|
||||
执行方式:
|
||||
- SQLModel 会自动创建表(通过 init_db)
|
||||
- 此脚本用于手动迁移或回滚
|
||||
|
||||
SQL DDL:
|
||||
```sql
|
||||
CREATE TABLE user_memories (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id VARCHAR NOT NULL,
|
||||
user_id VARCHAR NOT NULL,
|
||||
summary TEXT,
|
||||
facts JSON,
|
||||
preferences JSON,
|
||||
open_issues JSON,
|
||||
summary_version INTEGER NOT NULL DEFAULT 1,
|
||||
last_turn_id VARCHAR,
|
||||
expires_at TIMESTAMP,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX ix_user_memories_tenant_user ON user_memories(tenant_id, user_id);
|
||||
CREATE INDEX ix_user_memories_tenant_user_updated ON user_memories(tenant_id, user_id, updated_at);
|
||||
```
|
||||
|
||||
回滚 SQL:
|
||||
```sql
|
||||
DROP TABLE IF EXISTS user_memories;
|
||||
```
|
||||
"""
|
||||
|
||||
USER_MEMORIES_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS user_memories (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id VARCHAR NOT NULL,
|
||||
user_id VARCHAR NOT NULL,
|
||||
summary TEXT,
|
||||
facts JSON,
|
||||
preferences JSON,
|
||||
open_issues JSON,
|
||||
summary_version INTEGER NOT NULL DEFAULT 1,
|
||||
last_turn_id VARCHAR,
|
||||
expires_at TIMESTAMP,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
USER_MEMORIES_INDEXES = """
|
||||
CREATE INDEX IF NOT EXISTS ix_user_memories_tenant_user ON user_memories(tenant_id, user_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_user_memories_tenant_user_updated ON user_memories(tenant_id, user_id, updated_at);
|
||||
"""
|
||||
|
||||
USER_MEMORIES_ROLLBACK = """
|
||||
DROP TABLE IF EXISTS user_memories;
|
||||
"""
|
||||
|
||||
|
||||
async def upgrade(conn):
|
||||
"""执行迁移"""
|
||||
await conn.execute(USER_MEMORIES_DDL)
|
||||
await conn.execute(USER_MEMORIES_INDEXES)
|
||||
|
||||
|
||||
async def downgrade(conn):
|
||||
"""回滚迁移"""
|
||||
await conn.execute(USER_MEMORIES_ROLLBACK)
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
Script to cleanup vector data for a specific knowledge base.
|
||||
Clears the Qdrant collection for the given KB ID, allowing re-indexing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "Q:\\agentProject\\ai-robot-core\\ai-service")
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.qdrant_client import get_qdrant_client
|
||||
from app.core.database import get_session
|
||||
from app.models.entities import KnowledgeBase, Document
|
||||
from sqlalchemy import select
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_knowledge_base_info(kb_id: str) -> dict | None:
|
||||
"""Get knowledge base information from database."""
|
||||
async for session in get_session():
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
|
||||
if kb:
|
||||
doc_stmt = select(Document).where(Document.kb_id == kb_id)
|
||||
doc_result = await session.execute(doc_stmt)
|
||||
documents = doc_result.scalars().all()
|
||||
|
||||
return {
|
||||
"id": str(kb.id),
|
||||
"tenant_id": kb.tenant_id,
|
||||
"name": kb.name,
|
||||
"doc_count": len(documents),
|
||||
"document_ids": [str(doc.id) for doc in documents]
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def list_kb_collections(tenant_id: str, kb_id: str) -> list[str]:
|
||||
"""List all collections that might be related to the KB."""
|
||||
client = await get_qdrant_client()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
collections = await qdrant.get_collections()
|
||||
all_names = [c.name for c in collections.collections]
|
||||
|
||||
safe_tenant = tenant_id.replace('@', '_')
|
||||
safe_kb = kb_id.replace('-', '_')[:8]
|
||||
|
||||
matching = [
|
||||
name for name in all_names
|
||||
if safe_kb in name or kb_id.replace('-', '')[:8] in name.replace('_', '')
|
||||
]
|
||||
|
||||
return matching
|
||||
|
||||
|
||||
async def clear_kb_vector_data(tenant_id: str, kb_id: str, delete_docs: bool = False) -> bool:
|
||||
"""
|
||||
Clear vector data for a specific knowledge base.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
kb_id: Knowledge base ID
|
||||
delete_docs: Whether to also delete document records from database
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
client = await get_qdrant_client()
|
||||
qdrant = await client.get_client()
|
||||
|
||||
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
||||
|
||||
try:
|
||||
exists = await qdrant.collection_exists(collection_name)
|
||||
if exists:
|
||||
await qdrant.delete_collection(collection_name=collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
else:
|
||||
logger.info(f"Collection {collection_name} does not exist")
|
||||
|
||||
if delete_docs:
|
||||
async for session in get_session():
|
||||
doc_stmt = select(Document).where(Document.kb_id == kb_id)
|
||||
doc_result = await session.execute(doc_stmt)
|
||||
documents = doc_result.scalars().all()
|
||||
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if kb:
|
||||
kb.doc_count = 0
|
||||
kb.updated_at = datetime.utcnow()
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"Deleted {len(documents)} document records from database")
|
||||
break
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear KB vector data: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main(kb_id: str, delete_docs: bool = False):
|
||||
"""Main function to clear KB vector data."""
|
||||
logger.info(f"Starting cleanup for knowledge base: {kb_id}")
|
||||
|
||||
kb_info = await get_knowledge_base_info(kb_id)
|
||||
|
||||
if not kb_info:
|
||||
logger.error(f"Knowledge base not found: {kb_id}")
|
||||
return False
|
||||
|
||||
logger.info(f"Found knowledge base:")
|
||||
logger.info(f" - ID: {kb_info['id']}")
|
||||
logger.info(f" - Name: {kb_info['name']}")
|
||||
logger.info(f" - Tenant: {kb_info['tenant_id']}")
|
||||
logger.info(f" - Document count: {kb_info['doc_count']}")
|
||||
|
||||
matching_collections = await list_kb_collections(kb_info['tenant_id'], kb_id)
|
||||
if matching_collections:
|
||||
logger.info(f" - Related collections: {matching_collections}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("WARNING: This will delete all vector data for this knowledge base!")
|
||||
print(f"Collection to delete: kb_{kb_info['tenant_id'].replace('@', '_')}_{kb_id.replace('-', '_')[:8]}")
|
||||
if delete_docs:
|
||||
print("Document records in database will also be deleted!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
confirm = input("Continue? (yes/no): ")
|
||||
if confirm.lower() != "yes":
|
||||
print("Cancelled")
|
||||
return False
|
||||
|
||||
success = await clear_kb_vector_data(
|
||||
tenant_id=kb_info['tenant_id'],
|
||||
kb_id=kb_id,
|
||||
delete_docs=delete_docs
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully cleared vector data for KB: {kb_id}")
|
||||
logger.info("You can now re-index the knowledge base documents.")
|
||||
else:
|
||||
logger.error(f"Failed to clear vector data for KB: {kb_id}")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
parser = argparse.ArgumentParser(description="Clear vector data for a knowledge base")
|
||||
parser.add_argument("kb_id", help="Knowledge base ID to clear")
|
||||
parser.add_argument("--delete-docs", action="store_true",
|
||||
help="Also delete document records from database")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.kb_id, args.delete_docs))
|
||||
|
|
@ -0,0 +1 @@
|
|||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg class="icon" width="200px" height="200.00px" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"><path fill="#1296db" d="M894.1 355.6h-1.7C853 177.6 687.6 51.4 498.1 54.9S148.2 190.5 115.9 369.7c-35.2 5.6-61.1 36-61.1 71.7v143.4c0.9 40.4 34.3 72.5 74.7 71.7 21.7-0.3 42.2-10 56-26.7 33.6 84.5 99.9 152 183.8 187 1.1-2 2.3-3.9 3.7-5.7 0.9-1.5 2.4-2.6 4.1-3 1.3 0 2.5 0.5 3.6 1.2a318.46 318.46 0 0 1-105.3-187.1c-5.1-44.4 24.1-85.4 67.6-95.2 64.3-11.7 128.1-24.7 192.4-35.9 37.9-5.3 70.4-29.8 85.7-64.9 6.8-15.9 11-32.8 12.5-50 0.5-3.1 2.9-5.6 5.9-6.2 3.1-0.7 6.4 0.5 8.2 3l1.7-1.1c25.4 35.9 74.7 114.4 82.7 197.2 8.2 94.8 3.7 160-71.4 226.5-1.1 1.1-1.7 2.6-1.7 4.1 0.1 2 1.1 3.8 2.8 4.8h4.8l3.2-1.8c75.6-40.4 132.8-108.2 159.9-189.5 11.4 16.1 28.5 27.1 47.8 30.8C846 783.9 716.9 871.6 557.2 884.9c-12-28.6-42.5-44.8-72.9-38.6-33.6 5.4-56.6 37-51.2 70.6 4.4 27.6 26.8 48.8 54.5 51.6 30.6 4.6 60.3-13 70.8-42.2 184.9-14.5 333.2-120.8 364.2-286.9 27.8-10.8 46.3-37.4 46.6-67.2V428.7c-0.1-19.5-8.1-38.2-22.3-51.6-14.5-13.8-33.8-21.4-53.8-21.3l1-0.2zM825.9 397c-71.1-176.9-272.1-262.7-449-191.7-86.8 34.9-155.7 103.4-191 190-2.5-2.8-5.2-5.4-8-7.9 25.3-154.6 163.8-268.6 326.8-269.2s302.3 112.6 328.7 267c-2.9 3.8-5.4 7.7-7.5 11.8z" /></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
|
|
@ -0,0 +1 @@
|
|||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg class="icon" width="200px" height="200.00px" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"><path fill="#1296db" d="M573.9 516.2L512 640l-61.9-123.8C232 546.4 64 733.6 64 960h896c0-226.4-168-413.6-386.1-443.8zM480 384h64c17.7 0 32.1 14.4 32.1 32.1 0 17.7-14.4 32.1-32.1 32.1h-64c-11.9 0-22.3-6.5-27.8-16.1H356c34.9 48.5 91.7 80 156 80 106 0 192-86 192-192s-86-192-192-192-192 86-192 192c0 28.5 6.2 55.6 17.4 80h114.8c5.5-9.6 15.9-16.1 27.8-16.1z" /><path fill="#1296db" d="M272 432.1h84c-4.2-5.9-8.1-12-11.7-18.4-2.3-4.1-4.4-8.3-6.4-12.5-0.2-0.4-0.4-0.7-0.5-1.1H288c-8.8 0-16-7.2-16-16v-48.4c0-64.1 25-124.3 70.3-169.6S447.9 95.8 512 95.8s124.3 25 169.7 70.3c38.3 38.3 62.1 87.2 68.5 140.2-8.4 4-14.2 12.5-14.2 22.4v78.6c0 13.7 11.1 24.8 24.8 24.8h14.6c13.7 0 24.8-11.1 24.8-24.8v-78.6c0-11.3-7.6-20.9-18-23.8-6.9-60.9-33.9-117.4-78-161.3C652.9 92.1 584.6 63.9 512 63.9s-140.9 28.3-192.3 79.6C268.3 194.8 240 263.1 240 335.7v64.4c0 17.7 14.3 32 32 32z" /></svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
|
|
@ -0,0 +1 @@
|
|||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1773157868702" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="12129" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M657.066667 558.933333c-34.133333 25.6-76.8 38.4-123.733334 38.4s-85.333333-12.8-123.733333-38.4c0 4.266667-4.266667 4.266667-8.533333 4.266667C315.733333 614.4 256 708.266667 256 810.666667c0 12.8-8.533333 21.333333-21.333333 21.333333S213.333333 823.466667 213.333333 810.666667c0-119.466667 64-226.133333 166.4-281.6-38.4-38.4-59.733333-89.6-59.733333-145.066667 0-119.466667 93.866667-213.333333 213.333333-213.333333s213.333333 93.866667 213.333334 213.333333c0 55.466667-21.333333 106.666667-59.733334 145.066667 102.4 55.466667 166.4 162.133333 166.4 281.6 0 12.8-8.533333 21.333333-21.333333 21.333333s-21.333333-8.533333-21.333333-21.333333c0-102.4-59.733333-196.266667-149.333334-247.466667 0 0-4.266667 0-4.266666-4.266667z m-123.733334-4.266666c93.866667 0 170.666667-76.8 170.666667-170.666667s-76.8-170.666667-170.666667-170.666667-170.666667 76.8-170.666666 170.666667 76.8 170.666667 170.666666 170.666667z" p-id="12130" fill="#1296db"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
|
|
@ -0,0 +1,375 @@
|
|||
"""
|
||||
Unit tests for ImageParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.document.image_parser import (
|
||||
ImageChunk,
|
||||
ImageParseResult,
|
||||
ImageParser,
|
||||
)
|
||||
|
||||
|
||||
class TestImageParserBasics:
|
||||
"""Test basic functionality of ImageParser."""
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that ImageParser supports correct extensions."""
|
||||
parser = ImageParser()
|
||||
extensions = parser.get_supported_extensions()
|
||||
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
||||
assert extensions == expected_extensions
|
||||
|
||||
def test_get_mime_type(self):
|
||||
"""Test MIME type mapping."""
|
||||
parser = ImageParser()
|
||||
|
||||
assert parser._get_mime_type(".jpg") == "image/jpeg"
|
||||
assert parser._get_mime_type(".jpeg") == "image/jpeg"
|
||||
assert parser._get_mime_type(".png") == "image/png"
|
||||
assert parser._get_mime_type(".gif") == "image/gif"
|
||||
assert parser._get_mime_type(".webp") == "image/webp"
|
||||
assert parser._get_mime_type(".bmp") == "image/bmp"
|
||||
assert parser._get_mime_type(".tiff") == "image/tiff"
|
||||
assert parser._get_mime_type(".tif") == "image/tiff"
|
||||
assert parser._get_mime_type(".unknown") == "image/jpeg"
|
||||
|
||||
|
||||
class TestImageChunkParsing:
|
||||
"""Test LLM response parsing functionality."""
|
||||
|
||||
def test_extract_json_from_plain_json(self):
|
||||
"""Test extracting JSON from plain JSON response."""
|
||||
parser = ImageParser()
|
||||
|
||||
json_str = '{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello", "chunk_type": "text", "keywords": ["key"]}]}'
|
||||
result = parser._extract_json(json_str)
|
||||
|
||||
assert result == json_str
|
||||
|
||||
def test_extract_json_from_markdown(self):
|
||||
"""Test extracting JSON from markdown code block."""
|
||||
parser = ImageParser()
|
||||
|
||||
markdown = """Here is the analysis:
|
||||
|
||||
```json
|
||||
{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello"}]}
|
||||
```
|
||||
|
||||
Hope this helps!"""
|
||||
|
||||
result = parser._extract_json(markdown)
|
||||
assert "image_summary" in result
|
||||
assert "test" in result
|
||||
|
||||
def test_extract_json_from_text_with_json(self):
|
||||
"""Test extracting JSON from text with JSON embedded."""
|
||||
parser = ImageParser()
|
||||
|
||||
text = "The result is: {'image_summary': 'summary', 'chunks': []}"
|
||||
|
||||
result = parser._extract_json(text)
|
||||
assert "image_summary" in result
|
||||
assert "chunks" in result
|
||||
|
||||
def test_parse_llm_response_valid_json(self):
|
||||
"""Test parsing valid JSON response from LLM."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "测试图片",
|
||||
"total_chunks": 2,
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "这是第一块内容",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["测试", "内容"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "这是第二块内容,包含表格数据",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["表格", "数据"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert result.image_summary == "测试图片"
|
||||
assert len(result.chunks) == 2
|
||||
assert result.chunks[0].content == "这是第一块内容"
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
assert result.chunks[0].keywords == ["测试", "内容"]
|
||||
assert result.chunks[1].chunk_type == "table"
|
||||
assert result.chunks[1].keywords == ["表格", "数据"]
|
||||
|
||||
def test_parse_llm_response_empty_chunks(self):
|
||||
"""Test handling response with empty chunks."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "测试",
|
||||
"chunks": []
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == response
|
||||
|
||||
def test_parse_llm_response_invalid_json(self):
|
||||
"""Test handling invalid JSON response with fallback."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = "This is not JSON at all"
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == "This is not JSON at all"
|
||||
|
||||
def test_parse_llm_response_partial_json(self):
|
||||
"""Test handling response with partial/invalid JSON uses fallback."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = '{"image_summary": "test" some text here {"chunks": []}'
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == response
|
||||
|
||||
|
||||
class TestImageChunkDataClass:
|
||||
"""Test ImageChunk dataclass functionality."""
|
||||
|
||||
def test_image_chunk_creation(self):
|
||||
"""Test creating ImageChunk."""
|
||||
chunk = ImageChunk(
|
||||
chunk_index=0,
|
||||
content="Test content",
|
||||
chunk_type="text",
|
||||
keywords=["test", "content"]
|
||||
)
|
||||
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.content == "Test content"
|
||||
assert chunk.chunk_type == "text"
|
||||
assert chunk.keywords == ["test", "content"]
|
||||
|
||||
def test_image_chunk_default_values(self):
|
||||
"""Test ImageChunk with default values."""
|
||||
chunk = ImageChunk(chunk_index=0, content="Test")
|
||||
|
||||
assert chunk.chunk_type == "text"
|
||||
assert chunk.keywords == []
|
||||
|
||||
def test_image_parse_result_creation(self):
|
||||
"""Test creating ImageParseResult."""
|
||||
chunks = [
|
||||
ImageChunk(chunk_index=0, content="Chunk 1"),
|
||||
ImageChunk(chunk_index=1, content="Chunk 2"),
|
||||
]
|
||||
|
||||
result = ImageParseResult(
|
||||
image_summary="Test summary",
|
||||
chunks=chunks,
|
||||
raw_text="Chunk 1\n\nChunk 2",
|
||||
source_path="/path/to/image.png",
|
||||
file_size=1024,
|
||||
metadata={"format": "png"}
|
||||
)
|
||||
|
||||
assert result.image_summary == "Test summary"
|
||||
assert len(result.chunks) == 2
|
||||
assert result.raw_text == "Chunk 1\n\nChunk 2"
|
||||
assert result.file_size == 1024
|
||||
assert result.metadata["format"] == "png"
|
||||
|
||||
|
||||
class TestChunkTypes:
|
||||
"""Test different chunk types."""
|
||||
|
||||
def test_text_chunk_type(self):
|
||||
"""Test text chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Text content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Plain text content",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["text"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
|
||||
def test_table_chunk_type(self):
|
||||
"""Test table chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Table content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Name | Age\n---- | ---\nJohn | 30",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["table", "data"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "table"
|
||||
|
||||
def test_chart_chunk_type(self):
|
||||
"""Test chart chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Chart content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Bar chart showing sales data",
|
||||
"chunk_type": "chart",
|
||||
"keywords": ["chart", "sales"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "chart"
|
||||
|
||||
def test_list_chunk_type(self):
|
||||
"""Test list chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "List content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "1. First item\n2. Second item\n3. Third item",
|
||||
"chunk_type": "list",
|
||||
"keywords": ["list", "items"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "list"
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_single_chunk_scenario(self):
|
||||
"""Test single chunk scenario - simple image with one main content."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "简单文档截图",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "这是一段完整的文档内容,包含所有的信息要点。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["文档", "信息"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.image_summary == "简单文档截图"
|
||||
assert result.raw_text == "这是一段完整的文档内容,包含所有的信息要点。"
|
||||
|
||||
def test_multi_chunk_scenario(self):
|
||||
"""Test multi-chunk scenario - complex image with multiple sections."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "多段落文档",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "第一章:介绍部分,介绍项目的背景和目标。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第一章", "介绍"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "第二章:技术架构,包括前端、后端和数据库设计。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第二章", "架构"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 2,
|
||||
"content": "第三章:部署流程,包含开发环境和生产环境配置。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第三章", "部署"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 3
|
||||
assert "第一章" in result.chunks[0].content
|
||||
assert "第二章" in result.chunks[1].content
|
||||
assert "第三章" in result.chunks[2].content
|
||||
assert result.raw_text.count("\n\n") == 2
|
||||
|
||||
def test_mixed_content_scenario(self):
|
||||
"""Test mixed content scenario - text and table."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "混合内容图片",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "产品介绍:本文档介绍我们的核心产品功能。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["产品", "功能"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "产品规格表:\n| 型号 | 价格 | 库存 |\n| --- | --- | --- |\n| A1 | 100 | 50 |",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["规格", "价格", "库存"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 2,
|
||||
"content": "使用说明:\n1. 打开包装\n2. 连接电源\n3. 按下启动按钮",
|
||||
"chunk_type": "list",
|
||||
"keywords": ["说明", "步骤"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 3
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
assert result.chunks[1].chunk_type == "table"
|
||||
assert result.chunks[2].chunk_type == "list"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""
|
||||
Unit tests for multi-usage LLM configuration.
|
||||
Tests for LLMUsageType, LLMConfigManager multi-usage support, and API endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.llm.factory import (
|
||||
LLMConfigManager,
|
||||
LLMProviderFactory,
|
||||
LLMUsageType,
|
||||
LLM_USAGE_DESCRIPTIONS,
|
||||
LLM_USAGE_DISPLAY_NAMES,
|
||||
get_llm_config_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Mock settings for testing."""
|
||||
settings = MagicMock()
|
||||
settings.llm_provider = "openai"
|
||||
settings.llm_api_key = "test-api-key"
|
||||
settings.llm_base_url = "https://api.openai.com/v1"
|
||||
settings.llm_model = "gpt-4o-mini"
|
||||
settings.llm_max_tokens = 2048
|
||||
settings.llm_temperature = 0.7
|
||||
settings.llm_timeout_seconds = 30
|
||||
settings.llm_max_retries = 3
|
||||
settings.redis_enabled = False
|
||||
settings.redis_url = "redis://localhost:6379"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton():
|
||||
"""Reset singleton before and after each test."""
|
||||
import app.services.llm.factory as factory
|
||||
factory._llm_config_manager = None
|
||||
yield
|
||||
factory._llm_config_manager = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_config_file(tmp_path):
|
||||
"""Create an isolated config file for testing."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
config_file.write_text("{}")
|
||||
return config_file
|
||||
|
||||
|
||||
class TestLLMUsageType:
|
||||
"""Tests for LLMUsageType enum."""
|
||||
|
||||
def test_usage_types_exist(self):
|
||||
"""Test that required usage types exist."""
|
||||
assert LLMUsageType.CHAT.value == "chat"
|
||||
assert LLMUsageType.KB_PROCESSING.value == "kb_processing"
|
||||
|
||||
def test_usage_type_display_names(self):
|
||||
"""Test that display names are defined for all usage types."""
|
||||
for ut in LLMUsageType:
|
||||
assert ut in LLM_USAGE_DISPLAY_NAMES
|
||||
assert isinstance(LLM_USAGE_DISPLAY_NAMES[ut], str)
|
||||
assert len(LLM_USAGE_DISPLAY_NAMES[ut]) > 0
|
||||
|
||||
def test_usage_type_descriptions(self):
|
||||
"""Test that descriptions are defined for all usage types."""
|
||||
for ut in LLMUsageType:
|
||||
assert ut in LLM_USAGE_DESCRIPTIONS
|
||||
assert isinstance(LLM_USAGE_DESCRIPTIONS[ut], str)
|
||||
assert len(LLM_USAGE_DESCRIPTIONS[ut]) > 0
|
||||
|
||||
def test_usage_type_from_string(self):
|
||||
"""Test creating usage type from string."""
|
||||
assert LLMUsageType("chat") == LLMUsageType.CHAT
|
||||
assert LLMUsageType("kb_processing") == LLMUsageType.KB_PROCESSING
|
||||
|
||||
def test_invalid_usage_type(self):
|
||||
"""Test that invalid usage type raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
LLMUsageType("invalid_type")
|
||||
|
||||
|
||||
class TestLLMConfigManagerMultiUsage:
|
||||
"""Tests for LLMConfigManager multi-usage support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_configs(self, mock_settings, isolated_config_file):
|
||||
"""Test getting all configs at once."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
all_configs = manager.get_current_config()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
assert ut.value in all_configs
|
||||
assert "provider" in all_configs[ut.value]
|
||||
assert "config" in all_configs[ut.value]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_specific_usage_config(self, mock_settings, isolated_config_file):
|
||||
"""Test updating config for a specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_usage_config(
|
||||
usage_type=LLMUsageType.KB_PROCESSING,
|
||||
provider="ollama",
|
||||
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
)
|
||||
|
||||
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_config["provider"] == "ollama"
|
||||
assert kb_config["config"]["model"] == "llama3.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_all_configs(self, mock_settings, isolated_config_file):
|
||||
"""Test updating all configs at once."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_config(
|
||||
provider="deepseek",
|
||||
config={"api_key": "test-key", "model": "deepseek-chat"},
|
||||
)
|
||||
|
||||
for ut in LLMUsageType:
|
||||
config = manager.get_current_config(ut)
|
||||
assert config["provider"] == "deepseek"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_for_usage_type(self, mock_settings, isolated_config_file):
|
||||
"""Test getting client for specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
chat_client = manager.get_client(LLMUsageType.CHAT)
|
||||
assert chat_client is not None
|
||||
|
||||
kb_client = manager.get_client(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_client(self, mock_settings, isolated_config_file):
|
||||
"""Test get_chat_client convenience method."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
client = manager.get_chat_client()
|
||||
assert client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_kb_processing_client(self, mock_settings, isolated_config_file):
|
||||
"""Test get_kb_processing_client convenience method."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
client = manager.get_kb_processing_client()
|
||||
assert client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all_clients(self, mock_settings, isolated_config_file):
|
||||
"""Test that close() closes all clients."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
manager.get_client(ut)
|
||||
|
||||
await manager.close()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
assert manager._clients[ut] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence_to_file(self, mock_settings, isolated_config_file):
|
||||
"""Test that configs are persisted to file."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_usage_config(
|
||||
usage_type=LLMUsageType.KB_PROCESSING,
|
||||
provider="ollama",
|
||||
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
)
|
||||
|
||||
assert isolated_config_file.exists()
|
||||
|
||||
with open(isolated_config_file, "r", encoding="utf-8") as f:
|
||||
saved = json.load(f)
|
||||
|
||||
assert "chat" in saved
|
||||
assert "kb_processing" in saved
|
||||
assert saved["kb_processing"]["provider"] == "ollama"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_config_from_file(self, mock_settings, tmp_path):
|
||||
"""Test loading configs from file."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
saved_config = {
|
||||
"chat": {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-key", "model": "gpt-4o"},
|
||||
},
|
||||
"kb_processing": {
|
||||
"provider": "ollama",
|
||||
"config": {"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
},
|
||||
}
|
||||
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(saved_config, f)
|
||||
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
chat_config = manager.get_current_config(LLMUsageType.CHAT)
|
||||
assert chat_config["config"]["model"] == "gpt-4o"
|
||||
|
||||
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_config["provider"] == "ollama"
|
||||
assert kb_config["config"]["model"] == "llama3.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backward_compatibility_old_config_format(self, mock_settings, tmp_path):
|
||||
"""Test backward compatibility with old single-config format."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
old_config = {
|
||||
"provider": "deepseek",
|
||||
"config": {"api_key": "test-key", "model": "deepseek-chat"},
|
||||
}
|
||||
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(old_config, f)
|
||||
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
config = manager.get_current_config(ut)
|
||||
assert config["provider"] == "deepseek"
|
||||
assert config["config"]["model"] == "deepseek-chat"
|
||||
|
||||
|
||||
class TestLLMConfigManagerTestConnection:
|
||||
"""Tests for test_connection with usage type support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_connection_with_usage_type(self, mock_settings, isolated_config_file):
|
||||
"""Test connection testing with specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.generate = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content="Test response",
|
||||
model="gpt-4o-mini",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
)
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
with patch.object(
|
||||
LLMProviderFactory, "create_client", return_value=mock_client
|
||||
):
|
||||
result = await manager.test_connection(
|
||||
test_prompt="Hello",
|
||||
usage_type=LLMUsageType.CHAT,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "response" in result
|
||||
assert "latency_ms" in result
|
||||
|
||||
|
||||
class TestGetLLMConfigManager:
|
||||
"""Tests for get_llm_config_manager singleton."""
|
||||
|
||||
def test_singleton_instance(self, mock_settings, isolated_config_file):
|
||||
"""Test that get_llm_config_manager returns singleton."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager1 = get_llm_config_manager()
|
||||
manager2 = get_llm_config_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_singleton(self, mock_settings, isolated_config_file):
|
||||
"""Test resetting singleton instance."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = get_llm_config_manager()
|
||||
assert manager is not None
|
||||
|
||||
|
||||
class TestLLMConfigManagerProperties:
|
||||
"""Tests for LLMConfigManager properties."""
|
||||
|
||||
def test_chat_config_property(self, mock_settings, isolated_config_file):
|
||||
"""Test chat_config property."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
config = manager.chat_config
|
||||
assert isinstance(config, dict)
|
||||
assert "model" in config
|
||||
|
||||
def test_kb_processing_config_property(self, mock_settings, isolated_config_file):
|
||||
"""Test kb_processing_config property."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
config = manager.kb_processing_config
|
||||
assert isinstance(config, dict)
|
||||
assert "model" in config
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
"""
|
||||
Unit tests for Markdown intelligent chunker.
|
||||
Tests for MarkdownParser, MarkdownChunker, and integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.document.markdown_chunker import (
|
||||
MarkdownChunk,
|
||||
MarkdownChunker,
|
||||
MarkdownElement,
|
||||
MarkdownElementType,
|
||||
MarkdownParser,
|
||||
chunk_markdown,
|
||||
)
|
||||
|
||||
|
||||
class TestMarkdownParser:
|
||||
"""Tests for MarkdownParser."""
|
||||
|
||||
def test_parse_headers(self):
|
||||
"""Test header extraction."""
|
||||
text = """# Main Title
|
||||
|
||||
## Section 1
|
||||
|
||||
### Subsection 1.1
|
||||
|
||||
#### Deep Header
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
|
||||
assert len(headers) == 4
|
||||
assert headers[0].content == "Main Title"
|
||||
assert headers[0].level == 1
|
||||
assert headers[1].content == "Section 1"
|
||||
assert headers[1].level == 2
|
||||
assert headers[2].content == "Subsection 1.1"
|
||||
assert headers[2].level == 3
|
||||
assert headers[3].content == "Deep Header"
|
||||
assert headers[3].level == 4
|
||||
|
||||
def test_parse_code_blocks(self):
|
||||
"""Test code block extraction with language."""
|
||||
text = """Here is some code:
|
||||
|
||||
```python
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
And some more text.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_blocks) == 1
|
||||
assert code_blocks[0].language == "python"
|
||||
assert 'def hello():' in code_blocks[0].content
|
||||
assert 'print("Hello, World!")' in code_blocks[0].content
|
||||
|
||||
def test_parse_code_blocks_no_language(self):
|
||||
"""Test code block without language specification."""
|
||||
text = """```
|
||||
plain code here
|
||||
multiple lines
|
||||
```
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_blocks) == 1
|
||||
assert code_blocks[0].language == ""
|
||||
assert "plain code here" in code_blocks[0].content
|
||||
|
||||
def test_parse_tables(self):
|
||||
"""Test table extraction."""
|
||||
text = """| Name | Age | City |
|
||||
|------|-----|------|
|
||||
| Alice | 30 | NYC |
|
||||
| Bob | 25 | LA |
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(tables) == 1
|
||||
assert "Name" in tables[0].content
|
||||
assert "Alice" in tables[0].content
|
||||
assert tables[0].metadata.get("headers") == ["Name", "Age", "City"]
|
||||
assert tables[0].metadata.get("row_count") == 2
|
||||
|
||||
def test_parse_lists(self):
|
||||
"""Test list extraction."""
|
||||
text = """- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(lists) == 1
|
||||
assert "Item 1" in lists[0].content
|
||||
assert "Item 2" in lists[0].content
|
||||
assert "Item 3" in lists[0].content
|
||||
|
||||
def test_parse_ordered_lists(self):
|
||||
"""Test ordered list extraction."""
|
||||
text = """1. First
|
||||
2. Second
|
||||
3. Third
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(lists) == 1
|
||||
assert "First" in lists[0].content
|
||||
assert "Second" in lists[0].content
|
||||
assert "Third" in lists[0].content
|
||||
|
||||
def test_parse_blockquotes(self):
|
||||
"""Test blockquote extraction."""
|
||||
text = """> This is a quote.
|
||||
> It spans multiple lines.
|
||||
> And continues here.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
|
||||
|
||||
assert len(quotes) == 1
|
||||
assert "This is a quote." in quotes[0].content
|
||||
assert "It spans multiple lines." in quotes[0].content
|
||||
|
||||
def test_parse_paragraphs(self):
|
||||
"""Test paragraph extraction."""
|
||||
text = """This is the first paragraph.
|
||||
|
||||
This is the second paragraph.
|
||||
It has multiple lines.
|
||||
|
||||
This is the third.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
|
||||
|
||||
assert len(paragraphs) == 3
|
||||
assert "first paragraph" in paragraphs[0].content
|
||||
assert "second paragraph" in paragraphs[1].content
|
||||
|
||||
def test_parse_mixed_content(self):
|
||||
"""Test parsing mixed Markdown content."""
|
||||
text = """# Documentation
|
||||
|
||||
## Introduction
|
||||
|
||||
This is an introduction paragraph.
|
||||
|
||||
## Code Example
|
||||
|
||||
```python
|
||||
def example():
|
||||
return 42
|
||||
```
|
||||
|
||||
## Data Table
|
||||
|
||||
| Column A | Column B |
|
||||
|----------|----------|
|
||||
| Value 1 | Value 2 |
|
||||
|
||||
## List
|
||||
|
||||
- Item A
|
||||
- Item B
|
||||
|
||||
> Note: This is important.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
|
||||
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
|
||||
|
||||
assert len(headers) == 5
|
||||
assert len(code_blocks) == 1
|
||||
assert len(tables) == 1
|
||||
assert len(lists) == 1
|
||||
assert len(quotes) == 1
|
||||
assert len(paragraphs) >= 1
|
||||
|
||||
def test_code_blocks_not_parsed_as_other_elements(self):
|
||||
"""Test that code blocks don't get parsed as headers or lists."""
|
||||
text = """```markdown
|
||||
# This is not a header
|
||||
- This is not a list
|
||||
| This is not a table |
|
||||
```
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(headers) == 0
|
||||
assert len(lists) == 0
|
||||
assert len(tables) == 0
|
||||
assert len(code_blocks) == 1
|
||||
|
||||
|
||||
class TestMarkdownChunker:
|
||||
"""Tests for MarkdownChunker."""
|
||||
|
||||
def test_chunk_simple_document(self):
|
||||
"""Test chunking a simple document."""
|
||||
text = """# Title
|
||||
|
||||
This is a paragraph.
|
||||
|
||||
## Section
|
||||
|
||||
Another paragraph.
|
||||
"""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test_doc")
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert all(isinstance(chunk, MarkdownChunk) for chunk in chunks)
|
||||
assert all(chunk.chunk_id.startswith("test_doc") for chunk in chunks)
|
||||
|
||||
def test_chunk_preserves_header_context(self):
|
||||
"""Test that header context is preserved."""
|
||||
text = """# Main Title
|
||||
|
||||
## Section A
|
||||
|
||||
Content under section A.
|
||||
|
||||
### Subsection A1
|
||||
|
||||
Content under subsection A1.
|
||||
"""
|
||||
chunker = MarkdownChunker(include_header_context=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
subsection_chunks = [c for c in chunks if "Subsection A1" not in c.content]
|
||||
for chunk in subsection_chunks:
|
||||
if "subsection a1" in chunk.content.lower():
|
||||
assert "Main Title" in chunk.header_context
|
||||
assert "Section A" in chunk.header_context
|
||||
|
||||
def test_chunk_code_blocks_preserved(self):
|
||||
"""Test that code blocks are preserved as single chunks when possible."""
|
||||
text = """```python
|
||||
def function_one():
|
||||
pass
|
||||
|
||||
def function_two():
|
||||
pass
|
||||
```
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_code_blocks=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_chunks) == 1
|
||||
assert "def function_one" in code_chunks[0].content
|
||||
assert "def function_two" in code_chunks[0].content
|
||||
assert code_chunks[0].language == "python"
|
||||
|
||||
def test_chunk_large_code_block_split(self):
|
||||
"""Test that large code blocks are split properly."""
|
||||
lines = ["def function_{}(): pass".format(i) for i in range(100)]
|
||||
code_content = "\n".join(lines)
|
||||
text = f"""```python\n{code_content}\n```"""
|
||||
|
||||
chunker = MarkdownChunker(max_chunk_size=500, preserve_code_blocks=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_chunks) > 1
|
||||
for chunk in code_chunks:
|
||||
assert chunk.language == "python"
|
||||
assert "```python" in chunk.content
|
||||
assert "```" in chunk.content
|
||||
|
||||
def test_chunk_table_preserved(self):
|
||||
"""Test that tables are preserved."""
|
||||
text = """| Name | Age |
|
||||
|------|-----|
|
||||
| Alice | 30 |
|
||||
| Bob | 25 |
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_tables=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(table_chunks) == 1
|
||||
assert "Alice" in table_chunks[0].content
|
||||
assert "Bob" in table_chunks[0].content
|
||||
|
||||
def test_chunk_large_table_split(self):
|
||||
"""Test that large tables are split with header preserved."""
|
||||
rows = [f"| Name{i} | {i * 10} |" for i in range(50)]
|
||||
table_content = "| Name | Age |\n|------|-----|\n" + "\n".join(rows)
|
||||
text = table_content
|
||||
|
||||
chunker = MarkdownChunker(max_chunk_size=200, preserve_tables=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(table_chunks) > 1
|
||||
for chunk in table_chunks:
|
||||
assert "| Name | Age |" in chunk.content
|
||||
assert "|------|-----|" in chunk.content
|
||||
|
||||
def test_chunk_list_preserved(self):
|
||||
"""Test that lists are chunked properly."""
|
||||
text = """- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
- Item 4
|
||||
- Item 5
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_lists=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
list_chunks = [c for c in chunks if c.element_type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(list_chunks) == 1
|
||||
assert "Item 1" in list_chunks[0].content
|
||||
assert "Item 5" in list_chunks[0].content
|
||||
|
||||
def test_chunk_empty_document(self):
|
||||
"""Test chunking an empty document."""
|
||||
text = ""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_chunk_only_headers(self):
|
||||
"""Test chunking a document with only headers."""
|
||||
text = """# Title 1
|
||||
## Title 2
|
||||
### Title 3
|
||||
"""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
class TestChunkMarkdownFunction:
|
||||
"""Tests for the convenience chunk_markdown function."""
|
||||
|
||||
def test_basic_chunking(self):
|
||||
"""Test basic chunking via convenience function."""
|
||||
text = """# Title
|
||||
|
||||
Content paragraph.
|
||||
|
||||
```python
|
||||
code = "here"
|
||||
```
|
||||
"""
|
||||
chunks = chunk_markdown(text, "doc1")
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert all("chunk_id" in chunk for chunk in chunks)
|
||||
assert all("content" in chunk for chunk in chunks)
|
||||
assert all("element_type" in chunk for chunk in chunks)
|
||||
assert all("header_context" in chunk for chunk in chunks)
|
||||
|
||||
def test_custom_parameters(self):
|
||||
"""Test chunking with custom parameters."""
|
||||
text = "A" * 2000
|
||||
|
||||
chunks = chunk_markdown(
|
||||
text,
|
||||
"doc1",
|
||||
max_chunk_size=500,
|
||||
min_chunk_size=50,
|
||||
preserve_code_blocks=False,
|
||||
preserve_tables=False,
|
||||
preserve_lists=False,
|
||||
include_header_context=False,
|
||||
)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
|
||||
|
||||
class TestMarkdownElement:
|
||||
"""Tests for MarkdownElement dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.HEADER,
|
||||
content="Test Header",
|
||||
level=2,
|
||||
line_start=10,
|
||||
line_end=10,
|
||||
metadata={"level": 2},
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "header"
|
||||
assert result["content"] == "Test Header"
|
||||
assert result["level"] == 2
|
||||
assert result["line_start"] == 10
|
||||
assert result["line_end"] == 10
|
||||
|
||||
def test_code_block_with_language(self):
|
||||
"""Test code block element with language."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.CODE_BLOCK,
|
||||
content="print('hello')",
|
||||
language="python",
|
||||
line_start=5,
|
||||
line_end=7,
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "code_block"
|
||||
assert result["language"] == "python"
|
||||
|
||||
def test_table_with_metadata(self):
|
||||
"""Test table element with metadata."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.TABLE,
|
||||
content="| A | B |\n|---|---|\n| 1 | 2 |",
|
||||
line_start=1,
|
||||
line_end=3,
|
||||
metadata={"headers": ["A", "B"], "row_count": 1},
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "table"
|
||||
assert result["metadata"]["headers"] == ["A", "B"]
|
||||
assert result["metadata"]["row_count"] == 1
|
||||
|
||||
|
||||
class TestMarkdownChunk:
|
||||
"""Tests for MarkdownChunk dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
chunk = MarkdownChunk(
|
||||
chunk_id="doc_chunk_0",
|
||||
content="Test content",
|
||||
element_type=MarkdownElementType.PARAGRAPH,
|
||||
header_context=["Main Title", "Section"],
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
result = chunk.to_dict()
|
||||
|
||||
assert result["chunk_id"] == "doc_chunk_0"
|
||||
assert result["content"] == "Test content"
|
||||
assert result["element_type"] == "paragraph"
|
||||
assert result["header_context"] == ["Main Title", "Section"]
|
||||
assert result["metadata"]["key"] == "value"
|
||||
|
||||
def test_with_language(self):
|
||||
"""Test chunk with language info."""
|
||||
chunk = MarkdownChunk(
|
||||
chunk_id="code_0",
|
||||
content="```python\nprint('hi')\n```",
|
||||
element_type=MarkdownElementType.CODE_BLOCK,
|
||||
header_context=[],
|
||||
language="python",
|
||||
)
|
||||
|
||||
result = chunk.to_dict()
|
||||
|
||||
assert result["language"] == "python"
|
||||
|
||||
|
||||
class TestMarkdownElementType:
|
||||
"""Tests for MarkdownElementType enum."""
|
||||
|
||||
def test_all_types_exist(self):
|
||||
"""Test that all expected element types exist."""
|
||||
expected_types = [
|
||||
"header",
|
||||
"paragraph",
|
||||
"code_block",
|
||||
"inline_code",
|
||||
"table",
|
||||
"list",
|
||||
"blockquote",
|
||||
"horizontal_rule",
|
||||
"image",
|
||||
"link",
|
||||
"text",
|
||||
]
|
||||
|
||||
for type_name in expected_types:
|
||||
assert hasattr(MarkdownElementType, type_name.upper()) or \
|
||||
any(t.value == type_name for t in MarkdownElementType)
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
"""
|
||||
Unit tests for MetadataAutoInferenceService.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.metadata_auto_inference_service import (
|
||||
AutoInferenceResult,
|
||||
InferenceFieldContext,
|
||||
MetadataAutoInferenceService,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFieldDefinition:
|
||||
"""Mock field definition for testing"""
|
||||
field_key: str
|
||||
label: str
|
||||
type: str
|
||||
required: bool
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
class TestInferenceFieldContext:
|
||||
"""Test InferenceFieldContext dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating InferenceFieldContext."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
)
|
||||
|
||||
assert ctx.field_key == "grade"
|
||||
assert ctx.label == "年级"
|
||||
assert ctx.type == "enum"
|
||||
assert ctx.required is True
|
||||
assert ctx.options == ["初一", "初二", "初三"]
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="test",
|
||||
label="Test",
|
||||
type="text",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert ctx.options is None
|
||||
assert ctx.description is None
|
||||
|
||||
|
||||
class TestAutoInferenceResult:
|
||||
"""Test AutoInferenceResult dataclass."""
|
||||
|
||||
def test_success_result(self):
|
||||
"""Test successful inference result."""
|
||||
result = AutoInferenceResult(
|
||||
inferred_metadata={"grade": "初一", "subject": "数学"},
|
||||
confidence_scores={"grade": 0.95, "subject": 0.85},
|
||||
raw_response='{"inferred_metadata": {...}}',
|
||||
success=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.error_message is None
|
||||
assert result.inferred_metadata["grade"] == "初一"
|
||||
|
||||
def test_failure_result(self):
|
||||
"""Test failed inference result."""
|
||||
result = AutoInferenceResult(
|
||||
inferred_metadata={},
|
||||
confidence_scores={},
|
||||
raw_response="",
|
||||
success=False,
|
||||
error_message="JSON parse error",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_message == "JSON parse error"
|
||||
|
||||
|
||||
class TestMetadataAutoInferenceService:
|
||||
"""Test MetadataAutoInferenceService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance."""
|
||||
return MetadataAutoInferenceService(mock_session)
|
||||
|
||||
def test_build_field_contexts(self, service):
|
||||
"""Test building field contexts from definitions."""
|
||||
fields = [
|
||||
MockFieldDefinition(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
MockFieldDefinition(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
contexts = service._build_field_contexts(fields)
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0].field_key == "grade"
|
||||
assert contexts[0].options == ["初一", "初二", "初三"]
|
||||
assert contexts[1].field_key == "subject"
|
||||
|
||||
def test_build_user_prompt(self, service):
|
||||
"""Test building user prompt."""
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
prompt = service._build_user_prompt(
|
||||
content="这是一道初一数学题",
|
||||
field_contexts=field_contexts,
|
||||
)
|
||||
|
||||
assert "年级" in prompt
|
||||
assert "初一, 初二, 初三" in prompt
|
||||
assert "这是一道初一数学题" in prompt
|
||||
|
||||
def test_build_user_prompt_with_existing_metadata(self, service):
|
||||
"""Test building user prompt with existing metadata."""
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
prompt = service._build_user_prompt(
|
||||
content="这是一道数学题",
|
||||
field_contexts=field_contexts,
|
||||
existing_metadata={"grade": "初二"},
|
||||
)
|
||||
|
||||
assert "已有值: 初二" in prompt
|
||||
|
||||
def test_extract_json_from_plain_json(self, service):
|
||||
"""Test extracting JSON from plain JSON response."""
|
||||
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
|
||||
|
||||
result = service._extract_json(json_str)
|
||||
|
||||
assert result == json_str
|
||||
|
||||
def test_extract_json_from_markdown(self, service):
|
||||
"""Test extracting JSON from markdown code block."""
|
||||
markdown = """Here is the result:
|
||||
|
||||
```json
|
||||
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
|
||||
```
|
||||
"""
|
||||
|
||||
result = service._extract_json(markdown)
|
||||
|
||||
assert "inferred_metadata" in result
|
||||
assert "grade" in result
|
||||
|
||||
def test_parse_llm_response_valid(self, service):
|
||||
"""Test parsing valid LLM response."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初一",
|
||||
"subject": "数学",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.95,
|
||||
"subject": 0.85,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert result.inferred_metadata["grade"] == "初一"
|
||||
assert result.inferred_metadata["subject"] == "数学"
|
||||
assert result.confidence_scores["grade"] == 0.95
|
||||
|
||||
def test_parse_llm_response_invalid_option(self, service):
|
||||
"""Test parsing response with invalid enum option."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "高一", # Not in options
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.90,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert "grade" not in result.inferred_metadata
|
||||
|
||||
def test_parse_llm_response_invalid_json(self, service):
|
||||
"""Test parsing invalid JSON response."""
|
||||
response = "This is not valid JSON"
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="text",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is False
|
||||
assert "JSON parse error" in result.error_message
|
||||
|
||||
def test_validate_field_value_text(self, service):
|
||||
"""Test validating text field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="title",
|
||||
label="标题",
|
||||
type="text",
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = service._validate_field_value(ctx, "测试标题")
|
||||
|
||||
assert result == "测试标题"
|
||||
|
||||
def test_validate_field_value_number(self, service):
|
||||
"""Test validating number field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="count",
|
||||
label="数量",
|
||||
type="number",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, 42) == 42
|
||||
assert service._validate_field_value(ctx, "3.14") == 3.14
|
||||
assert service._validate_field_value(ctx, "invalid") is None
|
||||
|
||||
def test_validate_field_value_boolean(self, service):
|
||||
"""Test validating boolean field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="active",
|
||||
label="是否激活",
|
||||
type="boolean",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, True) is True
|
||||
assert service._validate_field_value(ctx, "true") is True
|
||||
assert service._validate_field_value(ctx, "false") is False
|
||||
assert service._validate_field_value(ctx, 1) is True
|
||||
|
||||
def test_validate_field_value_enum(self, service):
|
||||
"""Test validating enum field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=False,
|
||||
options=["初一", "初二", "初三"],
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, "初一") == "初一"
|
||||
assert service._validate_field_value(ctx, "高一") is None
|
||||
|
||||
def test_validate_field_value_array_enum(self, service):
|
||||
"""Test validating array_enum field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="tags",
|
||||
label="标签",
|
||||
type="array_enum",
|
||||
required=False,
|
||||
options=["重点", "难点", "易错"],
|
||||
)
|
||||
|
||||
result = service._validate_field_value(ctx, ["重点", "难点"])
|
||||
assert result == ["重点", "难点"]
|
||||
|
||||
result = service._validate_field_value(ctx, ["重点", "不存在"])
|
||||
assert result == ["重点"]
|
||||
|
||||
result = service._validate_field_value(ctx, "重点")
|
||||
assert result == ["重点"]
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance."""
|
||||
return MetadataAutoInferenceService(mock_session)
|
||||
|
||||
def test_education_scenario(self, service):
|
||||
"""Test education scenario with grade and subject."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初二",
|
||||
"subject": "物理",
|
||||
"type": "痛点",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.95,
|
||||
"subject": 0.90,
|
||||
"type": 0.85,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语", "物理", "化学"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="type",
|
||||
label="类型",
|
||||
type="enum",
|
||||
required=False,
|
||||
options=["痛点", "重点", "难点"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert result.inferred_metadata == {
|
||||
"grade": "初二",
|
||||
"subject": "物理",
|
||||
"type": "痛点",
|
||||
}
|
||||
assert result.confidence_scores["grade"] == 0.95
|
||||
|
||||
def test_partial_inference(self, service):
|
||||
"""Test partial inference when some fields cannot be inferred."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初一",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.90,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert "grade" in result.inferred_metadata
|
||||
assert "subject" not in result.inferred_metadata
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -0,0 +1,881 @@
|
|||
"""
|
||||
Mid Platform Dialogue Integration Test.
|
||||
中台联调界面对话过程集成测试脚本
|
||||
|
||||
测试重点:
|
||||
1. 意图置信度参数
|
||||
2. 执行模式(通用API vs ReAct模式)
|
||||
3. ReAct模式下的工具调用(工具名称、入参、返回结果)
|
||||
4. 知识库查询(是否命中、入参)
|
||||
5. 各部分耗时
|
||||
6. 提示词模板使用情况
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.mid.schemas import (
|
||||
DialogueRequest,
|
||||
DialogueResponse,
|
||||
ExecutionMode,
|
||||
FeatureFlags,
|
||||
HistoryMessage,
|
||||
Segment,
|
||||
TraceInfo,
|
||||
ToolCallTrace,
|
||||
ToolCallStatus,
|
||||
ToolType,
|
||||
IntentHintOutput,
|
||||
HighRiskCheckResult,
|
||||
HighRiskScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingRecord:
|
||||
"""耗时记录"""
|
||||
stage: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
duration_ms: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"stage": self.stage,
|
||||
"duration_ms": self.duration_ms,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueTestResult:
|
||||
"""对话测试结果"""
|
||||
request_id: str
|
||||
user_message: str
|
||||
response_text: str = ""
|
||||
|
||||
timing_records: list[TimingRecord] = field(default_factory=list)
|
||||
total_duration_ms: int = 0
|
||||
|
||||
execution_mode: ExecutionMode = ExecutionMode.AGENT
|
||||
intent: str | None = None
|
||||
confidence: float | None = None
|
||||
|
||||
react_iterations: int = 0
|
||||
tools_used: list[str] = field(default_factory=list)
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
|
||||
kb_tool_called: bool = False
|
||||
kb_hit: bool = False
|
||||
kb_query: str | None = None
|
||||
kb_filter: dict | None = None
|
||||
kb_hits_count: int = 0
|
||||
|
||||
prompt_template_used: str | None = None
|
||||
prompt_template_scene: str | None = None
|
||||
|
||||
guardrail_triggered: bool = False
|
||||
fallback_reason_code: str | None = None
|
||||
|
||||
raw_trace: dict | None = None
|
||||
|
||||
def to_summary(self) -> dict:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"user_message": self.user_message[:100],
|
||||
"execution_mode": self.execution_mode.value,
|
||||
"intent": self.intent,
|
||||
"confidence": self.confidence,
|
||||
"total_duration_ms": self.total_duration_ms,
|
||||
"timing_breakdown": [t.to_dict() for t in self.timing_records],
|
||||
"react_iterations": self.react_iterations,
|
||||
"tools_used": self.tools_used,
|
||||
"tool_calls_count": len(self.tool_calls),
|
||||
"kb_tool_called": self.kb_tool_called,
|
||||
"kb_hit": self.kb_hit,
|
||||
"kb_hits_count": self.kb_hits_count,
|
||||
"prompt_template_used": self.prompt_template_used,
|
||||
"guardrail_triggered": self.guardrail_triggered,
|
||||
"fallback_reason_code": self.fallback_reason_code,
|
||||
}
|
||||
|
||||
|
||||
class DialogueIntegrationTester:
|
||||
"""对话集成测试器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
tenant_id: str = "test_tenant",
|
||||
api_key: str | None = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.tenant_id = tenant_id
|
||||
self.api_key = api_key
|
||||
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
|
||||
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Tenant-Id": self.tenant_id,
|
||||
}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
return headers
|
||||
|
||||
async def send_dialogue(
|
||||
self,
|
||||
user_message: str,
|
||||
history: list[dict] | None = None,
|
||||
scene: str | None = None,
|
||||
feature_flags: dict | None = None,
|
||||
) -> DialogueTestResult:
|
||||
"""发送对话请求并记录详细信息"""
|
||||
request_id = str(uuid.uuid4())
|
||||
result = DialogueTestResult(
|
||||
request_id=request_id,
|
||||
user_message=user_message,
|
||||
response_text="",
|
||||
)
|
||||
|
||||
overall_start = time.time()
|
||||
|
||||
request_body = {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"user_message": user_message,
|
||||
"history": history or [],
|
||||
}
|
||||
|
||||
if scene:
|
||||
request_body["scene"] = scene
|
||||
if feature_flags:
|
||||
request_body["feature_flags"] = feature_flags
|
||||
|
||||
timing_records = []
|
||||
|
||||
try:
|
||||
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
|
||||
request_start = time.time()
|
||||
|
||||
response = await client.post(
|
||||
"/mid/dialogue/respond",
|
||||
json=request_body,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
request_end = time.time()
|
||||
timing_records.append(TimingRecord(
|
||||
stage="http_request",
|
||||
start_time=request_start,
|
||||
end_time=request_end,
|
||||
duration_ms=int((request_end - request_start) * 1000),
|
||||
))
|
||||
|
||||
if response.status_code != 200:
|
||||
result.response_text = f"Error: {response.status_code}"
|
||||
result.fallback_reason_code = f"http_error_{response.status_code}"
|
||||
return result
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
except Exception as e:
|
||||
result.response_text = f"Exception: {str(e)}"
|
||||
result.fallback_reason_code = "request_exception"
|
||||
result.total_duration_ms = int((time.time() - overall_start) * 1000)
|
||||
return result
|
||||
|
||||
overall_end = time.time()
|
||||
result.total_duration_ms = int((overall_end - overall_start) * 1000)
|
||||
result.timing_records = timing_records
|
||||
|
||||
try:
|
||||
dialogue_response = DialogueResponse(**response_data)
|
||||
|
||||
if dialogue_response.segments:
|
||||
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
|
||||
|
||||
trace = dialogue_response.trace
|
||||
result.raw_trace = trace.model_dump() if trace else None
|
||||
|
||||
if trace:
|
||||
result.execution_mode = trace.mode
|
||||
result.intent = trace.intent
|
||||
result.react_iterations = trace.react_iterations or 0
|
||||
result.tools_used = trace.tools_used or []
|
||||
result.kb_tool_called = trace.kb_tool_called or False
|
||||
result.kb_hit = trace.kb_hit or False
|
||||
result.guardrail_triggered = trace.guardrail_triggered or False
|
||||
result.fallback_reason_code = trace.fallback_reason_code
|
||||
|
||||
if trace.tool_calls:
|
||||
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
|
||||
|
||||
for tc in trace.tool_calls:
|
||||
if tc.tool_name == "kb_search_dynamic":
|
||||
result.kb_tool_called = True
|
||||
if tc.arguments:
|
||||
result.kb_query = tc.arguments.get("query")
|
||||
result.kb_filter = tc.arguments.get("context")
|
||||
if tc.result and isinstance(tc.result, dict):
|
||||
result.kb_hits_count = len(tc.result.get("hits", []))
|
||||
result.kb_hit = result.kb_hits_count > 0
|
||||
|
||||
if trace.duration_ms:
|
||||
timing_records.append(TimingRecord(
|
||||
stage="server_processing",
|
||||
start_time=overall_start,
|
||||
end_time=overall_end,
|
||||
duration_ms=trace.duration_ms,
|
||||
))
|
||||
|
||||
if trace.scene:
|
||||
result.prompt_template_scene = trace.scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse response: {e}")
|
||||
result.response_text = str(response_data)
|
||||
|
||||
return result
|
||||
|
||||
def print_result(self, result: DialogueTestResult):
|
||||
"""打印测试结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[对话测试结果] Request ID: {result.request_id}")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n[用户消息] {result.user_message}")
|
||||
print(f"[回复内容] {result.response_text[:200]}...")
|
||||
|
||||
print(f"\n[执行模式] {result.execution_mode.value}")
|
||||
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
|
||||
|
||||
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
|
||||
for tr in result.timing_records:
|
||||
print(f" - {tr.stage}: {tr.duration_ms}ms")
|
||||
|
||||
if result.execution_mode == ExecutionMode.AGENT:
|
||||
print(f"\n[ReAct模式]")
|
||||
print(f" - 迭代次数: {result.react_iterations}")
|
||||
print(f" - 使用的工具: {result.tools_used}")
|
||||
|
||||
if result.tool_calls:
|
||||
print(f"\n[工具调用详情]")
|
||||
for i, tc in enumerate(result.tool_calls, 1):
|
||||
print(f" [{i}] 工具: {tc.get('tool_name')}")
|
||||
print(f" 状态: {tc.get('status')}")
|
||||
print(f" 耗时: {tc.get('duration_ms')}ms")
|
||||
if tc.get('arguments'):
|
||||
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
|
||||
if tc.get('result'):
|
||||
result_str = str(tc.get('result'))[:300]
|
||||
print(f" 结果: {result_str}")
|
||||
|
||||
print(f"\n[知识库查询]")
|
||||
print(f" - 是否调用: {result.kb_tool_called}")
|
||||
print(f" - 是否命中: {result.kb_hit}")
|
||||
if result.kb_query:
|
||||
print(f" - 查询内容: {result.kb_query}")
|
||||
if result.kb_filter:
|
||||
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
|
||||
print(f" - 命中数量: {result.kb_hits_count}")
|
||||
|
||||
print(f"\n[提示词模板]")
|
||||
print(f" - 场景: {result.prompt_template_scene}")
|
||||
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
|
||||
|
||||
print(f"\n[其他信息]")
|
||||
print(f" - 护栏触发: {result.guardrail_triggered}")
|
||||
print(f" - 降级原因: {result.fallback_reason_code or '无'}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
class TestMidDialogueIntegration:
|
||||
"""中台对话集成测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def tester(self):
|
||||
return DialogueIntegrationTester(
|
||||
base_url="http://localhost:8000",
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client(self):
|
||||
"""模拟 LLM 客户端"""
|
||||
mock = AsyncMock()
|
||||
mock.generate = AsyncMock(return_value=MagicMock(
|
||||
content="这是测试回复",
|
||||
has_tool_calls=False,
|
||||
tool_calls=[],
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kb_tool(self):
|
||||
"""模拟知识库工具"""
|
||||
mock = AsyncMock()
|
||||
mock.execute = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
hits=[
|
||||
{"id": "1", "content": "测试知识库内容", "score": 0.9},
|
||||
],
|
||||
applied_filter={"scene": "test"},
|
||||
missing_required_slots=[],
|
||||
fallback_reason_code=None,
|
||||
duration_ms=100,
|
||||
tool_trace=None,
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
|
||||
"""测试简单问候"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="你好",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
assert result.total_duration_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_query(self, tester: DialogueIntegrationTester):
|
||||
"""测试知识库查询"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="退款流程是什么?",
|
||||
scene="after_sale",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
|
||||
"""测试高风险场景"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="我要投诉你们的服务",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_request(self, tester: DialogueIntegrationTester):
|
||||
"""测试转人工请求"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="帮我转人工客服",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_history(self, tester: DialogueIntegrationTester):
|
||||
"""测试带历史记录的对话"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="那退款要多久呢?",
|
||||
history=[
|
||||
{"role": "user", "content": "我想退款"},
|
||||
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
|
||||
],
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
|
||||
class TestDialogueWithMock:
|
||||
"""使用 Mock 的对话测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""创建带 Mock 的测试应用"""
|
||||
from fastapi import FastAPI
|
||||
from app.api.mid.dialogue import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, mock_app):
|
||||
return TestClient(mock_app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""模拟数据库会话"""
|
||||
mock = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock.execute.return_value = mock_result
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
"""模拟 LLM 响应"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "这是测试回复内容"
|
||||
mock_response.has_tool_calls = False
|
||||
mock_response.tool_calls = []
|
||||
return mock_response
|
||||
|
||||
def test_dialogue_request_structure(self, client: TestClient):
|
||||
"""测试对话请求结构"""
|
||||
request_body = {
|
||||
"session_id": "test_session_001",
|
||||
"user_id": "test_user_001",
|
||||
"user_message": "你好",
|
||||
"history": [],
|
||||
"scene": "open_consult",
|
||||
}
|
||||
|
||||
print("\n[测试请求结构]")
|
||||
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
|
||||
|
||||
assert request_body["session_id"] == "test_session_001"
|
||||
assert request_body["user_message"] == "你好"
|
||||
|
||||
def test_trace_info_structure(self):
|
||||
"""测试追踪信息结构"""
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
intent="greeting",
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
kb_tool_called=True,
|
||||
kb_hit=True,
|
||||
react_iterations=2,
|
||||
tools_used=["kb_search_dynamic"],
|
||||
tool_calls=[
|
||||
ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=150,
|
||||
status=ToolCallStatus.OK,
|
||||
arguments={"query": "测试查询", "scene": "test"},
|
||||
result={"hits": [{"content": "测试内容"}]},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
print("\n[TraceInfo 结构测试]")
|
||||
print(f"Mode: {trace.mode.value}")
|
||||
print(f"Intent: {trace.intent}")
|
||||
print(f"KB Tool Called: {trace.kb_tool_called}")
|
||||
print(f"KB Hit: {trace.kb_hit}")
|
||||
print(f"React Iterations: {trace.react_iterations}")
|
||||
print(f"Tools Used: {trace.tools_used}")
|
||||
|
||||
if trace.tool_calls:
|
||||
print(f"\n[Tool Calls]")
|
||||
for tc in trace.tool_calls:
|
||||
print(f" - Tool: {tc.tool_name}")
|
||||
print(f" Status: {tc.status.value}")
|
||||
print(f" Duration: {tc.duration_ms}ms")
|
||||
if tc.arguments:
|
||||
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
|
||||
|
||||
assert trace.mode == ExecutionMode.AGENT
|
||||
assert trace.kb_tool_called is True
|
||||
assert len(trace.tool_calls) == 1
|
||||
|
||||
def test_intent_hint_output_structure(self):
|
||||
"""测试意图提示输出结构"""
|
||||
hint = IntentHintOutput(
|
||||
intent="refund",
|
||||
confidence=0.85,
|
||||
response_type="flow",
|
||||
suggested_mode=ExecutionMode.MICRO_FLOW,
|
||||
target_flow_id="flow_refund_001",
|
||||
high_risk_detected=False,
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
print("\n[IntentHintOutput 结构测试]")
|
||||
print(f"Intent: {hint.intent}")
|
||||
print(f"Confidence: {hint.confidence}")
|
||||
print(f"Response Type: {hint.response_type}")
|
||||
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
|
||||
print(f"High Risk Detected: {hint.high_risk_detected}")
|
||||
print(f"Duration: {hint.duration_ms}ms")
|
||||
|
||||
assert hint.intent == "refund"
|
||||
assert hint.confidence == 0.85
|
||||
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
|
||||
|
||||
def test_high_risk_check_result_structure(self):
|
||||
"""测试高风险检测结果结构"""
|
||||
result = HighRiskCheckResult(
|
||||
matched=True,
|
||||
risk_scenario=HighRiskScenario.REFUND,
|
||||
confidence=0.95,
|
||||
recommended_mode=ExecutionMode.MICRO_FLOW,
|
||||
rule_id="rule_refund_001",
|
||||
reason="检测到退款关键词",
|
||||
duration_ms=30,
|
||||
)
|
||||
|
||||
print("\n[HighRiskCheckResult 结构测试]")
|
||||
print(f"Matched: {result.matched}")
|
||||
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
|
||||
print(f"Confidence: {result.confidence}")
|
||||
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
|
||||
print(f"Rule ID: {result.rule_id}")
|
||||
print(f"Duration: {result.duration_ms}ms")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.risk_scenario == HighRiskScenario.REFUND
|
||||
|
||||
|
||||
class TestPromptTemplateUsage:
|
||||
"""提示词模板使用测试"""
|
||||
|
||||
def test_template_resolution(self):
|
||||
"""测试模板解析"""
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
resolver = VariableResolver()
|
||||
|
||||
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
|
||||
variables = [
|
||||
{"key": "user_name", "value": "张三"},
|
||||
{"key": "bot_name", "value": "智能客服"},
|
||||
]
|
||||
|
||||
resolved = resolver.resolve(template, variables)
|
||||
|
||||
print("\n[模板解析测试]")
|
||||
print(f"原始模板: {template}")
|
||||
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
|
||||
print(f"解析结果: {resolved}")
|
||||
|
||||
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
|
||||
|
||||
def test_template_with_extra_context(self):
|
||||
"""测试带额外上下文的模板解析"""
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
resolver = VariableResolver()
|
||||
|
||||
template = "当前场景:{{scene}},用户问题:{{query}}"
|
||||
extra_context = {
|
||||
"scene": "售后服务",
|
||||
"query": "退款流程",
|
||||
}
|
||||
|
||||
resolved = resolver.resolve(template, [], extra_context)
|
||||
|
||||
print("\n[带上下文的模板解析测试]")
|
||||
print(f"原始模板: {template}")
|
||||
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
|
||||
print(f"解析结果: {resolved}")
|
||||
|
||||
assert "售后服务" in resolved
|
||||
assert "退款流程" in resolved
|
||||
|
||||
|
||||
class TestToolCallRecording:
|
||||
"""工具调用记录测试"""
|
||||
|
||||
def test_tool_call_trace_creation(self):
|
||||
"""测试工具调用追踪创建"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=150,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest="query=退款流程",
|
||||
result_digest="hits=3",
|
||||
arguments={
|
||||
"query": "退款流程是什么",
|
||||
"scene": "after_sale",
|
||||
"context": {"product_type": "vip"},
|
||||
},
|
||||
result={
|
||||
"success": True,
|
||||
"hits": [
|
||||
{"id": "1", "content": "退款流程说明...", "score": 0.95},
|
||||
{"id": "2", "content": "退款注意事项...", "score": 0.88},
|
||||
{"id": "3", "content": "退款时效说明...", "score": 0.82},
|
||||
],
|
||||
"applied_filter": {"product_type": "vip"},
|
||||
},
|
||||
)
|
||||
|
||||
print("\n[工具调用追踪测试]")
|
||||
print(f"Tool Name: {trace.tool_name}")
|
||||
print(f"Tool Type: {trace.tool_type.value}")
|
||||
print(f"Status: {trace.status.value}")
|
||||
print(f"Duration: {trace.duration_ms}ms")
|
||||
print(f"\n[入参详情]")
|
||||
if trace.arguments:
|
||||
for key, value in trace.arguments.items():
|
||||
print(f" - {key}: {value}")
|
||||
print(f"\n[返回结果]")
|
||||
if trace.result:
|
||||
if isinstance(trace.result, dict):
|
||||
print(f" - success: {trace.result.get('success')}")
|
||||
print(f" - hits count: {len(trace.result.get('hits', []))}")
|
||||
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
|
||||
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
|
||||
|
||||
assert trace.tool_name == "kb_search_dynamic"
|
||||
assert trace.status == ToolCallStatus.OK
|
||||
assert trace.arguments is not None
|
||||
assert trace.result is not None
|
||||
|
||||
def test_tool_call_timeout_trace(self):
|
||||
"""测试工具调用超时追踪"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=2000,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments={"query": "测试查询"},
|
||||
)
|
||||
|
||||
print("\n[工具调用超时追踪测试]")
|
||||
print(f"Tool Name: {trace.tool_name}")
|
||||
print(f"Status: {trace.status.value}")
|
||||
print(f"Error Code: {trace.error_code}")
|
||||
print(f"Duration: {trace.duration_ms}ms")
|
||||
|
||||
assert trace.status == ToolCallStatus.TIMEOUT
|
||||
assert trace.error_code == "TOOL_TIMEOUT"
|
||||
|
||||
|
||||
class TestExecutionModeRouting:
|
||||
"""执行模式路由测试"""
|
||||
|
||||
def test_policy_router_decision(self):
|
||||
"""测试策略路由器决策"""
|
||||
from app.services.mid.policy_router import PolicyRouter, IntentMatch
|
||||
|
||||
router = PolicyRouter()
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "正常对话 -> Agent模式",
|
||||
"user_message": "你好,请问有什么可以帮助我的?",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.AGENT,
|
||||
},
|
||||
{
|
||||
"name": "高风险退款 -> Micro Flow模式",
|
||||
"user_message": "我要退款",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.MICRO_FLOW,
|
||||
},
|
||||
{
|
||||
"name": "转人工请求 -> Transfer模式",
|
||||
"user_message": "帮我转人工",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.TRANSFER,
|
||||
},
|
||||
{
|
||||
"name": "人工模式 -> Transfer模式",
|
||||
"user_message": "你好",
|
||||
"session_mode": "HUMAN_ACTIVE",
|
||||
"expected_mode": ExecutionMode.TRANSFER,
|
||||
},
|
||||
]
|
||||
|
||||
print("\n[策略路由器决策测试]")
|
||||
|
||||
for tc in test_cases:
|
||||
result = router.route(
|
||||
user_message=tc["user_message"],
|
||||
session_mode=tc["session_mode"],
|
||||
)
|
||||
|
||||
print(f"\n测试用例: {tc['name']}")
|
||||
print(f" 用户消息: {tc['user_message']}")
|
||||
print(f" 会话模式: {tc['session_mode']}")
|
||||
print(f" 期望模式: {tc['expected_mode'].value}")
|
||||
print(f" 实际模式: {result.mode.value}")
|
||||
|
||||
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
|
||||
|
||||
|
||||
class TestKBSearchDynamic:
|
||||
"""知识库动态检索测试"""
|
||||
|
||||
def test_kb_search_result_structure(self):
|
||||
"""测试知识库检索结果结构"""
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||||
|
||||
result = KbSearchDynamicResult(
|
||||
success=True,
|
||||
hits=[
|
||||
{
|
||||
"id": "chunk_001",
|
||||
"content": "退款流程:1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
|
||||
"score": 0.92,
|
||||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
|
||||
},
|
||||
{
|
||||
"id": "chunk_002",
|
||||
"content": "退款时效:一般3-5个工作日到账...",
|
||||
"score": 0.85,
|
||||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
|
||||
},
|
||||
],
|
||||
applied_filter={"scene": "after_sale", "product_type": "vip"},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"source": "slot_state"},
|
||||
filter_sources={"scene": "slot", "product_type": "context"},
|
||||
duration_ms=120,
|
||||
)
|
||||
|
||||
print("\n[知识库检索结果结构测试]")
|
||||
print(f"Success: {result.success}")
|
||||
print(f"Hits Count: {len(result.hits)}")
|
||||
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
|
||||
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
|
||||
print(f"Duration: {result.duration_ms}ms")
|
||||
|
||||
print(f"\n[命中详情]")
|
||||
for i, hit in enumerate(result.hits, 1):
|
||||
print(f" [{i}] ID: {hit['id']}")
|
||||
print(f" Score: {hit['score']}")
|
||||
print(f" Content: {hit['content'][:50]}...")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.hits) == 2
|
||||
assert result.applied_filter is not None
|
||||
|
||||
def test_kb_search_missing_slots(self):
|
||||
"""测试知识库检索缺失槽位"""
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||||
|
||||
result = KbSearchDynamicResult(
|
||||
success=False,
|
||||
hits=[],
|
||||
applied_filter={},
|
||||
missing_required_slots=[
|
||||
{
|
||||
"field_key": "order_id",
|
||||
"label": "订单号",
|
||||
"reason": "必填字段缺失",
|
||||
"ask_back_prompt": "请提供您的订单号",
|
||||
},
|
||||
],
|
||||
filter_debug={"source": "builder"},
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
print("\n[知识库检索缺失槽位测试]")
|
||||
print(f"Success: {result.success}")
|
||||
print(f"Fallback Reason: {result.fallback_reason_code}")
|
||||
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
|
||||
|
||||
assert result.success is False
|
||||
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
|
||||
assert len(result.missing_required_slots) == 1
|
||||
|
||||
|
||||
class TestTimingBreakdown:
|
||||
"""耗时分解测试"""
|
||||
|
||||
def test_timing_breakdown_structure(self):
|
||||
"""测试耗时分解结构"""
|
||||
timings = [
|
||||
TimingRecord("intent_matching", 0, 0.05, 50),
|
||||
TimingRecord("high_risk_check", 0.05, 0.08, 30),
|
||||
TimingRecord("kb_search", 0.08, 0.2, 120),
|
||||
TimingRecord("llm_generation", 0.2, 1.5, 1300),
|
||||
TimingRecord("output_guardrail", 1.5, 1.55, 50),
|
||||
TimingRecord("response_formatting", 1.55, 1.6, 50),
|
||||
]
|
||||
|
||||
total = sum(t.duration_ms for t in timings)
|
||||
|
||||
print("\n[耗时分解测试]")
|
||||
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
|
||||
print("-" * 45)
|
||||
|
||||
for t in timings:
|
||||
percentage = (t.duration_ms / total * 100) if total > 0 else 0
|
||||
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
|
||||
|
||||
print("-" * 45)
|
||||
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
|
||||
|
||||
assert total == 1600
|
||||
|
||||
|
||||
def run_manual_test():
|
||||
"""手动运行测试"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="中台对话集成测试")
|
||||
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
|
||||
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
|
||||
parser.add_argument("--api-key", default=None, help="API Key")
|
||||
parser.add_argument("--message", default="你好", help="测试消息")
|
||||
parser.add_argument("--scene", default=None, help="场景标识")
|
||||
parser.add_argument("--interactive", action="store_true", help="交互模式")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tester = DialogueIntegrationTester(
|
||||
base_url=args.url,
|
||||
tenant_id=args.tenant,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
|
||||
if args.interactive:
|
||||
print("\n=== 中台对话集成测试 - 交互模式 ===")
|
||||
print("输入 'quit' 退出\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = input("请输入消息: ").strip()
|
||||
if message.lower() == "quit":
|
||||
break
|
||||
|
||||
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
|
||||
|
||||
result = asyncio.run(tester.send_dialogue(
|
||||
user_message=message,
|
||||
scene=scene,
|
||||
))
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
else:
|
||||
result = asyncio.run(tester.send_dialogue(
|
||||
user_message=args.message,
|
||||
scene=args.scene,
|
||||
))
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_manual_test()
|
||||
|
|
@ -0,0 +1,541 @@
|
|||
"""
|
||||
Unit tests for Retrieval Strategy Service.
|
||||
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.retrieval_strategy import (
|
||||
ReactMode,
|
||||
RolloutConfig,
|
||||
RolloutMode,
|
||||
StrategyType,
|
||||
RetrievalStrategyStatus,
|
||||
RetrievalStrategySwitchRequest,
|
||||
RetrievalStrategyValidationRequest,
|
||||
ValidationResult,
|
||||
)
|
||||
from app.services.retrieval.strategy_service import (
|
||||
RetrievalStrategyService,
|
||||
StrategyState,
|
||||
get_strategy_service,
|
||||
)
|
||||
from app.services.retrieval.strategy_audit import (
|
||||
StrategyAuditService,
|
||||
get_audit_service,
|
||||
)
|
||||
from app.services.retrieval.strategy_metrics import (
|
||||
StrategyMetricsService,
|
||||
get_metrics_service,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalStrategySchemas:
|
||||
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
|
||||
|
||||
def test_rollout_config_off_mode(self):
|
||||
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
|
||||
config = RolloutConfig(mode=RolloutMode.OFF)
|
||||
assert config.mode == RolloutMode.OFF
|
||||
assert config.percentage is None
|
||||
assert config.allowlist is None
|
||||
|
||||
def test_rollout_config_percentage_mode(self):
|
||||
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
|
||||
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
||||
assert config.mode == RolloutMode.PERCENTAGE
|
||||
assert config.percentage == 50.0
|
||||
|
||||
def test_rollout_config_percentage_mode_missing_value(self):
|
||||
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
|
||||
with pytest.raises(ValueError, match="percentage is required"):
|
||||
RolloutConfig(mode=RolloutMode.PERCENTAGE)
|
||||
|
||||
def test_rollout_config_allowlist_mode(self):
|
||||
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
|
||||
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
|
||||
assert config.mode == RolloutMode.ALLOWLIST
|
||||
assert config.allowlist == ["tenant1", "tenant2"]
|
||||
|
||||
def test_rollout_config_allowlist_mode_missing_value(self):
|
||||
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
|
||||
with pytest.raises(ValueError, match="allowlist is required"):
|
||||
RolloutConfig(mode=RolloutMode.ALLOWLIST)
|
||||
|
||||
def test_retrieval_strategy_status(self):
|
||||
"""[AC-AISVC-RES-01] Status should contain all required fields."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.OFF)
|
||||
status = RetrievalStrategyStatus(
|
||||
active_strategy=StrategyType.DEFAULT,
|
||||
react_mode=ReactMode.NON_REACT,
|
||||
rollout=rollout,
|
||||
)
|
||||
assert status.active_strategy == StrategyType.DEFAULT
|
||||
assert status.react_mode == ReactMode.NON_REACT
|
||||
assert status.rollout.mode == RolloutMode.OFF
|
||||
|
||||
def test_switch_request_minimal(self):
|
||||
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
|
||||
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
|
||||
assert request.target_strategy == StrategyType.ENHANCED
|
||||
assert request.react_mode is None
|
||||
assert request.rollout is None
|
||||
assert request.reason is None
|
||||
|
||||
def test_switch_request_full(self):
|
||||
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
rollout=rollout,
|
||||
reason="Testing enhanced strategy",
|
||||
)
|
||||
assert request.target_strategy == StrategyType.ENHANCED
|
||||
assert request.react_mode == ReactMode.REACT
|
||||
assert request.rollout.percentage == 30.0
|
||||
assert request.reason == "Testing enhanced strategy"
|
||||
|
||||
|
||||
class TestRetrievalStrategyService:
|
||||
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh service instance for each test."""
|
||||
return RetrievalStrategyService()
|
||||
|
||||
def test_get_current_status_default(self, service):
|
||||
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
|
||||
status = service.get_current_status()
|
||||
|
||||
assert status.active_strategy == StrategyType.DEFAULT
|
||||
assert status.react_mode == ReactMode.NON_REACT
|
||||
assert status.rollout.mode == RolloutMode.OFF
|
||||
|
||||
def test_switch_strategy_to_enhanced(self, service):
|
||||
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.previous.active_strategy == StrategyType.DEFAULT
|
||||
assert response.current.active_strategy == StrategyType.ENHANCED
|
||||
assert response.current.react_mode == ReactMode.REACT
|
||||
|
||||
def test_switch_strategy_with_grayscale_percentage(self, service):
|
||||
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.current.active_strategy == StrategyType.ENHANCED
|
||||
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
|
||||
assert response.current.rollout.percentage == 50.0
|
||||
|
||||
def test_switch_strategy_with_allowlist(self, service):
|
||||
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
|
||||
rollout = RolloutConfig(
|
||||
mode=RolloutMode.ALLOWLIST,
|
||||
allowlist=["tenant_a", "tenant_b"],
|
||||
)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
|
||||
assert "tenant_a" in response.current.rollout.allowlist
|
||||
|
||||
def test_rollback_strategy(self, service):
|
||||
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
response = service.rollback_strategy()
|
||||
|
||||
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
||||
assert response.rollback_to.react_mode == ReactMode.NON_REACT
|
||||
|
||||
def test_rollback_without_previous_returns_default(self, service):
|
||||
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
|
||||
response = service.rollback_strategy()
|
||||
|
||||
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
||||
|
||||
def test_should_use_enhanced_strategy_default(self, service):
|
||||
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
|
||||
assert service.should_use_enhanced_strategy("tenant_a") is False
|
||||
|
||||
def test_should_use_enhanced_strategy_with_allowlist(self, service):
|
||||
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
||||
rollout = RolloutConfig(
|
||||
mode=RolloutMode.ALLOWLIST,
|
||||
allowlist=["tenant_a"],
|
||||
)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
assert service.should_use_enhanced_strategy("tenant_a") is True
|
||||
assert service.should_use_enhanced_strategy("tenant_b") is False
|
||||
|
||||
def test_get_route_mode_react(self, service):
|
||||
"""[AC-AISVC-RES-10] React mode should return react route."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
route = service.get_route_mode("test query")
|
||||
assert route == "react"
|
||||
|
||||
def test_get_route_mode_direct(self, service):
|
||||
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.DEFAULT,
|
||||
react_mode=ReactMode.NON_REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
route = service.get_route_mode("test query")
|
||||
assert route == "direct"
|
||||
|
||||
def test_get_route_mode_auto_short_query(self, service):
|
||||
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
|
||||
service._state.react_mode = None
|
||||
|
||||
route = service._auto_route("短问题", confidence=0.8)
|
||||
|
||||
assert route == "direct"
|
||||
|
||||
def test_get_route_mode_auto_multiple_conditions(self, service):
|
||||
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
|
||||
route = service._auto_route("查询订单状态和物流信息")
|
||||
|
||||
assert route == "react"
|
||||
|
||||
def test_get_route_mode_auto_low_confidence(self, service):
|
||||
"""[AC-AISVC-RES-13] Low confidence should use react route."""
|
||||
route = service._auto_route("test query", confidence=0.3)
|
||||
|
||||
assert route == "react"
|
||||
|
||||
def test_get_switch_history(self, service):
|
||||
"""Should track switch history."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
reason="Testing",
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
history = service.get_switch_history()
|
||||
|
||||
assert len(history) == 1
|
||||
assert history[0]["to_strategy"] == "enhanced"
|
||||
|
||||
|
||||
class TestRetrievalStrategyValidation:
|
||||
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return RetrievalStrategyService()
|
||||
|
||||
def test_validate_default_strategy(self, service):
|
||||
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
assert response.passed is True
|
||||
|
||||
def test_validate_enhanced_strategy(self, service):
|
||||
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
assert isinstance(response.passed, bool)
|
||||
assert len(response.results) > 0
|
||||
|
||||
def test_validate_specific_checks(self, service):
|
||||
"""[AC-AISVC-RES-06] Should run specific validation checks."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
checks=["metadata_consistency", "performance_budget"],
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
check_names = [r.check for r in response.results]
|
||||
assert "metadata_consistency" in check_names
|
||||
assert "performance_budget" in check_names
|
||||
|
||||
def test_check_metadata_consistency(self, service):
|
||||
"""[AC-AISVC-RES-04] Metadata consistency check."""
|
||||
result = service._check_metadata_consistency(StrategyType.DEFAULT)
|
||||
|
||||
assert result.check == "metadata_consistency"
|
||||
assert result.passed is True
|
||||
|
||||
def test_check_rrf_config(self, service):
|
||||
"""[AC-AISVC-RES-02] RRF config check."""
|
||||
result = service._check_rrf_config(StrategyType.DEFAULT)
|
||||
|
||||
assert result.check == "rrf_config"
|
||||
assert isinstance(result.passed, bool)
|
||||
|
||||
def test_check_performance_budget(self, service):
|
||||
"""[AC-AISVC-RES-08] Performance budget check."""
|
||||
result = service._check_performance_budget(
|
||||
StrategyType.ENHANCED,
|
||||
ReactMode.REACT,
|
||||
)
|
||||
|
||||
assert result.check == "performance_budget"
|
||||
assert isinstance(result.passed, bool)
|
||||
|
||||
|
||||
class TestStrategyAuditService:
|
||||
"""[AC-AISVC-RES-07] Tests for audit service."""
|
||||
|
||||
@pytest.fixture
|
||||
def audit_service(self):
|
||||
return StrategyAuditService(max_entries=100)
|
||||
|
||||
def test_log_switch_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-07] Should log switch operation."""
|
||||
audit_service.log(
|
||||
operation="switch",
|
||||
previous_strategy="default",
|
||||
new_strategy="enhanced",
|
||||
reason="Testing",
|
||||
operator="admin",
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log()
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "switch"
|
||||
assert entries[0].previous_strategy == "default"
|
||||
assert entries[0].new_strategy == "enhanced"
|
||||
|
||||
def test_log_rollback_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-07] Should log rollback operation."""
|
||||
audit_service.log_rollback(
|
||||
previous_strategy="enhanced",
|
||||
new_strategy="default",
|
||||
reason="Performance issue",
|
||||
operator="admin",
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log(operation="rollback")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "rollback"
|
||||
|
||||
def test_log_validation_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-06] Should log validation operation."""
|
||||
audit_service.log_validation(
|
||||
strategy="enhanced",
|
||||
checks=["metadata_consistency"],
|
||||
passed=True,
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log(operation="validate")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "validate"
|
||||
|
||||
def test_get_audit_log_with_limit(self, audit_service):
|
||||
"""Should limit audit log entries."""
|
||||
for i in range(10):
|
||||
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
|
||||
|
||||
entries = audit_service.get_audit_log(limit=5)
|
||||
assert len(entries) == 5
|
||||
|
||||
def test_get_audit_stats(self, audit_service):
|
||||
"""Should return audit statistics."""
|
||||
audit_service.log(operation="switch", new_strategy="enhanced")
|
||||
audit_service.log(operation="rollback", new_strategy="default")
|
||||
|
||||
stats = audit_service.get_audit_stats()
|
||||
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["operation_counts"]["switch"] == 1
|
||||
assert stats["operation_counts"]["rollback"] == 1
|
||||
|
||||
def test_clear_audit_log(self, audit_service):
|
||||
"""Should clear audit log."""
|
||||
audit_service.log(operation="switch", new_strategy="enhanced")
|
||||
assert len(audit_service.get_audit_log()) == 1
|
||||
|
||||
count = audit_service.clear_audit_log()
|
||||
assert count == 1
|
||||
assert len(audit_service.get_audit_log()) == 0
|
||||
|
||||
|
||||
class TestStrategyMetricsService:
|
||||
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
|
||||
|
||||
@pytest.fixture
|
||||
def metrics_service(self):
|
||||
return StrategyMetricsService()
|
||||
|
||||
def test_record_request(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record request metrics."""
|
||||
metrics_service.record_request(
|
||||
latency_ms=100.0,
|
||||
success=True,
|
||||
route_mode="direct",
|
||||
)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.successful_requests == 1
|
||||
assert metrics.avg_latency_ms == 100.0
|
||||
|
||||
def test_record_failed_request(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record failed request."""
|
||||
metrics_service.record_request(latency_ms=50.0, success=False)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.failed_requests == 1
|
||||
|
||||
def test_record_fallback(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record fallback count."""
|
||||
metrics_service.record_request(
|
||||
latency_ms=100.0,
|
||||
success=True,
|
||||
fallback=True,
|
||||
)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.fallback_count == 1
|
||||
|
||||
def test_record_route_metrics(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should track route mode metrics."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
|
||||
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
|
||||
|
||||
route_metrics = metrics_service.get_route_metrics()
|
||||
assert "react" in route_metrics
|
||||
assert "direct" in route_metrics
|
||||
|
||||
def test_get_all_metrics(self, metrics_service):
|
||||
"""Should get metrics for all strategies."""
|
||||
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
|
||||
all_metrics = metrics_service.get_all_metrics()
|
||||
|
||||
assert StrategyType.DEFAULT.value in all_metrics
|
||||
assert StrategyType.ENHANCED.value in all_metrics
|
||||
|
||||
def test_get_performance_summary(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should get performance summary."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
metrics_service.record_request(latency_ms=200.0, success=True)
|
||||
metrics_service.record_request(latency_ms=50.0, success=False)
|
||||
|
||||
summary = metrics_service.get_performance_summary()
|
||||
|
||||
assert summary["total_requests"] == 3
|
||||
assert summary["successful_requests"] == 2
|
||||
assert summary["failed_requests"] == 1
|
||||
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
|
||||
|
||||
def test_check_performance_threshold_ok(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
|
||||
result = metrics_service.check_performance_threshold(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
max_latency_ms=5000.0,
|
||||
max_error_rate=0.1,
|
||||
)
|
||||
|
||||
assert result["latency_ok"] is True
|
||||
assert result["error_rate_ok"] is True
|
||||
assert result["overall_ok"] is True
|
||||
|
||||
def test_check_performance_threshold_exceeded(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
|
||||
metrics_service.record_request(latency_ms=6000.0, success=True)
|
||||
metrics_service.record_request(latency_ms=100.0, success=False)
|
||||
|
||||
result = metrics_service.check_performance_threshold(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
max_latency_ms=5000.0,
|
||||
max_error_rate=0.1,
|
||||
)
|
||||
|
||||
assert result["latency_ok"] is False or result["error_rate_ok"] is False
|
||||
|
||||
def test_reset_metrics(self, metrics_service):
|
||||
"""Should reset metrics."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
metrics_service.reset_metrics()
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.total_requests == 0
|
||||
|
||||
|
||||
class TestSingletonInstances:
|
||||
"""Tests for singleton instance getters."""
|
||||
|
||||
def test_get_strategy_service_singleton(self):
|
||||
"""Should return same strategy service instance."""
|
||||
from app.services.retrieval.strategy_service import _strategy_service
|
||||
|
||||
import app.services.retrieval.strategy_service as module
|
||||
module._strategy_service = None
|
||||
|
||||
service1 = get_strategy_service()
|
||||
service2 = get_strategy_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
||||
def test_get_audit_service_singleton(self):
|
||||
"""Should return same audit service instance."""
|
||||
from app.services.retrieval.strategy_audit import _audit_service
|
||||
|
||||
import app.services.retrieval.strategy_audit as module
|
||||
module._audit_service = None
|
||||
|
||||
service1 = get_audit_service()
|
||||
service2 = get_audit_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
||||
def test_get_metrics_service_singleton(self):
|
||||
"""Should return same metrics service instance."""
|
||||
from app.services.retrieval.strategy_metrics import _metrics_service
|
||||
|
||||
import app.services.retrieval.strategy_metrics as module
|
||||
module._metrics_service = None
|
||||
|
||||
service1 = get_metrics_service()
|
||||
service2 = get_metrics_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Integration tests for Retrieval Strategy API.
|
||||
[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints.
|
||||
|
||||
Tests the full API flow:
|
||||
- GET /strategy/retrieval/current
|
||||
- POST /strategy/retrieval/switch
|
||||
- POST /strategy/retrieval/validate
|
||||
- POST /strategy/retrieval/rollback
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_api_key_service():
|
||||
"""
|
||||
Mock API key service to bypass authentication in tests.
|
||||
"""
|
||||
mock_service = MagicMock()
|
||||
mock_service._initialized = True
|
||||
mock_service._keys_cache = {"test-api-key": MagicMock()}
|
||||
|
||||
mock_validation = MagicMock()
|
||||
mock_validation.ok = True
|
||||
mock_validation.reason = None
|
||||
mock_service.validate_key_with_context.return_value = mock_validation
|
||||
|
||||
with patch("app.services.api_key.get_api_key_service", return_value=mock_service):
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_strategy_state():
|
||||
"""
|
||||
Reset strategy state before and after each test.
|
||||
"""
|
||||
from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router
|
||||
from app.services.retrieval.strategy.config import RetrievalStrategyConfig
|
||||
|
||||
set_strategy_router(None)
|
||||
router = get_strategy_router()
|
||||
router.update_config(RetrievalStrategyConfig())
|
||||
yield
|
||||
set_strategy_router(None)
|
||||
router = get_strategy_router()
|
||||
router.update_config(RetrievalStrategyConfig())
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIIntegration:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Integration tests for retrieval strategy API.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_get_current_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-01] GET /current should return strategy status.
|
||||
"""
|
||||
response = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "active_strategy" in data
|
||||
assert "grayscale" in data
|
||||
assert data["active_strategy"] in ["default", "enhanced"]
|
||||
|
||||
def test_switch_strategy_to_enhanced(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-02] POST /switch should switch to enhanced strategy.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "success" in data
|
||||
assert data["success"] is True
|
||||
assert data["current_strategy"] == "enhanced"
|
||||
|
||||
def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] POST /switch should accept grayscale percentage.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"enabled": True,
|
||||
"percentage": 30.0,
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
|
||||
def test_switch_strategy_with_allowlist(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] POST /switch should accept allowlist.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"enabled": True,
|
||||
"allowlist": ["tenant_a", "tenant_b"],
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
|
||||
def test_validate_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-06] POST /validate should validate strategy.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "valid" in data
|
||||
assert "errors" in data
|
||||
assert isinstance(data["valid"], bool)
|
||||
|
||||
def test_validate_default_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-06] Default strategy should pass validation.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={
|
||||
"active_strategy": "default",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["valid"] is True
|
||||
|
||||
def test_rollback_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-07] POST /rollback should rollback to default.
|
||||
"""
|
||||
client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/strategy/retrieval/rollback",
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "success" in data
|
||||
assert data["current_strategy"] == "default"
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIValidation:
|
||||
"""
|
||||
[AC-AISVC-RES-03] Tests for API request validation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_switch_invalid_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] Invalid strategy value should return error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "invalid_strategy",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 422, 500]
|
||||
|
||||
def test_switch_percentage_out_of_range(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] Percentage > 100 should return validation error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"percentage": 150.0,
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIFlow:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Tests for complete API flow scenarios.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_complete_strategy_lifecycle(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-01~07] Test complete strategy lifecycle:
|
||||
1. Get current strategy
|
||||
2. Switch to enhanced
|
||||
3. Validate
|
||||
4. Rollback
|
||||
5. Verify back to default
|
||||
"""
|
||||
current = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert current.status_code == 200
|
||||
assert current.json()["active_strategy"] == "default"
|
||||
|
||||
switch = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {"enabled": True, "percentage": 50.0},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert switch.status_code == 200
|
||||
assert switch.json()["current_strategy"] == "enhanced"
|
||||
|
||||
validate = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={"active_strategy": "enhanced"},
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert validate.status_code == 200
|
||||
|
||||
rollback = client.post(
|
||||
"/strategy/retrieval/rollback",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert rollback.status_code == 200
|
||||
assert rollback.json()["current_strategy"] == "default"
|
||||
|
||||
final = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert final.status_code == 200
|
||||
assert final.json()["active_strategy"] == "default"
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIMissingTenant:
|
||||
"""
|
||||
Tests for API behavior without tenant ID.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_headers(self):
|
||||
return {"X-API-Key": "test-api-key"}
|
||||
|
||||
def test_current_without_tenant(self, client, api_key_headers):
|
||||
"""
|
||||
Missing X-Tenant-Id should return 400.
|
||||
"""
|
||||
response = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=api_key_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_switch_without_tenant(self, client, api_key_headers):
|
||||
"""
|
||||
Missing X-Tenant-Id should return 400.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={"active_strategy": "enhanced"},
|
||||
headers=api_key_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# 7年级到课赠礼
|
||||
|
||||
这是一张7年级到课赠礼的图片,可能包含到课奖励、学习用品或相关礼品的信息。图片可能展示赠礼的实物照片、礼品包装或相关的宣传内容。
|
||||
|
After Width: | Height: | Size: 466 KiB |
|
|
@ -0,0 +1,25 @@
|
|||
# s班小学3年级到课赠礼
|
||||
|
||||
这张图片是一张教育类宣传海报,主题为“飞跃领航计划”,核心内容是展示“专属福利大礼包”的五天完课福利,同时包含人物展示与视觉装饰元素。整体风格活泼,色彩明快,以浅蓝 - 浅绿渐变为主背景,搭配卡通元素增强亲和力。
|
||||
|
||||
## 1. 文字内容(按视觉顺序提取)
|
||||
- 标题区:左上角橙色标签“3阶”,主标题“飞跃领航计划”(“飞跃”为黑色粗体,“领航计划”为绿色粗体);下方绿色横幅“专属福利大礼包”。
|
||||
- 人物区:三位人物(两位女性、一位男性)并排站立,每人旁有蓝色标签标注姓名:**张婷婷**、**褚佳麟**、**王亚男**;人物下方黄色标签“思维 | 人文 | 剑桥”。
|
||||
- 福利列表区(白色背景框内,按天划分):
|
||||
- 第一天/完课福利:①《思维模块知识导图册》+《三年级口算14000题》
|
||||
- 第二天/完课福利:①精选双语动画电影20部(上10部) ②《应用题专项练习》+《小升初数学知识要点汇总》
|
||||
- 第三天/完课福利:①《知识大盘点+易错大集合》 ②《世界上下五千年》音频资料
|
||||
- 第四天/完课福利:①《考前高效培优知识梳理总复习》+《期末检测卷》2套
|
||||
- 第五天/完课福利:①精选双语动画电影20部(下10部)
|
||||
|
||||
## 2. 人物与视觉元素
|
||||
- 人物:三位形象正面、微笑,穿着职业装(女性为衬衫/西装,男性为浅灰西装),姿态自然,传递专业与亲和感。
|
||||
- 颜色:背景为浅蓝 - 浅绿渐变;标题文字黑、绿对比;人物标签蓝色;福利列表背景白色,文字黑色;卡通元素(底部橙子、礼物盒、小图标)色彩鲜艳(橙、黄、紫等),增加活泼感。
|
||||
- 布局:顶部为标题+人物展示区,中间为福利列表(分天排版,用圆点区分项目),底部为装饰性卡通元素(如橙子、礼物、电话/书本图标),整体结构清晰,信息层级分明。
|
||||
|
||||
## 3. 风格与意图
|
||||
海报风格偏向教育类宣传的“活泼专业”:通过卡通元素降低距离感,通过人物展示增强信任感,通过分天福利列表清晰传递“完课奖励”的核心信息,目标受众可能是中小学生或家长,旨在推广“飞跃领航计划”课程。
|
||||
|
||||
## 4. 额外说明
|
||||
- 视觉细节:底部卡通元素(如带笑脸的橙子、礼物盒)呼应“福利礼包”主题,增强趣味性;人物标签“思维 | 人文 | 剑桥”暗示课程涵盖多学科与国际化(剑桥)特色。
|
||||
- 信息逻辑:福利列表从“知识导图+口算”到“动画电影+音频”,再到“复习资料+试卷”,覆盖“知识输入 - 兴趣培养 - 复习巩固”全流程,体现课程设计的完整性。
|
||||
|
After Width: | Height: | Size: 482 KiB |
|
After Width: | Height: | Size: 618 KiB |
|
|
@ -0,0 +1,65 @@
|
|||
# 一部7年级课表
|
||||
|
||||
这是一张初一训练营的课程表,主题为"双语素养+科学思维",旨在帮助学生初一/初一定三年,学习不犯难。课程表包含四位主讲老师的信息和详细的课程安排,还提供了App下载二维码。
|
||||
|
||||
## 1. 课程主题与目标
|
||||
- 主标题:训练营 飞跃领航计划
|
||||
- 核心主题:双语素养+科学思维
|
||||
- 目标受众:初一/初一定三年,学习不犯难
|
||||
- 下载方式:扫码下载App听课(提供安卓和苹果版本二维码)
|
||||
|
||||
## 2. 主讲老师信息
|
||||
1. **陈久杰** - 科学探究主讲
|
||||
- 科学探究资深主讲老师
|
||||
- 12年线上与线下教学经验
|
||||
- 科学探究教学负责人
|
||||
- 《中考物理满分冲刺》主编
|
||||
- 人事人才网家庭教育高级指导师
|
||||
- 北京大学心理发展研修班进修
|
||||
|
||||
2. **毕玉琦** - 卓越双语主讲
|
||||
- 卓越双语学科教学总负责人和资深主讲老师
|
||||
- 全国巡回英语讲座200余场
|
||||
- 获新加坡南洋理工大学TESOL教学认证
|
||||
- 一线英语教学10余年
|
||||
- 培养中考英语高分学员10000+人
|
||||
|
||||
3. **阙红乾** - 思维逻辑主讲
|
||||
- 中考满分或保送学员41位清北学员12位
|
||||
- 12年授课经验,线上学员过百万
|
||||
- 海淀数学联赛顶尖师资
|
||||
- 全体教师赛课一等奖
|
||||
- S级主讲、教学大师奖、最佳教学效果奖、最受学生喜爱奖、最具课程吸引力奖、中考学员王者之师、金牌培训师
|
||||
|
||||
4. **张晓煜** - 人文博学主讲
|
||||
- 资深人文主讲、主讲培训师
|
||||
- 深耕语文一线教学14年
|
||||
- 高级脑力潜能开发师
|
||||
- 高级家庭教育指导师
|
||||
- 高级思维导图教师
|
||||
- 《初中语文考点一本通》《精准练语文》《21天预复习》等图书主编
|
||||
|
||||
## 3. 课程安排表
|
||||
|
||||
| 时间 | 科目 | 课程名称 | 课程内容 | 听课要求 |
|
||||
|------|------|----------|----------|----------|
|
||||
| 周四 19:55-20:55 | 规划讲座 | 学习规划讲座 | 【家长必听】初一全年学习规划及高效学习方法,帮助孩子领跑新征程! | 家长必听 |
|
||||
| 周五 18:55-19:55 | 思维逻辑(数) | 双中会 | 同学们对于线段双动点或者角的双角平分线理解存在较大难度,往往读不懂题。本节课帮助孩子从底层知识理解该类题目,快速搞定双角平分线或者双中点问题。 | 亲子共听 |
|
||||
| 周五 19:55-20:55 | 卓越双语(英) | 有根有据记词汇 | 驼哥用词根词缀法帮大家拆解单词,用底层逻辑法教大家辨别一词多义,让词汇记忆与积累不再发愁,词汇题目迎刃而解。 | 亲子共听 |
|
||||
| 周六 18:55-19:55 | 科学探究(物) | 隐力剧场:无形托举者 | 大气压强是一种看不见摸不到的物理现象,本节课运用实验让同学们在动手当中观察现象总结结论,理解物理概念。 | 亲子共听 |
|
||||
| 周六 19:55-20:55 | 卓越双语(英) | 有的放矢通关语法 | 对名词和数词的分类难以掌握、名词的变形和数词的用法感到困惑。本节课将通过"武术有师父"等驼哥专属口诀帮你解决初一语法中的两大词法问题,夯实语法基础,提升10分。本节课通过讲解成长类作文结构,帮助学生掌握事件描写和开头结尾技巧,并分享高分素材,让学生写作文从流水账成长到高分作文。 | 亲子共听 |
|
||||
| 周六 14:00-15:10 | 人文博学(语) | 从流水账到黄金屋 | 本章是孩子初中第一次接触几何证明题+几何辅助线构造+几何模型构造,对几何核心思维培养至关重要!本节课带你搭建几何模型思维,感受模型秒解的魅力! | 亲子共听 |
|
||||
| 周日 18:55-19:55 | 思维逻辑(数) | 拐点模型 | 本章是孩子初中第一次接触几何证明题+几何辅助线构造+几何模型构造,对几何核心思维培养至关重要!本节课带你搭建几何模型思维,感受模型秒解的魅力! | 亲子共听 |
|
||||
| 周日 19:55-20:55 | 科学探究(物) | 解密声音密码 | 通过实验法讲解抽象物理概念,让学生可以通透的理解声音的三大要素并能掌握核心的考点解释生活现象。启发物理思维,培养物理兴趣,为初二学习物理做好铺垫。 | 亲子共听 |
|
||||
|
||||
## 4. 设计特点
|
||||
- 色彩:浅蓝渐变背景,搭配橙色、绿色等明亮色彩,整体风格活泼专业
|
||||
- 布局:左侧为老师介绍,右侧为课程表,信息分区明确
|
||||
- 互动:提供二维码下载App,方便学生和家长听课
|
||||
- 听课要求:区分"家长必听"和"亲子共听",体现不同课程的参与方式
|
||||
|
||||
## 5. 课程特色
|
||||
- 覆盖多学科:科学探究、卓越双语、思维逻辑、人文博学
|
||||
- 时间安排合理:工作日晚上和周末安排课程
|
||||
- 实用性强:课程内容针对初中学习重点和难点
|
||||
- 方法指导:不仅教授知识,还提供学习方法指导
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# 以前缺少学习动力,高途的直播课让孩子学习态度积极;对课程老师都很认可
|
||||
|
||||
这是一张用户评价截图,展示了家长对孩子在高途直播课学习情况的反馈。截图呈现了对话形式,包含多位家长的评价内容,主要表达了对课程设计和老师教学的认可。
|
||||
|
||||
## 1. 主要评价内容
|
||||
|
||||
### 左侧评价(陈琦家长)
|
||||
- **用户名**:陈琦(chengqi34)
|
||||
- **评价内容**:
|
||||
- "孩子最近状态咋样啊"
|
||||
- "上课挺认真的!就是正确率仅在及格线上一点。😊"
|
||||
- "你们的课程和老师都有趣!她很喜欢!😊"
|
||||
- "你们的二讲老师超级负责!课后不懂的作业,老师会打电话来指导!这个很让我感动!👍👍👍"
|
||||
- "也非常感谢您帮我找了三个这么优秀的二讲老师!👏👏👏"
|
||||
- "现在她还没养成预习、复习、订正的习惯。"
|
||||
- "如果这个习惯养成了,估计她后续的学习不会那么吃力了!上课更不会睡觉了!😊"
|
||||
- "你们的课程互动环节设计得非常好!能牢牢抓住她的注意力!👍👍👍"
|
||||
- "她现在上课不用我催促,非常自觉!😊"
|
||||
- "也多谢您提醒我让她上午上完课就睡觉。这样她一点半上课就不会犯困了,整个下午也有精神。😊"
|
||||
- "还是您懂因材施教!👍👍👍"
|
||||
- **时间**:8/16 17:37:55
|
||||
|
||||
### 右侧评价
|
||||
- **评价内容**:
|
||||
- "孩子的笔记越来越好了!能看出有进步"
|
||||
- "嗯嗯笔记比以前有进步,这得感谢小王老师的风趣的课堂氛围,与严格的监督与督促👍👍👍"
|
||||
- "以前缺少学习动力,高途的直播课让同学对学习态度积极😊"
|
||||
- "还是要看孩子练习中的问题!勇于探索与解决问题最重要啦~"
|
||||
- "这孩子就是缺少学习动力,现在咱们的直播氛围比较喜欢 所以好像学习上有点积极态度了。"
|
||||
- "学习本来就是很快乐的事情,没有那么难,勤奋多思考"
|
||||
- "感谢您的辛苦付出🌹🌹🌹"
|
||||
- "孩子的初中很好~孩子多勤奋一些,英语真不难,词汇多背,语法搞懂,多记笔记,多练习"
|
||||
|
||||
## 2. 视觉设计
|
||||
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
|
||||
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
|
||||
- **表情**:包含多种表情符号(😊、👍、👏、🌹等),增加亲和力
|
||||
- **格式**:使用气泡对话框形式,模拟真实聊天场景
|
||||
|
||||
## 3. 评价特点
|
||||
- **真实感**:采用对话形式,增强可信度
|
||||
- **具体反馈**:包含具体的学习进步描述(笔记变好、上课认真等)
|
||||
- **多角度评价**:从不同家长角度反映课程效果
|
||||
- **情感表达**:包含感谢和积极的情感表达
|
||||
- **教学认可**:特别提到老师负责、课程有趣、互动设计好等优势
|
||||
|
||||
## 4. 营销意图
|
||||
- 展示真实用户评价,建立信任
|
||||
- 突出课程对学习动力的提升作用
|
||||
- 强调老师负责和课程设计优秀
|
||||
- 体现因材施教的教学理念
|
||||
- 展示学习进步的具体案例
|
||||
|
||||
这张评价截图有效地展示了高途直播课的教学效果和家长满意度,通过真实的对话形式传递课程价值。
|
||||
|
After Width: | Height: | Size: 602 KiB |
|
|
@ -0,0 +1,51 @@
|
|||
# 喜欢主讲老师,第一次上80很难得
|
||||
|
||||
这是一张用户评价截图,展示了家长对孩子在高途课程学习情况的反馈,特别提到了对杨易老师的喜爱和成绩提升的情况。
|
||||
|
||||
## 1. 主要评价内容
|
||||
|
||||
### 左侧评价
|
||||
- **用户反馈**:
|
||||
- "他是因为杨易老师了,所以杨易老师说啥他都积极响应"
|
||||
- "从娃上网课的状态就能看出,他其实蛮专注力一点问题都没有。"
|
||||
- "还是看老师讲的是不是他感兴趣的,上课方式是不是他喜欢的。"
|
||||
- "他上课的状态太好了,我忍不住拍视频给我妈看😊"
|
||||
- **系统回复**:
|
||||
- "哈哈哈,这个状态太喜人啦😊"
|
||||
|
||||
### 右侧评价
|
||||
- **用户反馈**:
|
||||
- "八下第一次上了80,很难得。比之前进步很多"
|
||||
- "他考试有进步的。期中考试84之前都是79左右,80分很难得"
|
||||
- "哇塞,太惊喜了🌹🌹八下语文还是比较难的,取得了进步,太不错了,看得出来孩子真的投入了,努力了"
|
||||
- "咱们继续保持,加油加油,我会持续关注瀚清😊"
|
||||
- "我跟天翼老师也分享一下下喜报😊"
|
||||
- "之前一直在其他机构,换了高途提升很多"
|
||||
- "确实高途的课程适合他,也是他选择换课程的。"
|
||||
- "必须分享"
|
||||
- "嗯嗯,孩子也很努力,我看好孩子,后面也会越来越好的,咱们一起相互配合,一起加油😊"
|
||||
- "妈妈也是第一时间被天翼老师吸引"
|
||||
- "我也是第一时间被天翼老师吸引才分享给刘瀚清试课的"
|
||||
|
||||
## 2. 视觉设计
|
||||
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
|
||||
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
|
||||
- **表情**:包含多种表情符号(😊、🌹等),增加亲和力
|
||||
- **格式**:使用气泡对话框形式,模拟真实聊天场景
|
||||
- **图片**:左侧包含一个小图片,显示上课场景
|
||||
|
||||
## 3. 评价特点
|
||||
- **具体成绩提升**:明确提到"八下第一次上了80",显示具体的学习进步
|
||||
- **老师影响**:强调杨易老师对孩子学习状态的积极影响
|
||||
- **课程转换**:提到从其他机构转到高途课程
|
||||
- **家长认可**:表达对天翼老师的认可和吸引
|
||||
- **真实感**:采用对话形式,增强可信度
|
||||
|
||||
## 4. 营销意图
|
||||
- 展示真实用户评价,建立信任
|
||||
- 突出老师对学生学习的积极影响
|
||||
- 体现课程效果的具体案例
|
||||
- 展示学生成绩提升的实例
|
||||
- 强调课程转换后的积极变化
|
||||
|
||||
这张评价截图有效地展示了高途课程的教学效果和家长满意度,通过真实的对话形式传递课程价值,特别突出了老师对学生学习状态的积极影响。
|
||||
|
After Width: | Height: | Size: 296 KiB |
|
|
@ -0,0 +1,54 @@
|
|||
# 成绩提升30多分
|
||||
|
||||
这是一张用户评价截图,展示了学生在高途课程学习后取得显著成绩提升的情况,特别突出了英语成绩的进步。
|
||||
|
||||
## 1. 主要评价内容
|
||||
|
||||
### 左侧评价
|
||||
- **用户反馈**:
|
||||
- "老师我英语从60多进步到80多"
|
||||
- "老师我们出成绩了,考的是789章的,比上一次考试进步了31分!"
|
||||
- "谢谢老师夸奖"
|
||||
- "这段时间我上课认真听课,单词好好背诵,然后也认真听褚帅老师的课"
|
||||
- "总分从405进步到432了"
|
||||
- "进步的分数里,英语占了一大半"
|
||||
|
||||
### 右侧评价
|
||||
- **系统回复**:
|
||||
- "我就知道你是可以的😊"
|
||||
- "你怎么进步了,分享一下这段时间的自己规划"
|
||||
- "🌹"
|
||||
|
||||
### 底部信息
|
||||
- **成绩展示**:
|
||||
- "进步30多分"(红色大字)
|
||||
- 英语成绩:70分(可能为本次成绩)
|
||||
- 英语成绩:66分(可能为上次成绩)
|
||||
- "老师,我这次英语成绩比上次进步了30分。"
|
||||
- "然后作文也有很大进步。"
|
||||
- "哇塞"
|
||||
|
||||
## 2. 视觉元素
|
||||
- **图片**:左侧包含两张图片
|
||||
- 上方:学生成绩单照片
|
||||
- 下方:Lays薯片卡通形象
|
||||
- **卡通形象**:右侧有一个可爱的卡通小熊形象,增加亲和力
|
||||
- **成绩数据**:明确显示具体的分数提升(31分,总分提升27分)
|
||||
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
|
||||
- **布局**:左右分栏的对话形式,底部有成绩展示区域
|
||||
|
||||
## 3. 评价特点
|
||||
- **具体成绩数据**:明确提到进步31分,总分从405到432
|
||||
- **学科重点**:特别强调英语成绩的提升(从60多到80多)
|
||||
- **学习态度**:提到认真听课和背诵单词的良好学习习惯
|
||||
- **多科进步**:不仅英语进步,作文也有很大进步
|
||||
- **真实感**:包含成绩单照片,增强可信度
|
||||
|
||||
## 4. 营销意图
|
||||
- 展示具体的成绩提升数据,证明课程效果
|
||||
- 突出英语学科的显著进步
|
||||
- 强调学生良好的学习态度和习惯
|
||||
- 通过真实成绩单建立信任
|
||||
- 展示多学科综合提升的效果
|
||||
|
||||
这张评价截图通过具体的成绩数据和真实的成绩单,有效地展示了高途课程对学生成绩的显著提升作用,特别是英语学科的进步。
|
||||
|
After Width: | Height: | Size: 288 KiB |
|
|
@ -0,0 +1,56 @@
|
|||
# 提升成绩
|
||||
|
||||
这是一张即时通讯软件的聊天记录截图,内容围绕学生与老师的对话,核心是学生成绩进步的反馈与交流。画面以聊天气泡为主要视觉元素,通过文字、表情、图片及少量图表呈现信息,整体风格偏向日常沟通,带有积极鼓励的氛围。
|
||||
|
||||
## 1. 文字内容(按出现顺序提取)
|
||||
- 时间戳:`19:38`、`12:10`
|
||||
- 对话文本:
|
||||
- "老师老师,这次期中我进步了11分"
|
||||
- "强"(配图,黑色背景+白色艺术字)
|
||||
- "进班1个月进步11分 李琪老师太牛了!!!"
|
||||
- "你太棒啦~宝"
|
||||
- "安在的半期成绩"
|
||||
- "进步明显"
|
||||
- "感谢老师,英语有明显进步"
|
||||
- "希望在老师的教导下再进步点"
|
||||
- "太棒啦~继续努力呢👍"
|
||||
- "麻烦妈妈吧安在同学的试卷和答题卡发给孟孟老师呢"
|
||||
- "看一下本次出现的问题以及接下来的问题解决"
|
||||
- "五一假期快乐呀~我看咱宝第二讲的作业还是没有完成提交,尽量在五一假期结束之前完成学习哦😊 噢噢好的写了没交我现在还在外面回去我就交"
|
||||
- "老师我的数学成绩出来了"
|
||||
- "86/120"(红色边框突出显示)
|
||||
- "李琪老师太牛了!!! 进班不到2个月,提升42分!"
|
||||
- "嘿嘿"
|
||||
- "进步了是不是,我记得你进班的时候四五十分😀"
|
||||
- "对呀"
|
||||
- "之前43.5"(红色边框突出显示)
|
||||
|
||||
## 2. 视觉元素与布局
|
||||
- **布局**:聊天气泡呈垂直堆叠,不同气泡颜色区分发言者(如灰色为系统/非当前用户消息,蓝色为当前用户消息,红色为强调文字/成绩)。部分关键信息(如成绩"86/120""之前43.5")用红色边框突出,增强视觉焦点。
|
||||
- **颜色**:
|
||||
- 聊天气泡底色:灰色(系统消息)、白色(用户消息);
|
||||
- 文字颜色:黑色(常规文字)、红色(强调/成绩、感叹句);
|
||||
- 表情符号:包含😊、👍、😀等,增加情感表达。
|
||||
- **图表与图片**:
|
||||
- 中间位置有一张小图表(疑似成绩趋势图,含折线/柱状元素),配合文字"安在的半期成绩"展示进步;
|
||||
- "强"字配图(黑色背景+白色艺术字),强化"进步"的积极情绪。
|
||||
- **人物与关系**:对话涉及"李琪老师""孟孟老师"(教师)与"安在""咱宝"(学生),通过"妈妈"代为传递信息,体现家校沟通场景。
|
||||
|
||||
## 3. 语境与信息逻辑
|
||||
对话围绕**学生成绩进步**展开:学生汇报期中/半期成绩提升(如"进步11分""提升42分"),老师或家长表达鼓励("太牛了""继续努力"),同时涉及作业提交("五一假期完成学习")和后续问题解决("看本次问题及解决"),整体传递出"进步—鼓励—后续规划"的沟通逻辑。
|
||||
|
||||
## 4. 评价特点
|
||||
- **具体成绩数据**:明确提到进步11分和42分,以及具体的分数(86/120,之前43.5)
|
||||
- **老师认可**:特别提到李琪老师的优秀教学
|
||||
- **多科进步**:涉及期中成绩和数学成绩
|
||||
- **家校沟通**:通过妈妈传递信息,体现家校协作
|
||||
- **积极氛围**:使用感叹号、表情符号和鼓励性语言
|
||||
|
||||
## 5. 营销意图
|
||||
- 展示具体的成绩提升数据,证明课程效果
|
||||
- 突出老师的教学能力
|
||||
- 强调短期内的显著进步
|
||||
- 展示家校合作的积极效果
|
||||
- 通过真实对话建立信任
|
||||
|
||||
这张评价截图通过具体的成绩数据和真实的对话,有效地展示了高途课程对学生成绩的显著提升作用,特别是短期内的快速进步。
|
||||
|
After Width: | Height: | Size: 348 KiB |
|
|
@ -0,0 +1,54 @@
|
|||
# 物理拿下高分
|
||||
|
||||
这是一张关于"物理成绩提升"的宣传类信息图,以聊天记录和喜报形式呈现,核心内容是展示学生在物理学习上的成绩进步,突出教学效果。整体设计通过文字、表情、红色标注等元素传递积极信息,布局清晰,重点突出。
|
||||
|
||||
## 1. 文字内容(按位置提取)
|
||||
- **顶部标题**:`物理成绩拿下高分`(红色大字体,位于页面最上方,字体加粗,视觉冲击力强)。
|
||||
- **左侧红色框内聊天记录**:
|
||||
- `1😊`
|
||||
- `物理也挺难`
|
||||
- `😊`
|
||||
- **中间喜报区域**:
|
||||
- `山东济南中考高分喜报!`(红色字体,突出地域和事件类型)
|
||||
- `81/90分,妥妥拿下~`(红色字体,强调具体成绩)
|
||||
- **右侧聊天记录**:
|
||||
- `你八十几?`(浅蓝色背景,模拟聊天气泡)
|
||||
- `太有石粒辣`(黑色字体,搭配哭笑表情🤣,表达惊喜)
|
||||
- `老师孩子月考成绩出来了比年前期末考试成绩84分提高7分。这次91分满分是100分班级第一!孩子自从上咱物理课兴趣激情满满。还说让老师给推荐一下教辅🙏谢谢老师`(其中`84分提高7分`、`91分`、`班级第一`、`上咱物理课兴趣激情满满`被红色框标注,突出成绩提升关键信息)
|
||||
- `山东济南`(红色字体,标注地域,强化地域关联)
|
||||
- **下方聊天记录**:
|
||||
- `刚扣9分,高分奖励这不到手了吗`
|
||||
- `给你发2200学币,看看有没有可以兑换的啊`
|
||||
- **页面底部**:`高途好老师`(黑色字体,标注机构/老师名称)
|
||||
|
||||
## 2. 视觉元素与布局
|
||||
- **颜色**:
|
||||
- 红色:用于标题、喜报文字、重点成绩标注,传递喜庆、强调的视觉感受;
|
||||
- 黑色:用于普通聊天文字,清晰易读;
|
||||
- 浅蓝色:用于聊天气泡背景,模拟真实聊天界面;
|
||||
- 表情:黄色底色+蓝色眼泪的哭笑表情,增强情感表达。
|
||||
- **布局**:
|
||||
- 页面分为左、中、右三栏:左侧是简短聊天框,中间是核心喜报,右侧是详细成绩说明+地域标注,底部是机构名称,整体结构清晰,信息层级分明。
|
||||
- **其他元素**:
|
||||
- 红色边框:左侧聊天框和右侧重点成绩文字使用红色边框,引导视觉焦点;
|
||||
- 聊天气泡:模拟真实对话场景,增强代入感;
|
||||
- 表情符号:增加情感色彩,使内容更生动。
|
||||
|
||||
## 3. 语境与信息逻辑
|
||||
这张图片的目的是**宣传"高途好老师"的物理辅导效果**,通过真实(或模拟)的聊天记录和成绩数据,展示学生成绩提升(从84分到91分,提高7分,班级第一),强调课程对学习兴趣的激发("兴趣激情满满"),同时通过地域标注(山东济南)和机构名称(高途好老师)强化品牌关联。设计上通过红色、重点标注等视觉手段,快速传递"成绩提升"的核心信息,吸引目标用户(学生/家长)关注。
|
||||
|
||||
## 4. 评价特点
|
||||
- **具体成绩数据**:明确提到从84分提高到91分,提高7分,班级第一
|
||||
- **学科重点**:特别突出物理学科的成绩提升
|
||||
- **兴趣培养**:强调课程对学习兴趣的激发
|
||||
- **地域案例**:提供山东济南的具体案例
|
||||
- **奖励机制**:提到学币奖励,体现激励机制
|
||||
|
||||
## 5. 营销意图
|
||||
- 展示具体的物理成绩提升数据,证明课程效果
|
||||
- 突出老师的教学能力
|
||||
- 强调学习兴趣的培养
|
||||
- 提供地域性成功案例
|
||||
- 展示奖励机制,增加吸引力
|
||||
|
||||
这张评价截图通过具体的成绩数据和真实的对话形式,有效地展示了高途物理课程对学生成绩的显著提升作用,特别是对学习兴趣的激发。
|
||||
|
After Width: | Height: | Size: 259 KiB |
|
|
@ -0,0 +1,51 @@
|
|||
# 高途让家长放心,早点认识高途就好了
|
||||
|
||||
这是一张用户评价截图,展示了家长对高途课程的认可和满意,表达了"早点认识高途就好了"的遗憾心情。
|
||||
|
||||
## 1. 主要评价内容
|
||||
|
||||
### 左侧评价
|
||||
- **用户反馈**:
|
||||
- "谢谢老师,还是老师指导有方"
|
||||
- "跟着高途学,家长放心"
|
||||
- "暑假跟着高途,孩子语文妈妈我就不担心了"
|
||||
- "我们一起为孩子能考上重点高中而努力😊😊😊"
|
||||
|
||||
### 右侧评价
|
||||
- **用户反馈**:
|
||||
- "好的"
|
||||
- "写的很好哦!"
|
||||
- "表扬!希望能有所收获"
|
||||
- "您太客气啦!"
|
||||
- "接下来有问题都随时来问我哈"
|
||||
- "早点认识高途就好了"
|
||||
- "早点认识高途就好了"
|
||||
- "走了不少弯路"
|
||||
- "嗯嗯,慢慢来"
|
||||
- "语文学习本身就是一个潜移默化的过程"
|
||||
- "走了不少弯路"
|
||||
- "那个抖音上那种阅读,也买过,没有系统的教学,靠娃自觉还是不行,家长又不懂"
|
||||
- "嗯嗯!必须的!最后一年,拼尽全力!💪💪"
|
||||
|
||||
## 2. 视觉设计
|
||||
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
|
||||
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
|
||||
- **表情**:包含多种表情符号(😊、💪等),增加亲和力
|
||||
- **格式**:使用气泡对话框形式,模拟真实聊天场景
|
||||
- **重点文字**:关键语句如"跟着高途学,家长放心"和"早点认识高途就好了"用红色大字突出显示
|
||||
|
||||
## 3. 评价特点
|
||||
- **家长认可**:明确表达对高途课程的信任和放心
|
||||
- **时间紧迫感**:提到"最后一年,拼尽全力",体现中考临近的紧迫感
|
||||
- **对比体验**:提到之前在其他平台学习的不足,突出高途的系统教学优势
|
||||
- **遗憾情绪**:表达"早点认识高途就好了"和"走了不少弯路"的遗憾
|
||||
- **积极态度**:尽管有遗憾,但仍然保持积极的学习态度
|
||||
|
||||
## 4. 营销意图
|
||||
- 展示家长对课程的信任和满意
|
||||
- 突出高途的系统教学优势
|
||||
- 强调早期认识高途的重要性
|
||||
- 体现中考临近的紧迫感和学习动力
|
||||
- 通过真实对话建立信任
|
||||
|
||||
这张评价截图通过家长的真情实感,有效地展示了高途课程给家长带来的安心感和信任度,同时也暗示了早期选择高途的重要性。
|
||||
|
After Width: | Height: | Size: 172 KiB |
|
After Width: | Height: | Size: 213 KiB |
|
|
@ -0,0 +1,66 @@
|
|||
# 引导加辅导老师微信图片
|
||||
|
||||
这是一张操作指引图,用于指导用户在"高途"APP中查找第二讲辅导老师的微信,包含三张手机界面截图与底部的步骤说明。整体背景为暖色调橙黄色渐变,界面以白色为主,搭配橙色、红色等强调色,通过箭头指示操作流程,视觉上清晰醒目。
|
||||
|
||||
## 1. 整体布局与视觉风格
|
||||
- **背景**:顶部至底部为橙黄色渐变,营造活泼、醒目的氛围,突出指引内容。
|
||||
- **截图排列**:三张手机界面截图横向排列,每张截图通过橙色/黄色箭头指示操作逻辑,布局清晰。
|
||||
|
||||
## 2. 三张截图的文字与元素(从左至右)
|
||||
|
||||
### 左图("高途"APP首页/课程列表页)
|
||||
- **顶部状态栏**:显示时间"上午10:20"、网络速度"111K/s"、电池/信号等图标。
|
||||
- **"今日学习推荐"板块**:
|
||||
- 课程1:"不用读全文阅读大满贯" → 主讲【新9阶】双语素养+科学思维训练营,5月2日18:55开课。
|
||||
- 课程2:"大力出奇迹" → 第2讲,【新9阶】双语素养+科学思维训练营,5月2日19:56开课。
|
||||
- **"全部课程"板块**:
|
||||
- 课程标题:"【新9阶】双语素养+科学思维训练营" → 共7节·未学习,5月2日18:55开课。
|
||||
- 主讲老师头像:王冰、韩瑞05。
|
||||
- **底部导航栏**:首页、学习、**上课**(红色图标,当前选中)、发现、我的。
|
||||
|
||||
### 中图(课程详情页)
|
||||
- **顶部状态栏**:时间"上午10:20"、网络速度"0.3K/s"。
|
||||
- **课程标题**:【新9阶】双语素养+科学思维训练营 → 有效期至2024年06月03日。
|
||||
- **主讲老师**:王冰、李雪冬、王泽龙(头像+"主讲老师"标签)。
|
||||
- **功能入口**:学习资料、缓存课程、错题本(图标+文字)。
|
||||
- **学习进度**:听课进度0%、听课时长0分。
|
||||
- **"课前准备"板块**:
|
||||
- 橙色按钮:"添加二讲老师微信"(核心操作入口)。
|
||||
- 选项:"关注公众号收取报告" → 右侧"去关注"链接。
|
||||
- **课程列表**:
|
||||
- 第1讲:"不用读全文阅读大满贯" → 05月02日周四18:55 - 19:55 | 未开始。
|
||||
- 第2讲:"大力出奇迹" → 05月02日周四19:56 - 21:01 | 未开始。
|
||||
- 第3讲:"全等模型狂想曲"(部分可见)。
|
||||
|
||||
### 右图("高途成长助手"页面)
|
||||
- **顶部状态栏**:时间"上午10:20"、网络速度"80.4K/s"。
|
||||
- **页面标题**:高途成长助手 → 评分2.5。
|
||||
- **倒计时与提示**:剩余04:56:8 → "请务必添加老师 否则无法上课"(红色警示文字)。
|
||||
- **老师信息**:高途助教老师 专属(卡通头像+"专属"标签)。
|
||||
- **二维码区域**:提示"长按二维码,联系老师"。
|
||||
- **底部按钮**:全程辅导答疑、定制学习规划。
|
||||
|
||||
## 3. 底部操作步骤文字
|
||||
- 第一步,下载"高途"APP
|
||||
- 第二步,登录后,点击下方"上课"按钮
|
||||
- 第三部,点击已报名课程("第三部"疑似"第三步"笔误)
|
||||
- 第四步,点击"添加二讲老师微信"按钮,自动跳转到微信添加
|
||||
|
||||
## 4. 语境与信息逻辑
|
||||
这张图是教育类APP"高途"的运营指引,核心目标是引导用户通过APP完成"添加第二讲辅导老师微信"的操作,确保课程学习顺利进行。三张截图对应"进入上课页面→选择课程→添加老师微信"的逻辑链,步骤简洁、视觉重点突出(如红色"上课"按钮、橙色"添加老师微信"按钮),符合用户操作习惯。暖色调背景与醒目的文字/按钮设计,提升了信息传递效率,避免用户因未添加老师而无法上课。
|
||||
|
||||
## 5. 设计特点
|
||||
- **操作指引清晰**:通过三张截图展示完整的操作流程
|
||||
- **视觉重点突出**:使用红色和橙色强调关键按钮和操作
|
||||
- **信息层级分明**:不同功能区域通过颜色和布局区分
|
||||
- **用户友好**:包含倒计时和警示信息,提醒用户及时操作
|
||||
- **品牌展示**:展示高途APP的界面和功能
|
||||
|
||||
## 6. 营销意图
|
||||
- 指导用户正确使用APP功能
|
||||
- 确保用户能够顺利添加辅导老师
|
||||
- 展示APP的用户界面和功能
|
||||
- 强调添加老师的重要性
|
||||
- 提升用户体验和课程参与度
|
||||
|
||||
这张引导图通过清晰的步骤和视觉设计,有效地指导用户完成关键操作,确保课程学习的顺利进行。
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# 素养小学到课赠礼
|
||||
|
||||
这是一张小学素养课程的赠礼海报,主题为"飞跃领航计划",展示3-6阶学生的专属福利大礼包。海报设计活泼明快,以浅蓝和浅绿为主色调,包含三位主讲老师和四天的完课福利。
|
||||
|
||||
## 1. 课程主题与目标
|
||||
- **阶段**:3-6阶(橙色标签)
|
||||
- **主标题**:飞跃领航计划("飞跃"为黑色粗体,"领航计划"为绿色粗体)
|
||||
- **核心主题**:专属福利大礼包(绿色横幅)
|
||||
- **课程标签**:思维 | 人文 | 脑力(黄色横幅)
|
||||
|
||||
## 2. 主讲老师信息
|
||||
海报展示了三位主讲老师:
|
||||
1. **陈君** - 左侧女性老师,穿着灰色西装
|
||||
2. **杨易** - 中间男性老师,穿着黑色中式服装,手持折扇
|
||||
3. **白马** - 右侧男性老师,穿着深色西装
|
||||
|
||||
## 3. 福利列表(四天完课福利)
|
||||
|
||||
### 第一天/完课福利
|
||||
- 《脑王数独秘籍》
|
||||
|
||||
### 第二天/完课福利
|
||||
1. NCTE词汇表(带翻译)
|
||||
2. 《世界上下五千年》音频资料(上)
|
||||
|
||||
### 第三天/完课福利
|
||||
- 《世界上下五千年》音频资料(中)
|
||||
|
||||
### 第四天/完课福利
|
||||
1. 《最强大脑同款》-图形推理秘籍
|
||||
2. BBC自然拼读益智动画
|
||||
|
||||
## 4. 视觉设计
|
||||
- **颜色**:浅蓝渐变背景,搭配橙色、绿色、黄色等明亮色彩
|
||||
- **布局**:顶部为标题和老师展示区,中间为福利列表,底部为装饰性卡通元素
|
||||
- **卡通元素**:包含可爱的橙子形象和礼物盒图标,增加活泼感
|
||||
- **文字风格**:标题使用大号粗体字,福利列表使用清晰的编号和项目符号
|
||||
|
||||
## 5. 课程特色
|
||||
- **多学科覆盖**:包含思维训练、人文知识、脑力开发
|
||||
- **多样化资源**:包括书籍、音频资料、动画等不同形式的学习材料
|
||||
- **循序渐进**:四天的福利安排,每天都有新的学习内容
|
||||
- **趣味性**:结合《最强大脑》和BBC动画等有趣内容,增加学习吸引力
|
||||
|
||||
## 6. 营销意图
|
||||
- 展示课程的丰富福利,吸引学生参与
|
||||
- 突出多学科素养培养
|
||||
- 通过具体的学习材料展示课程价值
|
||||
- 强调循序渐进的学习方式
|
||||
- 利用知名节目和品牌(最强大脑、BBC)增加可信度
|
||||
|
||||
这张海报通过清晰的结构和丰富的福利内容,有效地展示了小学素养课程的价值,特别强调了思维、人文和脑力培养的综合素养教育。
|
||||
|
After Width: | Height: | Size: 451 KiB |
|
After Width: | Height: | Size: 270 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 8化试听-袁媛
|
||||
|
||||
这是一张初中化学试听课的海报图片,主讲老师是袁媛。海报应该包含课程介绍、试听时间、报名方式等相关信息。图片可能包含化学相关的图形元素、课程亮点等内容。
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# 8年级_思维_佳美_数与式找规律试听课(新初一初二通用)
|
||||
|
||||
这是一张初中数学试听课海报,主题是"数与式找规律",适合8年级学生,由佳美老师主讲。海报应该包含数学规律题目的示例、课程特色、试听时间安排等信息。图片可能包含数学公式、几何图形或相关的教学元素。
|
||||
|
After Width: | Height: | Size: 599 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 8数试听-YOYO游剑荷-角度模型
|
||||
|
||||
这是一张初中数学试听课海报,主题是"角度模型",由YOYO游剑荷老师主讲。海报应该包含角度相关的几何图形、解题方法、课程亮点等内容。图片可能展示各种角度模型的应用和示例。
|
||||
|
After Width: | Height: | Size: 358 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 新8思维刘璐——全等三角形与辅助线
|
||||
|
||||
这是一张初中数学试听课海报,主题是"全等三角形与辅助线",由刘璐老师主讲。海报应该包含全等三角形的几何图形、辅助线的画法、解题技巧等内容。图片可能展示各种全等三角形的示例和辅助线的应用。
|
||||
|
After Width: | Height: | Size: 1.2 MiB |
|
After Width: | Height: | Size: 250 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 7物试听-郭志强-基础版
|
||||
|
||||
这是一张初中物理试听课海报,主题是基础物理知识,由郭志强老师主讲。海报应该包含物理基础概念、实验演示、课程内容介绍等信息。图片可能包含物理实验装置、公式或相关的教学元素。
|
||||
|
After Width: | Height: | Size: 243 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 7阶试听-李雪冬-声学
|
||||
|
||||
这是一张初中物理试听课海报,主题是"声学",由李雪冬老师主讲。海报应该包含声音的产生、传播、特性等声学知识,以及相关的实验和演示内容。图片可能展示声学相关的图形、公式或实验装置。
|
||||
|
After Width: | Height: | Size: 246 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 8物试听-李雪冬-光的折射
|
||||
|
||||
这是一张初中物理试听课海报,主题是"光的折射",由李雪冬老师主讲。海报应该包含光的折射原理、相关公式、实验演示等内容。图片可能展示光的折射现象、透镜、棱镜等光学元素。
|
||||
|
After Width: | Height: | Size: 273 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 9物试听-电功率
|
||||
|
||||
这是一张初中物理试听课海报,主题是"电功率",适合9年级学生。海报应该包含电功率的计算、电路分析、相关公式等内容。图片可能展示电路图、电表、功率计算示例等元素。
|
||||
|
After Width: | Height: | Size: 275 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 9物试听-韩盛乔-基础版
|
||||
|
||||
这是一张初中物理试听课海报,主题是基础物理知识,由韩盛乔老师主讲,适合9年级学生。海报应该包含物理基础概念、公式、实验等内容。图片可能展示物理实验装置、公式或相关的教学元素。
|
||||
|
After Width: | Height: | Size: 277 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 9物试听-韩盛乔-浙教版
|
||||
|
||||
这是一张初中物理试听课海报,主题是浙教版物理课程,由韩盛乔老师主讲,适合9年级学生。海报应该包含浙教版物理教材的相关内容、课程特色、试听信息等。图片可能展示浙教版教材的元素或相关的教学资料。
|
||||
|
After Width: | Height: | Size: 627 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 7英试听-张丹丹-词汇—易混词辨析
|
||||
|
||||
这是一张初中英语试听课海报,主题是"词汇—易混词辨析",由张丹丹老师主讲,适合7年级学生。海报应该包含英语易混词汇的对比、辨析方法、例句等内容。图片可能展示词汇对比表格、相关例句或教学元素。
|
||||
|
After Width: | Height: | Size: 669 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 8英完形填空大招试听-张丹丹PNG
|
||||
|
||||
这是一张初中英语试听课海报,主题是"完形填空大招",由张丹丹老师主讲,适合8年级学生。海报应该包含完形填空解题技巧、常用方法、例题解析等内容。图片可能展示完形填空题目示例、解题步骤或相关的教学元素。
|
||||
|
After Width: | Height: | Size: 681 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 王冰老师试听课—话题写作(难忘经历类)
|
||||
|
||||
这是一张初中英语试听课海报,主题是"话题写作(难忘经历类)",由王冰老师主讲。海报应该包含英语话题写作的技巧、范文示例、写作方法等内容。图片可能展示写作模板、相关例句或教学元素。
|
||||
|
After Width: | Height: | Size: 411 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 8年级-虫子-试听
|
||||
|
||||
这是一张初中语文试听课海报,由虫子老师主讲,适合8年级学生。海报应该包含语文课程介绍、试听内容、教学特色等信息。图片可能展示语文相关的元素,如古诗词、文学作品或教学资料。
|
||||
|
After Width: | Height: | Size: 398 KiB |
|
|
@ -0,0 +1,3 @@
|
|||
# 写作试听课-虫子老师
|
||||
|
||||
这是一张初中语文试听课海报,主题是写作课程,由虫子老师主讲。海报应该包含写作技巧、范文示例、写作方法等内容。图片可能展示写作相关的元素,如作文示例、写作技巧图表或教学资料。
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# 小学数学杨易
|
||||
|
||||
这是一张小学数学试听课海报,由杨易老师主讲。海报应该包含小学数学课程介绍、试听内容、教学特色等信息。图片可能展示数学相关的元素,如数字、图形、算式或教学资料。
|
||||