158 lines
4.8 KiB
Python
158 lines
4.8 KiB
Python
|
|
"""
|
||
|
|
Ollama embedding provider implementation.
|
||
|
|
[AC-AISVC-29, AC-AISVC-30] Ollama-based embedding provider.
|
||
|
|
|
||
|
|
Uses Ollama API for generating text embeddings.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
from app.services.embedding.base import (
|
||
|
|
EmbeddingConfig,
|
||
|
|
EmbeddingException,
|
||
|
|
EmbeddingProvider,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||
|
|
"""
|
||
|
|
Embedding provider using Ollama API.
|
||
|
|
[AC-AISVC-29, AC-AISVC-30] Supports local embedding models via Ollama.
|
||
|
|
"""
|
||
|
|
|
||
|
|
PROVIDER_NAME = "ollama"
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
base_url: str = "http://localhost:11434",
|
||
|
|
model: str = "nomic-embed-text",
|
||
|
|
dimension: int = 768,
|
||
|
|
timeout_seconds: int = 60,
|
||
|
|
**kwargs: Any,
|
||
|
|
):
|
||
|
|
self._base_url = base_url.rstrip("/")
|
||
|
|
self._model = model
|
||
|
|
self._dimension = dimension
|
||
|
|
self._timeout = timeout_seconds
|
||
|
|
self._client: httpx.AsyncClient | None = None
|
||
|
|
self._extra_config = kwargs
|
||
|
|
|
||
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
||
|
|
if self._client is None:
|
||
|
|
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||
|
|
return self._client
|
||
|
|
|
||
|
|
async def embed(self, text: str) -> list[float]:
|
||
|
|
"""
|
||
|
|
Generate embedding vector for a single text using Ollama API.
|
||
|
|
[AC-AISVC-29] Returns embedding vector.
|
||
|
|
"""
|
||
|
|
start_time = time.perf_counter()
|
||
|
|
|
||
|
|
try:
|
||
|
|
client = await self._get_client()
|
||
|
|
response = await client.post(
|
||
|
|
f"{self._base_url}/api/embeddings",
|
||
|
|
json={
|
||
|
|
"model": self._model,
|
||
|
|
"prompt": text,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
data = response.json()
|
||
|
|
embedding = data.get("embedding", [])
|
||
|
|
|
||
|
|
if not embedding:
|
||
|
|
raise EmbeddingException(
|
||
|
|
"Empty embedding returned",
|
||
|
|
provider=self.PROVIDER_NAME,
|
||
|
|
details={"text_length": len(text)}
|
||
|
|
)
|
||
|
|
|
||
|
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
|
|
logger.debug(
|
||
|
|
f"Generated embedding via Ollama: dim={len(embedding)}, "
|
||
|
|
f"latency={latency_ms:.2f}ms"
|
||
|
|
)
|
||
|
|
|
||
|
|
return embedding
|
||
|
|
|
||
|
|
except httpx.HTTPStatusError as e:
|
||
|
|
raise EmbeddingException(
|
||
|
|
f"Ollama API error: {e.response.status_code}",
|
||
|
|
provider=self.PROVIDER_NAME,
|
||
|
|
details={"status_code": e.response.status_code, "response": e.response.text}
|
||
|
|
)
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
raise EmbeddingException(
|
||
|
|
f"Ollama connection error: {e}",
|
||
|
|
provider=self.PROVIDER_NAME,
|
||
|
|
details={"base_url": self._base_url}
|
||
|
|
)
|
||
|
|
except EmbeddingException:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
raise EmbeddingException(
|
||
|
|
f"Embedding generation failed: {e}",
|
||
|
|
provider=self.PROVIDER_NAME
|
||
|
|
)
|
||
|
|
|
||
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||
|
|
"""
|
||
|
|
Generate embedding vectors for multiple texts.
|
||
|
|
[AC-AISVC-29] Sequential embedding generation.
|
||
|
|
"""
|
||
|
|
embeddings = []
|
||
|
|
for text in texts:
|
||
|
|
embedding = await self.embed(text)
|
||
|
|
embeddings.append(embedding)
|
||
|
|
return embeddings
|
||
|
|
|
||
|
|
def get_dimension(self) -> int:
|
||
|
|
"""Get the dimension of embedding vectors."""
|
||
|
|
return self._dimension
|
||
|
|
|
||
|
|
def get_provider_name(self) -> str:
|
||
|
|
"""Get the name of this embedding provider."""
|
||
|
|
return self.PROVIDER_NAME
|
||
|
|
|
||
|
|
def get_config_schema(self) -> dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get the configuration schema for Ollama provider.
|
||
|
|
[AC-AISVC-38] Returns JSON Schema for configuration parameters.
|
||
|
|
"""
|
||
|
|
return {
|
||
|
|
"base_url": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "Ollama API 地址",
|
||
|
|
"default": "http://localhost:11434",
|
||
|
|
},
|
||
|
|
"model": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "嵌入模型名称",
|
||
|
|
"default": "nomic-embed-text",
|
||
|
|
},
|
||
|
|
"dimension": {
|
||
|
|
"type": "integer",
|
||
|
|
"description": "向量维度",
|
||
|
|
"default": 768,
|
||
|
|
},
|
||
|
|
"timeout_seconds": {
|
||
|
|
"type": "integer",
|
||
|
|
"description": "请求超时时间(秒)",
|
||
|
|
"default": 60,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
"""Close the HTTP client."""
|
||
|
|
if self._client:
|
||
|
|
await self._client.aclose()
|
||
|
|
self._client = None
|