fix: 修复RAG检索多个问题并更新嵌入模型配置 [AC-AISVC-50]
主要修复: 1. 修复ConfigForm和EmbeddingConfigForm组件watch死循环导致内存溢出 2. 修复向量存储格式与检索格式不匹配问题 3. 修复两阶段检索和混合检索互斥问题 4. 修复RRF融合时vector字段丢失问题 5. 修复embedding_full未归一化导致相似度计算错误 6. 修复嵌入模型配置表单不显示参数问题 功能增强: - 添加with_vectors参数支持返回向量用于重排序 - 新增两阶段+混合检索组合策略 - 更新README嵌入模型配置说明,推荐nomic-embed-text-v2-moe - 添加cleanup_qdrant.py脚本用于清理向量数据
This commit is contained in:
parent
6150fc0dd2
commit
fd04ed2cef
30
README.md
30
README.md
|
|
@ -78,14 +78,36 @@ docker-compose up -d --build
|
||||||
|
|
||||||
#### 4. 拉取嵌入模型
|
#### 4. 拉取嵌入模型
|
||||||
|
|
||||||
服务启动后,需要在 Ollama 容器中拉取 nomic-embed-text 模型:
|
服务启动后,需要在 Ollama 容器中拉取嵌入模型。推荐使用 `nomic-embed-text-v2-moe`,对中文支持更好:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 进入 Ollama 容器拉取模型
|
# 进入 Ollama 容器拉取模型
|
||||||
docker exec -it ai-ollama ollama pull nomic-embed-text
|
docker exec -it ai-ollama ollama pull toshk0/nomic-embed-text-v2-moe:Q6_K
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5. 验证服务
|
**可选模型**:
|
||||||
|
|
||||||
|
| 模型 | 维度 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `toshk0/nomic-embed-text-v2-moe:Q6_K` | 768 | 推荐,中文支持好,支持任务前缀 |
|
||||||
|
| `nomic-embed-text:v1.5` | 768 | 原版,支持任务前缀和 Matryoshka |
|
||||||
|
| `bge-large-zh` | 1024 | 中文专用,效果最好 |
|
||||||
|
|
||||||
|
#### 5. 配置嵌入模型
|
||||||
|
|
||||||
|
访问前端管理界面,进入 **嵌入模型配置** 页面:
|
||||||
|
|
||||||
|
1. 选择提供者:**Nomic Embed (优化版)**
|
||||||
|
2. 配置参数:
|
||||||
|
- **API 地址**:`http://ollama:11434`(Docker 环境)或 `http://localhost:11434`(本地开发)
|
||||||
|
- **模型名称**:`toshk0/nomic-embed-text-v2-moe:Q6_K`
|
||||||
|
- **向量维度**:`768`
|
||||||
|
- **Matryoshka 截断**:`true`
|
||||||
|
3. 点击 **保存配置**
|
||||||
|
|
||||||
|
> **注意**: 使用 Nomic Embed (优化版) provider 可启用完整的 RAG 优化功能:任务前缀、Matryoshka 多向量、两阶段检索。
|
||||||
|
|
||||||
|
#### 6. 验证服务
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 检查服务状态
|
# 检查服务状态
|
||||||
|
|
@ -97,7 +119,7 @@ docker compose logs -f ai-service | grep "Default API Key"
|
||||||
|
|
||||||
> **重要**: 后端首次启动时会自动生成一个默认 API Key,请从日志中复制该 Key,用于前端配置。
|
> **重要**: 后端首次启动时会自动生成一个默认 API Key,请从日志中复制该 Key,用于前端配置。
|
||||||
|
|
||||||
#### 6. 配置前端 API Key
|
#### 7. 配置前端 API Key
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 创建前端环境变量文件
|
# 创建前端环境变量文件
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ export default defineConfig({
|
||||||
port: 3000,
|
port: 3000,
|
||||||
proxy: {
|
proxy: {
|
||||||
'/api': {
|
'/api': {
|
||||||
target: 'http://localhost:8000',
|
target: 'http://localhost:8088',
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
rewrite: (path) => path.replace(/^\/api/, ''),
|
rewrite: (path) => path.replace(/^\/api/, ''),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,7 @@ class QdrantClient:
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
vector_name: str = "full",
|
vector_name: str = "full",
|
||||||
|
with_vectors: bool = False,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-10] Search vectors in tenant's collection.
|
[AC-AISVC-10] Search vectors in tenant's collection.
|
||||||
|
|
@ -189,6 +190,7 @@ class QdrantClient:
|
||||||
score_threshold: Minimum score threshold for results
|
score_threshold: Minimum score threshold for results
|
||||||
vector_name: Name of the vector to search (for multi-vector collections)
|
vector_name: Name of the vector to search (for multi-vector collections)
|
||||||
Default is "full" for 768-dim vectors in Matryoshka setup.
|
Default is "full" for 768-dim vectors in Matryoshka setup.
|
||||||
|
with_vectors: Whether to return vectors in results (for two-stage reranking)
|
||||||
"""
|
"""
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
|
|
||||||
|
|
@ -216,6 +218,7 @@ class QdrantClient:
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=(vector_name, query_vector),
|
query_vector=(vector_name, query_vector),
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
with_vectors=with_vectors,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
if "vector name" in str(e).lower() or "Not existing vector" in str(e):
|
||||||
|
|
@ -227,6 +230,7 @@ class QdrantClient:
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
with_vectors=with_vectors,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
@ -235,15 +239,18 @@ class QdrantClient:
|
||||||
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
f"[AC-AISVC-10] Collection {collection_name} returned {len(results)} raw results"
|
||||||
)
|
)
|
||||||
|
|
||||||
hits = [
|
hits = []
|
||||||
{
|
for result in results:
|
||||||
|
if score_threshold is not None and result.score < score_threshold:
|
||||||
|
continue
|
||||||
|
hit = {
|
||||||
"id": str(result.id),
|
"id": str(result.id),
|
||||||
"score": result.score,
|
"score": result.score,
|
||||||
"payload": result.payload or {},
|
"payload": result.payload or {},
|
||||||
}
|
}
|
||||||
for result in results
|
if with_vectors and result.vector:
|
||||||
if score_threshold is None or result.score >= score_threshold
|
hit["vector"] = result.vector
|
||||||
]
|
hits.append(hit)
|
||||||
all_hits.extend(hits)
|
all_hits.extend(hits)
|
||||||
|
|
||||||
if hits:
|
if hits:
|
||||||
|
|
|
||||||
|
|
@ -74,11 +74,38 @@ class EmbeddingProviderFactory:
|
||||||
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
raw_schema = temp_instance.get_config_schema()
|
||||||
|
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
for key, field in raw_schema.items():
|
||||||
|
properties[key] = {
|
||||||
|
"type": field.get("type", "string"),
|
||||||
|
"title": field.get("title", key),
|
||||||
|
"description": field.get("description", ""),
|
||||||
|
"default": field.get("default"),
|
||||||
|
}
|
||||||
|
if field.get("enum"):
|
||||||
|
properties[key]["enum"] = field.get("enum")
|
||||||
|
if field.get("minimum") is not None:
|
||||||
|
properties[key]["minimum"] = field.get("minimum")
|
||||||
|
if field.get("maximum") is not None:
|
||||||
|
properties[key]["maximum"] = field.get("maximum")
|
||||||
|
if field.get("required"):
|
||||||
|
required.append(key)
|
||||||
|
|
||||||
|
config_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
if required:
|
||||||
|
config_schema["required"] = required
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"display_name": display_names.get(name, name),
|
"display_name": display_names.get(name, name),
|
||||||
"description": descriptions.get(name, ""),
|
"description": descriptions.get(name, ""),
|
||||||
"config_schema": temp_instance.get_config_schema(),
|
"config_schema": config_schema,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -286,7 +313,7 @@ def get_embedding_config_manager() -> EmbeddingConfigManager:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
_embedding_config_manager = EmbeddingConfigManager(
|
_embedding_config_manager = EmbeddingConfigManager(
|
||||||
default_provider="ollama",
|
default_provider="nomic",
|
||||||
default_config={
|
default_config={
|
||||||
"base_url": settings.ollama_base_url,
|
"base_url": settings.ollama_base_url,
|
||||||
"model": settings.ollama_embedding_model,
|
"model": settings.ollama_embedding_model,
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
||||||
|
|
||||||
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
embedding_256 = self._truncate_and_normalize(embedding, 256)
|
||||||
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
embedding_512 = self._truncate_and_normalize(embedding, 512)
|
||||||
|
embedding_full = self._truncate_and_normalize(embedding, len(embedding))
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Generated Nomic embedding: task={task.value}, "
|
f"Generated Nomic embedding: task={task.value}, "
|
||||||
|
|
@ -156,7 +157,7 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
return NomicEmbeddingResult(
|
return NomicEmbeddingResult(
|
||||||
embedding_full=embedding,
|
embedding_full=embedding_full,
|
||||||
embedding_256=embedding_256,
|
embedding_256=embedding_256,
|
||||||
embedding_512=embedding_512,
|
embedding_512=embedding_512,
|
||||||
dimension=len(embedding),
|
dimension=len(embedding),
|
||||||
|
|
@ -259,26 +260,31 @@ class NomicEmbeddingProvider(EmbeddingProvider):
|
||||||
return {
|
return {
|
||||||
"base_url": {
|
"base_url": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "API 地址",
|
||||||
"description": "Ollama API 地址",
|
"description": "Ollama API 地址",
|
||||||
"default": "http://localhost:11434",
|
"default": "http://localhost:11434",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "模型名称",
|
||||||
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
"description": "嵌入模型名称(推荐 nomic-embed-text v1.5)",
|
||||||
"default": "nomic-embed-text",
|
"default": "nomic-embed-text",
|
||||||
},
|
},
|
||||||
"dimension": {
|
"dimension": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "向量维度",
|
||||||
"description": "向量维度(支持 256/512/768)",
|
"description": "向量维度(支持 256/512/768)",
|
||||||
"default": 768,
|
"default": 768,
|
||||||
},
|
},
|
||||||
"timeout_seconds": {
|
"timeout_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "超时时间",
|
||||||
"description": "请求超时时间(秒)",
|
"description": "请求超时时间(秒)",
|
||||||
"default": 60,
|
"default": 60,
|
||||||
},
|
},
|
||||||
"enable_matryoshka": {
|
"enable_matryoshka": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
|
"title": "Matryoshka 截断",
|
||||||
"description": "启用 Matryoshka 维度截断",
|
"description": "启用 Matryoshka 维度截断",
|
||||||
"default": True,
|
"default": True,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -130,21 +130,25 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||||
return {
|
return {
|
||||||
"base_url": {
|
"base_url": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "API 地址",
|
||||||
"description": "Ollama API 地址",
|
"description": "Ollama API 地址",
|
||||||
"default": "http://localhost:11434",
|
"default": "http://localhost:11434",
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "模型名称",
|
||||||
"description": "嵌入模型名称",
|
"description": "嵌入模型名称",
|
||||||
"default": "nomic-embed-text",
|
"default": "nomic-embed-text",
|
||||||
},
|
},
|
||||||
"dimension": {
|
"dimension": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "向量维度",
|
||||||
"description": "向量维度",
|
"description": "向量维度",
|
||||||
"default": 768,
|
"default": 768,
|
||||||
},
|
},
|
||||||
"timeout_seconds": {
|
"timeout_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "超时时间",
|
||||||
"description": "请求超时时间(秒)",
|
"description": "请求超时时间(秒)",
|
||||||
"default": 60,
|
"default": 60,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -159,28 +159,33 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
return {
|
return {
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "API 密钥",
|
||||||
"description": "OpenAI API 密钥",
|
"description": "OpenAI API 密钥",
|
||||||
"required": True,
|
"required": True,
|
||||||
"secret": True,
|
"secret": True,
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "模型名称",
|
||||||
"description": "嵌入模型名称",
|
"description": "嵌入模型名称",
|
||||||
"default": "text-embedding-3-small",
|
"default": "text-embedding-3-small",
|
||||||
"enum": list(self.MODEL_DIMENSIONS.keys()),
|
"enum": list(self.MODEL_DIMENSIONS.keys()),
|
||||||
},
|
},
|
||||||
"base_url": {
|
"base_url": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"title": "API 地址",
|
||||||
"description": "OpenAI API 地址(支持兼容接口)",
|
"description": "OpenAI API 地址(支持兼容接口)",
|
||||||
"default": "https://api.openai.com/v1",
|
"default": "https://api.openai.com/v1",
|
||||||
},
|
},
|
||||||
"dimension": {
|
"dimension": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "向量维度",
|
||||||
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
|
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
|
||||||
"default": 1536,
|
"default": 1536,
|
||||||
},
|
},
|
||||||
"timeout_seconds": {
|
"timeout_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
"title": "超时时间",
|
||||||
"description": "请求超时时间(秒)",
|
"description": "请求超时时间(秒)",
|
||||||
"default": 60,
|
"default": 60,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,13 @@ class RRFCombiner:
|
||||||
"bm25_rank": -1,
|
"bm25_rank": -1,
|
||||||
"payload": result.get("payload", {}),
|
"payload": result.get("payload", {}),
|
||||||
"id": chunk_id,
|
"id": chunk_id,
|
||||||
|
"vector": result.get("vector"),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
combined_scores[chunk_id]["vector_score"] = result.get("score", 0.0)
|
||||||
|
combined_scores[chunk_id]["vector_rank"] = rank
|
||||||
|
if result.get("vector"):
|
||||||
|
combined_scores[chunk_id]["vector"] = result.get("vector")
|
||||||
|
|
||||||
combined_scores[chunk_id]["score"] += rrf_score
|
combined_scores[chunk_id]["score"] += rrf_score
|
||||||
|
|
||||||
|
|
@ -101,6 +107,7 @@ class RRFCombiner:
|
||||||
"bm25_rank": rank,
|
"bm25_rank": rank,
|
||||||
"payload": result.get("payload", {}),
|
"payload": result.get("payload", {}),
|
||||||
"id": chunk_id,
|
"id": chunk_id,
|
||||||
|
"vector": result.get("vector"),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
combined_scores[chunk_id]["bm25_score"] = result.get("score", 0.0)
|
||||||
|
|
@ -199,7 +206,15 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
f"dim_256={'available' if embedding_result.embedding_256 else 'not available'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._two_stage_enabled:
|
if self._two_stage_enabled and self._hybrid_enabled:
|
||||||
|
logger.info("[RAG-OPT] Using two-stage + hybrid retrieval strategy")
|
||||||
|
results = await self._two_stage_hybrid_retrieve(
|
||||||
|
ctx.tenant_id,
|
||||||
|
embedding_result,
|
||||||
|
ctx.query,
|
||||||
|
self._top_k,
|
||||||
|
)
|
||||||
|
elif self._two_stage_enabled:
|
||||||
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
logger.info("[RAG-OPT] Using two-stage retrieval strategy")
|
||||||
results = await self._two_stage_retrieve(
|
results = await self._two_stage_retrieve(
|
||||||
ctx.tenant_id,
|
ctx.tenant_id,
|
||||||
|
|
@ -300,20 +315,27 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
stage1_start = time.perf_counter()
|
stage1_start = time.perf_counter()
|
||||||
candidates = await self._search_with_dimension(
|
candidates = await self._search_with_dimension(
|
||||||
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||||
top_k * self._two_stage_expand_factor
|
top_k * self._two_stage_expand_factor,
|
||||||
|
with_vectors=True,
|
||||||
)
|
)
|
||||||
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||||
|
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
f"[RAG-OPT] Stage 1: {len(candidates)} candidates in {stage1_latency:.2f}ms"
|
||||||
)
|
)
|
||||||
|
|
||||||
stage2_start = time.perf_counter()
|
stage2_start = time.perf_counter()
|
||||||
reranked = []
|
reranked = []
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
stored_full_embedding = candidate.get("payload", {}).get("embedding_full", [])
|
vector_data = candidate.get("vector", {})
|
||||||
if stored_full_embedding:
|
stored_full_embedding = None
|
||||||
import numpy as np
|
|
||||||
|
if isinstance(vector_data, dict):
|
||||||
|
stored_full_embedding = vector_data.get("full", [])
|
||||||
|
elif isinstance(vector_data, list):
|
||||||
|
stored_full_embedding = vector_data
|
||||||
|
|
||||||
|
if stored_full_embedding and len(stored_full_embedding) > 0:
|
||||||
similarity = self._cosine_similarity(
|
similarity = self._cosine_similarity(
|
||||||
embedding_result.embedding_full,
|
embedding_result.embedding_full,
|
||||||
stored_full_embedding
|
stored_full_embedding
|
||||||
|
|
@ -326,7 +348,7 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
results = reranked[:top_k]
|
results = reranked[:top_k]
|
||||||
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||||
|
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
f"[RAG-OPT] Stage 2: {len(results)} final results in {stage2_latency:.2f}ms"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -374,6 +396,92 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
|
|
||||||
return combined[:top_k]
|
return combined[:top_k]
|
||||||
|
|
||||||
|
async def _two_stage_hybrid_retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
embedding_result: NomicEmbeddingResult,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Two-stage + Hybrid retrieval strategy.
|
||||||
|
|
||||||
|
Stage 1: Fast retrieval with 256-dim vectors + BM25 in parallel
|
||||||
|
Stage 2: RRF fusion + Precise reranking with 768-dim vectors
|
||||||
|
|
||||||
|
This combines the best of both worlds:
|
||||||
|
- Two-stage: Speed from 256-dim, precision from 768-dim reranking
|
||||||
|
- Hybrid: Semantic matching from vectors, keyword matching from BM25
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
stage1_start = time.perf_counter()
|
||||||
|
|
||||||
|
vector_task = self._search_with_dimension(
|
||||||
|
client, tenant_id, embedding_result.embedding_256, "dim_256",
|
||||||
|
top_k * self._two_stage_expand_factor,
|
||||||
|
with_vectors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
bm25_task = self._bm25_search(client, tenant_id, query, top_k * self._two_stage_expand_factor)
|
||||||
|
|
||||||
|
vector_results, bm25_results = await asyncio.gather(
|
||||||
|
vector_task, bm25_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(vector_results, Exception):
|
||||||
|
logger.warning(f"[RAG-OPT] Vector search failed: {vector_results}")
|
||||||
|
vector_results = []
|
||||||
|
|
||||||
|
if isinstance(bm25_results, Exception):
|
||||||
|
logger.warning(f"[RAG-OPT] BM25 search failed: {bm25_results}")
|
||||||
|
bm25_results = []
|
||||||
|
|
||||||
|
stage1_latency = (time.perf_counter() - stage1_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Two-stage Hybrid Stage 1: vector={len(vector_results)}, bm25={len(bm25_results)}, latency={stage1_latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
stage2_start = time.perf_counter()
|
||||||
|
|
||||||
|
combined = self._rrf_combiner.combine(
|
||||||
|
vector_results,
|
||||||
|
bm25_results,
|
||||||
|
vector_weight=settings.rag_vector_weight,
|
||||||
|
bm25_weight=settings.rag_bm25_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
reranked = []
|
||||||
|
for candidate in combined[:top_k * 2]:
|
||||||
|
vector_data = candidate.get("vector", {})
|
||||||
|
stored_full_embedding = None
|
||||||
|
|
||||||
|
if isinstance(vector_data, dict):
|
||||||
|
stored_full_embedding = vector_data.get("full", [])
|
||||||
|
elif isinstance(vector_data, list):
|
||||||
|
stored_full_embedding = vector_data
|
||||||
|
|
||||||
|
if stored_full_embedding and len(stored_full_embedding) > 0:
|
||||||
|
similarity = self._cosine_similarity(
|
||||||
|
embedding_result.embedding_full,
|
||||||
|
stored_full_embedding
|
||||||
|
)
|
||||||
|
candidate["score"] = similarity
|
||||||
|
candidate["stage"] = "two_stage_hybrid_reranked"
|
||||||
|
reranked.append(candidate)
|
||||||
|
|
||||||
|
reranked.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
results = reranked[:top_k]
|
||||||
|
stage2_latency = (time.perf_counter() - stage2_start) * 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RAG-OPT] Two-stage Hybrid Stage 2 (reranking): {len(results)} final results in {stage2_latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
async def _vector_retrieve(
|
async def _vector_retrieve(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|
@ -393,12 +501,13 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
vector_name: str,
|
vector_name: str,
|
||||||
limit: int,
|
limit: int,
|
||||||
|
with_vectors: bool = False,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Search using specified vector dimension."""
|
"""Search using specified vector dimension."""
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[RAG-OPT] Searching with vector_name={vector_name}, "
|
f"[RAG-OPT] Searching with vector_name={vector_name}, "
|
||||||
f"limit={limit}, vector_dim={len(query_vector)}"
|
f"limit={limit}, vector_dim={len(query_vector)}, with_vectors={with_vectors}"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await client.search(
|
results = await client.search(
|
||||||
|
|
@ -406,6 +515,7 @@ class OptimizedRetriever(BaseRetriever):
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
with_vectors=with_vectors,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
Script to cleanup Qdrant collections and data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, "Q:\\agentProject\\ai-robot-core\\ai-service")
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.qdrant_client import get_qdrant_client
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_collections():
|
||||||
|
"""List all collections in Qdrant."""
|
||||||
|
client = await get_qdrant_client()
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
|
||||||
|
collections = await qdrant.get_collections()
|
||||||
|
return [c.name for c in collections.collections]
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_collection(collection_name: str):
|
||||||
|
"""Delete a specific collection."""
|
||||||
|
client = await get_qdrant_client()
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await qdrant.delete_collection(collection_name)
|
||||||
|
logger.info(f"Deleted collection: {collection_name}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete collection {collection_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_all_collections():
|
||||||
|
"""Delete all collections."""
|
||||||
|
collections = await list_collections()
|
||||||
|
logger.info(f"Found {len(collections)} collections: {collections}")
|
||||||
|
|
||||||
|
for name in collections:
|
||||||
|
await delete_collection(name)
|
||||||
|
|
||||||
|
logger.info("All collections deleted")
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_tenant_collection(tenant_id: str):
|
||||||
|
"""Delete collection for a specific tenant."""
|
||||||
|
client = await get_qdrant_client()
|
||||||
|
collection_name = client.get_collection_name(tenant_id)
|
||||||
|
|
||||||
|
success = await delete_collection(collection_name)
|
||||||
|
if success:
|
||||||
|
logger.info(f"Deleted collection for tenant: {tenant_id}")
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Cleanup Qdrant data")
|
||||||
|
parser.add_argument("--all", action="store_true", help="Delete all collections")
|
||||||
|
parser.add_argument("--tenant", type=str, help="Delete collection for specific tenant")
|
||||||
|
parser.add_argument("--list", action="store_true", help="List all collections")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.list:
|
||||||
|
collections = asyncio.run(list_collections())
|
||||||
|
print(f"Collections: {collections}")
|
||||||
|
elif args.all:
|
||||||
|
confirm = input("Are you sure you want to delete ALL collections? (yes/no): ")
|
||||||
|
if confirm.lower() == "yes":
|
||||||
|
asyncio.run(delete_all_collections())
|
||||||
|
else:
|
||||||
|
print("Cancelled")
|
||||||
|
elif args.tenant:
|
||||||
|
confirm = input(f"Delete collection for tenant '{args.tenant}'? (yes/no): ")
|
||||||
|
if confirm.lower() == "yes":
|
||||||
|
asyncio.run(delete_tenant_collection(args.tenant))
|
||||||
|
else:
|
||||||
|
print("Cancelled")
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
Loading…
Reference in New Issue