ai-robot-core/ai-service/tests/test_llm_adapter.py

320 lines
11 KiB
Python

"""
Unit tests for LLM Adapter.
[AC-AISVC-02, AC-AISVC-06] Tests for LLM client interface.
Tests cover:
- Non-streaming generation
- Streaming generation
- Error handling
- Retry logic
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.llm.base import LLMConfig, LLMResponse, LLMStreamChunk
from app.services.llm.openai_client import (
LLMException,
OpenAIClient,
TimeoutException,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.llm_api_key = "test-api-key"
settings.llm_base_url = "https://api.openai.com/v1"
settings.llm_model = "gpt-4o-mini"
settings.llm_max_tokens = 2048
settings.llm_temperature = 0.7
settings.llm_timeout_seconds = 30
settings.llm_max_retries = 3
return settings
@pytest.fixture
def llm_client(mock_settings):
"""Create LLM client with mocked settings."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient()
yield client
@pytest.fixture
def mock_messages():
"""Sample chat messages for testing."""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
]
@pytest.fixture
def mock_generate_response():
"""Sample non-streaming response from OpenAI API."""
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I'm doing well, thank you for asking!",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35,
},
}
@pytest.fixture
def mock_stream_chunks():
"""Sample streaming chunks from OpenAI API."""
return [
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" How\"},\"finish_reason\":null}]}\n",
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" can I help?\"},\"finish_reason\":\"stop\"}]}\n",
"data: [DONE]\n",
]
class TestOpenAIClientGenerate:
"""Tests for non-streaming generation. [AC-AISVC-02]"""
@pytest.mark.asyncio
async def test_generate_success(self, llm_client, mock_messages, mock_generate_response):
"""[AC-AISVC-02] Test successful non-streaming generation."""
mock_response = MagicMock()
mock_response.json.return_value = mock_generate_response
mock_response.raise_for_status = MagicMock()
with patch.object(
llm_client, "_get_client"
) as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
result = await llm_client.generate(mock_messages)
assert isinstance(result, LLMResponse)
assert result.content == "Hello! I'm doing well, thank you for asking!"
assert result.model == "gpt-4o-mini"
assert result.finish_reason == "stop"
assert result.usage["total_tokens"] == 35
@pytest.mark.asyncio
async def test_generate_with_custom_config(self, llm_client, mock_messages, mock_generate_response):
"""[AC-AISVC-02] Test generation with custom configuration."""
custom_config = LLMConfig(
model="gpt-4",
max_tokens=1024,
temperature=0.5,
)
mock_response = MagicMock()
mock_response.json.return_value = {**mock_generate_response, "model": "gpt-4"}
mock_response.raise_for_status = MagicMock()
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
result = await llm_client.generate(mock_messages, config=custom_config)
assert result.model == "gpt-4"
@pytest.mark.asyncio
async def test_generate_timeout_error(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test timeout error handling."""
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
mock_get_client.return_value = mock_client
with pytest.raises(TimeoutException):
await llm_client.generate(mock_messages)
@pytest.mark.asyncio
async def test_generate_api_error(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test API error handling."""
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = '{"error": {"message": "Invalid API key"}}'
mock_response.json.return_value = {"error": {"message": "Invalid API key"}}
http_error = httpx.HTTPStatusError(
"Unauthorized",
request=MagicMock(),
response=mock_response,
)
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=http_error)
mock_get_client.return_value = mock_client
with pytest.raises(LLMException) as exc_info:
await llm_client.generate(mock_messages)
assert "Invalid API key" in str(exc_info.value.message)
@pytest.mark.asyncio
async def test_generate_malformed_response(self, llm_client, mock_messages):
"""[AC-AISVC-02] Test handling of malformed response."""
mock_response = MagicMock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status = MagicMock()
with patch.object(llm_client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_get_client.return_value = mock_client
with pytest.raises(LLMException):
await llm_client.generate(mock_messages)
class MockAsyncStreamContext:
"""Mock async context manager for streaming."""
def __init__(self, response):
self._response = response
async def __aenter__(self):
return self._response
async def __aexit__(self, *args):
pass
class TestOpenAIClientStreamGenerate:
"""Tests for streaming generation. [AC-AISVC-06, AC-AISVC-07]"""
@pytest.mark.asyncio
async def test_stream_generate_success(self, llm_client, mock_messages, mock_stream_chunks):
"""[AC-AISVC-06, AC-AISVC-07] Test successful streaming generation."""
async def mock_aiter_lines():
for chunk in mock_stream_chunks:
yield chunk
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.aiter_lines = mock_aiter_lines
mock_client = AsyncMock()
mock_client.stream = MagicMock(return_value=MockAsyncStreamContext(mock_response))
with patch.object(llm_client, "_get_client", return_value=mock_client):
chunks = []
async for chunk in llm_client.stream_generate(mock_messages):
chunks.append(chunk)
assert len(chunks) == 4
assert chunks[0].delta == "Hello"
assert chunks[-1].finish_reason == "stop"
@pytest.mark.asyncio
async def test_stream_generate_timeout_error(self, llm_client, mock_messages):
"""[AC-AISVC-06] Test streaming timeout error handling."""
mock_client = AsyncMock()
class TimeoutContext:
async def __aenter__(self):
raise httpx.TimeoutException("Timeout")
async def __aexit__(self, *args):
pass
mock_client.stream = MagicMock(return_value=TimeoutContext())
with patch.object(llm_client, "_get_client", return_value=mock_client):
with pytest.raises(TimeoutException):
async for _ in llm_client.stream_generate(mock_messages):
pass
@pytest.mark.asyncio
async def test_stream_generate_api_error(self, llm_client, mock_messages):
"""[AC-AISVC-06] Test streaming API error handling."""
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_response.json.return_value = {"error": {"message": "Internal Server Error"}}
http_error = httpx.HTTPStatusError(
"Internal Server Error",
request=MagicMock(),
response=mock_response,
)
mock_client = AsyncMock()
class ErrorContext:
async def __aenter__(self):
raise http_error
async def __aexit__(self, *args):
pass
mock_client.stream = MagicMock(return_value=ErrorContext())
with patch.object(llm_client, "_get_client", return_value=mock_client):
with pytest.raises(LLMException):
async for _ in llm_client.stream_generate(mock_messages):
pass
class TestOpenAIClientConfig:
"""Tests for LLM configuration."""
def test_default_config(self, mock_settings):
"""Test default configuration from settings."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient()
assert client._model == "gpt-4o-mini"
assert client._default_config.max_tokens == 2048
assert client._default_config.temperature == 0.7
def test_custom_config_override(self, mock_settings):
"""Test custom configuration override."""
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
client = OpenAIClient(
api_key="custom-key",
base_url="https://custom.api.com/v1",
model="gpt-4",
)
assert client._api_key == "custom-key"
assert client._base_url == "https://custom.api.com/v1"
assert client._model == "gpt-4"
class TestOpenAIClientClose:
"""Tests for client cleanup."""
@pytest.mark.asyncio
async def test_close_client(self, llm_client):
"""Test client close releases resources."""
mock_client = AsyncMock()
mock_client.aclose = AsyncMock()
llm_client._client = mock_client
await llm_client.close()
mock_client.aclose.assert_called_once()
assert llm_client._client is None