ai-robot-core/ai-service/app/services/llm/factory.py

671 lines
23 KiB
Python

"""
LLM Provider Factory and Configuration Management.
[AC-ASA-14, AC-ASA-15, AC-ASA-16, AC-ASA-17, AC-ASA-18] LLM provider management.
Design pattern: Factory pattern for pluggable LLM providers.
"""
import json
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
import redis
from app.core.config import get_settings
from app.services.llm.base import LLMClient, LLMConfig
from app.services.llm.openai_client import OpenAIClient
logger = logging.getLogger(__name__)
LLM_CONFIG_FILE = Path("config/llm_config.json")
LLM_CONFIG_REDIS_KEY = "ai_service:config:llm"
class LLMUsageType(str, Enum):
"""LLM usage type for different scenarios."""
CHAT = "chat"
KB_PROCESSING = "kb_processing"
LLM_USAGE_DISPLAY_NAMES: dict[LLMUsageType, str] = {
LLMUsageType.CHAT: "对话模型",
LLMUsageType.KB_PROCESSING: "知识库处理模型",
}
LLM_USAGE_DESCRIPTIONS: dict[LLMUsageType, str] = {
LLMUsageType.CHAT: "用于 Agent 对话、问答等交互场景",
LLMUsageType.KB_PROCESSING: "用于知识库文档上传、元数据推断、文档处理等场景",
}
@dataclass
class LLMProviderInfo:
"""Information about an LLM provider."""
name: str
display_name: str
description: str
config_schema: dict[str, Any]
LLM_PROVIDERS: dict[str, LLMProviderInfo] = {
"openai": LLMProviderInfo(
name="openai",
display_name="OpenAI",
description="OpenAI GPT 系列模型 (GPT-4, GPT-3.5 等)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.openai.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "gpt-4o-mini",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
"timeout_seconds": {
"type": "integer",
"title": "请求超时(秒)",
"description": "LLM 请求超时时间(秒)",
"default": 60,
"minimum": 5,
"maximum": 180,
},
"max_retries": {
"type": "integer",
"title": "最大重试次数",
"description": "请求失败后的最大重试次数",
"default": 3,
"minimum": 0,
"maximum": 10,
},
},
"required": ["api_key"],
},
),
"ollama": LLMProviderInfo(
name="ollama",
display_name="Ollama",
description="Ollama 本地模型 (Llama, Qwen 等)",
config_schema={
"type": "object",
"properties": {
"base_url": {
"type": "string",
"title": "Ollama API 地址",
"description": "Ollama API 地址",
"default": "http://localhost:11434/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称",
"default": "llama3.2",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": [],
},
),
"deepseek": LLMProviderInfo(
name="deepseek",
display_name="DeepSeek",
description="DeepSeek 大模型 (deepseek-chat, deepseek-coder)",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "DeepSeek API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "API Base URL",
"description": "API Base URL",
"default": "https://api.deepseek.com/v1",
},
"model": {
"type": "string",
"title": "模型名称",
"description": "模型名称 (deepseek-chat, deepseek-coder)",
"default": "deepseek-chat",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key"],
},
),
"azure": LLMProviderInfo(
name="azure",
display_name="Azure OpenAI",
description="Azure OpenAI 服务",
config_schema={
"type": "object",
"properties": {
"api_key": {
"type": "string",
"title": "API Key",
"description": "API Key",
"required": True,
},
"base_url": {
"type": "string",
"title": "Azure Endpoint",
"description": "Azure Endpoint",
"required": True,
},
"model": {
"type": "string",
"title": "部署名称",
"description": "部署名称",
"required": True,
},
"api_version": {
"type": "string",
"title": "API 版本",
"description": "API 版本",
"default": "2024-02-15-preview",
},
"max_tokens": {
"type": "integer",
"title": "最大输出 Token 数",
"description": "最大输出 Token 数",
"default": 2048,
},
"temperature": {
"type": "number",
"title": "温度参数",
"description": "温度参数 (0-2)",
"default": 0.7,
"minimum": 0,
"maximum": 2,
},
},
"required": ["api_key", "base_url", "model"],
},
),
}
class LLMProviderFactory:
"""
Factory for creating LLM clients.
[AC-ASA-14, AC-ASA-15] Dynamic provider creation.
"""
@classmethod
def get_providers(cls) -> list[LLMProviderInfo]:
"""Get all registered LLM providers."""
return list(LLM_PROVIDERS.values())
@classmethod
def get_provider_info(cls, name: str) -> LLMProviderInfo | None:
"""Get provider info by name."""
return LLM_PROVIDERS.get(name)
@classmethod
def create_client(
cls,
provider: str,
config: dict[str, Any],
) -> LLMClient:
"""
Create an LLM client for the specified provider.
[AC-ASA-15] Factory method for client creation.
Args:
provider: Provider name (openai, ollama, azure)
config: Provider configuration
Returns:
LLMClient instance
Raises:
ValueError: If provider is not supported
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
if provider in ("openai", "ollama", "azure", "deepseek"):
return OpenAIClient(
api_key=config.get("api_key"),
base_url=config.get("base_url"),
model=config.get("model"),
default_config=LLMConfig(
model=config.get("model", "gpt-4o-mini"),
max_tokens=config.get("max_tokens", 2048),
temperature=config.get("temperature", 0.7),
timeout_seconds=config.get("timeout_seconds", 60),
max_retries=config.get("max_retries", 3),
),
)
raise ValueError(f"Unsupported LLM provider: {provider}")
class LLMConfigManager:
"""
Manager for LLM configuration.
[AC-ASA-16, AC-ASA-17, AC-ASA-18] Configuration management with hot-reload and persistence.
Supports multiple LLM usage types (chat, kb_processing).
"""
def __init__(self):
from app.core.config import get_settings
settings = get_settings()
self._settings = settings
self._redis_client: redis.Redis | None = None
default_config = {
"api_key": settings.llm_api_key,
"base_url": settings.llm_base_url,
"model": settings.llm_model,
"max_tokens": settings.llm_max_tokens,
"temperature": settings.llm_temperature,
"timeout_seconds": settings.llm_timeout_seconds,
"max_retries": settings.llm_max_retries,
}
self._configs: dict[LLMUsageType, dict[str, Any]] = {
LLMUsageType.CHAT: {
"provider": settings.llm_provider,
"config": default_config.copy(),
},
LLMUsageType.KB_PROCESSING: {
"provider": settings.llm_provider,
"config": default_config.copy(),
},
}
self._clients: dict[LLMUsageType, LLMClient | None] = {
LLMUsageType.CHAT: None,
LLMUsageType.KB_PROCESSING: None,
}
self._load_from_redis()
self._load_from_file()
@property
def chat_provider(self) -> str:
return self._configs[LLMUsageType.CHAT]["provider"]
@property
def kb_processing_provider(self) -> str:
return self._configs[LLMUsageType.KB_PROCESSING]["provider"]
@property
def chat_config(self) -> dict[str, Any]:
return self._configs[LLMUsageType.CHAT]["config"].copy()
@property
def kb_processing_config(self) -> dict[str, Any]:
return self._configs[LLMUsageType.KB_PROCESSING]["config"].copy()
def _load_from_redis(self) -> None:
"""Load configuration from Redis if exists."""
try:
if not self._settings.redis_enabled:
return
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
saved_raw = self._redis_client.get(LLM_CONFIG_REDIS_KEY)
if not saved_raw:
return
saved = json.loads(saved_raw)
for usage_type in LLMUsageType:
type_key = usage_type.value
if type_key in saved:
self._configs[usage_type] = {
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
}
elif "provider" in saved:
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from Redis")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from Redis: {e}")
def _save_to_redis(self) -> None:
"""Save configuration to Redis."""
try:
if not self._settings.redis_enabled:
return
if self._redis_client is None:
self._redis_client = redis.from_url(
self._settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
save_data = {
usage_type.value: {
"provider": config["provider"],
"config": config["config"],
}
for usage_type, config in self._configs.items()
}
self._redis_client.set(
LLM_CONFIG_REDIS_KEY,
json.dumps(save_data, ensure_ascii=False),
)
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to Redis")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to save LLM config to Redis: {e}")
def _load_from_file(self) -> None:
"""Load configuration from file if exists."""
try:
if LLM_CONFIG_FILE.exists():
with open(LLM_CONFIG_FILE, encoding='utf-8') as f:
saved = json.load(f)
for usage_type in LLMUsageType:
type_key = usage_type.value
if type_key in saved:
self._configs[usage_type] = {
"provider": saved[type_key].get("provider", self._configs[usage_type]["provider"]),
"config": {**self._configs[usage_type]["config"], **saved[type_key].get("config", {})},
}
elif "provider" in saved:
self._configs[usage_type]["provider"] = saved.get("provider", self._configs[usage_type]["provider"])
self._configs[usage_type]["config"] = {**self._configs[usage_type]["config"], **saved.get("config", {})}
logger.info(f"[AC-ASA-16] Loaded multi-usage LLM config from file")
except Exception as e:
logger.warning(f"[AC-ASA-16] Failed to load LLM config from file: {e}")
def _save_to_file(self) -> None:
"""Save configuration to file."""
try:
LLM_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
save_data = {
usage_type.value: {
"provider": config["provider"],
"config": config["config"],
}
for usage_type, config in self._configs.items()
}
with open(LLM_CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
logger.info(f"[AC-ASA-16] Saved multi-usage LLM config to file")
except Exception as e:
logger.error(f"[AC-ASA-16] Failed to save LLM config to file: {e}")
def get_current_config(self, usage_type: LLMUsageType | None = None) -> dict[str, Any]:
"""Get current LLM configuration for specified usage type or all configs."""
if usage_type:
config = self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
return {
"usage_type": usage_type.value,
"provider": config["provider"],
"config": config["config"].copy(),
}
return {
usage_type.value: {
"provider": config["provider"],
"config": config["config"].copy(),
}
for usage_type, config in self._configs.items()
}
def get_config_for_usage(self, usage_type: LLMUsageType) -> dict[str, Any]:
"""Get configuration for a specific usage type."""
return self._configs.get(usage_type, self._configs[LLMUsageType.CHAT])
async def update_config(
self,
provider: str,
config: dict[str, Any],
usage_type: LLMUsageType | None = None,
) -> bool:
"""
Update LLM configuration.
[AC-ASA-16] Hot-reload configuration with persistence.
Args:
provider: Provider name
config: New configuration
usage_type: Usage type to update (None = update all)
Returns:
True if update successful
"""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
target_usage_types = [usage_type] if usage_type else list(LLMUsageType)
for ut in target_usage_types:
if self._clients[ut]:
await self._clients[ut].close()
self._clients[ut] = None
self._configs[ut]["provider"] = provider
self._configs[ut]["config"] = validated_config.copy()
self._save_to_redis()
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: provider={provider}, usage={usage_type or 'all'}")
return True
async def update_usage_config(
self,
usage_type: LLMUsageType,
provider: str,
config: dict[str, Any],
) -> bool:
"""Update configuration for a specific usage type."""
if provider not in LLM_PROVIDERS:
raise ValueError(f"Unsupported LLM provider: {provider}")
provider_info = LLM_PROVIDERS[provider]
validated_config = self._validate_config(provider_info, config)
if self._clients[usage_type]:
await self._clients[usage_type].close()
self._clients[usage_type] = None
self._configs[usage_type]["provider"] = provider
self._configs[usage_type]["config"] = validated_config
self._save_to_redis()
self._save_to_file()
logger.info(f"[AC-ASA-16] LLM config updated: usage={usage_type.value}, provider={provider}")
return True
def _validate_config(
self,
provider_info: LLMProviderInfo,
config: dict[str, Any],
) -> dict[str, Any]:
"""Validate configuration against provider schema."""
schema_props = provider_info.config_schema.get("properties", {})
required_fields = provider_info.config_schema.get("required", [])
validated = {}
for key, prop_schema in schema_props.items():
if key in config:
validated[key] = config[key]
elif "default" in prop_schema:
validated[key] = prop_schema["default"]
elif key in required_fields:
raise ValueError(f"Missing required config: {key}")
return validated
def get_client(self, usage_type: LLMUsageType | None = None) -> LLMClient:
"""Get or create LLM client with config for specified usage type."""
ut = usage_type or LLMUsageType.CHAT
if self._clients[ut] is None:
config = self._configs[ut]
self._clients[ut] = LLMProviderFactory.create_client(
config["provider"],
config["config"],
)
return self._clients[ut]
def get_chat_client(self) -> LLMClient:
"""Get LLM client for chat/dialogue."""
return self.get_client(LLMUsageType.CHAT)
def get_kb_processing_client(self) -> LLMClient:
"""Get LLM client for KB processing."""
return self.get_client(LLMUsageType.KB_PROCESSING)
async def test_connection(
self,
test_prompt: str = "你好,请简单介绍一下自己。",
provider: str | None = None,
config: dict[str, Any] | None = None,
usage_type: LLMUsageType | None = None,
) -> dict[str, Any]:
"""
Test LLM connection.
[AC-ASA-17, AC-ASA-18] Connection testing.
Args:
test_prompt: Test prompt to send
provider: Optional provider to test (uses current if not specified)
config: Optional config to test (uses current if not specified)
usage_type: Usage type for config lookup
Returns:
Test result with success status, response, and metrics
"""
import time
if usage_type and not provider:
usage_config = self._configs[usage_type]
test_provider = usage_config["provider"]
test_config = usage_config["config"]
else:
test_provider = provider or self._configs[LLMUsageType.CHAT]["provider"]
test_config = config if config else self._configs[LLMUsageType.CHAT]["config"]
logger.info(f"[AC-ASA-17] Test connection: provider={test_provider}, model={test_config.get('model')}")
if test_provider not in LLM_PROVIDERS:
return {
"success": False,
"error": f"Unsupported provider: {test_provider}",
}
try:
client = LLMProviderFactory.create_client(test_provider, test_config)
start_time = time.time()
response = await client.generate(
messages=[{"role": "user", "content": test_prompt}],
)
latency_ms = (time.time() - start_time) * 1000
await client.close()
return {
"success": True,
"response": response.content,
"latency_ms": round(latency_ms, 2),
"prompt_tokens": response.usage.get("prompt_tokens", 0),
"completion_tokens": response.usage.get("completion_tokens", 0),
"total_tokens": response.usage.get("total_tokens", 0),
"model": response.model,
"message": f"连接成功,模型: {response.model}",
}
except Exception as e:
logger.error(f"[AC-ASA-18] LLM test failed: {e}")
return {
"success": False,
"error": str(e),
"message": f"连接失败: {str(e)}",
}
async def close(self) -> None:
"""Close all clients."""
for client in self._clients.values():
if client:
await client.close()
self._clients = {ut: None for ut in LLMUsageType}
_llm_config_manager: LLMConfigManager | None = None
def get_llm_config_manager() -> LLMConfigManager:
"""Get or create LLM config manager instance."""
global _llm_config_manager
if _llm_config_manager is None:
_llm_config_manager = LLMConfigManager()
return _llm_config_manager