368 lines
12 KiB
Python
368 lines
12 KiB
Python
"""
|
||
Embedding provider factory and configuration manager.
|
||
[AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading.
|
||
|
||
Design reference: progress.md Section 7.1 - Architecture
|
||
- EmbeddingProviderFactory: creates providers based on config
|
||
- EmbeddingConfigManager: manages configuration with hot-reload support
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Any, Type
|
||
|
||
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
|
||
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
|
||
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
|
||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json")
|
||
|
||
|
||
class EmbeddingProviderFactory:
|
||
"""
|
||
Factory for creating embedding providers.
|
||
[AC-AISVC-30] Supports dynamic loading based on configuration.
|
||
"""
|
||
|
||
_providers: dict[str, Type[EmbeddingProvider]] = {
|
||
"ollama": OllamaEmbeddingProvider,
|
||
"openai": OpenAIEmbeddingProvider,
|
||
"nomic": NomicEmbeddingProvider,
|
||
}
|
||
|
||
@classmethod
|
||
def register_provider(cls, name: str, provider_class: Type[EmbeddingProvider]) -> None:
|
||
"""
|
||
Register a new embedding provider.
|
||
[AC-AISVC-30] Allows runtime registration of providers.
|
||
"""
|
||
cls._providers[name] = provider_class
|
||
logger.info(f"Registered embedding provider: {name}")
|
||
|
||
@classmethod
|
||
def get_available_providers(cls) -> list[str]:
|
||
"""
|
||
Get list of available provider names.
|
||
[AC-AISVC-38] Returns registered provider identifiers.
|
||
"""
|
||
return list(cls._providers.keys())
|
||
|
||
@classmethod
|
||
def get_provider_info(cls, name: str) -> dict[str, Any]:
|
||
"""
|
||
Get provider information including config schema.
|
||
[AC-AISVC-38] Returns provider metadata.
|
||
"""
|
||
if name not in cls._providers:
|
||
raise EmbeddingException(
|
||
f"Unknown provider: {name}",
|
||
provider="factory"
|
||
)
|
||
|
||
provider_class = cls._providers[name]
|
||
temp_instance = provider_class.__new__(provider_class)
|
||
|
||
display_names = {
|
||
"ollama": "Ollama 本地模型",
|
||
"openai": "OpenAI Embedding",
|
||
"nomic": "Nomic Embed (优化版)",
|
||
}
|
||
|
||
descriptions = {
|
||
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
|
||
"openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型",
|
||
"nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化",
|
||
}
|
||
|
||
raw_schema = temp_instance.get_config_schema()
|
||
|
||
properties = {}
|
||
required = []
|
||
for key, field in raw_schema.items():
|
||
properties[key] = {
|
||
"type": field.get("type", "string"),
|
||
"title": field.get("title", key),
|
||
"description": field.get("description", ""),
|
||
"default": field.get("default"),
|
||
}
|
||
if field.get("enum"):
|
||
properties[key]["enum"] = field.get("enum")
|
||
if field.get("minimum") is not None:
|
||
properties[key]["minimum"] = field.get("minimum")
|
||
if field.get("maximum") is not None:
|
||
properties[key]["maximum"] = field.get("maximum")
|
||
if field.get("required"):
|
||
required.append(key)
|
||
|
||
config_schema = {
|
||
"type": "object",
|
||
"properties": properties,
|
||
}
|
||
if required:
|
||
config_schema["required"] = required
|
||
|
||
return {
|
||
"name": name,
|
||
"display_name": display_names.get(name, name),
|
||
"description": descriptions.get(name, ""),
|
||
"config_schema": config_schema,
|
||
}
|
||
|
||
@classmethod
|
||
def create_provider(
|
||
cls,
|
||
name: str,
|
||
config: dict[str, Any],
|
||
) -> EmbeddingProvider:
|
||
"""
|
||
Create an embedding provider instance.
|
||
[AC-AISVC-30] Creates provider based on configuration.
|
||
|
||
Args:
|
||
name: Provider identifier (e.g., "ollama", "openai")
|
||
config: Provider-specific configuration
|
||
|
||
Returns:
|
||
Configured EmbeddingProvider instance
|
||
|
||
Raises:
|
||
EmbeddingException: If provider is unknown or configuration is invalid
|
||
"""
|
||
if name not in cls._providers:
|
||
raise EmbeddingException(
|
||
f"Unknown embedding provider: {name}. "
|
||
f"Available: {cls.get_available_providers()}",
|
||
provider="factory"
|
||
)
|
||
|
||
provider_class = cls._providers[name]
|
||
|
||
try:
|
||
instance = provider_class(**config)
|
||
logger.info(f"Created embedding provider: {name}")
|
||
return instance
|
||
except Exception as e:
|
||
raise EmbeddingException(
|
||
f"Failed to create provider '{name}': {e}",
|
||
provider="factory",
|
||
details={"config": config}
|
||
)
|
||
|
||
|
||
class EmbeddingConfigManager:
|
||
"""
|
||
Manager for embedding configuration.
|
||
[AC-AISVC-31] Supports hot-reload of configuration with persistence.
|
||
"""
|
||
|
||
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
|
||
self._default_provider = default_provider
|
||
self._default_config = default_config or {
|
||
"base_url": "http://localhost:11434",
|
||
"model": "nomic-embed-text",
|
||
"dimension": 768,
|
||
}
|
||
self._provider_name = default_provider
|
||
self._config = self._default_config.copy()
|
||
self._provider: EmbeddingProvider | None = None
|
||
|
||
self._load_from_file()
|
||
|
||
def _load_from_file(self) -> None:
|
||
"""Load configuration from file if exists."""
|
||
try:
|
||
if EMBEDDING_CONFIG_FILE.exists():
|
||
with open(EMBEDDING_CONFIG_FILE, 'r', encoding='utf-8') as f:
|
||
saved = json.load(f)
|
||
self._provider_name = saved.get("provider", self._default_provider)
|
||
self._config = saved.get("config", self._default_config.copy())
|
||
logger.info(f"Loaded embedding config from file: provider={self._provider_name}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load embedding config from file: {e}")
|
||
|
||
def _save_to_file(self) -> None:
|
||
"""Save configuration to file."""
|
||
try:
|
||
EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f:
|
||
json.dump({
|
||
"provider": self._provider_name,
|
||
"config": self._config,
|
||
}, f, indent=2, ensure_ascii=False)
|
||
logger.info(f"Saved embedding config to file: provider={self._provider_name}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to save embedding config to file: {e}")
|
||
|
||
def get_provider_name(self) -> str:
|
||
"""Get current provider name."""
|
||
return self._provider_name
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
"""Get current configuration."""
|
||
return self._config.copy()
|
||
|
||
def get_full_config(self) -> dict[str, Any]:
|
||
"""
|
||
Get full configuration including provider name.
|
||
[AC-AISVC-39] Returns complete configuration for API response.
|
||
"""
|
||
return {
|
||
"provider": self._provider_name,
|
||
"config": self._config.copy(),
|
||
}
|
||
|
||
async def get_provider(self) -> EmbeddingProvider:
|
||
"""
|
||
Get or create the embedding provider.
|
||
[AC-AISVC-29] Returns configured provider instance.
|
||
"""
|
||
if self._provider is None:
|
||
self._provider = EmbeddingProviderFactory.create_provider(
|
||
self._provider_name,
|
||
self._config
|
||
)
|
||
return self._provider
|
||
|
||
async def update_config(
|
||
self,
|
||
provider: str,
|
||
config: dict[str, Any],
|
||
) -> bool:
|
||
"""
|
||
Update embedding configuration.
|
||
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence.
|
||
|
||
Args:
|
||
provider: New provider name
|
||
config: New provider configuration
|
||
|
||
Returns:
|
||
True if update was successful
|
||
|
||
Raises:
|
||
EmbeddingException: If configuration is invalid
|
||
"""
|
||
old_provider = self._provider_name
|
||
old_config = self._config.copy()
|
||
|
||
try:
|
||
new_provider_instance = EmbeddingProviderFactory.create_provider(
|
||
provider,
|
||
config
|
||
)
|
||
|
||
if self._provider:
|
||
await self._provider.close()
|
||
|
||
self._provider_name = provider
|
||
self._config = config
|
||
self._provider = new_provider_instance
|
||
|
||
self._save_to_file()
|
||
|
||
logger.info(f"Updated embedding config: provider={provider}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self._provider_name = old_provider
|
||
self._config = old_config
|
||
raise EmbeddingException(
|
||
f"Failed to update config: {e}",
|
||
provider="config_manager",
|
||
details={"provider": provider, "config": config}
|
||
)
|
||
|
||
async def test_connection(
|
||
self,
|
||
test_text: str = "这是一个测试文本",
|
||
provider: str | None = None,
|
||
config: dict[str, Any] | None = None,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
Test embedding connection.
|
||
[AC-AISVC-41] Tests provider connectivity.
|
||
|
||
Args:
|
||
test_text: Text to embed for testing
|
||
provider: Provider to test (uses current if None)
|
||
config: Config to test (uses current if None)
|
||
|
||
Returns:
|
||
Dict with test results including success, dimension, latency
|
||
"""
|
||
import time
|
||
|
||
test_provider_name = provider or self._provider_name
|
||
test_config = config or self._config
|
||
|
||
try:
|
||
test_provider = EmbeddingProviderFactory.create_provider(
|
||
test_provider_name,
|
||
test_config
|
||
)
|
||
|
||
start_time = time.perf_counter()
|
||
embedding = await test_provider.embed(test_text)
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
|
||
await test_provider.close()
|
||
|
||
return {
|
||
"success": True,
|
||
"dimension": len(embedding),
|
||
"latency_ms": latency_ms,
|
||
"message": f"连接成功,向量维度: {len(embedding)}",
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"dimension": 0,
|
||
"latency_ms": 0,
|
||
"error": str(e),
|
||
"message": f"连接失败: {e}",
|
||
}
|
||
|
||
async def close(self) -> None:
|
||
"""Close the current provider."""
|
||
if self._provider:
|
||
await self._provider.close()
|
||
self._provider = None
|
||
|
||
|
||
_embedding_config_manager: EmbeddingConfigManager | None = None
|
||
|
||
|
||
def get_embedding_config_manager() -> EmbeddingConfigManager:
|
||
"""
|
||
Get the global embedding config manager.
|
||
[AC-AISVC-31] Singleton pattern for configuration management.
|
||
"""
|
||
global _embedding_config_manager
|
||
if _embedding_config_manager is None:
|
||
from app.core.config import get_settings
|
||
settings = get_settings()
|
||
|
||
_embedding_config_manager = EmbeddingConfigManager(
|
||
default_provider="nomic",
|
||
default_config={
|
||
"base_url": settings.ollama_base_url,
|
||
"model": settings.ollama_embedding_model,
|
||
"dimension": settings.qdrant_vector_size,
|
||
}
|
||
)
|
||
return _embedding_config_manager
|
||
|
||
|
||
async def get_embedding_provider() -> EmbeddingProvider:
|
||
"""
|
||
Get the current embedding provider.
|
||
[AC-AISVC-29] Convenience function for getting provider.
|
||
"""
|
||
manager = get_embedding_config_manager()
|
||
return await manager.get_provider()
|