ai-robot-core/ai-service/app/services/llm/factory.py

457 lines
15 KiB
Python
Raw Normal View History

"""
LLM Provider Factory and Configuration Management.
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management.
Design pattern: Factory pattern for pluggable LLM providers.
"""
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__)
LLM_CONFIG_FILE = Path("config/llm_config.json")
@dataclass
class LLMProviderInfo:
"""Information about an LLM provider."""
name: str
display_name: str
description: str
config_schema: dict[str, Any]
LLM_PROVIDERS: dict[str, LLMProviderInfo] = {
"openai": LLMProviderInfo(
name="openai",
display_name="OpenAI",
description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.openai.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "gpt-4o-mini",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key"],
},
),
"ollama": LLMProviderInfo(
name="ollama",
display_name="Ollama",
description="Ollama 本地模型 (Llama, Qwen 等)",
config_schema={
"type": "object",
"properties": {
"base_url": {
"type": "string",
"title": "Ollama API 地址",
"description": "Ollama API 地址",
"default": "http://localhost:11434/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "llama3.2",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": [],
},
),
"deepseek": LLMProviderInfo(
name="deepseek",
display_name="DeepSeek",
description="DeepSeek 大模型 (deepseek-chat, deepseek-coder)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "DeepSeek API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.deepseek.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称 (deepseek-chat, deepseek-coder)",
"default": "deepseek-chat",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key"],
},
),
"azure": LLMProviderInfo(
name="azure",
display_name="Azure OpenAI",
description="Azure OpenAI 服务",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "Azure Endpoint",
"description": "Azure Endpoint",
"required": True,
},
"model": {
"type": "string",
"title": "部署名称",
"description": "部署名称",
"required": True,
},
"api_version": {
"type": "string",
"title": "API 版本",
"description": "API 版本",
"default": "2024-02-15-preview",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key", "base_url", "model"],
},
),
}
class LLMProviderFactory:
"""
Factory for creating LLM clients.
[AC-ASA-14, AC-ASA-15] Dynamic provider creation.
"""
@classmethod
def get_providers(cls) -> list[LLMProviderInfo]:
"""Get all registered LLM providers."""
return list(LLM_PROVIDERS.values())
@classmethod
def get_provider_info(cls, name: str) -> LLMProviderInfo | None:
"""Get provider info by name."""
return LLM_PROVIDERS.get(name)
@classmethod
def create_client(
cls,
provider: str,
config: dict[str, Any],
) -> LLMClient:
"""
Create an LLM client for the specified provider.
[AC-ASA-15] Factory method for client creation.
Args:
provider: Provider name (openai, ollama, azure)
config: Provider configuration
Returns:
LLMClient instance
Raises:
ValueError: If provider is not supported
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
if provider in ("openai", "ollama", "azure", "deepseek"):
return OpenAIClient(
api_key=config.get("api_key"),
base_url=config.get("base_url"),
model=config.get("model"),
default_config=LLMConfig(
model=config.get("model", "gpt-4o-mini"),
max_tokens=config.get("max_tokens", 2048),
temperature=config.get("temperature", 0.7),
),
)
raise ValueError(f"Unsupported LLM provider: {provider}")
class LLMConfigManager:
"""
Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
"""
def __init__(self):
from app.core.config import get_settings
settings = get_settings()
self._current_provider: str = settings.llm_provider
self._current_config: dict[str, Any] = {
"api_key": settings.llm_api_key,
"base_url": settings.llm_base_url,
"model": settings.llm_model,
"max_tokens": settings.llm_max_tokens,
"temperature": settings.llm_temperature,
}
self._client: LLMClient | None = None
self._load_from_file()
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if LLM_CONFIG_FILE.exists():
with open(LLM_CONFIG_FILE, 'r', encoding='utf-8') as f:
saved = json.load(f)
self._current_provider = saved.get("provider", self._current_provider)
saved_config = saved.get("config", {})
if saved_config:
self._current_config.update(saved_config)
logger.info(f"[AC-ASA-16] Loaded LLM config from file: provider={self._current_provider}")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump({
"provider": self._current_provider,
"config": self._current_config,
}, f, indent=2, ensure_ascii=False)
logger.info(f"[AC-ASA-16] Saved LLM config to file: provider={self._current_provider}")
except Exception as e:
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
def get_current_config(self) -> dict[str, Any]:
"""Get current LLM configuration."""
return {
"provider": self._current_provider,
"config": self._current_config.copy(),
}
async def update_config(
self,
provider: str,
config: dict[str, Any],
) -> bool:
"""
Update LLM configuration.
[AC-ASA-16] Hot-reload configuration with persistence.
Args:
provider: Provider name
config: New configuration
Returns:
True if update successful
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._client:
await self._client.close()
self._client = None
self._current_provider = provider
self._current_config = validated_config
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}")
return True
def _validate_config(
self,
provider_info: LLMProviderInfo,
config: dict[str, Any],
) -> dict[str, Any]:
"""Validate configuration against provider schema."""
schema_props = provider_info.config_schema.get("properties", {})
required_fields = provider_info.config_schema.get("required", [])
validated = {}
for key, prop_schema in schema_props.items():
if key in config:
validated[key] = config[key]
elif "default" in prop_schema:
validated[key] = prop_schema["default"]
elif key in required_fields:
raise ValueError(f"Missing required config: {key}")
return validated
def get_client(self) -> LLMClient:
"""Get or create LLM client with current config."""
if self._client is None:
self._client = LLMProviderFactory.create_client(
self._current_provider,
self._current_config,
)
return self._client
async def test_connection(
self,
test_prompt: str = "你好,请简单介绍一下自己。",
provider: str | None = None,
config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Connection testing.
Args:
test_prompt: Test prompt to send
provider: Optional provider to test (uses current if not specified)
config: Optional config to test (uses current if not specified)
Returns:
Test result with success status, response, and metrics
"""
import time
test_provider = provider or self._current_provider
test_config = config if config else self._current_config
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
if test_provider not in LLM_PROVIDERS:
return {
"success": False,
"error": f"Unsupported provider: {test_provider}",
}
try:
client = LLMProviderFactory.create_client(test_provider, test_config)
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": test_prompt}],
)
latency_ms = (time.time() - start_time) * 1000
await client.close()
return {
"success": True,
"response": response.content,
"latency_ms": round(latency_ms, 2),
"prompt_tokens": response.usage.get("prompt_tokens", 0),
"completion_tokens": response.usage.get("completion_tokens", 0),
"total_tokens": response.usage.get("total_tokens", 0),
"model": response.model,
"message": f"连接成功,模型: {response.model}",
}
except Exception as e:
logger.error(f"[AC-ASA-18] LLM test failed: {e}")
return {
"success": False,
"error": str(e),
"message": f"连接失败: {str(e)}",
}
async def close(self) -> None:
"""Close the current client."""
if self._client:
await self._client.close()
self._client = None
_llm_config_manager: LLMConfigManager | None = None
def get_llm_config_manager() -> LLMConfigManager:
"""Get or create LLM config manager instance."""
global _llm_config_manager
if _llm_config_manager is None:
_llm_config_manager = LLMConfigManager()
return _llm_config_manager