""" Embedding provider factory and configuration manager. [AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading. Design reference: progress.md Section 7.1 - Architecture - EmbeddingProviderFactory: creates providers based on config - EmbeddingConfigManager: manages configuration with hot-reload support """ import json import logging from pathlib import Path from typing import Any import redis from app.core.config import get_settings from app.services.embedding.base import EmbeddingException, EmbeddingProvider from app.services.embedding.nomic_provider import NomicEmbeddingProvider from app.services.embedding.ollama_provider import OllamaEmbeddingProvider from app.services.embedding.openai_provider import OpenAIEmbeddingProvider logger = logging.getLogger(__name__) EMBEDDING_CONFIG_FILE = Path("config/embedding_config.json") EMBEDDING_CONFIG_REDIS_KEY = "ai_service:config:embedding" class EmbeddingProviderFactory: """ Factory for creating embedding providers. [AC-AISVC-30] Supports dynamic loading based on configuration. """ _providers: dict[str, type[EmbeddingProvider]] = { "ollama": OllamaEmbeddingProvider, "openai": OpenAIEmbeddingProvider, "nomic": NomicEmbeddingProvider, } @classmethod def register_provider(cls, name: str, provider_class: type[EmbeddingProvider]) -> None: """ Register a new embedding provider. [AC-AISVC-30] Allows runtime registration of providers. """ cls._providers[name] = provider_class logger.info(f"Registered embedding provider: {name}") @classmethod def get_available_providers(cls) -> list[str]: """ Get list of available provider names. [AC-AISVC-38] Returns registered provider identifiers. """ return list(cls._providers.keys()) @classmethod def get_provider_info(cls, name: str) -> dict[str, Any]: """ Get provider information including config schema. [AC-AISVC-38] Returns provider metadata. """ if name not in cls._providers: raise EmbeddingException( f"Unknown provider: {name}", provider="factory" ) provider_class = cls._providers[name] temp_instance = provider_class.__new__(provider_class) display_names = { "ollama": "Ollama 本地模型", "openai": "OpenAI Embedding", "nomic": "Nomic Embed (优化版)", } descriptions = { "ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型", "openai": "使用 OpenAI 官方 Embedding API,支持 text-embedding-3 系列模型", "nomic": "Nomic-embed-text v1.5 优化版,支持任务前缀和 Matryoshka 维度截断,专为RAG优化", } raw_schema = temp_instance.get_config_schema() properties = {} required = [] for key, field in raw_schema.items(): properties[key] = { "type": field.get("type", "string"), "title": field.get("title", key), "description": field.get("description", ""), "default": field.get("default"), } if field.get("enum"): properties[key]["enum"] = field.get("enum") if field.get("minimum") is not None: properties[key]["minimum"] = field.get("minimum") if field.get("maximum") is not None: properties[key]["maximum"] = field.get("maximum") if field.get("required"): required.append(key) config_schema = { "type": "object", "properties": properties, } if required: config_schema["required"] = required return { "name": name, "display_name": display_names.get(name, name), "description": descriptions.get(name, ""), "config_schema": config_schema, } @classmethod def create_provider( cls, name: str, config: dict[str, Any], ) -> EmbeddingProvider: """ Create an embedding provider instance. [AC-AISVC-30] Creates provider based on configuration. Args: name: Provider identifier (e.g., "ollama", "openai") config: Provider-specific configuration Returns: Configured EmbeddingProvider instance Raises: EmbeddingException: If provider is unknown or configuration is invalid """ if name not in cls._providers: raise EmbeddingException( f"Unknown embedding provider: {name}. " f"Available: {cls.get_available_providers()}", provider="factory" ) provider_class = cls._providers[name] try: instance = provider_class(**config) logger.info(f"Created embedding provider: {name}") return instance except Exception as e: raise EmbeddingException( f"Failed to create provider '{name}': {e}", provider="factory", details={"config": config} ) class EmbeddingConfigManager: """ Manager for embedding configuration. [AC-AISVC-31] Supports hot-reload of configuration with persistence. """ def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None): self._default_provider = default_provider self._default_config = default_config or { "base_url": "http://localhost:11434", "model": "nomic-embed-text", "dimension": 768, } self._provider_name = default_provider self._config = self._default_config.copy() self._provider: EmbeddingProvider | None = None self._settings = get_settings() self._redis_client: redis.Redis | None = None self._load_from_redis() self._load_from_file() def _load_from_redis(self) -> None: """Load configuration from Redis if exists.""" try: if not self._settings.redis_enabled: return self._redis_client = redis.from_url( self._settings.redis_url, encoding="utf-8", decode_responses=True, ) saved_raw = self._redis_client.get(EMBEDDING_CONFIG_REDIS_KEY) if not saved_raw: return saved = json.loads(saved_raw) self._provider_name = saved.get("provider", self._default_provider) self._config = saved.get("config", self._default_config.copy()) logger.info(f"Loaded embedding config from Redis: provider={self._provider_name}") except Exception as e: logger.warning(f"Failed to load embedding config from Redis: {e}") def _load_from_file(self) -> None: """Load configuration from file if exists.""" try: if EMBEDDING_CONFIG_FILE.exists(): with open(EMBEDDING_CONFIG_FILE, encoding='utf-8') as f: saved = json.load(f) self._provider_name = saved.get("provider", self._default_provider) self._config = saved.get("config", self._default_config.copy()) logger.info(f"Loaded embedding config from file: provider={self._provider_name}") except Exception as e: logger.warning(f"Failed to load embedding config from file: {e}") def _save_to_redis(self) -> None: """Save configuration to Redis.""" try: if not self._settings.redis_enabled: return if self._redis_client is None: self._redis_client = redis.from_url( self._settings.redis_url, encoding="utf-8", decode_responses=True, ) self._redis_client.set( EMBEDDING_CONFIG_REDIS_KEY, json.dumps({ "provider": self._provider_name, "config": self._config, }, ensure_ascii=False), ) logger.info(f"Saved embedding config to Redis: provider={self._provider_name}") except Exception as e: logger.warning(f"Failed to save embedding config to Redis: {e}") def _save_to_file(self) -> None: """Save configuration to file.""" try: EMBEDDING_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) with open(EMBEDDING_CONFIG_FILE, 'w', encoding='utf-8') as f: json.dump({ "provider": self._provider_name, "config": self._config, }, f, indent=2, ensure_ascii=False) logger.info(f"Saved embedding config to file: provider={self._provider_name}") except Exception as e: logger.error(f"Failed to save embedding config to file: {e}") def get_provider_name(self) -> str: """Get current provider name.""" return self._provider_name def get_config(self) -> dict[str, Any]: """Get current configuration.""" return self._config.copy() def get_full_config(self) -> dict[str, Any]: """ Get full configuration including provider name. [AC-AISVC-39] Returns complete configuration for API response. """ return { "provider": self._provider_name, "config": self._config.copy(), } async def get_provider(self) -> EmbeddingProvider: """ Get or create the embedding provider. [AC-AISVC-29] Returns configured provider instance. """ if self._provider is None: self._provider = EmbeddingProviderFactory.create_provider( self._provider_name, self._config ) return self._provider async def update_config( self, provider: str, config: dict[str, Any], ) -> bool: """ Update embedding configuration. [AC-AISVC-31, AC-AISVC-40] Supports hot-reload with persistence. Args: provider: New provider name config: New provider configuration Returns: True if update was successful Raises: EmbeddingException: If configuration is invalid """ old_provider = self._provider_name old_config = self._config.copy() try: new_provider_instance = EmbeddingProviderFactory.create_provider( provider, config ) if self._provider: await self._provider.close() self._provider_name = provider self._config = config self._provider = new_provider_instance self._save_to_redis() self._save_to_file() logger.info(f"Updated embedding config: provider={provider}") return True except Exception as e: self._provider_name = old_provider self._config = old_config raise EmbeddingException( f"Failed to update config: {e}", provider="config_manager", details={"provider": provider, "config": config} ) async def test_connection( self, test_text: str = "这是一个测试文本", provider: str | None = None, config: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Test embedding connection. [AC-AISVC-41] Tests provider connectivity. Args: test_text: Text to embed for testing provider: Provider to test (uses current if None) config: Config to test (uses current if None) Returns: Dict with test results including success, dimension, latency """ import time test_provider_name = provider or self._provider_name test_config = config or self._config try: test_provider = EmbeddingProviderFactory.create_provider( test_provider_name, test_config ) start_time = time.perf_counter() embedding = await test_provider.embed(test_text) latency_ms = (time.perf_counter() - start_time) * 1000 await test_provider.close() return { "success": True, "dimension": len(embedding), "latency_ms": latency_ms, "message": f"连接成功,向量维度: {len(embedding)}", } except Exception as e: return { "success": False, "dimension": 0, "latency_ms": 0, "error": str(e), "message": f"连接失败: {e}", } async def close(self) -> None: """Close the current provider.""" if self._provider: await self._provider.close() self._provider = None _embedding_config_manager: EmbeddingConfigManager | None = None def get_embedding_config_manager() -> EmbeddingConfigManager: """ Get the global embedding config manager. [AC-AISVC-31] Singleton pattern for configuration management. """ global _embedding_config_manager if _embedding_config_manager is None: from app.core.config import get_settings settings = get_settings() _embedding_config_manager = EmbeddingConfigManager( default_provider="nomic", default_config={ "base_url": settings.ollama_base_url, "model": settings.ollama_embedding_model, "dimension": settings.qdrant_vector_size, } ) return _embedding_config_manager async def get_embedding_provider() -> EmbeddingProvider: """ Get the current embedding provider. [AC-AISVC-29] Convenience function for getting provider. """ manager = get_embedding_config_manager() return await manager.get_provider()