[AC-TEST] test: 新增单元测试和集成测试
- 新增 test_image_parser 图片解析器测试 - 新增 test_llm_multi_usage_config LLM 多用途配置测试 - 新增 test_markdown_chunker Markdown 分块测试 - 新增 test_metadata_auto_inference 元数据推断测试 - 新增 test_mid_dialogue_integration 对话集成测试 - 新增 test_retrieval_strategy 检索策略测试 - 新增 test_retrieval_strategy_integration 检索策略集成测试
This commit is contained in:
parent
1490235b8f
commit
a6276522c8
|
|
@ -0,0 +1,375 @@
|
|||
"""
|
||||
Unit tests for ImageParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.document.image_parser import (
|
||||
ImageChunk,
|
||||
ImageParseResult,
|
||||
ImageParser,
|
||||
)
|
||||
|
||||
|
||||
class TestImageParserBasics:
|
||||
"""Test basic functionality of ImageParser."""
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that ImageParser supports correct extensions."""
|
||||
parser = ImageParser()
|
||||
extensions = parser.get_supported_extensions()
|
||||
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
||||
assert extensions == expected_extensions
|
||||
|
||||
def test_get_mime_type(self):
|
||||
"""Test MIME type mapping."""
|
||||
parser = ImageParser()
|
||||
|
||||
assert parser._get_mime_type(".jpg") == "image/jpeg"
|
||||
assert parser._get_mime_type(".jpeg") == "image/jpeg"
|
||||
assert parser._get_mime_type(".png") == "image/png"
|
||||
assert parser._get_mime_type(".gif") == "image/gif"
|
||||
assert parser._get_mime_type(".webp") == "image/webp"
|
||||
assert parser._get_mime_type(".bmp") == "image/bmp"
|
||||
assert parser._get_mime_type(".tiff") == "image/tiff"
|
||||
assert parser._get_mime_type(".tif") == "image/tiff"
|
||||
assert parser._get_mime_type(".unknown") == "image/jpeg"
|
||||
|
||||
|
||||
class TestImageChunkParsing:
|
||||
"""Test LLM response parsing functionality."""
|
||||
|
||||
def test_extract_json_from_plain_json(self):
|
||||
"""Test extracting JSON from plain JSON response."""
|
||||
parser = ImageParser()
|
||||
|
||||
json_str = '{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello", "chunk_type": "text", "keywords": ["key"]}]}'
|
||||
result = parser._extract_json(json_str)
|
||||
|
||||
assert result == json_str
|
||||
|
||||
def test_extract_json_from_markdown(self):
|
||||
"""Test extracting JSON from markdown code block."""
|
||||
parser = ImageParser()
|
||||
|
||||
markdown = """Here is the analysis:
|
||||
|
||||
```json
|
||||
{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello"}]}
|
||||
```
|
||||
|
||||
Hope this helps!"""
|
||||
|
||||
result = parser._extract_json(markdown)
|
||||
assert "image_summary" in result
|
||||
assert "test" in result
|
||||
|
||||
def test_extract_json_from_text_with_json(self):
|
||||
"""Test extracting JSON from text with JSON embedded."""
|
||||
parser = ImageParser()
|
||||
|
||||
text = "The result is: {'image_summary': 'summary', 'chunks': []}"
|
||||
|
||||
result = parser._extract_json(text)
|
||||
assert "image_summary" in result
|
||||
assert "chunks" in result
|
||||
|
||||
def test_parse_llm_response_valid_json(self):
|
||||
"""Test parsing valid JSON response from LLM."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "测试图片",
|
||||
"total_chunks": 2,
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "这是第一块内容",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["测试", "内容"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "这是第二块内容,包含表格数据",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["表格", "数据"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert result.image_summary == "测试图片"
|
||||
assert len(result.chunks) == 2
|
||||
assert result.chunks[0].content == "这是第一块内容"
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
assert result.chunks[0].keywords == ["测试", "内容"]
|
||||
assert result.chunks[1].chunk_type == "table"
|
||||
assert result.chunks[1].keywords == ["表格", "数据"]
|
||||
|
||||
def test_parse_llm_response_empty_chunks(self):
|
||||
"""Test handling response with empty chunks."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "测试",
|
||||
"chunks": []
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == response
|
||||
|
||||
def test_parse_llm_response_invalid_json(self):
|
||||
"""Test handling invalid JSON response with fallback."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = "This is not JSON at all"
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == "This is not JSON at all"
|
||||
|
||||
def test_parse_llm_response_partial_json(self):
|
||||
"""Test handling response with partial/invalid JSON uses fallback."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = '{"image_summary": "test" some text here {"chunks": []}'
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].content == response
|
||||
|
||||
|
||||
class TestImageChunkDataClass:
|
||||
"""Test ImageChunk dataclass functionality."""
|
||||
|
||||
def test_image_chunk_creation(self):
|
||||
"""Test creating ImageChunk."""
|
||||
chunk = ImageChunk(
|
||||
chunk_index=0,
|
||||
content="Test content",
|
||||
chunk_type="text",
|
||||
keywords=["test", "content"]
|
||||
)
|
||||
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.content == "Test content"
|
||||
assert chunk.chunk_type == "text"
|
||||
assert chunk.keywords == ["test", "content"]
|
||||
|
||||
def test_image_chunk_default_values(self):
|
||||
"""Test ImageChunk with default values."""
|
||||
chunk = ImageChunk(chunk_index=0, content="Test")
|
||||
|
||||
assert chunk.chunk_type == "text"
|
||||
assert chunk.keywords == []
|
||||
|
||||
def test_image_parse_result_creation(self):
|
||||
"""Test creating ImageParseResult."""
|
||||
chunks = [
|
||||
ImageChunk(chunk_index=0, content="Chunk 1"),
|
||||
ImageChunk(chunk_index=1, content="Chunk 2"),
|
||||
]
|
||||
|
||||
result = ImageParseResult(
|
||||
image_summary="Test summary",
|
||||
chunks=chunks,
|
||||
raw_text="Chunk 1\n\nChunk 2",
|
||||
source_path="/path/to/image.png",
|
||||
file_size=1024,
|
||||
metadata={"format": "png"}
|
||||
)
|
||||
|
||||
assert result.image_summary == "Test summary"
|
||||
assert len(result.chunks) == 2
|
||||
assert result.raw_text == "Chunk 1\n\nChunk 2"
|
||||
assert result.file_size == 1024
|
||||
assert result.metadata["format"] == "png"
|
||||
|
||||
|
||||
class TestChunkTypes:
|
||||
"""Test different chunk types."""
|
||||
|
||||
def test_text_chunk_type(self):
|
||||
"""Test text chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Text content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Plain text content",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["text"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
|
||||
def test_table_chunk_type(self):
|
||||
"""Test table chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Table content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Name | Age\n---- | ---\nJohn | 30",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["table", "data"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "table"
|
||||
|
||||
def test_chart_chunk_type(self):
|
||||
"""Test chart chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "Chart content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "Bar chart showing sales data",
|
||||
"chunk_type": "chart",
|
||||
"keywords": ["chart", "sales"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "chart"
|
||||
|
||||
def test_list_chunk_type(self):
|
||||
"""Test list chunk type."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "List content",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "1. First item\n2. Second item\n3. Third item",
|
||||
"chunk_type": "list",
|
||||
"keywords": ["list", "items"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
assert result.chunks[0].chunk_type == "list"
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_single_chunk_scenario(self):
|
||||
"""Test single chunk scenario - simple image with one main content."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "简单文档截图",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "这是一段完整的文档内容,包含所有的信息要点。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["文档", "信息"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
assert result.image_summary == "简单文档截图"
|
||||
assert result.raw_text == "这是一段完整的文档内容,包含所有的信息要点。"
|
||||
|
||||
def test_multi_chunk_scenario(self):
|
||||
"""Test multi-chunk scenario - complex image with multiple sections."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "多段落文档",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "第一章:介绍部分,介绍项目的背景和目标。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第一章", "介绍"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "第二章:技术架构,包括前端、后端和数据库设计。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第二章", "架构"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 2,
|
||||
"content": "第三章:部署流程,包含开发环境和生产环境配置。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["第三章", "部署"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 3
|
||||
assert "第一章" in result.chunks[0].content
|
||||
assert "第二章" in result.chunks[1].content
|
||||
assert "第三章" in result.chunks[2].content
|
||||
assert result.raw_text.count("\n\n") == 2
|
||||
|
||||
def test_mixed_content_scenario(self):
|
||||
"""Test mixed content scenario - text and table."""
|
||||
parser = ImageParser()
|
||||
|
||||
response = json.dumps({
|
||||
"image_summary": "混合内容图片",
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_index": 0,
|
||||
"content": "产品介绍:本文档介绍我们的核心产品功能。",
|
||||
"chunk_type": "text",
|
||||
"keywords": ["产品", "功能"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 1,
|
||||
"content": "产品规格表:\n| 型号 | 价格 | 库存 |\n| --- | --- | --- |\n| A1 | 100 | 50 |",
|
||||
"chunk_type": "table",
|
||||
"keywords": ["规格", "价格", "库存"]
|
||||
},
|
||||
{
|
||||
"chunk_index": 2,
|
||||
"content": "使用说明:\n1. 打开包装\n2. 连接电源\n3. 按下启动按钮",
|
||||
"chunk_type": "list",
|
||||
"keywords": ["说明", "步骤"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = parser._parse_llm_response(response)
|
||||
|
||||
assert len(result.chunks) == 3
|
||||
assert result.chunks[0].chunk_type == "text"
|
||||
assert result.chunks[1].chunk_type == "table"
|
||||
assert result.chunks[2].chunk_type == "list"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""
|
||||
Unit tests for multi-usage LLM configuration.
|
||||
Tests for LLMUsageType, LLMConfigManager multi-usage support, and API endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.llm.factory import (
|
||||
LLMConfigManager,
|
||||
LLMProviderFactory,
|
||||
LLMUsageType,
|
||||
LLM_USAGE_DESCRIPTIONS,
|
||||
LLM_USAGE_DISPLAY_NAMES,
|
||||
get_llm_config_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Mock settings for testing."""
|
||||
settings = MagicMock()
|
||||
settings.llm_provider = "openai"
|
||||
settings.llm_api_key = "test-api-key"
|
||||
settings.llm_base_url = "https://api.openai.com/v1"
|
||||
settings.llm_model = "gpt-4o-mini"
|
||||
settings.llm_max_tokens = 2048
|
||||
settings.llm_temperature = 0.7
|
||||
settings.llm_timeout_seconds = 30
|
||||
settings.llm_max_retries = 3
|
||||
settings.redis_enabled = False
|
||||
settings.redis_url = "redis://localhost:6379"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton():
|
||||
"""Reset singleton before and after each test."""
|
||||
import app.services.llm.factory as factory
|
||||
factory._llm_config_manager = None
|
||||
yield
|
||||
factory._llm_config_manager = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_config_file(tmp_path):
|
||||
"""Create an isolated config file for testing."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
config_file.write_text("{}")
|
||||
return config_file
|
||||
|
||||
|
||||
class TestLLMUsageType:
|
||||
"""Tests for LLMUsageType enum."""
|
||||
|
||||
def test_usage_types_exist(self):
|
||||
"""Test that required usage types exist."""
|
||||
assert LLMUsageType.CHAT.value == "chat"
|
||||
assert LLMUsageType.KB_PROCESSING.value == "kb_processing"
|
||||
|
||||
def test_usage_type_display_names(self):
|
||||
"""Test that display names are defined for all usage types."""
|
||||
for ut in LLMUsageType:
|
||||
assert ut in LLM_USAGE_DISPLAY_NAMES
|
||||
assert isinstance(LLM_USAGE_DISPLAY_NAMES[ut], str)
|
||||
assert len(LLM_USAGE_DISPLAY_NAMES[ut]) > 0
|
||||
|
||||
def test_usage_type_descriptions(self):
|
||||
"""Test that descriptions are defined for all usage types."""
|
||||
for ut in LLMUsageType:
|
||||
assert ut in LLM_USAGE_DESCRIPTIONS
|
||||
assert isinstance(LLM_USAGE_DESCRIPTIONS[ut], str)
|
||||
assert len(LLM_USAGE_DESCRIPTIONS[ut]) > 0
|
||||
|
||||
def test_usage_type_from_string(self):
|
||||
"""Test creating usage type from string."""
|
||||
assert LLMUsageType("chat") == LLMUsageType.CHAT
|
||||
assert LLMUsageType("kb_processing") == LLMUsageType.KB_PROCESSING
|
||||
|
||||
def test_invalid_usage_type(self):
|
||||
"""Test that invalid usage type raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
LLMUsageType("invalid_type")
|
||||
|
||||
|
||||
class TestLLMConfigManagerMultiUsage:
|
||||
"""Tests for LLMConfigManager multi-usage support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_configs(self, mock_settings, isolated_config_file):
|
||||
"""Test getting all configs at once."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
all_configs = manager.get_current_config()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
assert ut.value in all_configs
|
||||
assert "provider" in all_configs[ut.value]
|
||||
assert "config" in all_configs[ut.value]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_specific_usage_config(self, mock_settings, isolated_config_file):
|
||||
"""Test updating config for a specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_usage_config(
|
||||
usage_type=LLMUsageType.KB_PROCESSING,
|
||||
provider="ollama",
|
||||
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
)
|
||||
|
||||
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_config["provider"] == "ollama"
|
||||
assert kb_config["config"]["model"] == "llama3.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_all_configs(self, mock_settings, isolated_config_file):
|
||||
"""Test updating all configs at once."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_config(
|
||||
provider="deepseek",
|
||||
config={"api_key": "test-key", "model": "deepseek-chat"},
|
||||
)
|
||||
|
||||
for ut in LLMUsageType:
|
||||
config = manager.get_current_config(ut)
|
||||
assert config["provider"] == "deepseek"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_for_usage_type(self, mock_settings, isolated_config_file):
|
||||
"""Test getting client for specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
chat_client = manager.get_client(LLMUsageType.CHAT)
|
||||
assert chat_client is not None
|
||||
|
||||
kb_client = manager.get_client(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_client(self, mock_settings, isolated_config_file):
|
||||
"""Test get_chat_client convenience method."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
client = manager.get_chat_client()
|
||||
assert client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_kb_processing_client(self, mock_settings, isolated_config_file):
|
||||
"""Test get_kb_processing_client convenience method."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
client = manager.get_kb_processing_client()
|
||||
assert client is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all_clients(self, mock_settings, isolated_config_file):
|
||||
"""Test that close() closes all clients."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
manager.get_client(ut)
|
||||
|
||||
await manager.close()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
assert manager._clients[ut] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence_to_file(self, mock_settings, isolated_config_file):
|
||||
"""Test that configs are persisted to file."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
await manager.update_usage_config(
|
||||
usage_type=LLMUsageType.KB_PROCESSING,
|
||||
provider="ollama",
|
||||
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
)
|
||||
|
||||
assert isolated_config_file.exists()
|
||||
|
||||
with open(isolated_config_file, "r", encoding="utf-8") as f:
|
||||
saved = json.load(f)
|
||||
|
||||
assert "chat" in saved
|
||||
assert "kb_processing" in saved
|
||||
assert saved["kb_processing"]["provider"] == "ollama"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_config_from_file(self, mock_settings, tmp_path):
|
||||
"""Test loading configs from file."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
saved_config = {
|
||||
"chat": {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-key", "model": "gpt-4o"},
|
||||
},
|
||||
"kb_processing": {
|
||||
"provider": "ollama",
|
||||
"config": {"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
|
||||
},
|
||||
}
|
||||
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(saved_config, f)
|
||||
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
chat_config = manager.get_current_config(LLMUsageType.CHAT)
|
||||
assert chat_config["config"]["model"] == "gpt-4o"
|
||||
|
||||
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
|
||||
assert kb_config["provider"] == "ollama"
|
||||
assert kb_config["config"]["model"] == "llama3.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backward_compatibility_old_config_format(self, mock_settings, tmp_path):
|
||||
"""Test backward compatibility with old single-config format."""
|
||||
config_file = tmp_path / "llm_config.json"
|
||||
old_config = {
|
||||
"provider": "deepseek",
|
||||
"config": {"api_key": "test-key", "model": "deepseek-chat"},
|
||||
}
|
||||
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(old_config, f)
|
||||
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
for ut in LLMUsageType:
|
||||
config = manager.get_current_config(ut)
|
||||
assert config["provider"] == "deepseek"
|
||||
assert config["config"]["model"] == "deepseek-chat"
|
||||
|
||||
|
||||
class TestLLMConfigManagerTestConnection:
|
||||
"""Tests for test_connection with usage type support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_connection_with_usage_type(self, mock_settings, isolated_config_file):
|
||||
"""Test connection testing with specific usage type."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.generate = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content="Test response",
|
||||
model="gpt-4o-mini",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
)
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
with patch.object(
|
||||
LLMProviderFactory, "create_client", return_value=mock_client
|
||||
):
|
||||
result = await manager.test_connection(
|
||||
test_prompt="Hello",
|
||||
usage_type=LLMUsageType.CHAT,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "response" in result
|
||||
assert "latency_ms" in result
|
||||
|
||||
|
||||
class TestGetLLMConfigManager:
|
||||
"""Tests for get_llm_config_manager singleton."""
|
||||
|
||||
def test_singleton_instance(self, mock_settings, isolated_config_file):
|
||||
"""Test that get_llm_config_manager returns singleton."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager1 = get_llm_config_manager()
|
||||
manager2 = get_llm_config_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_singleton(self, mock_settings, isolated_config_file):
|
||||
"""Test resetting singleton instance."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = get_llm_config_manager()
|
||||
assert manager is not None
|
||||
|
||||
|
||||
class TestLLMConfigManagerProperties:
|
||||
"""Tests for LLMConfigManager properties."""
|
||||
|
||||
def test_chat_config_property(self, mock_settings, isolated_config_file):
|
||||
"""Test chat_config property."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
config = manager.chat_config
|
||||
assert isinstance(config, dict)
|
||||
assert "model" in config
|
||||
|
||||
def test_kb_processing_config_property(self, mock_settings, isolated_config_file):
|
||||
"""Test kb_processing_config property."""
|
||||
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
|
||||
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
|
||||
manager = LLMConfigManager()
|
||||
config = manager.kb_processing_config
|
||||
assert isinstance(config, dict)
|
||||
assert "model" in config
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
"""
|
||||
Unit tests for Markdown intelligent chunker.
|
||||
Tests for MarkdownParser, MarkdownChunker, and integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.document.markdown_chunker import (
|
||||
MarkdownChunk,
|
||||
MarkdownChunker,
|
||||
MarkdownElement,
|
||||
MarkdownElementType,
|
||||
MarkdownParser,
|
||||
chunk_markdown,
|
||||
)
|
||||
|
||||
|
||||
class TestMarkdownParser:
|
||||
"""Tests for MarkdownParser."""
|
||||
|
||||
def test_parse_headers(self):
|
||||
"""Test header extraction."""
|
||||
text = """# Main Title
|
||||
|
||||
## Section 1
|
||||
|
||||
### Subsection 1.1
|
||||
|
||||
#### Deep Header
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
|
||||
assert len(headers) == 4
|
||||
assert headers[0].content == "Main Title"
|
||||
assert headers[0].level == 1
|
||||
assert headers[1].content == "Section 1"
|
||||
assert headers[1].level == 2
|
||||
assert headers[2].content == "Subsection 1.1"
|
||||
assert headers[2].level == 3
|
||||
assert headers[3].content == "Deep Header"
|
||||
assert headers[3].level == 4
|
||||
|
||||
def test_parse_code_blocks(self):
|
||||
"""Test code block extraction with language."""
|
||||
text = """Here is some code:
|
||||
|
||||
```python
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
And some more text.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_blocks) == 1
|
||||
assert code_blocks[0].language == "python"
|
||||
assert 'def hello():' in code_blocks[0].content
|
||||
assert 'print("Hello, World!")' in code_blocks[0].content
|
||||
|
||||
def test_parse_code_blocks_no_language(self):
|
||||
"""Test code block without language specification."""
|
||||
text = """```
|
||||
plain code here
|
||||
multiple lines
|
||||
```
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_blocks) == 1
|
||||
assert code_blocks[0].language == ""
|
||||
assert "plain code here" in code_blocks[0].content
|
||||
|
||||
def test_parse_tables(self):
|
||||
"""Test table extraction."""
|
||||
text = """| Name | Age | City |
|
||||
|------|-----|------|
|
||||
| Alice | 30 | NYC |
|
||||
| Bob | 25 | LA |
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(tables) == 1
|
||||
assert "Name" in tables[0].content
|
||||
assert "Alice" in tables[0].content
|
||||
assert tables[0].metadata.get("headers") == ["Name", "Age", "City"]
|
||||
assert tables[0].metadata.get("row_count") == 2
|
||||
|
||||
def test_parse_lists(self):
|
||||
"""Test list extraction."""
|
||||
text = """- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(lists) == 1
|
||||
assert "Item 1" in lists[0].content
|
||||
assert "Item 2" in lists[0].content
|
||||
assert "Item 3" in lists[0].content
|
||||
|
||||
def test_parse_ordered_lists(self):
|
||||
"""Test ordered list extraction."""
|
||||
text = """1. First
|
||||
2. Second
|
||||
3. Third
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(lists) == 1
|
||||
assert "First" in lists[0].content
|
||||
assert "Second" in lists[0].content
|
||||
assert "Third" in lists[0].content
|
||||
|
||||
def test_parse_blockquotes(self):
|
||||
"""Test blockquote extraction."""
|
||||
text = """> This is a quote.
|
||||
> It spans multiple lines.
|
||||
> And continues here.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
|
||||
|
||||
assert len(quotes) == 1
|
||||
assert "This is a quote." in quotes[0].content
|
||||
assert "It spans multiple lines." in quotes[0].content
|
||||
|
||||
def test_parse_paragraphs(self):
|
||||
"""Test paragraph extraction."""
|
||||
text = """This is the first paragraph.
|
||||
|
||||
This is the second paragraph.
|
||||
It has multiple lines.
|
||||
|
||||
This is the third.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
|
||||
|
||||
assert len(paragraphs) == 3
|
||||
assert "first paragraph" in paragraphs[0].content
|
||||
assert "second paragraph" in paragraphs[1].content
|
||||
|
||||
def test_parse_mixed_content(self):
|
||||
"""Test parsing mixed Markdown content."""
|
||||
text = """# Documentation
|
||||
|
||||
## Introduction
|
||||
|
||||
This is an introduction paragraph.
|
||||
|
||||
## Code Example
|
||||
|
||||
```python
|
||||
def example():
|
||||
return 42
|
||||
```
|
||||
|
||||
## Data Table
|
||||
|
||||
| Column A | Column B |
|
||||
|----------|----------|
|
||||
| Value 1 | Value 2 |
|
||||
|
||||
## List
|
||||
|
||||
- Item A
|
||||
- Item B
|
||||
|
||||
> Note: This is important.
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
|
||||
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
|
||||
|
||||
assert len(headers) == 5
|
||||
assert len(code_blocks) == 1
|
||||
assert len(tables) == 1
|
||||
assert len(lists) == 1
|
||||
assert len(quotes) == 1
|
||||
assert len(paragraphs) >= 1
|
||||
|
||||
def test_code_blocks_not_parsed_as_other_elements(self):
|
||||
"""Test that code blocks don't get parsed as headers or lists."""
|
||||
text = """```markdown
|
||||
# This is not a header
|
||||
- This is not a list
|
||||
| This is not a table |
|
||||
```
|
||||
"""
|
||||
parser = MarkdownParser()
|
||||
elements = parser.parse(text)
|
||||
|
||||
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
|
||||
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
|
||||
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
|
||||
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(headers) == 0
|
||||
assert len(lists) == 0
|
||||
assert len(tables) == 0
|
||||
assert len(code_blocks) == 1
|
||||
|
||||
|
||||
class TestMarkdownChunker:
|
||||
"""Tests for MarkdownChunker."""
|
||||
|
||||
def test_chunk_simple_document(self):
|
||||
"""Test chunking a simple document."""
|
||||
text = """# Title
|
||||
|
||||
This is a paragraph.
|
||||
|
||||
## Section
|
||||
|
||||
Another paragraph.
|
||||
"""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test_doc")
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert all(isinstance(chunk, MarkdownChunk) for chunk in chunks)
|
||||
assert all(chunk.chunk_id.startswith("test_doc") for chunk in chunks)
|
||||
|
||||
def test_chunk_preserves_header_context(self):
|
||||
"""Test that header context is preserved."""
|
||||
text = """# Main Title
|
||||
|
||||
## Section A
|
||||
|
||||
Content under section A.
|
||||
|
||||
### Subsection A1
|
||||
|
||||
Content under subsection A1.
|
||||
"""
|
||||
chunker = MarkdownChunker(include_header_context=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
subsection_chunks = [c for c in chunks if "Subsection A1" not in c.content]
|
||||
for chunk in subsection_chunks:
|
||||
if "subsection a1" in chunk.content.lower():
|
||||
assert "Main Title" in chunk.header_context
|
||||
assert "Section A" in chunk.header_context
|
||||
|
||||
def test_chunk_code_blocks_preserved(self):
|
||||
"""Test that code blocks are preserved as single chunks when possible."""
|
||||
text = """```python
|
||||
def function_one():
|
||||
pass
|
||||
|
||||
def function_two():
|
||||
pass
|
||||
```
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_code_blocks=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_chunks) == 1
|
||||
assert "def function_one" in code_chunks[0].content
|
||||
assert "def function_two" in code_chunks[0].content
|
||||
assert code_chunks[0].language == "python"
|
||||
|
||||
def test_chunk_large_code_block_split(self):
|
||||
"""Test that large code blocks are split properly."""
|
||||
lines = ["def function_{}(): pass".format(i) for i in range(100)]
|
||||
code_content = "\n".join(lines)
|
||||
text = f"""```python\n{code_content}\n```"""
|
||||
|
||||
chunker = MarkdownChunker(max_chunk_size=500, preserve_code_blocks=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
|
||||
|
||||
assert len(code_chunks) > 1
|
||||
for chunk in code_chunks:
|
||||
assert chunk.language == "python"
|
||||
assert "```python" in chunk.content
|
||||
assert "```" in chunk.content
|
||||
|
||||
def test_chunk_table_preserved(self):
|
||||
"""Test that tables are preserved."""
|
||||
text = """| Name | Age |
|
||||
|------|-----|
|
||||
| Alice | 30 |
|
||||
| Bob | 25 |
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_tables=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(table_chunks) == 1
|
||||
assert "Alice" in table_chunks[0].content
|
||||
assert "Bob" in table_chunks[0].content
|
||||
|
||||
def test_chunk_large_table_split(self):
|
||||
"""Test that large tables are split with header preserved."""
|
||||
rows = [f"| Name{i} | {i * 10} |" for i in range(50)]
|
||||
table_content = "| Name | Age |\n|------|-----|\n" + "\n".join(rows)
|
||||
text = table_content
|
||||
|
||||
chunker = MarkdownChunker(max_chunk_size=200, preserve_tables=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
|
||||
|
||||
assert len(table_chunks) > 1
|
||||
for chunk in table_chunks:
|
||||
assert "| Name | Age |" in chunk.content
|
||||
assert "|------|-----|" in chunk.content
|
||||
|
||||
def test_chunk_list_preserved(self):
|
||||
"""Test that lists are chunked properly."""
|
||||
text = """- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
- Item 4
|
||||
- Item 5
|
||||
"""
|
||||
chunker = MarkdownChunker(max_chunk_size=2000, preserve_lists=True)
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
list_chunks = [c for c in chunks if c.element_type == MarkdownElementType.LIST]
|
||||
|
||||
assert len(list_chunks) == 1
|
||||
assert "Item 1" in list_chunks[0].content
|
||||
assert "Item 5" in list_chunks[0].content
|
||||
|
||||
def test_chunk_empty_document(self):
|
||||
"""Test chunking an empty document."""
|
||||
text = ""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_chunk_only_headers(self):
|
||||
"""Test chunking a document with only headers."""
|
||||
text = """# Title 1
|
||||
## Title 2
|
||||
### Title 3
|
||||
"""
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.chunk(text, "test")
|
||||
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
class TestChunkMarkdownFunction:
|
||||
"""Tests for the convenience chunk_markdown function."""
|
||||
|
||||
def test_basic_chunking(self):
|
||||
"""Test basic chunking via convenience function."""
|
||||
text = """# Title
|
||||
|
||||
Content paragraph.
|
||||
|
||||
```python
|
||||
code = "here"
|
||||
```
|
||||
"""
|
||||
chunks = chunk_markdown(text, "doc1")
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert all("chunk_id" in chunk for chunk in chunks)
|
||||
assert all("content" in chunk for chunk in chunks)
|
||||
assert all("element_type" in chunk for chunk in chunks)
|
||||
assert all("header_context" in chunk for chunk in chunks)
|
||||
|
||||
def test_custom_parameters(self):
|
||||
"""Test chunking with custom parameters."""
|
||||
text = "A" * 2000
|
||||
|
||||
chunks = chunk_markdown(
|
||||
text,
|
||||
"doc1",
|
||||
max_chunk_size=500,
|
||||
min_chunk_size=50,
|
||||
preserve_code_blocks=False,
|
||||
preserve_tables=False,
|
||||
preserve_lists=False,
|
||||
include_header_context=False,
|
||||
)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
|
||||
|
||||
class TestMarkdownElement:
|
||||
"""Tests for MarkdownElement dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.HEADER,
|
||||
content="Test Header",
|
||||
level=2,
|
||||
line_start=10,
|
||||
line_end=10,
|
||||
metadata={"level": 2},
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "header"
|
||||
assert result["content"] == "Test Header"
|
||||
assert result["level"] == 2
|
||||
assert result["line_start"] == 10
|
||||
assert result["line_end"] == 10
|
||||
|
||||
def test_code_block_with_language(self):
|
||||
"""Test code block element with language."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.CODE_BLOCK,
|
||||
content="print('hello')",
|
||||
language="python",
|
||||
line_start=5,
|
||||
line_end=7,
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "code_block"
|
||||
assert result["language"] == "python"
|
||||
|
||||
def test_table_with_metadata(self):
|
||||
"""Test table element with metadata."""
|
||||
elem = MarkdownElement(
|
||||
type=MarkdownElementType.TABLE,
|
||||
content="| A | B |\n|---|---|\n| 1 | 2 |",
|
||||
line_start=1,
|
||||
line_end=3,
|
||||
metadata={"headers": ["A", "B"], "row_count": 1},
|
||||
)
|
||||
|
||||
result = elem.to_dict()
|
||||
|
||||
assert result["type"] == "table"
|
||||
assert result["metadata"]["headers"] == ["A", "B"]
|
||||
assert result["metadata"]["row_count"] == 1
|
||||
|
||||
|
||||
class TestMarkdownChunk:
|
||||
"""Tests for MarkdownChunk dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
chunk = MarkdownChunk(
|
||||
chunk_id="doc_chunk_0",
|
||||
content="Test content",
|
||||
element_type=MarkdownElementType.PARAGRAPH,
|
||||
header_context=["Main Title", "Section"],
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
result = chunk.to_dict()
|
||||
|
||||
assert result["chunk_id"] == "doc_chunk_0"
|
||||
assert result["content"] == "Test content"
|
||||
assert result["element_type"] == "paragraph"
|
||||
assert result["header_context"] == ["Main Title", "Section"]
|
||||
assert result["metadata"]["key"] == "value"
|
||||
|
||||
def test_with_language(self):
|
||||
"""Test chunk with language info."""
|
||||
chunk = MarkdownChunk(
|
||||
chunk_id="code_0",
|
||||
content="```python\nprint('hi')\n```",
|
||||
element_type=MarkdownElementType.CODE_BLOCK,
|
||||
header_context=[],
|
||||
language="python",
|
||||
)
|
||||
|
||||
result = chunk.to_dict()
|
||||
|
||||
assert result["language"] == "python"
|
||||
|
||||
|
||||
class TestMarkdownElementType:
|
||||
"""Tests for MarkdownElementType enum."""
|
||||
|
||||
def test_all_types_exist(self):
|
||||
"""Test that all expected element types exist."""
|
||||
expected_types = [
|
||||
"header",
|
||||
"paragraph",
|
||||
"code_block",
|
||||
"inline_code",
|
||||
"table",
|
||||
"list",
|
||||
"blockquote",
|
||||
"horizontal_rule",
|
||||
"image",
|
||||
"link",
|
||||
"text",
|
||||
]
|
||||
|
||||
for type_name in expected_types:
|
||||
assert hasattr(MarkdownElementType, type_name.upper()) or \
|
||||
any(t.value == type_name for t in MarkdownElementType)
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
"""
|
||||
Unit tests for MetadataAutoInferenceService.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.metadata_auto_inference_service import (
|
||||
AutoInferenceResult,
|
||||
InferenceFieldContext,
|
||||
MetadataAutoInferenceService,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFieldDefinition:
|
||||
"""Mock field definition for testing"""
|
||||
field_key: str
|
||||
label: str
|
||||
type: str
|
||||
required: bool
|
||||
options: list[str] | None = None
|
||||
|
||||
|
||||
class TestInferenceFieldContext:
|
||||
"""Test InferenceFieldContext dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating InferenceFieldContext."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
)
|
||||
|
||||
assert ctx.field_key == "grade"
|
||||
assert ctx.label == "年级"
|
||||
assert ctx.type == "enum"
|
||||
assert ctx.required is True
|
||||
assert ctx.options == ["初一", "初二", "初三"]
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="test",
|
||||
label="Test",
|
||||
type="text",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert ctx.options is None
|
||||
assert ctx.description is None
|
||||
|
||||
|
||||
class TestAutoInferenceResult:
|
||||
"""Test AutoInferenceResult dataclass."""
|
||||
|
||||
def test_success_result(self):
|
||||
"""Test successful inference result."""
|
||||
result = AutoInferenceResult(
|
||||
inferred_metadata={"grade": "初一", "subject": "数学"},
|
||||
confidence_scores={"grade": 0.95, "subject": 0.85},
|
||||
raw_response='{"inferred_metadata": {...}}',
|
||||
success=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.error_message is None
|
||||
assert result.inferred_metadata["grade"] == "初一"
|
||||
|
||||
def test_failure_result(self):
|
||||
"""Test failed inference result."""
|
||||
result = AutoInferenceResult(
|
||||
inferred_metadata={},
|
||||
confidence_scores={},
|
||||
raw_response="",
|
||||
success=False,
|
||||
error_message="JSON parse error",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_message == "JSON parse error"
|
||||
|
||||
|
||||
class TestMetadataAutoInferenceService:
|
||||
"""Test MetadataAutoInferenceService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance."""
|
||||
return MetadataAutoInferenceService(mock_session)
|
||||
|
||||
def test_build_field_contexts(self, service):
|
||||
"""Test building field contexts from definitions."""
|
||||
fields = [
|
||||
MockFieldDefinition(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
MockFieldDefinition(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
contexts = service._build_field_contexts(fields)
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0].field_key == "grade"
|
||||
assert contexts[0].options == ["初一", "初二", "初三"]
|
||||
assert contexts[1].field_key == "subject"
|
||||
|
||||
def test_build_user_prompt(self, service):
|
||||
"""Test building user prompt."""
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
prompt = service._build_user_prompt(
|
||||
content="这是一道初一数学题",
|
||||
field_contexts=field_contexts,
|
||||
)
|
||||
|
||||
assert "年级" in prompt
|
||||
assert "初一, 初二, 初三" in prompt
|
||||
assert "这是一道初一数学题" in prompt
|
||||
|
||||
def test_build_user_prompt_with_existing_metadata(self, service):
|
||||
"""Test building user prompt with existing metadata."""
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
prompt = service._build_user_prompt(
|
||||
content="这是一道数学题",
|
||||
field_contexts=field_contexts,
|
||||
existing_metadata={"grade": "初二"},
|
||||
)
|
||||
|
||||
assert "已有值: 初二" in prompt
|
||||
|
||||
def test_extract_json_from_plain_json(self, service):
|
||||
"""Test extracting JSON from plain JSON response."""
|
||||
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
|
||||
|
||||
result = service._extract_json(json_str)
|
||||
|
||||
assert result == json_str
|
||||
|
||||
def test_extract_json_from_markdown(self, service):
|
||||
"""Test extracting JSON from markdown code block."""
|
||||
markdown = """Here is the result:
|
||||
|
||||
```json
|
||||
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
|
||||
```
|
||||
"""
|
||||
|
||||
result = service._extract_json(markdown)
|
||||
|
||||
assert "inferred_metadata" in result
|
||||
assert "grade" in result
|
||||
|
||||
def test_parse_llm_response_valid(self, service):
|
||||
"""Test parsing valid LLM response."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初一",
|
||||
"subject": "数学",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.95,
|
||||
"subject": 0.85,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert result.inferred_metadata["grade"] == "初一"
|
||||
assert result.inferred_metadata["subject"] == "数学"
|
||||
assert result.confidence_scores["grade"] == 0.95
|
||||
|
||||
def test_parse_llm_response_invalid_option(self, service):
|
||||
"""Test parsing response with invalid enum option."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "高一", # Not in options
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.90,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert "grade" not in result.inferred_metadata
|
||||
|
||||
def test_parse_llm_response_invalid_json(self, service):
|
||||
"""Test parsing invalid JSON response."""
|
||||
response = "This is not valid JSON"
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="text",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is False
|
||||
assert "JSON parse error" in result.error_message
|
||||
|
||||
def test_validate_field_value_text(self, service):
|
||||
"""Test validating text field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="title",
|
||||
label="标题",
|
||||
type="text",
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = service._validate_field_value(ctx, "测试标题")
|
||||
|
||||
assert result == "测试标题"
|
||||
|
||||
def test_validate_field_value_number(self, service):
|
||||
"""Test validating number field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="count",
|
||||
label="数量",
|
||||
type="number",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, 42) == 42
|
||||
assert service._validate_field_value(ctx, "3.14") == 3.14
|
||||
assert service._validate_field_value(ctx, "invalid") is None
|
||||
|
||||
def test_validate_field_value_boolean(self, service):
|
||||
"""Test validating boolean field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="active",
|
||||
label="是否激活",
|
||||
type="boolean",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, True) is True
|
||||
assert service._validate_field_value(ctx, "true") is True
|
||||
assert service._validate_field_value(ctx, "false") is False
|
||||
assert service._validate_field_value(ctx, 1) is True
|
||||
|
||||
def test_validate_field_value_enum(self, service):
|
||||
"""Test validating enum field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=False,
|
||||
options=["初一", "初二", "初三"],
|
||||
)
|
||||
|
||||
assert service._validate_field_value(ctx, "初一") == "初一"
|
||||
assert service._validate_field_value(ctx, "高一") is None
|
||||
|
||||
def test_validate_field_value_array_enum(self, service):
|
||||
"""Test validating array_enum field value."""
|
||||
ctx = InferenceFieldContext(
|
||||
field_key="tags",
|
||||
label="标签",
|
||||
type="array_enum",
|
||||
required=False,
|
||||
options=["重点", "难点", "易错"],
|
||||
)
|
||||
|
||||
result = service._validate_field_value(ctx, ["重点", "难点"])
|
||||
assert result == ["重点", "难点"]
|
||||
|
||||
result = service._validate_field_value(ctx, ["重点", "不存在"])
|
||||
assert result == ["重点"]
|
||||
|
||||
result = service._validate_field_value(ctx, "重点")
|
||||
assert result == ["重点"]
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create mock session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session):
|
||||
"""Create service instance."""
|
||||
return MetadataAutoInferenceService(mock_session)
|
||||
|
||||
def test_education_scenario(self, service):
|
||||
"""Test education scenario with grade and subject."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初二",
|
||||
"subject": "物理",
|
||||
"type": "痛点",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.95,
|
||||
"subject": 0.90,
|
||||
"type": 0.85,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语", "物理", "化学"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="type",
|
||||
label="类型",
|
||||
type="enum",
|
||||
required=False,
|
||||
options=["痛点", "重点", "难点"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert result.inferred_metadata == {
|
||||
"grade": "初二",
|
||||
"subject": "物理",
|
||||
"type": "痛点",
|
||||
}
|
||||
assert result.confidence_scores["grade"] == 0.95
|
||||
|
||||
def test_partial_inference(self, service):
|
||||
"""Test partial inference when some fields cannot be inferred."""
|
||||
response = json.dumps({
|
||||
"inferred_metadata": {
|
||||
"grade": "初一",
|
||||
},
|
||||
"confidence_scores": {
|
||||
"grade": 0.90,
|
||||
}
|
||||
})
|
||||
|
||||
field_contexts = [
|
||||
InferenceFieldContext(
|
||||
field_key="grade",
|
||||
label="年级",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["初一", "初二", "初三"],
|
||||
),
|
||||
InferenceFieldContext(
|
||||
field_key="subject",
|
||||
label="学科",
|
||||
type="enum",
|
||||
required=True,
|
||||
options=["语文", "数学", "英语"],
|
||||
),
|
||||
]
|
||||
|
||||
result = service._parse_llm_response(response, field_contexts)
|
||||
|
||||
assert result.success is True
|
||||
assert "grade" in result.inferred_metadata
|
||||
assert "subject" not in result.inferred_metadata
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -0,0 +1,881 @@
|
|||
"""
|
||||
Mid Platform Dialogue Integration Test.
|
||||
中台联调界面对话过程集成测试脚本
|
||||
|
||||
测试重点:
|
||||
1. 意图置信度参数
|
||||
2. 执行模式(通用API vs ReAct模式)
|
||||
3. ReAct模式下的工具调用(工具名称、入参、返回结果)
|
||||
4. 知识库查询(是否命中、入参)
|
||||
5. 各部分耗时
|
||||
6. 提示词模板使用情况
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.main import app
|
||||
from app.models.mid.schemas import (
|
||||
DialogueRequest,
|
||||
DialogueResponse,
|
||||
ExecutionMode,
|
||||
FeatureFlags,
|
||||
HistoryMessage,
|
||||
Segment,
|
||||
TraceInfo,
|
||||
ToolCallTrace,
|
||||
ToolCallStatus,
|
||||
ToolType,
|
||||
IntentHintOutput,
|
||||
HighRiskCheckResult,
|
||||
HighRiskScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingRecord:
|
||||
"""耗时记录"""
|
||||
stage: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
duration_ms: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"stage": self.stage,
|
||||
"duration_ms": self.duration_ms,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueTestResult:
|
||||
"""对话测试结果"""
|
||||
request_id: str
|
||||
user_message: str
|
||||
response_text: str = ""
|
||||
|
||||
timing_records: list[TimingRecord] = field(default_factory=list)
|
||||
total_duration_ms: int = 0
|
||||
|
||||
execution_mode: ExecutionMode = ExecutionMode.AGENT
|
||||
intent: str | None = None
|
||||
confidence: float | None = None
|
||||
|
||||
react_iterations: int = 0
|
||||
tools_used: list[str] = field(default_factory=list)
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
|
||||
kb_tool_called: bool = False
|
||||
kb_hit: bool = False
|
||||
kb_query: str | None = None
|
||||
kb_filter: dict | None = None
|
||||
kb_hits_count: int = 0
|
||||
|
||||
prompt_template_used: str | None = None
|
||||
prompt_template_scene: str | None = None
|
||||
|
||||
guardrail_triggered: bool = False
|
||||
fallback_reason_code: str | None = None
|
||||
|
||||
raw_trace: dict | None = None
|
||||
|
||||
def to_summary(self) -> dict:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"user_message": self.user_message[:100],
|
||||
"execution_mode": self.execution_mode.value,
|
||||
"intent": self.intent,
|
||||
"confidence": self.confidence,
|
||||
"total_duration_ms": self.total_duration_ms,
|
||||
"timing_breakdown": [t.to_dict() for t in self.timing_records],
|
||||
"react_iterations": self.react_iterations,
|
||||
"tools_used": self.tools_used,
|
||||
"tool_calls_count": len(self.tool_calls),
|
||||
"kb_tool_called": self.kb_tool_called,
|
||||
"kb_hit": self.kb_hit,
|
||||
"kb_hits_count": self.kb_hits_count,
|
||||
"prompt_template_used": self.prompt_template_used,
|
||||
"guardrail_triggered": self.guardrail_triggered,
|
||||
"fallback_reason_code": self.fallback_reason_code,
|
||||
}
|
||||
|
||||
|
||||
class DialogueIntegrationTester:
|
||||
"""对话集成测试器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
tenant_id: str = "test_tenant",
|
||||
api_key: str | None = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.tenant_id = tenant_id
|
||||
self.api_key = api_key
|
||||
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
|
||||
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Tenant-Id": self.tenant_id,
|
||||
}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
return headers
|
||||
|
||||
async def send_dialogue(
|
||||
self,
|
||||
user_message: str,
|
||||
history: list[dict] | None = None,
|
||||
scene: str | None = None,
|
||||
feature_flags: dict | None = None,
|
||||
) -> DialogueTestResult:
|
||||
"""发送对话请求并记录详细信息"""
|
||||
request_id = str(uuid.uuid4())
|
||||
result = DialogueTestResult(
|
||||
request_id=request_id,
|
||||
user_message=user_message,
|
||||
response_text="",
|
||||
)
|
||||
|
||||
overall_start = time.time()
|
||||
|
||||
request_body = {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"user_message": user_message,
|
||||
"history": history or [],
|
||||
}
|
||||
|
||||
if scene:
|
||||
request_body["scene"] = scene
|
||||
if feature_flags:
|
||||
request_body["feature_flags"] = feature_flags
|
||||
|
||||
timing_records = []
|
||||
|
||||
try:
|
||||
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
|
||||
request_start = time.time()
|
||||
|
||||
response = await client.post(
|
||||
"/mid/dialogue/respond",
|
||||
json=request_body,
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
|
||||
request_end = time.time()
|
||||
timing_records.append(TimingRecord(
|
||||
stage="http_request",
|
||||
start_time=request_start,
|
||||
end_time=request_end,
|
||||
duration_ms=int((request_end - request_start) * 1000),
|
||||
))
|
||||
|
||||
if response.status_code != 200:
|
||||
result.response_text = f"Error: {response.status_code}"
|
||||
result.fallback_reason_code = f"http_error_{response.status_code}"
|
||||
return result
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
except Exception as e:
|
||||
result.response_text = f"Exception: {str(e)}"
|
||||
result.fallback_reason_code = "request_exception"
|
||||
result.total_duration_ms = int((time.time() - overall_start) * 1000)
|
||||
return result
|
||||
|
||||
overall_end = time.time()
|
||||
result.total_duration_ms = int((overall_end - overall_start) * 1000)
|
||||
result.timing_records = timing_records
|
||||
|
||||
try:
|
||||
dialogue_response = DialogueResponse(**response_data)
|
||||
|
||||
if dialogue_response.segments:
|
||||
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
|
||||
|
||||
trace = dialogue_response.trace
|
||||
result.raw_trace = trace.model_dump() if trace else None
|
||||
|
||||
if trace:
|
||||
result.execution_mode = trace.mode
|
||||
result.intent = trace.intent
|
||||
result.react_iterations = trace.react_iterations or 0
|
||||
result.tools_used = trace.tools_used or []
|
||||
result.kb_tool_called = trace.kb_tool_called or False
|
||||
result.kb_hit = trace.kb_hit or False
|
||||
result.guardrail_triggered = trace.guardrail_triggered or False
|
||||
result.fallback_reason_code = trace.fallback_reason_code
|
||||
|
||||
if trace.tool_calls:
|
||||
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
|
||||
|
||||
for tc in trace.tool_calls:
|
||||
if tc.tool_name == "kb_search_dynamic":
|
||||
result.kb_tool_called = True
|
||||
if tc.arguments:
|
||||
result.kb_query = tc.arguments.get("query")
|
||||
result.kb_filter = tc.arguments.get("context")
|
||||
if tc.result and isinstance(tc.result, dict):
|
||||
result.kb_hits_count = len(tc.result.get("hits", []))
|
||||
result.kb_hit = result.kb_hits_count > 0
|
||||
|
||||
if trace.duration_ms:
|
||||
timing_records.append(TimingRecord(
|
||||
stage="server_processing",
|
||||
start_time=overall_start,
|
||||
end_time=overall_end,
|
||||
duration_ms=trace.duration_ms,
|
||||
))
|
||||
|
||||
if trace.scene:
|
||||
result.prompt_template_scene = trace.scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse response: {e}")
|
||||
result.response_text = str(response_data)
|
||||
|
||||
return result
|
||||
|
||||
def print_result(self, result: DialogueTestResult):
|
||||
"""打印测试结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[对话测试结果] Request ID: {result.request_id}")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n[用户消息] {result.user_message}")
|
||||
print(f"[回复内容] {result.response_text[:200]}...")
|
||||
|
||||
print(f"\n[执行模式] {result.execution_mode.value}")
|
||||
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
|
||||
|
||||
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
|
||||
for tr in result.timing_records:
|
||||
print(f" - {tr.stage}: {tr.duration_ms}ms")
|
||||
|
||||
if result.execution_mode == ExecutionMode.AGENT:
|
||||
print(f"\n[ReAct模式]")
|
||||
print(f" - 迭代次数: {result.react_iterations}")
|
||||
print(f" - 使用的工具: {result.tools_used}")
|
||||
|
||||
if result.tool_calls:
|
||||
print(f"\n[工具调用详情]")
|
||||
for i, tc in enumerate(result.tool_calls, 1):
|
||||
print(f" [{i}] 工具: {tc.get('tool_name')}")
|
||||
print(f" 状态: {tc.get('status')}")
|
||||
print(f" 耗时: {tc.get('duration_ms')}ms")
|
||||
if tc.get('arguments'):
|
||||
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
|
||||
if tc.get('result'):
|
||||
result_str = str(tc.get('result'))[:300]
|
||||
print(f" 结果: {result_str}")
|
||||
|
||||
print(f"\n[知识库查询]")
|
||||
print(f" - 是否调用: {result.kb_tool_called}")
|
||||
print(f" - 是否命中: {result.kb_hit}")
|
||||
if result.kb_query:
|
||||
print(f" - 查询内容: {result.kb_query}")
|
||||
if result.kb_filter:
|
||||
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
|
||||
print(f" - 命中数量: {result.kb_hits_count}")
|
||||
|
||||
print(f"\n[提示词模板]")
|
||||
print(f" - 场景: {result.prompt_template_scene}")
|
||||
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
|
||||
|
||||
print(f"\n[其他信息]")
|
||||
print(f" - 护栏触发: {result.guardrail_triggered}")
|
||||
print(f" - 降级原因: {result.fallback_reason_code or '无'}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
class TestMidDialogueIntegration:
|
||||
"""中台对话集成测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def tester(self):
|
||||
return DialogueIntegrationTester(
|
||||
base_url="http://localhost:8000",
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client(self):
|
||||
"""模拟 LLM 客户端"""
|
||||
mock = AsyncMock()
|
||||
mock.generate = AsyncMock(return_value=MagicMock(
|
||||
content="这是测试回复",
|
||||
has_tool_calls=False,
|
||||
tool_calls=[],
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kb_tool(self):
|
||||
"""模拟知识库工具"""
|
||||
mock = AsyncMock()
|
||||
mock.execute = AsyncMock(return_value=MagicMock(
|
||||
success=True,
|
||||
hits=[
|
||||
{"id": "1", "content": "测试知识库内容", "score": 0.9},
|
||||
],
|
||||
applied_filter={"scene": "test"},
|
||||
missing_required_slots=[],
|
||||
fallback_reason_code=None,
|
||||
duration_ms=100,
|
||||
tool_trace=None,
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
|
||||
"""测试简单问候"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="你好",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
assert result.total_duration_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kb_query(self, tester: DialogueIntegrationTester):
|
||||
"""测试知识库查询"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="退款流程是什么?",
|
||||
scene="after_sale",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
|
||||
"""测试高风险场景"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="我要投诉你们的服务",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_request(self, tester: DialogueIntegrationTester):
|
||||
"""测试转人工请求"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="帮我转人工客服",
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_history(self, tester: DialogueIntegrationTester):
|
||||
"""测试带历史记录的对话"""
|
||||
result = await tester.send_dialogue(
|
||||
user_message="那退款要多久呢?",
|
||||
history=[
|
||||
{"role": "user", "content": "我想退款"},
|
||||
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
|
||||
],
|
||||
)
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
assert result.request_id is not None
|
||||
|
||||
|
||||
class TestDialogueWithMock:
|
||||
"""使用 Mock 的对话测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""创建带 Mock 的测试应用"""
|
||||
from fastapi import FastAPI
|
||||
from app.api.mid.dialogue import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, mock_app):
|
||||
return TestClient(mock_app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""模拟数据库会话"""
|
||||
mock = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock.execute.return_value = mock_result
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
"""模拟 LLM 响应"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "这是测试回复内容"
|
||||
mock_response.has_tool_calls = False
|
||||
mock_response.tool_calls = []
|
||||
return mock_response
|
||||
|
||||
def test_dialogue_request_structure(self, client: TestClient):
|
||||
"""测试对话请求结构"""
|
||||
request_body = {
|
||||
"session_id": "test_session_001",
|
||||
"user_id": "test_user_001",
|
||||
"user_message": "你好",
|
||||
"history": [],
|
||||
"scene": "open_consult",
|
||||
}
|
||||
|
||||
print("\n[测试请求结构]")
|
||||
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
|
||||
|
||||
assert request_body["session_id"] == "test_session_001"
|
||||
assert request_body["user_message"] == "你好"
|
||||
|
||||
def test_trace_info_structure(self):
|
||||
"""测试追踪信息结构"""
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
intent="greeting",
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
kb_tool_called=True,
|
||||
kb_hit=True,
|
||||
react_iterations=2,
|
||||
tools_used=["kb_search_dynamic"],
|
||||
tool_calls=[
|
||||
ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=150,
|
||||
status=ToolCallStatus.OK,
|
||||
arguments={"query": "测试查询", "scene": "test"},
|
||||
result={"hits": [{"content": "测试内容"}]},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
print("\n[TraceInfo 结构测试]")
|
||||
print(f"Mode: {trace.mode.value}")
|
||||
print(f"Intent: {trace.intent}")
|
||||
print(f"KB Tool Called: {trace.kb_tool_called}")
|
||||
print(f"KB Hit: {trace.kb_hit}")
|
||||
print(f"React Iterations: {trace.react_iterations}")
|
||||
print(f"Tools Used: {trace.tools_used}")
|
||||
|
||||
if trace.tool_calls:
|
||||
print(f"\n[Tool Calls]")
|
||||
for tc in trace.tool_calls:
|
||||
print(f" - Tool: {tc.tool_name}")
|
||||
print(f" Status: {tc.status.value}")
|
||||
print(f" Duration: {tc.duration_ms}ms")
|
||||
if tc.arguments:
|
||||
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
|
||||
|
||||
assert trace.mode == ExecutionMode.AGENT
|
||||
assert trace.kb_tool_called is True
|
||||
assert len(trace.tool_calls) == 1
|
||||
|
||||
def test_intent_hint_output_structure(self):
|
||||
"""测试意图提示输出结构"""
|
||||
hint = IntentHintOutput(
|
||||
intent="refund",
|
||||
confidence=0.85,
|
||||
response_type="flow",
|
||||
suggested_mode=ExecutionMode.MICRO_FLOW,
|
||||
target_flow_id="flow_refund_001",
|
||||
high_risk_detected=False,
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
print("\n[IntentHintOutput 结构测试]")
|
||||
print(f"Intent: {hint.intent}")
|
||||
print(f"Confidence: {hint.confidence}")
|
||||
print(f"Response Type: {hint.response_type}")
|
||||
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
|
||||
print(f"High Risk Detected: {hint.high_risk_detected}")
|
||||
print(f"Duration: {hint.duration_ms}ms")
|
||||
|
||||
assert hint.intent == "refund"
|
||||
assert hint.confidence == 0.85
|
||||
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
|
||||
|
||||
def test_high_risk_check_result_structure(self):
|
||||
"""测试高风险检测结果结构"""
|
||||
result = HighRiskCheckResult(
|
||||
matched=True,
|
||||
risk_scenario=HighRiskScenario.REFUND,
|
||||
confidence=0.95,
|
||||
recommended_mode=ExecutionMode.MICRO_FLOW,
|
||||
rule_id="rule_refund_001",
|
||||
reason="检测到退款关键词",
|
||||
duration_ms=30,
|
||||
)
|
||||
|
||||
print("\n[HighRiskCheckResult 结构测试]")
|
||||
print(f"Matched: {result.matched}")
|
||||
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
|
||||
print(f"Confidence: {result.confidence}")
|
||||
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
|
||||
print(f"Rule ID: {result.rule_id}")
|
||||
print(f"Duration: {result.duration_ms}ms")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.risk_scenario == HighRiskScenario.REFUND
|
||||
|
||||
|
||||
class TestPromptTemplateUsage:
|
||||
"""提示词模板使用测试"""
|
||||
|
||||
def test_template_resolution(self):
|
||||
"""测试模板解析"""
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
resolver = VariableResolver()
|
||||
|
||||
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
|
||||
variables = [
|
||||
{"key": "user_name", "value": "张三"},
|
||||
{"key": "bot_name", "value": "智能客服"},
|
||||
]
|
||||
|
||||
resolved = resolver.resolve(template, variables)
|
||||
|
||||
print("\n[模板解析测试]")
|
||||
print(f"原始模板: {template}")
|
||||
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
|
||||
print(f"解析结果: {resolved}")
|
||||
|
||||
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
|
||||
|
||||
def test_template_with_extra_context(self):
|
||||
"""测试带额外上下文的模板解析"""
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
resolver = VariableResolver()
|
||||
|
||||
template = "当前场景:{{scene}},用户问题:{{query}}"
|
||||
extra_context = {
|
||||
"scene": "售后服务",
|
||||
"query": "退款流程",
|
||||
}
|
||||
|
||||
resolved = resolver.resolve(template, [], extra_context)
|
||||
|
||||
print("\n[带上下文的模板解析测试]")
|
||||
print(f"原始模板: {template}")
|
||||
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
|
||||
print(f"解析结果: {resolved}")
|
||||
|
||||
assert "售后服务" in resolved
|
||||
assert "退款流程" in resolved
|
||||
|
||||
|
||||
class TestToolCallRecording:
|
||||
"""工具调用记录测试"""
|
||||
|
||||
def test_tool_call_trace_creation(self):
|
||||
"""测试工具调用追踪创建"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=150,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest="query=退款流程",
|
||||
result_digest="hits=3",
|
||||
arguments={
|
||||
"query": "退款流程是什么",
|
||||
"scene": "after_sale",
|
||||
"context": {"product_type": "vip"},
|
||||
},
|
||||
result={
|
||||
"success": True,
|
||||
"hits": [
|
||||
{"id": "1", "content": "退款流程说明...", "score": 0.95},
|
||||
{"id": "2", "content": "退款注意事项...", "score": 0.88},
|
||||
{"id": "3", "content": "退款时效说明...", "score": 0.82},
|
||||
],
|
||||
"applied_filter": {"product_type": "vip"},
|
||||
},
|
||||
)
|
||||
|
||||
print("\n[工具调用追踪测试]")
|
||||
print(f"Tool Name: {trace.tool_name}")
|
||||
print(f"Tool Type: {trace.tool_type.value}")
|
||||
print(f"Status: {trace.status.value}")
|
||||
print(f"Duration: {trace.duration_ms}ms")
|
||||
print(f"\n[入参详情]")
|
||||
if trace.arguments:
|
||||
for key, value in trace.arguments.items():
|
||||
print(f" - {key}: {value}")
|
||||
print(f"\n[返回结果]")
|
||||
if trace.result:
|
||||
if isinstance(trace.result, dict):
|
||||
print(f" - success: {trace.result.get('success')}")
|
||||
print(f" - hits count: {len(trace.result.get('hits', []))}")
|
||||
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
|
||||
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
|
||||
|
||||
assert trace.tool_name == "kb_search_dynamic"
|
||||
assert trace.status == ToolCallStatus.OK
|
||||
assert trace.arguments is not None
|
||||
assert trace.result is not None
|
||||
|
||||
def test_tool_call_timeout_trace(self):
|
||||
"""测试工具调用超时追踪"""
|
||||
trace = ToolCallTrace(
|
||||
tool_name="kb_search_dynamic",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=2000,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments={"query": "测试查询"},
|
||||
)
|
||||
|
||||
print("\n[工具调用超时追踪测试]")
|
||||
print(f"Tool Name: {trace.tool_name}")
|
||||
print(f"Status: {trace.status.value}")
|
||||
print(f"Error Code: {trace.error_code}")
|
||||
print(f"Duration: {trace.duration_ms}ms")
|
||||
|
||||
assert trace.status == ToolCallStatus.TIMEOUT
|
||||
assert trace.error_code == "TOOL_TIMEOUT"
|
||||
|
||||
|
||||
class TestExecutionModeRouting:
|
||||
"""执行模式路由测试"""
|
||||
|
||||
def test_policy_router_decision(self):
|
||||
"""测试策略路由器决策"""
|
||||
from app.services.mid.policy_router import PolicyRouter, IntentMatch
|
||||
|
||||
router = PolicyRouter()
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "正常对话 -> Agent模式",
|
||||
"user_message": "你好,请问有什么可以帮助我的?",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.AGENT,
|
||||
},
|
||||
{
|
||||
"name": "高风险退款 -> Micro Flow模式",
|
||||
"user_message": "我要退款",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.MICRO_FLOW,
|
||||
},
|
||||
{
|
||||
"name": "转人工请求 -> Transfer模式",
|
||||
"user_message": "帮我转人工",
|
||||
"session_mode": "BOT_ACTIVE",
|
||||
"expected_mode": ExecutionMode.TRANSFER,
|
||||
},
|
||||
{
|
||||
"name": "人工模式 -> Transfer模式",
|
||||
"user_message": "你好",
|
||||
"session_mode": "HUMAN_ACTIVE",
|
||||
"expected_mode": ExecutionMode.TRANSFER,
|
||||
},
|
||||
]
|
||||
|
||||
print("\n[策略路由器决策测试]")
|
||||
|
||||
for tc in test_cases:
|
||||
result = router.route(
|
||||
user_message=tc["user_message"],
|
||||
session_mode=tc["session_mode"],
|
||||
)
|
||||
|
||||
print(f"\n测试用例: {tc['name']}")
|
||||
print(f" 用户消息: {tc['user_message']}")
|
||||
print(f" 会话模式: {tc['session_mode']}")
|
||||
print(f" 期望模式: {tc['expected_mode'].value}")
|
||||
print(f" 实际模式: {result.mode.value}")
|
||||
|
||||
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
|
||||
|
||||
|
||||
class TestKBSearchDynamic:
|
||||
"""知识库动态检索测试"""
|
||||
|
||||
def test_kb_search_result_structure(self):
|
||||
"""测试知识库检索结果结构"""
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||||
|
||||
result = KbSearchDynamicResult(
|
||||
success=True,
|
||||
hits=[
|
||||
{
|
||||
"id": "chunk_001",
|
||||
"content": "退款流程:1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
|
||||
"score": 0.92,
|
||||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
|
||||
},
|
||||
{
|
||||
"id": "chunk_002",
|
||||
"content": "退款时效:一般3-5个工作日到账...",
|
||||
"score": 0.85,
|
||||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
|
||||
},
|
||||
],
|
||||
applied_filter={"scene": "after_sale", "product_type": "vip"},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"source": "slot_state"},
|
||||
filter_sources={"scene": "slot", "product_type": "context"},
|
||||
duration_ms=120,
|
||||
)
|
||||
|
||||
print("\n[知识库检索结果结构测试]")
|
||||
print(f"Success: {result.success}")
|
||||
print(f"Hits Count: {len(result.hits)}")
|
||||
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
|
||||
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
|
||||
print(f"Duration: {result.duration_ms}ms")
|
||||
|
||||
print(f"\n[命中详情]")
|
||||
for i, hit in enumerate(result.hits, 1):
|
||||
print(f" [{i}] ID: {hit['id']}")
|
||||
print(f" Score: {hit['score']}")
|
||||
print(f" Content: {hit['content'][:50]}...")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.hits) == 2
|
||||
assert result.applied_filter is not None
|
||||
|
||||
def test_kb_search_missing_slots(self):
|
||||
"""测试知识库检索缺失槽位"""
|
||||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||||
|
||||
result = KbSearchDynamicResult(
|
||||
success=False,
|
||||
hits=[],
|
||||
applied_filter={},
|
||||
missing_required_slots=[
|
||||
{
|
||||
"field_key": "order_id",
|
||||
"label": "订单号",
|
||||
"reason": "必填字段缺失",
|
||||
"ask_back_prompt": "请提供您的订单号",
|
||||
},
|
||||
],
|
||||
filter_debug={"source": "builder"},
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
print("\n[知识库检索缺失槽位测试]")
|
||||
print(f"Success: {result.success}")
|
||||
print(f"Fallback Reason: {result.fallback_reason_code}")
|
||||
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
|
||||
|
||||
assert result.success is False
|
||||
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
|
||||
assert len(result.missing_required_slots) == 1
|
||||
|
||||
|
||||
class TestTimingBreakdown:
|
||||
"""耗时分解测试"""
|
||||
|
||||
def test_timing_breakdown_structure(self):
|
||||
"""测试耗时分解结构"""
|
||||
timings = [
|
||||
TimingRecord("intent_matching", 0, 0.05, 50),
|
||||
TimingRecord("high_risk_check", 0.05, 0.08, 30),
|
||||
TimingRecord("kb_search", 0.08, 0.2, 120),
|
||||
TimingRecord("llm_generation", 0.2, 1.5, 1300),
|
||||
TimingRecord("output_guardrail", 1.5, 1.55, 50),
|
||||
TimingRecord("response_formatting", 1.55, 1.6, 50),
|
||||
]
|
||||
|
||||
total = sum(t.duration_ms for t in timings)
|
||||
|
||||
print("\n[耗时分解测试]")
|
||||
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
|
||||
print("-" * 45)
|
||||
|
||||
for t in timings:
|
||||
percentage = (t.duration_ms / total * 100) if total > 0 else 0
|
||||
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
|
||||
|
||||
print("-" * 45)
|
||||
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
|
||||
|
||||
assert total == 1600
|
||||
|
||||
|
||||
def run_manual_test():
|
||||
"""手动运行测试"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="中台对话集成测试")
|
||||
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
|
||||
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
|
||||
parser.add_argument("--api-key", default=None, help="API Key")
|
||||
parser.add_argument("--message", default="你好", help="测试消息")
|
||||
parser.add_argument("--scene", default=None, help="场景标识")
|
||||
parser.add_argument("--interactive", action="store_true", help="交互模式")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tester = DialogueIntegrationTester(
|
||||
base_url=args.url,
|
||||
tenant_id=args.tenant,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
|
||||
if args.interactive:
|
||||
print("\n=== 中台对话集成测试 - 交互模式 ===")
|
||||
print("输入 'quit' 退出\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = input("请输入消息: ").strip()
|
||||
if message.lower() == "quit":
|
||||
break
|
||||
|
||||
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
|
||||
|
||||
result = asyncio.run(tester.send_dialogue(
|
||||
user_message=message,
|
||||
scene=scene,
|
||||
))
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
else:
|
||||
result = asyncio.run(tester.send_dialogue(
|
||||
user_message=args.message,
|
||||
scene=args.scene,
|
||||
))
|
||||
|
||||
tester.print_result(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_manual_test()
|
||||
|
|
@ -0,0 +1,541 @@
|
|||
"""
|
||||
Unit tests for Retrieval Strategy Service.
|
||||
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.retrieval_strategy import (
|
||||
ReactMode,
|
||||
RolloutConfig,
|
||||
RolloutMode,
|
||||
StrategyType,
|
||||
RetrievalStrategyStatus,
|
||||
RetrievalStrategySwitchRequest,
|
||||
RetrievalStrategyValidationRequest,
|
||||
ValidationResult,
|
||||
)
|
||||
from app.services.retrieval.strategy_service import (
|
||||
RetrievalStrategyService,
|
||||
StrategyState,
|
||||
get_strategy_service,
|
||||
)
|
||||
from app.services.retrieval.strategy_audit import (
|
||||
StrategyAuditService,
|
||||
get_audit_service,
|
||||
)
|
||||
from app.services.retrieval.strategy_metrics import (
|
||||
StrategyMetricsService,
|
||||
get_metrics_service,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalStrategySchemas:
|
||||
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
|
||||
|
||||
def test_rollout_config_off_mode(self):
|
||||
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
|
||||
config = RolloutConfig(mode=RolloutMode.OFF)
|
||||
assert config.mode == RolloutMode.OFF
|
||||
assert config.percentage is None
|
||||
assert config.allowlist is None
|
||||
|
||||
def test_rollout_config_percentage_mode(self):
|
||||
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
|
||||
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
||||
assert config.mode == RolloutMode.PERCENTAGE
|
||||
assert config.percentage == 50.0
|
||||
|
||||
def test_rollout_config_percentage_mode_missing_value(self):
|
||||
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
|
||||
with pytest.raises(ValueError, match="percentage is required"):
|
||||
RolloutConfig(mode=RolloutMode.PERCENTAGE)
|
||||
|
||||
def test_rollout_config_allowlist_mode(self):
|
||||
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
|
||||
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
|
||||
assert config.mode == RolloutMode.ALLOWLIST
|
||||
assert config.allowlist == ["tenant1", "tenant2"]
|
||||
|
||||
def test_rollout_config_allowlist_mode_missing_value(self):
|
||||
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
|
||||
with pytest.raises(ValueError, match="allowlist is required"):
|
||||
RolloutConfig(mode=RolloutMode.ALLOWLIST)
|
||||
|
||||
def test_retrieval_strategy_status(self):
|
||||
"""[AC-AISVC-RES-01] Status should contain all required fields."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.OFF)
|
||||
status = RetrievalStrategyStatus(
|
||||
active_strategy=StrategyType.DEFAULT,
|
||||
react_mode=ReactMode.NON_REACT,
|
||||
rollout=rollout,
|
||||
)
|
||||
assert status.active_strategy == StrategyType.DEFAULT
|
||||
assert status.react_mode == ReactMode.NON_REACT
|
||||
assert status.rollout.mode == RolloutMode.OFF
|
||||
|
||||
def test_switch_request_minimal(self):
|
||||
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
|
||||
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
|
||||
assert request.target_strategy == StrategyType.ENHANCED
|
||||
assert request.react_mode is None
|
||||
assert request.rollout is None
|
||||
assert request.reason is None
|
||||
|
||||
def test_switch_request_full(self):
|
||||
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
rollout=rollout,
|
||||
reason="Testing enhanced strategy",
|
||||
)
|
||||
assert request.target_strategy == StrategyType.ENHANCED
|
||||
assert request.react_mode == ReactMode.REACT
|
||||
assert request.rollout.percentage == 30.0
|
||||
assert request.reason == "Testing enhanced strategy"
|
||||
|
||||
|
||||
class TestRetrievalStrategyService:
|
||||
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh service instance for each test."""
|
||||
return RetrievalStrategyService()
|
||||
|
||||
def test_get_current_status_default(self, service):
|
||||
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
|
||||
status = service.get_current_status()
|
||||
|
||||
assert status.active_strategy == StrategyType.DEFAULT
|
||||
assert status.react_mode == ReactMode.NON_REACT
|
||||
assert status.rollout.mode == RolloutMode.OFF
|
||||
|
||||
def test_switch_strategy_to_enhanced(self, service):
|
||||
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.previous.active_strategy == StrategyType.DEFAULT
|
||||
assert response.current.active_strategy == StrategyType.ENHANCED
|
||||
assert response.current.react_mode == ReactMode.REACT
|
||||
|
||||
def test_switch_strategy_with_grayscale_percentage(self, service):
|
||||
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
|
||||
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.current.active_strategy == StrategyType.ENHANCED
|
||||
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
|
||||
assert response.current.rollout.percentage == 50.0
|
||||
|
||||
def test_switch_strategy_with_allowlist(self, service):
|
||||
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
|
||||
rollout = RolloutConfig(
|
||||
mode=RolloutMode.ALLOWLIST,
|
||||
allowlist=["tenant_a", "tenant_b"],
|
||||
)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
|
||||
response = service.switch_strategy(request)
|
||||
|
||||
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
|
||||
assert "tenant_a" in response.current.rollout.allowlist
|
||||
|
||||
def test_rollback_strategy(self, service):
|
||||
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
response = service.rollback_strategy()
|
||||
|
||||
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
||||
assert response.rollback_to.react_mode == ReactMode.NON_REACT
|
||||
|
||||
def test_rollback_without_previous_returns_default(self, service):
|
||||
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
|
||||
response = service.rollback_strategy()
|
||||
|
||||
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
||||
|
||||
def test_should_use_enhanced_strategy_default(self, service):
|
||||
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
|
||||
assert service.should_use_enhanced_strategy("tenant_a") is False
|
||||
|
||||
def test_should_use_enhanced_strategy_with_allowlist(self, service):
|
||||
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
||||
rollout = RolloutConfig(
|
||||
mode=RolloutMode.ALLOWLIST,
|
||||
allowlist=["tenant_a"],
|
||||
)
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
rollout=rollout,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
assert service.should_use_enhanced_strategy("tenant_a") is True
|
||||
assert service.should_use_enhanced_strategy("tenant_b") is False
|
||||
|
||||
def test_get_route_mode_react(self, service):
|
||||
"""[AC-AISVC-RES-10] React mode should return react route."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
react_mode=ReactMode.REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
route = service.get_route_mode("test query")
|
||||
assert route == "react"
|
||||
|
||||
def test_get_route_mode_direct(self, service):
|
||||
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.DEFAULT,
|
||||
react_mode=ReactMode.NON_REACT,
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
route = service.get_route_mode("test query")
|
||||
assert route == "direct"
|
||||
|
||||
def test_get_route_mode_auto_short_query(self, service):
|
||||
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
|
||||
service._state.react_mode = None
|
||||
|
||||
route = service._auto_route("短问题", confidence=0.8)
|
||||
|
||||
assert route == "direct"
|
||||
|
||||
def test_get_route_mode_auto_multiple_conditions(self, service):
|
||||
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
|
||||
route = service._auto_route("查询订单状态和物流信息")
|
||||
|
||||
assert route == "react"
|
||||
|
||||
def test_get_route_mode_auto_low_confidence(self, service):
|
||||
"""[AC-AISVC-RES-13] Low confidence should use react route."""
|
||||
route = service._auto_route("test query", confidence=0.3)
|
||||
|
||||
assert route == "react"
|
||||
|
||||
def test_get_switch_history(self, service):
|
||||
"""Should track switch history."""
|
||||
request = RetrievalStrategySwitchRequest(
|
||||
target_strategy=StrategyType.ENHANCED,
|
||||
reason="Testing",
|
||||
)
|
||||
service.switch_strategy(request)
|
||||
|
||||
history = service.get_switch_history()
|
||||
|
||||
assert len(history) == 1
|
||||
assert history[0]["to_strategy"] == "enhanced"
|
||||
|
||||
|
||||
class TestRetrievalStrategyValidation:
|
||||
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return RetrievalStrategyService()
|
||||
|
||||
def test_validate_default_strategy(self, service):
|
||||
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
assert response.passed is True
|
||||
|
||||
def test_validate_enhanced_strategy(self, service):
|
||||
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
assert isinstance(response.passed, bool)
|
||||
assert len(response.results) > 0
|
||||
|
||||
def test_validate_specific_checks(self, service):
|
||||
"""[AC-AISVC-RES-06] Should run specific validation checks."""
|
||||
request = RetrievalStrategyValidationRequest(
|
||||
strategy=StrategyType.ENHANCED,
|
||||
checks=["metadata_consistency", "performance_budget"],
|
||||
)
|
||||
|
||||
response = service.validate_strategy(request)
|
||||
|
||||
check_names = [r.check for r in response.results]
|
||||
assert "metadata_consistency" in check_names
|
||||
assert "performance_budget" in check_names
|
||||
|
||||
def test_check_metadata_consistency(self, service):
|
||||
"""[AC-AISVC-RES-04] Metadata consistency check."""
|
||||
result = service._check_metadata_consistency(StrategyType.DEFAULT)
|
||||
|
||||
assert result.check == "metadata_consistency"
|
||||
assert result.passed is True
|
||||
|
||||
def test_check_rrf_config(self, service):
|
||||
"""[AC-AISVC-RES-02] RRF config check."""
|
||||
result = service._check_rrf_config(StrategyType.DEFAULT)
|
||||
|
||||
assert result.check == "rrf_config"
|
||||
assert isinstance(result.passed, bool)
|
||||
|
||||
def test_check_performance_budget(self, service):
|
||||
"""[AC-AISVC-RES-08] Performance budget check."""
|
||||
result = service._check_performance_budget(
|
||||
StrategyType.ENHANCED,
|
||||
ReactMode.REACT,
|
||||
)
|
||||
|
||||
assert result.check == "performance_budget"
|
||||
assert isinstance(result.passed, bool)
|
||||
|
||||
|
||||
class TestStrategyAuditService:
|
||||
"""[AC-AISVC-RES-07] Tests for audit service."""
|
||||
|
||||
@pytest.fixture
|
||||
def audit_service(self):
|
||||
return StrategyAuditService(max_entries=100)
|
||||
|
||||
def test_log_switch_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-07] Should log switch operation."""
|
||||
audit_service.log(
|
||||
operation="switch",
|
||||
previous_strategy="default",
|
||||
new_strategy="enhanced",
|
||||
reason="Testing",
|
||||
operator="admin",
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log()
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "switch"
|
||||
assert entries[0].previous_strategy == "default"
|
||||
assert entries[0].new_strategy == "enhanced"
|
||||
|
||||
def test_log_rollback_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-07] Should log rollback operation."""
|
||||
audit_service.log_rollback(
|
||||
previous_strategy="enhanced",
|
||||
new_strategy="default",
|
||||
reason="Performance issue",
|
||||
operator="admin",
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log(operation="rollback")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "rollback"
|
||||
|
||||
def test_log_validation_operation(self, audit_service):
|
||||
"""[AC-AISVC-RES-06] Should log validation operation."""
|
||||
audit_service.log_validation(
|
||||
strategy="enhanced",
|
||||
checks=["metadata_consistency"],
|
||||
passed=True,
|
||||
)
|
||||
|
||||
entries = audit_service.get_audit_log(operation="validate")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].operation == "validate"
|
||||
|
||||
def test_get_audit_log_with_limit(self, audit_service):
|
||||
"""Should limit audit log entries."""
|
||||
for i in range(10):
|
||||
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
|
||||
|
||||
entries = audit_service.get_audit_log(limit=5)
|
||||
assert len(entries) == 5
|
||||
|
||||
def test_get_audit_stats(self, audit_service):
|
||||
"""Should return audit statistics."""
|
||||
audit_service.log(operation="switch", new_strategy="enhanced")
|
||||
audit_service.log(operation="rollback", new_strategy="default")
|
||||
|
||||
stats = audit_service.get_audit_stats()
|
||||
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["operation_counts"]["switch"] == 1
|
||||
assert stats["operation_counts"]["rollback"] == 1
|
||||
|
||||
def test_clear_audit_log(self, audit_service):
|
||||
"""Should clear audit log."""
|
||||
audit_service.log(operation="switch", new_strategy="enhanced")
|
||||
assert len(audit_service.get_audit_log()) == 1
|
||||
|
||||
count = audit_service.clear_audit_log()
|
||||
assert count == 1
|
||||
assert len(audit_service.get_audit_log()) == 0
|
||||
|
||||
|
||||
class TestStrategyMetricsService:
|
||||
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
|
||||
|
||||
@pytest.fixture
|
||||
def metrics_service(self):
|
||||
return StrategyMetricsService()
|
||||
|
||||
def test_record_request(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record request metrics."""
|
||||
metrics_service.record_request(
|
||||
latency_ms=100.0,
|
||||
success=True,
|
||||
route_mode="direct",
|
||||
)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.successful_requests == 1
|
||||
assert metrics.avg_latency_ms == 100.0
|
||||
|
||||
def test_record_failed_request(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record failed request."""
|
||||
metrics_service.record_request(latency_ms=50.0, success=False)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.failed_requests == 1
|
||||
|
||||
def test_record_fallback(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should record fallback count."""
|
||||
metrics_service.record_request(
|
||||
latency_ms=100.0,
|
||||
success=True,
|
||||
fallback=True,
|
||||
)
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.fallback_count == 1
|
||||
|
||||
def test_record_route_metrics(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should track route mode metrics."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
|
||||
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
|
||||
|
||||
route_metrics = metrics_service.get_route_metrics()
|
||||
assert "react" in route_metrics
|
||||
assert "direct" in route_metrics
|
||||
|
||||
def test_get_all_metrics(self, metrics_service):
|
||||
"""Should get metrics for all strategies."""
|
||||
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
|
||||
all_metrics = metrics_service.get_all_metrics()
|
||||
|
||||
assert StrategyType.DEFAULT.value in all_metrics
|
||||
assert StrategyType.ENHANCED.value in all_metrics
|
||||
|
||||
def test_get_performance_summary(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should get performance summary."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
metrics_service.record_request(latency_ms=200.0, success=True)
|
||||
metrics_service.record_request(latency_ms=50.0, success=False)
|
||||
|
||||
summary = metrics_service.get_performance_summary()
|
||||
|
||||
assert summary["total_requests"] == 3
|
||||
assert summary["successful_requests"] == 2
|
||||
assert summary["failed_requests"] == 1
|
||||
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
|
||||
|
||||
def test_check_performance_threshold_ok(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
|
||||
result = metrics_service.check_performance_threshold(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
max_latency_ms=5000.0,
|
||||
max_error_rate=0.1,
|
||||
)
|
||||
|
||||
assert result["latency_ok"] is True
|
||||
assert result["error_rate_ok"] is True
|
||||
assert result["overall_ok"] is True
|
||||
|
||||
def test_check_performance_threshold_exceeded(self, metrics_service):
|
||||
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
|
||||
metrics_service.record_request(latency_ms=6000.0, success=True)
|
||||
metrics_service.record_request(latency_ms=100.0, success=False)
|
||||
|
||||
result = metrics_service.check_performance_threshold(
|
||||
strategy=StrategyType.DEFAULT,
|
||||
max_latency_ms=5000.0,
|
||||
max_error_rate=0.1,
|
||||
)
|
||||
|
||||
assert result["latency_ok"] is False or result["error_rate_ok"] is False
|
||||
|
||||
def test_reset_metrics(self, metrics_service):
|
||||
"""Should reset metrics."""
|
||||
metrics_service.record_request(latency_ms=100.0, success=True)
|
||||
metrics_service.reset_metrics()
|
||||
|
||||
metrics = metrics_service.get_metrics()
|
||||
assert metrics.total_requests == 0
|
||||
|
||||
|
||||
class TestSingletonInstances:
|
||||
"""Tests for singleton instance getters."""
|
||||
|
||||
def test_get_strategy_service_singleton(self):
|
||||
"""Should return same strategy service instance."""
|
||||
from app.services.retrieval.strategy_service import _strategy_service
|
||||
|
||||
import app.services.retrieval.strategy_service as module
|
||||
module._strategy_service = None
|
||||
|
||||
service1 = get_strategy_service()
|
||||
service2 = get_strategy_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
||||
def test_get_audit_service_singleton(self):
|
||||
"""Should return same audit service instance."""
|
||||
from app.services.retrieval.strategy_audit import _audit_service
|
||||
|
||||
import app.services.retrieval.strategy_audit as module
|
||||
module._audit_service = None
|
||||
|
||||
service1 = get_audit_service()
|
||||
service2 = get_audit_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
||||
def test_get_metrics_service_singleton(self):
|
||||
"""Should return same metrics service instance."""
|
||||
from app.services.retrieval.strategy_metrics import _metrics_service
|
||||
|
||||
import app.services.retrieval.strategy_metrics as module
|
||||
module._metrics_service = None
|
||||
|
||||
service1 = get_metrics_service()
|
||||
service2 = get_metrics_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Integration tests for Retrieval Strategy API.
|
||||
[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints.
|
||||
|
||||
Tests the full API flow:
|
||||
- GET /strategy/retrieval/current
|
||||
- POST /strategy/retrieval/switch
|
||||
- POST /strategy/retrieval/validate
|
||||
- POST /strategy/retrieval/rollback
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_api_key_service():
|
||||
"""
|
||||
Mock API key service to bypass authentication in tests.
|
||||
"""
|
||||
mock_service = MagicMock()
|
||||
mock_service._initialized = True
|
||||
mock_service._keys_cache = {"test-api-key": MagicMock()}
|
||||
|
||||
mock_validation = MagicMock()
|
||||
mock_validation.ok = True
|
||||
mock_validation.reason = None
|
||||
mock_service.validate_key_with_context.return_value = mock_validation
|
||||
|
||||
with patch("app.services.api_key.get_api_key_service", return_value=mock_service):
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_strategy_state():
|
||||
"""
|
||||
Reset strategy state before and after each test.
|
||||
"""
|
||||
from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router
|
||||
from app.services.retrieval.strategy.config import RetrievalStrategyConfig
|
||||
|
||||
set_strategy_router(None)
|
||||
router = get_strategy_router()
|
||||
router.update_config(RetrievalStrategyConfig())
|
||||
yield
|
||||
set_strategy_router(None)
|
||||
router = get_strategy_router()
|
||||
router.update_config(RetrievalStrategyConfig())
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIIntegration:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Integration tests for retrieval strategy API.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_get_current_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-01] GET /current should return strategy status.
|
||||
"""
|
||||
response = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "active_strategy" in data
|
||||
assert "grayscale" in data
|
||||
assert data["active_strategy"] in ["default", "enhanced"]
|
||||
|
||||
def test_switch_strategy_to_enhanced(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-02] POST /switch should switch to enhanced strategy.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "success" in data
|
||||
assert data["success"] is True
|
||||
assert data["current_strategy"] == "enhanced"
|
||||
|
||||
def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] POST /switch should accept grayscale percentage.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"enabled": True,
|
||||
"percentage": 30.0,
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
|
||||
def test_switch_strategy_with_allowlist(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] POST /switch should accept allowlist.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"enabled": True,
|
||||
"allowlist": ["tenant_a", "tenant_b"],
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["success"] is True
|
||||
|
||||
def test_validate_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-06] POST /validate should validate strategy.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "valid" in data
|
||||
assert "errors" in data
|
||||
assert isinstance(data["valid"], bool)
|
||||
|
||||
def test_validate_default_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-06] Default strategy should pass validation.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={
|
||||
"active_strategy": "default",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["valid"] is True
|
||||
|
||||
def test_rollback_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-07] POST /rollback should rollback to default.
|
||||
"""
|
||||
client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/strategy/retrieval/rollback",
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "success" in data
|
||||
assert data["current_strategy"] == "default"
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIValidation:
|
||||
"""
|
||||
[AC-AISVC-RES-03] Tests for API request validation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_switch_invalid_strategy(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] Invalid strategy value should return error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "invalid_strategy",
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 422, 500]
|
||||
|
||||
def test_switch_percentage_out_of_range(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-03] Percentage > 100 should return validation error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {
|
||||
"percentage": 150.0,
|
||||
},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIFlow:
|
||||
"""
|
||||
[AC-AISVC-RES-01~15] Tests for complete API flow scenarios.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_headers(self):
|
||||
return {
|
||||
"X-Tenant-Id": "test@ash@2026",
|
||||
"X-API-Key": "test-api-key",
|
||||
}
|
||||
|
||||
def test_complete_strategy_lifecycle(self, client, valid_headers):
|
||||
"""
|
||||
[AC-AISVC-RES-01~07] Test complete strategy lifecycle:
|
||||
1. Get current strategy
|
||||
2. Switch to enhanced
|
||||
3. Validate
|
||||
4. Rollback
|
||||
5. Verify back to default
|
||||
"""
|
||||
current = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert current.status_code == 200
|
||||
assert current.json()["active_strategy"] == "default"
|
||||
|
||||
switch = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={
|
||||
"active_strategy": "enhanced",
|
||||
"grayscale": {"enabled": True, "percentage": 50.0},
|
||||
},
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert switch.status_code == 200
|
||||
assert switch.json()["current_strategy"] == "enhanced"
|
||||
|
||||
validate = client.post(
|
||||
"/strategy/retrieval/validate",
|
||||
json={"active_strategy": "enhanced"},
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert validate.status_code == 200
|
||||
|
||||
rollback = client.post(
|
||||
"/strategy/retrieval/rollback",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert rollback.status_code == 200
|
||||
assert rollback.json()["current_strategy"] == "default"
|
||||
|
||||
final = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=valid_headers,
|
||||
)
|
||||
assert final.status_code == 200
|
||||
assert final.json()["active_strategy"] == "default"
|
||||
|
||||
|
||||
class TestRetrievalStrategyAPIMissingTenant:
|
||||
"""
|
||||
Tests for API behavior without tenant ID.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_headers(self):
|
||||
return {"X-API-Key": "test-api-key"}
|
||||
|
||||
def test_current_without_tenant(self, client, api_key_headers):
|
||||
"""
|
||||
Missing X-Tenant-Id should return 400.
|
||||
"""
|
||||
response = client.get(
|
||||
"/strategy/retrieval/current",
|
||||
headers=api_key_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_switch_without_tenant(self, client, api_key_headers):
|
||||
"""
|
||||
Missing X-Tenant-Id should return 400.
|
||||
"""
|
||||
response = client.post(
|
||||
"/strategy/retrieval/switch",
|
||||
json={"active_strategy": "enhanced"},
|
||||
headers=api_key_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
Loading…
Reference in New Issue