ai-robot-core/ai-service/app/core/qdrant_client.py

675 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant client for AI Service.
[AC-AISVC-10] Vector database client with tenant-isolated collection management.
Supports multi-dimensional vectors for Matryoshka representation learning.
"""
import logging
from typing import Any
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class QdrantClient:
"""
[AC-AISVC-10, AC-AISVC-59] Qdrant client with tenant-isolated collection management.
Collection naming conventions:
- Legacy (single KB): kb_{tenantId}
- Multi-KB: kb_{tenantId}_{kbId}
Supports multi-dimensional vectors (256/512/768) for Matryoshka retrieval.
"""
def __init__(self):
self._client: AsyncQdrantClient | None = None
self._collection_prefix = settings.qdrant_collection_prefix
self._vector_size = settings.qdrant_vector_size
async def get_client(self) -> AsyncQdrantClient:
"""Get or create Qdrant client instance."""
if self._client is None:
self._client = AsyncQdrantClient(url=settings.qdrant_url)
logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}")
return self._client
async def close(self) -> None:
"""Close Qdrant client connection."""
if self._client:
await self._client.close()
self._client = None
logger.info("Qdrant client connection closed")
def get_collection_name(self, tenant_id: str) -> str:
"""
[AC-AISVC-10] Get legacy collection name for a tenant.
Naming convention: kb_{tenantId}
Replaces @ with _ to ensure valid collection names.
Note: This is kept for backward compatibility.
For multi-KB, use get_kb_collection_name() instead.
"""
safe_tenant_id = tenant_id.replace('@', '_')
return f"{self._collection_prefix}{safe_tenant_id}"
def get_kb_collection_name(self, tenant_id: str, kb_id: str | None = None) -> str:
"""
[AC-AISVC-59, AC-AISVC-63] Get collection name for a specific knowledge base.
Naming convention:
- If kb_id is None or "default": kb_{tenantId} (legacy format for backward compatibility)
- Otherwise: kb_{tenantId}_{kbId}
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID (optional, defaults to legacy naming)
Returns:
Collection name for the knowledge base
"""
safe_tenant_id = tenant_id.replace('@', '_')
if kb_id is None or kb_id == "default" or kb_id == "":
return f"{self._collection_prefix}{safe_tenant_id}"
safe_kb_id = kb_id.replace('-', '_')[:8]
return f"{self._collection_prefix}{safe_tenant_id}_{safe_kb_id}"
async def ensure_collection_exists(self, tenant_id: str, use_multi_vector: bool = True) -> bool:
"""
[AC-AISVC-10] Ensure collection exists for tenant (legacy single-KB mode).
Supports multi-dimensional vectors for Matryoshka retrieval.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}")
return False
async def ensure_kb_collection_exists(
self,
tenant_id: str,
kb_id: str | None = None,
use_multi_vector: bool = True,
) -> bool:
"""
[AC-AISVC-59] Ensure collection exists for a specific knowledge base.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID (optional, defaults to legacy naming)
use_multi_vector: Whether to use multi-dimensional vectors
Returns:
True if collection exists or was created successfully
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
if use_multi_vector:
vectors_config = {
"full": VectorParams(
size=768,
distance=Distance.COSINE,
),
"dim_256": VectorParams(
size=256,
distance=Distance.COSINE,
),
"dim_512": VectorParams(
size=512,
distance=Distance.COSINE,
),
}
else:
vectors_config = VectorParams(
size=self._vector_size,
distance=Distance.COSINE,
)
await client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
logger.info(
f"[AC-AISVC-59] Created KB collection: {collection_name} for tenant={tenant_id}, kb_id={kb_id} "
f"with multi_vector={use_multi_vector}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-59] Error ensuring KB collection: {e}")
return False
async def upsert_vectors(
self,
tenant_id: str,
points: list[PointStruct],
kb_id: str | None = None,
) -> bool:
"""
[AC-AISVC-10, AC-AISVC-63] Upsert vectors into tenant's collection.
Args:
tenant_id: Tenant identifier
points: List of PointStruct to upsert
kb_id: Knowledge base ID (optional, uses legacy naming if not provided)
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
await client.upsert(
collection_name=collection_name,
points=points,
)
logger.info(
f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}, kb_id={kb_id}"
)
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}")
return False
async def upsert_multi_vector(
self,
tenant_id: str,
points: list[dict[str, Any]],
kb_id: str | None = None,
) -> bool:
"""
Upsert points with multi-dimensional vectors.
Args:
tenant_id: Tenant identifier
points: List of points with format:
{
"id": str | int,
"vector": {
"full": [768 floats],
"dim_256": [256 floats],
"dim_512": [512 floats],
},
"payload": dict
}
kb_id: Knowledge base ID (optional, uses legacy naming if not provided)
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
qdrant_points = []
for p in points:
point = PointStruct(
id=p["id"],
vector=p["vector"],
payload=p.get("payload", {}),
)
qdrant_points.append(point)
await client.upsert(
collection_name=collection_name,
points=qdrant_points,
)
logger.info(
f"[RAG-OPT] Upserted {len(points)} multi-vector points for tenant={tenant_id}, kb_id={kb_id}"
)
return True
except Exception as e:
logger.error(f"[RAG-OPT] Error upserting multi-vectors: {e}")
return False
async def search(
self,
tenant_id: str,
query_vector: list[float],
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
with_vectors: bool = False,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-10] Search vectors in tenant's collections.
Returns results with score >= score_threshold if specified.
Searches all collections for the tenant (multi-KB support).
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
limit: Maximum number of results per collection
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search (for multi-vector collections)
Default is "full" for 768-dim vectors in Matryoshka setup.
with_vectors: Whether to return vectors in results (for two-stage reranking)
metadata_filter: Optional metadata filter to apply during search
kb_ids: Optional list of knowledge base IDs to restrict search to specific KBs
"""
client = await self.get_client()
logger.info(
f"[AC-AISVC-10] Starting search: tenant_id={tenant_id}, "
f"limit={limit}, score_threshold={score_threshold}, vector_dim={len(query_vector)}, vector_name={vector_name}"
)
if metadata_filter:
logger.info(f"[AC-AISVC-10] Metadata filter: {metadata_filter}")
# 构建 Qdrant filter
qdrant_filter = None
if metadata_filter:
qdrant_filter = self._build_qdrant_filter(metadata_filter)
logger.info(f"[AC-AISVC-10] Qdrant filter: {qdrant_filter}")
# 获取该租户的所有 collections
collection_names = await self._get_tenant_collections(client, tenant_id)
# 如果指定了 kb_ids则只搜索指定的知识库 collections
if kb_ids:
target_collections = []
for kb_id in kb_ids:
kb_collection_name = self.get_kb_collection_name(tenant_id, kb_id)
if kb_collection_name in collection_names:
target_collections.append(kb_collection_name)
else:
logger.warning(f"[AC-AISVC-10] KB collection not found: {kb_collection_name} for kb_id={kb_id}")
collection_names = target_collections
logger.info(f"[AC-AISVC-10] Restricted to {len(collection_names)} KB collections: {collection_names}")
else:
logger.info(f"[AC-AISVC-10] Will search in {len(collection_names)} collections: {collection_names}")
all_hits = []
for collection_name in collection_names:
try:
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-10] Collection {collection_name} does not exist")
continue
try:
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
query_filter=qdrant_filter,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
logger.info(
f"[AC-AISVC-10] Collection {collection_name} doesn't have vector named '{vector_name}', "
f"trying without vector name (single-vector mode)"
)
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
query_filter=qdrant_filter,
)
else:
raise
logger.info(
f"[AC-AISVC-10] Collection {collection_name} returned {len(results.points)} raw results"
)
hits = []
for result in results.points:
hit = {
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
"collection": collection_name, # 添加 collection 信息
}
if with_vectors and result.vector:
hit["vector"] = result.vector
hits.append(hit)
all_hits.extend(hits)
if hits:
logger.info(
f"[AC-AISVC-10] Search in collection {collection_name}: {len(hits)} results for tenant={tenant_id}"
)
else:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} returned no hits (filtered or empty)"
)
except Exception as e:
logger.warning(
f"[AC-AISVC-10] Collection {collection_name} not found or error: {e}"
)
# 按分数排序并返回 top results
all_hits.sort(key=lambda x: x["score"], reverse=True)
all_hits = all_hits[:limit]
logger.info(
f"[AC-AISVC-10] Search returned {len(all_hits)} total results for tenant={tenant_id}"
)
if len(all_hits) == 0:
logger.warning(
f"[AC-AISVC-10] No results found! tenant={tenant_id}, "
f"collections_tried={collection_names}, limit={limit}"
)
return all_hits
async def _get_tenant_collections(
self,
client: AsyncQdrantClient,
tenant_id: str,
) -> list[str]:
"""
获取指定租户的所有 collections。
优先从 Redis 缓存获取,未缓存则从 Qdrant 查询并缓存。
Args:
client: Qdrant client
tenant_id: 租户 ID
Returns:
Collection 名称列表
"""
import time
start_time = time.time()
# 1. 尝试从缓存获取
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cache_key = f"collections:{tenant_id}"
try:
# 确保 Redis 连接已初始化
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
cached = await redis_client.get(cache_key)
if cached:
import json
collections = json.loads(cached)
logger.info(
f"[AC-AISVC-10] Cache hit: Found {len(collections)} collections "
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
)
return collections
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache get error: {e}")
# 2. 从 Qdrant 查询
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"{self._collection_prefix}{safe_tenant_id}"
try:
collections = await client.get_collections()
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
# 按名称排序
tenant_collections.sort()
db_time = (time.time() - start_time) * 1000
logger.info(
f"[AC-AISVC-10] Found {len(tenant_collections)} collections from Qdrant "
f"for tenant={tenant_id} in {db_time:.2f}ms: {tenant_collections}"
)
# 3. 缓存结果5分钟 TTL
try:
redis_client = await cache_service._get_redis()
if redis_client and cache_service._enabled:
import json
await redis_client.setex(
cache_key,
300, # 5分钟
json.dumps(tenant_collections)
)
logger.info(f"[AC-AISVC-10] Cached collections for tenant={tenant_id}")
except Exception as e:
logger.warning(f"[AC-AISVC-10] Cache set error: {e}")
return tenant_collections
except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to get collections for tenant={tenant_id}: {e}")
return [self.get_collection_name(tenant_id)]
def _build_qdrant_filter(
self,
metadata_filter: dict[str, Any],
) -> Any:
"""
构建 Qdrant 过滤条件。
Args:
metadata_filter: 元数据过滤条件,如 {"grade": "三年级", "subject": "语文"}
Returns:
Qdrant Filter 对象
"""
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
# 支持嵌套 metadata 字段,如 metadata.grade
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
return Filter(must=must_conditions) if must_conditions else None
async def delete_collection(self, tenant_id: str) -> bool:
"""
[AC-AISVC-10] Delete tenant's collection.
Used when tenant is removed.
"""
client = await self.get_client()
collection_name = self.get_collection_name(tenant_id)
try:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}")
return True
except Exception as e:
logger.error(f"[AC-AISVC-10] Error deleting collection: {e}")
return False
async def delete_kb_collection(self, tenant_id: str, kb_id: str) -> bool:
"""
[AC-AISVC-62] Delete a specific knowledge base's collection.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
Returns:
True if collection was deleted successfully
"""
client = await self.get_client()
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if exists:
await client.delete_collection(collection_name=collection_name)
logger.info(f"[AC-AISVC-62] Deleted KB collection: {collection_name} for kb_id={kb_id}")
else:
logger.info(f"[AC-AISVC-62] KB collection {collection_name} does not exist, nothing to delete")
return True
except Exception as e:
logger.error(f"[AC-AISVC-62] Error deleting KB collection: {e}")
return False
async def search_kb(
self,
tenant_id: str,
query_vector: list[float],
kb_ids: list[str] | None = None,
limit: int = 5,
score_threshold: float | None = None,
vector_name: str = "full",
with_vectors: bool = False,
) -> list[dict[str, Any]]:
"""
[AC-AISVC-64] Search vectors across multiple knowledge base collections.
Args:
tenant_id: Tenant identifier
query_vector: Query vector for similarity search
kb_ids: List of knowledge base IDs to search. If None, searches legacy collection.
limit: Maximum number of results per collection
score_threshold: Minimum score threshold for results
vector_name: Name of the vector to search
with_vectors: Whether to return vectors in results
Returns:
Combined and sorted results from all collections
"""
client = await self.get_client()
if kb_ids is None or len(kb_ids) == 0:
return await self.search(
tenant_id=tenant_id,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
vector_name=vector_name,
with_vectors=with_vectors,
)
logger.info(
f"[AC-AISVC-64] Starting multi-KB search: tenant_id={tenant_id}, "
f"kb_ids={kb_ids}, limit={limit}, score_threshold={score_threshold}"
)
all_hits = []
for kb_id in kb_ids:
collection_name = self.get_kb_collection_name(tenant_id, kb_id)
try:
exists = await client.collection_exists(collection_name)
if not exists:
logger.warning(f"[AC-AISVC-64] Collection {collection_name} does not exist")
continue
try:
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
except Exception as e:
if "vector name" in str(e).lower() or "Not existing vector" in str(e) or "using" in str(e).lower():
results = await client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
else:
raise
for result in results.points:
hit = {
"id": str(result.id),
"score": result.score,
"payload": result.payload or {},
"kb_id": kb_id,
}
if with_vectors and result.vector:
hit["vector"] = result.vector
all_hits.append(hit)
logger.info(
f"[AC-AISVC-64] Collection {collection_name} returned {len(results.points)} results"
)
except Exception as e:
logger.warning(f"[AC-AISVC-64] Error searching collection {collection_name}: {e}")
continue
all_hits = sorted(all_hits, key=lambda x: x["score"], reverse=True)[:limit]
logger.info(
f"[AC-AISVC-64] Multi-KB search returned {len(all_hits)} total results"
)
return all_hits
_qdrant_client: QdrantClient | None = None
async def get_qdrant_client() -> QdrantClient:
"""Get or create Qdrant client instance."""
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = QdrantClient()
return _qdrant_client
async def close_qdrant_client() -> None:
"""Close Qdrant client connection."""
global _qdrant_client
if _qdrant_client:
await _qdrant_client.close()
_qdrant_client = None