[AC-TEST] test: 新增单元测试和集成测试

- 新增 test_image_parser 图片解析器测试
- 新增 test_llm_multi_usage_config LLM 多用途配置测试
- 新增 test_markdown_chunker Markdown 分块测试
- 新增 test_metadata_auto_inference 元数据推断测试
- 新增 test_mid_dialogue_integration 对话集成测试
- 新增 test_retrieval_strategy 检索策略测试
- 新增 test_retrieval_strategy_integration 检索策略集成测试
This commit is contained in:
MerCry 2026-03-11 19:10:05 +08:00
parent 1490235b8f
commit a6276522c8
7 changed files with 3455 additions and 0 deletions

View File

@ -0,0 +1,375 @@
"""
Unit tests for ImageParser.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.document.image_parser import (
ImageChunk,
ImageParseResult,
ImageParser,
)
class TestImageParserBasics:
"""Test basic functionality of ImageParser."""
def test_supported_extensions(self):
"""Test that ImageParser supports correct extensions."""
parser = ImageParser()
extensions = parser.get_supported_extensions()
expected_extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
assert extensions == expected_extensions
def test_get_mime_type(self):
"""Test MIME type mapping."""
parser = ImageParser()
assert parser._get_mime_type(".jpg") == "image/jpeg"
assert parser._get_mime_type(".jpeg") == "image/jpeg"
assert parser._get_mime_type(".png") == "image/png"
assert parser._get_mime_type(".gif") == "image/gif"
assert parser._get_mime_type(".webp") == "image/webp"
assert parser._get_mime_type(".bmp") == "image/bmp"
assert parser._get_mime_type(".tiff") == "image/tiff"
assert parser._get_mime_type(".tif") == "image/tiff"
assert parser._get_mime_type(".unknown") == "image/jpeg"
class TestImageChunkParsing:
"""Test LLM response parsing functionality."""
def test_extract_json_from_plain_json(self):
"""Test extracting JSON from plain JSON response."""
parser = ImageParser()
json_str = '{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello", "chunk_type": "text", "keywords": ["key"]}]}'
result = parser._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self):
"""Test extracting JSON from markdown code block."""
parser = ImageParser()
markdown = """Here is the analysis:
```json
{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello"}]}
```
Hope this helps!"""
result = parser._extract_json(markdown)
assert "image_summary" in result
assert "test" in result
def test_extract_json_from_text_with_json(self):
"""Test extracting JSON from text with JSON embedded."""
parser = ImageParser()
text = "The result is: {'image_summary': 'summary', 'chunks': []}"
result = parser._extract_json(text)
assert "image_summary" in result
assert "chunks" in result
def test_parse_llm_response_valid_json(self):
"""Test parsing valid JSON response from LLM."""
parser = ImageParser()
response = json.dumps({
"image_summary": "测试图片",
"total_chunks": 2,
"chunks": [
{
"chunk_index": 0,
"content": "这是第一块内容",
"chunk_type": "text",
"keywords": ["测试", "内容"]
},
{
"chunk_index": 1,
"content": "这是第二块内容,包含表格数据",
"chunk_type": "table",
"keywords": ["表格", "数据"]
}
]
})
result = parser._parse_llm_response(response)
assert result.image_summary == "测试图片"
assert len(result.chunks) == 2
assert result.chunks[0].content == "这是第一块内容"
assert result.chunks[0].chunk_type == "text"
assert result.chunks[0].keywords == ["测试", "内容"]
assert result.chunks[1].chunk_type == "table"
assert result.chunks[1].keywords == ["表格", "数据"]
def test_parse_llm_response_empty_chunks(self):
"""Test handling response with empty chunks."""
parser = ImageParser()
response = json.dumps({
"image_summary": "测试",
"chunks": []
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == response
def test_parse_llm_response_invalid_json(self):
"""Test handling invalid JSON response with fallback."""
parser = ImageParser()
response = "This is not JSON at all"
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == "This is not JSON at all"
def test_parse_llm_response_partial_json(self):
"""Test handling response with partial/invalid JSON uses fallback."""
parser = ImageParser()
response = '{"image_summary": "test" some text here {"chunks": []}'
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.chunks[0].content == response
class TestImageChunkDataClass:
"""Test ImageChunk dataclass functionality."""
def test_image_chunk_creation(self):
"""Test creating ImageChunk."""
chunk = ImageChunk(
chunk_index=0,
content="Test content",
chunk_type="text",
keywords=["test", "content"]
)
assert chunk.chunk_index == 0
assert chunk.content == "Test content"
assert chunk.chunk_type == "text"
assert chunk.keywords == ["test", "content"]
def test_image_chunk_default_values(self):
"""Test ImageChunk with default values."""
chunk = ImageChunk(chunk_index=0, content="Test")
assert chunk.chunk_type == "text"
assert chunk.keywords == []
def test_image_parse_result_creation(self):
"""Test creating ImageParseResult."""
chunks = [
ImageChunk(chunk_index=0, content="Chunk 1"),
ImageChunk(chunk_index=1, content="Chunk 2"),
]
result = ImageParseResult(
image_summary="Test summary",
chunks=chunks,
raw_text="Chunk 1\n\nChunk 2",
source_path="/path/to/image.png",
file_size=1024,
metadata={"format": "png"}
)
assert result.image_summary == "Test summary"
assert len(result.chunks) == 2
assert result.raw_text == "Chunk 1\n\nChunk 2"
assert result.file_size == 1024
assert result.metadata["format"] == "png"
class TestChunkTypes:
"""Test different chunk types."""
def test_text_chunk_type(self):
"""Test text chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Text content",
"chunks": [
{
"chunk_index": 0,
"content": "Plain text content",
"chunk_type": "text",
"keywords": ["text"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "text"
def test_table_chunk_type(self):
"""Test table chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Table content",
"chunks": [
{
"chunk_index": 0,
"content": "Name | Age\n---- | ---\nJohn | 30",
"chunk_type": "table",
"keywords": ["table", "data"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "table"
def test_chart_chunk_type(self):
"""Test chart chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "Chart content",
"chunks": [
{
"chunk_index": 0,
"content": "Bar chart showing sales data",
"chunk_type": "chart",
"keywords": ["chart", "sales"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "chart"
def test_list_chunk_type(self):
"""Test list chunk type."""
parser = ImageParser()
response = json.dumps({
"image_summary": "List content",
"chunks": [
{
"chunk_index": 0,
"content": "1. First item\n2. Second item\n3. Third item",
"chunk_type": "list",
"keywords": ["list", "items"]
}
]
})
result = parser._parse_llm_response(response)
assert result.chunks[0].chunk_type == "list"
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_single_chunk_scenario(self):
"""Test single chunk scenario - simple image with one main content."""
parser = ImageParser()
response = json.dumps({
"image_summary": "简单文档截图",
"chunks": [
{
"chunk_index": 0,
"content": "这是一段完整的文档内容,包含所有的信息要点。",
"chunk_type": "text",
"keywords": ["文档", "信息"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 1
assert result.image_summary == "简单文档截图"
assert result.raw_text == "这是一段完整的文档内容,包含所有的信息要点。"
def test_multi_chunk_scenario(self):
"""Test multi-chunk scenario - complex image with multiple sections."""
parser = ImageParser()
response = json.dumps({
"image_summary": "多段落文档",
"chunks": [
{
"chunk_index": 0,
"content": "第一章:介绍部分,介绍项目的背景和目标。",
"chunk_type": "text",
"keywords": ["第一章", "介绍"]
},
{
"chunk_index": 1,
"content": "第二章:技术架构,包括前端、后端和数据库设计。",
"chunk_type": "text",
"keywords": ["第二章", "架构"]
},
{
"chunk_index": 2,
"content": "第三章:部署流程,包含开发环境和生产环境配置。",
"chunk_type": "text",
"keywords": ["第三章", "部署"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 3
assert "第一章" in result.chunks[0].content
assert "第二章" in result.chunks[1].content
assert "第三章" in result.chunks[2].content
assert result.raw_text.count("\n\n") == 2
def test_mixed_content_scenario(self):
"""Test mixed content scenario - text and table."""
parser = ImageParser()
response = json.dumps({
"image_summary": "混合内容图片",
"chunks": [
{
"chunk_index": 0,
"content": "产品介绍:本文档介绍我们的核心产品功能。",
"chunk_type": "text",
"keywords": ["产品", "功能"]
},
{
"chunk_index": 1,
"content": "产品规格表:\n| 型号 | 价格 | 库存 |\n| --- | --- | --- |\n| A1 | 100 | 50 |",
"chunk_type": "table",
"keywords": ["规格", "价格", "库存"]
},
{
"chunk_index": 2,
"content": "使用说明:\n1. 打开包装\n2. 连接电源\n3. 按下启动按钮",
"chunk_type": "list",
"keywords": ["说明", "步骤"]
}
]
})
result = parser._parse_llm_response(response)
assert len(result.chunks) == 3
assert result.chunks[0].chunk_type == "text"
assert result.chunks[1].chunk_type == "table"
assert result.chunks[2].chunk_type == "list"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,332 @@
"""
Unit tests for multi-usage LLM configuration.
Tests for LLMUsageType, LLMConfigManager multi-usage support, and API endpoints.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.llm.factory import (
LLMConfigManager,
LLMProviderFactory,
LLMUsageType,
LLM_USAGE_DESCRIPTIONS,
LLM_USAGE_DISPLAY_NAMES,
get_llm_config_manager,
)
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
settings.llm_provider = "openai"
settings.llm_api_key = "test-api-key"
settings.llm_base_url = "https://api.openai.com/v1"
settings.llm_model = "gpt-4o-mini"
settings.llm_max_tokens = 2048
settings.llm_temperature = 0.7
settings.llm_timeout_seconds = 30
settings.llm_max_retries = 3
settings.redis_enabled = False
settings.redis_url = "redis://localhost:6379"
return settings
@pytest.fixture(autouse=True)
def reset_singleton():
"""Reset singleton before and after each test."""
import app.services.llm.factory as factory
factory._llm_config_manager = None
yield
factory._llm_config_manager = None
@pytest.fixture
def isolated_config_file(tmp_path):
"""Create an isolated config file for testing."""
config_file = tmp_path / "llm_config.json"
config_file.write_text("{}")
return config_file
class TestLLMUsageType:
"""Tests for LLMUsageType enum."""
def test_usage_types_exist(self):
"""Test that required usage types exist."""
assert LLMUsageType.CHAT.value == "chat"
assert LLMUsageType.KB_PROCESSING.value == "kb_processing"
def test_usage_type_display_names(self):
"""Test that display names are defined for all usage types."""
for ut in LLMUsageType:
assert ut in LLM_USAGE_DISPLAY_NAMES
assert isinstance(LLM_USAGE_DISPLAY_NAMES[ut], str)
assert len(LLM_USAGE_DISPLAY_NAMES[ut]) > 0
def test_usage_type_descriptions(self):
"""Test that descriptions are defined for all usage types."""
for ut in LLMUsageType:
assert ut in LLM_USAGE_DESCRIPTIONS
assert isinstance(LLM_USAGE_DESCRIPTIONS[ut], str)
assert len(LLM_USAGE_DESCRIPTIONS[ut]) > 0
def test_usage_type_from_string(self):
"""Test creating usage type from string."""
assert LLMUsageType("chat") == LLMUsageType.CHAT
assert LLMUsageType("kb_processing") == LLMUsageType.KB_PROCESSING
def test_invalid_usage_type(self):
"""Test that invalid usage type raises error."""
with pytest.raises(ValueError):
LLMUsageType("invalid_type")
class TestLLMConfigManagerMultiUsage:
"""Tests for LLMConfigManager multi-usage support."""
@pytest.mark.asyncio
async def test_get_all_configs(self, mock_settings, isolated_config_file):
"""Test getting all configs at once."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
all_configs = manager.get_current_config()
for ut in LLMUsageType:
assert ut.value in all_configs
assert "provider" in all_configs[ut.value]
assert "config" in all_configs[ut.value]
@pytest.mark.asyncio
async def test_update_specific_usage_config(self, mock_settings, isolated_config_file):
"""Test updating config for a specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_usage_config(
usage_type=LLMUsageType.KB_PROCESSING,
provider="ollama",
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
)
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
assert kb_config["provider"] == "ollama"
assert kb_config["config"]["model"] == "llama3.2"
@pytest.mark.asyncio
async def test_update_all_configs(self, mock_settings, isolated_config_file):
"""Test updating all configs at once."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_config(
provider="deepseek",
config={"api_key": "test-key", "model": "deepseek-chat"},
)
for ut in LLMUsageType:
config = manager.get_current_config(ut)
assert config["provider"] == "deepseek"
@pytest.mark.asyncio
async def test_get_client_for_usage_type(self, mock_settings, isolated_config_file):
"""Test getting client for specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
chat_client = manager.get_client(LLMUsageType.CHAT)
assert chat_client is not None
kb_client = manager.get_client(LLMUsageType.KB_PROCESSING)
assert kb_client is not None
@pytest.mark.asyncio
async def test_get_chat_client(self, mock_settings, isolated_config_file):
"""Test get_chat_client convenience method."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
client = manager.get_chat_client()
assert client is not None
@pytest.mark.asyncio
async def test_get_kb_processing_client(self, mock_settings, isolated_config_file):
"""Test get_kb_processing_client convenience method."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
client = manager.get_kb_processing_client()
assert client is not None
@pytest.mark.asyncio
async def test_close_all_clients(self, mock_settings, isolated_config_file):
"""Test that close() closes all clients."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
for ut in LLMUsageType:
manager.get_client(ut)
await manager.close()
for ut in LLMUsageType:
assert manager._clients[ut] is None
@pytest.mark.asyncio
async def test_config_persistence_to_file(self, mock_settings, isolated_config_file):
"""Test that configs are persisted to file."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
await manager.update_usage_config(
usage_type=LLMUsageType.KB_PROCESSING,
provider="ollama",
config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
)
assert isolated_config_file.exists()
with open(isolated_config_file, "r", encoding="utf-8") as f:
saved = json.load(f)
assert "chat" in saved
assert "kb_processing" in saved
assert saved["kb_processing"]["provider"] == "ollama"
@pytest.mark.asyncio
async def test_load_config_from_file(self, mock_settings, tmp_path):
"""Test loading configs from file."""
config_file = tmp_path / "llm_config.json"
saved_config = {
"chat": {
"provider": "openai",
"config": {"api_key": "test-key", "model": "gpt-4o"},
},
"kb_processing": {
"provider": "ollama",
"config": {"base_url": "http://localhost:11434/v1", "model": "llama3.2"},
},
}
with open(config_file, "w", encoding="utf-8") as f:
json.dump(saved_config, f)
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
manager = LLMConfigManager()
chat_config = manager.get_current_config(LLMUsageType.CHAT)
assert chat_config["config"]["model"] == "gpt-4o"
kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING)
assert kb_config["provider"] == "ollama"
assert kb_config["config"]["model"] == "llama3.2"
@pytest.mark.asyncio
async def test_backward_compatibility_old_config_format(self, mock_settings, tmp_path):
"""Test backward compatibility with old single-config format."""
config_file = tmp_path / "llm_config.json"
old_config = {
"provider": "deepseek",
"config": {"api_key": "test-key", "model": "deepseek-chat"},
}
with open(config_file, "w", encoding="utf-8") as f:
json.dump(old_config, f)
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file):
manager = LLMConfigManager()
for ut in LLMUsageType:
config = manager.get_current_config(ut)
assert config["provider"] == "deepseek"
assert config["config"]["model"] == "deepseek-chat"
class TestLLMConfigManagerTestConnection:
"""Tests for test_connection with usage type support."""
@pytest.mark.asyncio
async def test_test_connection_with_usage_type(self, mock_settings, isolated_config_file):
"""Test connection testing with specific usage type."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
mock_client = AsyncMock()
mock_client.generate = AsyncMock(
return_value=MagicMock(
content="Test response",
model="gpt-4o-mini",
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
)
mock_client.close = AsyncMock()
with patch.object(
LLMProviderFactory, "create_client", return_value=mock_client
):
result = await manager.test_connection(
test_prompt="Hello",
usage_type=LLMUsageType.CHAT,
)
assert result["success"] is True
assert "response" in result
assert "latency_ms" in result
class TestGetLLMConfigManager:
"""Tests for get_llm_config_manager singleton."""
def test_singleton_instance(self, mock_settings, isolated_config_file):
"""Test that get_llm_config_manager returns singleton."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager1 = get_llm_config_manager()
manager2 = get_llm_config_manager()
assert manager1 is manager2
def test_reset_singleton(self, mock_settings, isolated_config_file):
"""Test resetting singleton instance."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = get_llm_config_manager()
assert manager is not None
class TestLLMConfigManagerProperties:
"""Tests for LLMConfigManager properties."""
def test_chat_config_property(self, mock_settings, isolated_config_file):
"""Test chat_config property."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
config = manager.chat_config
assert isinstance(config, dict)
assert "model" in config
def test_kb_processing_config_property(self, mock_settings, isolated_config_file):
"""Test kb_processing_config property."""
with patch("app.services.llm.factory.get_settings", return_value=mock_settings):
with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file):
manager = LLMConfigManager()
config = manager.kb_processing_config
assert isinstance(config, dict)
assert "model" in config

View File

@ -0,0 +1,530 @@
"""
Unit tests for Markdown intelligent chunker.
Tests for MarkdownParser, MarkdownChunker, and integration.
"""
import pytest
from app.services.document.markdown_chunker import (
MarkdownChunk,
MarkdownChunker,
MarkdownElement,
MarkdownElementType,
MarkdownParser,
chunk_markdown,
)
class TestMarkdownParser:
"""Tests for MarkdownParser."""
def test_parse_headers(self):
"""Test header extraction."""
text = """# Main Title
## Section 1
### Subsection 1.1
#### Deep Header
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
assert len(headers) == 4
assert headers[0].content == "Main Title"
assert headers[0].level == 1
assert headers[1].content == "Section 1"
assert headers[1].level == 2
assert headers[2].content == "Subsection 1.1"
assert headers[2].level == 3
assert headers[3].content == "Deep Header"
assert headers[3].level == 4
def test_parse_code_blocks(self):
"""Test code block extraction with language."""
text = """Here is some code:
```python
def hello():
print("Hello, World!")
```
And some more text.
"""
parser = MarkdownParser()
elements = parser.parse(text)
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(code_blocks) == 1
assert code_blocks[0].language == "python"
assert 'def hello():' in code_blocks[0].content
assert 'print("Hello, World!")' in code_blocks[0].content
def test_parse_code_blocks_no_language(self):
"""Test code block without language specification."""
text = """```
plain code here
multiple lines
```
"""
parser = MarkdownParser()
elements = parser.parse(text)
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(code_blocks) == 1
assert code_blocks[0].language == ""
assert "plain code here" in code_blocks[0].content
def test_parse_tables(self):
"""Test table extraction."""
text = """| Name | Age | City |
|------|-----|------|
| Alice | 30 | NYC |
| Bob | 25 | LA |
"""
parser = MarkdownParser()
elements = parser.parse(text)
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
assert len(tables) == 1
assert "Name" in tables[0].content
assert "Alice" in tables[0].content
assert tables[0].metadata.get("headers") == ["Name", "Age", "City"]
assert tables[0].metadata.get("row_count") == 2
def test_parse_lists(self):
"""Test list extraction."""
text = """- Item 1
- Item 2
- Item 3
"""
parser = MarkdownParser()
elements = parser.parse(text)
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
assert len(lists) == 1
assert "Item 1" in lists[0].content
assert "Item 2" in lists[0].content
assert "Item 3" in lists[0].content
def test_parse_ordered_lists(self):
"""Test ordered list extraction."""
text = """1. First
2. Second
3. Third
"""
parser = MarkdownParser()
elements = parser.parse(text)
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
assert len(lists) == 1
assert "First" in lists[0].content
assert "Second" in lists[0].content
assert "Third" in lists[0].content
def test_parse_blockquotes(self):
"""Test blockquote extraction."""
text = """> This is a quote.
> It spans multiple lines.
> And continues here.
"""
parser = MarkdownParser()
elements = parser.parse(text)
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
assert len(quotes) == 1
assert "This is a quote." in quotes[0].content
assert "It spans multiple lines." in quotes[0].content
def test_parse_paragraphs(self):
"""Test paragraph extraction."""
text = """This is the first paragraph.
This is the second paragraph.
It has multiple lines.
This is the third.
"""
parser = MarkdownParser()
elements = parser.parse(text)
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
assert len(paragraphs) == 3
assert "first paragraph" in paragraphs[0].content
assert "second paragraph" in paragraphs[1].content
def test_parse_mixed_content(self):
"""Test parsing mixed Markdown content."""
text = """# Documentation
## Introduction
This is an introduction paragraph.
## Code Example
```python
def example():
return 42
```
## Data Table
| Column A | Column B |
|----------|----------|
| Value 1 | Value 2 |
## List
- Item A
- Item B
> Note: This is important.
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE]
paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH]
assert len(headers) == 5
assert len(code_blocks) == 1
assert len(tables) == 1
assert len(lists) == 1
assert len(quotes) == 1
assert len(paragraphs) >= 1
def test_code_blocks_not_parsed_as_other_elements(self):
"""Test that code blocks don't get parsed as headers or lists."""
text = """```markdown
# This is not a header
- This is not a list
| This is not a table |
```
"""
parser = MarkdownParser()
elements = parser.parse(text)
headers = [e for e in elements if e.type == MarkdownElementType.HEADER]
lists = [e for e in elements if e.type == MarkdownElementType.LIST]
tables = [e for e in elements if e.type == MarkdownElementType.TABLE]
code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK]
assert len(headers) == 0
assert len(lists) == 0
assert len(tables) == 0
assert len(code_blocks) == 1
class TestMarkdownChunker:
"""Tests for MarkdownChunker."""
def test_chunk_simple_document(self):
"""Test chunking a simple document."""
text = """# Title
This is a paragraph.
## Section
Another paragraph.
"""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test_doc")
assert len(chunks) >= 2
assert all(isinstance(chunk, MarkdownChunk) for chunk in chunks)
assert all(chunk.chunk_id.startswith("test_doc") for chunk in chunks)
def test_chunk_preserves_header_context(self):
"""Test that header context is preserved."""
text = """# Main Title
## Section A
Content under section A.
### Subsection A1
Content under subsection A1.
"""
chunker = MarkdownChunker(include_header_context=True)
chunks = chunker.chunk(text, "test")
subsection_chunks = [c for c in chunks if "Subsection A1" not in c.content]
for chunk in subsection_chunks:
if "subsection a1" in chunk.content.lower():
assert "Main Title" in chunk.header_context
assert "Section A" in chunk.header_context
def test_chunk_code_blocks_preserved(self):
"""Test that code blocks are preserved as single chunks when possible."""
text = """```python
def function_one():
pass
def function_two():
pass
```
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_code_blocks=True)
chunks = chunker.chunk(text, "test")
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
assert len(code_chunks) == 1
assert "def function_one" in code_chunks[0].content
assert "def function_two" in code_chunks[0].content
assert code_chunks[0].language == "python"
def test_chunk_large_code_block_split(self):
"""Test that large code blocks are split properly."""
lines = ["def function_{}(): pass".format(i) for i in range(100)]
code_content = "\n".join(lines)
text = f"""```python\n{code_content}\n```"""
chunker = MarkdownChunker(max_chunk_size=500, preserve_code_blocks=True)
chunks = chunker.chunk(text, "test")
code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK]
assert len(code_chunks) > 1
for chunk in code_chunks:
assert chunk.language == "python"
assert "```python" in chunk.content
assert "```" in chunk.content
def test_chunk_table_preserved(self):
"""Test that tables are preserved."""
text = """| Name | Age |
|------|-----|
| Alice | 30 |
| Bob | 25 |
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_tables=True)
chunks = chunker.chunk(text, "test")
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
assert len(table_chunks) == 1
assert "Alice" in table_chunks[0].content
assert "Bob" in table_chunks[0].content
def test_chunk_large_table_split(self):
"""Test that large tables are split with header preserved."""
rows = [f"| Name{i} | {i * 10} |" for i in range(50)]
table_content = "| Name | Age |\n|------|-----|\n" + "\n".join(rows)
text = table_content
chunker = MarkdownChunker(max_chunk_size=200, preserve_tables=True)
chunks = chunker.chunk(text, "test")
table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE]
assert len(table_chunks) > 1
for chunk in table_chunks:
assert "| Name | Age |" in chunk.content
assert "|------|-----|" in chunk.content
def test_chunk_list_preserved(self):
"""Test that lists are chunked properly."""
text = """- Item 1
- Item 2
- Item 3
- Item 4
- Item 5
"""
chunker = MarkdownChunker(max_chunk_size=2000, preserve_lists=True)
chunks = chunker.chunk(text, "test")
list_chunks = [c for c in chunks if c.element_type == MarkdownElementType.LIST]
assert len(list_chunks) == 1
assert "Item 1" in list_chunks[0].content
assert "Item 5" in list_chunks[0].content
def test_chunk_empty_document(self):
"""Test chunking an empty document."""
text = ""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test")
assert len(chunks) == 0
def test_chunk_only_headers(self):
"""Test chunking a document with only headers."""
text = """# Title 1
## Title 2
### Title 3
"""
chunker = MarkdownChunker()
chunks = chunker.chunk(text, "test")
assert len(chunks) == 0
class TestChunkMarkdownFunction:
"""Tests for the convenience chunk_markdown function."""
def test_basic_chunking(self):
"""Test basic chunking via convenience function."""
text = """# Title
Content paragraph.
```python
code = "here"
```
"""
chunks = chunk_markdown(text, "doc1")
assert len(chunks) >= 1
assert all("chunk_id" in chunk for chunk in chunks)
assert all("content" in chunk for chunk in chunks)
assert all("element_type" in chunk for chunk in chunks)
assert all("header_context" in chunk for chunk in chunks)
def test_custom_parameters(self):
"""Test chunking with custom parameters."""
text = "A" * 2000
chunks = chunk_markdown(
text,
"doc1",
max_chunk_size=500,
min_chunk_size=50,
preserve_code_blocks=False,
preserve_tables=False,
preserve_lists=False,
include_header_context=False,
)
assert len(chunks) >= 1
class TestMarkdownElement:
"""Tests for MarkdownElement dataclass."""
def test_to_dict(self):
"""Test serialization to dictionary."""
elem = MarkdownElement(
type=MarkdownElementType.HEADER,
content="Test Header",
level=2,
line_start=10,
line_end=10,
metadata={"level": 2},
)
result = elem.to_dict()
assert result["type"] == "header"
assert result["content"] == "Test Header"
assert result["level"] == 2
assert result["line_start"] == 10
assert result["line_end"] == 10
def test_code_block_with_language(self):
"""Test code block element with language."""
elem = MarkdownElement(
type=MarkdownElementType.CODE_BLOCK,
content="print('hello')",
language="python",
line_start=5,
line_end=7,
)
result = elem.to_dict()
assert result["type"] == "code_block"
assert result["language"] == "python"
def test_table_with_metadata(self):
"""Test table element with metadata."""
elem = MarkdownElement(
type=MarkdownElementType.TABLE,
content="| A | B |\n|---|---|\n| 1 | 2 |",
line_start=1,
line_end=3,
metadata={"headers": ["A", "B"], "row_count": 1},
)
result = elem.to_dict()
assert result["type"] == "table"
assert result["metadata"]["headers"] == ["A", "B"]
assert result["metadata"]["row_count"] == 1
class TestMarkdownChunk:
"""Tests for MarkdownChunk dataclass."""
def test_to_dict(self):
"""Test serialization to dictionary."""
chunk = MarkdownChunk(
chunk_id="doc_chunk_0",
content="Test content",
element_type=MarkdownElementType.PARAGRAPH,
header_context=["Main Title", "Section"],
metadata={"key": "value"},
)
result = chunk.to_dict()
assert result["chunk_id"] == "doc_chunk_0"
assert result["content"] == "Test content"
assert result["element_type"] == "paragraph"
assert result["header_context"] == ["Main Title", "Section"]
assert result["metadata"]["key"] == "value"
def test_with_language(self):
"""Test chunk with language info."""
chunk = MarkdownChunk(
chunk_id="code_0",
content="```python\nprint('hi')\n```",
element_type=MarkdownElementType.CODE_BLOCK,
header_context=[],
language="python",
)
result = chunk.to_dict()
assert result["language"] == "python"
class TestMarkdownElementType:
"""Tests for MarkdownElementType enum."""
def test_all_types_exist(self):
"""Test that all expected element types exist."""
expected_types = [
"header",
"paragraph",
"code_block",
"inline_code",
"table",
"list",
"blockquote",
"horizontal_rule",
"image",
"link",
"text",
]
for type_name in expected_types:
assert hasattr(MarkdownElementType, type_name.upper()) or \
any(t.value == type_name for t in MarkdownElementType)

View File

@ -0,0 +1,443 @@
"""
Unit tests for MetadataAutoInferenceService.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from app.services.metadata_auto_inference_service import (
AutoInferenceResult,
InferenceFieldContext,
MetadataAutoInferenceService,
)
@dataclass
class MockFieldDefinition:
"""Mock field definition for testing"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
class TestInferenceFieldContext:
"""Test InferenceFieldContext dataclass."""
def test_creation(self):
"""Test creating InferenceFieldContext."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
)
assert ctx.field_key == "grade"
assert ctx.label == "年级"
assert ctx.type == "enum"
assert ctx.required is True
assert ctx.options == ["初一", "初二", "初三"]
def test_default_values(self):
"""Test default values."""
ctx = InferenceFieldContext(
field_key="test",
label="Test",
type="text",
required=False,
)
assert ctx.options is None
assert ctx.description is None
class TestAutoInferenceResult:
"""Test AutoInferenceResult dataclass."""
def test_success_result(self):
"""Test successful inference result."""
result = AutoInferenceResult(
inferred_metadata={"grade": "初一", "subject": "数学"},
confidence_scores={"grade": 0.95, "subject": 0.85},
raw_response='{"inferred_metadata": {...}}',
success=True,
)
assert result.success is True
assert result.error_message is None
assert result.inferred_metadata["grade"] == "初一"
def test_failure_result(self):
"""Test failed inference result."""
result = AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response="",
success=False,
error_message="JSON parse error",
)
assert result.success is False
assert result.error_message == "JSON parse error"
class TestMetadataAutoInferenceService:
"""Test MetadataAutoInferenceService functionality."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_build_field_contexts(self, service):
"""Test building field contexts from definitions."""
fields = [
MockFieldDefinition(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
MockFieldDefinition(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
contexts = service._build_field_contexts(fields)
assert len(contexts) == 2
assert contexts[0].field_key == "grade"
assert contexts[0].options == ["初一", "初二", "初三"]
assert contexts[1].field_key == "subject"
def test_build_user_prompt(self, service):
"""Test building user prompt."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道初一数学题",
field_contexts=field_contexts,
)
assert "年级" in prompt
assert "初一, 初二, 初三" in prompt
assert "这是一道初一数学题" in prompt
def test_build_user_prompt_with_existing_metadata(self, service):
"""Test building user prompt with existing metadata."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道数学题",
field_contexts=field_contexts,
existing_metadata={"grade": "初二"},
)
assert "已有值: 初二" in prompt
def test_extract_json_from_plain_json(self, service):
"""Test extracting JSON from plain JSON response."""
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
result = service._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self, service):
"""Test extracting JSON from markdown code block."""
markdown = """Here is the result:
```json
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
```
"""
result = service._extract_json(markdown)
assert "inferred_metadata" in result
assert "grade" in result
def test_parse_llm_response_valid(self, service):
"""Test parsing valid LLM response."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
"subject": "数学",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata["grade"] == "初一"
assert result.inferred_metadata["subject"] == "数学"
assert result.confidence_scores["grade"] == 0.95
def test_parse_llm_response_invalid_option(self, service):
"""Test parsing response with invalid enum option."""
response = json.dumps({
"inferred_metadata": {
"grade": "高一", # Not in options
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" not in result.inferred_metadata
def test_parse_llm_response_invalid_json(self, service):
"""Test parsing invalid JSON response."""
response = "This is not valid JSON"
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="text",
required=False,
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is False
assert "JSON parse error" in result.error_message
def test_validate_field_value_text(self, service):
"""Test validating text field value."""
ctx = InferenceFieldContext(
field_key="title",
label="标题",
type="text",
required=False,
)
result = service._validate_field_value(ctx, "测试标题")
assert result == "测试标题"
def test_validate_field_value_number(self, service):
"""Test validating number field value."""
ctx = InferenceFieldContext(
field_key="count",
label="数量",
type="number",
required=False,
)
assert service._validate_field_value(ctx, 42) == 42
assert service._validate_field_value(ctx, "3.14") == 3.14
assert service._validate_field_value(ctx, "invalid") is None
def test_validate_field_value_boolean(self, service):
"""Test validating boolean field value."""
ctx = InferenceFieldContext(
field_key="active",
label="是否激活",
type="boolean",
required=False,
)
assert service._validate_field_value(ctx, True) is True
assert service._validate_field_value(ctx, "true") is True
assert service._validate_field_value(ctx, "false") is False
assert service._validate_field_value(ctx, 1) is True
def test_validate_field_value_enum(self, service):
"""Test validating enum field value."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=False,
options=["初一", "初二", "初三"],
)
assert service._validate_field_value(ctx, "初一") == "初一"
assert service._validate_field_value(ctx, "高一") is None
def test_validate_field_value_array_enum(self, service):
"""Test validating array_enum field value."""
ctx = InferenceFieldContext(
field_key="tags",
label="标签",
type="array_enum",
required=False,
options=["重点", "难点", "易错"],
)
result = service._validate_field_value(ctx, ["重点", "难点"])
assert result == ["重点", "难点"]
result = service._validate_field_value(ctx, ["重点", "不存在"])
assert result == ["重点"]
result = service._validate_field_value(ctx, "重点")
assert result == ["重点"]
class TestIntegrationScenarios:
"""Test integration scenarios."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_education_scenario(self, service):
"""Test education scenario with grade and subject."""
response = json.dumps({
"inferred_metadata": {
"grade": "初二",
"subject": "物理",
"type": "痛点",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.90,
"type": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语", "物理", "化学"],
),
InferenceFieldContext(
field_key="type",
label="类型",
type="enum",
required=False,
options=["痛点", "重点", "难点"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata == {
"grade": "初二",
"subject": "物理",
"type": "痛点",
}
assert result.confidence_scores["grade"] == 0.95
def test_partial_inference(self, service):
"""Test partial inference when some fields cannot be inferred."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" in result.inferred_metadata
assert "subject" not in result.inferred_metadata
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,881 @@
"""
Mid Platform Dialogue Integration Test.
中台联调界面对话过程集成测试脚本
测试重点:
1. 意图置信度参数
2. 执行模式通用API vs ReAct模式
3. ReAct模式下的工具调用工具名称入参返回结果
4. 知识库查询是否命中入参
5. 各部分耗时
6. 提示词模板使用情况
"""
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from app.main import app
from app.models.mid.schemas import (
DialogueRequest,
DialogueResponse,
ExecutionMode,
FeatureFlags,
HistoryMessage,
Segment,
TraceInfo,
ToolCallTrace,
ToolCallStatus,
ToolType,
IntentHintOutput,
HighRiskCheckResult,
HighRiskScenario,
)
logger = logging.getLogger(__name__)
@dataclass
class TimingRecord:
"""耗时记录"""
stage: str
start_time: float
end_time: float
duration_ms: int
def to_dict(self) -> dict:
return {
"stage": self.stage,
"duration_ms": self.duration_ms,
}
@dataclass
class DialogueTestResult:
"""对话测试结果"""
request_id: str
user_message: str
response_text: str = ""
timing_records: list[TimingRecord] = field(default_factory=list)
total_duration_ms: int = 0
execution_mode: ExecutionMode = ExecutionMode.AGENT
intent: str | None = None
confidence: float | None = None
react_iterations: int = 0
tools_used: list[str] = field(default_factory=list)
tool_calls: list[dict] = field(default_factory=list)
kb_tool_called: bool = False
kb_hit: bool = False
kb_query: str | None = None
kb_filter: dict | None = None
kb_hits_count: int = 0
prompt_template_used: str | None = None
prompt_template_scene: str | None = None
guardrail_triggered: bool = False
fallback_reason_code: str | None = None
raw_trace: dict | None = None
def to_summary(self) -> dict:
return {
"request_id": self.request_id,
"user_message": self.user_message[:100],
"execution_mode": self.execution_mode.value,
"intent": self.intent,
"confidence": self.confidence,
"total_duration_ms": self.total_duration_ms,
"timing_breakdown": [t.to_dict() for t in self.timing_records],
"react_iterations": self.react_iterations,
"tools_used": self.tools_used,
"tool_calls_count": len(self.tool_calls),
"kb_tool_called": self.kb_tool_called,
"kb_hit": self.kb_hit,
"kb_hits_count": self.kb_hits_count,
"prompt_template_used": self.prompt_template_used,
"guardrail_triggered": self.guardrail_triggered,
"fallback_reason_code": self.fallback_reason_code,
}
class DialogueIntegrationTester:
"""对话集成测试器"""
def __init__(
self,
base_url: str = "http://localhost:8000",
tenant_id: str = "test_tenant",
api_key: str | None = None,
):
self.base_url = base_url
self.tenant_id = tenant_id
self.api_key = api_key
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
def _get_headers(self) -> dict:
headers = {
"Content-Type": "application/json",
"X-Tenant-Id": self.tenant_id,
}
if self.api_key:
headers["X-API-Key"] = self.api_key
return headers
async def send_dialogue(
self,
user_message: str,
history: list[dict] | None = None,
scene: str | None = None,
feature_flags: dict | None = None,
) -> DialogueTestResult:
"""发送对话请求并记录详细信息"""
request_id = str(uuid.uuid4())
result = DialogueTestResult(
request_id=request_id,
user_message=user_message,
response_text="",
)
overall_start = time.time()
request_body = {
"session_id": self.session_id,
"user_id": self.user_id,
"user_message": user_message,
"history": history or [],
}
if scene:
request_body["scene"] = scene
if feature_flags:
request_body["feature_flags"] = feature_flags
timing_records = []
try:
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
request_start = time.time()
response = await client.post(
"/mid/dialogue/respond",
json=request_body,
headers=self._get_headers(),
)
request_end = time.time()
timing_records.append(TimingRecord(
stage="http_request",
start_time=request_start,
end_time=request_end,
duration_ms=int((request_end - request_start) * 1000),
))
if response.status_code != 200:
result.response_text = f"Error: {response.status_code}"
result.fallback_reason_code = f"http_error_{response.status_code}"
return result
response_data = response.json()
except Exception as e:
result.response_text = f"Exception: {str(e)}"
result.fallback_reason_code = "request_exception"
result.total_duration_ms = int((time.time() - overall_start) * 1000)
return result
overall_end = time.time()
result.total_duration_ms = int((overall_end - overall_start) * 1000)
result.timing_records = timing_records
try:
dialogue_response = DialogueResponse(**response_data)
if dialogue_response.segments:
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
trace = dialogue_response.trace
result.raw_trace = trace.model_dump() if trace else None
if trace:
result.execution_mode = trace.mode
result.intent = trace.intent
result.react_iterations = trace.react_iterations or 0
result.tools_used = trace.tools_used or []
result.kb_tool_called = trace.kb_tool_called or False
result.kb_hit = trace.kb_hit or False
result.guardrail_triggered = trace.guardrail_triggered or False
result.fallback_reason_code = trace.fallback_reason_code
if trace.tool_calls:
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
for tc in trace.tool_calls:
if tc.tool_name == "kb_search_dynamic":
result.kb_tool_called = True
if tc.arguments:
result.kb_query = tc.arguments.get("query")
result.kb_filter = tc.arguments.get("context")
if tc.result and isinstance(tc.result, dict):
result.kb_hits_count = len(tc.result.get("hits", []))
result.kb_hit = result.kb_hits_count > 0
if trace.duration_ms:
timing_records.append(TimingRecord(
stage="server_processing",
start_time=overall_start,
end_time=overall_end,
duration_ms=trace.duration_ms,
))
if trace.scene:
result.prompt_template_scene = trace.scene
except Exception as e:
logger.error(f"Failed to parse response: {e}")
result.response_text = str(response_data)
return result
def print_result(self, result: DialogueTestResult):
"""打印测试结果"""
print("\n" + "=" * 80)
print(f"[对话测试结果] Request ID: {result.request_id}")
print("=" * 80)
print(f"\n[用户消息] {result.user_message}")
print(f"[回复内容] {result.response_text[:200]}...")
print(f"\n[执行模式] {result.execution_mode.value}")
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
for tr in result.timing_records:
print(f" - {tr.stage}: {tr.duration_ms}ms")
if result.execution_mode == ExecutionMode.AGENT:
print(f"\n[ReAct模式]")
print(f" - 迭代次数: {result.react_iterations}")
print(f" - 使用的工具: {result.tools_used}")
if result.tool_calls:
print(f"\n[工具调用详情]")
for i, tc in enumerate(result.tool_calls, 1):
print(f" [{i}] 工具: {tc.get('tool_name')}")
print(f" 状态: {tc.get('status')}")
print(f" 耗时: {tc.get('duration_ms')}ms")
if tc.get('arguments'):
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
if tc.get('result'):
result_str = str(tc.get('result'))[:300]
print(f" 结果: {result_str}")
print(f"\n[知识库查询]")
print(f" - 是否调用: {result.kb_tool_called}")
print(f" - 是否命中: {result.kb_hit}")
if result.kb_query:
print(f" - 查询内容: {result.kb_query}")
if result.kb_filter:
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
print(f" - 命中数量: {result.kb_hits_count}")
print(f"\n[提示词模板]")
print(f" - 场景: {result.prompt_template_scene}")
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
print(f"\n[其他信息]")
print(f" - 护栏触发: {result.guardrail_triggered}")
print(f" - 降级原因: {result.fallback_reason_code or ''}")
print("\n" + "=" * 80)
class TestMidDialogueIntegration:
"""中台对话集成测试"""
@pytest.fixture
def tester(self):
return DialogueIntegrationTester(
base_url="http://localhost:8000",
tenant_id="test_tenant",
)
@pytest.fixture
def mock_llm_client(self):
"""模拟 LLM 客户端"""
mock = AsyncMock()
mock.generate = AsyncMock(return_value=MagicMock(
content="这是测试回复",
has_tool_calls=False,
tool_calls=[],
))
return mock
@pytest.fixture
def mock_kb_tool(self):
"""模拟知识库工具"""
mock = AsyncMock()
mock.execute = AsyncMock(return_value=MagicMock(
success=True,
hits=[
{"id": "1", "content": "测试知识库内容", "score": 0.9},
],
applied_filter={"scene": "test"},
missing_required_slots=[],
fallback_reason_code=None,
duration_ms=100,
tool_trace=None,
))
return mock
@pytest.mark.asyncio
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
"""测试简单问候"""
result = await tester.send_dialogue(
user_message="你好",
)
tester.print_result(result)
assert result.request_id is not None
assert result.total_duration_ms > 0
@pytest.mark.asyncio
async def test_kb_query(self, tester: DialogueIntegrationTester):
"""测试知识库查询"""
result = await tester.send_dialogue(
user_message="退款流程是什么?",
scene="after_sale",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
"""测试高风险场景"""
result = await tester.send_dialogue(
user_message="我要投诉你们的服务",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_transfer_request(self, tester: DialogueIntegrationTester):
"""测试转人工请求"""
result = await tester.send_dialogue(
user_message="帮我转人工客服",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_with_history(self, tester: DialogueIntegrationTester):
"""测试带历史记录的对话"""
result = await tester.send_dialogue(
user_message="那退款要多久呢?",
history=[
{"role": "user", "content": "我想退款"},
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
],
)
tester.print_result(result)
assert result.request_id is not None
class TestDialogueWithMock:
"""使用 Mock 的对话测试"""
@pytest.fixture
def mock_app(self):
"""创建带 Mock 的测试应用"""
from fastapi import FastAPI
from app.api.mid.dialogue import router
app = FastAPI()
app.include_router(router)
return app
@pytest.fixture
def client(self, mock_app):
return TestClient(mock_app)
@pytest.fixture
def mock_session(self):
"""模拟数据库会话"""
mock = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock.execute.return_value = mock_result
return mock
@pytest.fixture
def mock_llm(self):
"""模拟 LLM 响应"""
mock_response = MagicMock()
mock_response.content = "这是测试回复内容"
mock_response.has_tool_calls = False
mock_response.tool_calls = []
return mock_response
def test_dialogue_request_structure(self, client: TestClient):
"""测试对话请求结构"""
request_body = {
"session_id": "test_session_001",
"user_id": "test_user_001",
"user_message": "你好",
"history": [],
"scene": "open_consult",
}
print("\n[测试请求结构]")
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
assert request_body["session_id"] == "test_session_001"
assert request_body["user_message"] == "你好"
def test_trace_info_structure(self):
"""测试追踪信息结构"""
trace = TraceInfo(
mode=ExecutionMode.AGENT,
intent="greeting",
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
kb_tool_called=True,
kb_hit=True,
react_iterations=2,
tools_used=["kb_search_dynamic"],
tool_calls=[
ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
arguments={"query": "测试查询", "scene": "test"},
result={"hits": [{"content": "测试内容"}]},
),
],
)
print("\n[TraceInfo 结构测试]")
print(f"Mode: {trace.mode.value}")
print(f"Intent: {trace.intent}")
print(f"KB Tool Called: {trace.kb_tool_called}")
print(f"KB Hit: {trace.kb_hit}")
print(f"React Iterations: {trace.react_iterations}")
print(f"Tools Used: {trace.tools_used}")
if trace.tool_calls:
print(f"\n[Tool Calls]")
for tc in trace.tool_calls:
print(f" - Tool: {tc.tool_name}")
print(f" Status: {tc.status.value}")
print(f" Duration: {tc.duration_ms}ms")
if tc.arguments:
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
assert trace.mode == ExecutionMode.AGENT
assert trace.kb_tool_called is True
assert len(trace.tool_calls) == 1
def test_intent_hint_output_structure(self):
"""测试意图提示输出结构"""
hint = IntentHintOutput(
intent="refund",
confidence=0.85,
response_type="flow",
suggested_mode=ExecutionMode.MICRO_FLOW,
target_flow_id="flow_refund_001",
high_risk_detected=False,
duration_ms=50,
)
print("\n[IntentHintOutput 结构测试]")
print(f"Intent: {hint.intent}")
print(f"Confidence: {hint.confidence}")
print(f"Response Type: {hint.response_type}")
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
print(f"High Risk Detected: {hint.high_risk_detected}")
print(f"Duration: {hint.duration_ms}ms")
assert hint.intent == "refund"
assert hint.confidence == 0.85
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
def test_high_risk_check_result_structure(self):
"""测试高风险检测结果结构"""
result = HighRiskCheckResult(
matched=True,
risk_scenario=HighRiskScenario.REFUND,
confidence=0.95,
recommended_mode=ExecutionMode.MICRO_FLOW,
rule_id="rule_refund_001",
reason="检测到退款关键词",
duration_ms=30,
)
print("\n[HighRiskCheckResult 结构测试]")
print(f"Matched: {result.matched}")
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
print(f"Confidence: {result.confidence}")
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
print(f"Rule ID: {result.rule_id}")
print(f"Duration: {result.duration_ms}ms")
assert result.matched is True
assert result.risk_scenario == HighRiskScenario.REFUND
class TestPromptTemplateUsage:
"""提示词模板使用测试"""
def test_template_resolution(self):
"""测试模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
variables = [
{"key": "user_name", "value": "张三"},
{"key": "bot_name", "value": "智能客服"},
]
resolved = resolver.resolve(template, variables)
print("\n[模板解析测试]")
print(f"原始模板: {template}")
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
def test_template_with_extra_context(self):
"""测试带额外上下文的模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "当前场景:{{scene}},用户问题:{{query}}"
extra_context = {
"scene": "售后服务",
"query": "退款流程",
}
resolved = resolver.resolve(template, [], extra_context)
print("\n[带上下文的模板解析测试]")
print(f"原始模板: {template}")
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert "售后服务" in resolved
assert "退款流程" in resolved
class TestToolCallRecording:
"""工具调用记录测试"""
def test_tool_call_trace_creation(self):
"""测试工具调用追踪创建"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
args_digest="query=退款流程",
result_digest="hits=3",
arguments={
"query": "退款流程是什么",
"scene": "after_sale",
"context": {"product_type": "vip"},
},
result={
"success": True,
"hits": [
{"id": "1", "content": "退款流程说明...", "score": 0.95},
{"id": "2", "content": "退款注意事项...", "score": 0.88},
{"id": "3", "content": "退款时效说明...", "score": 0.82},
],
"applied_filter": {"product_type": "vip"},
},
)
print("\n[工具调用追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Tool Type: {trace.tool_type.value}")
print(f"Status: {trace.status.value}")
print(f"Duration: {trace.duration_ms}ms")
print(f"\n[入参详情]")
if trace.arguments:
for key, value in trace.arguments.items():
print(f" - {key}: {value}")
print(f"\n[返回结果]")
if trace.result:
if isinstance(trace.result, dict):
print(f" - success: {trace.result.get('success')}")
print(f" - hits count: {len(trace.result.get('hits', []))}")
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
assert trace.tool_name == "kb_search_dynamic"
assert trace.status == ToolCallStatus.OK
assert trace.arguments is not None
assert trace.result is not None
def test_tool_call_timeout_trace(self):
"""测试工具调用超时追踪"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=2000,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
arguments={"query": "测试查询"},
)
print("\n[工具调用超时追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Status: {trace.status.value}")
print(f"Error Code: {trace.error_code}")
print(f"Duration: {trace.duration_ms}ms")
assert trace.status == ToolCallStatus.TIMEOUT
assert trace.error_code == "TOOL_TIMEOUT"
class TestExecutionModeRouting:
"""执行模式路由测试"""
def test_policy_router_decision(self):
"""测试策略路由器决策"""
from app.services.mid.policy_router import PolicyRouter, IntentMatch
router = PolicyRouter()
test_cases = [
{
"name": "正常对话 -> Agent模式",
"user_message": "你好,请问有什么可以帮助我的?",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.AGENT,
},
{
"name": "高风险退款 -> Micro Flow模式",
"user_message": "我要退款",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.MICRO_FLOW,
},
{
"name": "转人工请求 -> Transfer模式",
"user_message": "帮我转人工",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
{
"name": "人工模式 -> Transfer模式",
"user_message": "你好",
"session_mode": "HUMAN_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
]
print("\n[策略路由器决策测试]")
for tc in test_cases:
result = router.route(
user_message=tc["user_message"],
session_mode=tc["session_mode"],
)
print(f"\n测试用例: {tc['name']}")
print(f" 用户消息: {tc['user_message']}")
print(f" 会话模式: {tc['session_mode']}")
print(f" 期望模式: {tc['expected_mode'].value}")
print(f" 实际模式: {result.mode.value}")
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
class TestKBSearchDynamic:
"""知识库动态检索测试"""
def test_kb_search_result_structure(self):
"""测试知识库检索结果结构"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=True,
hits=[
{
"id": "chunk_001",
"content": "退款流程1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
"score": 0.92,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
},
{
"id": "chunk_002",
"content": "退款时效一般3-5个工作日到账...",
"score": 0.85,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
},
],
applied_filter={"scene": "after_sale", "product_type": "vip"},
missing_required_slots=[],
filter_debug={"source": "slot_state"},
filter_sources={"scene": "slot", "product_type": "context"},
duration_ms=120,
)
print("\n[知识库检索结果结构测试]")
print(f"Success: {result.success}")
print(f"Hits Count: {len(result.hits)}")
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
print(f"Duration: {result.duration_ms}ms")
print(f"\n[命中详情]")
for i, hit in enumerate(result.hits, 1):
print(f" [{i}] ID: {hit['id']}")
print(f" Score: {hit['score']}")
print(f" Content: {hit['content'][:50]}...")
assert result.success is True
assert len(result.hits) == 2
assert result.applied_filter is not None
def test_kb_search_missing_slots(self):
"""测试知识库检索缺失槽位"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=False,
hits=[],
applied_filter={},
missing_required_slots=[
{
"field_key": "order_id",
"label": "订单号",
"reason": "必填字段缺失",
"ask_back_prompt": "请提供您的订单号",
},
],
filter_debug={"source": "builder"},
fallback_reason_code="MISSING_REQUIRED_SLOTS",
duration_ms=50,
)
print("\n[知识库检索缺失槽位测试]")
print(f"Success: {result.success}")
print(f"Fallback Reason: {result.fallback_reason_code}")
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
assert result.success is False
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
assert len(result.missing_required_slots) == 1
class TestTimingBreakdown:
"""耗时分解测试"""
def test_timing_breakdown_structure(self):
"""测试耗时分解结构"""
timings = [
TimingRecord("intent_matching", 0, 0.05, 50),
TimingRecord("high_risk_check", 0.05, 0.08, 30),
TimingRecord("kb_search", 0.08, 0.2, 120),
TimingRecord("llm_generation", 0.2, 1.5, 1300),
TimingRecord("output_guardrail", 1.5, 1.55, 50),
TimingRecord("response_formatting", 1.55, 1.6, 50),
]
total = sum(t.duration_ms for t in timings)
print("\n[耗时分解测试]")
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
print("-" * 45)
for t in timings:
percentage = (t.duration_ms / total * 100) if total > 0 else 0
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
print("-" * 45)
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
assert total == 1600
def run_manual_test():
"""手动运行测试"""
import argparse
parser = argparse.ArgumentParser(description="中台对话集成测试")
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
parser.add_argument("--api-key", default=None, help="API Key")
parser.add_argument("--message", default="你好", help="测试消息")
parser.add_argument("--scene", default=None, help="场景标识")
parser.add_argument("--interactive", action="store_true", help="交互模式")
args = parser.parse_args()
tester = DialogueIntegrationTester(
base_url=args.url,
tenant_id=args.tenant,
api_key=args.api_key,
)
if args.interactive:
print("\n=== 中台对话集成测试 - 交互模式 ===")
print("输入 'quit' 退出\n")
while True:
try:
message = input("请输入消息: ").strip()
if message.lower() == "quit":
break
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
result = asyncio.run(tester.send_dialogue(
user_message=message,
scene=scene,
))
tester.print_result(result)
except KeyboardInterrupt:
break
else:
result = asyncio.run(tester.send_dialogue(
user_message=args.message,
scene=args.scene,
))
tester.print_result(result)
if __name__ == "__main__":
run_manual_test()

View File

@ -0,0 +1,541 @@
"""
Unit tests for Retrieval Strategy Service.
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategyValidationRequest,
ValidationResult,
)
from app.services.retrieval.strategy_service import (
RetrievalStrategyService,
StrategyState,
get_strategy_service,
)
from app.services.retrieval.strategy_audit import (
StrategyAuditService,
get_audit_service,
)
from app.services.retrieval.strategy_metrics import (
StrategyMetricsService,
get_metrics_service,
)
class TestRetrievalStrategySchemas:
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
def test_rollout_config_off_mode(self):
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
config = RolloutConfig(mode=RolloutMode.OFF)
assert config.mode == RolloutMode.OFF
assert config.percentage is None
assert config.allowlist is None
def test_rollout_config_percentage_mode(self):
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
assert config.mode == RolloutMode.PERCENTAGE
assert config.percentage == 50.0
def test_rollout_config_percentage_mode_missing_value(self):
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
with pytest.raises(ValueError, match="percentage is required"):
RolloutConfig(mode=RolloutMode.PERCENTAGE)
def test_rollout_config_allowlist_mode(self):
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
assert config.mode == RolloutMode.ALLOWLIST
assert config.allowlist == ["tenant1", "tenant2"]
def test_rollout_config_allowlist_mode_missing_value(self):
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
with pytest.raises(ValueError, match="allowlist is required"):
RolloutConfig(mode=RolloutMode.ALLOWLIST)
def test_retrieval_strategy_status(self):
"""[AC-AISVC-RES-01] Status should contain all required fields."""
rollout = RolloutConfig(mode=RolloutMode.OFF)
status = RetrievalStrategyStatus(
active_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
rollout=rollout,
)
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_request_minimal(self):
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode is None
assert request.rollout is None
assert request.reason is None
def test_switch_request_full(self):
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
rollout=rollout,
reason="Testing enhanced strategy",
)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode == ReactMode.REACT
assert request.rollout.percentage == 30.0
assert request.reason == "Testing enhanced strategy"
class TestRetrievalStrategyService:
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
@pytest.fixture
def service(self):
"""Create a fresh service instance for each test."""
return RetrievalStrategyService()
def test_get_current_status_default(self, service):
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
status = service.get_current_status()
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_strategy_to_enhanced(self, service):
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
response = service.switch_strategy(request)
assert response.previous.active_strategy == StrategyType.DEFAULT
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.react_mode == ReactMode.REACT
def test_switch_strategy_with_grayscale_percentage(self, service):
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
assert response.current.rollout.percentage == 50.0
def test_switch_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a", "tenant_b"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
assert "tenant_a" in response.current.rollout.allowlist
def test_rollback_strategy(self, service):
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
assert response.rollback_to.react_mode == ReactMode.NON_REACT
def test_rollback_without_previous_returns_default(self, service):
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
def test_should_use_enhanced_strategy_default(self, service):
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
assert service.should_use_enhanced_strategy("tenant_a") is False
def test_should_use_enhanced_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
service.switch_strategy(request)
assert service.should_use_enhanced_strategy("tenant_a") is True
assert service.should_use_enhanced_strategy("tenant_b") is False
def test_get_route_mode_react(self, service):
"""[AC-AISVC-RES-10] React mode should return react route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "react"
def test_get_route_mode_direct(self, service):
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "direct"
def test_get_route_mode_auto_short_query(self, service):
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
service._state.react_mode = None
route = service._auto_route("短问题", confidence=0.8)
assert route == "direct"
def test_get_route_mode_auto_multiple_conditions(self, service):
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
route = service._auto_route("查询订单状态和物流信息")
assert route == "react"
def test_get_route_mode_auto_low_confidence(self, service):
"""[AC-AISVC-RES-13] Low confidence should use react route."""
route = service._auto_route("test query", confidence=0.3)
assert route == "react"
def test_get_switch_history(self, service):
"""Should track switch history."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
reason="Testing",
)
service.switch_strategy(request)
history = service.get_switch_history()
assert len(history) == 1
assert history[0]["to_strategy"] == "enhanced"
class TestRetrievalStrategyValidation:
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
@pytest.fixture
def service(self):
return RetrievalStrategyService()
def test_validate_default_strategy(self, service):
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.DEFAULT,
)
response = service.validate_strategy(request)
assert response.passed is True
def test_validate_enhanced_strategy(self, service):
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
)
response = service.validate_strategy(request)
assert isinstance(response.passed, bool)
assert len(response.results) > 0
def test_validate_specific_checks(self, service):
"""[AC-AISVC-RES-06] Should run specific validation checks."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
checks=["metadata_consistency", "performance_budget"],
)
response = service.validate_strategy(request)
check_names = [r.check for r in response.results]
assert "metadata_consistency" in check_names
assert "performance_budget" in check_names
def test_check_metadata_consistency(self, service):
"""[AC-AISVC-RES-04] Metadata consistency check."""
result = service._check_metadata_consistency(StrategyType.DEFAULT)
assert result.check == "metadata_consistency"
assert result.passed is True
def test_check_rrf_config(self, service):
"""[AC-AISVC-RES-02] RRF config check."""
result = service._check_rrf_config(StrategyType.DEFAULT)
assert result.check == "rrf_config"
assert isinstance(result.passed, bool)
def test_check_performance_budget(self, service):
"""[AC-AISVC-RES-08] Performance budget check."""
result = service._check_performance_budget(
StrategyType.ENHANCED,
ReactMode.REACT,
)
assert result.check == "performance_budget"
assert isinstance(result.passed, bool)
class TestStrategyAuditService:
"""[AC-AISVC-RES-07] Tests for audit service."""
@pytest.fixture
def audit_service(self):
return StrategyAuditService(max_entries=100)
def test_log_switch_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log switch operation."""
audit_service.log(
operation="switch",
previous_strategy="default",
new_strategy="enhanced",
reason="Testing",
operator="admin",
)
entries = audit_service.get_audit_log()
assert len(entries) == 1
assert entries[0].operation == "switch"
assert entries[0].previous_strategy == "default"
assert entries[0].new_strategy == "enhanced"
def test_log_rollback_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log rollback operation."""
audit_service.log_rollback(
previous_strategy="enhanced",
new_strategy="default",
reason="Performance issue",
operator="admin",
)
entries = audit_service.get_audit_log(operation="rollback")
assert len(entries) == 1
assert entries[0].operation == "rollback"
def test_log_validation_operation(self, audit_service):
"""[AC-AISVC-RES-06] Should log validation operation."""
audit_service.log_validation(
strategy="enhanced",
checks=["metadata_consistency"],
passed=True,
)
entries = audit_service.get_audit_log(operation="validate")
assert len(entries) == 1
assert entries[0].operation == "validate"
def test_get_audit_log_with_limit(self, audit_service):
"""Should limit audit log entries."""
for i in range(10):
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
entries = audit_service.get_audit_log(limit=5)
assert len(entries) == 5
def test_get_audit_stats(self, audit_service):
"""Should return audit statistics."""
audit_service.log(operation="switch", new_strategy="enhanced")
audit_service.log(operation="rollback", new_strategy="default")
stats = audit_service.get_audit_stats()
assert stats["total_entries"] == 2
assert stats["operation_counts"]["switch"] == 1
assert stats["operation_counts"]["rollback"] == 1
def test_clear_audit_log(self, audit_service):
"""Should clear audit log."""
audit_service.log(operation="switch", new_strategy="enhanced")
assert len(audit_service.get_audit_log()) == 1
count = audit_service.clear_audit_log()
assert count == 1
assert len(audit_service.get_audit_log()) == 0
class TestStrategyMetricsService:
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
@pytest.fixture
def metrics_service(self):
return StrategyMetricsService()
def test_record_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record request metrics."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
route_mode="direct",
)
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 1
assert metrics.successful_requests == 1
assert metrics.avg_latency_ms == 100.0
def test_record_failed_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record failed request."""
metrics_service.record_request(latency_ms=50.0, success=False)
metrics = metrics_service.get_metrics()
assert metrics.failed_requests == 1
def test_record_fallback(self, metrics_service):
"""[AC-AISVC-RES-08] Should record fallback count."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
fallback=True,
)
metrics = metrics_service.get_metrics()
assert metrics.fallback_count == 1
def test_record_route_metrics(self, metrics_service):
"""[AC-AISVC-RES-08] Should track route mode metrics."""
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
route_metrics = metrics_service.get_route_metrics()
assert "react" in route_metrics
assert "direct" in route_metrics
def test_get_all_metrics(self, metrics_service):
"""Should get metrics for all strategies."""
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
metrics_service.record_request(latency_ms=100.0, success=True)
all_metrics = metrics_service.get_all_metrics()
assert StrategyType.DEFAULT.value in all_metrics
assert StrategyType.ENHANCED.value in all_metrics
def test_get_performance_summary(self, metrics_service):
"""[AC-AISVC-RES-08] Should get performance summary."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.record_request(latency_ms=200.0, success=True)
metrics_service.record_request(latency_ms=50.0, success=False)
summary = metrics_service.get_performance_summary()
assert summary["total_requests"] == 3
assert summary["successful_requests"] == 2
assert summary["failed_requests"] == 1
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
def test_check_performance_threshold_ok(self, metrics_service):
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
metrics_service.record_request(latency_ms=100.0, success=True)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is True
assert result["error_rate_ok"] is True
assert result["overall_ok"] is True
def test_check_performance_threshold_exceeded(self, metrics_service):
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
metrics_service.record_request(latency_ms=6000.0, success=True)
metrics_service.record_request(latency_ms=100.0, success=False)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is False or result["error_rate_ok"] is False
def test_reset_metrics(self, metrics_service):
"""Should reset metrics."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.reset_metrics()
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 0
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_strategy_service_singleton(self):
"""Should return same strategy service instance."""
from app.services.retrieval.strategy_service import _strategy_service
import app.services.retrieval.strategy_service as module
module._strategy_service = None
service1 = get_strategy_service()
service2 = get_strategy_service()
assert service1 is service2
def test_get_audit_service_singleton(self):
"""Should return same audit service instance."""
from app.services.retrieval.strategy_audit import _audit_service
import app.services.retrieval.strategy_audit as module
module._audit_service = None
service1 = get_audit_service()
service2 = get_audit_service()
assert service1 is service2
def test_get_metrics_service_singleton(self):
"""Should return same metrics service instance."""
from app.services.retrieval.strategy_metrics import _metrics_service
import app.services.retrieval.strategy_metrics as module
module._metrics_service = None
service1 = get_metrics_service()
service2 = get_metrics_service()
assert service1 is service2

View File

@ -0,0 +1,353 @@
"""
Integration tests for Retrieval Strategy API.
[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints.
Tests the full API flow:
- GET /strategy/retrieval/current
- POST /strategy/retrieval/switch
- POST /strategy/retrieval/validate
- POST /strategy/retrieval/rollback
"""
import json
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture(autouse=True)
def mock_api_key_service():
"""
Mock API key service to bypass authentication in tests.
"""
mock_service = MagicMock()
mock_service._initialized = True
mock_service._keys_cache = {"test-api-key": MagicMock()}
mock_validation = MagicMock()
mock_validation.ok = True
mock_validation.reason = None
mock_service.validate_key_with_context.return_value = mock_validation
with patch("app.services.api_key.get_api_key_service", return_value=mock_service):
yield mock_service
@pytest.fixture(autouse=True)
def reset_strategy_state():
"""
Reset strategy state before and after each test.
"""
from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router
from app.services.retrieval.strategy.config import RetrievalStrategyConfig
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
yield
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
class TestRetrievalStrategyAPIIntegration:
"""
[AC-AISVC-RES-01~15] Integration tests for retrieval strategy API.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_get_current_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-01] GET /current should return strategy status.
"""
response = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "active_strategy" in data
assert "grayscale" in data
assert data["active_strategy"] in ["default", "enhanced"]
def test_switch_strategy_to_enhanced(self, client, valid_headers):
"""
[AC-AISVC-RES-02] POST /switch should switch to enhanced strategy.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["success"] is True
assert data["current_strategy"] == "enhanced"
def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept grayscale percentage.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"percentage": 30.0,
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_switch_strategy_with_allowlist(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept allowlist.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"allowlist": ["tenant_a", "tenant_b"],
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_validate_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] POST /validate should validate strategy.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "valid" in data
assert "errors" in data
assert isinstance(data["valid"], bool)
def test_validate_default_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] Default strategy should pass validation.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "default",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is True
def test_rollback_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-07] POST /rollback should rollback to default.
"""
client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
response = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["current_strategy"] == "default"
class TestRetrievalStrategyAPIValidation:
"""
[AC-AISVC-RES-03] Tests for API request validation.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_switch_invalid_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Invalid strategy value should return error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "invalid_strategy",
},
headers=valid_headers,
)
assert response.status_code in [400, 422, 500]
def test_switch_percentage_out_of_range(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Percentage > 100 should return validation error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"percentage": 150.0,
},
},
headers=valid_headers,
)
assert response.status_code in [400, 422]
class TestRetrievalStrategyAPIFlow:
"""
[AC-AISVC-RES-01~15] Tests for complete API flow scenarios.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_complete_strategy_lifecycle(self, client, valid_headers):
"""
[AC-AISVC-RES-01~07] Test complete strategy lifecycle:
1. Get current strategy
2. Switch to enhanced
3. Validate
4. Rollback
5. Verify back to default
"""
current = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert current.status_code == 200
assert current.json()["active_strategy"] == "default"
switch = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {"enabled": True, "percentage": 50.0},
},
headers=valid_headers,
)
assert switch.status_code == 200
assert switch.json()["current_strategy"] == "enhanced"
validate = client.post(
"/strategy/retrieval/validate",
json={"active_strategy": "enhanced"},
headers=valid_headers,
)
assert validate.status_code == 200
rollback = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert rollback.status_code == 200
assert rollback.json()["current_strategy"] == "default"
final = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert final.status_code == 200
assert final.json()["active_strategy"] == "default"
class TestRetrievalStrategyAPIMissingTenant:
"""
Tests for API behavior without tenant ID.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def api_key_headers(self):
return {"X-API-Key": "test-api-key"}
def test_current_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.get(
"/strategy/retrieval/current",
headers=api_key_headers,
)
assert response.status_code == 400
def test_switch_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.post(
"/strategy/retrieval/switch",
json={"active_strategy": "enhanced"},
headers=api_key_headers,
)
assert response.status_code == 400