""" Ollama embedding provider implementation. [AC-AISVC-29, AC-AISVC-30] Ollama-based embedding provider. Uses Ollama API for generating text embeddings. """ import logging import time from typing import Any import httpx from app.services.embedding.base import ( EmbeddingConfig, EmbeddingException, EmbeddingProvider, ) logger = logging.getLogger(__name__) class OllamaEmbeddingProvider(EmbeddingProvider): """ Embedding provider using Ollama API. [AC-AISVC-29, AC-AISVC-30] Supports local embedding models via Ollama. """ PROVIDER_NAME = "ollama" def __init__( self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text", dimension: int = 768, timeout_seconds: int = 60, **kwargs: Any, ): self._base_url = base_url.rstrip("/") self._model = model self._dimension = dimension self._timeout = timeout_seconds self._client: httpx.AsyncClient | None = None self._extra_config = kwargs async def _get_client(self) -> httpx.AsyncClient: if self._client is None: self._client = httpx.AsyncClient(timeout=self._timeout) return self._client async def embed(self, text: str) -> list[float]: """ Generate embedding vector for a single text using Ollama API. [AC-AISVC-29] Returns embedding vector. """ start_time = time.perf_counter() try: client = await self._get_client() response = await client.post( f"{self._base_url}/api/embeddings", json={ "model": self._model, "prompt": text, } ) response.raise_for_status() data = response.json() embedding = data.get("embedding", []) if not embedding: raise EmbeddingException( "Empty embedding returned", provider=self.PROVIDER_NAME, details={"text_length": len(text)} ) latency_ms = (time.perf_counter() - start_time) * 1000 logger.debug( f"Generated embedding via Ollama: dim={len(embedding)}, " f"latency={latency_ms:.2f}ms" ) return embedding except httpx.HTTPStatusError as e: raise EmbeddingException( f"Ollama API error: {e.response.status_code}", provider=self.PROVIDER_NAME, details={"status_code": e.response.status_code, "response": e.response.text} ) except httpx.RequestError as e: raise EmbeddingException( f"Ollama connection error: {e}", provider=self.PROVIDER_NAME, details={"base_url": self._base_url} ) except EmbeddingException: raise except Exception as e: raise EmbeddingException( f"Embedding generation failed: {e}", provider=self.PROVIDER_NAME ) async def embed_batch(self, texts: list[str]) -> list[list[float]]: """ Generate embedding vectors for multiple texts. [AC-AISVC-29] Sequential embedding generation. """ embeddings = [] for text in texts: embedding = await self.embed(text) embeddings.append(embedding) return embeddings def get_dimension(self) -> int: """Get the dimension of embedding vectors.""" return self._dimension def get_provider_name(self) -> str: """Get the name of this embedding provider.""" return self.PROVIDER_NAME def get_config_schema(self) -> dict[str, Any]: """ Get the configuration schema for Ollama provider. [AC-AISVC-38] Returns JSON Schema for configuration parameters. """ return { "base_url": { "type": "string", "description": "Ollama API 地址", "default": "http://localhost:11434", }, "model": { "type": "string", "description": "嵌入模型名称", "default": "nomic-embed-text", }, "dimension": { "type": "integer", "description": "向量维度", "default": 768, }, "timeout_seconds": { "type": "integer", "description": "请求超时时间(秒)", "default": 60, }, } async def close(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None