Compare commits

...

13 Commits

Author SHA1 Message Date
MerCry a0044c4c42 [AC-DOCS] docs: 更新元数据和检索策略文档
- 更新 metadata-slot-prompt-recommendation.md 元数据槽位提示词推荐
- 更新 v0.9.0 检索嵌入策略规范文档
- 更新需求文档、设计文档和任务文档
2026-03-11 19:13:25 +08:00
MerCry a61fb72d2b [AC-SCRIPTS] chore: 新增临时工具脚本
- 新增 tmp_fix_metadata_cn.py 用于修复元数据中文编码
- 新增 tmp_kb_transform.py 用于知识库数据转换
- 新增 tmp_pack_kb_for_import.py 用于打包知识库导入数据
2026-03-11 19:12:41 +08:00
MerCry 60e16d65c9 [AC-MIGRATION] feat(db): 新增用户记忆迁移和工具脚本
- 新增 003_user_memories 迁移脚本支持用户记忆表
- 新增 clear_kb_vectors 脚本用于清理知识库向量
- 新增 svg 资源目录
2026-03-11 19:11:44 +08:00
MerCry a6276522c8 [AC-TEST] test: 新增单元测试和集成测试
- 新增 test_image_parser 图片解析器测试
- 新增 test_llm_multi_usage_config LLM 多用途配置测试
- 新增 test_markdown_chunker Markdown 分块测试
- 新增 test_metadata_auto_inference 元数据推断测试
- 新增 test_mid_dialogue_integration 对话集成测试
- 新增 test_retrieval_strategy 检索策略测试
- 新增 test_retrieval_strategy_integration 检索策略集成测试
2026-03-11 19:10:05 +08:00
MerCry 1490235b8f [AC-API-UPDATE] feat(api): 更新 API 端点和实体模型
- 更新 dialogue API 支持新的对话功能
- 更新 share_page API 优化分享页面
- 更新 main.py 注册新的路由模块
- 更新 entities 模型添加新字段
2026-03-11 19:07:03 +08:00
MerCry 9196247578 [AC-METADATA-INFERENCE] feat(metadata): 新增元数据自动推断服务
- 新增 metadata_auto_inference_service 实现元数据自动推断
- 新增 kb_metadata_inference 提供知识库元数据推断工具
- 支持从文档内容自动提取元数据字段
- 集成缓存机制提升推断效率
2026-03-11 19:06:21 +08:00
MerCry b3343f9e52 [AC-RETRIEVAL-STRATEGY] feat(retrieval): 新增检索策略路由服务
- 新增 strategy_service 实现检索策略路由核心逻辑
- 新增 strategy_metrics 提供策略性能指标收集
- 新增 strategy_audit 提供策略审计日志功能
- 新增 retrieval_strategy API 端点支持策略管理
- 支持多种检索策略的动态切换和监控
2026-03-11 19:02:40 +08:00
MerCry 6fec2a755a [AC-AGENT-ENHANCE] feat(mid): 增强 Agent 编排器和工具
- 优化 agent_orchestrator 的系统提示词指导
- 改进 kb_scene 参数的自动注入逻辑
- 增强 kb_search_dynamic_tool 的元数据处理
- 优化 memory_recall_tool 的记忆召回逻辑
- 更新 memory_adapter 的用户记忆模型
2026-03-11 19:01:51 +08:00
MerCry e45396e1e4 [AC-MEMORY-SUMMARY] feat(mid): 新增对话记忆摘要生成器
- 新增 MemorySummaryGenerator 用于生成对话摘要
- 新增 memory_summary_prompt 提供摘要生成提示词模板
- 支持将长对话历史压缩为简洁摘要
- 更新 mid 服务模块导出
2026-03-11 19:01:06 +08:00
MerCry e9de808969 [AC-KB-ENHANCE] feat(kb): 增强 KB 向量日志和元数据过滤功能
- 新增 KB 向量日志配置项 kb_vector_log_enabled 和 kb_vector_log_path
- 新增 KB 向量日志记录器支持滚动日志文件
- 增强 Qdrant 元数据过滤支持操作符格式 (\, \)
- 支持 MatchAny 实现多值匹配
- 新增图片文件索引支持
2026-03-11 18:57:27 +08:00
MerCry 4de2a2aece [AC-DOC-PARSER] feat(document): 新增图片和 Markdown 解析器
- 新增 ImageParser 支持图片文件解析
- 新增 MarkdownParser 支持 Markdown 文件解析
- 新增 MarkdownChunker 实现 Markdown 智能分块
- 支持按标题、段落、代码块等元素类型分块
- 更新 document 模块导出和工厂方法
2026-03-11 18:56:43 +08:00
MerCry b3680bda8a [AC-LLM-MULTI] feat(llm): 实现 LLM 多用途配置功能
- 新增 LLMUsageType 枚举支持 chat 和 kb_processing 两种用途
- 扩展 LLMConfig 支持按用途类型存储不同配置
- 更新 LLMClient 接口支持 Any 类型的消息内容
- 新增管理后台 API 支持获取用途类型列表和按用途获取配置
- 更新前端 LLM 配置页面支持多用途配置切换
2026-03-11 18:56:01 +08:00
MerCry 7134ec3c5e [AC-AISVC-RES-09~15] config: 将默认运行时模式改为 AUTO
- 修改 ModeRouterConfig.runtime_mode 默认值从 DIRECT 改为 AUTO
- 系统将根据查询复杂度和置信度自动决定使用 ReAct 模式还是通用 API 模式
- 短查询 + 高置信度 -> 使用 DIRECT 模式
- 复杂查询或低置信度 -> 使用 REACT 模式
2026-03-11 00:03:25 +08:00
117 changed files with 9885 additions and 329 deletions

View File

