320 lines
11 KiB
Python
320 lines
11 KiB
Python
"""
|
|
Unit tests for LLM Adapter.
|
|
[AC-AISVC-02, AC-AISVC-06] Tests for LLM client interface.
|
|
|
|
Tests cover:
|
|
- Non-streaming generation
|
|
- Streaming generation
|
|
- Error handling
|
|
- Retry logic
|
|
"""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.services.llm.base import LLMConfig, LLMResponse, LLMStreamChunk
|
|
from app.services.llm.openai_client import (
|
|
LLMException,
|
|
OpenAIClient,
|
|
TimeoutException,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_settings():
|
|
"""Mock settings for testing."""
|
|
settings = MagicMock()
|
|
settings.llm_api_key = "test-api-key"
|
|
settings.llm_base_url = "https://api.openai.com/v1"
|
|
settings.llm_model = "gpt-4o-mini"
|
|
settings.llm_max_tokens = 2048
|
|
settings.llm_temperature = 0.7
|
|
settings.llm_timeout_seconds = 30
|
|
settings.llm_max_retries = 3
|
|
return settings
|
|
|
|
|
|
@pytest.fixture
|
|
def llm_client(mock_settings):
|
|
"""Create LLM client with mocked settings."""
|
|
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
|
client = OpenAIClient()
|
|
yield client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_messages():
|
|
"""Sample chat messages for testing."""
|
|
return [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"},
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_generate_response():
|
|
"""Sample non-streaming response from OpenAI API."""
|
|
return {
|
|
"id": "chatcmpl-123",
|
|
"object": "chat.completion",
|
|
"created": 1677652288,
|
|
"model": "gpt-4o-mini",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello! I'm doing well, thank you for asking!",
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 20,
|
|
"completion_tokens": 15,
|
|
"total_tokens": 35,
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_stream_chunks():
|
|
"""Sample streaming chunks from OpenAI API."""
|
|
return [
|
|
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n",
|
|
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n",
|
|
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" How\"},\"finish_reason\":null}]}\n",
|
|
"data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" can I help?\"},\"finish_reason\":\"stop\"}]}\n",
|
|
"data: [DONE]\n",
|
|
]
|
|
|
|
|
|
class TestOpenAIClientGenerate:
|
|
"""Tests for non-streaming generation. [AC-AISVC-02]"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_success(self, llm_client, mock_messages, mock_generate_response):
|
|
"""[AC-AISVC-02] Test successful non-streaming generation."""
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_generate_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch.object(
|
|
llm_client, "_get_client"
|
|
) as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
result = await llm_client.generate(mock_messages)
|
|
|
|
assert isinstance(result, LLMResponse)
|
|
assert result.content == "Hello! I'm doing well, thank you for asking!"
|
|
assert result.model == "gpt-4o-mini"
|
|
assert result.finish_reason == "stop"
|
|
assert result.usage["total_tokens"] == 35
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_custom_config(self, llm_client, mock_messages, mock_generate_response):
|
|
"""[AC-AISVC-02] Test generation with custom configuration."""
|
|
custom_config = LLMConfig(
|
|
model="gpt-4",
|
|
max_tokens=1024,
|
|
temperature=0.5,
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {**mock_generate_response, "model": "gpt-4"}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch.object(llm_client, "_get_client") as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
result = await llm_client.generate(mock_messages, config=custom_config)
|
|
|
|
assert result.model == "gpt-4"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_timeout_error(self, llm_client, mock_messages):
|
|
"""[AC-AISVC-02] Test timeout error handling."""
|
|
with patch.object(llm_client, "_get_client") as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
|
mock_get_client.return_value = mock_client
|
|
|
|
with pytest.raises(TimeoutException):
|
|
await llm_client.generate(mock_messages)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_api_error(self, llm_client, mock_messages):
|
|
"""[AC-AISVC-02] Test API error handling."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = '{"error": {"message": "Invalid API key"}}'
|
|
mock_response.json.return_value = {"error": {"message": "Invalid API key"}}
|
|
|
|
http_error = httpx.HTTPStatusError(
|
|
"Unauthorized",
|
|
request=MagicMock(),
|
|
response=mock_response,
|
|
)
|
|
|
|
with patch.object(llm_client, "_get_client") as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(side_effect=http_error)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
with pytest.raises(LLMException) as exc_info:
|
|
await llm_client.generate(mock_messages)
|
|
|
|
assert "Invalid API key" in str(exc_info.value.message)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_malformed_response(self, llm_client, mock_messages):
|
|
"""[AC-AISVC-02] Test handling of malformed response."""
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"invalid": "response"}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch.object(llm_client, "_get_client") as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
with pytest.raises(LLMException):
|
|
await llm_client.generate(mock_messages)
|
|
|
|
|
|
class MockAsyncStreamContext:
|
|
"""Mock async context manager for streaming."""
|
|
|
|
def __init__(self, response):
|
|
self._response = response
|
|
|
|
async def __aenter__(self):
|
|
return self._response
|
|
|
|
async def __aexit__(self, *args):
|
|
pass
|
|
|
|
|
|
class TestOpenAIClientStreamGenerate:
|
|
"""Tests for streaming generation. [AC-AISVC-06, AC-AISVC-07]"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_generate_success(self, llm_client, mock_messages, mock_stream_chunks):
|
|
"""[AC-AISVC-06, AC-AISVC-07] Test successful streaming generation."""
|
|
async def mock_aiter_lines():
|
|
for chunk in mock_stream_chunks:
|
|
yield chunk
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_response.aiter_lines = mock_aiter_lines
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.stream = MagicMock(return_value=MockAsyncStreamContext(mock_response))
|
|
|
|
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
|
chunks = []
|
|
async for chunk in llm_client.stream_generate(mock_messages):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 4
|
|
assert chunks[0].delta == "Hello"
|
|
assert chunks[-1].finish_reason == "stop"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_generate_timeout_error(self, llm_client, mock_messages):
|
|
"""[AC-AISVC-06] Test streaming timeout error handling."""
|
|
mock_client = AsyncMock()
|
|
|
|
class TimeoutContext:
|
|
async def __aenter__(self):
|
|
raise httpx.TimeoutException("Timeout")
|
|
async def __aexit__(self, *args):
|
|
pass
|
|
|
|
mock_client.stream = MagicMock(return_value=TimeoutContext())
|
|
|
|
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
|
with pytest.raises(TimeoutException):
|
|
async for _ in llm_client.stream_generate(mock_messages):
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_generate_api_error(self, llm_client, mock_messages):
|
|
"""[AC-AISVC-06] Test streaming API error handling."""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 500
|
|
mock_response.text = "Internal Server Error"
|
|
mock_response.json.return_value = {"error": {"message": "Internal Server Error"}}
|
|
|
|
http_error = httpx.HTTPStatusError(
|
|
"Internal Server Error",
|
|
request=MagicMock(),
|
|
response=mock_response,
|
|
)
|
|
|
|
mock_client = AsyncMock()
|
|
|
|
class ErrorContext:
|
|
async def __aenter__(self):
|
|
raise http_error
|
|
async def __aexit__(self, *args):
|
|
pass
|
|
|
|
mock_client.stream = MagicMock(return_value=ErrorContext())
|
|
|
|
with patch.object(llm_client, "_get_client", return_value=mock_client):
|
|
with pytest.raises(LLMException):
|
|
async for _ in llm_client.stream_generate(mock_messages):
|
|
pass
|
|
|
|
|
|
class TestOpenAIClientConfig:
|
|
"""Tests for LLM configuration."""
|
|
|
|
def test_default_config(self, mock_settings):
|
|
"""Test default configuration from settings."""
|
|
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
|
client = OpenAIClient()
|
|
|
|
assert client._model == "gpt-4o-mini"
|
|
assert client._default_config.max_tokens == 2048
|
|
assert client._default_config.temperature == 0.7
|
|
|
|
def test_custom_config_override(self, mock_settings):
|
|
"""Test custom configuration override."""
|
|
with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings):
|
|
client = OpenAIClient(
|
|
api_key="custom-key",
|
|
base_url="https://custom.api.com/v1",
|
|
model="gpt-4",
|
|
)
|
|
|
|
assert client._api_key == "custom-key"
|
|
assert client._base_url == "https://custom.api.com/v1"
|
|
assert client._model == "gpt-4"
|
|
|
|
|
|
class TestOpenAIClientClose:
|
|
"""Tests for client cleanup."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_client(self, llm_client):
|
|
"""Test client close releases resources."""
|
|
mock_client = AsyncMock()
|
|
mock_client.aclose = AsyncMock()
|
|
llm_client._client = mock_client
|
|
|
|
await llm_client.close()
|
|
|
|
mock_client.aclose.assert_called_once()
|
|
assert llm_client._client is None
|