ai-robot-core/ai-service/app/services/embedding/openai_provider.py

194 lines
6.3 KiB
Python

"""
OpenAI embedding provider implementation.
[AC-AISVC-29, AC-AISVC-30] OpenAI-based embedding provider.
Uses OpenAI API for generating text embeddings.
"""
import logging
import time
from typing import Any
import httpx
from app.services.embedding.base import (
EmbeddingException,
EmbeddingProvider,
)
logger = logging.getLogger(__name__)
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""
Embedding provider using OpenAI API.
[AC-AISVC-29, AC-AISVC-30] Supports OpenAI embedding models.
"""
PROVIDER_NAME = "openai"
MODEL_DIMENSIONS = {
"text-embedding-ada-002": 1536,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
}
def __init__(
self,
api_key: str,
model: str = "text-embedding-3-small",
base_url: str = "https://api.openai.com/v1",
dimension: int | None = None,
timeout_seconds: int = 60,
**kwargs: Any,
):
self._api_key = api_key
self._model = model
self._base_url = base_url.rstrip("/")
self._timeout = timeout_seconds
self._client: httpx.AsyncClient | None = None
self._extra_config = kwargs
if dimension:
self._dimension = dimension
elif model in self.MODEL_DIMENSIONS:
self._dimension = self.MODEL_DIMENSIONS[model]
else:
self._dimension = 1536
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for a single text using OpenAI API.
[AC-AISVC-29] Returns embedding vector.
"""
embeddings = await self.embed_batch([text])
return embeddings[0]
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embedding vectors for multiple texts using OpenAI API.
[AC-AISVC-29] Supports batch embedding for efficiency.
"""
start_time = time.perf_counter()
try:
client = await self._get_client()
request_body: dict[str, Any] = {
"model": self._model,
"input": texts,
}
if self._dimension and self._model.startswith("text-embedding-3"):
request_body["dimensions"] = self._dimension
response = await client.post(
f"{self._base_url}/embeddings",
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
json=request_body,
)
response.raise_for_status()
data = response.json()
embeddings = []
for item in data.get("data", []):
embedding = item.get("embedding", [])
if not embedding:
raise EmbeddingException(
"Empty embedding returned",
provider=self.PROVIDER_NAME,
details={"index": item.get("index", 0)}
)
embeddings.append(embedding)
if len(embeddings) != len(texts):
raise EmbeddingException(
f"Embedding count mismatch: expected {len(texts)}, got {len(embeddings)}",
provider=self.PROVIDER_NAME
)
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"Generated {len(embeddings)} embeddings via OpenAI: "
f"dim={len(embeddings[0]) if embeddings else 0}, "
f"latency={latency_ms:.2f}ms"
)
return embeddings
except httpx.HTTPStatusError as e:
raise EmbeddingException(
f"OpenAI API error: {e.response.status_code}",
provider=self.PROVIDER_NAME,
details={"status_code": e.response.status_code, "response": e.response.text}
)
except httpx.RequestError as e:
raise EmbeddingException(
f"OpenAI connection error: {e}",
provider=self.PROVIDER_NAME,
details={"base_url": self._base_url}
)
except EmbeddingException:
raise
except Exception as e:
raise EmbeddingException(
f"Embedding generation failed: {e}",
provider=self.PROVIDER_NAME
)
def get_dimension(self) -> int:
"""Get the dimension of embedding vectors."""
return self._dimension
def get_provider_name(self) -> str:
"""Get the name of this embedding provider."""
return self.PROVIDER_NAME
def get_config_schema(self) -> dict[str, Any]:
"""
Get the configuration schema for OpenAI provider.
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
"""
return {
"api_key": {
"type": "string",
"description": "OpenAI API 密钥",
"required": True,
"secret": True,
},
"model": {
"type": "string",
"description": "嵌入模型名称",
"default": "text-embedding-3-small",
"enum": list(self.MODEL_DIMENSIONS.keys()),
},
"base_url": {
"type": "string",
"description": "OpenAI API 地址(支持兼容接口)",
"default": "https://api.openai.com/v1",
},
"dimension": {
"type": "integer",
"description": "向量维度(仅 text-embedding-3 系列支持自定义)",
"default": 1536,
},
"timeout_seconds": {
"type": "integer",
"description": "请求超时时间(秒)",
"default": 60,
},
}
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None