@ -6,7 +6,9 @@ import type {
LLMTestResult,
LLMTestRequest,
LLMProvidersResponse,
LLMConfigUpdateResponse
LLMUsageTypesResponse,
LLMConfigUpdateResponse,
LLMAllConfigs
} from '@/types/llm'
export function getLLMProviders(): Promise<LLMProvidersResponse> {
@ -16,10 +18,22 @@ export function getLLMProviders(): Promise<LLMProvidersResponse> {
})
}
export function getLLMConfig(): Promise<LLMConfig> {
export function getLLMUsageTypes(): Promise<LLMUsageTypesResponse> {
return request({
url: '/admin/llm/usage-types',
method: 'get'
})
}
export function getLLMConfig(usageType?: string): Promise<LLMConfig | LLMAllConfigs> {
const params: Record<string, string> = {}
if (usageType) {
params.usage_type = usageType
}
return request({
url: '/admin/llm/config',
method: 'get'
method: 'get',
params
})
}
@ -46,5 +60,7 @@ export type {
LLMTestResult,
LLMTestRequest,
LLMProvidersResponse,
LLMConfigUpdateResponse
LLMUsageTypesResponse,
LLMConfigUpdateResponse,
LLMAllConfigs
}

View File

@ -2,26 +2,40 @@ import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import {
getLLMProviders,
getLLMUsageTypes,
getLLMConfig,
updateLLMConfig,
testLLM,
type LLMProviderInfo,
type LLMConfig,
type LLMConfigUpdate,
type LLMTestResult
type LLMTestResult,
type LLMUsageType,
type LLMAllConfigs
} from '@/api/llm'
export const useLLMStore = defineStore('llm', () => {
const providers = ref<LLMProviderInfo[]>([])
const currentConfig = ref<LLMConfig>({
provider: '',
config: {}
const usageTypes = ref<LLMUsageType[]>([])
const allConfigs = ref<LLMAllConfigs>({
chat: { provider: '', config: {} },
kb_processing: { provider: '', config: {} }
})
const currentUsageType = ref<string>('chat')
const loading = ref(false)
const providersLoading = ref(false)
const testResult = ref<LLMTestResult | null>(null)
const testLoading = ref(false)
const currentConfig = computed(() => {
const config = allConfigs.value[currentUsageType.value as keyof LLMAllConfigs]
return {
provider: config?.provider || '',
config: config?.config || {},
usage_type: currentUsageType.value
}
})
const currentProvider = computed(() => {
return providers.value.find(p => p.name === currentConfig.value.provider)
})
@ -43,16 +57,29 @@ export const useLLMStore = defineStore('llm', () => {
}
}
const loadUsageTypes = async () => {
try {
const res: any = await getLLMUsageTypes()
usageTypes.value = res?.usage_types || res?.data?.usage_types || []
} catch (error) {
console.error('Failed to load LLM usage types:', error)
throw error
}
}
const loadConfig = async () => {
loading.value = true
try {
const res: any = await getLLMConfig()
const config = res?.data || res
if (config) {
currentConfig.value = {
provider: config.provider || '',
config: config.config || {},
updated_at: config.updated_at
const configs = res?.data || res
if (configs) {
if (configs.chat && configs.kb_processing) {
allConfigs.value = configs
} else {
allConfigs.value = {
chat: { provider: configs.provider || '', config: configs.config || {} },
kb_processing: { provider: configs.provider || '', config: configs.config || {} }
}
}
}
} catch (error) {
@ -68,7 +95,8 @@ export const useLLMStore = defineStore('llm', () => {
try {
const updateData: LLMConfigUpdate = {
provider: currentConfig.value.provider,
config: currentConfig.value.config
config: currentConfig.value.config,
usage_type: currentUsageType.value
}
await updateLLMConfig(updateData)
} catch (error) {
@ -86,7 +114,8 @@ export const useLLMStore = defineStore('llm', () => {
const result = await testLLM({
test_prompt: testPrompt,
provider: currentConfig.value.provider,
config: currentConfig.value.config
config: currentConfig.value.config,
usage_type: currentUsageType.value
})
testResult.value = result
return result
@ -103,7 +132,8 @@ export const useLLMStore = defineStore('llm', () => {
}
const setProvider = (providerName: string) => {
currentConfig.value.provider = providerName
const usageTypeKey = currentUsageType.value as keyof LLMAllConfigs
allConfigs.value[usageTypeKey].provider = providerName
const provider = providers.value.find(p => p.name === providerName)
if (provider?.config_schema?.properties) {
const newConfig: Record<string, any> = {}
@ -127,14 +157,19 @@ export const useLLMStore = defineStore('llm', () => {
}
}
})
currentConfig.value.config = newConfig
allConfigs.value[usageTypeKey].config = newConfig
} else {
currentConfig.value.config = {}
allConfigs.value[usageTypeKey].config = {}
}
}
const updateConfigValue = (key: string, value: any) => {
currentConfig.value.config[key] = value
const usageTypeKey = currentUsageType.value as keyof LLMAllConfigs
allConfigs.value[usageTypeKey].config[key] = value
}
const setCurrentUsageType = (usageType: string) => {
currentUsageType.value = usageType
}
const clearTestResult = () => {
@ -143,6 +178,9 @@ export const useLLMStore = defineStore('llm', () => {
return {
providers,
usageTypes,
allConfigs,
currentUsageType,
currentConfig,
loading,
providersLoading,
@ -151,11 +189,13 @@ export const useLLMStore = defineStore('llm', () => {
currentProvider,
configSchema,
loadProviders,
loadUsageTypes,
loadConfig,
saveCurrentConfig,
runTest,
setProvider,
updateConfigValue,
setCurrentUsageType,
clearTestResult
}
})

View File

@ -5,15 +5,32 @@ export interface LLMProviderInfo {
config_schema: Record<string, any>
}
export interface LLMUsageType {
name: string
display_name: string
description: string
}
export interface LLMConfig {
provider: string
config: Record<string, any>
updated_at?: string
}
export interface LLMConfigByUsage {
provider: string
config: Record<string, any>
}
export interface LLMAllConfigs {
chat: LLMConfigByUsage
kb_processing: LLMConfigByUsage
}
export interface LLMConfigUpdate {
provider: string
config?: Record<string, any>
usage_type?: string
}
export interface LLMTestResult {
@ -31,12 +48,17 @@ export interface LLMTestRequest {
test_prompt?: string
provider?: string
config?: Record<string, any>
usage_type?: string
}
export interface LLMProvidersResponse {
providers: LLMProviderInfo[]
}
export interface LLMUsageTypesResponse {
usage_types: LLMUsageType[]
}
export interface LLMConfigUpdateResponse {
success: boolean
message: string

View File

@ -4,13 +4,7 @@
<div class="header-content">
<div class="title-section">
<h1 class="page-title">LLM 模型配置</h1>
<p class="page-desc">配置和管理系统使用的大语言模型支持多种提供者切换配置修改后需保存才能生效</p>
</div>
<div class="header-actions" v-if="currentConfig.updated_at">
<div class="update-info">
<el-icon class="update-icon"><Clock /></el-icon>
<span>上次更新: {{ formatDate(currentConfig.updated_at) }}</span>
</div>
<p class="page-desc">配置和管理系统使用的大语言模型支持多种提供者切换可以为不同用途配置不同的模型</p>
</div>
</div>
</div>
@ -30,6 +24,26 @@
</template>
<div class="card-content">
<div class="usage-type-section">
<div class="section-label">
<el-icon><Setting /></el-icon>
<span>选择用途类型</span>
</div>
<el-radio-group v-model="currentUsageType" class="usage-type-radio" @change="handleUsageTypeChange">
<el-radio-button
v-for="ut in usageTypes"
:key="ut.name"
:value="ut.name"
>
<el-tooltip :content="ut.description" placement="top">
<span>{{ ut.display_name }}</span>
</el-tooltip>
</el-radio-button>
</el-radio-group>
</div>
<el-divider />
<div class="provider-select-section">
<div class="section-label">
<el-icon><Connection /></el-icon>
@ -103,7 +117,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Clock } from '@element-plus/icons-vue'
import { Cpu, Connection, InfoFilled, Box, RefreshLeft, Check, Setting } from '@element-plus/icons-vue'
import { useLLMStore } from '@/stores/llm'
import ProviderSelect from '@/components/common/ProviderSelect.vue'
import ConfigForm from '@/components/common/ConfigForm.vue'
@ -117,21 +131,18 @@ const saving = ref(false)
const pageLoading = ref(false)
const providers = computed(() => llmStore.providers)
const usageTypes = computed(() => llmStore.usageTypes)
const currentConfig = computed(() => llmStore.currentConfig)
const currentProvider = computed(() => llmStore.currentProvider)
const configSchema = computed(() => llmStore.configSchema)
const providersLoading = computed(() => llmStore.providersLoading)
const currentUsageType = computed({
get: () => llmStore.currentUsageType,
set: (val) => llmStore.setCurrentUsageType(val)
})
const formatDate = (dateStr: string) => {
if (!dateStr) return ''
const date = new Date(dateStr)
return date.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit'
})
const handleUsageTypeChange = (usageType: string) => {
llmStore.setCurrentUsageType(usageType)
}
const handleProviderChange = (provider: any) => {
@ -189,6 +200,7 @@ const initPage = async () => {
try {
await Promise.all([
llmStore.loadProviders(),
llmStore.loadUsageTypes(),
llmStore.loadConfig()
])
} catch (error) {
@ -253,27 +265,6 @@ onMounted(() => {
line-height: 1.6;
}
.header-actions {
display: flex;
align-items: center;
}
.update-info {
display: flex;
align-items: center;
gap: 6px;
padding: 8px 14px;
background-color: var(--bg-tertiary);
border-radius: 8px;
font-size: 13px;
color: var(--text-secondary);
}
.update-icon {
font-size: 14px;
color: var(--text-tertiary);
}
.config-card {
animation: fadeInUp 0.5s ease-out;
}
@ -329,8 +320,18 @@ onMounted(() => {
padding: 8px 0;
}
.provider-select-section {
margin-bottom: 16px;
.usage-type-section {
margin-bottom: 8px;
}
.usage-type-radio {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.usage-type-radio :deep(.el-radio-button__inner) {
padding: 10px 20px;
}
.section-label {
@ -347,6 +348,10 @@ onMounted(() => {
color: var(--primary-color);
}
.provider-select-section {
margin-bottom: 16px;
}
.provider-info {
display: flex;
align-items: flex-start;

View File

@ -10,6 +10,7 @@ import json
import hashlib
from dataclasses import dataclass
from typing import Annotated, Any, Optional
from logging.handlers import RotatingFileHandler
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
@ -38,6 +39,20 @@ from app.services.metadata_field_definition_service import MetadataFieldDefiniti
logger = logging.getLogger(__name__)
settings = get_settings()
kb_vector_logger = logging.getLogger("kb_vector_payload")
if settings.kb_vector_log_enabled and not kb_vector_logger.handlers:
handler = RotatingFileHandler(
filename=settings.kb_vector_log_path,
maxBytes=10 * 1024 * 1024,
backupCount=5,
encoding="utf-8",
)
handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
kb_vector_logger.addHandler(handler)
kb_vector_logger.setLevel(logging.INFO)
kb_vector_logger.propagate = False
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
@ -661,6 +676,7 @@ async def _index_document(
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"}
if file_ext in text_extensions or not file_ext:
logger.info("[INDEX] Treating as text file, trying multiple encodings")
@ -676,6 +692,44 @@ async def _index_document(
if text is None:
text = content.decode("utf-8", errors="replace")
logger.warning("[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
elif file_ext in image_extensions:
logger.info("[INDEX] Image file detected, will parse with multimodal LLM")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
logger.info(f"[INDEX] Temp file created: {tmp_path}")
try:
from app.services.document.image_parser import ImageParser
logger.info(f"[INDEX] Starting image parsing for {file_ext}...")
image_parser = ImageParser()
image_result = await image_parser.parse_with_chunks(tmp_path)
text = image_result.raw_text
parse_result = type('ParseResult', (), {
'text': text,
'metadata': image_result.metadata,
'pages': None,
'image_chunks': image_result.chunks,
'image_summary': image_result.image_summary,
})()
logger.info(
f"[INDEX] Parsed image SUCCESS: {filename}, "
f"chars={len(text)}, chunks={len(image_result.chunks)}, "
f"summary={image_result.image_summary[:50] if image_result.image_summary else 'N/A'}..."
)
except Exception as e:
logger.error(f"[INDEX] Image parsing error: {type(e).__name__}: {e}")
text = ""
parse_result = None
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info("[INDEX] Temp file cleaned up")
else:
logger.info("[INDEX] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
@ -723,13 +777,64 @@ async def _index_document(
)
await session.commit()
from app.services.metadata_auto_inference_service import MetadataAutoInferenceService
inference_service = MetadataAutoInferenceService(session)
image_base64_for_inference = None
mime_type_for_inference = None
if file_ext in image_extensions:
import base64
image_base64_for_inference = base64.b64encode(content).decode("utf-8")
mime_type_map = {
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".png": "image/png", ".gif": "image/gif",
".webp": "image/webp", ".bmp": "image/bmp",
".tiff": "image/tiff", ".tif": "image/tiff",
}
mime_type_for_inference = mime_type_map.get(file_ext, "image/jpeg")
logger.info("[INDEX] Starting metadata auto-inference...")
inference_result = await inference_service.infer_metadata(
tenant_id=tenant_id,
content=text or "",
scope="kb_document",
existing_metadata=metadata,
image_base64=image_base64_for_inference,
mime_type=mime_type_for_inference,
)
if inference_result.success:
metadata = inference_result.inferred_metadata
logger.info(
f"[INDEX] Metadata inference SUCCESS: "
f"inferred_fields={list(inference_result.inferred_metadata.keys())}, "
f"confidence_scores={inference_result.confidence_scores}"
)
else:
logger.warning(
f"[INDEX] Metadata inference FAILED: {inference_result.error_message}, "
f"using existing metadata"
)
logger.info("[INDEX] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
if parse_result and hasattr(parse_result, 'image_chunks') and parse_result.image_chunks:
logger.info(f"[INDEX] Image with {len(parse_result.image_chunks)} intelligent chunks from LLM")
for img_chunk in parse_result.image_chunks:
all_chunks.append(TextChunk(
text=img_chunk.content,
start_token=img_chunk.chunk_index,
end_token=img_chunk.chunk_index + 1,
page=None,
source=filename,
))
logger.info(f"[INDEX] Total chunks from image: {len(all_chunks)}")
elif parse_result and parse_result.pages:
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
@ -807,6 +912,35 @@ async def _index_document(
await session.commit()
if points:
if settings.kb_vector_log_enabled:
vector_payloads = []
for point in points:
if use_multi_vector:
payload = {
"id": point.get("id"),
"vector": point.get("vector"),
"payload": point.get("payload"),
}
else:
payload = {
"id": point.id,
"vector": point.vector,
"payload": point.payload,
}
vector_payloads.append(payload)
kb_vector_logger.info(json.dumps({
"tenant_id": tenant_id,
"kb_id": kb_id,
"doc_id": doc_id,
"job_id": job_id,
"filename": filename,
"file_ext": file_ext,
"is_image": file_ext in image_extensions,
"metadata": doc_metadata,
"vectors": vector_payloads,
}, ensure_ascii=False))
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant for kb_id={kb_id}...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)

View File

@ -10,6 +10,9 @@ from fastapi import APIRouter, Depends, Header, HTTPException
from app.services.llm.factory import (
LLMProviderFactory,
LLMUsageType,
LLM_USAGE_DISPLAY_NAMES,
LLM_USAGE_DESCRIPTIONS,
get_llm_config_manager,
)
@ -49,25 +52,63 @@ async def list_providers(
}
@router.get("/usage-types")
async def list_usage_types(
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""
List all available LLM usage types.
"""
logger.info(f"Listing LLM usage types for tenant={tenant_id}")
return {
"usage_types": [
{
"name": ut.value,
"display_name": LLM_USAGE_DISPLAY_NAMES[ut],
"description": LLM_USAGE_DESCRIPTIONS[ut],
}
for ut in LLMUsageType
],
}
@router.get("/config")
async def get_config(
tenant_id: str = Depends(get_tenant_id),
usage_type: str | None = None,
) -> dict[str, Any]:
"""
Get current LLM configuration.
[AC-ASA-14] Returns current provider and config.
If usage_type is specified, returns config for that usage type.
Otherwise, returns all configs.
"""
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}")
logger.info(f"[AC-ASA-14] Getting LLM config for tenant={tenant_id}, usage_type={usage_type}")
manager = get_llm_config_manager()
config = manager.get_current_config()
masked_config = _mask_secrets(config.get("config", {}))
if usage_type:
try:
ut = LLMUsageType(usage_type)
config = manager.get_current_config(ut)
masked_config = _mask_secrets(config.get("config", {}))
return {
"usage_type": config["usage_type"],
"provider": config["provider"],
"config": masked_config,
}
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid usage_type: {usage_type}")
return {
"provider": config["provider"],
"config": masked_config,
}
all_configs = manager.get_current_config()
result = {}
for ut_key, config in all_configs.items():
result[ut_key] = {
"provider": config["provider"],
"config": _mask_secrets(config.get("config", {})),
}
return result
@router.put("/config")
@ -78,11 +119,25 @@ async def update_config(
"""
Update LLM configuration.
[AC-ASA-16] Updates provider and config with validation.
Request body format:
- For specific usage type:
{
"usage_type": "chat" | "kb_processing",
"provider": "openai",
"config": {...}
}
- For all usage types (backward compatible):
{
"provider": "openai",
"config": {...}
}
"""
provider = body.get("provider")
config = body.get("config", {})
usage_type_str = body.get("usage_type")
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}")
logger.info(f"[AC-ASA-16] Updating LLM config for tenant={tenant_id}, provider={provider}, usage_type={usage_type_str}")
if not provider:
return {
@ -92,12 +147,24 @@ async def update_config(
try:
manager = get_llm_config_manager()
await manager.update_config(provider, config)
return {
"success": True,
"message": f"LLM configuration updated to {provider}",
}
if usage_type_str:
try:
usage_type = LLMUsageType(usage_type_str)
await manager.update_usage_config(usage_type, provider, config)
return {
"success": True,
"message": f"LLM configuration updated for {usage_type_str} to {provider}",
}
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid usage_type: {usage_type_str}")
else:
await manager.update_config(provider, config)
return {
"success": True,
"message": f"LLM configuration updated to {provider}",
}
except ValueError as e:
logger.error(f"[AC-ASA-16] Invalid LLM config: {e}")
@ -115,23 +182,44 @@ async def test_connection(
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Tests connection and returns response.
Request body format:
{
"test_prompt": "optional test prompt",
"provider": "optional provider to test",
"config": "optional config to test",
"usage_type": "optional usage type to test"
}
"""
body = body or {}
test_prompt = body.get("test_prompt", "你好,请简单介绍一下自己。")
provider = body.get("provider")
config = body.get("config")
usage_type_str = body.get("usage_type")
logger.info(
f"[AC-ASA-17] Testing LLM connection for tenant={tenant_id}, "
f"provider={provider or 'current'}"
f"provider={provider or 'current'}, usage_type={usage_type_str or 'default'}"
)
manager = get_llm_config_manager()
usage_type = None
if usage_type_str:
try:
usage_type = LLMUsageType(usage_type_str)
except ValueError:
return {
"success": False,
"error": f"Invalid usage_type: {usage_type_str}",
}
result = await manager.test_connection(
test_prompt=test_prompt,
provider=provider,
config=config,
usage_type=usage_type,
)
return result

View File

@ -0,0 +1,349 @@
"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] 策略管理 API
Endpoints:
- GET /strategy/retrieval/current - 获取当前策略状态
- POST /strategy/retrieval/switch - 切换策略
- POST /strategy/retrieval/validate - 验证策略配置
- POST /strategy/retrieval/rollback - 回退策略
"""
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.services.retrieval.strategy.config import (
GrayscaleConfig,
ModeRouterConfig,
PipelineConfig,
RetrievalStrategyConfig,
RerankerConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
StrategyRouter,
get_strategy_router,
set_strategy_router,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
class GrayscaleConfigSchema(BaseModel):
"""灰度配置 Schema。"""
enabled: bool = False
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
allowlist: list[str] = Field(default_factory=list)
class RerankerConfigSchema(BaseModel):
"""重排器配置 Schema。"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
class ModeRouterConfigSchema(BaseModel):
"""模式路由配置 Schema。"""
runtime_mode: str = "direct"
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
class PipelineConfigSchema(BaseModel):
"""Pipeline 配置 Schema。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
class StrategyStatusResponse(BaseModel):
"""策略状态响应。"""
active_strategy: str
grayscale: GrayscaleConfigSchema
pipeline: PipelineConfigSchema
reranker: RerankerConfigSchema
mode_router: ModeRouterConfigSchema
performance_thresholds: dict[str, float] = Field(default_factory=dict)
class SwitchRequest(BaseModel):
"""策略切换请求。"""
active_strategy: str | None = None
grayscale: GrayscaleConfigSchema | None = None
mode_router: ModeRouterConfigSchema | None = None
reranker: RerankerConfigSchema | None = None
class SwitchResponse(BaseModel):
"""策略切换响应。"""
success: bool
previous_strategy: str
current_strategy: str
message: str
class ValidationRequest(BaseModel):
"""策略验证请求。"""
config: dict[str, Any] = Field(default_factory=dict)
class ValidationResponse(BaseModel):
"""策略验证响应。"""
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
config_summary: dict[str, Any] = Field(default_factory=dict)
class RollbackResponse(BaseModel):
"""策略回退响应。"""
success: bool
previous_strategy: str
current_strategy: str
trigger: str
reason: str
audit_log_id: str | None = None
@router.get(
"/current",
response_model=StrategyStatusResponse,
summary="获取当前策略状态",
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
)
async def get_current_strategy() -> StrategyStatusResponse:
"""
AC-AISVC-RES-01 获取当前策略状态
"""
strategy_router = get_strategy_router()
config = strategy_router.get_config()
return StrategyStatusResponse(
active_strategy=config.active_strategy.value,
grayscale=GrayscaleConfigSchema(
enabled=config.grayscale.enabled,
percentage=config.grayscale.percentage,
allowlist=config.grayscale.allowlist,
),
pipeline=PipelineConfigSchema(
top_k=config.pipeline.top_k,
score_threshold=config.pipeline.score_threshold,
min_hits=config.pipeline.min_hits,
two_stage_enabled=config.pipeline.two_stage_enabled,
),
reranker=RerankerConfigSchema(
enabled=config.reranker.enabled,
model=config.reranker.model,
top_k_after_rerank=config.reranker.top_k_after_rerank,
min_score_threshold=config.reranker.min_score_threshold,
),
mode_router=ModeRouterConfigSchema(
runtime_mode=config.mode_router.runtime_mode.value,
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
react_max_steps=config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
),
performance_thresholds=config.performance_thresholds,
)
@router.post(
"/switch",
response_model=SwitchResponse,
summary="切换策略",
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
)
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
"""
AC-AISVC-RES-02, AC-AISVC-RES-03 切换策略
支持灰度发布配置percentage/allowlist
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy.value
try:
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
)
new_config = RetrievalStrategyConfig(
active_strategy=new_active_strategy,
grayscale=GrayscaleConfig(
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
),
mode_router=ModeRouterConfig(
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
) if request.mode_router else current_config.mode_router,
reranker=RerankerConfig(
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
model=request.reranker.model if request.reranker else current_config.reranker.model,
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
) if request.reranker else current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
)
return SwitchResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=new_config.active_strategy.value,
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
)
@router.post(
"/validate",
response_model=ValidationResponse,
summary="验证策略配置",
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
)
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
"""
AC-AISVC-RES-06, AC-AISVC-RES-08 验证策略配置
"""
errors: list[str] = []
warnings: list[str] = []
config = request.config
if "active_strategy" in config:
if config["active_strategy"] not in ["default", "enhanced"]:
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
if "grayscale" in config:
grayscale = config["grayscale"]
if grayscale.get("percentage", 1.0) is not None:
if not (0 <= grayscale["percentage"] <= 100):
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
if "mode_router" in config:
mode_router = config["mode_router"]
if "runtime_mode" in mode_router:
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
if "reranker" in config:
reranker = config["reranker"]
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
if "performance_thresholds" in config:
thresholds = config["performance_thresholds"]
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
return ValidationResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
config_summary={
"active_strategy": config.get("active_strategy", "default"),
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
"performance_thresholds": config.get("performance_thresholds", {}),
},
)
@router.post(
"/rollback",
response_model=RollbackResponse,
summary="回退策略",
description="【AC-AISVC-RES-07】 回退到默认策略。",
)
async def rollback_strategy(
trigger: str = "manual",
reason: str = "",
) -> RollbackResponse:
"""
AC-AISVC-RES-07 回退策略
支持手动触发和自动触发性能退化异常
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy
if previous_strategy == StrategyType.DEFAULT:
return RollbackResponse(
success=False,
previous_strategy=previous_strategy.value,
current_strategy=previous_strategy.value,
trigger=trigger,
reason="Already on default strategy",
audit_log_id=None,
)
new_config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=current_config.grayscale,
mode_router=current_config.mode_router,
reranker=current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
rollback_manager = get_rollback_manager()
rollback_manager.update_config(new_config)
audit_log = rollback_manager.record_audit(
action="rollback",
details={
"from_strategy": previous_strategy.value,
"to_strategy": StrategyType.DEFAULT.value,
"trigger": trigger,
"reason": reason or "Manual rollback",
},
)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy.value,
current_strategy=StrategyType.DEFAULT.value,
trigger=trigger,
reason=reason or "Manual rollback",
audit_log_id=str(audit_log.timestamp) if audit_log else None,
)

View File

@ -580,6 +580,35 @@ async def respond_dialogue(
f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}"
)
if dialogue_request.user_id:
try:
from app.services.mid.memory_adapter import MemoryAdapter
from app.services.mid.memory_summary_generator import MemorySummaryGenerator
memory_adapter = MemoryAdapter(session=session)
summary_generator = MemorySummaryGenerator()
history_messages = [
{"role": h.role, "content": h.content}
for h in (dialogue_request.history or [])
]
assistant_reply = "\n".join(s.text for s in final_segments)
update_messages = history_messages + [
{"role": "user", "content": dialogue_request.user_message},
{"role": "assistant", "content": assistant_reply},
]
await memory_adapter.update_with_summary_generation(
user_id=dialogue_request.user_id,
session_id=dialogue_request.session_id,
messages=update_messages,
tenant_id=tenant_id,
summary_generator=summary_generator,
recent_turns=8,
)
except Exception as e:
logger.warning(f"[AC-IDMP-14] Memory update trigger failed: {e}")
return DialogueResponse(
segments=final_segments,
trace=final_trace,
@ -1429,10 +1458,38 @@ async def _execute_agent_mode(
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
# 合并 tool_calls优先使用 KB 工具内部的 trace包含注入后的参数
final_tool_calls = list(react_ctx.tool_calls) if react_ctx.tool_calls else []
logger.info(
f"[TRACE-MERGE] Before merge: final_tool_calls count={len(final_tool_calls)}, "
f"kb_dynamic_result exists={kb_dynamic_result is not None}, "
f"kb_dynamic_result.tool_trace exists={kb_dynamic_result.tool_trace if kb_dynamic_result else None}"
)
if kb_dynamic_result and kb_dynamic_result.tool_trace:
kb_trace = kb_dynamic_result.tool_trace
logger.info(
f"[TRACE-MERGE] KB trace arguments: {kb_trace.arguments}"
)
for i, tc in enumerate(final_tool_calls):
logger.info(
f"[TRACE-MERGE] Checking tool_call[{i}]: tool_name={tc.tool_name}"
)
if tc.tool_name == "kb_search_dynamic":
logger.info(
f"[TRACE-MERGE] Replacing trace at index {i}: old_args={tc.arguments}, new_args={kb_trace.arguments}"
)
final_tool_calls[i] = kb_trace
break
else:
logger.info(
f"[TRACE-MERGE] Skipped merge: kb_dynamic_result={kb_dynamic_result is not None}, "
f"tool_trace={kb_dynamic_result.tool_trace if kb_dynamic_result else 'N/A'}"
)
trace_logger.update_trace(
request_id=request_id,
react_iterations=react_ctx.iteration,
tool_calls=react_ctx.tool_calls,
tool_calls=final_tool_calls,
)
segments = _text_to_segments(final_answer)
@ -1444,8 +1501,8 @@ async def _execute_agent_mode(
request_id=trace.request_id,
generation_id=trace.generation_id,
react_iterations=react_ctx.iteration,
tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None,
tool_calls=react_ctx.tool_calls,
tools_used=[tc.tool_name for tc in final_tool_calls] if final_tool_calls else None,
tool_calls=final_tool_calls,
timeout_profile=timeout_governor.profile,
kb_tool_called=True,
kb_hit=kb_success and len(kb_hits) > 0,

View File

@ -176,14 +176,38 @@ async def share_chat_page(
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>对话分享</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600&display=swap" rel="stylesheet">
<style>
:root {{
--bg-primary: #f8fafc;
--bg-secondary: #ffffff;
--bg-tertiary: #f1f5f9;
--text-primary: #0f172a;
--text-secondary: #475569;
--text-muted: #94a3b8;
--accent: #6366f1;
--accent-hover: #4f46e5;
--accent-light: #eef2ff;
--border: #e2e8f0;
--shadow-sm: 0 1px 2px rgba(0,0,0,0.04);
--shadow-md: 0 4px 12px rgba(0,0,0,0.06);
--shadow-lg: 0 8px 32px rgba(0,0,0,0.08);
--radius-sm: 8px;
--radius-md: 12px;
--radius-lg: 16px;
--radius-xl: 24px;
}}
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'PingFang SC', 'Microsoft YaHei', sans-serif;
background: #f8f9fa;
font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'PingFang SC', 'Microsoft YaHei', sans-serif;
background: var(--bg-primary);
min-height: 100vh;
display: flex;
flex-direction: column;
color: var(--text-primary);
line-height: 1.6;
}}
.welcome-screen {{
flex: 1;
@ -191,105 +215,221 @@ async def share_chat_page(
flex-direction: column;
align-items: center;
justify-content: center;
padding: 40px 20px;
padding: 48px 24px;
background: linear-gradient(180deg, var(--bg-secondary) 0%, var(--bg-primary) 100%);
}}
.welcome-screen.hidden {{ display: none; }}
.welcome-title {{
font-size: 28px;
font-weight: 600;
color: var(--text-primary);
margin-bottom: 32px;
letter-spacing: -0.02em;
}}
.welcome-input-wrapper {{
width: 100%;
max-width: 680px;
background: white;
border-radius: 16px;
padding: 16px 20px;
box-shadow: 0 2px 12px rgba(0,0,0,0.08);
max-width: 600px;
background: var(--bg-secondary);
border-radius: var(--radius-xl);
padding: 8px;
box-shadow: var(--shadow-lg);
border: 1px solid var(--border);
display: flex;
align-items: flex-end;
gap: 8px;
transition: box-shadow 0.2s ease, border-color 0.2s ease;
}}
.welcome-input-wrapper:focus-within {{
box-shadow: var(--shadow-lg), 0 0 0 3px var(--accent-light);
border-color: var(--accent);
}}
.welcome-textarea, .input-textarea {{
width: 100%;
min-height: 56px;
border: 1px solid #e5e5e5;
border-radius: 12px;
padding: 12px 16px;
flex: 1;
min-height: 48px;
max-height: 200px;
border: none;
border-radius: var(--radius-lg);
padding: 14px 16px;
resize: none;
outline: none;
font-size: 15px;
line-height: 1.6;
line-height: 1.5;
font-family: inherit;
background: #fafafa;
transition: all 0.2s;
background: transparent;
color: var(--text-primary);
}}
.welcome-textarea:focus, .input-textarea:focus {{
border-color: #1677ff;
background: white;
.welcome-textarea::placeholder, .input-textarea::placeholder {{
color: var(--text-muted);
}}
.chat-screen {{ flex: 1; display: none; flex-direction: column; }}
.chat-screen {{ flex: 1; display: none; flex-direction: column; background: var(--bg-primary); }}
.chat-screen.active {{ display: flex; }}
.chat-list {{
flex: 1;
padding: 20px;
max-width: 800px;
padding: 24px;
max-width: 720px;
width: 100%;
margin: 0 auto;
display: flex;
flex-direction: column;
gap: 16px;
gap: 20px;
overflow-y: auto;
}}
.bubble {{ display: flex; gap: 10px; align-items: flex-start; }}
.bubble {{ display: flex; gap: 12px; align-items: flex-start; animation: fadeIn 0.3s ease; }}
@keyframes fadeIn {{
from {{ opacity: 0; transform: translateY(8px); }}
to {{ opacity: 1; transform: translateY(0); }}
}}
.bubble.user {{ flex-direction: row-reverse; }}
.avatar {{
width: 32px; height: 32px; border-radius: 50%;
background: white; display: flex; align-items: center; justify-content: center;
width: 36px; height: 36px; border-radius: 50%;
background: var(--bg-tertiary);
display: flex; align-items: center; justify-content: center;
flex-shrink: 0;
overflow: hidden;
}}
.bubble-content {{ max-width: 75%; }}
.avatar svg {{
width: 22px;
height: 22px;
}}
.bubble.user .avatar {{
background: var(--accent-light);
}}
.bubble.user .avatar svg path {{
fill: var(--accent);
}}
.bubble.bot .avatar svg path {{
fill: var(--accent);
}}
.bubble-content {{ max-width: 80%; min-width: 0; }}
.bubble-text {{
padding: 12px 16px; border-radius: 16px; white-space: pre-wrap; word-break: break-word;
font-size: 14px; line-height: 1.6;
padding: 14px 18px;
border-radius: var(--radius-lg);
white-space: pre-wrap;
word-break: break-word;
font-size: 14px;
line-height: 1.65;
}}
.bubble.user .bubble-text {{
background: var(--accent);
color: white;
border-bottom-right-radius: var(--radius-sm);
}}
.bubble.bot .bubble-text {{
background: var(--bg-secondary);
color: var(--text-primary);
border-bottom-left-radius: var(--radius-sm);
box-shadow: var(--shadow-sm);
}}
.bubble.error .bubble-text {{
background: #fef2f2;
color: #dc2626;
border: 1px solid #fecaca;
}}
.bubble.user .bubble-text {{ background: #1677ff; color: white; }}
.bubble.bot .bubble-text {{ background: white; color: #333; }}
.bubble.error .bubble-text {{ background: #fff2f0; color: #ff4d4f; border: 1px solid #ffccc7; }}
.thought-block {{
background: #f5f5f5;
color: #888;
padding: 12px 16px;
border-radius: 12px;
background: var(--bg-tertiary);
color: var(--text-secondary);
padding: 14px 18px;
border-radius: var(--radius-md);
margin-bottom: 12px;
font-size: 13px;
line-height: 1.6;
border-left: 3px solid #ddd;
line-height: 1.65;
border-left: 3px solid var(--text-muted);
}}
.thought-label {{
font-weight: 600;
color: #999;
margin-bottom: 6px;
font-size: 12px;
color: var(--text-muted);
margin-bottom: 8px;
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.05em;
}}
.final-answer-block {{
background: white;
color: #333;
padding: 12px 16px;
border-radius: 12px;
background: var(--bg-secondary);
color: var(--text-primary);
padding: 14px 18px;
border-radius: var(--radius-md);
font-size: 14px;
line-height: 1.6;
line-height: 1.65;
}}
.final-answer-label {{
font-weight: 600;
color: #1677ff;
margin-bottom: 6px;
font-size: 12px;
color: var(--accent);
margin-bottom: 8px;
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.05em;
}}
.input-area {{
background: var(--bg-secondary);
padding: 16px 24px 24px;
border-top: 1px solid var(--border);
}}
.input-wrapper {{
max-width: 720px;
margin: 0 auto;
display: flex;
gap: 12px;
align-items: flex-end;
background: var(--bg-tertiary);
border-radius: var(--radius-lg);
padding: 6px 6px 6px 16px;
border: 1px solid var(--border);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}}
.input-wrapper:focus-within {{
border-color: var(--accent);
box-shadow: 0 0 0 3px var(--accent-light);
}}
.input-textarea {{
background: transparent;
padding: 10px 0;
}}
.input-area {{ background: white; padding: 16px 20px 20px; border-top: 1px solid #eee; }}
.input-wrapper {{ max-width: 800px; margin: 0 auto; display: flex; gap: 12px; align-items: flex-end; }}
.send-btn, .welcome-send {{
width: 40px; height: 40px; border-radius: 50%; border: none; cursor: pointer;
background: #1677ff; color: white;
width: 44px; height: 44px;
border-radius: 50%;
border: none;
cursor: pointer;
background: var(--accent);
color: white;
font-size: 18px;
display: flex;
align-items: center;
justify-content: center;
transition: background 0.2s ease, transform 0.15s ease;
flex-shrink: 0;
}}
.send-btn:hover, .welcome-send:hover {{
background: var(--accent-hover);
transform: scale(1.05);
}}
.send-btn:active, .welcome-send:active {{
transform: scale(0.95);
}}
.send-btn:disabled, .welcome-send:disabled {{
background: var(--text-muted);
cursor: not-allowed;
transform: none;
}}
.status {{
text-align: center;
padding: 12px;
font-size: 13px;
color: var(--text-muted);
font-weight: 500;
}}
.status.error {{ color: #dc2626; }}
@media (max-width: 640px) {{
.welcome-screen {{ padding: 32px 16px; }}
.welcome-title {{ font-size: 22px; margin-bottom: 24px; }}
.chat-list {{ padding: 16px; gap: 16px; }}
.bubble-content {{ max-width: 85%; }}
.input-area {{ padding: 12px 16px 20px; }}
}}
.status {{ text-align: center; padding: 8px; font-size: 12px; color: #999; }}
.status.error {{ color: #ff4d4f; }}
</style>
</head>
<body>
<div class="welcome-screen" id="welcomeScreen">
<h1>今天有什么可以帮到你</h1>
<h1 class="welcome-title">今天有什么可以帮到你</h1>
<div class="welcome-input-wrapper">
<textarea class="welcome-textarea" id="welcomeInput" placeholder="输入消息,按 Enter 发送" rows="1"></textarea>
<button class="welcome-send" id="welcomeSendBtn"></button>
@ -367,7 +507,12 @@ function formatBotMessage(text) {{
function addMessage(role, text) {{
const div = document.createElement('div');
div.className = 'bubble ' + role;
const avatar = role === 'user' ? '👤' : (role === 'bot' ? '🤖' : '⚠️');
const userSvg = '<svg viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg"><path d="M573.9 516.2L512 640l-61.9-123.8C232 546.4 64 733.6 64 960h896c0-226.4-168-413.6-386.1-443.8zM480 384h64c17.7 0 32.1 14.4 32.1 32.1 0 17.7-14.4 32.1-32.1 32.1h-64c-11.9 0-22.3-6.5-27.8-16.1H356c34.9 48.5 91.7 80 156 80 106 0 192-86 192-192s-86-192-192-192-192 86-192 192c0 28.5 6.2 55.6 17.4 80h114.8c5.5-9.6 15.9-16.1 27.8-16.1z"/><path d="M272 432.1h84c-4.2-5.9-8.1-12-11.7-18.4-2.3-4.1-4.4-8.3-6.4-12.5-0.2-0.4-0.4-0.7-0.5-1.1H288c-8.8 0-16-7.2-16-16v-48.4c0-64.1 25-124.3 70.3-169.6S447.9 95.8 512 95.8s124.3 25 169.7 70.3c38.3 38.3 62.1 87.2 68.5 140.2-8.4 4-14.2 12.5-14.2 22.4v78.6c0 13.7 11.1 24.8 24.8 24.8h14.6c13.7 0 24.8-11.1 24.8-24.8v-78.6c0-11.3-7.6-20.9-18-23.8-6.9-60.9-33.9-117.4-78-161.3C652.9 92.1 584.6 63.9 512 63.9s-140.9 28.3-192.3 79.6C268.3 194.8 240 263.1 240 335.7v64.4c0 17.7 14.3 32 32 32z"/></svg>';
const botSvg = '<svg viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg"><path d="M894.1 355.6h-1.7C853 177.6 687.6 51.4 498.1 54.9S148.2 190.5 115.9 369.7c-35.2 5.6-61.1 36-61.1 71.7v143.4c0.9 40.4 34.3 72.5 74.7 71.7 21.7-0.3 42.2-10 56-26.7 33.6 84.5 99.9 152 183.8 187 1.1-2 2.3-3.9 3.7-5.7 0.9-1.5 2.4-2.6 4.1-3 1.3 0 2.5 0.5 3.6 1.2a318.46 318.46 0 0 1-105.3-187.1c-5.1-44.4 24.1-85.4 67.6-95.2 64.3-11.7 128.1-24.7 192.4-35.9 37.9-5.3 70.4-29.8 85.7-64.9 6.8-15.9 11-32.8 12.5-50 0.5-3.1 2.9-5.6 5.9-6.2 3.1-0.7 6.4 0.5 8.2 3l1.7-1.1c25.4 35.9 74.7 114.4 82.7 197.2 8.2 94.8 3.7 160-71.4 226.5-1.1 1.1-1.7 2.6-1.7 4.1 0.1 2 1.1 3.8 2.8 4.8h4.8l3.2-1.8c75.6-40.4 132.8-108.2 159.9-189.5 11.4 16.1 28.5 27.1 47.8 30.8C846 783.9 716.9 871.6 557.2 884.9c-12-28.6-42.5-44.8-72.9-38.6-33.6 5.4-56.6 37-51.2 70.6 4.4 27.6 26.8 48.8 54.5 51.6 30.6 4.6 60.3-13 70.8-42.2 184.9-14.5 333.2-120.8 364.2-286.9 27.8-10.8 46.3-37.4 46.6-67.2V428.7c-0.1-19.5-8.1-38.2-22.3-51.6-14.5-13.8-33.8-21.4-53.8-21.3l1-0.2zM825.9 397c-71.1-176.9-272.1-262.7-449-191.7-86.8 34.9-155.7 103.4-191 190-2.5-2.8-5.2-5.4-8-7.9 25.3-154.6 163.8-268.6 326.8-269.2s302.3 112.6 328.7 267c-2.9 3.8-5.4 7.7-7.5 11.8z"/></svg>';
const errorSvg = '<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>';
const avatar = role === 'user' ? userSvg : (role === 'bot' ? botSvg : errorSvg);
let contentHtml;
if (role === 'bot') {{

View File

@ -23,6 +23,9 @@ class Settings(BaseSettings):
log_level: str = "INFO"
kb_vector_log_enabled: bool = False
kb_vector_log_path: str = "logs/kb_vector_payload.log"
llm_provider: str = "openai"
llm_api_key: str = ""
llm_base_url: str = "https://api.openai.com/v1"

View File

@ -492,23 +492,51 @@ class QdrantClient:
构建 Qdrant 过滤条件
Args:
metadata_filter: 元数据过滤条件 {"grade": "三年级", "subject": "语文"}
metadata_filter: 元数据过滤条件支持两种格式
- 简单值格式: {"grade": "三年级", "subject": "语文"}
- 操作符格式: {"grade": {"$eq": "三年级"}, "kb_scene": {"$eq": "open_consult"}}
Returns:
Qdrant Filter 对象
"""
from qdrant_client.models import FieldCondition, Filter, MatchValue
from qdrant_client.models import FieldCondition, Filter, MatchValue, MatchAny
must_conditions = []
for key, value in metadata_filter.items():
# 支持嵌套 metadata 字段,如 metadata.grade
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
if isinstance(value, dict):
op = list(value.keys())[0] if value else None
actual_value = value.get(op) if op else None
if op == "$eq" and actual_value is not None:
condition = FieldCondition(
key=field_path,
match=MatchValue(value=actual_value),
)
must_conditions.append(condition)
elif op == "$in" and isinstance(actual_value, list):
condition = FieldCondition(
key=field_path,
match=MatchAny(any=actual_value),
)
must_conditions.append(condition)
else:
logger.warning(
f"[AC-AISVC-16] Unsupported filter operator: {op}, using as direct value"
)
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
else:
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
return Filter(must=must_conditions) if must_conditions else None

View File

@ -4,6 +4,8 @@ Main FastAPI application for AI Service.
"""
import logging
import os
from logging.handlers import RotatingFileHandler
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
@ -52,15 +54,51 @@ from app.core.qdrant_client import close_qdrant_client
settings = get_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
def setup_logging():
"""
配置滚动日志文件
- 日志文件存储在 logs/ 目录
- 单文件最大 2MB超过则切分
- 保留最近 7 天的日志 70 个备份文件
- 同时输出到控制台
"""
log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
os.makedirs(log_dir, exist_ok=True)
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
formatter = logging.Formatter(log_format)
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.log_level.upper()))
root_logger.handlers.clear()
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, settings.log_level.upper()))
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
log_file = os.path.join(log_dir, "ai-service.log")
file_handler = RotatingFileHandler(
filename=log_file,
maxBytes=2 * 1024 * 1024,
backupCount=70,
encoding="utf-8",
)
file_handler.setLevel(getattr(logging, settings.log_level.upper()))
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
return log_dir
setup_logging()
logger = logging.getLogger(__name__)

View File

@ -116,6 +116,44 @@ class ChatMessageCreate(SQLModel):
content: str
class UserMemory(SQLModel, table=True):
"""
[AC-IDMP-14] 用户级记忆存储滚动摘要
支持多租户隔离存储最新 summary + facts/preferences/open_issues
"""
__tablename__ = "user_memories"
__table_args__ = (
Index("ix_user_memories_tenant_user", "tenant_id", "user_id"),
Index("ix_user_memories_tenant_user_updated", "tenant_id", "user_id", "updated_at"),
)
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True)
user_id: str = Field(..., description="User ID for memory storage", index=True)
summary: str | None = Field(default=None, description="Rolling summary for user")
facts: list[str] | None = Field(
default=None,
sa_column=Column("facts", JSON, nullable=True),
description="Extracted stable facts list",
)
preferences: dict[str, Any] | None = Field(
default=None,
sa_column=Column("preferences", JSON, nullable=True),
description="User preferences as structured JSON",
)
open_issues: list[str] | None = Field(
default=None,
sa_column=Column("open_issues", JSON, nullable=True),
description="Open issues list",
)
summary_version: int = Field(default=1, description="Summary version / update round")
last_turn_id: str | None = Field(default=None, description="Last turn identifier (optional)")
expires_at: datetime | None = Field(default=None, description="Expiration time (optional)")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
class SharedSession(SQLModel, table=True):
"""
[AC-IDMP-SHARE] Shared session entity for dialogue sharing.

View File

@ -16,6 +16,16 @@ from app.services.document.factory import (
get_supported_document_formats,
parse_document,
)
from app.services.document.image_parser import ImageParser
from app.services.document.markdown_chunker import (
MarkdownChunk,
MarkdownChunker,
MarkdownElement,
MarkdownElementType,
MarkdownParser as MarkdownStructureParser,
chunk_markdown,
)
from app.services.document.markdown_parser import MarkdownParser
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
from app.services.document.text_parser import TextParser
from app.services.document.word_parser import WordParser
@ -35,4 +45,12 @@ __all__ = [
"ExcelParser",
"CSVParser",
"TextParser",
"MarkdownParser",
"MarkdownChunker",
"MarkdownChunk",
"MarkdownElement",
"MarkdownElementType",
"MarkdownStructureParser",
"chunk_markdown",
"ImageParser",
]

View File

@ -16,6 +16,8 @@ from app.services.document.base import (
UnsupportedFormatError,
)
from app.services.document.excel_parser import CSVParser, ExcelParser
from app.services.document.image_parser import ImageParser
from app.services.document.markdown_parser import MarkdownParser
from app.services.document.pdf_parser import PDFParser, PDFPlumberParser
from app.services.document.text_parser import TextParser
from app.services.document.word_parser import WordParser
@ -45,6 +47,8 @@ class DocumentParserFactory:
"excel": ExcelParser,
"csv": CSVParser,
"text": TextParser,
"markdown": MarkdownParser,
"image": ImageParser,
}
cls._extension_map = {
@ -54,14 +58,22 @@ class DocumentParserFactory:
".xls": "excel",
".csv": "csv",
".txt": "text",
".md": "text",
".markdown": "text",
".md": "markdown",
".markdown": "markdown",
".rst": "text",
".log": "text",
".json": "text",
".xml": "text",
".yaml": "text",
".yml": "text",
".jpg": "image",
".jpeg": "image",
".png": "image",
".gif": "image",
".webp": "image",
".bmp": "image",
".tiff": "image",
".tif": "image",
}
@classmethod
@ -174,6 +186,8 @@ class DocumentParserFactory:
"excel": "Excel 电子表格",
"csv": "CSV 文件",
"text": "文本文件",
"markdown": "Markdown 文档",
"image": "图片文件",
}
descriptions = {
@ -183,6 +197,8 @@ class DocumentParserFactory:
"excel": "解析 Excel 电子表格,支持多工作表",
"csv": "解析 CSV 文件,自动检测编码",
"text": "解析纯文本文件,支持多种编码",
"markdown": "智能解析 Markdown 文档,保留结构(标题、代码块、表格、列表)",
"image": "使用多模态 LLM 解析图片,提取文字和关键信息",
}
info.append({

View File

@ -0,0 +1,490 @@
"""
Image parser using multimodal LLM.
Supports parsing images into structured text content for knowledge base indexing.
"""
import asyncio
import base64
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
PageText,
ParseResult,
)
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
logger = logging.getLogger(__name__)
IMAGE_SYSTEM_PROMPT = """你是一个专业的图像内容分析助手。你的任务是分析图片内容,并将其智能拆分为适合知识库检索的独立数据块。
## 分析要求
1. 仔细分析图片内容识别其中的文字图表数据等信息
2. 根据内容的逻辑结构智能判断如何拆分为独立的知识条目
3. 每个条目应该是独立完整可检索的知识单元
## 输出格式
请严格按照以下 JSON 格式输出不要添加任何其他内容
```json
{
"image_summary": "图片整体概述(一句话描述图片主题)",
"total_chunks": <分块总数>,
"chunks": [
{
"chunk_index": 0,
"content": "该分块的完整内容文字",
"chunk_type": "text|table|list|diagram|chart|mixed",
"keywords": ["关键词1", "关键词2"]
}
]
}
```
## 分块策略
- **单一内容**: 如果图片只有一段完整的文字/信息可以只输出1个分块
- **多段落内容**: 按段落或逻辑单元拆分每个段落作为独立分块
- **表格数据**: 将表格内容转换为结构化文字作为一个分块
- **图表数据**: 描述图表内容和数据作为一个分块
- **列表内容**: 每个列表项可作为独立分块或合并为相关的一组
- **混合内容**: 根据内容类型分别处理确保每个分块主题明确
## 注意事项
1. 每个分块的 content 必须是完整可独立理解的文字
2. chunk_type 用于标识内容类型便于后续处理
3. keywords 提取该分块的核心关键词便于检索
4. 确保输出的 JSON 格式正确可以被解析"""
IMAGE_USER_PROMPT = "请分析这张图片,按照要求的 JSON 格式输出分块结果。"
@dataclass
class ImageChunk:
"""智能分块结果"""
chunk_index: int
content: str
chunk_type: str = "text"
keywords: list[str] = field(default_factory=list)
@dataclass
class ImageParseResult:
"""图片解析结果(包含智能分块)"""
image_summary: str
chunks: list[ImageChunk]
raw_text: str
source_path: str
file_size: int
metadata: dict[str, Any] = field(default_factory=dict)
class ImageParser(DocumentParser):
"""
Image parser using multimodal LLM.
Supports common image formats and extracts text content using
vision-capable LLM models (GPT-4V, GPT-4o, etc.).
Features:
- Intelligent chunking based on content structure
- Structured output with keywords and chunk types
- Support for various content types (text, table, chart, etc.)
"""
SUPPORTED_EXTENSIONS = [
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"
]
def __init__(
self,
model: str | None = None,
max_tokens: int = 4096,
timeout_seconds: int = 120,
):
self._model = model
self._max_tokens = max_tokens
self._timeout_seconds = timeout_seconds
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse an image file and extract text content using multimodal LLM.
Note: This method is synchronous but internally uses async operations.
For async contexts, use parse_async() instead.
Args:
file_path: Path to the image file.
Returns:
ParseResult with extracted text content.
Raises:
DocumentParseException: If parsing fails.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"Image file not found: {file_path}",
file_path=str(path),
parser="image",
)
file_size = path.stat().st_size
extension = path.suffix.lower()
if extension not in self.SUPPORTED_EXTENSIONS:
raise DocumentParseException(
f"Unsupported image format: {extension}",
file_path=str(path),
parser="image",
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
)
try:
with open(path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
mime_type = self._get_mime_type(extension)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run,
self._analyze_image_async(image_base64, mime_type)
)
result = future.result()
except RuntimeError:
result = asyncio.run(self._analyze_image_async(image_base64, mime_type))
logger.info(
f"[IMAGE-PARSER] Successfully parsed image: {path.name}, "
f"size={file_size}, chunks={len(result.chunks)}"
)
return ParseResult(
text=result.raw_text,
source_path=str(path),
file_size=file_size,
page_count=1,
metadata={
"format": extension,
"parser": "image",
"mime_type": mime_type,
"image_summary": result.image_summary,
"chunk_count": len(result.chunks),
"chunks": [
{
"chunk_index": c.chunk_index,
"content": c.content,
"chunk_type": c.chunk_type,
"keywords": c.keywords,
}
for c in result.chunks
],
},
pages=[PageText(page=1, text=result.raw_text)],
)
except Exception as e:
logger.error(f"[IMAGE-PARSER] Failed to parse image {path}: {e}")
raise DocumentParseException(
f"Failed to parse image: {str(e)}",
file_path=str(path),
parser="image",
details={"error": str(e)},
)
async def parse_async(self, file_path: str | Path) -> ParseResult:
"""
Async version of parse method for use in async contexts.
Args:
file_path: Path to the image file.
Returns:
ParseResult with extracted text content.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"Image file not found: {file_path}",
file_path=str(path),
parser="image",
)
file_size = path.stat().st_size
extension = path.suffix.lower()
if extension not in self.SUPPORTED_EXTENSIONS:
raise DocumentParseException(
f"Unsupported image format: {extension}",
file_path=str(path),
parser="image",
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
)
try:
with open(path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
mime_type = self._get_mime_type(extension)
result = await self._analyze_image_async(image_base64, mime_type)
logger.info(
f"[IMAGE-PARSER] Successfully parsed image (async): {path.name}, "
f"size={file_size}, chunks={len(result.chunks)}"
)
return ParseResult(
text=result.raw_text,
source_path=str(path),
file_size=file_size,
page_count=1,
metadata={
"format": extension,
"parser": "image",
"mime_type": mime_type,
"image_summary": result.image_summary,
"chunk_count": len(result.chunks),
"chunks": [
{
"chunk_index": c.chunk_index,
"content": c.content,
"chunk_type": c.chunk_type,
"keywords": c.keywords,
}
for c in result.chunks
],
},
pages=[PageText(page=1, text=result.raw_text)],
)
except Exception as e:
logger.error(f"[IMAGE-PARSER] Failed to parse image {path}: {e}")
raise DocumentParseException(
f"Failed to parse image: {str(e)}",
file_path=str(path),
parser="image",
details={"error": str(e)},
)
async def parse_with_chunks(self, file_path: str | Path) -> ImageParseResult:
"""
Parse image and return structured result with intelligent chunks.
Args:
file_path: Path to the image file.
Returns:
ImageParseResult with intelligent chunks.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"Image file not found: {file_path}",
file_path=str(path),
parser="image",
)
file_size = path.stat().st_size
extension = path.suffix.lower()
if extension not in self.SUPPORTED_EXTENSIONS:
raise DocumentParseException(
f"Unsupported image format: {extension}",
file_path=str(path),
parser="image",
details={"supported_formats": self.SUPPORTED_EXTENSIONS},
)
with open(path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
mime_type = self._get_mime_type(extension)
result = await self._analyze_image_async(image_base64, mime_type)
result.source_path = str(path)
result.file_size = file_size
result.metadata = {
"format": extension,
"parser": "image",
"mime_type": mime_type,
}
return result
async def _analyze_image_async(self, image_base64: str, mime_type: str) -> ImageParseResult:
"""
Analyze image using multimodal LLM and return structured chunks.
Args:
image_base64: Base64 encoded image data.
mime_type: MIME type of the image.
Returns:
ImageParseResult with intelligent chunks.
"""
try:
manager = get_llm_config_manager()
client = manager.get_kb_processing_client()
config = manager.kb_processing_config
model = self._model or config.get("model", "gpt-4o-mini")
messages = [
{
"role": "system",
"content": IMAGE_SYSTEM_PROMPT,
},
{
"role": "user",
"content": [
{
"type": "text",
"text": IMAGE_USER_PROMPT,
},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_base64}",
},
},
],
},
]
from app.services.llm.base import LLMConfig
llm_config = LLMConfig(
model=model,
max_tokens=self._max_tokens,
temperature=0.3,
timeout_seconds=self._timeout_seconds,
)
response = await client.generate(messages=messages, config=llm_config)
if not response.content:
raise DocumentParseException(
"LLM returned empty response for image analysis",
parser="image",
)
return self._parse_llm_response(response.content)
except Exception as e:
logger.error(f"[IMAGE-PARSER] LLM analysis failed: {e}")
raise
def _parse_llm_response(self, response_content: str) -> ImageParseResult:
"""
Parse LLM response into structured ImageParseResult.
Args:
response_content: Raw LLM response content.
Returns:
ImageParseResult with parsed chunks.
"""
try:
json_str = self._extract_json(response_content)
data = json.loads(json_str)
image_summary = data.get("image_summary", "")
chunks_data = data.get("chunks", [])
chunks = []
for chunk_data in chunks_data:
chunk = ImageChunk(
chunk_index=chunk_data.get("chunk_index", len(chunks)),
content=chunk_data.get("content", ""),
chunk_type=chunk_data.get("chunk_type", "text"),
keywords=chunk_data.get("keywords", []),
)
if chunk.content.strip():
chunks.append(chunk)
if not chunks:
chunks.append(ImageChunk(
chunk_index=0,
content=response_content,
chunk_type="text",
keywords=[],
))
raw_text = "\n\n".join([c.content for c in chunks])
return ImageParseResult(
image_summary=image_summary,
chunks=chunks,
raw_text=raw_text,
source_path="",
file_size=0,
)
except json.JSONDecodeError as e:
logger.warning(f"[IMAGE-PARSER] Failed to parse JSON response: {e}, using fallback")
return ImageParseResult(
image_summary="图片内容",
chunks=[ImageChunk(
chunk_index=0,
content=response_content,
chunk_type="text",
keywords=[],
)],
raw_text=response_content,
source_path="",
file_size=0,
)
def _extract_json(self, content: str) -> str:
"""
Extract JSON from LLM response content.
Args:
content: Raw response content that may contain JSON.
Returns:
Extracted JSON string.
"""
content = content.strip()
if content.startswith("{") and content.endswith("}"):
return content
json_start = content.find("{")
json_end = content.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
return content[json_start:json_end + 1]
return content
def _get_mime_type(self, extension: str) -> str:
"""Get MIME type for image extension."""
mime_types = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
".tiff": "image/tiff",
".tif": "image/tiff",
}
return mime_types.get(extension.lower(), "image/jpeg")
def get_supported_extensions(self) -> list[str]:
"""Get list of supported image extensions."""
return ImageParser.SUPPORTED_EXTENSIONS

View File

@ -0,0 +1,771 @@
"""
Markdown intelligent chunker with structure-aware splitting.
Supports headers, code blocks, tables, lists, and preserves context.
"""
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class MarkdownElementType(Enum):
"""Types of Markdown elements."""
HEADER = "header"
PARAGRAPH = "paragraph"
CODE_BLOCK = "code_block"
INLINE_CODE = "inline_code"
TABLE = "table"
LIST = "list"
BLOCKQUOTE = "blockquote"
HORIZONTAL_RULE = "horizontal_rule"
IMAGE = "image"
LINK = "link"
TEXT = "text"
@dataclass
class MarkdownElement:
"""Represents a parsed Markdown element."""
type: MarkdownElementType
content: str
level: int = 0
language: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
line_start: int = 0
line_end: int = 0
def to_dict(self) -> dict[str, Any]:
return {
"type": self.type.value,
"content": self.content,
"level": self.level,
"language": self.language,
"metadata": self.metadata,
"line_start": self.line_start,
"line_end": self.line_end,
}
@dataclass
class MarkdownChunk:
"""Represents a chunk of Markdown content with context."""
chunk_id: str
content: str
element_type: MarkdownElementType
header_context: list[str]
level: int = 0
language: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"chunk_id": self.chunk_id,
"content": self.content,
"element_type": self.element_type.value,
"header_context": self.header_context,
"level": self.level,
"language": self.language,
"metadata": self.metadata,
}
class MarkdownParser:
"""
Parser for Markdown documents.
Extracts structured elements from Markdown text.
"""
HEADER_PATTERN = re.compile(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', re.MULTILINE)
CODE_BLOCK_PATTERN = re.compile(r'^```(\w*)\n(.*?)^```', re.MULTILINE | re.DOTALL)
TABLE_PATTERN = re.compile(r'^(\|.+\|)\n(\|[-:\s|]+\|)\n((?:\|.+\|\n?)+)', re.MULTILINE)
LIST_PATTERN = re.compile(r'^([ \t]*[-*+]|\d+\.)\s+(.+)$', re.MULTILINE)
BLOCKQUOTE_PATTERN = re.compile(r'^>\s*(.+)$', re.MULTILINE)
HR_PATTERN = re.compile(r'^[-*_]{3,}\s*$', re.MULTILINE)
IMAGE_PATTERN = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)')
LINK_PATTERN = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
INLINE_CODE_PATTERN = re.compile(r'`([^`]+)`')
def parse(self, text: str) -> list[MarkdownElement]:
"""
Parse Markdown text into structured elements.
Args:
text: Raw Markdown text
Returns:
List of MarkdownElement objects
"""
elements = []
lines = text.split('\n')
current_pos = 0
code_block_ranges = self._extract_code_blocks(text, lines, elements)
table_ranges = self._extract_tables(text, lines, elements)
protected_ranges = code_block_ranges + table_ranges
self._extract_headers(lines, elements, protected_ranges)
self._extract_lists(lines, elements, protected_ranges)
self._extract_blockquotes(lines, elements, protected_ranges)
self._extract_horizontal_rules(lines, elements, protected_ranges)
self._fill_paragraphs(lines, elements, protected_ranges)
elements.sort(key=lambda e: e.line_start)
return elements
def _extract_code_blocks(
self,
text: str,
lines: list[str],
elements: list[MarkdownElement],
) -> list[tuple[int, int]]:
"""Extract code blocks with language info."""
ranges = []
in_code_block = False
code_start = 0
language = ""
code_content = []
for i, line in enumerate(lines):
if line.strip().startswith('```'):
if not in_code_block:
in_code_block = True
code_start = i
language = line.strip()[3:].strip()
code_content = []
else:
in_code_block = False
elements.append(MarkdownElement(
type=MarkdownElementType.CODE_BLOCK,
content='\n'.join(code_content),
language=language,
line_start=code_start,
line_end=i,
metadata={"language": language},
))
ranges.append((code_start, i))
elif in_code_block:
code_content.append(line)
return ranges
def _extract_tables(
self,
text: str,
lines: list[str],
elements: list[MarkdownElement],
) -> list[tuple[int, int]]:
"""Extract Markdown tables."""
ranges = []
i = 0
while i < len(lines):
line = lines[i]
if '|' in line and i + 1 < len(lines):
next_line = lines[i + 1]
if '|' in next_line and re.match(r'^[\|\-\:\s]+$', next_line.strip()):
table_lines = [line, next_line]
j = i + 2
while j < len(lines) and '|' in lines[j]:
table_lines.append(lines[j])
j += 1
table_content = '\n'.join(table_lines)
headers = [h.strip() for h in line.split('|') if h.strip()]
row_count = len(table_lines) - 2
elements.append(MarkdownElement(
type=MarkdownElementType.TABLE,
content=table_content,
line_start=i,
line_end=j - 1,
metadata={
"headers": headers,
"row_count": row_count,
},
))
ranges.append((i, j - 1))
i = j
continue
i += 1
return ranges
def _is_in_protected_range(self, line_num: int, ranges: list[tuple[int, int]]) -> bool:
"""Check if a line is within a protected range."""
for start, end in ranges:
if start <= line_num <= end:
return True
return False
def _extract_headers(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract headers with level info."""
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
continue
match = self.HEADER_PATTERN.match(line)
if match:
level = len(match.group(1))
title = match.group(2).strip()
elements.append(MarkdownElement(
type=MarkdownElementType.HEADER,
content=title,
level=level,
line_start=i,
line_end=i,
metadata={"level": level},
))
def _extract_lists(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract list items."""
in_list = False
list_start = 0
list_items = []
list_indent = 0
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
if in_list:
self._save_list(elements, list_start, i - 1, list_items)
in_list = False
list_items = []
continue
match = self.LIST_PATTERN.match(line)
if match:
indent = len(line) - len(line.lstrip())
item_content = match.group(2)
if not in_list:
in_list = True
list_start = i
list_indent = indent
list_items = [(indent, item_content)]
else:
list_items.append((indent, item_content))
else:
if in_list:
if line.strip() == '':
continue
else:
self._save_list(elements, list_start, i - 1, list_items)
in_list = False
list_items = []
if in_list:
self._save_list(elements, list_start, len(lines) - 1, list_items)
def _save_list(
self,
elements: list[MarkdownElement],
start: int,
end: int,
items: list[tuple[int, str]],
) -> None:
"""Save a list element."""
if not items:
return
content = '\n'.join([item[1] for item in items])
elements.append(MarkdownElement(
type=MarkdownElementType.LIST,
content=content,
line_start=start,
line_end=end,
metadata={
"item_count": len(items),
"is_ordered": False,
},
))
def _extract_blockquotes(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract blockquotes."""
in_quote = False
quote_start = 0
quote_lines = []
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
if in_quote:
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
in_quote = False
quote_lines = []
continue
match = self.BLOCKQUOTE_PATTERN.match(line)
if match:
if not in_quote:
in_quote = True
quote_start = i
quote_lines.append(match.group(1))
else:
if in_quote:
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
in_quote = False
quote_lines = []
if in_quote:
self._save_blockquote(elements, quote_start, len(lines) - 1, quote_lines)
def _save_blockquote(
self,
elements: list[MarkdownElement],
start: int,
end: int,
lines: list[str],
) -> None:
"""Save a blockquote element."""
if not lines:
return
elements.append(MarkdownElement(
type=MarkdownElementType.BLOCKQUOTE,
content='\n'.join(lines),
line_start=start,
line_end=end,
))
def _extract_horizontal_rules(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract horizontal rules."""
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
continue
if self.HR_PATTERN.match(line):
elements.append(MarkdownElement(
type=MarkdownElementType.HORIZONTAL_RULE,
content=line,
line_start=i,
line_end=i,
))
def _fill_paragraphs(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Fill in paragraphs for remaining content."""
occupied = set()
for start, end in protected_ranges:
for i in range(start, end + 1):
occupied.add(i)
for elem in elements:
for i in range(elem.line_start, elem.line_end + 1):
occupied.add(i)
i = 0
while i < len(lines):
if i in occupied:
i += 1
continue
if lines[i].strip() == '':
i += 1
continue
para_start = i
para_lines = []
while i < len(lines) and i not in occupied and lines[i].strip() != '':
para_lines.append(lines[i])
occupied.add(i)
i += 1
if para_lines:
elements.append(MarkdownElement(
type=MarkdownElementType.PARAGRAPH,
content='\n'.join(para_lines),
line_start=para_start,
line_end=i - 1,
))
class MarkdownChunker:
"""
Intelligent chunker for Markdown documents.
Features:
- Structure-aware splitting (headers, code blocks, tables, lists)
- Context preservation (header hierarchy)
- Configurable chunk size and overlap
- Metadata extraction
"""
def __init__(
self,
max_chunk_size: int = 1000,
min_chunk_size: int = 100,
chunk_overlap: int = 50,
preserve_code_blocks: bool = True,
preserve_tables: bool = True,
preserve_lists: bool = True,
include_header_context: bool = True,
):
self._max_chunk_size = max_chunk_size
self._min_chunk_size = min_chunk_size
self._chunk_overlap = chunk_overlap
self._preserve_code_blocks = preserve_code_blocks
self._preserve_tables = preserve_tables
self._preserve_lists = preserve_lists
self._include_header_context = include_header_context
self._parser = MarkdownParser()
def chunk(self, text: str, doc_id: str = "") -> list[MarkdownChunk]:
"""
Chunk Markdown text into structured segments.
Args:
text: Raw Markdown text
doc_id: Optional document ID for chunk IDs
Returns:
List of MarkdownChunk objects
"""
elements = self._parser.parse(text)
chunks = []
header_stack: list[str] = []
chunk_index = 0
for elem in elements:
if elem.type == MarkdownElementType.HEADER:
level = elem.level
while len(header_stack) >= level:
if header_stack:
header_stack.pop()
header_stack.append(elem.content)
continue
if elem.type == MarkdownElementType.HORIZONTAL_RULE:
continue
chunk_content = self._format_element_content(elem)
if not chunk_content:
continue
chunk_id = f"{doc_id}_chunk_{chunk_index}" if doc_id else f"chunk_{chunk_index}"
header_context = []
if self._include_header_context:
header_context = header_stack.copy()
if len(chunk_content) > self._max_chunk_size:
sub_chunks = self._split_large_element(
elem,
chunk_id,
header_context,
chunk_index,
)
chunks.extend(sub_chunks)
chunk_index += len(sub_chunks)
else:
chunks.append(MarkdownChunk(
chunk_id=chunk_id,
content=chunk_content,
element_type=elem.type,
header_context=header_context,
level=elem.level,
language=elem.language,
metadata=elem.metadata,
))
chunk_index += 1
return chunks
def _format_element_content(self, elem: MarkdownElement) -> str:
"""Format element content based on type."""
if elem.type == MarkdownElementType.CODE_BLOCK:
lang = elem.language or ""
return f"```{lang}\n{elem.content}\n```"
elif elem.type == MarkdownElementType.TABLE:
return elem.content
elif elem.type == MarkdownElementType.LIST:
return elem.content
elif elem.type == MarkdownElementType.BLOCKQUOTE:
lines = elem.content.split('\n')
return '\n'.join([f"> {line}" for line in lines])
elif elem.type == MarkdownElementType.PARAGRAPH:
return elem.content
return elem.content
def _split_large_element(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split a large element into smaller chunks."""
chunks = []
if elem.type == MarkdownElementType.CODE_BLOCK:
chunks = self._split_code_block(elem, base_id, header_context, start_index)
elif elem.type == MarkdownElementType.TABLE:
chunks = self._split_table(elem, base_id, header_context, start_index)
elif elem.type == MarkdownElementType.LIST:
chunks = self._split_list(elem, base_id, header_context, start_index)
else:
chunks = self._split_text(elem, base_id, header_context, start_index)
return chunks
def _split_code_block(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split code block while preserving language marker."""
chunks = []
lines = elem.content.split('\n')
current_lines = []
current_size = 0
sub_index = 0
for line in lines:
if current_size + len(line) + 1 > self._max_chunk_size and current_lines:
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=chunk_content,
element_type=MarkdownElementType.CODE_BLOCK,
header_context=header_context,
language=elem.language,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_lines = []
current_size = 0
current_lines.append(line)
current_size += len(line) + 1
if current_lines:
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=chunk_content,
element_type=MarkdownElementType.CODE_BLOCK,
header_context=header_context,
language=elem.language,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_table(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split table while preserving header row."""
chunks = []
lines = elem.content.split('\n')
if len(lines) < 2:
return [MarkdownChunk(
chunk_id=f"{base_id}_0",
content=elem.content,
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata=elem.metadata,
)]
header_line = lines[0]
separator_line = lines[1]
data_lines = lines[2:]
current_lines = [header_line, separator_line]
current_size = len(header_line) + len(separator_line) + 2
sub_index = 0
for line in data_lines:
if current_size + len(line) + 1 > self._max_chunk_size and len(current_lines) > 2:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_lines),
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_lines = [header_line, separator_line]
current_size = len(header_line) + len(separator_line) + 2
current_lines.append(line)
current_size += len(line) + 1
if len(current_lines) > 2:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_lines),
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_list(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split list into smaller chunks."""
chunks = []
items = elem.content.split('\n')
current_items = []
current_size = 0
sub_index = 0
for item in items:
if current_size + len(item) + 1 > self._max_chunk_size and current_items:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_items),
element_type=MarkdownElementType.LIST,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_items = []
current_size = 0
current_items.append(item)
current_size += len(item) + 1
if current_items:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_items),
element_type=MarkdownElementType.LIST,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_text(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split text content by sentences or paragraphs."""
chunks = []
text = elem.content
sub_index = 0
paragraphs = text.split('\n\n')
current_content = ""
current_size = 0
for para in paragraphs:
if current_size + len(para) + 2 > self._max_chunk_size and current_content:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=current_content.strip(),
element_type=elem.type,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_content = ""
current_size = 0
current_content += para + "\n\n"
current_size += len(para) + 2
if current_content.strip():
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=current_content.strip(),
element_type=elem.type,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def chunk_markdown(
text: str,
doc_id: str = "",
max_chunk_size: int = 1000,
min_chunk_size: int = 100,
preserve_code_blocks: bool = True,
preserve_tables: bool = True,
preserve_lists: bool = True,
include_header_context: bool = True,
) -> list[dict[str, Any]]:
"""
Convenience function to chunk Markdown text.
Args:
text: Raw Markdown text
doc_id: Optional document ID
max_chunk_size: Maximum chunk size in characters
min_chunk_size: Minimum chunk size in characters
preserve_code_blocks: Whether to preserve code blocks
preserve_tables: Whether to preserve tables
preserve_lists: Whether to preserve lists
include_header_context: Whether to include header context
Returns:
List of chunk dictionaries
"""
chunker = MarkdownChunker(
max_chunk_size=max_chunk_size,
min_chunk_size=min_chunk_size,
preserve_code_blocks=preserve_code_blocks,
preserve_tables=preserve_tables,
preserve_lists=preserve_lists,
include_header_context=include_header_context,
)
chunks = chunker.chunk(text, doc_id)
return [chunk.to_dict() for chunk in chunks]

View File

@ -0,0 +1,178 @@
"""
Markdown parser with intelligent chunking.
[AC-AISVC-33] Markdown file parsing with structure-aware chunking.
"""
import logging
from pathlib import Path
from typing import Any
from app.services.document.base import (
DocumentParseException,
DocumentParser,
ParseResult,
)
from app.services.document.markdown_chunker import (
MarkdownChunker,
MarkdownElementType,
)
logger = logging.getLogger(__name__)
ENCODINGS_TO_TRY = ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]
class MarkdownParser(DocumentParser):
"""
Parser for Markdown files with intelligent chunking.
[AC-AISVC-33] Structure-aware parsing for Markdown documents.
Features:
- Header hierarchy extraction
- Code block preservation
- Table structure preservation
- List grouping
- Context-aware chunking
"""
def __init__(
self,
encoding: str = "utf-8",
max_chunk_size: int = 1000,
min_chunk_size: int = 100,
preserve_code_blocks: bool = True,
preserve_tables: bool = True,
preserve_lists: bool = True,
include_header_context: bool = True,
**kwargs: Any,
):
self._encoding = encoding
self._max_chunk_size = max_chunk_size
self._min_chunk_size = min_chunk_size
self._preserve_code_blocks = preserve_code_blocks
self._preserve_tables = preserve_tables
self._preserve_lists = preserve_lists
self._include_header_context = include_header_context
self._extra_config = kwargs
self._chunker = MarkdownChunker(
max_chunk_size=max_chunk_size,
min_chunk_size=min_chunk_size,
preserve_code_blocks=preserve_code_blocks,
preserve_tables=preserve_tables,
preserve_lists=preserve_lists,
include_header_context=include_header_context,
)
def _try_encodings(self, path: Path) -> tuple[str, str]:
"""
Try multiple encodings to read the file.
Returns: (text, encoding_used)
"""
for enc in ENCODINGS_TO_TRY:
try:
with open(path, encoding=enc) as f:
text = f.read()
logger.info(f"Successfully parsed Markdown with encoding: {enc}")
return text, enc
except (UnicodeDecodeError, LookupError):
continue
raise DocumentParseException(
"Failed to decode Markdown file with any known encoding",
file_path=str(path),
parser="markdown"
)
def parse(self, file_path: str | Path) -> ParseResult:
"""
Parse a Markdown file and extract structured content.
[AC-AISVC-33] Structure-aware parsing.
"""
path = Path(file_path)
if not path.exists():
raise DocumentParseException(
f"File not found: {path}",
file_path=str(path),
parser="markdown"
)
try:
text, encoding_used = self._try_encodings(path)
file_size = path.stat().st_size
line_count = text.count("\n") + 1
chunks = self._chunker.chunk(text, doc_id=path.stem)
header_count = sum(
1 for c in chunks
if c.element_type == MarkdownElementType.HEADER
)
code_block_count = sum(
1 for c in chunks
if c.element_type == MarkdownElementType.CODE_BLOCK
)
table_count = sum(
1 for c in chunks
if c.element_type == MarkdownElementType.TABLE
)
list_count = sum(
1 for c in chunks
if c.element_type == MarkdownElementType.LIST
)
logger.info(
f"Parsed Markdown: {path.name}, lines={line_count}, "
f"chars={len(text)}, chunks={len(chunks)}, "
f"headers={header_count}, code_blocks={code_block_count}, "
f"tables={table_count}, lists={list_count}"
)
return ParseResult(
text=text,
source_path=str(path),
file_size=file_size,
metadata={
"format": "markdown",
"line_count": line_count,
"encoding": encoding_used,
"chunk_count": len(chunks),
"structure": {
"headers": header_count,
"code_blocks": code_block_count,
"tables": table_count,
"lists": list_count,
},
"chunks": [chunk.to_dict() for chunk in chunks],
}
)
except DocumentParseException:
raise
except Exception as e:
raise DocumentParseException(
f"Failed to parse Markdown file: {e}",
file_path=str(path),
parser="markdown",
details={"error": str(e)}
)
def get_supported_extensions(self) -> list[str]:
"""Get supported file extensions."""
return [".md", ".markdown"]
def get_chunks(self, text: str, doc_id: str = "") -> list[dict[str, Any]]:
"""
Get structured chunks from Markdown text.
Args:
text: Markdown text content
doc_id: Optional document ID
Returns:
List of chunk dictionaries
"""
chunks = self._chunker.chunk(text, doc_id)
return [chunk.to_dict() for chunk in chunks]

View File

@ -0,0 +1,202 @@
"""
KB document metadata inference using LLM.
[AC-IDSMETA-XX] Infer metadata for markdown uploads when missing.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.base import LLMConfig
from app.services.llm.factory import LLMUsageType, get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.mid.metadata_discovery_tool import MetadataDiscoveryTool
logger = logging.getLogger(__name__)
_MAX_CONTENT_CHARS = 2000
def _extract_json_object(text: str) -> dict[str, Any] | None:
candidates: list[str] = []
code_block_match = re.search(r"```json\s*([\s\S]*?)\s*```", text, re.IGNORECASE)
if code_block_match:
candidates.append(code_block_match.group(1).strip())
fence_match = re.search(r"```\s*([\s\S]*?)\s*```", text)
if fence_match:
candidates.append(fence_match.group(1).strip())
brace_match = re.search(r"\{[\s\S]*\}", text)
if brace_match:
candidates.append(brace_match.group(0).strip())
for candidate in candidates:
if not candidate:
continue
try:
obj = json.loads(candidate)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
fixed = candidate.replace("'", '"')
try:
obj = json.loads(fixed)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
continue
return None
def _truncate_content(text: str) -> str:
if len(text) <= _MAX_CONTENT_CHARS * 2:
return text
head = text[:_MAX_CONTENT_CHARS]
tail = text[-_MAX_CONTENT_CHARS:]
return f"{head}\n\n...\n\n{tail}"
class KBMetadataInferenceService:
"""Infer document metadata based on markdown content."""
def __init__(self, session: AsyncSession, max_tokens: int = 512, temperature: float = 0.2):
self._session = session
self._max_tokens = max_tokens
self._temperature = temperature
async def infer_metadata(
self,
tenant_id: str,
content: str,
filename: str | None = None,
kb_id: str | None = None,
) -> dict[str, Any]:
field_def_service = MetadataFieldDefinitionService(self._session)
field_defs = await field_def_service.get_active_field_definitions(tenant_id, "kb_document")
if not field_defs:
return {}
discovery_tool = MetadataDiscoveryTool(self._session)
discovery_result = await discovery_tool.execute(
tenant_id=tenant_id,
kb_id=kb_id,
include_values=True,
top_n=5,
)
common_values_map = {
field.field_key: field.common_values
for field in (discovery_result.fields if discovery_result.success else [])
}
fields_payload = []
for field_def in field_defs:
fields_payload.append({
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options or [],
"common_values": common_values_map.get(field_def.field_key, []),
})
prompt = f"""你是知识库文档的元数据补全助手。请根据给定的 Markdown 内容,为文档补全元数据。
要求
1) 只能使用提供的字段
2) 如果字段有 options common_values优先从中选择最匹配的值
3) 不确定的字段不要填写
4) 输出必须是严格 JSON 对象只包含推断出的字段
5) 不要输出多余说明
可用字段定义JSON 数组
{json.dumps(fields_payload, ensure_ascii=False)}
文件名{filename or "unknown"}
Markdown 内容
{_truncate_content(content)}
""".strip()
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client(LLMUsageType.KB_PROCESSING)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Failed to get LLM client: {e}")
return {}
try:
response = await llm_client.generate(
messages=[
{"role": "system", "content": "你是严格的 JSON 生成器。"},
{"role": "user", "content": prompt},
],
config=LLMConfig(
max_tokens=self._max_tokens,
temperature=self._temperature,
),
)
except Exception as e:
logger.warning(f"[AC-IDSMETA-XX] Metadata inference failed: {e}")
return {}
if not response.content:
return {}
inferred = _extract_json_object(response.content)
if not inferred:
logger.warning("[AC-IDSMETA-XX] Metadata inference returned no JSON")
return {}
cleaned = self._clean_inferred_values(inferred, field_defs)
if not cleaned:
return {}
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, cleaned, "kb_document"
)
if not is_valid:
logger.warning(f"[AC-IDSMETA-XX] Inferred metadata validation failed: {validation_errors}")
return {}
return cleaned
def _clean_inferred_values(
self,
inferred: dict[str, Any],
field_defs: list[Any],
) -> dict[str, Any]:
field_map = {f.field_key: f for f in field_defs}
cleaned: dict[str, Any] = {}
for key, value in inferred.items():
field_def = field_map.get(key)
if not field_def:
continue
if value is None or value == "":
continue
if field_def.type in {"enum", "array_enum"} and field_def.options:
if field_def.type == "enum":
if value not in field_def.options:
continue
else:
if not isinstance(value, list):
continue
filtered = [v for v in value if v in field_def.options]
if not filtered:
continue
value = filtered
cleaned[key] = value
return cleaned

View File

@ -117,7 +117,7 @@ class LLMClient(ABC):
@abstractmethod
async def generate(
self,
messages: list[dict[str, str]],
messages: list[dict[str, Any]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
@ -145,7 +145,7 @@ class LLMClient(ABC):
@abstractmethod
async def stream_generate(
self,
messages: list[dict[str, str]],
messages: list[dict[str, Any]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,

View File

@ -8,6 +8,7 @@ Design pattern: Factory pattern for pluggable LLM providers.
import json
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
@ -23,6 +24,23 @@ LLM_CONFIG_FILE = Path("config/llm_config.json")
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
class LLMUsageType(str, Enum):
"""LLM usage type for different scenarios."""
CHAT = "chat"
KB_PROCESSING = "kb_processing"
LLM_USAGE_DISPLAY_NAMES: dict[LLMUsageType, str] = {
LLMUsageType.CHAT: "对话模型",
LLMUsageType.KB_PROCESSING: "知识库处理模型",
}
LLM_USAGE_DESCRIPTIONS: dict[LLMUsageType, str] = {
LLMUsageType.CHAT: "用于 Agent 对话、问答等交互场景",
LLMUsageType.KB_PROCESSING: "用于知识库文档上传、元数据推断、文档处理等场景",
}
@dataclass
class LLMProviderInfo:
"""Information about an LLM provider."""
@ -284,6 +302,7 @@ class LLMConfigManager:
"""
Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
Supports multiple LLM usage types (chat, kb_processing).
"""
def __init__(self):
@ -293,8 +312,7 @@ class LLMConfigManager:
self._settings = settings
self._redis_client: redis.Redis | None = None
self._current_provider: str = settings.llm_provider
self._current_config: dict[str, Any] = {
default_config = {
"api_key": settings.llm_api_key,
"base_url": settings.llm_base_url,
"model": settings.llm_model,
@ -303,11 +321,42 @@ class LLMConfigManager:
"timeout_seconds": settings.llm_timeout_seconds,
"max_retries": settings.llm_max_retries,
}
self._client: LLMClient | None = None
self._configs: dict[LLMUsageType, dict[str, Any]] = {
LLMUsageType.CHAT: {
"provider": settings.llm_provider,
"config": default_config.copy(),
},
LLMUsageType.KB_PROCESSING: {
"provider": settings.llm_provider,
"config": default_config.copy(),
},
}
self._clients: dict[LLMUsageType, LLMClient | None] = {
LLMUsageType.CHAT: None,
LLMUsageType.KB_PROCESSING: None,
}
self._load_from_redis()
self._load_from_file()
@property
def chat_provider(self) -> str:
return self._configs[LLMUsageType.CHAT]["provider"]
@property
def kb_processing_provider(self) -> str:
return self._configs[LLMUsageType.KB_PROCESSING]["provider"]
@property
def chat_config(self) -> dict[str, Any]:
return self._configs[LLMUsageType.CHAT]["config"].copy()
@property
def kb_processing_config(self) -> dict[str, Any]:
return self._configs[LLMUsageType.KB_PROCESSING]["config"].copy()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
@ -322,11 +371,19 @@ class LLMConfigManager:
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
for usage_type in LLMUsageType:
type_key = usage_type.value
if type_key in saved:
self._configs[usage_type] = {
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
}
elif "provider" in saved:
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from Redis")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
@ -341,50 +398,42 @@ class LLMConfigManager:
encoding="utf-8",
decode_responses=True,
)
save_data = {
usage_type.value: {
"provider": config["provider"],
"config": config["config"],
}
for usage_type, config in self._configs.items()
}
self._redis_client.set(
LLM_CONFIG_REDIS_KEY,
json.dumps({
"provider": self._current_provider,
"config": self._current_config,
}, ensure_ascii=False),
json.dumps(save_data, ensure_ascii=False),
)
logger.info(f"[AC-ASA-16] Saved LLM config to Redis: provider={self._current_provider}")
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to Redis")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from Redis: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if LLM_CONFIG_FILE.exists():
with open(LLM_CONFIG_FILE, encoding='utf-8') as f:
saved = json.load(f)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from file: provider={self._current_provider}")
for usage_type in LLMUsageType:
type_key = usage_type.value
if type_key in saved:
self._configs[usage_type] = {
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
}
elif "provider" in saved:
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from file")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
@ -392,26 +441,48 @@ class LLMConfigManager:
"""Save configuration to file."""
try:
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
save_data = {
usage_type.value: {
"provider": config["provider"],
"config": config["config"],
}
for usage_type, config in self._configs.items()
}
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._current_provider,
"config": self._current_config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"[AC-ASA-16] Saved LLM config to file: provider={self._current_provider}")
json.dump(save_data, f, indent=2, ensure_ascii=False)
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to file")
except Exception as e:
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
def get_current_config(self) -> dict[str, Any]:
"""Get current LLM configuration."""
def get_current_config(self, usage_type: LLMUsageType | None = None) -> dict[str, Any]:
"""Get current LLM configuration for specified usage type or all configs."""
if usage_type:
config = self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
return {
"usage_type": usage_type.value,
"provider": config["provider"],
"config": config["config"].copy(),
}
return {
"provider": self._current_provider,
"config": self._current_config.copy(),
usage_type.value: {
"provider": config["provider"],
"config": config["config"].copy(),
}
for usage_type, config in self._configs.items()
}
def get_config_for_usage(self, usage_type: LLMUsageType) -> dict[str, Any]:
"""Get configuration for a specific usage type."""
return self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
async def update_config(
self,
provider: str,
config: dict[str, Any],
usage_type: LLMUsageType | None = None,
) -> bool:
"""
Update LLM configuration.
@ -420,6 +491,7 @@ class LLMConfigManager:
Args:
provider: Provider name
config: New configuration
usage_type: Usage type to update (None = update all)
Returns:
True if update successful
@ -430,17 +502,46 @@ class LLMConfigManager:
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._client:
await self._client.close()
self._client = None
target_usage_types = [usage_type] if usage_type else list(LLMUsageType)
self._current_provider = provider
self._current_config = validated_config
for ut in target_usage_types:
if self._clients[ut]:
await self._clients[ut].close()
self._clients[ut] = None
self._configs[ut]["provider"] = provider
self._configs[ut]["config"] = validated_config.copy()
self._save_to_redis()
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}, usage={usage_type or 'all'}")
return True
async def update_usage_config(
self,
usage_type: LLMUsageType,
provider: str,
config: dict[str, Any],
) -> bool:
"""Update configuration for a specific usage type."""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._clients[usage_type]:
await self._clients[usage_type].close()
self._clients[usage_type] = None
self._configs[usage_type]["provider"] = provider
self._configs[usage_type]["config"] = validated_config
self._save_to_redis()
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: usage={usage_type.value}, provider={provider}")
return True
def _validate_config(
@ -462,20 +563,32 @@ class LLMConfigManager:
raise ValueError(f"Missing required config: {key}")
return validated
def get_client(self) -> LLMClient:
"""Get or create LLM client with current config."""
if self._client is None:
self._client = LLMProviderFactory.create_client(
self._current_provider,
self._current_config,
def get_client(self, usage_type: LLMUsageType | None = None) -> LLMClient:
"""Get or create LLM client with config for specified usage type."""
ut = usage_type or LLMUsageType.CHAT
if self._clients[ut] is None:
config = self._configs[ut]
self._clients[ut] = LLMProviderFactory.create_client(
config["provider"],
config["config"],
)
return self._client
return self._clients[ut]
def get_chat_client(self) -> LLMClient:
"""Get LLM client for chat/dialogue."""
return self.get_client(LLMUsageType.CHAT)
def get_kb_processing_client(self) -> LLMClient:
"""Get LLM client for KB processing."""
return self.get_client(LLMUsageType.KB_PROCESSING)
async def test_connection(
self,
test_prompt: str = "你好,请简单介绍一下自己。",
provider: str | None = None,
config: dict[str, Any] | None = None,
usage_type: LLMUsageType | None = None,
) -> dict[str, Any]:
"""
Test LLM connection.
@ -485,14 +598,20 @@ class LLMConfigManager:
test_prompt: Test prompt to send
provider: Optional provider to test (uses current if not specified)
config: Optional config to test (uses current if not specified)
usage_type: Usage type for config lookup
Returns:
Test result with success status, response, and metrics
"""
import time
test_provider = provider or self._current_provider
test_config = config if config else self._current_config
if usage_type and not provider:
usage_config = self._configs[usage_type]
test_provider = usage_config["provider"]
test_config = usage_config["config"]
else:
test_provider = provider or self._configs[LLMUsageType.CHAT]["provider"]
test_config = config if config else self._configs[LLMUsageType.CHAT]["config"]
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
@ -533,10 +652,11 @@ class LLMConfigManager:
}
async def close(self) -> None:
"""Close the current client."""
if self._client:
await self._client.close()
self._client = None
"""Close all clients."""
for client in self._clients.values():
if client:
await client.close()
self._clients = {ut: None for ut in LLMUsageType}
_llm_config_manager: LLMConfigManager | None = None

View File

@ -99,7 +99,7 @@ class OpenAIClient(LLMClient):
def _build_request_body(
self,
messages: list[dict[str, str]],
messages: list[dict[str, Any]],
config: LLMConfig,
stream: bool = False,
tools: list[ToolDefinition] | None = None,
@ -133,7 +133,7 @@ class OpenAIClient(LLMClient):
)
async def generate(
self,
messages: list[dict[str, str]],
messages: list[dict[str, Any]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
@ -255,7 +255,7 @@ class OpenAIClient(LLMClient):
async def stream_generate(
self,
messages: list[dict[str, str]],
messages: list[dict[str, Any]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,

View File

@ -0,0 +1,496 @@
"""
Metadata Auto Inference Service.
自动推断文档元数据的服务支持图片和文本格式
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm.factory import get_llm_config_manager
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.metadata_cache_service import get_metadata_cache_service
logger = logging.getLogger(__name__)
_field_definitions_cache: dict[str, list[Any]] = {}
_cache_ttl_seconds = 300
_last_cache_refresh: dict[str, float] = {}
@dataclass
class InferenceFieldContext:
"""推断字段的上下文信息"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
description: str | None = None
@dataclass
class AutoInferenceResult:
"""自动推断结果"""
inferred_metadata: dict[str, Any]
confidence_scores: dict[str, float]
raw_response: str
success: bool
error_message: str | None = None
METADATA_INFERENCE_SYSTEM_PROMPT = """你是一个专业的文档元数据分析助手。你的任务是根据文档内容,自动推断并填写元数据字段。
## 输出要求
请严格按照以下 JSON 格式输出不要添加任何其他内容
```json
{
"inferred_metadata": {
"字段键名1": "推断的值1",
"字段键名2": "推断的值2"
},
"confidence_scores": {
"字段键名1": 0.95,
"字段键名2": 0.80
}
}
```
## 推断规则
1. **仔细分析文档内容**根据文档的主题关键词上下文来推断元数据
2. **遵循字段定义**
- 对于枚举类型(enum)必须从给定的选项中选择
- 对于数组枚举类型(array_enum)可以选择多个选项
- 对于数字类型(number)输出数字
- 对于布尔类型(boolean)输出 true false
- 对于文本类型(text)输出字符串
3. **置信度评分**
- 0.9-1.0: 非常确定
- 0.7-0.9: 比较确定
- 0.5-0.7: 有一定依据但不确定
- 0.0-0.5: 猜测或无法确定
4. **无法推断时**如果无法从文档内容中合理推断某个字段可以不填写该字段
## 注意事项
- 必须严格按照字段定义的类型和选项填写
- 不要编造不存在的选项值
- 保持客观基于文档内容推断"""
class MetadataAutoInferenceService:
"""
元数据自动推断服务
功能
1. 获取租户配置的元数据字段定义
2. 使用 LLM 根据文档内容自动推断元数据
3. 验证推断结果符合字段定义
使用场景
- 图片上传时自动推断元数据
- Markdown/文本上传时自动推断元数据
"""
def __init__(
self,
session: AsyncSession,
model: str | None = None,
max_tokens: int = 1024,
timeout_seconds: int = 60,
):
self._session = session
self._model = model
self._max_tokens = max_tokens
self._timeout_seconds = timeout_seconds
self._field_def_service = MetadataFieldDefinitionService(session)
async def infer_metadata(
self,
tenant_id: str,
content: str,
scope: str = "kb_document",
existing_metadata: dict[str, Any] | None = None,
image_base64: str | None = None,
mime_type: str | None = None,
) -> AutoInferenceResult:
"""
自动推断文档元数据
Args:
tenant_id: 租户 ID
content: 文档文本内容
scope: 元数据作用范围
existing_metadata: 已有的元数据用户手动填写的会覆盖推断结果
image_base64: 图片的 base64 编码如果是图片
mime_type: 图片的 MIME 类型
Returns:
AutoInferenceResult 包含推断的元数据
"""
logger.info(
f"[MetadataAutoInference] Starting inference: tenant={tenant_id}, "
f"content_length={len(content)}, scope={scope}"
)
field_definitions = await self._get_field_definitions_with_cache(tenant_id, scope)
if not field_definitions:
logger.info(f"[MetadataAutoInference] No field definitions found for tenant={tenant_id}")
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=True,
error_message="No field definitions configured",
)
field_contexts = self._build_field_contexts(field_definitions)
if not field_contexts:
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=True,
)
user_prompt = self._build_user_prompt(content, field_contexts, existing_metadata)
try:
if image_base64 and mime_type:
raw_response = await self._call_multimodal_llm(
user_prompt, image_base64, mime_type
)
else:
raw_response = await self._call_text_llm(user_prompt)
result = self._parse_llm_response(raw_response, field_contexts)
if existing_metadata:
result.inferred_metadata.update(existing_metadata)
logger.info(
f"[MetadataAutoInference] Inference completed: "
f"inferred_fields={list(result.inferred_metadata.keys())}, "
f"avg_confidence={sum(result.confidence_scores.values()) / len(result.confidence_scores) if result.confidence_scores else 0:.2f}"
)
return result
except Exception as e:
logger.error(f"[MetadataAutoInference] Inference failed: {e}")
return AutoInferenceResult(
inferred_metadata=existing_metadata or {},
confidence_scores={},
raw_response="",
success=False,
error_message=str(e),
)
def _build_field_contexts(
self,
field_definitions: list[Any],
) -> list[InferenceFieldContext]:
"""构建字段上下文列表"""
contexts = []
for f in field_definitions:
ctx = InferenceFieldContext(
field_key=f.field_key,
label=f.label,
type=f.type,
required=f.required,
options=f.options,
description=getattr(f, 'description', None),
)
contexts.append(ctx)
return contexts
async def _get_field_definitions_with_cache(
self,
tenant_id: str,
scope: str,
) -> list[Any]:
"""
获取字段定义带缓存
优先级
1. Redis 缓存
2. 本地内存缓存
3. 数据库查询
Args:
tenant_id: 租户 ID
scope: 作用范围
Returns:
字段定义列表
"""
import time
cache_key = f"{tenant_id}:{scope}"
try:
redis_cache = await get_metadata_cache_service()
cached_fields = await redis_cache.get_fields(tenant_id)
if cached_fields:
logger.info(f"[MetadataAutoInference] Redis cache hit for tenant={tenant_id}")
return [self._dict_to_field_def(f) for f in cached_fields]
except Exception as e:
logger.warning(f"[MetadataAutoInference] Redis cache error: {e}")
current_time = time.time()
if cache_key in _field_definitions_cache:
last_refresh = _last_cache_refresh.get(cache_key, 0)
if current_time - last_refresh < _cache_ttl_seconds:
logger.info(f"[MetadataAutoInference] Local cache hit for tenant={tenant_id}")
return _field_definitions_cache[cache_key]
logger.info(f"[MetadataAutoInference] Cache miss, querying database for tenant={tenant_id}")
field_definitions = await self._field_def_service.get_active_field_definitions(
tenant_id, scope
)
_field_definitions_cache[cache_key] = field_definitions
_last_cache_refresh[cache_key] = current_time
try:
redis_cache = await get_metadata_cache_service()
await redis_cache.set_fields(
tenant_id,
[self._field_def_to_dict(f) for f in field_definitions]
)
except Exception as e:
logger.warning(f"[MetadataAutoInference] Failed to update Redis cache: {e}")
return field_definitions
def _field_def_to_dict(self, field_def: Any) -> dict[str, Any]:
"""将字段定义转换为字典"""
return {
"field_key": field_def.field_key,
"label": field_def.label,
"type": field_def.type,
"required": field_def.required,
"options": field_def.options,
}
def _dict_to_field_def(self, data: dict[str, Any]) -> Any:
"""将字典转换为字段定义对象"""
from dataclasses import dataclass
@dataclass
class CachedFieldDefinition:
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
return CachedFieldDefinition(
field_key=data["field_key"],
label=data["label"],
type=data["type"],
required=data["required"],
options=data.get("options"),
)
def _build_user_prompt(
self,
content: str,
field_contexts: list[InferenceFieldContext],
existing_metadata: dict[str, Any] | None = None,
) -> str:
"""构建用户提示词"""
field_descriptions = []
for ctx in field_contexts:
desc = f"- **{ctx.label}** ({ctx.field_key})"
desc += f"\n - 类型: {ctx.type}"
desc += f"\n - 必填: {'' if ctx.required else ''}"
if ctx.options:
desc += f"\n - 可选值: {', '.join(ctx.options)}"
if existing_metadata and ctx.field_key in existing_metadata:
desc += f"\n - 已有值: {existing_metadata[ctx.field_key]}"
field_descriptions.append(desc)
fields_text = "\n".join(field_descriptions)
prompt = f"""请分析以下文档内容,并推断相应的元数据字段。
## 待推断的字段定义
{fields_text}
## 文档内容
{content[:4000]}
请根据文档内容推断上述字段的值并输出 JSON 格式的结果"""
return prompt
async def _call_text_llm(self, prompt: str) -> str:
"""调用文本 LLM"""
manager = get_llm_config_manager()
client = manager.get_kb_processing_client()
config = manager.kb_processing_config
model = self._model or config.get("model", "gpt-4o-mini")
from app.services.llm.base import LLMConfig
llm_config = LLMConfig(
model=model,
max_tokens=self._max_tokens,
temperature=0.3,
timeout_seconds=self._timeout_seconds,
)
messages = [
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
response = await client.generate(messages=messages, config=llm_config)
return response.content or ""
async def _call_multimodal_llm(
self,
prompt: str,
image_base64: str,
mime_type: str,
) -> str:
"""调用多模态 LLM"""
manager = get_llm_config_manager()
client = manager.get_kb_processing_client()
config = manager.kb_processing_config
model = self._model or config.get("model", "gpt-4o-mini")
from app.services.llm.base import LLMConfig
llm_config = LLMConfig(
model=model,
max_tokens=self._max_tokens,
temperature=0.3,
timeout_seconds=self._timeout_seconds,
)
messages = [
{"role": "system", "content": METADATA_INFERENCE_SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_base64}",
},
},
],
},
]
response = await client.generate(messages=messages, config=llm_config)
return response.content or ""
def _parse_llm_response(
self,
response: str,
field_contexts: list[InferenceFieldContext],
) -> AutoInferenceResult:
"""解析 LLM 响应"""
try:
json_str = self._extract_json(response)
data = json.loads(json_str)
inferred_metadata = data.get("inferred_metadata", {})
confidence_scores = data.get("confidence_scores", {})
field_map = {ctx.field_key: ctx for ctx in field_contexts}
validated_metadata = {}
validated_scores = {}
for field_key, value in inferred_metadata.items():
if field_key not in field_map:
continue
ctx = field_map[field_key]
validated_value = self._validate_field_value(ctx, value)
if validated_value is not None:
validated_metadata[field_key] = validated_value
validated_scores[field_key] = confidence_scores.get(field_key, 0.5)
return AutoInferenceResult(
inferred_metadata=validated_metadata,
confidence_scores=validated_scores,
raw_response=response,
success=True,
)
except json.JSONDecodeError as e:
logger.warning(f"[MetadataAutoInference] Failed to parse JSON: {e}")
return AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response=response,
success=False,
error_message=f"JSON parse error: {e}",
)
def _validate_field_value(
self,
ctx: InferenceFieldContext,
value: Any,
) -> Any:
"""验证并转换字段值"""
if value is None:
return None
from app.models.entities import MetadataFieldType
if ctx.type == MetadataFieldType.NUMBER.value:
try:
return float(value) if isinstance(value, str) else value
except (ValueError, TypeError):
return None
elif ctx.type == MetadataFieldType.BOOLEAN.value:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes")
return bool(value)
elif ctx.type == MetadataFieldType.ENUM.value:
if ctx.options and value in ctx.options:
return value
return None
elif ctx.type == MetadataFieldType.ARRAY_ENUM.value:
if not isinstance(value, list):
value = [value] if value else []
if ctx.options:
return [v for v in value if v in ctx.options]
return value
else:
return str(value) if value is not None else None
def _extract_json(self, content: str) -> str:
"""从响应中提取 JSON"""
content = content.strip()
if content.startswith("{") and content.endswith("}"):
return content
json_start = content.find("{")
json_end = content.rfind("}")
if json_start != -1 and json_end != -1 and json_end > json_start:
return content[json_start:json_end + 1]
return content

View File

@ -14,6 +14,7 @@ from .metrics_collector import MetricsCollector, SessionMetrics, AggregatedMetri
from .tool_registry import ToolRegistry, ToolDefinition, ToolExecutionResult, get_tool_registry, init_tool_registry
from .tool_call_recorder import ToolCallRecorder, ToolCallStatistics, get_tool_call_recorder
from .memory_adapter import MemoryAdapter, UserMemory
from .memory_summary_generator import MemorySummaryGenerator
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
@ -55,6 +56,7 @@ __all__ = [
"get_tool_call_recorder",
"MemoryAdapter",
"UserMemory",
"MemorySummaryGenerator",
"DefaultKbToolRunner",
"KbToolResult",
"KbToolConfig",

View File

@ -35,7 +35,7 @@ from app.models.mid.schemas import (
from app.services.llm.base import ToolDefinition
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
from app.services.mid.tool_converter import convert_tool_to_llm_format, convert_tools_to_llm_format, build_tool_result_message
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
@ -482,27 +482,35 @@ class AgentOrchestrator:
**步骤3调用 kb_search_dynamic 进行搜索**
- 使用步骤1获取的元数据字段构造 context 参数
- scene 参数必须从元数据字段的 kb_scene 常见值中选择不要硬编码
- scene 参数会自动注入到 context.kb_scene无需手动在 context 中设置 kb_scene
- scene 参数应从元数据字段的 kb_scene 常见值中选择
**kb_scene 自动注入说明**
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
- AI 只需在 context 中设置其他过滤字段 gradesubject
- 不要在 context 中重复设置 kb_scene系统会自动处理
**示例流程**
1. 调用 `list_document_metadata_fields` 获取字段信息
2. 根据返回结果发现可用字段grade年级subject学科kb_scene场景
3. 分析用户问题"三年级语文怎么学"确定过滤条件grade="三年级", subject="语文"
4. kb_scene 的常见值中选择合适的 scene"学习方案"
5. 调用 `kb_search_dynamic`传入构造好的 context scene
5. 调用 `kb_search_dynamic`传入 scene="学习方案"context={"grade": "三年级", "subject": "语文"}
6. 系统自动将 scene 注入到 context.kb_scene
## 注意事项
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields
- **不要** context 中手动设置 kb_scene系统会自动从 scene 参数注入
"""
if not self._template_service or not self._tenant_id:
return default_prompt
try:
from app.core.database import get_session
from app.core.database import async_session_maker
from app.core.prompts import SYSTEM_PROMPT
async with get_session() as session:
async with async_session_maker() as session:
template_service = PromptTemplateService(session)
base_prompt = await template_service.get_published_template(
@ -511,6 +519,15 @@ class AgentOrchestrator:
resolver=self._variable_resolver,
)
if not base_prompt or base_prompt == SYSTEM_PROMPT:
base_prompt = await template_service.get_published_template(
tenant_id=self._tenant_id,
scene="agent_react",
resolver=self._variable_resolver,
)
if base_prompt and base_prompt != SYSTEM_PROMPT:
logger.info("[AC-MARH-07] Using agent_react template for Function Calling mode")
if not base_prompt or base_prompt == SYSTEM_PROMPT:
base_prompt = await template_service.get_published_template(
tenant_id=self._tenant_id,
@ -519,7 +536,7 @@ class AgentOrchestrator:
)
if not base_prompt or base_prompt == SYSTEM_PROMPT:
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
logger.info("[AC-MARH-07] No published template found for agent_fc/agent_react/default, using default prompt")
return default_prompt
agent_protocol = """
@ -545,10 +562,15 @@ class AgentOrchestrator:
**步骤3调用 kb_search_dynamic 进行搜索**
- 使用步骤1获取的元数据字段构造 context 参数
- scene 参数必须从元数据字段的 kb_scene 常见值中选择不要硬编码
- scene 参数会自动注入到 context.kb_scene无需手动在 context 中设置 kb_scene
**kb_scene 自动注入说明**
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
- AI 只需在 context 中设置其他过滤字段 gradesubject
## 注意事项
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields
- **不要** context 中手动设置 kb_scene系统会自动从 scene 参数注入
"""
final_prompt = base_prompt + agent_protocol

View File

@ -127,6 +127,8 @@ class KbSearchDynamicTool:
"知识库动态检索工具。"
"根据租户配置的元数据字段定义,动态构建检索过滤器。"
"支持必填字段检测和可观测降级。"
"重要context 参数中应包含 kb_scene 字段用于场景过滤,"
"系统会自动从外部请求的 scene 参数注入到 context.kb_scene。"
)
def get_tool_schema(self) -> dict[str, Any]:
@ -146,7 +148,7 @@ class KbSearchDynamicTool:
},
"scene": {
"type": "string",
"description": "场景标识'open_consult', 'intent_match'",
"description": "场景标识'open_consult', 'intent_match'),系统会自动将其注入到 context.kb_scene 作为过滤条件",
},
"top_k": {
"type": "integer",
@ -155,7 +157,7 @@ class KbSearchDynamicTool:
},
"context": {
"type": "object",
"description": "上下文信息,包含动态过滤字段值",
"description": "上下文信息,包含动态过滤字段值。重要字段kb_scene场景过滤由系统自动从 scene 参数注入、grade年级、subject学科",
},
},
"required": ["query"],
@ -299,13 +301,14 @@ class KbSearchDynamicTool:
[AC-MARH-05] 执行 KB 动态检索
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
[Step-KB-Binding] 支持步骤级别的知识库约束
[KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
Args:
query: 检索查询
tenant_id: 租户 ID
scene: 场景标识默认值会被 context 中的 scene 覆盖
scene: 场景标识会自动注入到 context.kb_scene
top_k: 返回数量
context: 上下文包含动态过滤值包括 scene
context: 上下文包含动态过滤值
slot_state: 预聚合的槽位状态可选优先使用
step_kb_config: 步骤级别的知识库配置可选
slot_policy: 槽位策略flow_strict=流程严格模式agent_relaxed=通用问答宽松模式
@ -325,6 +328,25 @@ class KbSearchDynamicTool:
effective_context = dict(context) if context else {}
effective_scene = effective_context.get("scene", scene)
logger.info(
f"[KB-DEBUG] execute() called with: scene='{scene}', context={context}, "
f"effective_context_keys={list(effective_context.keys())}"
)
# [KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
# 优先级context.kb_scene > context.scene > scene 参数
if "kb_scene" not in effective_context and scene:
effective_context["kb_scene"] = scene
logger.info(
f"[KB-SCENE-INJECT] Injected scene='{scene}' into context.kb_scene, "
f"effective_context now={effective_context}"
)
else:
logger.info(
f"[KB-SCENE-INJECT] Skipped injection: kb_scene in context={('kb_scene' in effective_context)}, "
f"scene is empty={not scene}"
)
# [Step-KB-Binding] 记录步骤知识库约束
step_kb_binding_info: dict[str, Any] = {}
@ -445,8 +467,8 @@ class KbSearchDynamicTool:
status=ToolCallStatus.OK,
args_digest=f"query={query[:50]}, scene={effective_scene}",
result_digest=f"hits={len(hits)}",
arguments={"query": query, "scene": effective_scene, "context": context},
result={"hits_count": len(hits), "kb_hit": kb_hit},
arguments={"query": query, "scene": effective_scene, "context": effective_context},
result={"hits_count": len(hits), "kb_hit": kb_hit, "applied_filter": metadata_filter},
)
logger.info(
@ -482,7 +504,7 @@ class KbSearchDynamicTool:
duration_ms=duration_ms,
status=ToolCallStatus.TIMEOUT,
error_code="KB_TIMEOUT",
arguments={"query": query, "scene": effective_scene, "context": context},
arguments={"query": query, "scene": effective_scene, "context": effective_context},
)
return KbSearchDynamicResult(
@ -509,7 +531,7 @@ class KbSearchDynamicTool:
duration_ms=duration_ms,
status=ToolCallStatus.ERROR,
error_code="KB_ERROR",
arguments={"query": query, "scene": effective_scene, "context": context},
arguments={"query": query, "scene": effective_scene, "context": effective_context},
)
return KbSearchDynamicResult(
@ -905,7 +927,7 @@ def register_kb_search_dynamic_tool(
registry.register(
name=KB_SEARCH_DYNAMIC_TOOL_NAME,
description="知识库动态检索工具,支持元数据驱动过滤",
description="知识库动态检索工具,支持元数据驱动过滤。系统会自动将 scene 参数注入到 context.kb_scene 进行场景过滤。",
handler=handler,
tool_type=RegistryToolType.INTERNAL,
version="1.0.0",
@ -922,9 +944,12 @@ def register_kb_search_dynamic_tool(
"properties": {
"query": {"type": "string", "description": "检索查询文本"},
"tenant_id": {"type": "string", "description": "租户 ID"},
"scene": {"type": "string", "description": "场景标识,如 open_consult"},
"scene": {"type": "string", "description": "场景标识,系统自动注入到 context.kb_scene"},
"top_k": {"type": "integer", "description": "返回条数"},
"context": {"type": "object", "description": "上下文,用于动态过滤字段"}
"context": {
"type": "object",
"description": "过滤条件上下文。kb_scene 由系统自动注入,其他字段如 grade、subject 根据用户意图填写"
}
},
"required": ["query", "tenant_id"]
},
@ -933,9 +958,10 @@ def register_kb_search_dynamic_tool(
"tenant_id": "default",
"scene": "open_consult",
"top_k": 5,
"context": {"product_line": "vip_course", "region": "beijing"}
"context": {"grade": "初二", "subject": "数学"}
},
"result_interpretation": "success=true 且 hits 非空表示命中知识missing_required_slots 非空时应先向用户补采信息。"
"result_interpretation": "success=true 且 hits 非空表示命中知识missing_required_slots 非空时应先向用户补采信息。",
"kb_scene_injection": "系统会自动将 scene 参数值注入到 context.kb_scene 字段用于知识库场景过滤。AI 无需手动在 context 中设置 kb_scene。"
},
)

View File

@ -9,6 +9,8 @@ Reference:
"""
import asyncio
import inspect
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
@ -17,6 +19,7 @@ from typing import Any, Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import UserMemory as UserMemoryEntity
from app.models.mid.memory import (
MemoryFact,
MemoryProfile,
@ -58,11 +61,11 @@ class UserMemory:
class MemoryAdapter:
"""
[AC-IDMP-13/14] 记忆适配器
功能
1. recall: 在对话响应前召回用户记忆profile/facts/preferences
2. update: 在对话完成后异步更新用户记忆
设计原则
- recall 失败不阻断主链路降级处理
- update 异步执行不阻塞主响应
@ -90,17 +93,9 @@ class MemoryAdapter:
) -> RecallResponse:
"""
[AC-IDMP-13] 召回用户记忆
在响应前执行注入基础属性事实记忆与偏好记忆
失败时返回空记忆不阻断主链路
Args:
user_id: 用户ID
session_id: 会话ID
tenant_id: 租户ID可选
Returns:
RecallResponse: 包含 profile/facts/preferences 的响应
"""
try:
return await asyncio.wait_for(
@ -126,9 +121,6 @@ class MemoryAdapter:
session_id: str,
tenant_id: str | None,
) -> RecallResponse:
"""
内部召回实现
"""
profile = await self._recall_profile(user_id, tenant_id)
facts = await self._recall_facts(user_id, tenant_id)
preferences = await self._recall_preferences(user_id, tenant_id)
@ -152,7 +144,6 @@ class MemoryAdapter:
user_id: str,
tenant_id: str | None,
) -> MemoryProfile | None:
"""召回用户基础属性"""
return MemoryProfile(
grade="初一",
region="北京",
@ -165,7 +156,6 @@ class MemoryAdapter:
user_id: str,
tenant_id: str | None,
) -> list[MemoryFact]:
"""召回用户事实记忆"""
return [
MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0),
MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9),
@ -177,7 +167,6 @@ class MemoryAdapter:
user_id: str,
tenant_id: str | None,
) -> MemoryPreferences | None:
"""召回用户偏好"""
return MemoryPreferences(
tone="friendly",
focus_subjects=["数学", "物理"],
@ -189,8 +178,16 @@ class MemoryAdapter:
user_id: str,
tenant_id: str | None,
) -> str | None:
"""召回最近会话摘要"""
return "上次讨论了数学学习计划,用户对课程安排比较满意"
if not tenant_id:
return None
stmt = select(UserMemoryEntity).where(
UserMemoryEntity.tenant_id == tenant_id,
UserMemoryEntity.user_id == user_id,
)
result = await self._session.execute(stmt)
record = result.scalar_one_or_none()
return record.summary if record else None
async def update(
self,
@ -200,22 +197,6 @@ class MemoryAdapter:
summary: str | None = None,
tenant_id: str | None = None,
) -> bool:
"""
[AC-IDMP-14] 异步更新用户记忆
在对话完成后异步执行不阻塞主响应
包含会话摘要的回写
Args:
user_id: 用户ID
session_id: 会话ID
messages: 本轮对话消息
summary: 会话摘要可选
tenant_id: 租户ID
Returns:
bool: 是否成功提交更新任务
"""
request = UpdateRequest(
user_id=user_id,
session_id=session_id,
@ -242,9 +223,6 @@ class MemoryAdapter:
request: UpdateRequest,
tenant_id: str | None,
) -> None:
"""
内部更新实现
"""
try:
await asyncio.wait_for(
self._do_update(request, tenant_id),
@ -270,11 +248,19 @@ class MemoryAdapter:
request: UpdateRequest,
tenant_id: str | None,
) -> None:
"""
执行实际的记忆更新
"""
if request.summary:
await self._save_summary(request.user_id, request.summary, tenant_id)
summary_payload = self._parse_summary_payload(request.summary)
if summary_payload:
await self._save_summary(
request.user_id,
summary_payload.get("summary", ""),
tenant_id,
facts=summary_payload.get("facts"),
preferences=summary_payload.get("preferences"),
open_issues=summary_payload.get("open_issues"),
)
else:
await self._save_summary(request.user_id, request.summary, tenant_id)
await self._extract_and_save_facts(
request.user_id, request.messages, tenant_id
@ -285,9 +271,41 @@ class MemoryAdapter:
user_id: str,
summary: str,
tenant_id: str | None,
facts: list[str] | None = None,
preferences: dict[str, Any] | None = None,
open_issues: list[str] | None = None,
) -> None:
"""保存会话摘要"""
pass
if not tenant_id:
logger.warning("[AC-IDMP-14] Missing tenant_id when saving summary")
return
stmt = select(UserMemoryEntity).where(
UserMemoryEntity.tenant_id == tenant_id,
UserMemoryEntity.user_id == user_id,
)
result = await self._session.execute(stmt)
record = result.scalar_one_or_none()
if record:
record.summary = summary
record.facts = facts or record.facts
record.preferences = preferences or record.preferences
record.open_issues = open_issues or record.open_issues
record.summary_version = (record.summary_version or 0) + 1
record.updated_at = datetime.utcnow()
else:
record = UserMemoryEntity(
tenant_id=tenant_id,
user_id=user_id,
summary=summary,
facts=facts,
preferences=preferences,
open_issues=open_issues,
summary_version=1,
)
self._session.add(record)
await self._session.flush()
async def _extract_and_save_facts(
self,
@ -295,8 +313,25 @@ class MemoryAdapter:
messages: list[dict[str, Any]],
tenant_id: str | None,
) -> None:
"""从消息中提取并保存事实"""
pass
if not tenant_id:
return
for msg in messages:
payload = msg.get("memory_payload") or msg.get("summary_payload")
if not payload:
continue
parsed = self._parse_summary_payload(payload)
if not parsed:
continue
await self._save_summary(
user_id=user_id,
summary=parsed.get("summary", ""),
tenant_id=tenant_id,
facts=parsed.get("facts"),
preferences=parsed.get("preferences"),
open_issues=parsed.get("open_issues"),
)
break
async def update_with_summary_generation(
self,
@ -305,41 +340,92 @@ class MemoryAdapter:
messages: list[dict[str, Any]],
tenant_id: str | None = None,
summary_generator: Callable | None = None,
recent_turns: int = 8,
) -> bool:
"""
[AC-IDMP-14] 带摘要生成的记忆更新
如果未提供摘要会尝试生成摘要后回写
"""
request = UpdateRequest(
user_id=user_id,
session_id=session_id,
messages=messages,
summary=None,
)
task = asyncio.create_task(
self._update_with_generation_internal(
request,
tenant_id,
summary_generator,
recent_turns,
),
name=f"memory_update_gen_{user_id}_{session_id}",
)
self._pending_updates.append(task)
task.add_done_callback(lambda t: self._pending_updates.remove(t))
logger.info(
f"[AC-IDMP-14] Memory update (with summary) scheduled for user={user_id}, "
f"session={session_id}, messages_count={len(messages)}"
)
return True
async def _update_with_generation_internal(
self,
request: UpdateRequest,
tenant_id: str | None,
summary_generator: Callable | None,
recent_turns: int,
) -> None:
try:
await asyncio.wait_for(
self._do_update_with_generation(
request,
tenant_id,
summary_generator,
recent_turns,
),
timeout=self._update_timeout_ms / 1000,
)
logger.info(
f"[AC-IDMP-14] Memory updated (with summary) for user={request.user_id}, "
f"session={request.session_id}"
)
except asyncio.TimeoutError:
logger.warning(
f"[AC-IDMP-14] Memory update (with summary) timeout for user={request.user_id}, "
f"session={request.session_id}"
)
except Exception as e:
logger.error(
f"[AC-IDMP-14] Memory update (with summary) failed for user={request.user_id}, "
f"session={request.session_id}, error={e}"
)
async def _do_update_with_generation(
self,
request: UpdateRequest,
tenant_id: str | None,
summary_generator: Callable | None,
recent_turns: int,
) -> None:
summary = None
if summary_generator:
try:
summary = await summary_generator(messages)
old_summary = await self._load_latest_summary(request.user_id, tenant_id)
recent_messages = self._trim_recent_messages(request.messages, recent_turns)
summary = await self._call_summary_generator(
summary_generator,
recent_messages,
old_summary,
)
except Exception as e:
logger.warning(
f"[AC-IDMP-14] Summary generation failed: {e}"
)
return await self.update(
user_id=user_id,
session_id=session_id,
messages=messages,
summary=summary,
tenant_id=tenant_id,
)
request.summary = summary
await self._do_update(request, tenant_id)
async def wait_pending_updates(self, timeout: float = 5.0) -> int:
"""
等待所有待处理的更新任务完成
用于优雅关闭时确保所有更新完成
Args:
timeout: 最大等待时间
Returns:
int: 完成的任务数
"""
if not self._pending_updates:
return 0
@ -353,3 +439,62 @@ class MemoryAdapter:
except Exception as e:
logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}")
return 0
async def _load_latest_summary(
self,
user_id: str,
tenant_id: str | None,
) -> str | None:
if not tenant_id:
return None
stmt = select(UserMemoryEntity).where(
UserMemoryEntity.tenant_id == tenant_id,
UserMemoryEntity.user_id == user_id,
)
result = await self._session.execute(stmt)
record = result.scalar_one_or_none()
return record.summary if record else None
def _trim_recent_messages(
self,
messages: list[dict[str, Any]],
recent_turns: int,
) -> list[dict[str, Any]]:
if recent_turns <= 0:
return []
return messages[-recent_turns:]
async def _call_summary_generator(
self,
summary_generator: Callable,
recent_messages: list[dict[str, Any]],
old_summary: str | None,
) -> str | None:
try:
if len(inspect.signature(summary_generator).parameters) >= 2:
return await summary_generator(recent_messages, old_summary)
except Exception:
return await summary_generator(recent_messages)
return await summary_generator(recent_messages)
def _parse_summary_payload(
self,
payload: Any,
) -> dict[str, Any] | None:
if not payload:
return None
if isinstance(payload, dict):
return payload
if isinstance(payload, str):
try:
parsed = json.loads(payload)
if isinstance(parsed, dict):
return parsed
except Exception:
return None
return None

View File

@ -334,24 +334,20 @@ class MemoryRecallTool:
) -> str | None:
"""召回最近会话摘要。"""
try:
from app.models.entities import MidAuditLog
from sqlmodel import col
from app.models.entities import UserMemory
stmt = (
select(MidAuditLog)
select(UserMemory)
.where(
MidAuditLog.tenant_id == tenant_id,
UserMemory.tenant_id == tenant_id,
UserMemory.user_id == user_id,
)
.order_by(col(MidAuditLog.created_at).desc())
.limit(1)
)
result = await self._session.execute(stmt)
audit = result.scalar_one_or_none()
memory = result.scalar_one_or_none()
if audit:
return f"上次会话模式: {audit.mode}"
return None
return memory.summary if memory else None
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")

View File

@ -0,0 +1,58 @@
"""
Memory summary generator using LLM.
[AC-IDMP-14] Generate rolling summary for memory update.
"""
from __future__ import annotations
import logging
from typing import Any
from app.services.llm.base import LLMConfig
from app.services.llm.factory import get_llm_config_manager
from app.services.mid.memory_summary_prompt import build_memory_summary_prompt
logger = logging.getLogger(__name__)
class MemorySummaryGenerator:
"""
LLM-based memory summary generator.
Output expected to be a JSON object or structured text.
"""
def __init__(self, max_tokens: int = 512, temperature: float = 0.2):
self._max_tokens = max_tokens
self._temperature = temperature
async def __call__(
self,
messages: list[dict[str, Any]],
old_summary: str | None = None,
) -> str | None:
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[AC-IDMP-14] Failed to get LLM client: {e}")
return None
prompt = build_memory_summary_prompt(messages, old_summary)
try:
response = await llm_client.generate(
messages=[
{"role": "system", "content": "你是一个严格遵循 JSON 格式的摘要器。仅输出 JSON。"},
{"role": "user", "content": prompt},
],
config=LLMConfig(
max_tokens=self._max_tokens,
temperature=self._temperature,
),
)
except Exception as e:
logger.warning(f"[AC-IDMP-14] Summary generation failed: {e}")
return None
return response.content

View File

@ -0,0 +1,46 @@
"""
Memory summary prompt builder.
[AC-IDMP-14] Rolling summary prompt for memory update.
"""
from __future__ import annotations
from typing import Any
SUMMARY_PROMPT_TEMPLATE = """
你是一个记忆摘要生成器你的目标是把对话中稳定有长期价值的信息归纳为可用于记忆召回的摘要
要求
1) 必须保留稳定事实用户偏好未解决问题
2) 不要写闲聊/情绪
3) 输出必须为严格 JSON 对象只允许以下字段summary, facts, preferences, open_issues
4) summary 为一段话150-300字以内facts/preferences/open_issues 为列表
5) 如果新内容没有变化保留旧摘要并可轻微精简
6) 所有内容必须基于对话不允许编造
旧摘要
{old_summary}
最近对话
{recent_messages}
""".strip()
def build_recent_messages_text(messages: list[dict[str, Any]]) -> str:
lines: list[str] = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if not content:
continue
lines.append(f"{role}: {content}")
return "\n".join(lines)
def build_memory_summary_prompt(
messages: list[dict[str, Any]],
old_summary: str | None = None,
) -> str:
return SUMMARY_PROMPT_TEMPLATE.format(
old_summary=old_summary or "",
recent_messages=build_recent_messages_text(messages),
)

View File

@ -69,7 +69,7 @@ class RerankerConfig:
@dataclass
class ModeRouterConfig:
"""模式路由配置。【AC-AISVC-RES-09~15】"""
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
runtime_mode: RuntimeMode = RuntimeMode.AUTO
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5

View File

@ -0,0 +1,300 @@
"""
Strategy Audit Service for AI Service.
[AC-AISVC-RES-07] Audit logging for strategy operations.
"""
import json
import logging
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import StrategyAuditLog
logger = logging.getLogger(__name__)
@dataclass
class AuditEntry:
"""
Internal audit entry structure.
"""
timestamp: str
operation: str
previous_strategy: str | None = None
new_strategy: str | None = None
previous_react_mode: str | None = None
new_react_mode: str | None = None
reason: str | None = None
operator: str | None = None
tenant_id: str | None = None
metadata: dict[str, Any] | None = None
class StrategyAuditService:
"""
[AC-AISVC-RES-07] Audit service for strategy operations.
Features:
- Structured audit logging
- In-memory audit trail (configurable retention)
- JSON output for log aggregation
"""
def __init__(self, max_entries: int = 1000):
self._audit_log: deque[AuditEntry] = deque(maxlen=max_entries)
self._max_entries = max_entries
def log(
self,
operation: str,
previous_strategy: str | None = None,
new_strategy: str | None = None,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
[AC-AISVC-RES-07] Log a strategy operation.
Args:
operation: Operation type (switch, rollback, validate).
previous_strategy: Previous strategy value.
new_strategy: New strategy value.
previous_react_mode: Previous react mode.
new_react_mode: New react mode.
reason: Reason for the operation.
operator: Operator who performed the operation.
tenant_id: Tenant ID if applicable.
metadata: Additional metadata.
"""
entry = AuditEntry(
timestamp=datetime.utcnow().isoformat(),
operation=operation,
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason,
operator=operator,
tenant_id=tenant_id,
metadata=metadata,
)
self._audit_log.append(entry)
log_data = {
"audit_type": "strategy_operation",
"timestamp": entry.timestamp,
"operation": entry.operation,
"previous_strategy": entry.previous_strategy,
"new_strategy": entry.new_strategy,
"previous_react_mode": entry.previous_react_mode,
"new_react_mode": entry.new_react_mode,
"reason": entry.reason,
"operator": entry.operator,
"tenant_id": entry.tenant_id,
"metadata": entry.metadata,
}
logger.info(
f"[AC-AISVC-RES-07] Strategy audit: operation={operation}, "
f"from={previous_strategy} -> to={new_strategy}, "
f"operator={operator}, reason={reason}"
)
audit_logger = logging.getLogger("audit.strategy")
audit_logger.info(json.dumps(log_data, ensure_ascii=False))
def log_switch(
self,
previous_strategy: str,
new_strategy: str,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
rollout_config: dict[str, Any] | None = None,
) -> None:
"""
Log a strategy switch operation.
Args:
previous_strategy: Previous strategy value.
new_strategy: New strategy value.
previous_react_mode: Previous react mode.
new_react_mode: New react mode.
reason: Reason for the switch.
operator: Operator who performed the switch.
tenant_id: Tenant ID if applicable.
rollout_config: Rollout configuration.
"""
self.log(
operation="switch",
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason,
operator=operator,
tenant_id=tenant_id,
metadata={"rollout_config": rollout_config} if rollout_config else None,
)
def log_rollback(
self,
previous_strategy: str,
new_strategy: str,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
) -> None:
"""
Log a strategy rollback operation.
Args:
previous_strategy: Previous strategy value.
new_strategy: Strategy rolled back to.
previous_react_mode: Previous react mode.
new_react_mode: React mode rolled back to.
reason: Reason for the rollback.
operator: Operator who performed the rollback.
tenant_id: Tenant ID if applicable.
"""
self.log(
operation="rollback",
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason or "Manual rollback",
operator=operator,
tenant_id=tenant_id,
)
def log_validation(
self,
strategy: str,
react_mode: str | None = None,
checks: list[str] | None = None,
passed: bool = False,
operator: str | None = None,
tenant_id: str | None = None,
) -> None:
"""
Log a strategy validation operation.
Args:
strategy: Strategy being validated.
react_mode: React mode being validated.
checks: List of checks performed.
passed: Whether validation passed.
operator: Operator who performed the validation.
tenant_id: Tenant ID if applicable.
"""
self.log(
operation="validate",
new_strategy=strategy,
new_react_mode=react_mode,
operator=operator,
tenant_id=tenant_id,
metadata={
"checks": checks,
"passed": passed,
},
)
def get_audit_log(
self,
limit: int = 100,
operation: str | None = None,
tenant_id: str | None = None,
) -> list[StrategyAuditLog]:
"""
Get audit log entries.
Args:
limit: Maximum number of entries to return.
operation: Filter by operation type.
tenant_id: Filter by tenant ID.
Returns:
List of StrategyAuditLog entries.
"""
entries = list(self._audit_log)
if operation:
entries = [e for e in entries if e.operation == operation]
if tenant_id:
entries = [e for e in entries if e.tenant_id == tenant_id]
entries = entries[-limit:]
return [
StrategyAuditLog(
timestamp=e.timestamp,
operation=e.operation,
previous_strategy=e.previous_strategy,
new_strategy=e.new_strategy,
previous_react_mode=e.previous_react_mode,
new_react_mode=e.new_react_mode,
reason=e.reason,
operator=e.operator,
tenant_id=e.tenant_id,
metadata=e.metadata,
)
for e in entries
]
def get_audit_stats(self) -> dict[str, Any]:
"""
Get audit log statistics.
Returns:
Dictionary with audit statistics.
"""
entries = list(self._audit_log)
operation_counts: dict[str, int] = {}
for entry in entries:
operation_counts[entry.operation] = operation_counts.get(entry.operation, 0) + 1
return {
"total_entries": len(entries),
"max_entries": self._max_entries,
"operation_counts": operation_counts,
"oldest_entry": entries[0].timestamp if entries else None,
"newest_entry": entries[-1].timestamp if entries else None,
}
def clear_audit_log(self) -> int:
"""
Clear all audit log entries.
Returns:
Number of entries cleared.
"""
count = len(self._audit_log)
self._audit_log.clear()
logger.info(f"[AC-AISVC-RES-07] Audit log cleared: {count} entries removed")
return count
_audit_service: StrategyAuditService | None = None
def get_audit_service() -> StrategyAuditService:
"""Get or create StrategyAuditService instance."""
global _audit_service
if _audit_service is None:
_audit_service = StrategyAuditService()
return _audit_service

View File

@ -0,0 +1,452 @@
"""
Strategy Metrics Service for AI Service.
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics collection for strategy operations.
"""
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import (
ReactMode,
StrategyMetrics,
StrategyType,
)
logger = logging.getLogger(__name__)
@dataclass
class LatencyTracker:
"""
Latency tracking for a single operation.
"""
latencies: list[float] = field(default_factory=list)
max_samples: int = 1000
def record(self, latency_ms: float) -> None:
"""Record a latency sample."""
if len(self.latencies) >= self.max_samples:
self.latencies = self.latencies[-self.max_samples // 2 :]
self.latencies.append(latency_ms)
def get_percentile(self, percentile: float) -> float:
"""Get latency at given percentile."""
if not self.latencies:
return 0.0
sorted_latencies = sorted(self.latencies)
index = int(len(sorted_latencies) * percentile / 100)
index = min(index, len(sorted_latencies) - 1)
return sorted_latencies[index]
def get_avg(self) -> float:
"""Get average latency."""
if not self.latencies:
return 0.0
return sum(self.latencies) / len(self.latencies)
@dataclass
class StrategyMetricsData:
"""
Internal metrics data structure.
"""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
latency_tracker: LatencyTracker = field(default_factory=LatencyTracker)
direct_route_count: int = 0
react_route_count: int = 0
auto_route_count: int = 0
fallback_count: int = 0
last_updated: str | None = None
class StrategyMetricsService:
"""
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics service for strategy operations.
Features:
- Request counting by strategy and route mode
- Latency tracking with percentiles
- Fallback and error tracking
- Metrics export for monitoring
"""
def __init__(self):
self._metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
self._route_metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
self._current_strategy: StrategyType = StrategyType.DEFAULT
self._current_react_mode: ReactMode = ReactMode.NON_REACT
def set_current_strategy(
self,
strategy: StrategyType,
react_mode: ReactMode,
) -> None:
"""
Set current strategy for metrics attribution.
Args:
strategy: Current active strategy.
react_mode: Current react mode.
"""
self._current_strategy = strategy
self._current_react_mode = react_mode
def record_request(
self,
latency_ms: float,
success: bool = True,
route_mode: str | None = None,
fallback: bool = False,
strategy: StrategyType | None = None,
) -> None:
"""
[AC-AISVC-RES-03, AC-AISVC-RES-08] Record a retrieval request.
Args:
latency_ms: Request latency in milliseconds.
success: Whether the request was successful.
route_mode: Route mode used (direct, react, auto).
fallback: Whether fallback to default occurred.
strategy: Strategy used (defaults to current).
"""
effective_strategy = strategy or self._current_strategy
key = effective_strategy.value
metrics = self._metrics[key]
metrics.total_requests += 1
if success:
metrics.successful_requests += 1
else:
metrics.failed_requests += 1
metrics.latency_tracker.record(latency_ms)
metrics.last_updated = datetime.utcnow().isoformat()
if fallback:
metrics.fallback_count += 1
if route_mode:
self._record_route_metric(route_mode, latency_ms, success)
logger.debug(
f"[AC-AISVC-RES-08] Request recorded: strategy={key}, "
f"latency={latency_ms:.2f}ms, success={success}, route={route_mode}"
)
def _record_route_metric(
self,
route_mode: str,
latency_ms: float,
success: bool,
) -> None:
"""
Record metrics for route mode.
Args:
route_mode: Route mode (direct, react, auto).
latency_ms: Request latency.
success: Whether successful.
"""
metrics = self._route_metrics[route_mode]
metrics.total_requests += 1
if success:
metrics.successful_requests += 1
else:
metrics.failed_requests += 1
metrics.latency_tracker.record(latency_ms)
metrics.last_updated = datetime.utcnow().isoformat()
if route_mode == "direct":
self._metrics[self._current_strategy.value].direct_route_count += 1
elif route_mode == "react":
self._metrics[self._current_strategy.value].react_route_count += 1
elif route_mode == "auto":
self._metrics[self._current_strategy.value].auto_route_count += 1
def record_strategy_switch(
self,
from_strategy: str,
to_strategy: str,
) -> None:
"""
Record a strategy switch event.
Args:
from_strategy: Previous strategy.
to_strategy: New strategy.
"""
metrics_logger = logging.getLogger("metrics.strategy")
metrics_logger.info(
json.dumps(
{
"event": "strategy_switch",
"from_strategy": from_strategy,
"to_strategy": to_strategy,
"timestamp": datetime.utcnow().isoformat(),
},
ensure_ascii=False,
)
)
logger.info(
f"[AC-AISVC-RES-03] Strategy switch recorded: {from_strategy} -> {to_strategy}"
)
def record_grayscale_request(
self,
tenant_id: str,
strategy_used: str,
in_grayscale: bool,
) -> None:
"""
[AC-AISVC-RES-03] Record a grayscale request.
Args:
tenant_id: Tenant ID.
strategy_used: Strategy used for the request.
in_grayscale: Whether the request was in grayscale group.
"""
metrics_logger = logging.getLogger("metrics.grayscale")
metrics_logger.info(
json.dumps(
{
"event": "grayscale_request",
"tenant_id": tenant_id,
"strategy_used": strategy_used,
"in_grayscale": in_grayscale,
"timestamp": datetime.utcnow().isoformat(),
},
ensure_ascii=False,
)
)
def get_metrics(self, strategy: StrategyType | None = None) -> StrategyMetrics:
"""
Get metrics for a specific strategy or current strategy.
Args:
strategy: Strategy to get metrics for (defaults to current).
Returns:
StrategyMetrics for the strategy.
"""
effective_strategy = strategy or self._current_strategy
key = effective_strategy.value
data = self._metrics[key]
return StrategyMetrics(
strategy=effective_strategy,
react_mode=self._current_react_mode,
total_requests=data.total_requests,
successful_requests=data.successful_requests,
failed_requests=data.failed_requests,
avg_latency_ms=round(data.latency_tracker.get_avg(), 2),
p99_latency_ms=round(data.latency_tracker.get_percentile(99), 2),
direct_route_count=data.direct_route_count,
react_route_count=data.react_route_count,
auto_route_count=data.auto_route_count,
fallback_count=data.fallback_count,
last_updated=data.last_updated,
)
def get_all_metrics(self) -> dict[str, StrategyMetrics]:
"""
Get metrics for all strategies.
Returns:
Dictionary of strategy name to metrics.
"""
return {
strategy.value: self.get_metrics(StrategyType(strategy))
for strategy in StrategyType
}
def get_route_metrics(self) -> dict[str, dict[str, Any]]:
"""
Get metrics by route mode.
Returns:
Dictionary of route mode to metrics.
"""
result = {}
for route_mode, data in self._route_metrics.items():
result[route_mode] = {
"total_requests": data.total_requests,
"successful_requests": data.successful_requests,
"failed_requests": data.failed_requests,
"avg_latency_ms": round(data.latency_tracker.get_avg(), 2),
"p99_latency_ms": round(data.latency_tracker.get_percentile(99), 2),
"last_updated": data.last_updated,
}
return result
def get_performance_summary(self) -> dict[str, Any]:
"""
[AC-AISVC-RES-08] Get performance summary for monitoring.
Returns:
Performance summary dictionary.
"""
all_metrics = self.get_all_metrics()
total_requests = sum(m.total_requests for m in all_metrics.values())
total_success = sum(m.successful_requests for m in all_metrics.values())
total_failed = sum(m.failed_requests for m in all_metrics.values())
avg_latencies = [
m.avg_latency_ms for m in all_metrics.values() if m.avg_latency_ms > 0
]
overall_avg_latency = (
sum(avg_latencies) / len(avg_latencies) if avg_latencies else 0.0
)
p99_latencies = [
m.p99_latency_ms for m in all_metrics.values() if m.p99_latency_ms > 0
]
overall_p99_latency = max(p99_latencies) if p99_latencies else 0.0
return {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failed,
"success_rate": round(total_success / total_requests, 4) if total_requests > 0 else 0.0,
"avg_latency_ms": round(overall_avg_latency, 2),
"p99_latency_ms": round(overall_p99_latency, 2),
"current_strategy": self._current_strategy.value,
"current_react_mode": self._current_react_mode.value,
"strategies": {
name: {
"total_requests": m.total_requests,
"success_rate": round(
m.successful_requests / m.total_requests, 4
)
if m.total_requests > 0
else 0.0,
"avg_latency_ms": m.avg_latency_ms,
"p99_latency_ms": m.p99_latency_ms,
}
for name, m in all_metrics.items()
},
"routes": self.get_route_metrics(),
}
def reset_metrics(self, strategy: StrategyType | None = None) -> None:
"""
Reset metrics for a strategy or all strategies.
Args:
strategy: Strategy to reset (None for all).
"""
if strategy:
self._metrics[strategy.value] = StrategyMetricsData()
logger.info(f"[AC-AISVC-RES-08] Metrics reset for strategy: {strategy.value}")
else:
self._metrics.clear()
self._route_metrics.clear()
logger.info("[AC-AISVC-RES-08] All metrics reset")
def check_performance_threshold(
self,
strategy: StrategyType,
max_latency_ms: float = 5000.0,
max_error_rate: float = 0.1,
) -> dict[str, Any]:
"""
[AC-AISVC-RES-08] Check if performance is within acceptable thresholds.
Args:
strategy: Strategy to check.
max_latency_ms: Maximum acceptable average latency.
max_error_rate: Maximum acceptable error rate (0-1).
Returns:
Dictionary with check results.
"""
metrics = self.get_metrics(strategy)
latency_ok = metrics.avg_latency_ms <= max_latency_ms
error_rate = (
metrics.failed_requests / metrics.total_requests
if metrics.total_requests > 0
else 0.0
)
error_rate_ok = error_rate <= max_error_rate
return {
"strategy": strategy.value,
"latency_ok": latency_ok,
"avg_latency_ms": metrics.avg_latency_ms,
"max_latency_ms": max_latency_ms,
"error_rate_ok": error_rate_ok,
"error_rate": round(error_rate, 4),
"max_error_rate": max_error_rate,
"overall_ok": latency_ok and error_rate_ok,
"recommendation": (
"Performance within acceptable thresholds"
if latency_ok and error_rate_ok
else "Consider rollback or investigation"
),
}
class MetricsContext:
"""
Context manager for timing operations.
"""
def __init__(
self,
metrics_service: StrategyMetricsService,
route_mode: str | None = None,
strategy: StrategyType | None = None,
):
self._metrics_service = metrics_service
self._route_mode = route_mode
self._strategy = strategy
self._start_time: float | None = None
self._success = True
def __enter__(self) -> "MetricsContext":
self._start_time = time.time()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._start_time is None:
return
latency_ms = (time.time() - self._start_time) * 1000
success = exc_type is None
self._metrics_service.record_request(
latency_ms=latency_ms,
success=success,
route_mode=self._route_mode,
strategy=self._strategy,
)
def mark_failed(self) -> None:
"""Mark the operation as failed."""
self._success = False
_metrics_service: StrategyMetricsService | None = None
def get_metrics_service() -> StrategyMetricsService:
"""Get or create StrategyMetricsService instance."""
global _metrics_service
if _metrics_service is None:
_metrics_service = StrategyMetricsService()
return _metrics_service

View File

@ -0,0 +1,484 @@
"""
Retrieval Strategy Service for AI Service.
[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategySwitchResponse,
RetrievalStrategyValidationRequest,
RetrievalStrategyValidationResponse,
RetrievalStrategyRollbackResponse,
ValidationResult,
)
logger = logging.getLogger(__name__)
@dataclass
class StrategyState:
"""
[AC-AISVC-RES-01] Internal state for retrieval strategy.
"""
active_strategy: StrategyType = StrategyType.DEFAULT
react_mode: ReactMode = ReactMode.NON_REACT
rollout_mode: RolloutMode = RolloutMode.OFF
rollout_percentage: float = 0.0
rollout_allowlist: list[str] = field(default_factory=list)
previous_strategy: StrategyType | None = None
previous_react_mode: ReactMode | None = None
switch_history: list[dict[str, Any]] = field(default_factory=list)
class RetrievalStrategyService:
"""
[AC-AISVC-RES-01~15] Service for managing retrieval strategies.
Features:
- Strategy switching with grayscale support
- Rollback to previous/default strategy
- Validation of strategy configuration
- Audit logging integration
"""
def __init__(self):
self._state = StrategyState()
self._audit_callback: Any = None
self._metrics_callback: Any = None
def set_audit_callback(self, callback: Any) -> None:
"""Set callback for audit logging."""
self._audit_callback = callback
def set_metrics_callback(self, callback: Any) -> None:
"""Set callback for metrics recording."""
self._metrics_callback = callback
def get_current_status(self) -> RetrievalStrategyStatus:
"""
[AC-AISVC-RES-01] Get current retrieval strategy status.
Returns:
RetrievalStrategyStatus with current configuration.
"""
rollout = RolloutConfig(
mode=self._state.rollout_mode,
percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None,
allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None,
)
status = RetrievalStrategyStatus(
active_strategy=self._state.active_strategy,
react_mode=self._state.react_mode,
rollout=rollout,
)
logger.info(
f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, "
f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}"
)
return status
def switch_strategy(
self,
request: RetrievalStrategySwitchRequest,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategySwitchResponse:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy.
Args:
request: Switch request with target strategy and options.
operator: Operator who initiated the switch.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategySwitchResponse with previous and current status.
"""
previous_status = self.get_current_status()
self._state.previous_strategy = self._state.active_strategy
self._state.previous_react_mode = self._state.react_mode
self._state.active_strategy = request.target_strategy
if request.react_mode:
self._state.react_mode = request.react_mode
if request.rollout:
self._state.rollout_mode = request.rollout.mode
if request.rollout.mode == RolloutMode.PERCENTAGE:
self._state.rollout_percentage = request.rollout.percentage or 0.0
elif request.rollout.mode == RolloutMode.ALLOWLIST:
self._state.rollout_allowlist = request.rollout.allowlist or []
switch_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
"react_mode": self._state.react_mode.value,
"rollout_mode": self._state.rollout_mode.value,
"reason": request.reason,
"operator": operator,
}
self._state.switch_history.append(switch_record)
current_status = self.get_current_status()
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> "
f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="switch",
previous_strategy=self._state.previous_strategy.value,
new_strategy=self._state.active_strategy.value,
previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None,
new_react_mode=self._state.react_mode.value,
reason=request.reason,
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_switch", {
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
})
return RetrievalStrategySwitchResponse(
previous=previous_status,
current=current_status,
)
def validate_strategy(
self,
request: RetrievalStrategyValidationRequest,
) -> RetrievalStrategyValidationResponse:
"""
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration.
Args:
request: Validation request with strategy and checks.
Returns:
RetrievalStrategyValidationResponse with check results.
"""
results: list[ValidationResult] = []
default_checks = [
"metadata_consistency",
"embedding_prefix",
"rrf_config",
"performance_budget",
]
checks_to_run = request.checks if request.checks else default_checks
for check in checks_to_run:
result = self._run_validation_check(check, request.strategy, request.react_mode)
results.append(result)
all_passed = all(r.passed for r in results)
logger.info(
f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, "
f"checks={len(results)}, passed={all_passed}"
)
return RetrievalStrategyValidationResponse(
passed=all_passed,
results=results,
)
def _run_validation_check(
self,
check: str,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
Run a single validation check.
Args:
check: Check name.
strategy: Strategy to validate.
react_mode: ReAct mode to validate.
Returns:
ValidationResult for the check.
"""
if check == "metadata_consistency":
return self._check_metadata_consistency(strategy)
elif check == "embedding_prefix":
return self._check_embedding_prefix(strategy)
elif check == "rrf_config":
return self._check_rrf_config(strategy)
elif check == "performance_budget":
return self._check_performance_budget(strategy, react_mode)
else:
return ValidationResult(
check=check,
passed=False,
message=f"Unknown check type: {check}",
)
def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-04] Check metadata consistency between strategies.
"""
try:
passed = True
message = "Metadata consistency check passed"
logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="metadata_consistency", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="metadata_consistency", passed=False, message=str(e))
def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult:
"""
Check embedding prefix configuration.
"""
try:
passed = True
message = "Embedding prefix configuration valid"
logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="embedding_prefix", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="embedding_prefix", passed=False, message=str(e))
def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration.
"""
try:
from app.core.config import get_settings
settings = get_settings()
if strategy == StrategyType.ENHANCED:
if not settings.rag_hybrid_enabled:
return ValidationResult(
check="rrf_config",
passed=False,
message="Hybrid retrieval not enabled for enhanced strategy",
)
if settings.rag_rrf_k <= 0:
return ValidationResult(
check="rrf_config",
passed=False,
message="RRF K parameter must be positive",
)
return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid")
except Exception as e:
return ValidationResult(check="rrf_config", passed=False, message=str(e))
def _check_performance_budget(
self,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
[AC-AISVC-RES-08] Check performance budget constraints.
"""
try:
max_latency_ms = 5000
if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT:
max_latency_ms = 10000
message = f"Performance budget check passed (max_latency={max_latency_ms}ms)"
logger.debug(
f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, "
f"react_mode={react_mode}, max_latency={max_latency_ms}ms"
)
return ValidationResult(check="performance_budget", passed=True, message=message)
except Exception as e:
return ValidationResult(check="performance_budget", passed=False, message=str(e))
def rollback_strategy(
self,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategyRollbackResponse:
"""
[AC-AISVC-RES-07] Rollback to previous or default strategy.
Args:
operator: Operator who initiated the rollback.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategyRollbackResponse with current and rollback status.
"""
current_status = self.get_current_status()
rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT
rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT
old_strategy = self._state.active_strategy
old_react_mode = self._state.react_mode
self._state.active_strategy = rollback_to_strategy
self._state.react_mode = rollback_to_react_mode
self._state.rollout_mode = RolloutMode.OFF
self._state.rollout_percentage = 0.0
self._state.rollout_allowlist = []
rollback_status = self.get_current_status()
rollback_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
"operator": operator,
}
self._state.switch_history.append(rollback_record)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> "
f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="rollback",
previous_strategy=old_strategy.value,
new_strategy=rollback_to_strategy.value,
previous_react_mode=old_react_mode.value,
new_react_mode=rollback_to_react_mode.value,
reason="Manual rollback",
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_rollback", {
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
})
return RetrievalStrategyRollbackResponse(
current=current_status,
rollback_to=rollback_status,
)
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
"""
[AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config.
Args:
tenant_id: Tenant ID for allowlist check.
Returns:
True if enhanced strategy should be used.
"""
if self._state.active_strategy == StrategyType.DEFAULT:
return False
if self._state.rollout_mode == RolloutMode.OFF:
return self._state.active_strategy == StrategyType.ENHANCED
if self._state.rollout_mode == RolloutMode.ALLOWLIST:
if tenant_id and tenant_id in self._state.rollout_allowlist:
return True
return False
if self._state.rollout_mode == RolloutMode.PERCENTAGE:
import random
return random.random() * 100 < self._state.rollout_percentage
return False
def get_route_mode(
self,
query: str,
confidence: float | None = None,
) -> str:
"""
[AC-AISVC-RES-09~15] Determine route mode based on query and confidence.
Args:
query: User query.
confidence: Confidence score from metadata inference.
Returns:
Route mode: "direct", "react", or "auto".
"""
if self._state.react_mode == ReactMode.REACT:
return "react"
elif self._state.react_mode == ReactMode.NON_REACT:
return "direct"
else:
return self._auto_route(query, confidence)
def _auto_route(self, query: str, confidence: float | None = None) -> str:
"""
[AC-AISVC-RES-11~14] Auto route based on query complexity and confidence.
"""
query_length = len(query)
has_multiple_conditions = "" in query or "" in query or "以及" in query
low_confidence_threshold = 0.5
short_query_threshold = 20
if confidence is not None and confidence < low_confidence_threshold:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}"
)
return "react"
if has_multiple_conditions:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected"
)
return "react"
if query_length < short_query_threshold and confidence and confidence > 0.7:
logger.info(
f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence"
)
return "direct"
return "direct"
def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]:
"""
Get recent switch history.
Args:
limit: Maximum number of records to return.
Returns:
List of switch records.
"""
return self._state.switch_history[-limit:]
_strategy_service: RetrievalStrategyService | None = None
def get_strategy_service() -> RetrievalStrategyService:
"""Get or create RetrievalStrategyService instance."""
global _strategy_service
if _strategy_service is None:
_strategy_service = RetrievalStrategyService()
return _strategy_service

View File

@ -0,0 +1,75 @@
"""
Database Migration: User Memories Table.
[AC-IDMP-14] 用户级记忆滚动摘要表
创建时间: 2025-03-08
变更说明:
- 新增 user_memories 表用于存储滚动摘要与事实/偏好/未解决问题
执行方式:
- SQLModel 会自动创建表通过 init_db
- 此脚本用于手动迁移或回滚
SQL DDL:
```sql
CREATE TABLE user_memories (
id UUID PRIMARY KEY,
tenant_id VARCHAR NOT NULL,
user_id VARCHAR NOT NULL,
summary TEXT,
facts JSON,
preferences JSON,
open_issues JSON,
summary_version INTEGER NOT NULL DEFAULT 1,
last_turn_id VARCHAR,
expires_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
CREATE INDEX ix_user_memories_tenant_user ON user_memories(tenant_id, user_id);
CREATE INDEX ix_user_memories_tenant_user_updated ON user_memories(tenant_id, user_id, updated_at);
```
回滚 SQL:
```sql
DROP TABLE IF EXISTS user_memories;
```
"""
USER_MEMORIES_DDL = """
CREATE TABLE IF NOT EXISTS user_memories (
id UUID PRIMARY KEY,
tenant_id VARCHAR NOT NULL,
user_id VARCHAR NOT NULL,
summary TEXT,
facts JSON,
preferences JSON,
open_issues JSON,
summary_version INTEGER NOT NULL DEFAULT 1,
last_turn_id VARCHAR,
expires_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
"""
USER_MEMORIES_INDEXES = """
CREATE INDEX IF NOT EXISTS ix_user_memories_tenant_user ON user_memories(tenant_id, user_id);
CREATE INDEX IF NOT EXISTS ix_user_memories_tenant_user_updated ON user_memories(tenant_id, user_id, updated_at);
"""
USER_MEMORIES_ROLLBACK = """
DROP TABLE IF EXISTS user_memories;
"""
async def upgrade(conn):
"""执行迁移"""
await conn.execute(USER_MEMORIES_DDL)
await conn.execute(USER_MEMORIES_INDEXES)
async def downgrade(conn):
"""回滚迁移"""
await conn.execute(USER_MEMORIES_ROLLBACK)

View File

@ -0,0 +1,178 @@
"""
Script to cleanup vector data for a specific knowledge base.
Clears the Qdrant collection for the given KB ID, allowing re-indexing.
"""
import asyncio
import logging
import sys
sys.path.insert(0, "Q:\\agentProject\\ai-robot-core\\ai-service")
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
from app.core.database import get_session
from app.models.entities import KnowledgeBase, Document
from sqlalchemy import select
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
async def get_knowledge_base_info(kb_id: str) -> dict | None:
"""Get knowledge base information from database."""
async for session in get_session():
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if kb:
doc_stmt = select(Document).where(Document.kb_id == kb_id)
doc_result = await session.execute(doc_stmt)
documents = doc_result.scalars().all()
return {
"id": str(kb.id),
"tenant_id": kb.tenant_id,
"name": kb.name,
"doc_count": len(documents),
"document_ids": [str(doc.id) for doc in documents]
}
return None
async def list_kb_collections(tenant_id: str, kb_id: str) -> list[str]:
"""List all collections that might be related to the KB."""
client = await get_qdrant_client()
qdrant = await client.get_client()
collections = await qdrant.get_collections()
all_names = [c.name for c in collections.collections]
safe_tenant = tenant_id.replace('@', '_')
safe_kb = kb_id.replace('-', '_')[:8]
matching = [
name for name in all_names
if safe_kb in name or kb_id.replace('-', '')[:8] in name.replace('_', '')
]
return matching
async def clear_kb_vector_data(tenant_id: str, kb_id: str, delete_docs: bool = False) -> bool:
"""
Clear vector data for a specific knowledge base.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
delete_docs: Whether to also delete document records from database
Returns:
True if successful
"""
client = await get_qdrant_client()
qdrant = await client.get_client()
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await qdrant.collection_exists(collection_name)
if exists:
await qdrant.delete_collection(collection_name=collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.info(f"Collection {collection_name} does not exist")
if delete_docs:
async for session in get_session():
doc_stmt = select(Document).where(Document.kb_id == kb_id)
doc_result = await session.execute(doc_stmt)
documents = doc_result.scalars().all()
for doc in documents:
await session.delete(doc)
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if kb:
kb.doc_count = 0
kb.updated_at = datetime.utcnow()
await session.commit()
logger.info(f"Deleted {len(documents)} document records from database")
break
return True
except Exception as e:
logger.error(f"Failed to clear KB vector data: {e}")
return False
async def main(kb_id: str, delete_docs: bool = False):
"""Main function to clear KB vector data."""
logger.info(f"Starting cleanup for knowledge base: {kb_id}")
kb_info = await get_knowledge_base_info(kb_id)
if not kb_info:
logger.error(f"Knowledge base not found: {kb_id}")
return False
logger.info(f"Found knowledge base:")
logger.info(f" - ID: {kb_info['id']}")
logger.info(f" - Name: {kb_info['name']}")
logger.info(f" - Tenant: {kb_info['tenant_id']}")
logger.info(f" - Document count: {kb_info['doc_count']}")
matching_collections = await list_kb_collections(kb_info['tenant_id'], kb_id)
if matching_collections:
logger.info(f" - Related collections: {matching_collections}")
print()
print("=" * 60)
print("WARNING: This will delete all vector data for this knowledge base!")
print(f"Collection to delete: kb_{kb_info['tenant_id'].replace('@', '_')}_{kb_id.replace('-', '_')[:8]}")
if delete_docs:
print("Document records in database will also be deleted!")
print("=" * 60)
print()
confirm = input("Continue? (yes/no): ")
if confirm.lower() != "yes":
print("Cancelled")
return False
success = await clear_kb_vector_data(
tenant_id=kb_info['tenant_id'],
kb_id=kb_id,
delete_docs=delete_docs
)
if success:
logger.info(f"Successfully cleared vector data for KB: {kb_id}")
logger.info("You can now re-index the knowledge base documents.")
else:
logger.error(f"Failed to clear vector data for KB: {kb_id}")
return success
if __name__ == "__main__":
import argparse
from datetime import datetime
parser = argparse.ArgumentParser(description="Clear vector data for a knowledge base")
parser.add_argument("kb_id", help="Knowledge base ID to clear")
parser.add_argument("--delete-docs", action="store_true",
help="Also delete document records from database")
args = parser.parse_args()
asyncio.run(main(args.kb_id, args.delete_docs))

1
ai-service/svg/kefu.svg Normal file
View File

@ -0,0 +1 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg class="icon" width="200px" height="200.00px" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"><path fill="#1296db" d="M894.1 355.6h-1.7C853 177.6 687.6 51.4 498.1 54.9S148.2 190.5 115.9 369.7c-35.2 5.6-61.1 36-61.1 71.7v143.4c0.9 40.4 34.3 72.5 74.7 71.7 21.7-0.3 42.2-10 56-26.7 33.6 84.5 99.9 152 183.8 187 1.1-2 2.3-3.9 3.7-5.7 0.9-1.5 2.4-2.6 4.1-3 1.3 0 2.5 0.5 3.6 1.2a318.46 318.46 0 0 1-105.3-187.1c-5.1-44.4 24.1-85.4 67.6-95.2 64.3-11.7 128.1-24.7 192.4-35.9 37.9-5.3 70.4-29.8 85.7-64.9 6.8-15.9 11-32.8 12.5-50 0.5-3.1 2.9-5.6 5.9-6.2 3.1-0.7 6.4 0.5 8.2 3l1.7-1.1c25.4 35.9 74.7 114.4 82.7 197.2 8.2 94.8 3.7 160-71.4 226.5-1.1 1.1-1.7 2.6-1.7 4.1 0.1 2 1.1 3.8 2.8 4.8h4.8l3.2-1.8c75.6-40.4 132.8-108.2 159.9-189.5 11.4 16.1 28.5 27.1 47.8 30.8C846 783.9 716.9 871.6 557.2 884.9c-12-28.6-42.5-44.8-72.9-38.6-33.6 5.4-56.6 37-51.2 70.6 4.4 27.6 26.8 48.8 54.5 51.6 30.6 4.6 60.3-13 70.8-42.2 184.9-14.5 333.2-120.8 364.2-286.9 27.8-10.8 46.3-37.4 46.6-67.2V428.7c-0.1-19.5-8.1-38.2-22.3-51.6-14.5-13.8-33.8-21.4-53.8-21.3l1-0.2zM825.9 397c-71.1-176.9-272.1-262.7-449-191.7-86.8 34.9-155.7 103.4-191 190-2.5-2.8-5.2-5.4-8-7.9 25.3-154.6 163.8-268.6 326.8-269.2s302.3 112.6 328.7 267c-2.9 3.8-5.4 7.7-7.5 11.8z" /></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

1
ai-service/svg/user.svg Normal file
View File

@ -0,0 +1 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg class="icon" width="200px" height="200.00px" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"><path fill="#1296db" d="M573.9 516.2L512 640l-61.9-123.8C232 546.4 64 733.6 64 960h896c0-226.4-168-413.6-386.1-443.8zM480 384h64c17.7 0 32.1 14.4 32.1 32.1 0 17.7-14.4 32.1-32.1 32.1h-64c-11.9 0-22.3-6.5-27.8-16.1H356c34.9 48.5 91.7 80 156 80 106 0 192-86 192-192s-86-192-192-192-192 86-192 192c0 28.5 6.2 55.6 17.4 80h114.8c5.5-9.6 15.9-16.1 27.8-16.1z" /><path fill="#1296db" d="M272 432.1h84c-4.2-5.9-8.1-12-11.7-18.4-2.3-4.1-4.4-8.3-6.4-12.5-0.2-0.4-0.4-0.7-0.5-1.1H288c-8.8 0-16-7.2-16-16v-48.4c0-64.1 25-124.3 70.3-169.6S447.9 95.8 512 95.8s124.3 25 169.7 70.3c38.3 38.3 62.1 87.2 68.5 140.2-8.4 4-14.2 12.5-14.2 22.4v78.6c0 13.7 11.1 24.8 24.8 24.8h14.6c13.7 0 24.8-11.1 24.8-24.8v-78.6c0-11.3-7.6-20.9-18-23.8-6.9-60.9-33.9-117.4-78-161.3C652.9 92.1 584.6 63.9 512 63.9s-140.9 28.3-192.3 79.6C268.3 194.8 240 263.1 240 335.7v64.4c0 17.7 14.3 32 32 32z" /></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1773157868702" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="12129" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M657.066667 558.933333c-34.133333 25.6-76.8 38.4-123.733334 38.4s-85.333333-12.8-123.733333-38.4c0 4.266667-4.266667 4.266667-8.533333 4.266667C315.733333 614.4 256 708.266667 256 810.666667c0 12.8-8.533333 21.333333-21.333333 21.333333S213.333333 823.466667 213.333333 810.666667c0-119.466667 64-226.133333 166.4-281.6-38.4-38.4-59.733333-89.6-59.733333-145.066667 0-119.466667 93.866667-213.333333 213.333333-213.333333s213.333333 93.866667 213.333334 213.333333c0 55.466667-21.333333 106.666667-59.733334 145.066667 102.4 55.466667 166.4 162.133333 166.4 281.6 0 12.8-8.533333 21.333333-21.333333 21.333333s-21.333333-8.533333-21.333333-21.333333c0-102.4-59.733333-196.266667-149.333334-247.466667 0 0-4.266667 0-4.266666-4.266667z m-123.733334-4.266666c93.866667 0 170.666667-76.8 170.666667-170.666667s-76.8-170.666667-170.666667-170.666667-170.666667 76.8-170.666666 170.666667 76.8 170.666667 170.666666 170.666667z" p-id="12130" fill="#1296db"></path></svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1,375 @@
"""
Unit tests for ImageParser.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.document.image_parser import (
ImageChunk,
ImageParseResult,
ImageParser,
)
class TestImageParserBasics:
"""Test basic functionality of ImageParser."""
def test_supported_extensions(self):
"""Test that ImageParser supports correct extensions."""
parser = ImageParser()
extensions = parser.get_supported_extensions()
expected_extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
assert extensions == expected_extensions
def test_get_mime_type(self):
"""Test MIME type mapping."""
parser = ImageParser()
assert parser._get_mime_type(".jpg") == "image/jpeg"
assert parser._get_mime_type(".jpeg") == "image/jpeg"
assert parser._get_mime_type(".png") == "image/png"
assert parser._get_mime_type(".gif") == "image/gif"
assert parser._get_mime_type(".webp") == "image/webp"
assert parser._get_mime_type(".bmp") == "image/bmp"
assert parser._get_mime_type(".tiff") == "image/tiff"
assert parser._get_mime_type(".tif") == "image/tiff"
assert parser._get_mime_type(".unknown") == "image/jpeg"
class TestImageChunkParsing:
"""Test LLM response parsing functionality."""
def test_extract_json_from_plain_json(self):
"""Test extracting JSON from plain JSON response."""
parser = ImageParser()
json_str = '{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello", "chunk_type": "text", "keywords": ["key"]}]}'
result = parser._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self):
"""Test extracting JSON from markdown code block."""
parser = ImageParser()
markdown = """Here is the analysis:
```json
{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello"}]}
```
Hope this helps!"""
result = parser._extract_json(markdown)
assert "image_summary" in result
assert "test" in result
def test_extract_json_from_text_with_json(self):
"""Test extracting JSON from text with JSON embedded."""
parser = ImageParser()
text = "The result is: {'image_summary': 'summary', 'chunks': []}"
result = parser._extract_json(text)
assert "image_summary" in result
assert "chunks" in result
def test_parse_llm_response_valid_json(self):
"""Test parsing valid JSON response from LLM."""
parser = ImageParser()
response = json.dumps({
"image_summary": "测试图片",
"total_chunks": 2,
"chunks": [
{
"chunk_index": 0,
"content": "这是第一块内容",
"chunk_type": "text",
"keywords": ["测试", "内容"]
},
{
"chunk_index": 1,
"content": "这是第二块内容,包含表格数据",
"chunk_type": "table",
"keywords": ["表格", "数据"]
}
]
})
result = parser._parse_llm_response(response)
assert result.image_summary == "测试图片"
assert len(result.chunks) == 2
assert result.chunks[0].content == "这是第一块内容"
assert result.chunks[0].chunk_type == "text"
assert result.chunks[0].keywords == ["测试", "内容"]
assert result.chunks[1].chunk_type == "table"
assert result.chunks[1].keywords == ["表格", "数据"]
def test_parse_llm_response_empty_chunks(self):
"""Test handling response with empty chunks."""
parser = ImageParser()
response = json.dumps({
"image_summary": "测试",
"chunks": []
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == response
def test_parse_llm_response_invalid_json(self):
"""Test handling invalid JSON response with fallback."""
parser = ImageParser()
response = "This is not JSON at all"
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == "This is not JSON at all"
def test_parse_llm_response_partial_json(self):
"""Test handling response with partial/invalid JSON uses fallback."""
parser = ImageParser()
response = '{"image_summary": "test" some text here {"chunks": []}'
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == response
class TestImageChunkDataClass:
"""Test ImageChunk dataclass functionality."""
def test_image_chunk_creation(self):
"""Test creating ImageChunk."""
chunk = ImageChunk(
chunk_index=0,
content="Test content",
chunk_type="text",
keywords=["test", "content"]
)
assert chunk.chunk_index == 0
assert chunk.content == "Test content"
assert chunk.chunk_type == "text"
assert chunk.keywords == ["test", "content"]
def test_image_chunk_default_values(self):
"""Test ImageChunk with default values."""
chunk = ImageChunk(chunk_index=0, content="Test")
assert chunk.chunk_type == "text"
assert chunk.keywords == []
def test_image_parse_result_creation(self):
"""Test creating ImageParseResult."""
chunks = [
ImageChunk(chunk_index=0, content="Chunk 1"),
ImageChunk(chunk_index=1, content="Chunk 2"),
]
result = ImageParseResult(
image_summary="Test summary",
chunks=chunks,
raw_text="Chunk 1\n\nChunk 2",
source_path="/path/to/image.png",
file_size=1024,
metadata={"format": "png"}
)
assert result.image_summary == "Test summary"
assert len(result.chunks) == 2
assert result.raw_text == "Chunk 1\n\nChunk 2"
assert result.file_size == 1024
assert result.metadata["format"] == "png"
class TestChunkTypes:
"""Test different chunk types."""
def test_text_chunk_type(self):
"""Test text chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Text content",
"chunks": [
{
"chunk_index": 0,
"content": "Plain text content",
"chunk_type": "text",
"keywords": ["text"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "text"
def test_table_chunk_type(self):
"""Test table chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Table content",
"chunks": [
{
"chunk_index": 0,
"content": "Name | Age\n---- | ---\nJohn | 30",
"chunk_type": "table",
"keywords": ["table", "data"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "table"
def test_chart_chunk_type(self):
"""Test chart chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Chart content",
"chunks": [
{
"chunk_index": 0,
"content": "Bar chart showing sales data",
"chunk_type": "chart",
"keywords": ["chart", "sales"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "chart"
def test_list_chunk_type(self):
"""Test list chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "List content",
"chunks": [
{
"chunk_index": 0,
"content": "1. First item\n2. Second item\n3. Third item",
"chunk_type": "list",
"keywords": ["list", "items"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "list"
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_single_chunk_scenario(self):
"""Test single chunk scenario - simple image with one main content."""
parser = ImageParser()
response = json.dumps({
"image_summary": "简单文档截图",
"chunks": [
{
"chunk_index": 0,
"content": "这是一段完整的文档内容,包含所有的信息要点。",
"chunk_type": "text",
"keywords": ["文档", "信息"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.image_summary == "简单文档截图"
assert result.raw_text == "这是一段完整的文档内容,包含所有的信息要点。"
def test_multi_chunk_scenario(self):
"""Test multi-chunk scenario - complex image with multiple sections."""
parser = ImageParser()
response = json.dumps({
"image_summary": "多段落文档",
"chunks": [
{
"chunk_index": 0,
"content": "第一章:介绍部分,介绍项目的背景和目标。",
"chunk_type": "text",
"keywords": ["第一章", "介绍"]
},
{
"chunk_index": 1,
"content": "第二章:技术架构,包括前端、后端和数据库设计。",
"chunk_type": "text",
"keywords": ["第二章", "架构"]
},
{
"chunk_index": 2,
"content": "第三章:部署流程,包含开发环境和生产环境配置。",
"chunk_type": "text",
"keywords": ["第三章", "部署"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 3
assert "第一章" in result.chunks[0].content
assert "第二章" in result.chunks[1].content
assert "第三章" in result.chunks[2].content
assert result.raw_text.count("\n\n") == 2
def test_mixed_content_scenario(self):
"""Test mixed content scenario - text and table."""
parser = ImageParser()
response = json.dumps({
"image_summary": "混合内容图片",
"chunks": [
{
"chunk_index": 0,
"content": "产品介绍:本文档介绍我们的核心产品功能。",
"chunk_type": "text",
"keywords": ["产品", "功能"]
},
{
"chunk_index": 1,
"content": "产品规格表:\n| 型号 | 价格 | 库存 |\n| --- | --- | --- |\n| A1 | 100 | 50 |",
"chunk_type": "table",
"keywords": ["规格", "价格", "库存"]
},
{
"chunk_index": 2,
"content": "使用说明:\n1. 打开包装\n2. 连接电源\n3. 按下启动按钮",
"chunk_type": "list",
"keywords": ["说明", "步骤"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 3
assert result.chunks[0].chunk_type == "text"
assert result.chunks[1].chunk_type == "table"
assert result.chunks[2].chunk_type == "list"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,332 @@
"""
Unit tests for multi-usage LLM configuration.
Tests for LLMUsageType, LLMConfigManager multi-usage support, and API endpoints.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.llm.factory import (
LLMConfigManager,
LLMProviderFactory,
LLMUsageType,
LLM_USAGE_DESCRIPTIONS,
LLM_USAGE_DISPLAY_NAMES,
get_llm_config_manager,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.llm_provider = "openai"
settings.llm_api_key = "test-api-key"
settings.llm_base_url = "https://api.openai.com/v1"
settings.llm_model = "gpt-4o-mini"
settings.llm_max_tokens = 2048
settings.llm_temperature = 0.7
settings.llm_timeout_seconds = 30
settings.llm_max_retries = 3
settings.redis_enabled = False
settings.redis_url = "redis://localhost:6379"
return settings
@pytest.fixture(autouse=True)
def reset_singleton():
"""Reset singleton before and after each test."""
import app.services.llm.factory as factory
factory._llm_config_manager = None
yield
factory._llm_config_manager = None
@pytest.fixture
def isolated_config_file(tmp_path):
"""Create an isolated config file for testing."""
config_file = tmp_path / "llm_config.json"
config_file.write_text("{}")
return config_file
class TestLLMUsageType:
"""Tests for LLMUsageType enum."""
def test_usage_types_exist(self):
"""Test that required usage types exist."""
assert LLMUsageType.CHAT.value == "chat"
assert LLMUsageType.KB_PROCESSING.value == "kb_processing"
def test_usage_type_display_names(self):
"""Test that display names are defined for all usage types."""
for ut in LLMUsageType:
assert ut in LLM_USAGE_DISPLAY_NAMES
assert isinstance(LLM_USAGE_DISPLAY_NAMES[ut], str)
assert len(LLM_USAGE_DISPLAY_NAMES[ut]) > 0
def test_usage_type_descriptions(self):
"""Test that descriptions are defined for all usage types."""
for ut in LLMUsageType:
assert ut in LLM_USAGE_DESCRIPTIONS
assert isinstance(LLM_USAGE_DESCRIPTIONS[ut], str)
assert len(LLM_USAGE_DESCRIPTIONS[ut]) > 0
def test_usage_type_from_string(self):
"""Test creating usage type from string."""
assert LLMUsageType("chat") == LLMUsageType.CHAT
assert LLMUsageType("kb_processing") == LLMUsageType.KB_PROCESSING
def test_invalid_usage_type(self):
"""Test that invalid usage type raises error."""
with pytest.raises(ValueError):
LLMUsageType("invalid_type")
class TestLLMConfigManagerMultiUsage:
"""Tests for LLMConfigManager multi-usage support."""
@pytest.mark.asyncio
async def test_get_all_configs(self, mock_settings, isolated_config_file):
"""Test getting all configs at once."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
all_configs = manager.get_current_config()
for ut in LLMUsageType:
assert ut.value in all_configs
assert "provider" in all_configs[ut.value]
assert "config" in all_configs[ut.value]
@pytest.mark.asyncio
async def test_update_specific_usage_config(self, mock_settings, isolated_config_file):
"""Test updating config for a specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_usage_config(
usage_type=LLMUsageType.KB_PROCESSING,
provider="ollama",
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
)
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
assert kb_config["provider"] == "ollama"
assert kb_config["config"]["model"] == "llama3.2"
@pytest.mark.asyncio
async def test_update_all_configs(self, mock_settings, isolated_config_file):
"""Test updating all configs at once."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_config(
provider="deepseek",
config={"api_key": "test-key", "model": "deepseek-chat"},
)
for ut in LLMUsageType:
config = manager.get_current_config(ut)
assert config["provider"] == "deepseek"
@pytest.mark.asyncio
async def test_get_client_for_usage_type(self, mock_settings, isolated_config_file):
"""Test getting client for specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
chat_client = manager.get_client(LLMUsageType.CHAT)
assert chat_client is not None
kb_client = manager.get_client(LLMUsageType.KB_PROCESSING)
assert kb_client is not None
@pytest.mark.asyncio
async def test_get_chat_client(self, mock_settings, isolated_config_file):
"""Test get_chat_client convenience method."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
client = manager.get_chat_client()
assert client is not None
@pytest.mark.asyncio
async def test_get_kb_processing_client(self, mock_settings, isolated_config_file):
"""Test get_kb_processing_client convenience method."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
client = manager.get_kb_processing_client()
assert client is not None
@pytest.mark.asyncio
async def test_close_all_clients(self, mock_settings, isolated_config_file):
"""Test that close() closes all clients."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
for ut in LLMUsageType:
manager.get_client(ut)
await manager.close()
for ut in LLMUsageType:
assert manager._clients[ut] is None
@pytest.mark.asyncio
async def test_config_persistence_to_file(self, mock_settings, isolated_config_file):
"""Test that configs are persisted to file."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_usage_config(
usage_type=LLMUsageType.KB_PROCESSING,
provider="ollama",
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
)
assert isolated_config_file.exists()
with open(isolated_config_file, "r", encoding="utf-8") as f:
saved = json.load(f)
assert "chat" in saved
assert "kb_processing" in saved
assert saved["kb_processing"]["provider"] == "ollama"
@pytest.mark.asyncio
async def test_load_config_from_file(self, mock_settings, tmp_path):
"""Test loading configs from file."""
config_file = tmp_path / "llm_config.json"
saved_config = {
"chat": {
"provider": "openai",
"config": {"api_key": "test-key", "model": "gpt-4o"},
},
"kb_processing": {
"provider": "ollama",
"config": {"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
},
}
with open(config_file, "w", encoding="utf-8") as f:
json.dump(saved_config, f)
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
manager = LLMConfigManager()
chat_config = manager.get_current_config(LLMUsageType.CHAT)
assert chat_config["config"]["model"] == "gpt-4o"
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
assert kb_config["provider"] == "ollama"
assert kb_config["config"]["model"] == "llama3.2"
@pytest.mark.asyncio
async def test_backward_compatibility_old_config_format(self, mock_settings, tmp_path):
"""Test backward compatibility with old single-config format."""
config_file = tmp_path / "llm_config.json"
old_config = {
"provider": "deepseek",
"config": {"api_key": "test-key", "model": "deepseek-chat"},
}
with open(config_file, "w", encoding="utf-8") as f:
json.dump(old_config, f)
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
manager = LLMConfigManager()
for ut in LLMUsageType:
config = manager.get_current_config(ut)
assert config["provider"] == "deepseek"
assert config["config"]["model"] == "deepseek-chat"
class TestLLMConfigManagerTestConnection:
"""Tests for test_connection with usage type support."""
@pytest.mark.asyncio
async def test_test_connection_with_usage_type(self, mock_settings, isolated_config_file):
"""Test connection testing with specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
mock_client = AsyncMock()
mock_client.generate = AsyncMock(
return_value=MagicMock(
content="Test response",
model="gpt-4o-mini",
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
)
mock_client.close = AsyncMock()
with patch.object(
LLMProviderFactory, "create_client", return_value=mock_client
):
result = await manager.test_connection(
test_prompt="Hello",
usage_type=LLMUsageType.CHAT,
)
assert result["success"] is True
assert "response" in result
assert "latency_ms" in result
class TestGetLLMConfigManager:
"""Tests for get_llm_config_manager singleton."""
def test_singleton_instance(self, mock_settings, isolated_config_file):
"""Test that get_llm_config_manager returns singleton."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager1 = get_llm_config_manager()
manager2 = get_llm_config_manager()
assert manager1 is manager2
def test_reset_singleton(self, mock_settings, isolated_config_file):
"""Test resetting singleton instance."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = get_llm_config_manager()
assert manager is not None
class TestLLMConfigManagerProperties:
"""Tests for LLMConfigManager properties."""
def test_chat_config_property(self, mock_settings, isolated_config_file):
"""Test chat_config property."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
config = manager.chat_config
assert isinstance(config, dict)
assert "model" in config
def test_kb_processing_config_property(self, mock_settings, isolated_config_file):
"""Test kb_processing_config property."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
config = manager.kb_processing_config
assert isinstance(config, dict)
assert "model" in config

View File

@ -0,0 +1,530 @@
"""
Unit tests for Markdown intelligent chunker.
Tests for MarkdownParser, MarkdownChunker, and integration.
"""
import pytest
from app.services.document.markdown_chunker import (
MarkdownChunk,
MarkdownChunker,
MarkdownElement,
MarkdownElementType,
MarkdownParser,
chunk_markdown,
)
class TestMarkdownParser:
"""Tests for MarkdownParser."""
def test_parse_headers(self):
"""Test header extraction."""
text = """# Main Title
## Section 1
### Subsection 1.1
#### Deep Header
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
assert len(headers) == 4
assert headers[0].content == "Main Title"
assert headers[0].level == 1
assert headers[1].content == "Section 1"
assert headers[1].level == 2
assert headers[2].content == "Subsection 1.1"
assert headers[2].level == 3
assert headers[3].content == "Deep Header"
assert headers[3].level == 4
def test_parse_code_blocks(self):
"""Test code block extraction with language."""
text = """Here is some code:
```python
def hello():
print("Hello, World!")
```
And some more text.
"""
parser = MarkdownParser()
elements = parser.parse(text)
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(code_blocks) == 1
assert code_blocks[0].language == "python"
assert 'def hello():' in code_blocks[0].content
assert 'print("Hello, World!")' in code_blocks[0].content
def test_parse_code_blocks_no_language(self):
"""Test code block without language specification."""
text = """```
plain code here
multiple lines
```
"""
parser = MarkdownParser()
elements = parser.parse(text)
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(code_blocks) == 1
assert code_blocks[0].language == ""
assert "plain code here" in code_blocks[0].content
def test_parse_tables(self):
"""Test table extraction."""
text = """| Name | Age | City |
|------|-----|------|
| Alice | 30 | NYC |
| Bob | 25 | LA |
"""
parser = MarkdownParser()
elements = parser.parse(text)
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
assert len(tables) == 1
assert "Name" in tables[0].content
assert "Alice" in tables[0].content
assert tables[0].metadata.get("headers") == ["Name", "Age", "City"]
assert tables[0].metadata.get("row_count") == 2
def test_parse_lists(self):
"""Test list extraction."""
text = """- Item 1
- Item 2
- Item 3
"""
parser = MarkdownParser()
elements = parser.parse(text)
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
assert len(lists) == 1
assert "Item 1" in lists[0].content
assert "Item 2" in lists[0].content
assert "Item 3" in lists[0].content
def test_parse_ordered_lists(self):
"""Test ordered list extraction."""
text = """1. First
2. Second
3. Third
"""
parser = MarkdownParser()
elements = parser.parse(text)
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
assert len(lists) == 1
assert "First" in lists[0].content
assert "Second" in lists[0].content
assert "Third" in lists[0].content
def test_parse_blockquotes(self):
"""Test blockquote extraction."""
text = """> This is a quote.
> It spans multiple lines.
> And continues here.
"""
parser = MarkdownParser()
elements = parser.parse(text)
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
assert len(quotes) == 1
assert "This is a quote." in quotes[0].content
assert "It spans multiple lines." in quotes[0].content
def test_parse_paragraphs(self):
"""Test paragraph extraction."""
text = """This is the first paragraph.
This is the second paragraph.
It has multiple lines.
This is the third.
"""
parser = MarkdownParser()
elements = parser.parse(text)
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
assert len(paragraphs) == 3
assert "first paragraph" in paragraphs[0].content
assert "second paragraph" in paragraphs[1].content
def test_parse_mixed_content(self):
"""Test parsing mixed Markdown content."""
text = """# Documentation
## Introduction
This is an introduction paragraph.
## Code Example
```python
def example():
return 42
```
## Data Table
| Column A | Column B |
|----------|----------|
| Value 1 | Value 2 |
## List
- Item A
- Item B
> Note: This is important.
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
assert len(headers) == 5
assert len(code_blocks) == 1
assert len(tables) == 1
assert len(lists) == 1
assert len(quotes) == 1
assert len(paragraphs) >= 1
def test_code_blocks_not_parsed_as_other_elements(self):
"""Test that code blocks don't get parsed as headers or lists."""
text = """```markdown
# This is not a header
- This is not a list
| This is not a table |
```
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(headers) == 0
assert len(lists) == 0
assert len(tables) == 0
assert len(code_blocks) == 1
class TestMarkdownChunker:
"""Tests for MarkdownChunker."""
def test_chunk_simple_document(self):
"""Test chunking a simple document."""
text = """# Title
This is a paragraph.
## Section
Another paragraph.
"""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test_doc")
assert len(chunks) >= 2
assert all(isinstance(chunk, MarkdownChunk) for chunk in chunks)
assert all(chunk.chunk_id.startswith("test_doc") for chunk in chunks)
def test_chunk_preserves_header_context(self):
"""Test that header context is preserved."""
text = """# Main Title
## Section A
Content under section A.
### Subsection A1
Content under subsection A1.
"""
chunker = MarkdownChunker(include_header_context=True)
chunks = chunker.chunk(text, "test")
subsection_chunks = [c for c in chunks if "Subsection A1" not in c.content]
for chunk in subsection_chunks:
if "subsection a1" in chunk.content.lower():
assert "Main Title" in chunk.header_context
assert "Section A" in chunk.header_context
def test_chunk_code_blocks_preserved(self):
"""Test that code blocks are preserved as single chunks when possible."""
text = """```python
def function_one():
pass
def function_two():
pass
```
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_code_blocks=True)
chunks = chunker.chunk(text, "test")
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
assert len(code_chunks) == 1
assert "def function_one" in code_chunks[0].content
assert "def function_two" in code_chunks[0].content
assert code_chunks[0].language == "python"
def test_chunk_large_code_block_split(self):
"""Test that large code blocks are split properly."""
lines = ["def function_{}(): pass".format(i) for i in range(100)]
code_content = "\n".join(lines)
text = f"""```python\n{code_content}\n```"""
chunker = MarkdownChunker(max_chunk_size=500, preserve_code_blocks=True)
chunks = chunker.chunk(text, "test")
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
assert len(code_chunks) > 1
for chunk in code_chunks:
assert chunk.language == "python"
assert "```python" in chunk.content
assert "```" in chunk.content
def test_chunk_table_preserved(self):
"""Test that tables are preserved."""
text = """| Name | Age |
|------|-----|
| Alice | 30 |
| Bob | 25 |
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_tables=True)
chunks = chunker.chunk(text, "test")
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
assert len(table_chunks) == 1
assert "Alice" in table_chunks[0].content
assert "Bob" in table_chunks[0].content
def test_chunk_large_table_split(self):
"""Test that large tables are split with header preserved."""
rows = [f"| Name{i} | {i * 10} |" for i in range(50)]
table_content = "| Name | Age |\n|------|-----|\n" + "\n".join(rows)
text = table_content
chunker = MarkdownChunker(max_chunk_size=200, preserve_tables=True)
chunks = chunker.chunk(text, "test")
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
assert len(table_chunks) > 1
for chunk in table_chunks:
assert "| Name | Age |" in chunk.content
assert "|------|-----|" in chunk.content
def test_chunk_list_preserved(self):
"""Test that lists are chunked properly."""
text = """- Item 1
- Item 2
- Item 3
- Item 4
- Item 5
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_lists=True)
chunks = chunker.chunk(text, "test")
list_chunks = [c for c in chunks if c.element_type == MarkdownElementType.LIST]
assert len(list_chunks) == 1
assert "Item 1" in list_chunks[0].content
assert "Item 5" in list_chunks[0].content
def test_chunk_empty_document(self):
"""Test chunking an empty document."""
text = ""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test")
assert len(chunks) == 0
def test_chunk_only_headers(self):
"""Test chunking a document with only headers."""
text = """# Title 1
## Title 2
### Title 3
"""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test")
assert len(chunks) == 0
class TestChunkMarkdownFunction:
"""Tests for the convenience chunk_markdown function."""
def test_basic_chunking(self):
"""Test basic chunking via convenience function."""
text = """# Title
Content paragraph.
```python
code = "here"
```
"""
chunks = chunk_markdown(text, "doc1")
assert len(chunks) >= 1
assert all("chunk_id" in chunk for chunk in chunks)
assert all("content" in chunk for chunk in chunks)
assert all("element_type" in chunk for chunk in chunks)
assert all("header_context" in chunk for chunk in chunks)
def test_custom_parameters(self):
"""Test chunking with custom parameters."""
text = "A" * 2000
chunks = chunk_markdown(
text,
"doc1",
max_chunk_size=500,
min_chunk_size=50,
preserve_code_blocks=False,
preserve_tables=False,
preserve_lists=False,
include_header_context=False,
)
assert len(chunks) >= 1
class TestMarkdownElement:
"""Tests for MarkdownElement dataclass."""
def test_to_dict(self):
"""Test serialization to dictionary."""
elem = MarkdownElement(
type=MarkdownElementType.HEADER,
content="Test Header",
level=2,
line_start=10,
line_end=10,
metadata={"level": 2},
)
result = elem.to_dict()
assert result["type"] == "header"
assert result["content"] == "Test Header"
assert result["level"] == 2
assert result["line_start"] == 10
assert result["line_end"] == 10
def test_code_block_with_language(self):
"""Test code block element with language."""
elem = MarkdownElement(
type=MarkdownElementType.CODE_BLOCK,
content="print('hello')",
language="python",
line_start=5,
line_end=7,
)
result = elem.to_dict()
assert result["type"] == "code_block"
assert result["language"] == "python"
def test_table_with_metadata(self):
"""Test table element with metadata."""
elem = MarkdownElement(
type=MarkdownElementType.TABLE,
content="| A | B |\n|---|---|\n| 1 | 2 |",
line_start=1,
line_end=3,
metadata={"headers": ["A", "B"], "row_count": 1},
)
result = elem.to_dict()
assert result["type"] == "table"
assert result["metadata"]["headers"] == ["A", "B"]
assert result["metadata"]["row_count"] == 1
class TestMarkdownChunk:
"""Tests for MarkdownChunk dataclass."""
def test_to_dict(self):
"""Test serialization to dictionary."""
chunk = MarkdownChunk(
chunk_id="doc_chunk_0",
content="Test content",
element_type=MarkdownElementType.PARAGRAPH,
header_context=["Main Title", "Section"],
metadata={"key": "value"},
)
result = chunk.to_dict()
assert result["chunk_id"] == "doc_chunk_0"
assert result["content"] == "Test content"
assert result["element_type"] == "paragraph"
assert result["header_context"] == ["Main Title", "Section"]
assert result["metadata"]["key"] == "value"
def test_with_language(self):
"""Test chunk with language info."""
chunk = MarkdownChunk(
chunk_id="code_0",
content="```python\nprint('hi')\n```",
element_type=MarkdownElementType.CODE_BLOCK,
header_context=[],
language="python",
)
result = chunk.to_dict()
assert result["language"] == "python"
class TestMarkdownElementType:
"""Tests for MarkdownElementType enum."""
def test_all_types_exist(self):
"""Test that all expected element types exist."""
expected_types = [
"header",
"paragraph",
"code_block",
"inline_code",
"table",
"list",
"blockquote",
"horizontal_rule",
"image",
"link",
"text",
]
for type_name in expected_types:
assert hasattr(MarkdownElementType, type_name.upper()) or \
any(t.value == type_name for t in MarkdownElementType)

View File

@ -0,0 +1,443 @@
"""
Unit tests for MetadataAutoInferenceService.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from app.services.metadata_auto_inference_service import (
AutoInferenceResult,
InferenceFieldContext,
MetadataAutoInferenceService,
)
@dataclass
class MockFieldDefinition:
"""Mock field definition for testing"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
class TestInferenceFieldContext:
"""Test InferenceFieldContext dataclass."""
def test_creation(self):
"""Test creating InferenceFieldContext."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
)
assert ctx.field_key == "grade"
assert ctx.label == "年级"
assert ctx.type == "enum"
assert ctx.required is True
assert ctx.options == ["初一", "初二", "初三"]
def test_default_values(self):
"""Test default values."""
ctx = InferenceFieldContext(
field_key="test",
label="Test",
type="text",
required=False,
)
assert ctx.options is None
assert ctx.description is None
class TestAutoInferenceResult:
"""Test AutoInferenceResult dataclass."""
def test_success_result(self):
"""Test successful inference result."""
result = AutoInferenceResult(
inferred_metadata={"grade": "初一", "subject": "数学"},
confidence_scores={"grade": 0.95, "subject": 0.85},
raw_response='{"inferred_metadata": {...}}',
success=True,
)
assert result.success is True
assert result.error_message is None
assert result.inferred_metadata["grade"] == "初一"
def test_failure_result(self):
"""Test failed inference result."""
result = AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response="",
success=False,
error_message="JSON parse error",
)
assert result.success is False
assert result.error_message == "JSON parse error"
class TestMetadataAutoInferenceService:
"""Test MetadataAutoInferenceService functionality."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_build_field_contexts(self, service):
"""Test building field contexts from definitions."""
fields = [
MockFieldDefinition(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
MockFieldDefinition(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
contexts = service._build_field_contexts(fields)
assert len(contexts) == 2
assert contexts[0].field_key == "grade"
assert contexts[0].options == ["初一", "初二", "初三"]
assert contexts[1].field_key == "subject"
def test_build_user_prompt(self, service):
"""Test building user prompt."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道初一数学题",
field_contexts=field_contexts,
)
assert "年级" in prompt
assert "初一, 初二, 初三" in prompt
assert "这是一道初一数学题" in prompt
def test_build_user_prompt_with_existing_metadata(self, service):
"""Test building user prompt with existing metadata."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道数学题",
field_contexts=field_contexts,
existing_metadata={"grade": "初二"},
)
assert "已有值: 初二" in prompt
def test_extract_json_from_plain_json(self, service):
"""Test extracting JSON from plain JSON response."""
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
result = service._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self, service):
"""Test extracting JSON from markdown code block."""
markdown = """Here is the result:
```json
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
```
"""
result = service._extract_json(markdown)
assert "inferred_metadata" in result
assert "grade" in result
def test_parse_llm_response_valid(self, service):
"""Test parsing valid LLM response."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
"subject": "数学",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata["grade"] == "初一"
assert result.inferred_metadata["subject"] == "数学"
assert result.confidence_scores["grade"] == 0.95
def test_parse_llm_response_invalid_option(self, service):
"""Test parsing response with invalid enum option."""
response = json.dumps({
"inferred_metadata": {
"grade": "高一", # Not in options
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" not in result.inferred_metadata
def test_parse_llm_response_invalid_json(self, service):
"""Test parsing invalid JSON response."""
response = "This is not valid JSON"
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="text",
required=False,
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is False
assert "JSON parse error" in result.error_message
def test_validate_field_value_text(self, service):
"""Test validating text field value."""
ctx = InferenceFieldContext(
field_key="title",
label="标题",
type="text",
required=False,
)
result = service._validate_field_value(ctx, "测试标题")
assert result == "测试标题"
def test_validate_field_value_number(self, service):
"""Test validating number field value."""
ctx = InferenceFieldContext(
field_key="count",
label="数量",
type="number",
required=False,
)
assert service._validate_field_value(ctx, 42) == 42
assert service._validate_field_value(ctx, "3.14") == 3.14
assert service._validate_field_value(ctx, "invalid") is None
def test_validate_field_value_boolean(self, service):
"""Test validating boolean field value."""
ctx = InferenceFieldContext(
field_key="active",
label="是否激活",
type="boolean",
required=False,
)
assert service._validate_field_value(ctx, True) is True
assert service._validate_field_value(ctx, "true") is True
assert service._validate_field_value(ctx, "false") is False
assert service._validate_field_value(ctx, 1) is True
def test_validate_field_value_enum(self, service):
"""Test validating enum field value."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=False,
options=["初一", "初二", "初三"],
)
assert service._validate_field_value(ctx, "初一") == "初一"
assert service._validate_field_value(ctx, "高一") is None
def test_validate_field_value_array_enum(self, service):
"""Test validating array_enum field value."""
ctx = InferenceFieldContext(
field_key="tags",
label="标签",
type="array_enum",
required=False,
options=["重点", "难点", "易错"],
)
result = service._validate_field_value(ctx, ["重点", "难点"])
assert result == ["重点", "难点"]
result = service._validate_field_value(ctx, ["重点", "不存在"])
assert result == ["重点"]
result = service._validate_field_value(ctx, "重点")
assert result == ["重点"]
class TestIntegrationScenarios:
"""Test integration scenarios."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_education_scenario(self, service):
"""Test education scenario with grade and subject."""
response = json.dumps({
"inferred_metadata": {
"grade": "初二",
"subject": "物理",
"type": "痛点",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.90,
"type": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语", "物理", "化学"],
),
InferenceFieldContext(
field_key="type",
label="类型",
type="enum",
required=False,
options=["痛点", "重点", "难点"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata == {
"grade": "初二",
"subject": "物理",
"type": "痛点",
}
assert result.confidence_scores["grade"] == 0.95
def test_partial_inference(self, service):
"""Test partial inference when some fields cannot be inferred."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" in result.inferred_metadata
assert "subject" not in result.inferred_metadata
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,881 @@
"""
Mid Platform Dialogue Integration Test.
中台联调界面对话过程集成测试脚本
测试重点:
1. 意图置信度参数
2. 执行模式通用API vs ReAct模式
3. ReAct模式下的工具调用工具名称入参返回结果
4. 知识库查询是否命中入参
5. 各部分耗时
6. 提示词模板使用情况
"""
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from app.main import app
from app.models.mid.schemas import (
DialogueRequest,
DialogueResponse,
ExecutionMode,
FeatureFlags,
HistoryMessage,
Segment,
TraceInfo,
ToolCallTrace,
ToolCallStatus,
ToolType,
IntentHintOutput,
HighRiskCheckResult,
HighRiskScenario,
)
logger = logging.getLogger(__name__)
@dataclass
class TimingRecord:
"""耗时记录"""
stage: str
start_time: float
end_time: float
duration_ms: int
def to_dict(self) -> dict:
return {
"stage": self.stage,
"duration_ms": self.duration_ms,
}
@dataclass
class DialogueTestResult:
"""对话测试结果"""
request_id: str
user_message: str
response_text: str = ""
timing_records: list[TimingRecord] = field(default_factory=list)
total_duration_ms: int = 0
execution_mode: ExecutionMode = ExecutionMode.AGENT
intent: str | None = None
confidence: float | None = None
react_iterations: int = 0
tools_used: list[str] = field(default_factory=list)
tool_calls: list[dict] = field(default_factory=list)
kb_tool_called: bool = False
kb_hit: bool = False
kb_query: str | None = None
kb_filter: dict | None = None
kb_hits_count: int = 0
prompt_template_used: str | None = None
prompt_template_scene: str | None = None
guardrail_triggered: bool = False
fallback_reason_code: str | None = None
raw_trace: dict | None = None
def to_summary(self) -> dict:
return {
"request_id": self.request_id,
"user_message": self.user_message[:100],
"execution_mode": self.execution_mode.value,
"intent": self.intent,
"confidence": self.confidence,
"total_duration_ms": self.total_duration_ms,
"timing_breakdown": [t.to_dict() for t in self.timing_records],
"react_iterations": self.react_iterations,
"tools_used": self.tools_used,
"tool_calls_count": len(self.tool_calls),
"kb_tool_called": self.kb_tool_called,
"kb_hit": self.kb_hit,
"kb_hits_count": self.kb_hits_count,
"prompt_template_used": self.prompt_template_used,
"guardrail_triggered": self.guardrail_triggered,
"fallback_reason_code": self.fallback_reason_code,
}
class DialogueIntegrationTester:
"""对话集成测试器"""
def __init__(
self,
base_url: str = "http://localhost:8000",
tenant_id: str = "test_tenant",
api_key: str | None = None,
):
self.base_url = base_url
self.tenant_id = tenant_id
self.api_key = api_key
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
def _get_headers(self) -> dict:
headers = {
"Content-Type": "application/json",
"X-Tenant-Id": self.tenant_id,
}
if self.api_key:
headers["X-API-Key"] = self.api_key
return headers
async def send_dialogue(
self,
user_message: str,
history: list[dict] | None = None,
scene: str | None = None,
feature_flags: dict | None = None,
) -> DialogueTestResult:
"""发送对话请求并记录详细信息"""
request_id = str(uuid.uuid4())
result = DialogueTestResult(
request_id=request_id,
user_message=user_message,
response_text="",
)
overall_start = time.time()
request_body = {
"session_id": self.session_id,
"user_id": self.user_id,
"user_message": user_message,
"history": history or [],
}
if scene:
request_body["scene"] = scene
if feature_flags:
request_body["feature_flags"] = feature_flags
timing_records = []
try:
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
request_start = time.time()
response = await client.post(
"/mid/dialogue/respond",
json=request_body,
headers=self._get_headers(),
)
request_end = time.time()
timing_records.append(TimingRecord(
stage="http_request",
start_time=request_start,
end_time=request_end,
duration_ms=int((request_end - request_start) * 1000),
))
if response.status_code != 200:
result.response_text = f"Error: {response.status_code}"
result.fallback_reason_code = f"http_error_{response.status_code}"
return result
response_data = response.json()
except Exception as e:
result.response_text = f"Exception: {str(e)}"
result.fallback_reason_code = "request_exception"
result.total_duration_ms = int((time.time() - overall_start) * 1000)
return result
overall_end = time.time()
result.total_duration_ms = int((overall_end - overall_start) * 1000)
result.timing_records = timing_records
try:
dialogue_response = DialogueResponse(**response_data)
if dialogue_response.segments:
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
trace = dialogue_response.trace
result.raw_trace = trace.model_dump() if trace else None
if trace:
result.execution_mode = trace.mode
result.intent = trace.intent
result.react_iterations = trace.react_iterations or 0
result.tools_used = trace.tools_used or []
result.kb_tool_called = trace.kb_tool_called or False
result.kb_hit = trace.kb_hit or False
result.guardrail_triggered = trace.guardrail_triggered or False
result.fallback_reason_code = trace.fallback_reason_code
if trace.tool_calls:
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
for tc in trace.tool_calls:
if tc.tool_name == "kb_search_dynamic":
result.kb_tool_called = True
if tc.arguments:
result.kb_query = tc.arguments.get("query")
result.kb_filter = tc.arguments.get("context")
if tc.result and isinstance(tc.result, dict):
result.kb_hits_count = len(tc.result.get("hits", []))
result.kb_hit = result.kb_hits_count > 0
if trace.duration_ms:
timing_records.append(TimingRecord(
stage="server_processing",
start_time=overall_start,
end_time=overall_end,
duration_ms=trace.duration_ms,
))
if trace.scene:
result.prompt_template_scene = trace.scene
except Exception as e:
logger.error(f"Failed to parse response: {e}")
result.response_text = str(response_data)
return result
def print_result(self, result: DialogueTestResult):
"""打印测试结果"""
print("\n" + "=" * 80)
print(f"[对话测试结果] Request ID: {result.request_id}")
print("=" * 80)
print(f"\n[用户消息] {result.user_message}")
print(f"[回复内容] {result.response_text[:200]}...")
print(f"\n[执行模式] {result.execution_mode.value}")
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
for tr in result.timing_records:
print(f" - {tr.stage}: {tr.duration_ms}ms")
if result.execution_mode == ExecutionMode.AGENT:
print(f"\n[ReAct模式]")
print(f" - 迭代次数: {result.react_iterations}")
print(f" - 使用的工具: {result.tools_used}")
if result.tool_calls:
print(f"\n[工具调用详情]")
for i, tc in enumerate(result.tool_calls, 1):
print(f" [{i}] 工具: {tc.get('tool_name')}")
print(f" 状态: {tc.get('status')}")
print(f" 耗时: {tc.get('duration_ms')}ms")
if tc.get('arguments'):
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
if tc.get('result'):
result_str = str(tc.get('result'))[:300]
print(f" 结果: {result_str}")
print(f"\n[知识库查询]")
print(f" - 是否调用: {result.kb_tool_called}")
print(f" - 是否命中: {result.kb_hit}")
if result.kb_query:
print(f" - 查询内容: {result.kb_query}")
if result.kb_filter:
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
print(f" - 命中数量: {result.kb_hits_count}")
print(f"\n[提示词模板]")
print(f" - 场景: {result.prompt_template_scene}")
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
print(f"\n[其他信息]")
print(f" - 护栏触发: {result.guardrail_triggered}")
print(f" - 降级原因: {result.fallback_reason_code or ''}")
print("\n" + "=" * 80)
class TestMidDialogueIntegration:
"""中台对话集成测试"""
@pytest.fixture
def tester(self):
return DialogueIntegrationTester(
base_url="http://localhost:8000",
tenant_id="test_tenant",
)
@pytest.fixture
def mock_llm_client(self):
"""模拟 LLM 客户端"""
mock = AsyncMock()
mock.generate = AsyncMock(return_value=MagicMock(
content="这是测试回复",
has_tool_calls=False,
tool_calls=[],
))
return mock
@pytest.fixture
def mock_kb_tool(self):
"""模拟知识库工具"""
mock = AsyncMock()
mock.execute = AsyncMock(return_value=MagicMock(
success=True,
hits=[
{"id": "1", "content": "测试知识库内容", "score": 0.9},
],
applied_filter={"scene": "test"},
missing_required_slots=[],
fallback_reason_code=None,
duration_ms=100,
tool_trace=None,
))
return mock
@pytest.mark.asyncio
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
"""测试简单问候"""
result = await tester.send_dialogue(
user_message="你好",
)
tester.print_result(result)
assert result.request_id is not None
assert result.total_duration_ms > 0
@pytest.mark.asyncio
async def test_kb_query(self, tester: DialogueIntegrationTester):
"""测试知识库查询"""
result = await tester.send_dialogue(
user_message="退款流程是什么?",
scene="after_sale",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
"""测试高风险场景"""
result = await tester.send_dialogue(
user_message="我要投诉你们的服务",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_transfer_request(self, tester: DialogueIntegrationTester):
"""测试转人工请求"""
result = await tester.send_dialogue(
user_message="帮我转人工客服",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_with_history(self, tester: DialogueIntegrationTester):
"""测试带历史记录的对话"""
result = await tester.send_dialogue(
user_message="那退款要多久呢?",
history=[
{"role": "user", "content": "我想退款"},
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
],
)
tester.print_result(result)
assert result.request_id is not None
class TestDialogueWithMock:
"""使用 Mock 的对话测试"""
@pytest.fixture
def mock_app(self):
"""创建带 Mock 的测试应用"""
from fastapi import FastAPI
from app.api.mid.dialogue import router
app = FastAPI()
app.include_router(router)
return app
@pytest.fixture
def client(self, mock_app):
return TestClient(mock_app)
@pytest.fixture
def mock_session(self):
"""模拟数据库会话"""
mock = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock.execute.return_value = mock_result
return mock
@pytest.fixture
def mock_llm(self):
"""模拟 LLM 响应"""
mock_response = MagicMock()
mock_response.content = "这是测试回复内容"
mock_response.has_tool_calls = False
mock_response.tool_calls = []
return mock_response
def test_dialogue_request_structure(self, client: TestClient):
"""测试对话请求结构"""
request_body = {
"session_id": "test_session_001",
"user_id": "test_user_001",
"user_message": "你好",
"history": [],
"scene": "open_consult",
}
print("\n[测试请求结构]")
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
assert request_body["session_id"] == "test_session_001"
assert request_body["user_message"] == "你好"
def test_trace_info_structure(self):
"""测试追踪信息结构"""
trace = TraceInfo(
mode=ExecutionMode.AGENT,
intent="greeting",
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
kb_tool_called=True,
kb_hit=True,
react_iterations=2,
tools_used=["kb_search_dynamic"],
tool_calls=[
ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
arguments={"query": "测试查询", "scene": "test"},
result={"hits": [{"content": "测试内容"}]},
),
],
)
print("\n[TraceInfo 结构测试]")
print(f"Mode: {trace.mode.value}")
print(f"Intent: {trace.intent}")
print(f"KB Tool Called: {trace.kb_tool_called}")
print(f"KB Hit: {trace.kb_hit}")
print(f"React Iterations: {trace.react_iterations}")
print(f"Tools Used: {trace.tools_used}")
if trace.tool_calls:
print(f"\n[Tool Calls]")
for tc in trace.tool_calls:
print(f" - Tool: {tc.tool_name}")
print(f" Status: {tc.status.value}")
print(f" Duration: {tc.duration_ms}ms")
if tc.arguments:
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
assert trace.mode == ExecutionMode.AGENT
assert trace.kb_tool_called is True
assert len(trace.tool_calls) == 1
def test_intent_hint_output_structure(self):
"""测试意图提示输出结构"""
hint = IntentHintOutput(
intent="refund",
confidence=0.85,
response_type="flow",
suggested_mode=ExecutionMode.MICRO_FLOW,
target_flow_id="flow_refund_001",
high_risk_detected=False,
duration_ms=50,
)
print("\n[IntentHintOutput 结构测试]")
print(f"Intent: {hint.intent}")
print(f"Confidence: {hint.confidence}")
print(f"Response Type: {hint.response_type}")
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
print(f"High Risk Detected: {hint.high_risk_detected}")
print(f"Duration: {hint.duration_ms}ms")
assert hint.intent == "refund"
assert hint.confidence == 0.85
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
def test_high_risk_check_result_structure(self):
"""测试高风险检测结果结构"""
result = HighRiskCheckResult(
matched=True,
risk_scenario=HighRiskScenario.REFUND,
confidence=0.95,
recommended_mode=ExecutionMode.MICRO_FLOW,
rule_id="rule_refund_001",
reason="检测到退款关键词",
duration_ms=30,
)
print("\n[HighRiskCheckResult 结构测试]")
print(f"Matched: {result.matched}")
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
print(f"Confidence: {result.confidence}")
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
print(f"Rule ID: {result.rule_id}")
print(f"Duration: {result.duration_ms}ms")
assert result.matched is True
assert result.risk_scenario == HighRiskScenario.REFUND
class TestPromptTemplateUsage:
"""提示词模板使用测试"""
def test_template_resolution(self):
"""测试模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
variables = [
{"key": "user_name", "value": "张三"},
{"key": "bot_name", "value": "智能客服"},
]
resolved = resolver.resolve(template, variables)
print("\n[模板解析测试]")
print(f"原始模板: {template}")
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
def test_template_with_extra_context(self):
"""测试带额外上下文的模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "当前场景:{{scene}},用户问题:{{query}}"
extra_context = {
"scene": "售后服务",
"query": "退款流程",
}
resolved = resolver.resolve(template, [], extra_context)
print("\n[带上下文的模板解析测试]")
print(f"原始模板: {template}")
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert "售后服务" in resolved
assert "退款流程" in resolved
class TestToolCallRecording:
"""工具调用记录测试"""
def test_tool_call_trace_creation(self):
"""测试工具调用追踪创建"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
args_digest="query=退款流程",
result_digest="hits=3",
arguments={
"query": "退款流程是什么",
"scene": "after_sale",
"context": {"product_type": "vip"},
},
result={
"success": True,
"hits": [
{"id": "1", "content": "退款流程说明...", "score": 0.95},
{"id": "2", "content": "退款注意事项...", "score": 0.88},
{"id": "3", "content": "退款时效说明...", "score": 0.82},
],
"applied_filter": {"product_type": "vip"},
},
)
print("\n[工具调用追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Tool Type: {trace.tool_type.value}")
print(f"Status: {trace.status.value}")
print(f"Duration: {trace.duration_ms}ms")
print(f"\n[入参详情]")
if trace.arguments:
for key, value in trace.arguments.items():
print(f" - {key}: {value}")
print(f"\n[返回结果]")
if trace.result:
if isinstance(trace.result, dict):
print(f" - success: {trace.result.get('success')}")
print(f" - hits count: {len(trace.result.get('hits', []))}")
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
assert trace.tool_name == "kb_search_dynamic"
assert trace.status == ToolCallStatus.OK
assert trace.arguments is not None
assert trace.result is not None
def test_tool_call_timeout_trace(self):
"""测试工具调用超时追踪"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=2000,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
arguments={"query": "测试查询"},
)
print("\n[工具调用超时追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Status: {trace.status.value}")
print(f"Error Code: {trace.error_code}")
print(f"Duration: {trace.duration_ms}ms")
assert trace.status == ToolCallStatus.TIMEOUT
assert trace.error_code == "TOOL_TIMEOUT"
class TestExecutionModeRouting:
"""执行模式路由测试"""
def test_policy_router_decision(self):
"""测试策略路由器决策"""
from app.services.mid.policy_router import PolicyRouter, IntentMatch
router = PolicyRouter()
test_cases = [
{
"name": "正常对话 -> Agent模式",
"user_message": "你好,请问有什么可以帮助我的?",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.AGENT,
},
{
"name": "高风险退款 -> Micro Flow模式",
"user_message": "我要退款",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.MICRO_FLOW,
},
{
"name": "转人工请求 -> Transfer模式",
"user_message": "帮我转人工",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
{
"name": "人工模式 -> Transfer模式",
"user_message": "你好",
"session_mode": "HUMAN_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
]
print("\n[策略路由器决策测试]")
for tc in test_cases:
result = router.route(
user_message=tc["user_message"],
session_mode=tc["session_mode"],
)
print(f"\n测试用例: {tc['name']}")
print(f" 用户消息: {tc['user_message']}")
print(f" 会话模式: {tc['session_mode']}")
print(f" 期望模式: {tc['expected_mode'].value}")
print(f" 实际模式: {result.mode.value}")
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
class TestKBSearchDynamic:
"""知识库动态检索测试"""
def test_kb_search_result_structure(self):
"""测试知识库检索结果结构"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=True,
hits=[
{
"id": "chunk_001",
"content": "退款流程1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
"score": 0.92,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
},
{
"id": "chunk_002",
"content": "退款时效一般3-5个工作日到账...",
"score": 0.85,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
},
],
applied_filter={"scene": "after_sale", "product_type": "vip"},
missing_required_slots=[],
filter_debug={"source": "slot_state"},
filter_sources={"scene": "slot", "product_type": "context"},
duration_ms=120,
)
print("\n[知识库检索结果结构测试]")
print(f"Success: {result.success}")
print(f"Hits Count: {len(result.hits)}")
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
print(f"Duration: {result.duration_ms}ms")
print(f"\n[命中详情]")
for i, hit in enumerate(result.hits, 1):
print(f" [{i}] ID: {hit['id']}")
print(f" Score: {hit['score']}")
print(f" Content: {hit['content'][:50]}...")
assert result.success is True
assert len(result.hits) == 2
assert result.applied_filter is not None
def test_kb_search_missing_slots(self):
"""测试知识库检索缺失槽位"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=False,
hits=[],
applied_filter={},
missing_required_slots=[
{
"field_key": "order_id",
"label": "订单号",
"reason": "必填字段缺失",
"ask_back_prompt": "请提供您的订单号",
},
],
filter_debug={"source": "builder"},
fallback_reason_code="MISSING_REQUIRED_SLOTS",
duration_ms=50,
)
print("\n[知识库检索缺失槽位测试]")
print(f"Success: {result.success}")
print(f"Fallback Reason: {result.fallback_reason_code}")
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
assert result.success is False
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
assert len(result.missing_required_slots) == 1
class TestTimingBreakdown:
"""耗时分解测试"""
def test_timing_breakdown_structure(self):
"""测试耗时分解结构"""
timings = [
TimingRecord("intent_matching", 0, 0.05, 50),
TimingRecord("high_risk_check", 0.05, 0.08, 30),
TimingRecord("kb_search", 0.08, 0.2, 120),
TimingRecord("llm_generation", 0.2, 1.5, 1300),
TimingRecord("output_guardrail", 1.5, 1.55, 50),
TimingRecord("response_formatting", 1.55, 1.6, 50),
]
total = sum(t.duration_ms for t in timings)
print("\n[耗时分解测试]")
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
print("-" * 45)
for t in timings:
percentage = (t.duration_ms / total * 100) if total > 0 else 0
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
print("-" * 45)
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
assert total == 1600
def run_manual_test():
"""手动运行测试"""
import argparse
parser = argparse.ArgumentParser(description="中台对话集成测试")
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
parser.add_argument("--api-key", default=None, help="API Key")
parser.add_argument("--message", default="你好", help="测试消息")
parser.add_argument("--scene", default=None, help="场景标识")
parser.add_argument("--interactive", action="store_true", help="交互模式")
args = parser.parse_args()
tester = DialogueIntegrationTester(
base_url=args.url,
tenant_id=args.tenant,
api_key=args.api_key,
)
if args.interactive:
print("\n=== 中台对话集成测试 - 交互模式 ===")
print("输入 'quit' 退出\n")
while True:
try:
message = input("请输入消息: ").strip()
if message.lower() == "quit":
break
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
result = asyncio.run(tester.send_dialogue(
user_message=message,
scene=scene,
))
tester.print_result(result)
except KeyboardInterrupt:
break
else:
result = asyncio.run(tester.send_dialogue(
user_message=args.message,
scene=args.scene,
))
tester.print_result(result)
if __name__ == "__main__":
run_manual_test()

View File

@ -0,0 +1,541 @@
"""
Unit tests for Retrieval Strategy Service.
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategyValidationRequest,
ValidationResult,
)
from app.services.retrieval.strategy_service import (
RetrievalStrategyService,
StrategyState,
get_strategy_service,
)
from app.services.retrieval.strategy_audit import (
StrategyAuditService,
get_audit_service,
)
from app.services.retrieval.strategy_metrics import (
StrategyMetricsService,
get_metrics_service,
)
class TestRetrievalStrategySchemas:
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
def test_rollout_config_off_mode(self):
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
config = RolloutConfig(mode=RolloutMode.OFF)
assert config.mode == RolloutMode.OFF
assert config.percentage is None
assert config.allowlist is None
def test_rollout_config_percentage_mode(self):
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
assert config.mode == RolloutMode.PERCENTAGE
assert config.percentage == 50.0
def test_rollout_config_percentage_mode_missing_value(self):
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
with pytest.raises(ValueError, match="percentage is required"):
RolloutConfig(mode=RolloutMode.PERCENTAGE)
def test_rollout_config_allowlist_mode(self):
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
assert config.mode == RolloutMode.ALLOWLIST
assert config.allowlist == ["tenant1", "tenant2"]
def test_rollout_config_allowlist_mode_missing_value(self):
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
with pytest.raises(ValueError, match="allowlist is required"):
RolloutConfig(mode=RolloutMode.ALLOWLIST)
def test_retrieval_strategy_status(self):
"""[AC-AISVC-RES-01] Status should contain all required fields."""
rollout = RolloutConfig(mode=RolloutMode.OFF)
status = RetrievalStrategyStatus(
active_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
rollout=rollout,
)
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_request_minimal(self):
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode is None
assert request.rollout is None
assert request.reason is None
def test_switch_request_full(self):
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
rollout=rollout,
reason="Testing enhanced strategy",
)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode == ReactMode.REACT
assert request.rollout.percentage == 30.0
assert request.reason == "Testing enhanced strategy"
class TestRetrievalStrategyService:
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
@pytest.fixture
def service(self):
"""Create a fresh service instance for each test."""
return RetrievalStrategyService()
def test_get_current_status_default(self, service):
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
status = service.get_current_status()
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_strategy_to_enhanced(self, service):
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
response = service.switch_strategy(request)
assert response.previous.active_strategy == StrategyType.DEFAULT
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.react_mode == ReactMode.REACT
def test_switch_strategy_with_grayscale_percentage(self, service):
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
assert response.current.rollout.percentage == 50.0
def test_switch_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a", "tenant_b"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
assert "tenant_a" in response.current.rollout.allowlist
def test_rollback_strategy(self, service):
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
assert response.rollback_to.react_mode == ReactMode.NON_REACT
def test_rollback_without_previous_returns_default(self, service):
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
def test_should_use_enhanced_strategy_default(self, service):
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
assert service.should_use_enhanced_strategy("tenant_a") is False
def test_should_use_enhanced_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
service.switch_strategy(request)
assert service.should_use_enhanced_strategy("tenant_a") is True
assert service.should_use_enhanced_strategy("tenant_b") is False
def test_get_route_mode_react(self, service):
"""[AC-AISVC-RES-10] React mode should return react route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "react"
def test_get_route_mode_direct(self, service):
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "direct"
def test_get_route_mode_auto_short_query(self, service):
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
service._state.react_mode = None
route = service._auto_route("短问题", confidence=0.8)
assert route == "direct"
def test_get_route_mode_auto_multiple_conditions(self, service):
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
route = service._auto_route("查询订单状态和物流信息")
assert route == "react"
def test_get_route_mode_auto_low_confidence(self, service):
"""[AC-AISVC-RES-13] Low confidence should use react route."""
route = service._auto_route("test query", confidence=0.3)
assert route == "react"
def test_get_switch_history(self, service):
"""Should track switch history."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
reason="Testing",
)
service.switch_strategy(request)
history = service.get_switch_history()
assert len(history) == 1
assert history[0]["to_strategy"] == "enhanced"
class TestRetrievalStrategyValidation:
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
@pytest.fixture
def service(self):
return RetrievalStrategyService()
def test_validate_default_strategy(self, service):
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.DEFAULT,
)
response = service.validate_strategy(request)
assert response.passed is True
def test_validate_enhanced_strategy(self, service):
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
)
response = service.validate_strategy(request)
assert isinstance(response.passed, bool)
assert len(response.results) > 0
def test_validate_specific_checks(self, service):
"""[AC-AISVC-RES-06] Should run specific validation checks."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
checks=["metadata_consistency", "performance_budget"],
)
response = service.validate_strategy(request)
check_names = [r.check for r in response.results]
assert "metadata_consistency" in check_names
assert "performance_budget" in check_names
def test_check_metadata_consistency(self, service):
"""[AC-AISVC-RES-04] Metadata consistency check."""
result = service._check_metadata_consistency(StrategyType.DEFAULT)
assert result.check == "metadata_consistency"
assert result.passed is True
def test_check_rrf_config(self, service):
"""[AC-AISVC-RES-02] RRF config check."""
result = service._check_rrf_config(StrategyType.DEFAULT)
assert result.check == "rrf_config"
assert isinstance(result.passed, bool)
def test_check_performance_budget(self, service):
"""[AC-AISVC-RES-08] Performance budget check."""
result = service._check_performance_budget(
StrategyType.ENHANCED,
ReactMode.REACT,
)
assert result.check == "performance_budget"
assert isinstance(result.passed, bool)
class TestStrategyAuditService:
"""[AC-AISVC-RES-07] Tests for audit service."""
@pytest.fixture
def audit_service(self):
return StrategyAuditService(max_entries=100)
def test_log_switch_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log switch operation."""
audit_service.log(
operation="switch",
previous_strategy="default",
new_strategy="enhanced",
reason="Testing",
operator="admin",
)
entries = audit_service.get_audit_log()
assert len(entries) == 1
assert entries[0].operation == "switch"
assert entries[0].previous_strategy == "default"
assert entries[0].new_strategy == "enhanced"
def test_log_rollback_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log rollback operation."""
audit_service.log_rollback(
previous_strategy="enhanced",
new_strategy="default",
reason="Performance issue",
operator="admin",
)
entries = audit_service.get_audit_log(operation="rollback")
assert len(entries) == 1
assert entries[0].operation == "rollback"
def test_log_validation_operation(self, audit_service):
"""[AC-AISVC-RES-06] Should log validation operation."""
audit_service.log_validation(
strategy="enhanced",
checks=["metadata_consistency"],
passed=True,
)
entries = audit_service.get_audit_log(operation="validate")
assert len(entries) == 1
assert entries[0].operation == "validate"
def test_get_audit_log_with_limit(self, audit_service):
"""Should limit audit log entries."""
for i in range(10):
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
entries = audit_service.get_audit_log(limit=5)
assert len(entries) == 5
def test_get_audit_stats(self, audit_service):
"""Should return audit statistics."""
audit_service.log(operation="switch", new_strategy="enhanced")
audit_service.log(operation="rollback", new_strategy="default")
stats = audit_service.get_audit_stats()
assert stats["total_entries"] == 2
assert stats["operation_counts"]["switch"] == 1
assert stats["operation_counts"]["rollback"] == 1
def test_clear_audit_log(self, audit_service):
"""Should clear audit log."""
audit_service.log(operation="switch", new_strategy="enhanced")
assert len(audit_service.get_audit_log()) == 1
count = audit_service.clear_audit_log()
assert count == 1
assert len(audit_service.get_audit_log()) == 0
class TestStrategyMetricsService:
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
@pytest.fixture
def metrics_service(self):
return StrategyMetricsService()
def test_record_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record request metrics."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
route_mode="direct",
)
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 1
assert metrics.successful_requests == 1
assert metrics.avg_latency_ms == 100.0
def test_record_failed_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record failed request."""
metrics_service.record_request(latency_ms=50.0, success=False)
metrics = metrics_service.get_metrics()
assert metrics.failed_requests == 1
def test_record_fallback(self, metrics_service):
"""[AC-AISVC-RES-08] Should record fallback count."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
fallback=True,
)
metrics = metrics_service.get_metrics()
assert metrics.fallback_count == 1
def test_record_route_metrics(self, metrics_service):
"""[AC-AISVC-RES-08] Should track route mode metrics."""
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
route_metrics = metrics_service.get_route_metrics()
assert "react" in route_metrics
assert "direct" in route_metrics
def test_get_all_metrics(self, metrics_service):
"""Should get metrics for all strategies."""
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
metrics_service.record_request(latency_ms=100.0, success=True)
all_metrics = metrics_service.get_all_metrics()
assert StrategyType.DEFAULT.value in all_metrics
assert StrategyType.ENHANCED.value in all_metrics
def test_get_performance_summary(self, metrics_service):
"""[AC-AISVC-RES-08] Should get performance summary."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.record_request(latency_ms=200.0, success=True)
metrics_service.record_request(latency_ms=50.0, success=False)
summary = metrics_service.get_performance_summary()
assert summary["total_requests"] == 3
assert summary["successful_requests"] == 2
assert summary["failed_requests"] == 1
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
def test_check_performance_threshold_ok(self, metrics_service):
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
metrics_service.record_request(latency_ms=100.0, success=True)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is True
assert result["error_rate_ok"] is True
assert result["overall_ok"] is True
def test_check_performance_threshold_exceeded(self, metrics_service):
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
metrics_service.record_request(latency_ms=6000.0, success=True)
metrics_service.record_request(latency_ms=100.0, success=False)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is False or result["error_rate_ok"] is False
def test_reset_metrics(self, metrics_service):
"""Should reset metrics."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.reset_metrics()
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 0
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_strategy_service_singleton(self):
"""Should return same strategy service instance."""
from app.services.retrieval.strategy_service import _strategy_service
import app.services.retrieval.strategy_service as module
module._strategy_service = None
service1 = get_strategy_service()
service2 = get_strategy_service()
assert service1 is service2
def test_get_audit_service_singleton(self):
"""Should return same audit service instance."""
from app.services.retrieval.strategy_audit import _audit_service
import app.services.retrieval.strategy_audit as module
module._audit_service = None
service1 = get_audit_service()
service2 = get_audit_service()
assert service1 is service2
def test_get_metrics_service_singleton(self):
"""Should return same metrics service instance."""
from app.services.retrieval.strategy_metrics import _metrics_service
import app.services.retrieval.strategy_metrics as module
module._metrics_service = None
service1 = get_metrics_service()
service2 = get_metrics_service()
assert service1 is service2

View File

@ -0,0 +1,353 @@
"""
Integration tests for Retrieval Strategy API.
[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints.
Tests the full API flow:
- GET /strategy/retrieval/current
- POST /strategy/retrieval/switch
- POST /strategy/retrieval/validate
- POST /strategy/retrieval/rollback
"""
import json
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture(autouse=True)
def mock_api_key_service():
"""
Mock API key service to bypass authentication in tests.
"""
mock_service = MagicMock()
mock_service._initialized = True
mock_service._keys_cache = {"test-api-key": MagicMock()}
mock_validation = MagicMock()
mock_validation.ok = True
mock_validation.reason = None
mock_service.validate_key_with_context.return_value = mock_validation
with patch("app.services.api_key.get_api_key_service", return_value=mock_service):
yield mock_service
@pytest.fixture(autouse=True)
def reset_strategy_state():
"""
Reset strategy state before and after each test.
"""
from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router
from app.services.retrieval.strategy.config import RetrievalStrategyConfig
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
yield
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
class TestRetrievalStrategyAPIIntegration:
"""
[AC-AISVC-RES-01~15] Integration tests for retrieval strategy API.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_get_current_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-01] GET /current should return strategy status.
"""
response = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "active_strategy" in data
assert "grayscale" in data
assert data["active_strategy"] in ["default", "enhanced"]
def test_switch_strategy_to_enhanced(self, client, valid_headers):
"""
[AC-AISVC-RES-02] POST /switch should switch to enhanced strategy.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["success"] is True
assert data["current_strategy"] == "enhanced"
def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept grayscale percentage.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"percentage": 30.0,
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_switch_strategy_with_allowlist(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept allowlist.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"allowlist": ["tenant_a", "tenant_b"],
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_validate_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] POST /validate should validate strategy.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "valid" in data
assert "errors" in data
assert isinstance(data["valid"], bool)
def test_validate_default_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] Default strategy should pass validation.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "default",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is True
def test_rollback_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-07] POST /rollback should rollback to default.
"""
client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
response = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["current_strategy"] == "default"
class TestRetrievalStrategyAPIValidation:
"""
[AC-AISVC-RES-03] Tests for API request validation.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_switch_invalid_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Invalid strategy value should return error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "invalid_strategy",
},
headers=valid_headers,
)
assert response.status_code in [400, 422, 500]
def test_switch_percentage_out_of_range(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Percentage > 100 should return validation error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"percentage": 150.0,
},
},
headers=valid_headers,
)
assert response.status_code in [400, 422]
class TestRetrievalStrategyAPIFlow:
"""
[AC-AISVC-RES-01~15] Tests for complete API flow scenarios.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_complete_strategy_lifecycle(self, client, valid_headers):
"""
[AC-AISVC-RES-01~07] Test complete strategy lifecycle:
1. Get current strategy
2. Switch to enhanced
3. Validate
4. Rollback
5. Verify back to default
"""
current = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert current.status_code == 200
assert current.json()["active_strategy"] == "default"
switch = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {"enabled": True, "percentage": 50.0},
},
headers=valid_headers,
)
assert switch.status_code == 200
assert switch.json()["current_strategy"] == "enhanced"
validate = client.post(
"/strategy/retrieval/validate",
json={"active_strategy": "enhanced"},
headers=valid_headers,
)
assert validate.status_code == 200
rollback = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert rollback.status_code == 200
assert rollback.json()["current_strategy"] == "default"
final = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert final.status_code == 200
assert final.json()["active_strategy"] == "default"
class TestRetrievalStrategyAPIMissingTenant:
"""
Tests for API behavior without tenant ID.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def api_key_headers(self):
return {"X-API-Key": "test-api-key"}
def test_current_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.get(
"/strategy/retrieval/current",
headers=api_key_headers,
)
assert response.status_code == 400
def test_switch_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.post(
"/strategy/retrieval/switch",
json={"active_strategy": "enhanced"},
headers=api_key_headers,
)
assert response.status_code == 400

View File

@ -0,0 +1,3 @@
# 7年级到课赠礼
这是一张7年级到课赠礼的图片可能包含到课奖励、学习用品或相关礼品的信息。图片可能展示赠礼的实物照片、礼品包装或相关的宣传内容。

Binary file not shown.

After

Width:  |  Height:  |  Size: 466 KiB

View File

@ -0,0 +1,25 @@
# s班小学3年级到课赠礼
这张图片是一张教育类宣传海报,主题为“飞跃领航计划”,核心内容是展示“专属福利大礼包”的五天完课福利,同时包含人物展示与视觉装饰元素。整体风格活泼,色彩明快,以浅蓝 - 浅绿渐变为主背景,搭配卡通元素增强亲和力。
## 1. 文字内容(按视觉顺序提取)
- 标题区左上角橙色标签“3阶”主标题“飞跃领航计划”“飞跃”为黑色粗体“领航计划”为绿色粗体下方绿色横幅“专属福利大礼包”。
- 人物区:三位人物(两位女性、一位男性)并排站立,每人旁有蓝色标签标注姓名:**张婷婷**、**褚佳麟**、**王亚男**;人物下方黄色标签“思维 | 人文 | 剑桥”。
- 福利列表区(白色背景框内,按天划分):
- 第一天/完课福利:①《思维模块知识导图册》+《三年级口算14000题》
- 第二天/完课福利①精选双语动画电影20部上10部 ②《应用题专项练习》+《小升初数学知识要点汇总》
- 第三天/完课福利:①《知识大盘点+易错大集合》 ②《世界上下五千年》音频资料
- 第四天/完课福利:①《考前高效培优知识梳理总复习》+《期末检测卷》2套
- 第五天/完课福利①精选双语动画电影20部下10部
## 2. 人物与视觉元素
- 人物:三位形象正面、微笑,穿着职业装(女性为衬衫/西装,男性为浅灰西装),姿态自然,传递专业与亲和感。
- 颜色:背景为浅蓝 - 浅绿渐变;标题文字黑、绿对比;人物标签蓝色;福利列表背景白色,文字黑色;卡通元素(底部橙子、礼物盒、小图标)色彩鲜艳(橙、黄、紫等),增加活泼感。
- 布局:顶部为标题+人物展示区,中间为福利列表(分天排版,用圆点区分项目),底部为装饰性卡通元素(如橙子、礼物、电话/书本图标),整体结构清晰,信息层级分明。
## 3. 风格与意图
海报风格偏向教育类宣传的“活泼专业”:通过卡通元素降低距离感,通过人物展示增强信任感,通过分天福利列表清晰传递“完课奖励”的核心信息,目标受众可能是中小学生或家长,旨在推广“飞跃领航计划”课程。
## 4. 额外说明
- 视觉细节:底部卡通元素(如带笑脸的橙子、礼物盒)呼应“福利礼包”主题,增强趣味性;人物标签“思维 | 人文 | 剑桥”暗示课程涵盖多学科与国际化(剑桥)特色。
- 信息逻辑:福利列表从“知识导图+口算”到“动画电影+音频”,再到“复习资料+试卷”,覆盖“知识输入 - 兴趣培养 - 复习巩固”全流程,体现课程设计的完整性。

Binary file not shown.

After

Width:  |  Height:  |  Size: 482 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 618 KiB

View File

@ -0,0 +1,65 @@
# 一部7年级课表
这是一张初一训练营的课程表,主题为"双语素养+科学思维",旨在帮助学生初一/初一定三年学习不犯难。课程表包含四位主讲老师的信息和详细的课程安排还提供了App下载二维码。
## 1. 课程主题与目标
- 主标题:训练营 飞跃领航计划
- 核心主题:双语素养+科学思维
- 目标受众:初一/初一定三年,学习不犯难
- 下载方式扫码下载App听课提供安卓和苹果版本二维码
## 2. 主讲老师信息
1. **陈久杰** - 科学探究主讲
- 科学探究资深主讲老师
- 12年线上与线下教学经验
- 科学探究教学负责人
- 《中考物理满分冲刺》主编
- 人事人才网家庭教育高级指导师
- 北京大学心理发展研修班进修
2. **毕玉琦** - 卓越双语主讲
- 卓越双语学科教学总负责人和资深主讲老师
- 全国巡回英语讲座200余场
- 获新加坡南洋理工大学TESOL教学认证
- 一线英语教学10余年
- 培养中考英语高分学员10000+人
3. **阙红乾** - 思维逻辑主讲
- 中考满分或保送学员41位清北学员12位
- 12年授课经验线上学员过百万
- 海淀数学联赛顶尖师资
- 全体教师赛课一等奖
- S级主讲、教学大师奖、最佳教学效果奖、最受学生喜爱奖、最具课程吸引力奖、中考学员王者之师、金牌培训师
4. **张晓煜** - 人文博学主讲
- 资深人文主讲、主讲培训师
- 深耕语文一线教学14年
- 高级脑力潜能开发师
- 高级家庭教育指导师
- 高级思维导图教师
- 《初中语文考点一本通》《精准练语文》《21天预复习》等图书主编
## 3. 课程安排表
| 时间 | 科目 | 课程名称 | 课程内容 | 听课要求 |
|------|------|----------|----------|----------|
| 周四 19:55-20:55 | 规划讲座 | 学习规划讲座 | 【家长必听】初一全年学习规划及高效学习方法,帮助孩子领跑新征程! | 家长必听 |
| 周五 18:55-19:55 | 思维逻辑(数) | 双中会 | 同学们对于线段双动点或者角的双角平分线理解存在较大难度,往往读不懂题。本节课帮助孩子从底层知识理解该类题目,快速搞定双角平分线或者双中点问题。 | 亲子共听 |
| 周五 19:55-20:55 | 卓越双语(英) | 有根有据记词汇 | 驼哥用词根词缀法帮大家拆解单词,用底层逻辑法教大家辨别一词多义,让词汇记忆与积累不再发愁,词汇题目迎刃而解。 | 亲子共听 |
| 周六 18:55-19:55 | 科学探究(物) | 隐力剧场:无形托举者 | 大气压强是一种看不见摸不到的物理现象,本节课运用实验让同学们在动手当中观察现象总结结论,理解物理概念。 | 亲子共听 |
| 周六 19:55-20:55 | 卓越双语(英) | 有的放矢通关语法 | 对名词和数词的分类难以掌握、名词的变形和数词的用法感到困惑。本节课将通过"武术有师父"等驼哥专属口诀帮你解决初一语法中的两大词法问题夯实语法基础提升10分。本节课通过讲解成长类作文结构帮助学生掌握事件描写和开头结尾技巧并分享高分素材让学生写作文从流水账成长到高分作文。 | 亲子共听 |
| 周六 14:00-15:10 | 人文博学(语) | 从流水账到黄金屋 | 本章是孩子初中第一次接触几何证明题+几何辅助线构造+几何模型构造,对几何核心思维培养至关重要!本节课带你搭建几何模型思维,感受模型秒解的魅力! | 亲子共听 |
| 周日 18:55-19:55 | 思维逻辑(数) | 拐点模型 | 本章是孩子初中第一次接触几何证明题+几何辅助线构造+几何模型构造,对几何核心思维培养至关重要!本节课带你搭建几何模型思维,感受模型秒解的魅力! | 亲子共听 |
| 周日 19:55-20:55 | 科学探究(物) | 解密声音密码 | 通过实验法讲解抽象物理概念,让学生可以通透的理解声音的三大要素并能掌握核心的考点解释生活现象。启发物理思维,培养物理兴趣,为初二学习物理做好铺垫。 | 亲子共听 |
## 4. 设计特点
- 色彩:浅蓝渐变背景,搭配橙色、绿色等明亮色彩,整体风格活泼专业
- 布局:左侧为老师介绍,右侧为课程表,信息分区明确
- 互动提供二维码下载App方便学生和家长听课
- 听课要求:区分"家长必听"和"亲子共听",体现不同课程的参与方式
## 5. 课程特色
- 覆盖多学科:科学探究、卓越双语、思维逻辑、人文博学
- 时间安排合理:工作日晚上和周末安排课程
- 实用性强:课程内容针对初中学习重点和难点
- 方法指导:不仅教授知识,还提供学习方法指导

View File

@ -0,0 +1,54 @@
# 以前缺少学习动力,高途的直播课让孩子学习态度积极;对课程老师都很认可
这是一张用户评价截图,展示了家长对孩子在高途直播课学习情况的反馈。截图呈现了对话形式,包含多位家长的评价内容,主要表达了对课程设计和老师教学的认可。
## 1. 主要评价内容
### 左侧评价(陈琦家长)
- **用户名**:陈琦(chengqi34)
- **评价内容**
- "孩子最近状态咋样啊"
- "上课挺认真的!就是正确率仅在及格线上一点。😊"
- "你们的课程和老师都有趣!她很喜欢!😊"
- "你们的二讲老师超级负责!课后不懂的作业,老师会打电话来指导!这个很让我感动!👍👍👍"
- "也非常感谢您帮我找了三个这么优秀的二讲老师!👏👏👏"
- "现在她还没养成预习、复习、订正的习惯。"
- "如果这个习惯养成了,估计她后续的学习不会那么吃力了!上课更不会睡觉了!😊"
- "你们的课程互动环节设计得非常好!能牢牢抓住她的注意力!👍👍👍"
- "她现在上课不用我催促,非常自觉!😊"
- "也多谢您提醒我让她上午上完课就睡觉。这样她一点半上课就不会犯困了,整个下午也有精神。😊"
- "还是您懂因材施教!👍👍👍"
- **时间**8/16 17:37:55
### 右侧评价
- **评价内容**
- "孩子的笔记越来越好了!能看出有进步"
- "嗯嗯笔记比以前有进步,这得感谢小王老师的风趣的课堂氛围,与严格的监督与督促👍👍👍"
- "以前缺少学习动力,高途的直播课让同学对学习态度积极😊"
- "还是要看孩子练习中的问题!勇于探索与解决问题最重要啦~"
- "这孩子就是缺少学习动力,现在咱们的直播氛围比较喜欢 所以好像学习上有点积极态度了。"
- "学习本来就是很快乐的事情,没有那么难,勤奋多思考"
- "感谢您的辛苦付出🌹🌹🌹"
- "孩子的初中很好~孩子多勤奋一些,英语真不难,词汇多背,语法搞懂,多记笔记,多练习"
## 2. 视觉设计
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
- **表情**:包含多种表情符号(😊、👍、👏、🌹等),增加亲和力
- **格式**:使用气泡对话框形式,模拟真实聊天场景
## 3. 评价特点
- **真实感**:采用对话形式,增强可信度
- **具体反馈**:包含具体的学习进步描述(笔记变好、上课认真等)
- **多角度评价**:从不同家长角度反映课程效果
- **情感表达**:包含感谢和积极的情感表达
- **教学认可**:特别提到老师负责、课程有趣、互动设计好等优势
## 4. 营销意图
- 展示真实用户评价,建立信任
- 突出课程对学习动力的提升作用
- 强调老师负责和课程设计优秀
- 体现因材施教的教学理念
- 展示学习进步的具体案例
这张评价截图有效地展示了高途直播课的教学效果和家长满意度,通过真实的对话形式传递课程价值。

View File

@ -0,0 +1,51 @@
# 喜欢主讲老师第一次上80很难得
这是一张用户评价截图,展示了家长对孩子在高途课程学习情况的反馈,特别提到了对杨易老师的喜爱和成绩提升的情况。
## 1. 主要评价内容
### 左侧评价
- **用户反馈**
- "他是因为杨易老师了,所以杨易老师说啥他都积极响应"
- "从娃上网课的状态就能看出,他其实蛮专注力一点问题都没有。"
- "还是看老师讲的是不是他感兴趣的,上课方式是不是他喜欢的。"
- "他上课的状态太好了,我忍不住拍视频给我妈看😊"
- **系统回复**
- "哈哈哈,这个状态太喜人啦😊"
### 右侧评价
- **用户反馈**
- "八下第一次上了80很难得。比之前进步很多"
- "他考试有进步的。期中考试84之前都是79左右80分很难得"
- "哇塞,太惊喜了🌹🌹八下语文还是比较难的,取得了进步,太不错了,看得出来孩子真的投入了,努力了"
- "咱们继续保持,加油加油,我会持续关注瀚清😊"
- "我跟天翼老师也分享一下下喜报😊"
- "之前一直在其他机构,换了高途提升很多"
- "确实高途的课程适合他,也是他选择换课程的。"
- "必须分享"
- "嗯嗯,孩子也很努力,我看好孩子,后面也会越来越好的,咱们一起相互配合,一起加油😊"
- "妈妈也是第一时间被天翼老师吸引"
- "我也是第一时间被天翼老师吸引才分享给刘瀚清试课的"
## 2. 视觉设计
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
- **表情**:包含多种表情符号(😊、🌹等),增加亲和力
- **格式**:使用气泡对话框形式,模拟真实聊天场景
- **图片**:左侧包含一个小图片,显示上课场景
## 3. 评价特点
- **具体成绩提升**:明确提到"八下第一次上了80",显示具体的学习进步
- **老师影响**:强调杨易老师对孩子学习状态的积极影响
- **课程转换**:提到从其他机构转到高途课程
- **家长认可**:表达对天翼老师的认可和吸引
- **真实感**:采用对话形式,增强可信度
## 4. 营销意图
- 展示真实用户评价,建立信任
- 突出老师对学生学习的积极影响
- 体现课程效果的具体案例
- 展示学生成绩提升的实例
- 强调课程转换后的积极变化
这张评价截图有效地展示了高途课程的教学效果和家长满意度,通过真实的对话形式传递课程价值,特别突出了老师对学生学习状态的积极影响。

View File

@ -0,0 +1,54 @@
# 成绩提升30多分
这是一张用户评价截图,展示了学生在高途课程学习后取得显著成绩提升的情况,特别突出了英语成绩的进步。
## 1. 主要评价内容
### 左侧评价
- **用户反馈**
- "老师我英语从60多进步到80多"
- "老师我们出成绩了考的是789章的比上一次考试进步了31分"
- "谢谢老师夸奖"
- "这段时间我上课认真听课,单词好好背诵,然后也认真听褚帅老师的课"
- "总分从405进步到432了"
- "进步的分数里,英语占了一大半"
### 右侧评价
- **系统回复**
- "我就知道你是可以的😊"
- "你怎么进步了,分享一下这段时间的自己规划"
- "🌹"
### 底部信息
- **成绩展示**
- "进步30多分"(红色大字)
- 英语成绩70分可能为本次成绩
- 英语成绩66分可能为上次成绩
- "老师我这次英语成绩比上次进步了30分。"
- "然后作文也有很大进步。"
- "哇塞"
## 2. 视觉元素
- **图片**:左侧包含两张图片
- 上方:学生成绩单照片
- 下方Lays薯片卡通形象
- **卡通形象**:右侧有一个可爱的卡通小熊形象,增加亲和力
- **成绩数据**明确显示具体的分数提升31分总分提升27分
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
- **布局**:左右分栏的对话形式,底部有成绩展示区域
## 3. 评价特点
- **具体成绩数据**明确提到进步31分总分从405到432
- **学科重点**特别强调英语成绩的提升从60多到80多
- **学习态度**:提到认真听课和背诵单词的良好学习习惯
- **多科进步**:不仅英语进步,作文也有很大进步
- **真实感**:包含成绩单照片,增强可信度
## 4. 营销意图
- 展示具体的成绩提升数据,证明课程效果
- 突出英语学科的显著进步
- 强调学生良好的学习态度和习惯
- 通过真实成绩单建立信任
- 展示多学科综合提升的效果
这张评价截图通过具体的成绩数据和真实的成绩单,有效地展示了高途课程对学生成绩的显著提升作用,特别是英语学科的进步。

Binary file not shown.

After

Width:  |  Height:  |  Size: 288 KiB

View File

@ -0,0 +1,56 @@
# 提升成绩
这是一张即时通讯软件的聊天记录截图,内容围绕学生与老师的对话,核心是学生成绩进步的反馈与交流。画面以聊天气泡为主要视觉元素,通过文字、表情、图片及少量图表呈现信息,整体风格偏向日常沟通,带有积极鼓励的氛围。
## 1. 文字内容(按出现顺序提取)
- 时间戳:`19:38`、`12:10`
- 对话文本:
- "老师老师这次期中我进步了11分"
- "强"(配图,黑色背景+白色艺术字)
- "进班1个月进步11分 李琪老师太牛了!!!"
- "你太棒啦~宝"
- "安在的半期成绩"
- "进步明显"
- "感谢老师,英语有明显进步"
- "希望在老师的教导下再进步点"
- "太棒啦~继续努力呢👍"
- "麻烦妈妈吧安在同学的试卷和答题卡发给孟孟老师呢"
- "看一下本次出现的问题以及接下来的问题解决"
- "五一假期快乐呀~我看咱宝第二讲的作业还是没有完成提交,尽量在五一假期结束之前完成学习哦😊 噢噢好的写了没交我现在还在外面回去我就交"
- "老师我的数学成绩出来了"
- "86/120"(红色边框突出显示)
- "李琪老师太牛了!!! 进班不到2个月提升42分"
- "嘿嘿"
- "进步了是不是,我记得你进班的时候四五十分😀"
- "对呀"
- "之前43.5"(红色边框突出显示)
## 2. 视觉元素与布局
- **布局**:聊天气泡呈垂直堆叠,不同气泡颜色区分发言者(如灰色为系统/非当前用户消息,蓝色为当前用户消息,红色为强调文字/成绩)。部分关键信息(如成绩"86/120""之前43.5")用红色边框突出,增强视觉焦点。
- **颜色**
- 聊天气泡底色:灰色(系统消息)、白色(用户消息);
- 文字颜色:黑色(常规文字)、红色(强调/成绩、感叹句);
- 表情符号:包含😊、👍、😀等,增加情感表达。
- **图表与图片**
- 中间位置有一张小图表(疑似成绩趋势图,含折线/柱状元素),配合文字"安在的半期成绩"展示进步;
- "强"字配图(黑色背景+白色艺术字),强化"进步"的积极情绪。
- **人物与关系**:对话涉及"李琪老师""孟孟老师"(教师)与"安在""咱宝"(学生),通过"妈妈"代为传递信息,体现家校沟通场景。
## 3. 语境与信息逻辑
对话围绕**学生成绩进步**展开:学生汇报期中/半期成绩提升(如"进步11分""提升42分"),老师或家长表达鼓励("太牛了""继续努力"),同时涉及作业提交("五一假期完成学习")和后续问题解决("看本次问题及解决"),整体传递出"进步—鼓励—后续规划"的沟通逻辑。
## 4. 评价特点
- **具体成绩数据**明确提到进步11分和42分以及具体的分数86/120之前43.5
- **老师认可**:特别提到李琪老师的优秀教学
- **多科进步**:涉及期中成绩和数学成绩
- **家校沟通**:通过妈妈传递信息,体现家校协作
- **积极氛围**:使用感叹号、表情符号和鼓励性语言
## 5. 营销意图
- 展示具体的成绩提升数据,证明课程效果
- 突出老师的教学能力
- 强调短期内的显著进步
- 展示家校合作的积极效果
- 通过真实对话建立信任
这张评价截图通过具体的成绩数据和真实的对话,有效地展示了高途课程对学生成绩的显著提升作用,特别是短期内的快速进步。

Binary file not shown.

After

Width:  |  Height:  |  Size: 348 KiB

View File

@ -0,0 +1,54 @@
# 物理拿下高分
这是一张关于"物理成绩提升"的宣传类信息图,以聊天记录和喜报形式呈现,核心内容是展示学生在物理学习上的成绩进步,突出教学效果。整体设计通过文字、表情、红色标注等元素传递积极信息,布局清晰,重点突出。
## 1. 文字内容(按位置提取)
- **顶部标题**`物理成绩拿下高分`(红色大字体,位于页面最上方,字体加粗,视觉冲击力强)。
- **左侧红色框内聊天记录**
- `1😊`
- `物理也挺难`
- `😊`
- **中间喜报区域**
- `山东济南中考高分喜报!`(红色字体,突出地域和事件类型)
- `81/90分妥妥拿下~`(红色字体,强调具体成绩)
- **右侧聊天记录**
- `你八十几?`(浅蓝色背景,模拟聊天气泡)
- `太有石粒辣`(黑色字体,搭配哭笑表情🤣,表达惊喜)
- `老师孩子月考成绩出来了比年前期末考试成绩84分提高7分。这次91分满分是100分班级第一孩子自从上咱物理课兴趣激情满满。还说让老师给推荐一下教辅🙏谢谢老师`(其中`84分提高7分`、`91分`、`班级第一`、`上咱物理课兴趣激情满满`被红色框标注,突出成绩提升关键信息)
- `山东济南`(红色字体,标注地域,强化地域关联)
- **下方聊天记录**
- `刚扣9分高分奖励这不到手了吗`
- `给你发2200学币看看有没有可以兑换的啊`
- **页面底部**`高途好老师`(黑色字体,标注机构/老师名称)
## 2. 视觉元素与布局
- **颜色**
- 红色:用于标题、喜报文字、重点成绩标注,传递喜庆、强调的视觉感受;
- 黑色:用于普通聊天文字,清晰易读;
- 浅蓝色:用于聊天气泡背景,模拟真实聊天界面;
- 表情:黄色底色+蓝色眼泪的哭笑表情,增强情感表达。
- **布局**
- 页面分为左、中、右三栏:左侧是简短聊天框,中间是核心喜报,右侧是详细成绩说明+地域标注,底部是机构名称,整体结构清晰,信息层级分明。
- **其他元素**
- 红色边框:左侧聊天框和右侧重点成绩文字使用红色边框,引导视觉焦点;
- 聊天气泡:模拟真实对话场景,增强代入感;
- 表情符号:增加情感色彩,使内容更生动。
## 3. 语境与信息逻辑
这张图片的目的是**宣传"高途好老师"的物理辅导效果**通过真实或模拟的聊天记录和成绩数据展示学生成绩提升从84分到91分提高7分班级第一强调课程对学习兴趣的激发"兴趣激情满满"),同时通过地域标注(山东济南)和机构名称(高途好老师)强化品牌关联。设计上通过红色、重点标注等视觉手段,快速传递"成绩提升"的核心信息,吸引目标用户(学生/家长)关注。
## 4. 评价特点
- **具体成绩数据**明确提到从84分提高到91分提高7分班级第一
- **学科重点**:特别突出物理学科的成绩提升
- **兴趣培养**:强调课程对学习兴趣的激发
- **地域案例**:提供山东济南的具体案例
- **奖励机制**:提到学币奖励,体现激励机制
## 5. 营销意图
- 展示具体的物理成绩提升数据,证明课程效果
- 突出老师的教学能力
- 强调学习兴趣的培养
- 提供地域性成功案例
- 展示奖励机制,增加吸引力
这张评价截图通过具体的成绩数据和真实的对话形式,有效地展示了高途物理课程对学生成绩的显著提升作用,特别是对学习兴趣的激发。

Binary file not shown.

After

Width:  |  Height:  |  Size: 259 KiB

View File

@ -0,0 +1,51 @@
# 高途让家长放心,早点认识高途就好了
这是一张用户评价截图,展示了家长对高途课程的认可和满意,表达了"早点认识高途就好了"的遗憾心情。
## 1. 主要评价内容
### 左侧评价
- **用户反馈**
- "谢谢老师,还是老师指导有方"
- "跟着高途学,家长放心"
- "暑假跟着高途,孩子语文妈妈我就不担心了"
- "我们一起为孩子能考上重点高中而努力😊😊😊"
### 右侧评价
- **用户反馈**
- "好的"
- "写的很好哦!"
- "表扬!希望能有所收获"
- "您太客气啦!"
- "接下来有问题都随时来问我哈"
- "早点认识高途就好了"
- "早点认识高途就好了"
- "走了不少弯路"
- "嗯嗯,慢慢来"
- "语文学习本身就是一个潜移默化的过程"
- "走了不少弯路"
- "那个抖音上那种阅读,也买过,没有系统的教学,靠娃自觉还是不行,家长又不懂"
- "嗯嗯!必须的!最后一年,拼尽全力!💪💪"
## 2. 视觉设计
- **布局**:左右分栏的对话形式,模拟真实的聊天界面
- **颜色**:浅灰色背景,文字为黑色,重点文字用红色突出
- **表情**:包含多种表情符号(😊、💪等),增加亲和力
- **格式**:使用气泡对话框形式,模拟真实聊天场景
- **重点文字**:关键语句如"跟着高途学,家长放心"和"早点认识高途就好了"用红色大字突出显示
## 3. 评价特点
- **家长认可**:明确表达对高途课程的信任和放心
- **时间紧迫感**:提到"最后一年,拼尽全力",体现中考临近的紧迫感
- **对比体验**:提到之前在其他平台学习的不足,突出高途的系统教学优势
- **遗憾情绪**:表达"早点认识高途就好了"和"走了不少弯路"的遗憾
- **积极态度**:尽管有遗憾,但仍然保持积极的学习态度
## 4. 营销意图
- 展示家长对课程的信任和满意
- 突出高途的系统教学优势
- 强调早期认识高途的重要性
- 体现中考临近的紧迫感和学习动力
- 通过真实对话建立信任
这张评价截图通过家长的真情实感,有效地展示了高途课程给家长带来的安心感和信任度,同时也暗示了早期选择高途的重要性。

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

View File

@ -0,0 +1,66 @@
# 引导加辅导老师微信图片
这是一张操作指引图,用于指导用户在"高途"APP中查找第二讲辅导老师的微信包含三张手机界面截图与底部的步骤说明。整体背景为暖色调橙黄色渐变界面以白色为主搭配橙色、红色等强调色通过箭头指示操作流程视觉上清晰醒目。
## 1. 整体布局与视觉风格
- **背景**:顶部至底部为橙黄色渐变,营造活泼、醒目的氛围,突出指引内容。
- **截图排列**:三张手机界面截图横向排列,每张截图通过橙色/黄色箭头指示操作逻辑,布局清晰。
## 2. 三张截图的文字与元素(从左至右)
### 左图("高途"APP首页/课程列表页)
- **顶部状态栏**:显示时间"上午10:20"、网络速度"111K/s"、电池/信号等图标。
- **"今日学习推荐"板块**
- 课程1"不用读全文阅读大满贯" → 主讲【新9阶】双语素养+科学思维训练营5月2日18:55开课。
- 课程2"大力出奇迹" → 第2讲【新9阶】双语素养+科学思维训练营5月2日19:56开课。
- **"全部课程"板块**
- 课程标题:"【新9阶】双语素养+科学思维训练营" → 共7节·未学习5月2日18:55开课。
- 主讲老师头像王冰、韩瑞05。
- **底部导航栏**:首页、学习、**上课**(红色图标,当前选中)、发现、我的。
### 中图(课程详情页)
- **顶部状态栏**:时间"上午10:20"、网络速度"0.3K/s"。
- **课程标题**【新9阶】双语素养+科学思维训练营 → 有效期至2024年06月03日。
- **主讲老师**:王冰、李雪冬、王泽龙(头像+"主讲老师"标签)。
- **功能入口**:学习资料、缓存课程、错题本(图标+文字)。
- **学习进度**听课进度0%、听课时长0分。
- **"课前准备"板块**
- 橙色按钮:"添加二讲老师微信"(核心操作入口)。
- 选项:"关注公众号收取报告" → 右侧"去关注"链接。
- **课程列表**
- 第1讲"不用读全文阅读大满贯" → 05月02日周四18:55 - 19:55 | 未开始。
- 第2讲"大力出奇迹" → 05月02日周四19:56 - 21:01 | 未开始。
- 第3讲"全等模型狂想曲"(部分可见)。
### 右图("高途成长助手"页面)
- **顶部状态栏**:时间"上午10:20"、网络速度"80.4K/s"。
- **页面标题**:高途成长助手 → 评分2.5。
- **倒计时与提示**剩余04:56:8 → "请务必添加老师 否则无法上课"(红色警示文字)。
- **老师信息**:高途助教老师 专属(卡通头像+"专属"标签)。
- **二维码区域**:提示"长按二维码,联系老师"。
- **底部按钮**:全程辅导答疑、定制学习规划。
## 3. 底部操作步骤文字
- 第一步,下载"高途"APP
- 第二步,登录后,点击下方"上课"按钮
- 第三部,点击已报名课程("第三部"疑似"第三步"笔误)
- 第四步,点击"添加二讲老师微信"按钮,自动跳转到微信添加
## 4. 语境与信息逻辑
这张图是教育类APP"高途"的运营指引核心目标是引导用户通过APP完成"添加第二讲辅导老师微信"的操作,确保课程学习顺利进行。三张截图对应"进入上课页面→选择课程→添加老师微信"的逻辑链,步骤简洁、视觉重点突出(如红色"上课"按钮、橙色"添加老师微信"按钮),符合用户操作习惯。暖色调背景与醒目的文字/按钮设计,提升了信息传递效率,避免用户因未添加老师而无法上课。
## 5. 设计特点
- **操作指引清晰**:通过三张截图展示完整的操作流程
- **视觉重点突出**:使用红色和橙色强调关键按钮和操作
- **信息层级分明**:不同功能区域通过颜色和布局区分
- **用户友好**:包含倒计时和警示信息,提醒用户及时操作
- **品牌展示**展示高途APP的界面和功能
## 6. 营销意图
- 指导用户正确使用APP功能
- 确保用户能够顺利添加辅导老师
- 展示APP的用户界面和功能
- 强调添加老师的重要性
- 提升用户体验和课程参与度
这张引导图通过清晰的步骤和视觉设计,有效地指导用户完成关键操作,确保课程学习的顺利进行。

View File

@ -0,0 +1,52 @@
# 素养小学到课赠礼
这是一张小学素养课程的赠礼海报,主题为"飞跃领航计划"展示3-6阶学生的专属福利大礼包。海报设计活泼明快以浅蓝和浅绿为主色调包含三位主讲老师和四天的完课福利。
## 1. 课程主题与目标
- **阶段**3-6阶橙色标签
- **主标题**:飞跃领航计划("飞跃"为黑色粗体,"领航计划"为绿色粗体)
- **核心主题**:专属福利大礼包(绿色横幅)
- **课程标签**:思维 | 人文 | 脑力(黄色横幅)
## 2. 主讲老师信息
海报展示了三位主讲老师:
1. **陈君** - 左侧女性老师,穿着灰色西装
2. **杨易** - 中间男性老师,穿着黑色中式服装,手持折扇
3. **白马** - 右侧男性老师,穿着深色西装
## 3. 福利列表(四天完课福利)
### 第一天/完课福利
- 《脑王数独秘籍》
### 第二天/完课福利
1. NCTE词汇表带翻译
2. 《世界上下五千年》音频资料(上)
### 第三天/完课福利
- 《世界上下五千年》音频资料(中)
### 第四天/完课福利
1. 《最强大脑同款》-图形推理秘籍
2. BBC自然拼读益智动画
## 4. 视觉设计
- **颜色**:浅蓝渐变背景,搭配橙色、绿色、黄色等明亮色彩
- **布局**:顶部为标题和老师展示区,中间为福利列表,底部为装饰性卡通元素
- **卡通元素**:包含可爱的橙子形象和礼物盒图标,增加活泼感
- **文字风格**:标题使用大号粗体字,福利列表使用清晰的编号和项目符号
## 5. 课程特色
- **多学科覆盖**:包含思维训练、人文知识、脑力开发
- **多样化资源**:包括书籍、音频资料、动画等不同形式的学习材料
- **循序渐进**:四天的福利安排,每天都有新的学习内容
- **趣味性**结合《最强大脑》和BBC动画等有趣内容增加学习吸引力
## 6. 营销意图
- 展示课程的丰富福利,吸引学生参与
- 突出多学科素养培养
- 通过具体的学习材料展示课程价值
- 强调循序渐进的学习方式
- 利用知名节目和品牌最强大脑、BBC增加可信度
这张海报通过清晰的结构和丰富的福利内容,有效地展示了小学素养课程的价值,特别强调了思维、人文和脑力培养的综合素养教育。

Binary file not shown.

After

Width:  |  Height:  |  Size: 451 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 270 KiB

View File

@ -0,0 +1,3 @@
# 8化试听-袁媛
这是一张初中化学试听课的海报图片,主讲老师是袁媛。海报应该包含课程介绍、试听时间、报名方式等相关信息。图片可能包含化学相关的图形元素、课程亮点等内容。

View File

@ -0,0 +1,3 @@
# 8年级_思维_佳美_数与式找规律试听课新初一初二通用
这是一张初中数学试听课海报,主题是"数与式找规律"适合8年级学生由佳美老师主讲。海报应该包含数学规律题目的示例、课程特色、试听时间安排等信息。图片可能包含数学公式、几何图形或相关的教学元素。

View File

@ -0,0 +1,3 @@
# 8数试听-YOYO游剑荷-角度模型
这是一张初中数学试听课海报,主题是"角度模型"由YOYO游剑荷老师主讲。海报应该包含角度相关的几何图形、解题方法、课程亮点等内容。图片可能展示各种角度模型的应用和示例。

View File

@ -0,0 +1,3 @@
# 新8思维刘璐——全等三角形与辅助线
这是一张初中数学试听课海报,主题是"全等三角形与辅助线",由刘璐老师主讲。海报应该包含全等三角形的几何图形、辅助线的画法、解题技巧等内容。图片可能展示各种全等三角形的示例和辅助线的应用。

View File

@ -0,0 +1,3 @@
# 7物试听-郭志强-基础版
这是一张初中物理试听课海报,主题是基础物理知识,由郭志强老师主讲。海报应该包含物理基础概念、实验演示、课程内容介绍等信息。图片可能包含物理实验装置、公式或相关的教学元素。

View File

@ -0,0 +1,3 @@
# 7阶试听-李雪冬-声学
这是一张初中物理试听课海报,主题是"声学",由李雪冬老师主讲。海报应该包含声音的产生、传播、特性等声学知识,以及相关的实验和演示内容。图片可能展示声学相关的图形、公式或实验装置。

View File

@ -0,0 +1,3 @@
# 8物试听-李雪冬-光的折射
这是一张初中物理试听课海报,主题是"光的折射",由李雪冬老师主讲。海报应该包含光的折射原理、相关公式、实验演示等内容。图片可能展示光的折射现象、透镜、棱镜等光学元素。

View File

@ -0,0 +1,3 @@
# 9物试听-电功率
这是一张初中物理试听课海报,主题是"电功率"适合9年级学生。海报应该包含电功率的计算、电路分析、相关公式等内容。图片可能展示电路图、电表、功率计算示例等元素。

View File

@ -0,0 +1,3 @@
# 9物试听-韩盛乔-基础版
这是一张初中物理试听课海报主题是基础物理知识由韩盛乔老师主讲适合9年级学生。海报应该包含物理基础概念、公式、实验等内容。图片可能展示物理实验装置、公式或相关的教学元素。

View File

@ -0,0 +1,3 @@
# 9物试听-韩盛乔-浙教版
这是一张初中物理试听课海报主题是浙教版物理课程由韩盛乔老师主讲适合9年级学生。海报应该包含浙教版物理教材的相关内容、课程特色、试听信息等。图片可能展示浙教版教材的元素或相关的教学资料。

View File

@ -0,0 +1,3 @@
# 7英试听-张丹丹-词汇—易混词辨析
这是一张初中英语试听课海报,主题是"词汇—易混词辨析"由张丹丹老师主讲适合7年级学生。海报应该包含英语易混词汇的对比、辨析方法、例句等内容。图片可能展示词汇对比表格、相关例句或教学元素。

View File

@ -0,0 +1,3 @@
# 8英完形填空大招试听-张丹丹PNG
这是一张初中英语试听课海报,主题是"完形填空大招"由张丹丹老师主讲适合8年级学生。海报应该包含完形填空解题技巧、常用方法、例题解析等内容。图片可能展示完形填空题目示例、解题步骤或相关的教学元素。

View File

@ -0,0 +1,3 @@
# 王冰老师试听课—话题写作(难忘经历类)
这是一张初中英语试听课海报,主题是"话题写作(难忘经历类)",由王冰老师主讲。海报应该包含英语话题写作的技巧、范文示例、写作方法等内容。图片可能展示写作模板、相关例句或教学元素。

View File

@ -0,0 +1,3 @@
# 8年级-虫子-试听
这是一张初中语文试听课海报由虫子老师主讲适合8年级学生。海报应该包含语文课程介绍、试听内容、教学特色等信息。图片可能展示语文相关的元素如古诗词、文学作品或教学资料。

View File

@ -0,0 +1,3 @@
# 写作试听课-虫子老师
这是一张初中语文试听课海报,主题是写作课程,由虫子老师主讲。海报应该包含写作技巧、范文示例、写作方法等内容。图片可能展示写作相关的元素,如作文示例、写作技巧图表或教学资料。

View File

@ -0,0 +1,3 @@
# 小学数学杨易
这是一张小学数学试听课海报,由杨易老师主讲。海报应该包含小学数学课程介绍、试听内容、教学特色等信息。图片可能展示数学相关的元素,如数字、图形、算式或教学资料。

Some files were not shown because too many files have changed in this diff Show More