""" OpenAI embedding provider implementation. [AC-AISVC-29, AC-AISVC-30] OpenAI-based embedding provider. Uses OpenAI API for generating text embeddings. """ import logging import time from typing import Any import httpx from app.services.embedding.base import ( EmbeddingException, EmbeddingProvider, ) logger = logging.getLogger(__name__) class OpenAIEmbeddingProvider(EmbeddingProvider): """ Embedding provider using OpenAI API. [AC-AISVC-29, AC-AISVC-30] Supports OpenAI embedding models. """ PROVIDER_NAME = "openai" MODEL_DIMENSIONS = { "text-embedding-ada-002": 1536, "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, } def __init__( self, api_key: str, model: str = "text-embedding-3-small", base_url: str = "https://api.openai.com/v1", dimension: int | None = None, timeout_seconds: int = 60, **kwargs: Any, ): self._api_key = api_key self._model = model self._base_url = base_url.rstrip("/") self._timeout = timeout_seconds self._client: httpx.AsyncClient | None = None self._extra_config = kwargs if dimension: self._dimension = dimension elif model in self.MODEL_DIMENSIONS: self._dimension = self.MODEL_DIMENSIONS[model] else: self._dimension = 1536 async def _get_client(self) -> httpx.AsyncClient: if self._client is None: self._client = httpx.AsyncClient(timeout=self._timeout) return self._client async def embed(self, text: str) -> list[float]: """ Generate embedding vector for a single text using OpenAI API. [AC-AISVC-29] Returns embedding vector. """ embeddings = await self.embed_batch([text]) return embeddings[0] async def embed_batch(self, texts: list[str]) -> list[list[float]]: """ Generate embedding vectors for multiple texts using OpenAI API. [AC-AISVC-29] Supports batch embedding for efficiency. """ start_time = time.perf_counter() try: client = await self._get_client() request_body: dict[str, Any] = { "model": self._model, "input": texts, } if self._dimension and self._model.startswith("text-embedding-3"): request_body["dimensions"] = self._dimension response = await client.post( f"{self._base_url}/embeddings", headers={ "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", }, json=request_body, ) response.raise_for_status() data = response.json() embeddings = [] for item in data.get("data", []): embedding = item.get("embedding", []) if not embedding: raise EmbeddingException( "Empty embedding returned", provider=self.PROVIDER_NAME, details={"index": item.get("index", 0)} ) embeddings.append(embedding) if len(embeddings) != len(texts): raise EmbeddingException( f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}", provider=self.PROVIDER_NAME ) latency_ms = (time.perf_counter() - start_time) * 1000 logger.debug( f"Generated {len(embeddings)} embeddings via OpenAI: " f"dim={len(embeddings[0]) if embeddings else 0}, " f"latency={latency_ms:.2f}ms" ) return embeddings except httpx.HTTPStatusError as e: raise EmbeddingException( f"OpenAI API error: {e.response.status_code}", provider=self.PROVIDER_NAME, details={"status_code": e.response.status_code, "response": e.response.text} ) except httpx.RequestError as e: raise EmbeddingException( f"OpenAI connection error: {e}", provider=self.PROVIDER_NAME, details={"base_url": self._base_url} ) except EmbeddingException: raise except Exception as e: raise EmbeddingException( f"Embedding generation failed: {e}", provider=self.PROVIDER_NAME ) def get_dimension(self) -> int: """Get the dimension of embedding vectors.""" return self._dimension def get_provider_name(self) -> str: """Get the name of this embedding provider.""" return self.PROVIDER_NAME def get_config_schema(self) -> dict[str, Any]: """ Get the configuration schema for OpenAI provider. [AC-AISVC-38] Returns JSON Schema for configuration parameters. """ return { "api_key": { "type": "string", "description": "OpenAI API 密钥", "required": True, "secret": True, }, "model": { "type": "string", "description": "嵌入模型名称", "default": "text-embedding-3-small", "enum": list(self.MODEL_DIMENSIONS.keys()), }, "base_url": { "type": "string", "description": "OpenAI API 地址(支持兼容接口)", "default": "https://api.openai.com/v1", }, "dimension": { "type": "integer", "description": "向量维度(仅 text-embedding-3 系列支持自定义)", "default": 1536, }, "timeout_seconds": { "type": "integer", "description": "请求超时时间(秒)", "default": 60, }, } async def close(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None