ai-robot-core/ai-service/app/services/llm/openai_client.py

404 lines
14 KiB
Python

"""
OpenAI-compatible LLM client implementation.
[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API.
Design reference: design.md Section 8.1 - LLMClient interface
- Uses langchain-openai or official SDK pattern
- Supports generate and stream_generate
"""
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
import httpx
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.core.config import get_settings
from app.core.exceptions import AIServiceException, ErrorCode, TimeoutException
from app.services.llm.base import (
LLMClient,
LLMConfig,
LLMResponse,
LLMStreamChunk,
ToolCall,
ToolDefinition,
)
logger = logging.getLogger(__name__)
class LLMException(AIServiceException):
"""Exception raised when LLM operations fail."""
def __init__(self, message: str, details: list[dict] | None = None):
super().__init__(
code=ErrorCode.LLM_ERROR,
message=message,
status_code=503,
details=details,
)
class OpenAIClient(LLMClient):
"""
OpenAI-compatible LLM client.
[AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API.
Supports:
- OpenAI API (official)
- OpenAI-compatible endpoints (Azure, local models, etc.)
"""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
default_config: LLMConfig | None = None,
):
settings = get_settings()
self._api_key = api_key or settings.llm_api_key
self._base_url = (base_url or settings.llm_base_url).rstrip("/")
self._model = model or settings.llm_model
self._default_config = default_config or LLMConfig(
model=self._model,
max_tokens=settings.llm_max_tokens,
temperature=settings.llm_temperature,
timeout_seconds=settings.llm_timeout_seconds,
max_retries=settings.llm_max_retries,
)
self._client: httpx.AsyncClient | None = None
self._client_timeout_seconds: int | None = None
def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient:
"""Get or create HTTP client.
Recreate client when timeout changes to ensure runtime config takes effect.
"""
if self._client is None or self._client_timeout_seconds != timeout_seconds:
if self._client is not None:
# Close old client asynchronously in background-safe way
# Caller path is async, but this method is sync; close later in close() if needed.
pass
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout_seconds),
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
self._client_timeout_seconds = timeout_seconds
return self._client
def _build_request_body(
self,
messages: list[dict[str, str]],
config: LLMConfig,
stream: bool = False,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Build request body for OpenAI API."""
body: dict[str, Any] = {
"model": config.model,
"messages": messages,
"max_tokens": config.max_tokens,
"temperature": config.temperature,
"top_p": config.top_p,
"stream": stream,
}
if tools:
body["tools"] = [tool.to_openai_format() for tool in tools]
if tool_choice:
body["tool_choice"] = tool_choice
body.update(config.extra_params)
body.update(kwargs)
return body
@retry(
retry=retry_if_exception_type(httpx.TimeoutException),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
)
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns complete response for ChatResponse.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Returns:
LLMResponse with generated content, tool_calls, and metadata.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(
messages, effective_config, stream=False,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-02] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-02] ========== FULL PROMPT TO AI ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
logger.info(f"[AC-AISVC-02] [{i}] role={role}, content_length={len(content)}")
logger.info(f"[AC-AISVC-02] [{i}] content:\n{content}")
logger.info("[AC-AISVC-02] ======================================")
try:
response = await client.post(
f"{self._base_url}/chat/completions",
json=body,
)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-02] LLM request timeout: {e}")
raise TimeoutException(message=f"LLM request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-02] LLM API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
except json.JSONDecodeError as e:
logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}")
raise LLMException(message=f"Failed to parse LLM response: {e}")
try:
choice = data["choices"][0]
message = choice["message"]
content = message.get("content")
usage = data.get("usage", {})
finish_reason = choice.get("finish_reason", "stop")
tool_calls = self._parse_tool_calls(message)
logger.info(
f"[AC-AISVC-02] Generated response: "
f"tokens={usage.get('total_tokens', 'N/A')}, "
f"finish_reason={finish_reason}, "
f"tool_calls={len(tool_calls)}"
)
return LLMResponse(
content=content,
model=data.get("model", effective_config.model),
usage=usage,
finish_reason=finish_reason,
tool_calls=tool_calls,
metadata={"raw_response": data},
)
except (KeyError, IndexError) as e:
logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}")
raise LLMException(
message=f"Unexpected LLM response format: {e}",
details=[{"response": str(data)}],
)
def _parse_tool_calls(self, message: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from LLM response message."""
tool_calls = []
raw_tool_calls = message.get("tool_calls", [])
for tc in raw_tool_calls:
if tc.get("type") == "function":
func = tc.get("function", {})
try:
arguments = json.loads(func.get("arguments", "{}"))
except json.JSONDecodeError:
arguments = {}
tool_calls.append(ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=arguments,
))
return tool_calls
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: str | dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[LLMStreamChunk, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE.
Args:
messages: List of chat messages with 'role' and 'content'.
config: Optional LLM configuration overrides.
tools: Optional list of tools for function calling.
tool_choice: Tool choice strategy ("auto", "none", or specific tool).
**kwargs: Additional provider-specific parameters.
Yields:
LLMStreamChunk with incremental content.
Raises:
LLMException: If generation fails.
TimeoutException: If request times out.
"""
effective_config = config or self._default_config
client = self._get_client(effective_config.timeout_seconds)
body = self._build_request_body(
messages, effective_config, stream=True,
tools=tools, tool_choice=tool_choice, **kwargs
)
logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}")
if tools:
logger.info(f"[AC-AISVC-06] Function calling enabled with {len(tools)} tools")
logger.info("[AC-AISVC-06] ========== FULL PROMPT TO AI (STREAMING) ==========")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
logger.info(f"[AC-AISVC-06] [{i}] role={role}, content_length={len(content)}")
logger.info(f"[AC-AISVC-06] [{i}] content:\n{content}")
logger.info("[AC-AISVC-06] ======================================")
try:
async with client.stream(
"POST",
f"{self._base_url}/chat/completions",
json=body,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
json_str = line[6:]
try:
chunk_data = json.loads(json_str)
chunk = self._parse_stream_chunk(chunk_data, effective_config.model)
if chunk:
yield chunk
except json.JSONDecodeError as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
continue
except httpx.TimeoutException as e:
logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}")
raise TimeoutException(message=f"LLM streaming request timed out: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}")
error_detail = self._parse_error_response(e.response)
raise LLMException(
message=f"LLM streaming API error: {error_detail}",
details=[{"status_code": e.response.status_code, "response": error_detail}],
)
logger.info("[AC-AISVC-06] Streaming generation completed")
def _parse_stream_chunk(
self,
data: dict[str, Any],
model: str,
) -> LLMStreamChunk | None:
"""Parse a streaming chunk from OpenAI API."""
try:
choices = data.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
content = delta.get("content", "")
finish_reason = choices[0].get("finish_reason")
if not content and not finish_reason:
return None
return LLMStreamChunk(
delta=content,
model=data.get("model", model),
finish_reason=finish_reason,
metadata={"raw_chunk": data},
)
except (KeyError, IndexError) as e:
logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}")
return None
def _parse_error_response(self, response: httpx.Response) -> str:
"""Parse error response from API."""
try:
data = response.json()
if "error" in data:
error = data["error"]
if isinstance(error, dict):
return error.get("message", str(error))
return str(error)
return response.text
except Exception:
return response.text
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
_llm_client: OpenAIClient | None = None
def get_llm_client() -> OpenAIClient:
"""Get or create LLM client instance."""
global _llm_client
if _llm_client is None:
_llm_client = OpenAIClient()
return _llm_client
async def close_llm_client() -> None:
"""Close the global LLM client."""
global _llm_client
if _llm_client:
await _llm_client.close()
_llm_client = None