ai-robot-core/ai-service/app/services/embedding/factory.py

302 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Embedding provider factory and configuration manager.
[AC-AISVC-30, AC-AISVC-31] Factory pattern for dynamic provider loading.
Design reference: progress.md Section 7.1 - Architecture
- EmbeddingProviderFactory: creates providers based on config
- EmbeddingConfigManager: manages configuration with hot-reload support
"""
import logging
from typing import Any, Type
from app.services.embedding.base import EmbeddingException, EmbeddingProvider
from app.services.embedding.ollama_provider import OllamaEmbeddingProvider
from app.services.embedding.openai_provider import OpenAIEmbeddingProvider
logger = logging.getLogger(__name__)
class EmbeddingProviderFactory:
"""
Factory for creating embedding providers.
[AC-AISVC-30] Supports dynamic loading based on configuration.
"""
_providers: dict[str, Type[EmbeddingProvider]] = {
"ollama": OllamaEmbeddingProvider,
"openai": OpenAIEmbeddingProvider,
}
@classmethod
def register_provider(cls, name: str, provider_class: Type[EmbeddingProvider]) -> None:
"""
Register a new embedding provider.
[AC-AISVC-30] Allows runtime registration of providers.
"""
cls._providers[name] = provider_class
logger.info(f"Registered embedding provider: {name}")
@classmethod
def get_available_providers(cls) -> list[str]:
"""
Get list of available provider names.
[AC-AISVC-38] Returns registered provider identifiers.
"""
return list(cls._providers.keys())
@classmethod
def get_provider_info(cls, name: str) -> dict[str, Any]:
"""
Get provider information including config schema.
[AC-AISVC-38] Returns provider metadata.
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown provider: {name}",
provider="factory"
)
provider_class = cls._providers[name]
temp_instance = provider_class.__new__(provider_class)
display_names = {
"ollama": "Ollama 本地模型",
"openai": "OpenAI Embedding",
}
descriptions = {
"ollama": "使用 Ollama 运行的本地嵌入模型,支持 nomic-embed-text 等开源模型",
"openai": "使用 OpenAI 官方 Embedding API支持 text-embedding-3 系列模型",
}
return {
"name": name,
"display_name": display_names.get(name, name),
"description": descriptions.get(name, ""),
"config_schema": temp_instance.get_config_schema(),
}
@classmethod
def create_provider(
cls,
name: str,
config: dict[str, Any],
) -> EmbeddingProvider:
"""
Create an embedding provider instance.
[AC-AISVC-30] Creates provider based on configuration.
Args:
name: Provider identifier (e.g., "ollama", "openai")
config: Provider-specific configuration
Returns:
Configured EmbeddingProvider instance
Raises:
EmbeddingException: If provider is unknown or configuration is invalid
"""
if name not in cls._providers:
raise EmbeddingException(
f"Unknown embedding provider: {name}. "
f"Available: {cls.get_available_providers()}",
provider="factory"
)
provider_class = cls._providers[name]
try:
instance = provider_class(**config)
logger.info(f"Created embedding provider: {name}")
return instance
except Exception as e:
raise EmbeddingException(
f"Failed to create provider '{name}': {e}",
provider="factory",
details={"config": config}
)
class EmbeddingConfigManager:
"""
Manager for embedding configuration.
[AC-AISVC-31] Supports hot-reload of configuration.
"""
def __init__(self, default_provider: str = "ollama", default_config: dict[str, Any] | None = None):
self._provider_name = default_provider
self._config = default_config or {
"base_url": "http://localhost:11434",
"model": "nomic-embed-text",
"dimension": 768,
}
self._provider: EmbeddingProvider | None = None
def get_provider_name(self) -> str:
"""Get current provider name."""
return self._provider_name
def get_config(self) -> dict[str, Any]:
"""Get current configuration."""
return self._config.copy()
def get_full_config(self) -> dict[str, Any]:
"""
Get full configuration including provider name.
[AC-AISVC-39] Returns complete configuration for API response.
"""
return {
"provider": self._provider_name,
"config": self._config.copy(),
}
async def get_provider(self) -> EmbeddingProvider:
"""
Get or create the embedding provider.
[AC-AISVC-29] Returns configured provider instance.
"""
if self._provider is None:
self._provider = EmbeddingProviderFactory.create_provider(
self._provider_name,
self._config
)
return self._provider
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update embedding configuration.
[AC-AISVC-31, AC-AISVC-40] Supports hot-reload.
Args:
provider: New provider name
config: New provider configuration
Returns:
True if update was successful
Raises:
EmbeddingException: If configuration is invalid
"""
old_provider = self._provider_name
old_config = self._config.copy()
try:
new_provider_instance = EmbeddingProviderFactory.create_provider(
provider,
config
)
if self._provider:
await self._provider.close()
self._provider_name = provider
self._config = config
self._provider = new_provider_instance
logger.info(f"Updated embedding config: provider={provider}")
return True
except Exception as e:
self._provider_name = old_provider
self._config = old_config
raise EmbeddingException(
f"Failed to update config: {e}",
provider="config_manager",
details={"provider": provider, "config": config}
)
async def test_connection(
self,
test_text: str = "这是一个测试文本",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test embedding connection.
[AC-AISVC-41] Tests provider connectivity.
Args:
test_text: Text to embed for testing
provider: Provider to test (uses current if None)
config: Config to test (uses current if None)
Returns:
Dict with test results including success, dimension, latency
"""
import time
test_provider_name = provider or self._provider_name
test_config = config or self._config
try:
test_provider = EmbeddingProviderFactory.create_provider(
test_provider_name,
test_config
)
start_time = time.perf_counter()
embedding = await test_provider.embed(test_text)
latency_ms = (time.perf_counter() - start_time) * 1000
await test_provider.close()
return {
"success": True,
"dimension": len(embedding),
"latency_ms": latency_ms,
"message": f"连接成功,向量维度: {len(embedding)}",
}
except Exception as e:
return {
"success": False,
"dimension": 0,
"latency_ms": 0,
"error": str(e),
"message": f"连接失败: {e}",
}
async def close(self) -> None:
"""Close the current provider."""
if self._provider:
await self._provider.close()
self._provider = None
_embedding_config_manager: EmbeddingConfigManager | None = None
def get_embedding_config_manager() -> EmbeddingConfigManager:
"""
Get the global embedding config manager.
[AC-AISVC-31] Singleton pattern for configuration management.
"""
global _embedding_config_manager
if _embedding_config_manager is None:
from app.core.config import get_settings
settings = get_settings()
_embedding_config_manager = EmbeddingConfigManager(
default_provider="ollama",
default_config={
"base_url": settings.ollama_base_url,
"model": settings.ollama_embedding_model,
"dimension": settings.qdrant_vector_size,
}
)
return _embedding_config_manager
async def get_embedding_provider() -> EmbeddingProvider:
"""
Get the current embedding provider.
[AC-AISVC-29] Convenience function for getting provider.
"""
manager = get_embedding_config_manager()
return await manager.get_provider()