131 lines
3.4 KiB
Python
131 lines
3.4 KiB
Python
|
|
"""
|
||
|
|
Base embedding provider interface.
|
||
|
|
[AC-AISVC-29] Abstract interface for embedding providers.
|
||
|
|
|
||
|
|
Design reference: progress.md Section 7.1 - EmbeddingProvider interface
|
||
|
|
- embed(text) -> list[float]
|
||
|
|
- embed_batch(texts) -> list[list[float]]
|
||
|
|
- get_dimension() -> int
|
||
|
|
- get_provider_name() -> str
|
||
|
|
"""
|
||
|
|
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class EmbeddingConfig:
|
||
|
|
"""
|
||
|
|
Configuration for embedding provider.
|
||
|
|
[AC-AISVC-31] Supports configurable embedding parameters.
|
||
|
|
"""
|
||
|
|
dimension: int = 768
|
||
|
|
batch_size: int = 32
|
||
|
|
timeout_seconds: int = 60
|
||
|
|
extra_params: dict[str, Any] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class EmbeddingResult:
|
||
|
|
"""
|
||
|
|
Result from embedding generation.
|
||
|
|
[AC-AISVC-29] Contains embedding vector and metadata.
|
||
|
|
"""
|
||
|
|
embedding: list[float]
|
||
|
|
dimension: int
|
||
|
|
model: str
|
||
|
|
latency_ms: float = 0.0
|
||
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
class EmbeddingProvider(ABC):
|
||
|
|
"""
|
||
|
|
Abstract base class for embedding providers.
|
||
|
|
[AC-AISVC-29] Provides unified interface for different embedding providers.
|
||
|
|
|
||
|
|
Design reference: progress.md Section 7.1 - Architecture
|
||
|
|
- OllamaEmbeddingProvider / OpenAIEmbeddingProvider can be swapped
|
||
|
|
- Factory pattern for dynamic loading
|
||
|
|
"""
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def embed(self, text: str) -> list[float]:
|
||
|
|
"""
|
||
|
|
Generate embedding vector for a single text.
|
||
|
|
[AC-AISVC-29] Returns embedding vector.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: Input text to embed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of floats representing the embedding vector.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
EmbeddingException: If embedding generation fails.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||
|
|
"""
|
||
|
|
Generate embedding vectors for multiple texts.
|
||
|
|
[AC-AISVC-29] Returns list of embedding vectors.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
texts: List of input texts to embed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of embedding vectors.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
EmbeddingException: If embedding generation fails.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_dimension(self) -> int:
|
||
|
|
"""
|
||
|
|
Get the dimension of embedding vectors.
|
||
|
|
[AC-AISVC-29] Returns vector dimension.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Integer dimension of embedding vectors.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_provider_name(self) -> str:
|
||
|
|
"""
|
||
|
|
Get the name of this embedding provider.
|
||
|
|
[AC-AISVC-29] Returns provider identifier.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
String identifier for this provider.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_config_schema(self) -> dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get the configuration schema for this provider.
|
||
|
|
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict describing configuration parameters.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
"""Close the provider and release resources. Default no-op."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class EmbeddingException(Exception):
|
||
|
|
"""Exception raised when embedding generation fails."""
|
||
|
|
|
||
|
|
def __init__(self, message: str, provider: str = "", details: dict[str, Any] | None = None):
|
||
|
|
self.provider = provider
|
||
|
|
self.details = details or {}
|
||
|
|
super().__init__(f"[{provider}] {message}" if provider else message)
|