diff --git a/ai-service/tests/test_image_parser.py b/ai-service/tests/test_image_parser.py new file mode 100644 index 0000000..6df9197 --- /dev/null +++ b/ai-service/tests/test_image_parser.py @@ -0,0 +1,375 @@ +""" +Unit tests for ImageParser. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.document.image_parser import ( + ImageChunk, + ImageParseResult, + ImageParser, +) + + +class TestImageParserBasics: + """Test basic functionality of ImageParser.""" + + def test_supported_extensions(self): + """Test that ImageParser supports correct extensions.""" + parser = ImageParser() + extensions = parser.get_supported_extensions() + + expected_extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] + assert extensions == expected_extensions + + def test_get_mime_type(self): + """Test MIME type mapping.""" + parser = ImageParser() + + assert parser._get_mime_type(".jpg") == "image/jpeg" + assert parser._get_mime_type(".jpeg") == "image/jpeg" + assert parser._get_mime_type(".png") == "image/png" + assert parser._get_mime_type(".gif") == "image/gif" + assert parser._get_mime_type(".webp") == "image/webp" + assert parser._get_mime_type(".bmp") == "image/bmp" + assert parser._get_mime_type(".tiff") == "image/tiff" + assert parser._get_mime_type(".tif") == "image/tiff" + assert parser._get_mime_type(".unknown") == "image/jpeg" + + +class TestImageChunkParsing: + """Test LLM response parsing functionality.""" + + def test_extract_json_from_plain_json(self): + """Test extracting JSON from plain JSON response.""" + parser = ImageParser() + + json_str = '{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello", "chunk_type": "text", "keywords": ["key"]}]}' + result = parser._extract_json(json_str) + + assert result == json_str + + def test_extract_json_from_markdown(self): + """Test extracting JSON from markdown code block.""" + parser = ImageParser() + + markdown = """Here is the analysis: + +```json +{"image_summary": "test", "chunks": [{"chunk_index": 0, "content": "hello"}]} +``` + +Hope this helps!""" + + result = parser._extract_json(markdown) + assert "image_summary" in result + assert "test" in result + + def test_extract_json_from_text_with_json(self): + """Test extracting JSON from text with JSON embedded.""" + parser = ImageParser() + + text = "The result is: {'image_summary': 'summary', 'chunks': []}" + + result = parser._extract_json(text) + assert "image_summary" in result + assert "chunks" in result + + def test_parse_llm_response_valid_json(self): + """Test parsing valid JSON response from LLM.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "测试图片", + "total_chunks": 2, + "chunks": [ + { + "chunk_index": 0, + "content": "这是第一块内容", + "chunk_type": "text", + "keywords": ["测试", "内容"] + }, + { + "chunk_index": 1, + "content": "这是第二块内容,包含表格数据", + "chunk_type": "table", + "keywords": ["表格", "数据"] + } + ] + }) + + result = parser._parse_llm_response(response) + + assert result.image_summary == "测试图片" + assert len(result.chunks) == 2 + assert result.chunks[0].content == "这是第一块内容" + assert result.chunks[0].chunk_type == "text" + assert result.chunks[0].keywords == ["测试", "内容"] + assert result.chunks[1].chunk_type == "table" + assert result.chunks[1].keywords == ["表格", "数据"] + + def test_parse_llm_response_empty_chunks(self): + """Test handling response with empty chunks.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "测试", + "chunks": [] + }) + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 1 + assert result.chunks[0].content == response + + def test_parse_llm_response_invalid_json(self): + """Test handling invalid JSON response with fallback.""" + parser = ImageParser() + + response = "This is not JSON at all" + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 1 + assert result.chunks[0].content == "This is not JSON at all" + + def test_parse_llm_response_partial_json(self): + """Test handling response with partial/invalid JSON uses fallback.""" + parser = ImageParser() + + response = '{"image_summary": "test" some text here {"chunks": []}' + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 1 + assert result.chunks[0].content == response + + +class TestImageChunkDataClass: + """Test ImageChunk dataclass functionality.""" + + def test_image_chunk_creation(self): + """Test creating ImageChunk.""" + chunk = ImageChunk( + chunk_index=0, + content="Test content", + chunk_type="text", + keywords=["test", "content"] + ) + + assert chunk.chunk_index == 0 + assert chunk.content == "Test content" + assert chunk.chunk_type == "text" + assert chunk.keywords == ["test", "content"] + + def test_image_chunk_default_values(self): + """Test ImageChunk with default values.""" + chunk = ImageChunk(chunk_index=0, content="Test") + + assert chunk.chunk_type == "text" + assert chunk.keywords == [] + + def test_image_parse_result_creation(self): + """Test creating ImageParseResult.""" + chunks = [ + ImageChunk(chunk_index=0, content="Chunk 1"), + ImageChunk(chunk_index=1, content="Chunk 2"), + ] + + result = ImageParseResult( + image_summary="Test summary", + chunks=chunks, + raw_text="Chunk 1\n\nChunk 2", + source_path="/path/to/image.png", + file_size=1024, + metadata={"format": "png"} + ) + + assert result.image_summary == "Test summary" + assert len(result.chunks) == 2 + assert result.raw_text == "Chunk 1\n\nChunk 2" + assert result.file_size == 1024 + assert result.metadata["format"] == "png" + + +class TestChunkTypes: + """Test different chunk types.""" + + def test_text_chunk_type(self): + """Test text chunk type.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "Text content", + "chunks": [ + { + "chunk_index": 0, + "content": "Plain text content", + "chunk_type": "text", + "keywords": ["text"] + } + ] + }) + + result = parser._parse_llm_response(response) + assert result.chunks[0].chunk_type == "text" + + def test_table_chunk_type(self): + """Test table chunk type.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "Table content", + "chunks": [ + { + "chunk_index": 0, + "content": "Name | Age\n---- | ---\nJohn | 30", + "chunk_type": "table", + "keywords": ["table", "data"] + } + ] + }) + + result = parser._parse_llm_response(response) + assert result.chunks[0].chunk_type == "table" + + def test_chart_chunk_type(self): + """Test chart chunk type.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "Chart content", + "chunks": [ + { + "chunk_index": 0, + "content": "Bar chart showing sales data", + "chunk_type": "chart", + "keywords": ["chart", "sales"] + } + ] + }) + + result = parser._parse_llm_response(response) + assert result.chunks[0].chunk_type == "chart" + + def test_list_chunk_type(self): + """Test list chunk type.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "List content", + "chunks": [ + { + "chunk_index": 0, + "content": "1. First item\n2. Second item\n3. Third item", + "chunk_type": "list", + "keywords": ["list", "items"] + } + ] + }) + + result = parser._parse_llm_response(response) + assert result.chunks[0].chunk_type == "list" + + +class TestIntegrationScenarios: + """Test integration scenarios.""" + + def test_single_chunk_scenario(self): + """Test single chunk scenario - simple image with one main content.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "简单文档截图", + "chunks": [ + { + "chunk_index": 0, + "content": "这是一段完整的文档内容,包含所有的信息要点。", + "chunk_type": "text", + "keywords": ["文档", "信息"] + } + ] + }) + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 1 + assert result.image_summary == "简单文档截图" + assert result.raw_text == "这是一段完整的文档内容,包含所有的信息要点。" + + def test_multi_chunk_scenario(self): + """Test multi-chunk scenario - complex image with multiple sections.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "多段落文档", + "chunks": [ + { + "chunk_index": 0, + "content": "第一章:介绍部分,介绍项目的背景和目标。", + "chunk_type": "text", + "keywords": ["第一章", "介绍"] + }, + { + "chunk_index": 1, + "content": "第二章:技术架构,包括前端、后端和数据库设计。", + "chunk_type": "text", + "keywords": ["第二章", "架构"] + }, + { + "chunk_index": 2, + "content": "第三章:部署流程,包含开发环境和生产环境配置。", + "chunk_type": "text", + "keywords": ["第三章", "部署"] + } + ] + }) + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 3 + assert "第一章" in result.chunks[0].content + assert "第二章" in result.chunks[1].content + assert "第三章" in result.chunks[2].content + assert result.raw_text.count("\n\n") == 2 + + def test_mixed_content_scenario(self): + """Test mixed content scenario - text and table.""" + parser = ImageParser() + + response = json.dumps({ + "image_summary": "混合内容图片", + "chunks": [ + { + "chunk_index": 0, + "content": "产品介绍:本文档介绍我们的核心产品功能。", + "chunk_type": "text", + "keywords": ["产品", "功能"] + }, + { + "chunk_index": 1, + "content": "产品规格表:\n| 型号 | 价格 | 库存 |\n| --- | --- | --- |\n| A1 | 100 | 50 |", + "chunk_type": "table", + "keywords": ["规格", "价格", "库存"] + }, + { + "chunk_index": 2, + "content": "使用说明:\n1. 打开包装\n2. 连接电源\n3. 按下启动按钮", + "chunk_type": "list", + "keywords": ["说明", "步骤"] + } + ] + }) + + result = parser._parse_llm_response(response) + + assert len(result.chunks) == 3 + assert result.chunks[0].chunk_type == "text" + assert result.chunks[1].chunk_type == "table" + assert result.chunks[2].chunk_type == "list" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/ai-service/tests/test_llm_multi_usage_config.py b/ai-service/tests/test_llm_multi_usage_config.py new file mode 100644 index 0000000..24f5f3b --- /dev/null +++ b/ai-service/tests/test_llm_multi_usage_config.py @@ -0,0 +1,332 @@ +""" +Unit tests for multi-usage LLM configuration. +Tests for LLMUsageType, LLMConfigManager multi-usage support, and API endpoints. +""" + +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.llm.factory import ( + LLMConfigManager, + LLMProviderFactory, + LLMUsageType, + LLM_USAGE_DESCRIPTIONS, + LLM_USAGE_DISPLAY_NAMES, + get_llm_config_manager, +) + + +@pytest.fixture +def mock_settings(): + """Mock settings for testing.""" + settings = MagicMock() + settings.llm_provider = "openai" + settings.llm_api_key = "test-api-key" + settings.llm_base_url = "https://api.openai.com/v1" + settings.llm_model = "gpt-4o-mini" + settings.llm_max_tokens = 2048 + settings.llm_temperature = 0.7 + settings.llm_timeout_seconds = 30 + settings.llm_max_retries = 3 + settings.redis_enabled = False + settings.redis_url = "redis://localhost:6379" + return settings + + +@pytest.fixture(autouse=True) +def reset_singleton(): + """Reset singleton before and after each test.""" + import app.services.llm.factory as factory + factory._llm_config_manager = None + yield + factory._llm_config_manager = None + + +@pytest.fixture +def isolated_config_file(tmp_path): + """Create an isolated config file for testing.""" + config_file = tmp_path / "llm_config.json" + config_file.write_text("{}") + return config_file + + +class TestLLMUsageType: + """Tests for LLMUsageType enum.""" + + def test_usage_types_exist(self): + """Test that required usage types exist.""" + assert LLMUsageType.CHAT.value == "chat" + assert LLMUsageType.KB_PROCESSING.value == "kb_processing" + + def test_usage_type_display_names(self): + """Test that display names are defined for all usage types.""" + for ut in LLMUsageType: + assert ut in LLM_USAGE_DISPLAY_NAMES + assert isinstance(LLM_USAGE_DISPLAY_NAMES[ut], str) + assert len(LLM_USAGE_DISPLAY_NAMES[ut]) > 0 + + def test_usage_type_descriptions(self): + """Test that descriptions are defined for all usage types.""" + for ut in LLMUsageType: + assert ut in LLM_USAGE_DESCRIPTIONS + assert isinstance(LLM_USAGE_DESCRIPTIONS[ut], str) + assert len(LLM_USAGE_DESCRIPTIONS[ut]) > 0 + + def test_usage_type_from_string(self): + """Test creating usage type from string.""" + assert LLMUsageType("chat") == LLMUsageType.CHAT + assert LLMUsageType("kb_processing") == LLMUsageType.KB_PROCESSING + + def test_invalid_usage_type(self): + """Test that invalid usage type raises error.""" + with pytest.raises(ValueError): + LLMUsageType("invalid_type") + + +class TestLLMConfigManagerMultiUsage: + """Tests for LLMConfigManager multi-usage support.""" + + @pytest.mark.asyncio + async def test_get_all_configs(self, mock_settings, isolated_config_file): + """Test getting all configs at once.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + all_configs = manager.get_current_config() + + for ut in LLMUsageType: + assert ut.value in all_configs + assert "provider" in all_configs[ut.value] + assert "config" in all_configs[ut.value] + + @pytest.mark.asyncio + async def test_update_specific_usage_config(self, mock_settings, isolated_config_file): + """Test updating config for a specific usage type.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + await manager.update_usage_config( + usage_type=LLMUsageType.KB_PROCESSING, + provider="ollama", + config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"}, + ) + + kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING) + assert kb_config["provider"] == "ollama" + assert kb_config["config"]["model"] == "llama3.2" + + @pytest.mark.asyncio + async def test_update_all_configs(self, mock_settings, isolated_config_file): + """Test updating all configs at once.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + await manager.update_config( + provider="deepseek", + config={"api_key": "test-key", "model": "deepseek-chat"}, + ) + + for ut in LLMUsageType: + config = manager.get_current_config(ut) + assert config["provider"] == "deepseek" + + @pytest.mark.asyncio + async def test_get_client_for_usage_type(self, mock_settings, isolated_config_file): + """Test getting client for specific usage type.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + chat_client = manager.get_client(LLMUsageType.CHAT) + assert chat_client is not None + + kb_client = manager.get_client(LLMUsageType.KB_PROCESSING) + assert kb_client is not None + + @pytest.mark.asyncio + async def test_get_chat_client(self, mock_settings, isolated_config_file): + """Test get_chat_client convenience method.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + client = manager.get_chat_client() + assert client is not None + + @pytest.mark.asyncio + async def test_get_kb_processing_client(self, mock_settings, isolated_config_file): + """Test get_kb_processing_client convenience method.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + client = manager.get_kb_processing_client() + assert client is not None + + @pytest.mark.asyncio + async def test_close_all_clients(self, mock_settings, isolated_config_file): + """Test that close() closes all clients.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + for ut in LLMUsageType: + manager.get_client(ut) + + await manager.close() + + for ut in LLMUsageType: + assert manager._clients[ut] is None + + @pytest.mark.asyncio + async def test_config_persistence_to_file(self, mock_settings, isolated_config_file): + """Test that configs are persisted to file.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + await manager.update_usage_config( + usage_type=LLMUsageType.KB_PROCESSING, + provider="ollama", + config={"base_url": "http://localhost:11434/v1", "model": "llama3.2"}, + ) + + assert isolated_config_file.exists() + + with open(isolated_config_file, "r", encoding="utf-8") as f: + saved = json.load(f) + + assert "chat" in saved + assert "kb_processing" in saved + assert saved["kb_processing"]["provider"] == "ollama" + + @pytest.mark.asyncio + async def test_load_config_from_file(self, mock_settings, tmp_path): + """Test loading configs from file.""" + config_file = tmp_path / "llm_config.json" + saved_config = { + "chat": { + "provider": "openai", + "config": {"api_key": "test-key", "model": "gpt-4o"}, + }, + "kb_processing": { + "provider": "ollama", + "config": {"base_url": "http://localhost:11434/v1", "model": "llama3.2"}, + }, + } + + with open(config_file, "w", encoding="utf-8") as f: + json.dump(saved_config, f) + + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file): + manager = LLMConfigManager() + + chat_config = manager.get_current_config(LLMUsageType.CHAT) + assert chat_config["config"]["model"] == "gpt-4o" + + kb_config = manager.get_current_config(LLMUsageType.KB_PROCESSING) + assert kb_config["provider"] == "ollama" + assert kb_config["config"]["model"] == "llama3.2" + + @pytest.mark.asyncio + async def test_backward_compatibility_old_config_format(self, mock_settings, tmp_path): + """Test backward compatibility with old single-config format.""" + config_file = tmp_path / "llm_config.json" + old_config = { + "provider": "deepseek", + "config": {"api_key": "test-key", "model": "deepseek-chat"}, + } + + with open(config_file, "w", encoding="utf-8") as f: + json.dump(old_config, f) + + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", config_file): + manager = LLMConfigManager() + + for ut in LLMUsageType: + config = manager.get_current_config(ut) + assert config["provider"] == "deepseek" + assert config["config"]["model"] == "deepseek-chat" + + +class TestLLMConfigManagerTestConnection: + """Tests for test_connection with usage type support.""" + + @pytest.mark.asyncio + async def test_test_connection_with_usage_type(self, mock_settings, isolated_config_file): + """Test connection testing with specific usage type.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + + mock_client = AsyncMock() + mock_client.generate = AsyncMock( + return_value=MagicMock( + content="Test response", + model="gpt-4o-mini", + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + ) + mock_client.close = AsyncMock() + + with patch.object( + LLMProviderFactory, "create_client", return_value=mock_client + ): + result = await manager.test_connection( + test_prompt="Hello", + usage_type=LLMUsageType.CHAT, + ) + + assert result["success"] is True + assert "response" in result + assert "latency_ms" in result + + +class TestGetLLMConfigManager: + """Tests for get_llm_config_manager singleton.""" + + def test_singleton_instance(self, mock_settings, isolated_config_file): + """Test that get_llm_config_manager returns singleton.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager1 = get_llm_config_manager() + manager2 = get_llm_config_manager() + + assert manager1 is manager2 + + def test_reset_singleton(self, mock_settings, isolated_config_file): + """Test resetting singleton instance.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = get_llm_config_manager() + assert manager is not None + + +class TestLLMConfigManagerProperties: + """Tests for LLMConfigManager properties.""" + + def test_chat_config_property(self, mock_settings, isolated_config_file): + """Test chat_config property.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + config = manager.chat_config + assert isinstance(config, dict) + assert "model" in config + + def test_kb_processing_config_property(self, mock_settings, isolated_config_file): + """Test kb_processing_config property.""" + with patch("app.services.llm.factory.get_settings", return_value=mock_settings): + with patch("app.services.llm.factory.LLM_CONFIG_FILE", isolated_config_file): + manager = LLMConfigManager() + config = manager.kb_processing_config + assert isinstance(config, dict) + assert "model" in config diff --git a/ai-service/tests/test_markdown_chunker.py b/ai-service/tests/test_markdown_chunker.py new file mode 100644 index 0000000..85fcf34 --- /dev/null +++ b/ai-service/tests/test_markdown_chunker.py @@ -0,0 +1,530 @@ +""" +Unit tests for Markdown intelligent chunker. +Tests for MarkdownParser, MarkdownChunker, and integration. +""" + +import pytest + +from app.services.document.markdown_chunker import ( + MarkdownChunk, + MarkdownChunker, + MarkdownElement, + MarkdownElementType, + MarkdownParser, + chunk_markdown, +) + + +class TestMarkdownParser: + """Tests for MarkdownParser.""" + + def test_parse_headers(self): + """Test header extraction.""" + text = """# Main Title + +## Section 1 + +### Subsection 1.1 + +#### Deep Header +""" + parser = MarkdownParser() + elements = parser.parse(text) + + headers = [e for e in elements if e.type == MarkdownElementType.HEADER] + + assert len(headers) == 4 + assert headers[0].content == "Main Title" + assert headers[0].level == 1 + assert headers[1].content == "Section 1" + assert headers[1].level == 2 + assert headers[2].content == "Subsection 1.1" + assert headers[2].level == 3 + assert headers[3].content == "Deep Header" + assert headers[3].level == 4 + + def test_parse_code_blocks(self): + """Test code block extraction with language.""" + text = """Here is some code: + +```python +def hello(): + print("Hello, World!") +``` + +And some more text. +""" + parser = MarkdownParser() + elements = parser.parse(text) + + code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK] + + assert len(code_blocks) == 1 + assert code_blocks[0].language == "python" + assert 'def hello():' in code_blocks[0].content + assert 'print("Hello, World!")' in code_blocks[0].content + + def test_parse_code_blocks_no_language(self): + """Test code block without language specification.""" + text = """``` +plain code here +multiple lines +``` +""" + parser = MarkdownParser() + elements = parser.parse(text) + + code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK] + + assert len(code_blocks) == 1 + assert code_blocks[0].language == "" + assert "plain code here" in code_blocks[0].content + + def test_parse_tables(self): + """Test table extraction.""" + text = """| Name | Age | City | +|------|-----|------| +| Alice | 30 | NYC | +| Bob | 25 | LA | +""" + parser = MarkdownParser() + elements = parser.parse(text) + + tables = [e for e in elements if e.type == MarkdownElementType.TABLE] + + assert len(tables) == 1 + assert "Name" in tables[0].content + assert "Alice" in tables[0].content + assert tables[0].metadata.get("headers") == ["Name", "Age", "City"] + assert tables[0].metadata.get("row_count") == 2 + + def test_parse_lists(self): + """Test list extraction.""" + text = """- Item 1 +- Item 2 +- Item 3 +""" + parser = MarkdownParser() + elements = parser.parse(text) + + lists = [e for e in elements if e.type == MarkdownElementType.LIST] + + assert len(lists) == 1 + assert "Item 1" in lists[0].content + assert "Item 2" in lists[0].content + assert "Item 3" in lists[0].content + + def test_parse_ordered_lists(self): + """Test ordered list extraction.""" + text = """1. First +2. Second +3. Third +""" + parser = MarkdownParser() + elements = parser.parse(text) + + lists = [e for e in elements if e.type == MarkdownElementType.LIST] + + assert len(lists) == 1 + assert "First" in lists[0].content + assert "Second" in lists[0].content + assert "Third" in lists[0].content + + def test_parse_blockquotes(self): + """Test blockquote extraction.""" + text = """> This is a quote. +> It spans multiple lines. +> And continues here. +""" + parser = MarkdownParser() + elements = parser.parse(text) + + quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE] + + assert len(quotes) == 1 + assert "This is a quote." in quotes[0].content + assert "It spans multiple lines." in quotes[0].content + + def test_parse_paragraphs(self): + """Test paragraph extraction.""" + text = """This is the first paragraph. + +This is the second paragraph. +It has multiple lines. + +This is the third. +""" + parser = MarkdownParser() + elements = parser.parse(text) + + paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH] + + assert len(paragraphs) == 3 + assert "first paragraph" in paragraphs[0].content + assert "second paragraph" in paragraphs[1].content + + def test_parse_mixed_content(self): + """Test parsing mixed Markdown content.""" + text = """# Documentation + +## Introduction + +This is an introduction paragraph. + +## Code Example + +```python +def example(): + return 42 +``` + +## Data Table + +| Column A | Column B | +|----------|----------| +| Value 1 | Value 2 | + +## List + +- Item A +- Item B + +> Note: This is important. +""" + parser = MarkdownParser() + elements = parser.parse(text) + + headers = [e for e in elements if e.type == MarkdownElementType.HEADER] + code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK] + tables = [e for e in elements if e.type == MarkdownElementType.TABLE] + lists = [e for e in elements if e.type == MarkdownElementType.LIST] + quotes = [e for e in elements if e.type == MarkdownElementType.BLOCKQUOTE] + paragraphs = [e for e in elements if e.type == MarkdownElementType.PARAGRAPH] + + assert len(headers) == 5 + assert len(code_blocks) == 1 + assert len(tables) == 1 + assert len(lists) == 1 + assert len(quotes) == 1 + assert len(paragraphs) >= 1 + + def test_code_blocks_not_parsed_as_other_elements(self): + """Test that code blocks don't get parsed as headers or lists.""" + text = """```markdown +# This is not a header +- This is not a list +| This is not a table | +``` +""" + parser = MarkdownParser() + elements = parser.parse(text) + + headers = [e for e in elements if e.type == MarkdownElementType.HEADER] + lists = [e for e in elements if e.type == MarkdownElementType.LIST] + tables = [e for e in elements if e.type == MarkdownElementType.TABLE] + code_blocks = [e for e in elements if e.type == MarkdownElementType.CODE_BLOCK] + + assert len(headers) == 0 + assert len(lists) == 0 + assert len(tables) == 0 + assert len(code_blocks) == 1 + + +class TestMarkdownChunker: + """Tests for MarkdownChunker.""" + + def test_chunk_simple_document(self): + """Test chunking a simple document.""" + text = """# Title + +This is a paragraph. + +## Section + +Another paragraph. +""" + chunker = MarkdownChunker() + chunks = chunker.chunk(text, "test_doc") + + assert len(chunks) >= 2 + assert all(isinstance(chunk, MarkdownChunk) for chunk in chunks) + assert all(chunk.chunk_id.startswith("test_doc") for chunk in chunks) + + def test_chunk_preserves_header_context(self): + """Test that header context is preserved.""" + text = """# Main Title + +## Section A + +Content under section A. + +### Subsection A1 + +Content under subsection A1. +""" + chunker = MarkdownChunker(include_header_context=True) + chunks = chunker.chunk(text, "test") + + subsection_chunks = [c for c in chunks if "Subsection A1" not in c.content] + for chunk in subsection_chunks: + if "subsection a1" in chunk.content.lower(): + assert "Main Title" in chunk.header_context + assert "Section A" in chunk.header_context + + def test_chunk_code_blocks_preserved(self): + """Test that code blocks are preserved as single chunks when possible.""" + text = """```python +def function_one(): + pass + +def function_two(): + pass +``` +""" + chunker = MarkdownChunker(max_chunk_size=2000, preserve_code_blocks=True) + chunks = chunker.chunk(text, "test") + + code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK] + + assert len(code_chunks) == 1 + assert "def function_one" in code_chunks[0].content + assert "def function_two" in code_chunks[0].content + assert code_chunks[0].language == "python" + + def test_chunk_large_code_block_split(self): + """Test that large code blocks are split properly.""" + lines = ["def function_{}(): pass".format(i) for i in range(100)] + code_content = "\n".join(lines) + text = f"""```python\n{code_content}\n```""" + + chunker = MarkdownChunker(max_chunk_size=500, preserve_code_blocks=True) + chunks = chunker.chunk(text, "test") + + code_chunks = [c for c in chunks if c.element_type == MarkdownElementType.CODE_BLOCK] + + assert len(code_chunks) > 1 + for chunk in code_chunks: + assert chunk.language == "python" + assert "```python" in chunk.content + assert "```" in chunk.content + + def test_chunk_table_preserved(self): + """Test that tables are preserved.""" + text = """| Name | Age | +|------|-----| +| Alice | 30 | +| Bob | 25 | +""" + chunker = MarkdownChunker(max_chunk_size=2000, preserve_tables=True) + chunks = chunker.chunk(text, "test") + + table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE] + + assert len(table_chunks) == 1 + assert "Alice" in table_chunks[0].content + assert "Bob" in table_chunks[0].content + + def test_chunk_large_table_split(self): + """Test that large tables are split with header preserved.""" + rows = [f"| Name{i} | {i * 10} |" for i in range(50)] + table_content = "| Name | Age |\n|------|-----|\n" + "\n".join(rows) + text = table_content + + chunker = MarkdownChunker(max_chunk_size=200, preserve_tables=True) + chunks = chunker.chunk(text, "test") + + table_chunks = [c for c in chunks if c.element_type == MarkdownElementType.TABLE] + + assert len(table_chunks) > 1 + for chunk in table_chunks: + assert "| Name | Age |" in chunk.content + assert "|------|-----|" in chunk.content + + def test_chunk_list_preserved(self): + """Test that lists are chunked properly.""" + text = """- Item 1 +- Item 2 +- Item 3 +- Item 4 +- Item 5 +""" + chunker = MarkdownChunker(max_chunk_size=2000, preserve_lists=True) + chunks = chunker.chunk(text, "test") + + list_chunks = [c for c in chunks if c.element_type == MarkdownElementType.LIST] + + assert len(list_chunks) == 1 + assert "Item 1" in list_chunks[0].content + assert "Item 5" in list_chunks[0].content + + def test_chunk_empty_document(self): + """Test chunking an empty document.""" + text = "" + chunker = MarkdownChunker() + chunks = chunker.chunk(text, "test") + + assert len(chunks) == 0 + + def test_chunk_only_headers(self): + """Test chunking a document with only headers.""" + text = """# Title 1 +## Title 2 +### Title 3 +""" + chunker = MarkdownChunker() + chunks = chunker.chunk(text, "test") + + assert len(chunks) == 0 + + +class TestChunkMarkdownFunction: + """Tests for the convenience chunk_markdown function.""" + + def test_basic_chunking(self): + """Test basic chunking via convenience function.""" + text = """# Title + +Content paragraph. + +```python +code = "here" +``` +""" + chunks = chunk_markdown(text, "doc1") + + assert len(chunks) >= 1 + assert all("chunk_id" in chunk for chunk in chunks) + assert all("content" in chunk for chunk in chunks) + assert all("element_type" in chunk for chunk in chunks) + assert all("header_context" in chunk for chunk in chunks) + + def test_custom_parameters(self): + """Test chunking with custom parameters.""" + text = "A" * 2000 + + chunks = chunk_markdown( + text, + "doc1", + max_chunk_size=500, + min_chunk_size=50, + preserve_code_blocks=False, + preserve_tables=False, + preserve_lists=False, + include_header_context=False, + ) + + assert len(chunks) >= 1 + + +class TestMarkdownElement: + """Tests for MarkdownElement dataclass.""" + + def test_to_dict(self): + """Test serialization to dictionary.""" + elem = MarkdownElement( + type=MarkdownElementType.HEADER, + content="Test Header", + level=2, + line_start=10, + line_end=10, + metadata={"level": 2}, + ) + + result = elem.to_dict() + + assert result["type"] == "header" + assert result["content"] == "Test Header" + assert result["level"] == 2 + assert result["line_start"] == 10 + assert result["line_end"] == 10 + + def test_code_block_with_language(self): + """Test code block element with language.""" + elem = MarkdownElement( + type=MarkdownElementType.CODE_BLOCK, + content="print('hello')", + language="python", + line_start=5, + line_end=7, + ) + + result = elem.to_dict() + + assert result["type"] == "code_block" + assert result["language"] == "python" + + def test_table_with_metadata(self): + """Test table element with metadata.""" + elem = MarkdownElement( + type=MarkdownElementType.TABLE, + content="| A | B |\n|---|---|\n| 1 | 2 |", + line_start=1, + line_end=3, + metadata={"headers": ["A", "B"], "row_count": 1}, + ) + + result = elem.to_dict() + + assert result["type"] == "table" + assert result["metadata"]["headers"] == ["A", "B"] + assert result["metadata"]["row_count"] == 1 + + +class TestMarkdownChunk: + """Tests for MarkdownChunk dataclass.""" + + def test_to_dict(self): + """Test serialization to dictionary.""" + chunk = MarkdownChunk( + chunk_id="doc_chunk_0", + content="Test content", + element_type=MarkdownElementType.PARAGRAPH, + header_context=["Main Title", "Section"], + metadata={"key": "value"}, + ) + + result = chunk.to_dict() + + assert result["chunk_id"] == "doc_chunk_0" + assert result["content"] == "Test content" + assert result["element_type"] == "paragraph" + assert result["header_context"] == ["Main Title", "Section"] + assert result["metadata"]["key"] == "value" + + def test_with_language(self): + """Test chunk with language info.""" + chunk = MarkdownChunk( + chunk_id="code_0", + content="```python\nprint('hi')\n```", + element_type=MarkdownElementType.CODE_BLOCK, + header_context=[], + language="python", + ) + + result = chunk.to_dict() + + assert result["language"] == "python" + + +class TestMarkdownElementType: + """Tests for MarkdownElementType enum.""" + + def test_all_types_exist(self): + """Test that all expected element types exist.""" + expected_types = [ + "header", + "paragraph", + "code_block", + "inline_code", + "table", + "list", + "blockquote", + "horizontal_rule", + "image", + "link", + "text", + ] + + for type_name in expected_types: + assert hasattr(MarkdownElementType, type_name.upper()) or \ + any(t.value == type_name for t in MarkdownElementType) diff --git a/ai-service/tests/test_metadata_auto_inference.py b/ai-service/tests/test_metadata_auto_inference.py new file mode 100644 index 0000000..f898dab --- /dev/null +++ b/ai-service/tests/test_metadata_auto_inference.py @@ -0,0 +1,443 @@ +""" +Unit tests for MetadataAutoInferenceService. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import dataclass + +from app.services.metadata_auto_inference_service import ( + AutoInferenceResult, + InferenceFieldContext, + MetadataAutoInferenceService, +) + + +@dataclass +class MockFieldDefinition: + """Mock field definition for testing""" + field_key: str + label: str + type: str + required: bool + options: list[str] | None = None + + +class TestInferenceFieldContext: + """Test InferenceFieldContext dataclass.""" + + def test_creation(self): + """Test creating InferenceFieldContext.""" + ctx = InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ) + + assert ctx.field_key == "grade" + assert ctx.label == "年级" + assert ctx.type == "enum" + assert ctx.required is True + assert ctx.options == ["初一", "初二", "初三"] + + def test_default_values(self): + """Test default values.""" + ctx = InferenceFieldContext( + field_key="test", + label="Test", + type="text", + required=False, + ) + + assert ctx.options is None + assert ctx.description is None + + +class TestAutoInferenceResult: + """Test AutoInferenceResult dataclass.""" + + def test_success_result(self): + """Test successful inference result.""" + result = AutoInferenceResult( + inferred_metadata={"grade": "初一", "subject": "数学"}, + confidence_scores={"grade": 0.95, "subject": 0.85}, + raw_response='{"inferred_metadata": {...}}', + success=True, + ) + + assert result.success is True + assert result.error_message is None + assert result.inferred_metadata["grade"] == "初一" + + def test_failure_result(self): + """Test failed inference result.""" + result = AutoInferenceResult( + inferred_metadata={}, + confidence_scores={}, + raw_response="", + success=False, + error_message="JSON parse error", + ) + + assert result.success is False + assert result.error_message == "JSON parse error" + + +class TestMetadataAutoInferenceService: + """Test MetadataAutoInferenceService functionality.""" + + @pytest.fixture + def mock_session(self): + """Create mock session.""" + return AsyncMock() + + @pytest.fixture + def service(self, mock_session): + """Create service instance.""" + return MetadataAutoInferenceService(mock_session) + + def test_build_field_contexts(self, service): + """Test building field contexts from definitions.""" + fields = [ + MockFieldDefinition( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + MockFieldDefinition( + field_key="subject", + label="学科", + type="enum", + required=True, + options=["语文", "数学", "英语"], + ), + ] + + contexts = service._build_field_contexts(fields) + + assert len(contexts) == 2 + assert contexts[0].field_key == "grade" + assert contexts[0].options == ["初一", "初二", "初三"] + assert contexts[1].field_key == "subject" + + def test_build_user_prompt(self, service): + """Test building user prompt.""" + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + ] + + prompt = service._build_user_prompt( + content="这是一道初一数学题", + field_contexts=field_contexts, + ) + + assert "年级" in prompt + assert "初一, 初二, 初三" in prompt + assert "这是一道初一数学题" in prompt + + def test_build_user_prompt_with_existing_metadata(self, service): + """Test building user prompt with existing metadata.""" + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + ] + + prompt = service._build_user_prompt( + content="这是一道数学题", + field_contexts=field_contexts, + existing_metadata={"grade": "初二"}, + ) + + assert "已有值: 初二" in prompt + + def test_extract_json_from_plain_json(self, service): + """Test extracting JSON from plain JSON response.""" + json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}' + + result = service._extract_json(json_str) + + assert result == json_str + + def test_extract_json_from_markdown(self, service): + """Test extracting JSON from markdown code block.""" + markdown = """Here is the result: + +```json +{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}} +``` +""" + + result = service._extract_json(markdown) + + assert "inferred_metadata" in result + assert "grade" in result + + def test_parse_llm_response_valid(self, service): + """Test parsing valid LLM response.""" + response = json.dumps({ + "inferred_metadata": { + "grade": "初一", + "subject": "数学", + }, + "confidence_scores": { + "grade": 0.95, + "subject": 0.85, + } + }) + + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + InferenceFieldContext( + field_key="subject", + label="学科", + type="enum", + required=True, + options=["语文", "数学", "英语"], + ), + ] + + result = service._parse_llm_response(response, field_contexts) + + assert result.success is True + assert result.inferred_metadata["grade"] == "初一" + assert result.inferred_metadata["subject"] == "数学" + assert result.confidence_scores["grade"] == 0.95 + + def test_parse_llm_response_invalid_option(self, service): + """Test parsing response with invalid enum option.""" + response = json.dumps({ + "inferred_metadata": { + "grade": "高一", # Not in options + }, + "confidence_scores": { + "grade": 0.90, + } + }) + + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + ] + + result = service._parse_llm_response(response, field_contexts) + + assert result.success is True + assert "grade" not in result.inferred_metadata + + def test_parse_llm_response_invalid_json(self, service): + """Test parsing invalid JSON response.""" + response = "This is not valid JSON" + + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="text", + required=False, + ), + ] + + result = service._parse_llm_response(response, field_contexts) + + assert result.success is False + assert "JSON parse error" in result.error_message + + def test_validate_field_value_text(self, service): + """Test validating text field value.""" + ctx = InferenceFieldContext( + field_key="title", + label="标题", + type="text", + required=False, + ) + + result = service._validate_field_value(ctx, "测试标题") + + assert result == "测试标题" + + def test_validate_field_value_number(self, service): + """Test validating number field value.""" + ctx = InferenceFieldContext( + field_key="count", + label="数量", + type="number", + required=False, + ) + + assert service._validate_field_value(ctx, 42) == 42 + assert service._validate_field_value(ctx, "3.14") == 3.14 + assert service._validate_field_value(ctx, "invalid") is None + + def test_validate_field_value_boolean(self, service): + """Test validating boolean field value.""" + ctx = InferenceFieldContext( + field_key="active", + label="是否激活", + type="boolean", + required=False, + ) + + assert service._validate_field_value(ctx, True) is True + assert service._validate_field_value(ctx, "true") is True + assert service._validate_field_value(ctx, "false") is False + assert service._validate_field_value(ctx, 1) is True + + def test_validate_field_value_enum(self, service): + """Test validating enum field value.""" + ctx = InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=False, + options=["初一", "初二", "初三"], + ) + + assert service._validate_field_value(ctx, "初一") == "初一" + assert service._validate_field_value(ctx, "高一") is None + + def test_validate_field_value_array_enum(self, service): + """Test validating array_enum field value.""" + ctx = InferenceFieldContext( + field_key="tags", + label="标签", + type="array_enum", + required=False, + options=["重点", "难点", "易错"], + ) + + result = service._validate_field_value(ctx, ["重点", "难点"]) + assert result == ["重点", "难点"] + + result = service._validate_field_value(ctx, ["重点", "不存在"]) + assert result == ["重点"] + + result = service._validate_field_value(ctx, "重点") + assert result == ["重点"] + + +class TestIntegrationScenarios: + """Test integration scenarios.""" + + @pytest.fixture + def mock_session(self): + """Create mock session.""" + return AsyncMock() + + @pytest.fixture + def service(self, mock_session): + """Create service instance.""" + return MetadataAutoInferenceService(mock_session) + + def test_education_scenario(self, service): + """Test education scenario with grade and subject.""" + response = json.dumps({ + "inferred_metadata": { + "grade": "初二", + "subject": "物理", + "type": "痛点", + }, + "confidence_scores": { + "grade": 0.95, + "subject": 0.90, + "type": 0.85, + } + }) + + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + InferenceFieldContext( + field_key="subject", + label="学科", + type="enum", + required=True, + options=["语文", "数学", "英语", "物理", "化学"], + ), + InferenceFieldContext( + field_key="type", + label="类型", + type="enum", + required=False, + options=["痛点", "重点", "难点"], + ), + ] + + result = service._parse_llm_response(response, field_contexts) + + assert result.success is True + assert result.inferred_metadata == { + "grade": "初二", + "subject": "物理", + "type": "痛点", + } + assert result.confidence_scores["grade"] == 0.95 + + def test_partial_inference(self, service): + """Test partial inference when some fields cannot be inferred.""" + response = json.dumps({ + "inferred_metadata": { + "grade": "初一", + }, + "confidence_scores": { + "grade": 0.90, + } + }) + + field_contexts = [ + InferenceFieldContext( + field_key="grade", + label="年级", + type="enum", + required=True, + options=["初一", "初二", "初三"], + ), + InferenceFieldContext( + field_key="subject", + label="学科", + type="enum", + required=True, + options=["语文", "数学", "英语"], + ), + ] + + result = service._parse_llm_response(response, field_contexts) + + assert result.success is True + assert "grade" in result.inferred_metadata + assert "subject" not in result.inferred_metadata + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/ai-service/tests/test_mid_dialogue_integration.py b/ai-service/tests/test_mid_dialogue_integration.py new file mode 100644 index 0000000..e24cf93 --- /dev/null +++ b/ai-service/tests/test_mid_dialogue_integration.py @@ -0,0 +1,881 @@ +""" +Mid Platform Dialogue Integration Test. +中台联调界面对话过程集成测试脚本 + +测试重点: +1. 意图置信度参数 +2. 执行模式(通用API vs ReAct模式) +3. ReAct模式下的工具调用(工具名称、入参、返回结果) +4. 知识库查询(是否命中、入参) +5. 各部分耗时 +6. 提示词模板使用情况 +""" + +import asyncio +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from httpx import AsyncClient + +from app.main import app +from app.models.mid.schemas import ( + DialogueRequest, + DialogueResponse, + ExecutionMode, + FeatureFlags, + HistoryMessage, + Segment, + TraceInfo, + ToolCallTrace, + ToolCallStatus, + ToolType, + IntentHintOutput, + HighRiskCheckResult, + HighRiskScenario, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class TimingRecord: + """耗时记录""" + stage: str + start_time: float + end_time: float + duration_ms: int + + def to_dict(self) -> dict: + return { + "stage": self.stage, + "duration_ms": self.duration_ms, + } + + +@dataclass +class DialogueTestResult: + """对话测试结果""" + request_id: str + user_message: str + response_text: str = "" + + timing_records: list[TimingRecord] = field(default_factory=list) + total_duration_ms: int = 0 + + execution_mode: ExecutionMode = ExecutionMode.AGENT + intent: str | None = None + confidence: float | None = None + + react_iterations: int = 0 + tools_used: list[str] = field(default_factory=list) + tool_calls: list[dict] = field(default_factory=list) + + kb_tool_called: bool = False + kb_hit: bool = False + kb_query: str | None = None + kb_filter: dict | None = None + kb_hits_count: int = 0 + + prompt_template_used: str | None = None + prompt_template_scene: str | None = None + + guardrail_triggered: bool = False + fallback_reason_code: str | None = None + + raw_trace: dict | None = None + + def to_summary(self) -> dict: + return { + "request_id": self.request_id, + "user_message": self.user_message[:100], + "execution_mode": self.execution_mode.value, + "intent": self.intent, + "confidence": self.confidence, + "total_duration_ms": self.total_duration_ms, + "timing_breakdown": [t.to_dict() for t in self.timing_records], + "react_iterations": self.react_iterations, + "tools_used": self.tools_used, + "tool_calls_count": len(self.tool_calls), + "kb_tool_called": self.kb_tool_called, + "kb_hit": self.kb_hit, + "kb_hits_count": self.kb_hits_count, + "prompt_template_used": self.prompt_template_used, + "guardrail_triggered": self.guardrail_triggered, + "fallback_reason_code": self.fallback_reason_code, + } + + +class DialogueIntegrationTester: + """对话集成测试器""" + + def __init__( + self, + base_url: str = "http://localhost:8000", + tenant_id: str = "test_tenant", + api_key: str | None = None, + ): + self.base_url = base_url + self.tenant_id = tenant_id + self.api_key = api_key + self.session_id = f"test_session_{uuid.uuid4().hex[:8]}" + self.user_id = f"test_user_{uuid.uuid4().hex[:8]}" + + def _get_headers(self) -> dict: + headers = { + "Content-Type": "application/json", + "X-Tenant-Id": self.tenant_id, + } + if self.api_key: + headers["X-API-Key"] = self.api_key + return headers + + async def send_dialogue( + self, + user_message: str, + history: list[dict] | None = None, + scene: str | None = None, + feature_flags: dict | None = None, + ) -> DialogueTestResult: + """发送对话请求并记录详细信息""" + request_id = str(uuid.uuid4()) + result = DialogueTestResult( + request_id=request_id, + user_message=user_message, + response_text="", + ) + + overall_start = time.time() + + request_body = { + "session_id": self.session_id, + "user_id": self.user_id, + "user_message": user_message, + "history": history or [], + } + + if scene: + request_body["scene"] = scene + if feature_flags: + request_body["feature_flags"] = feature_flags + + timing_records = [] + + try: + async with AsyncClient(base_url=self.base_url, timeout=120.0) as client: + request_start = time.time() + + response = await client.post( + "/mid/dialogue/respond", + json=request_body, + headers=self._get_headers(), + ) + + request_end = time.time() + timing_records.append(TimingRecord( + stage="http_request", + start_time=request_start, + end_time=request_end, + duration_ms=int((request_end - request_start) * 1000), + )) + + if response.status_code != 200: + result.response_text = f"Error: {response.status_code}" + result.fallback_reason_code = f"http_error_{response.status_code}" + return result + + response_data = response.json() + + except Exception as e: + result.response_text = f"Exception: {str(e)}" + result.fallback_reason_code = "request_exception" + result.total_duration_ms = int((time.time() - overall_start) * 1000) + return result + + overall_end = time.time() + result.total_duration_ms = int((overall_end - overall_start) * 1000) + result.timing_records = timing_records + + try: + dialogue_response = DialogueResponse(**response_data) + + if dialogue_response.segments: + result.response_text = "\n".join(s.text for s in dialogue_response.segments) + + trace = dialogue_response.trace + result.raw_trace = trace.model_dump() if trace else None + + if trace: + result.execution_mode = trace.mode + result.intent = trace.intent + result.react_iterations = trace.react_iterations or 0 + result.tools_used = trace.tools_used or [] + result.kb_tool_called = trace.kb_tool_called or False + result.kb_hit = trace.kb_hit or False + result.guardrail_triggered = trace.guardrail_triggered or False + result.fallback_reason_code = trace.fallback_reason_code + + if trace.tool_calls: + result.tool_calls = [tc.model_dump() for tc in trace.tool_calls] + + for tc in trace.tool_calls: + if tc.tool_name == "kb_search_dynamic": + result.kb_tool_called = True + if tc.arguments: + result.kb_query = tc.arguments.get("query") + result.kb_filter = tc.arguments.get("context") + if tc.result and isinstance(tc.result, dict): + result.kb_hits_count = len(tc.result.get("hits", [])) + result.kb_hit = result.kb_hits_count > 0 + + if trace.duration_ms: + timing_records.append(TimingRecord( + stage="server_processing", + start_time=overall_start, + end_time=overall_end, + duration_ms=trace.duration_ms, + )) + + if trace.scene: + result.prompt_template_scene = trace.scene + + except Exception as e: + logger.error(f"Failed to parse response: {e}") + result.response_text = str(response_data) + + return result + + def print_result(self, result: DialogueTestResult): + """打印测试结果""" + print("\n" + "=" * 80) + print(f"[对话测试结果] Request ID: {result.request_id}") + print("=" * 80) + + print(f"\n[用户消息] {result.user_message}") + print(f"[回复内容] {result.response_text[:200]}...") + + print(f"\n[执行模式] {result.execution_mode.value}") + print(f"[意图识别] intent={result.intent}, confidence={result.confidence}") + + print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms") + for tr in result.timing_records: + print(f" - {tr.stage}: {tr.duration_ms}ms") + + if result.execution_mode == ExecutionMode.AGENT: + print(f"\n[ReAct模式]") + print(f" - 迭代次数: {result.react_iterations}") + print(f" - 使用的工具: {result.tools_used}") + + if result.tool_calls: + print(f"\n[工具调用详情]") + for i, tc in enumerate(result.tool_calls, 1): + print(f" [{i}] 工具: {tc.get('tool_name')}") + print(f" 状态: {tc.get('status')}") + print(f" 耗时: {tc.get('duration_ms')}ms") + if tc.get('arguments'): + print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}") + if tc.get('result'): + result_str = str(tc.get('result'))[:300] + print(f" 结果: {result_str}") + + print(f"\n[知识库查询]") + print(f" - 是否调用: {result.kb_tool_called}") + print(f" - 是否命中: {result.kb_hit}") + if result.kb_query: + print(f" - 查询内容: {result.kb_query}") + if result.kb_filter: + print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}") + print(f" - 命中数量: {result.kb_hits_count}") + + print(f"\n[提示词模板]") + print(f" - 场景: {result.prompt_template_scene}") + print(f" - 使用模板: {result.prompt_template_used or '默认模板'}") + + print(f"\n[其他信息]") + print(f" - 护栏触发: {result.guardrail_triggered}") + print(f" - 降级原因: {result.fallback_reason_code or '无'}") + + print("\n" + "=" * 80) + + +class TestMidDialogueIntegration: + """中台对话集成测试""" + + @pytest.fixture + def tester(self): + return DialogueIntegrationTester( + base_url="http://localhost:8000", + tenant_id="test_tenant", + ) + + @pytest.fixture + def mock_llm_client(self): + """模拟 LLM 客户端""" + mock = AsyncMock() + mock.generate = AsyncMock(return_value=MagicMock( + content="这是测试回复", + has_tool_calls=False, + tool_calls=[], + )) + return mock + + @pytest.fixture + def mock_kb_tool(self): + """模拟知识库工具""" + mock = AsyncMock() + mock.execute = AsyncMock(return_value=MagicMock( + success=True, + hits=[ + {"id": "1", "content": "测试知识库内容", "score": 0.9}, + ], + applied_filter={"scene": "test"}, + missing_required_slots=[], + fallback_reason_code=None, + duration_ms=100, + tool_trace=None, + )) + return mock + + @pytest.mark.asyncio + async def test_simple_greeting(self, tester: DialogueIntegrationTester): + """测试简单问候""" + result = await tester.send_dialogue( + user_message="你好", + ) + + tester.print_result(result) + + assert result.request_id is not None + assert result.total_duration_ms > 0 + + @pytest.mark.asyncio + async def test_kb_query(self, tester: DialogueIntegrationTester): + """测试知识库查询""" + result = await tester.send_dialogue( + user_message="退款流程是什么?", + scene="after_sale", + ) + + tester.print_result(result) + + assert result.request_id is not None + + @pytest.mark.asyncio + async def test_high_risk_scenario(self, tester: DialogueIntegrationTester): + """测试高风险场景""" + result = await tester.send_dialogue( + user_message="我要投诉你们的服务", + ) + + tester.print_result(result) + + assert result.request_id is not None + + @pytest.mark.asyncio + async def test_transfer_request(self, tester: DialogueIntegrationTester): + """测试转人工请求""" + result = await tester.send_dialogue( + user_message="帮我转人工客服", + ) + + tester.print_result(result) + + assert result.request_id is not None + + @pytest.mark.asyncio + async def test_with_history(self, tester: DialogueIntegrationTester): + """测试带历史记录的对话""" + result = await tester.send_dialogue( + user_message="那退款要多久呢?", + history=[ + {"role": "user", "content": "我想退款"}, + {"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"}, + ], + ) + + tester.print_result(result) + + assert result.request_id is not None + + +class TestDialogueWithMock: + """使用 Mock 的对话测试""" + + @pytest.fixture + def mock_app(self): + """创建带 Mock 的测试应用""" + from fastapi import FastAPI + from app.api.mid.dialogue import router + + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self, mock_app): + return TestClient(mock_app) + + @pytest.fixture + def mock_session(self): + """模拟数据库会话""" + mock = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock.execute.return_value = mock_result + return mock + + @pytest.fixture + def mock_llm(self): + """模拟 LLM 响应""" + mock_response = MagicMock() + mock_response.content = "这是测试回复内容" + mock_response.has_tool_calls = False + mock_response.tool_calls = [] + return mock_response + + def test_dialogue_request_structure(self, client: TestClient): + """测试对话请求结构""" + request_body = { + "session_id": "test_session_001", + "user_id": "test_user_001", + "user_message": "你好", + "history": [], + "scene": "open_consult", + } + + print("\n[测试请求结构]") + print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}") + + assert request_body["session_id"] == "test_session_001" + assert request_body["user_message"] == "你好" + + def test_trace_info_structure(self): + """测试追踪信息结构""" + trace = TraceInfo( + mode=ExecutionMode.AGENT, + intent="greeting", + request_id=str(uuid.uuid4()), + generation_id=str(uuid.uuid4()), + kb_tool_called=True, + kb_hit=True, + react_iterations=2, + tools_used=["kb_search_dynamic"], + tool_calls=[ + ToolCallTrace( + tool_name="kb_search_dynamic", + tool_type=ToolType.INTERNAL, + duration_ms=150, + status=ToolCallStatus.OK, + arguments={"query": "测试查询", "scene": "test"}, + result={"hits": [{"content": "测试内容"}]}, + ), + ], + ) + + print("\n[TraceInfo 结构测试]") + print(f"Mode: {trace.mode.value}") + print(f"Intent: {trace.intent}") + print(f"KB Tool Called: {trace.kb_tool_called}") + print(f"KB Hit: {trace.kb_hit}") + print(f"React Iterations: {trace.react_iterations}") + print(f"Tools Used: {trace.tools_used}") + + if trace.tool_calls: + print(f"\n[Tool Calls]") + for tc in trace.tool_calls: + print(f" - Tool: {tc.tool_name}") + print(f" Status: {tc.status.value}") + print(f" Duration: {tc.duration_ms}ms") + if tc.arguments: + print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}") + + assert trace.mode == ExecutionMode.AGENT + assert trace.kb_tool_called is True + assert len(trace.tool_calls) == 1 + + def test_intent_hint_output_structure(self): + """测试意图提示输出结构""" + hint = IntentHintOutput( + intent="refund", + confidence=0.85, + response_type="flow", + suggested_mode=ExecutionMode.MICRO_FLOW, + target_flow_id="flow_refund_001", + high_risk_detected=False, + duration_ms=50, + ) + + print("\n[IntentHintOutput 结构测试]") + print(f"Intent: {hint.intent}") + print(f"Confidence: {hint.confidence}") + print(f"Response Type: {hint.response_type}") + print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}") + print(f"High Risk Detected: {hint.high_risk_detected}") + print(f"Duration: {hint.duration_ms}ms") + + assert hint.intent == "refund" + assert hint.confidence == 0.85 + assert hint.suggested_mode == ExecutionMode.MICRO_FLOW + + def test_high_risk_check_result_structure(self): + """测试高风险检测结果结构""" + result = HighRiskCheckResult( + matched=True, + risk_scenario=HighRiskScenario.REFUND, + confidence=0.95, + recommended_mode=ExecutionMode.MICRO_FLOW, + rule_id="rule_refund_001", + reason="检测到退款关键词", + duration_ms=30, + ) + + print("\n[HighRiskCheckResult 结构测试]") + print(f"Matched: {result.matched}") + print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}") + print(f"Confidence: {result.confidence}") + print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}") + print(f"Rule ID: {result.rule_id}") + print(f"Duration: {result.duration_ms}ms") + + assert result.matched is True + assert result.risk_scenario == HighRiskScenario.REFUND + + +class TestPromptTemplateUsage: + """提示词模板使用测试""" + + def test_template_resolution(self): + """测试模板解析""" + from app.services.prompt.variable_resolver import VariableResolver + + resolver = VariableResolver() + + template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。" + variables = [ + {"key": "user_name", "value": "张三"}, + {"key": "bot_name", "value": "智能客服"}, + ] + + resolved = resolver.resolve(template, variables) + + print("\n[模板解析测试]") + print(f"原始模板: {template}") + print(f"变量: {json.dumps(variables, ensure_ascii=False)}") + print(f"解析结果: {resolved}") + + assert resolved == "你好,张三!我是智能客服,很高兴为您服务。" + + def test_template_with_extra_context(self): + """测试带额外上下文的模板解析""" + from app.services.prompt.variable_resolver import VariableResolver + + resolver = VariableResolver() + + template = "当前场景:{{scene}},用户问题:{{query}}" + extra_context = { + "scene": "售后服务", + "query": "退款流程", + } + + resolved = resolver.resolve(template, [], extra_context) + + print("\n[带上下文的模板解析测试]") + print(f"原始模板: {template}") + print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}") + print(f"解析结果: {resolved}") + + assert "售后服务" in resolved + assert "退款流程" in resolved + + +class TestToolCallRecording: + """工具调用记录测试""" + + def test_tool_call_trace_creation(self): + """测试工具调用追踪创建""" + trace = ToolCallTrace( + tool_name="kb_search_dynamic", + tool_type=ToolType.INTERNAL, + duration_ms=150, + status=ToolCallStatus.OK, + args_digest="query=退款流程", + result_digest="hits=3", + arguments={ + "query": "退款流程是什么", + "scene": "after_sale", + "context": {"product_type": "vip"}, + }, + result={ + "success": True, + "hits": [ + {"id": "1", "content": "退款流程说明...", "score": 0.95}, + {"id": "2", "content": "退款注意事项...", "score": 0.88}, + {"id": "3", "content": "退款时效说明...", "score": 0.82}, + ], + "applied_filter": {"product_type": "vip"}, + }, + ) + + print("\n[工具调用追踪测试]") + print(f"Tool Name: {trace.tool_name}") + print(f"Tool Type: {trace.tool_type.value}") + print(f"Status: {trace.status.value}") + print(f"Duration: {trace.duration_ms}ms") + print(f"\n[入参详情]") + if trace.arguments: + for key, value in trace.arguments.items(): + print(f" - {key}: {value}") + print(f"\n[返回结果]") + if trace.result: + if isinstance(trace.result, dict): + print(f" - success: {trace.result.get('success')}") + print(f" - hits count: {len(trace.result.get('hits', []))}") + for i, hit in enumerate(trace.result.get('hits', [])[:2], 1): + print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...") + + assert trace.tool_name == "kb_search_dynamic" + assert trace.status == ToolCallStatus.OK + assert trace.arguments is not None + assert trace.result is not None + + def test_tool_call_timeout_trace(self): + """测试工具调用超时追踪""" + trace = ToolCallTrace( + tool_name="kb_search_dynamic", + tool_type=ToolType.INTERNAL, + duration_ms=2000, + status=ToolCallStatus.TIMEOUT, + error_code="TOOL_TIMEOUT", + arguments={"query": "测试查询"}, + ) + + print("\n[工具调用超时追踪测试]") + print(f"Tool Name: {trace.tool_name}") + print(f"Status: {trace.status.value}") + print(f"Error Code: {trace.error_code}") + print(f"Duration: {trace.duration_ms}ms") + + assert trace.status == ToolCallStatus.TIMEOUT + assert trace.error_code == "TOOL_TIMEOUT" + + +class TestExecutionModeRouting: + """执行模式路由测试""" + + def test_policy_router_decision(self): + """测试策略路由器决策""" + from app.services.mid.policy_router import PolicyRouter, IntentMatch + + router = PolicyRouter() + + test_cases = [ + { + "name": "正常对话 -> Agent模式", + "user_message": "你好,请问有什么可以帮助我的?", + "session_mode": "BOT_ACTIVE", + "expected_mode": ExecutionMode.AGENT, + }, + { + "name": "高风险退款 -> Micro Flow模式", + "user_message": "我要退款", + "session_mode": "BOT_ACTIVE", + "expected_mode": ExecutionMode.MICRO_FLOW, + }, + { + "name": "转人工请求 -> Transfer模式", + "user_message": "帮我转人工", + "session_mode": "BOT_ACTIVE", + "expected_mode": ExecutionMode.TRANSFER, + }, + { + "name": "人工模式 -> Transfer模式", + "user_message": "你好", + "session_mode": "HUMAN_ACTIVE", + "expected_mode": ExecutionMode.TRANSFER, + }, + ] + + print("\n[策略路由器决策测试]") + + for tc in test_cases: + result = router.route( + user_message=tc["user_message"], + session_mode=tc["session_mode"], + ) + + print(f"\n测试用例: {tc['name']}") + print(f" 用户消息: {tc['user_message']}") + print(f" 会话模式: {tc['session_mode']}") + print(f" 期望模式: {tc['expected_mode'].value}") + print(f" 实际模式: {result.mode.value}") + + assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}" + + +class TestKBSearchDynamic: + """知识库动态检索测试""" + + def test_kb_search_result_structure(self): + """测试知识库检索结果结构""" + from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult + + result = KbSearchDynamicResult( + success=True, + hits=[ + { + "id": "chunk_001", + "content": "退款流程:1. 登录账户 2. 进入订单页面 3. 点击退款按钮...", + "score": 0.92, + "metadata": {"kb_id": "kb_001", "doc_id": "doc_001"}, + }, + { + "id": "chunk_002", + "content": "退款时效:一般3-5个工作日到账...", + "score": 0.85, + "metadata": {"kb_id": "kb_001", "doc_id": "doc_002"}, + }, + ], + applied_filter={"scene": "after_sale", "product_type": "vip"}, + missing_required_slots=[], + filter_debug={"source": "slot_state"}, + filter_sources={"scene": "slot", "product_type": "context"}, + duration_ms=120, + ) + + print("\n[知识库检索结果结构测试]") + print(f"Success: {result.success}") + print(f"Hits Count: {len(result.hits)}") + print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}") + print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}") + print(f"Duration: {result.duration_ms}ms") + + print(f"\n[命中详情]") + for i, hit in enumerate(result.hits, 1): + print(f" [{i}] ID: {hit['id']}") + print(f" Score: {hit['score']}") + print(f" Content: {hit['content'][:50]}...") + + assert result.success is True + assert len(result.hits) == 2 + assert result.applied_filter is not None + + def test_kb_search_missing_slots(self): + """测试知识库检索缺失槽位""" + from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult + + result = KbSearchDynamicResult( + success=False, + hits=[], + applied_filter={}, + missing_required_slots=[ + { + "field_key": "order_id", + "label": "订单号", + "reason": "必填字段缺失", + "ask_back_prompt": "请提供您的订单号", + }, + ], + filter_debug={"source": "builder"}, + fallback_reason_code="MISSING_REQUIRED_SLOTS", + duration_ms=50, + ) + + print("\n[知识库检索缺失槽位测试]") + print(f"Success: {result.success}") + print(f"Fallback Reason: {result.fallback_reason_code}") + print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}") + + assert result.success is False + assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS" + assert len(result.missing_required_slots) == 1 + + +class TestTimingBreakdown: + """耗时分解测试""" + + def test_timing_breakdown_structure(self): + """测试耗时分解结构""" + timings = [ + TimingRecord("intent_matching", 0, 0.05, 50), + TimingRecord("high_risk_check", 0.05, 0.08, 30), + TimingRecord("kb_search", 0.08, 0.2, 120), + TimingRecord("llm_generation", 0.2, 1.5, 1300), + TimingRecord("output_guardrail", 1.5, 1.55, 50), + TimingRecord("response_formatting", 1.55, 1.6, 50), + ] + + total = sum(t.duration_ms for t in timings) + + print("\n[耗时分解测试]") + print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}") + print("-" * 45) + + for t in timings: + percentage = (t.duration_ms / total * 100) if total > 0 else 0 + print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%") + + print("-" * 45) + print(f"{'总计':<25} {total:<10} {'100.0%':<10}") + + assert total == 1600 + + +def run_manual_test(): + """手动运行测试""" + import argparse + + parser = argparse.ArgumentParser(description="中台对话集成测试") + parser.add_argument("--url", default="http://localhost:8000", help="服务地址") + parser.add_argument("--tenant", default="test_tenant", help="租户ID") + parser.add_argument("--api-key", default=None, help="API Key") + parser.add_argument("--message", default="你好", help="测试消息") + parser.add_argument("--scene", default=None, help="场景标识") + parser.add_argument("--interactive", action="store_true", help="交互模式") + + args = parser.parse_args() + + tester = DialogueIntegrationTester( + base_url=args.url, + tenant_id=args.tenant, + api_key=args.api_key, + ) + + if args.interactive: + print("\n=== 中台对话集成测试 - 交互模式 ===") + print("输入 'quit' 退出\n") + + while True: + try: + message = input("请输入消息: ").strip() + if message.lower() == "quit": + break + + scene = input("请输入场景(可选,直接回车跳过): ").strip() or None + + result = asyncio.run(tester.send_dialogue( + user_message=message, + scene=scene, + )) + + tester.print_result(result) + + except KeyboardInterrupt: + break + else: + result = asyncio.run(tester.send_dialogue( + user_message=args.message, + scene=args.scene, + )) + + tester.print_result(result) + + +if __name__ == "__main__": + run_manual_test() diff --git a/ai-service/tests/test_retrieval_strategy.py b/ai-service/tests/test_retrieval_strategy.py new file mode 100644 index 0000000..2a5eec2 --- /dev/null +++ b/ai-service/tests/test_retrieval_strategy.py @@ -0,0 +1,541 @@ +""" +Unit tests for Retrieval Strategy Service. +[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback. +""" + +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from app.schemas.retrieval_strategy import ( + ReactMode, + RolloutConfig, + RolloutMode, + StrategyType, + RetrievalStrategyStatus, + RetrievalStrategySwitchRequest, + RetrievalStrategyValidationRequest, + ValidationResult, +) +from app.services.retrieval.strategy_service import ( + RetrievalStrategyService, + StrategyState, + get_strategy_service, +) +from app.services.retrieval.strategy_audit import ( + StrategyAuditService, + get_audit_service, +) +from app.services.retrieval.strategy_metrics import ( + StrategyMetricsService, + get_metrics_service, +) + + +class TestRetrievalStrategySchemas: + """[AC-AISVC-RES-01~15] Tests for strategy schema models.""" + + def test_rollout_config_off_mode(self): + """[AC-AISVC-RES-03] Off mode should not require percentage or allowlist.""" + config = RolloutConfig(mode=RolloutMode.OFF) + assert config.mode == RolloutMode.OFF + assert config.percentage is None + assert config.allowlist is None + + def test_rollout_config_percentage_mode(self): + """[AC-AISVC-RES-03] Percentage mode should require percentage.""" + config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0) + assert config.mode == RolloutMode.PERCENTAGE + assert config.percentage == 50.0 + + def test_rollout_config_percentage_mode_missing_value(self): + """[AC-AISVC-RES-03] Percentage mode without percentage should raise error.""" + with pytest.raises(ValueError, match="percentage is required"): + RolloutConfig(mode=RolloutMode.PERCENTAGE) + + def test_rollout_config_allowlist_mode(self): + """[AC-AISVC-RES-03] Allowlist mode should require allowlist.""" + config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"]) + assert config.mode == RolloutMode.ALLOWLIST + assert config.allowlist == ["tenant1", "tenant2"] + + def test_rollout_config_allowlist_mode_missing_value(self): + """[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error.""" + with pytest.raises(ValueError, match="allowlist is required"): + RolloutConfig(mode=RolloutMode.ALLOWLIST) + + def test_retrieval_strategy_status(self): + """[AC-AISVC-RES-01] Status should contain all required fields.""" + rollout = RolloutConfig(mode=RolloutMode.OFF) + status = RetrievalStrategyStatus( + active_strategy=StrategyType.DEFAULT, + react_mode=ReactMode.NON_REACT, + rollout=rollout, + ) + assert status.active_strategy == StrategyType.DEFAULT + assert status.react_mode == ReactMode.NON_REACT + assert status.rollout.mode == RolloutMode.OFF + + def test_switch_request_minimal(self): + """[AC-AISVC-RES-02] Switch request should work with minimal fields.""" + request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED) + assert request.target_strategy == StrategyType.ENHANCED + assert request.react_mode is None + assert request.rollout is None + assert request.reason is None + + def test_switch_request_full(self): + """[AC-AISVC-RES-02,03,05] Switch request should accept all fields.""" + rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0) + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + react_mode=ReactMode.REACT, + rollout=rollout, + reason="Testing enhanced strategy", + ) + assert request.target_strategy == StrategyType.ENHANCED + assert request.react_mode == ReactMode.REACT + assert request.rollout.percentage == 30.0 + assert request.reason == "Testing enhanced strategy" + + +class TestRetrievalStrategyService: + """[AC-AISVC-RES-01~15] Tests for strategy service.""" + + @pytest.fixture + def service(self): + """Create a fresh service instance for each test.""" + return RetrievalStrategyService() + + def test_get_current_status_default(self, service): + """[AC-AISVC-RES-01] Default status should be default strategy and non_react mode.""" + status = service.get_current_status() + + assert status.active_strategy == StrategyType.DEFAULT + assert status.react_mode == ReactMode.NON_REACT + assert status.rollout.mode == RolloutMode.OFF + + def test_switch_strategy_to_enhanced(self, service): + """[AC-AISVC-RES-02] Should switch to enhanced strategy.""" + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + react_mode=ReactMode.REACT, + ) + + response = service.switch_strategy(request) + + assert response.previous.active_strategy == StrategyType.DEFAULT + assert response.current.active_strategy == StrategyType.ENHANCED + assert response.current.react_mode == ReactMode.REACT + + def test_switch_strategy_with_grayscale_percentage(self, service): + """[AC-AISVC-RES-03] Should switch with grayscale percentage.""" + rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0) + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + rollout=rollout, + ) + + response = service.switch_strategy(request) + + assert response.current.active_strategy == StrategyType.ENHANCED + assert response.current.rollout.mode == RolloutMode.PERCENTAGE + assert response.current.rollout.percentage == 50.0 + + def test_switch_strategy_with_allowlist(self, service): + """[AC-AISVC-RES-03] Should switch with allowlist grayscale.""" + rollout = RolloutConfig( + mode=RolloutMode.ALLOWLIST, + allowlist=["tenant_a", "tenant_b"], + ) + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + rollout=rollout, + ) + + response = service.switch_strategy(request) + + assert response.current.rollout.mode == RolloutMode.ALLOWLIST + assert "tenant_a" in response.current.rollout.allowlist + + def test_rollback_strategy(self, service): + """[AC-AISVC-RES-07] Should rollback to previous strategy.""" + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + react_mode=ReactMode.REACT, + ) + service.switch_strategy(request) + + response = service.rollback_strategy() + + assert response.rollback_to.active_strategy == StrategyType.DEFAULT + assert response.rollback_to.react_mode == ReactMode.NON_REACT + + def test_rollback_without_previous_returns_default(self, service): + """[AC-AISVC-RES-07] Rollback without previous should return default.""" + response = service.rollback_strategy() + + assert response.rollback_to.active_strategy == StrategyType.DEFAULT + + def test_should_use_enhanced_strategy_default(self, service): + """[AC-AISVC-RES-01] Default strategy should not use enhanced.""" + assert service.should_use_enhanced_strategy("tenant_a") is False + + def test_should_use_enhanced_strategy_with_allowlist(self, service): + """[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist.""" + rollout = RolloutConfig( + mode=RolloutMode.ALLOWLIST, + allowlist=["tenant_a"], + ) + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + rollout=rollout, + ) + service.switch_strategy(request) + + assert service.should_use_enhanced_strategy("tenant_a") is True + assert service.should_use_enhanced_strategy("tenant_b") is False + + def test_get_route_mode_react(self, service): + """[AC-AISVC-RES-10] React mode should return react route.""" + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + react_mode=ReactMode.REACT, + ) + service.switch_strategy(request) + + route = service.get_route_mode("test query") + assert route == "react" + + def test_get_route_mode_direct(self, service): + """[AC-AISVC-RES-09] Non-react mode should return direct route.""" + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.DEFAULT, + react_mode=ReactMode.NON_REACT, + ) + service.switch_strategy(request) + + route = service.get_route_mode("test query") + assert route == "direct" + + def test_get_route_mode_auto_short_query(self, service): + """[AC-AISVC-RES-12] Short query with high confidence should use direct route.""" + service._state.react_mode = None + + route = service._auto_route("短问题", confidence=0.8) + + assert route == "direct" + + def test_get_route_mode_auto_multiple_conditions(self, service): + """[AC-AISVC-RES-13] Query with multiple conditions should use react route.""" + route = service._auto_route("查询订单状态和物流信息") + + assert route == "react" + + def test_get_route_mode_auto_low_confidence(self, service): + """[AC-AISVC-RES-13] Low confidence should use react route.""" + route = service._auto_route("test query", confidence=0.3) + + assert route == "react" + + def test_get_switch_history(self, service): + """Should track switch history.""" + request = RetrievalStrategySwitchRequest( + target_strategy=StrategyType.ENHANCED, + reason="Testing", + ) + service.switch_strategy(request) + + history = service.get_switch_history() + + assert len(history) == 1 + assert history[0]["to_strategy"] == "enhanced" + + +class TestRetrievalStrategyValidation: + """[AC-AISVC-RES-04,06,08] Tests for strategy validation.""" + + @pytest.fixture + def service(self): + return RetrievalStrategyService() + + def test_validate_default_strategy(self, service): + """[AC-AISVC-RES-06] Default strategy should pass validation.""" + request = RetrievalStrategyValidationRequest( + strategy=StrategyType.DEFAULT, + ) + + response = service.validate_strategy(request) + + assert response.passed is True + + def test_validate_enhanced_strategy(self, service): + """[AC-AISVC-RES-06] Enhanced strategy validation.""" + request = RetrievalStrategyValidationRequest( + strategy=StrategyType.ENHANCED, + ) + + response = service.validate_strategy(request) + + assert isinstance(response.passed, bool) + assert len(response.results) > 0 + + def test_validate_specific_checks(self, service): + """[AC-AISVC-RES-06] Should run specific validation checks.""" + request = RetrievalStrategyValidationRequest( + strategy=StrategyType.ENHANCED, + checks=["metadata_consistency", "performance_budget"], + ) + + response = service.validate_strategy(request) + + check_names = [r.check for r in response.results] + assert "metadata_consistency" in check_names + assert "performance_budget" in check_names + + def test_check_metadata_consistency(self, service): + """[AC-AISVC-RES-04] Metadata consistency check.""" + result = service._check_metadata_consistency(StrategyType.DEFAULT) + + assert result.check == "metadata_consistency" + assert result.passed is True + + def test_check_rrf_config(self, service): + """[AC-AISVC-RES-02] RRF config check.""" + result = service._check_rrf_config(StrategyType.DEFAULT) + + assert result.check == "rrf_config" + assert isinstance(result.passed, bool) + + def test_check_performance_budget(self, service): + """[AC-AISVC-RES-08] Performance budget check.""" + result = service._check_performance_budget( + StrategyType.ENHANCED, + ReactMode.REACT, + ) + + assert result.check == "performance_budget" + assert isinstance(result.passed, bool) + + +class TestStrategyAuditService: + """[AC-AISVC-RES-07] Tests for audit service.""" + + @pytest.fixture + def audit_service(self): + return StrategyAuditService(max_entries=100) + + def test_log_switch_operation(self, audit_service): + """[AC-AISVC-RES-07] Should log switch operation.""" + audit_service.log( + operation="switch", + previous_strategy="default", + new_strategy="enhanced", + reason="Testing", + operator="admin", + ) + + entries = audit_service.get_audit_log() + assert len(entries) == 1 + assert entries[0].operation == "switch" + assert entries[0].previous_strategy == "default" + assert entries[0].new_strategy == "enhanced" + + def test_log_rollback_operation(self, audit_service): + """[AC-AISVC-RES-07] Should log rollback operation.""" + audit_service.log_rollback( + previous_strategy="enhanced", + new_strategy="default", + reason="Performance issue", + operator="admin", + ) + + entries = audit_service.get_audit_log(operation="rollback") + assert len(entries) == 1 + assert entries[0].operation == "rollback" + + def test_log_validation_operation(self, audit_service): + """[AC-AISVC-RES-06] Should log validation operation.""" + audit_service.log_validation( + strategy="enhanced", + checks=["metadata_consistency"], + passed=True, + ) + + entries = audit_service.get_audit_log(operation="validate") + assert len(entries) == 1 + assert entries[0].operation == "validate" + + def test_get_audit_log_with_limit(self, audit_service): + """Should limit audit log entries.""" + for i in range(10): + audit_service.log(operation="switch", new_strategy=f"strategy_{i}") + + entries = audit_service.get_audit_log(limit=5) + assert len(entries) == 5 + + def test_get_audit_stats(self, audit_service): + """Should return audit statistics.""" + audit_service.log(operation="switch", new_strategy="enhanced") + audit_service.log(operation="rollback", new_strategy="default") + + stats = audit_service.get_audit_stats() + + assert stats["total_entries"] == 2 + assert stats["operation_counts"]["switch"] == 1 + assert stats["operation_counts"]["rollback"] == 1 + + def test_clear_audit_log(self, audit_service): + """Should clear audit log.""" + audit_service.log(operation="switch", new_strategy="enhanced") + assert len(audit_service.get_audit_log()) == 1 + + count = audit_service.clear_audit_log() + assert count == 1 + assert len(audit_service.get_audit_log()) == 0 + + +class TestStrategyMetricsService: + """[AC-AISVC-RES-03,08] Tests for metrics service.""" + + @pytest.fixture + def metrics_service(self): + return StrategyMetricsService() + + def test_record_request(self, metrics_service): + """[AC-AISVC-RES-08] Should record request metrics.""" + metrics_service.record_request( + latency_ms=100.0, + success=True, + route_mode="direct", + ) + + metrics = metrics_service.get_metrics() + assert metrics.total_requests == 1 + assert metrics.successful_requests == 1 + assert metrics.avg_latency_ms == 100.0 + + def test_record_failed_request(self, metrics_service): + """[AC-AISVC-RES-08] Should record failed request.""" + metrics_service.record_request(latency_ms=50.0, success=False) + + metrics = metrics_service.get_metrics() + assert metrics.failed_requests == 1 + + def test_record_fallback(self, metrics_service): + """[AC-AISVC-RES-08] Should record fallback count.""" + metrics_service.record_request( + latency_ms=100.0, + success=True, + fallback=True, + ) + + metrics = metrics_service.get_metrics() + assert metrics.fallback_count == 1 + + def test_record_route_metrics(self, metrics_service): + """[AC-AISVC-RES-08] Should track route mode metrics.""" + metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react") + metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct") + + route_metrics = metrics_service.get_route_metrics() + assert "react" in route_metrics + assert "direct" in route_metrics + + def test_get_all_metrics(self, metrics_service): + """Should get metrics for all strategies.""" + metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT) + metrics_service.record_request(latency_ms=100.0, success=True) + + all_metrics = metrics_service.get_all_metrics() + + assert StrategyType.DEFAULT.value in all_metrics + assert StrategyType.ENHANCED.value in all_metrics + + def test_get_performance_summary(self, metrics_service): + """[AC-AISVC-RES-08] Should get performance summary.""" + metrics_service.record_request(latency_ms=100.0, success=True) + metrics_service.record_request(latency_ms=200.0, success=True) + metrics_service.record_request(latency_ms=50.0, success=False) + + summary = metrics_service.get_performance_summary() + + assert summary["total_requests"] == 3 + assert summary["successful_requests"] == 2 + assert summary["failed_requests"] == 1 + assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01) + + def test_check_performance_threshold_ok(self, metrics_service): + """[AC-AISVC-RES-08] Should pass performance threshold check.""" + metrics_service.record_request(latency_ms=100.0, success=True) + + result = metrics_service.check_performance_threshold( + strategy=StrategyType.DEFAULT, + max_latency_ms=5000.0, + max_error_rate=0.1, + ) + + assert result["latency_ok"] is True + assert result["error_rate_ok"] is True + assert result["overall_ok"] is True + + def test_check_performance_threshold_exceeded(self, metrics_service): + """[AC-AISVC-RES-08] Should fail when threshold exceeded.""" + metrics_service.record_request(latency_ms=6000.0, success=True) + metrics_service.record_request(latency_ms=100.0, success=False) + + result = metrics_service.check_performance_threshold( + strategy=StrategyType.DEFAULT, + max_latency_ms=5000.0, + max_error_rate=0.1, + ) + + assert result["latency_ok"] is False or result["error_rate_ok"] is False + + def test_reset_metrics(self, metrics_service): + """Should reset metrics.""" + metrics_service.record_request(latency_ms=100.0, success=True) + metrics_service.reset_metrics() + + metrics = metrics_service.get_metrics() + assert metrics.total_requests == 0 + + +class TestSingletonInstances: + """Tests for singleton instance getters.""" + + def test_get_strategy_service_singleton(self): + """Should return same strategy service instance.""" + from app.services.retrieval.strategy_service import _strategy_service + + import app.services.retrieval.strategy_service as module + module._strategy_service = None + + service1 = get_strategy_service() + service2 = get_strategy_service() + + assert service1 is service2 + + def test_get_audit_service_singleton(self): + """Should return same audit service instance.""" + from app.services.retrieval.strategy_audit import _audit_service + + import app.services.retrieval.strategy_audit as module + module._audit_service = None + + service1 = get_audit_service() + service2 = get_audit_service() + + assert service1 is service2 + + def test_get_metrics_service_singleton(self): + """Should return same metrics service instance.""" + from app.services.retrieval.strategy_metrics import _metrics_service + + import app.services.retrieval.strategy_metrics as module + module._metrics_service = None + + service1 = get_metrics_service() + service2 = get_metrics_service() + + assert service1 is service2 diff --git a/ai-service/tests/test_retrieval_strategy_integration.py b/ai-service/tests/test_retrieval_strategy_integration.py new file mode 100644 index 0000000..1ceef1f --- /dev/null +++ b/ai-service/tests/test_retrieval_strategy_integration.py @@ -0,0 +1,353 @@ +""" +Integration tests for Retrieval Strategy API. +[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints. + +Tests the full API flow: +- GET /strategy/retrieval/current +- POST /strategy/retrieval/switch +- POST /strategy/retrieval/validate +- POST /strategy/retrieval/rollback +""" + +import json +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi.testclient import TestClient + +from app.main import app + + +@pytest.fixture(autouse=True) +def mock_api_key_service(): + """ + Mock API key service to bypass authentication in tests. + """ + mock_service = MagicMock() + mock_service._initialized = True + mock_service._keys_cache = {"test-api-key": MagicMock()} + + mock_validation = MagicMock() + mock_validation.ok = True + mock_validation.reason = None + mock_service.validate_key_with_context.return_value = mock_validation + + with patch("app.services.api_key.get_api_key_service", return_value=mock_service): + yield mock_service + + +@pytest.fixture(autouse=True) +def reset_strategy_state(): + """ + Reset strategy state before and after each test. + """ + from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router + from app.services.retrieval.strategy.config import RetrievalStrategyConfig + + set_strategy_router(None) + router = get_strategy_router() + router.update_config(RetrievalStrategyConfig()) + yield + set_strategy_router(None) + router = get_strategy_router() + router.update_config(RetrievalStrategyConfig()) + + +class TestRetrievalStrategyAPIIntegration: + """ + [AC-AISVC-RES-01~15] Integration tests for retrieval strategy API. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def valid_headers(self): + return { + "X-Tenant-Id": "test@ash@2026", + "X-API-Key": "test-api-key", + } + + def test_get_current_strategy(self, client, valid_headers): + """ + [AC-AISVC-RES-01] GET /current should return strategy status. + """ + response = client.get( + "/strategy/retrieval/current", + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert "active_strategy" in data + assert "grayscale" in data + assert data["active_strategy"] in ["default", "enhanced"] + + def test_switch_strategy_to_enhanced(self, client, valid_headers): + """ + [AC-AISVC-RES-02] POST /switch should switch to enhanced strategy. + """ + response = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + }, + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert "success" in data + assert data["success"] is True + assert data["current_strategy"] == "enhanced" + + def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers): + """ + [AC-AISVC-RES-03] POST /switch should accept grayscale percentage. + """ + response = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + "grayscale": { + "enabled": True, + "percentage": 30.0, + }, + }, + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["success"] is True + + def test_switch_strategy_with_allowlist(self, client, valid_headers): + """ + [AC-AISVC-RES-03] POST /switch should accept allowlist. + """ + response = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + "grayscale": { + "enabled": True, + "allowlist": ["tenant_a", "tenant_b"], + }, + }, + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["success"] is True + + def test_validate_strategy(self, client, valid_headers): + """ + [AC-AISVC-RES-06] POST /validate should validate strategy. + """ + response = client.post( + "/strategy/retrieval/validate", + json={ + "active_strategy": "enhanced", + }, + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert "valid" in data + assert "errors" in data + assert isinstance(data["valid"], bool) + + def test_validate_default_strategy(self, client, valid_headers): + """ + [AC-AISVC-RES-06] Default strategy should pass validation. + """ + response = client.post( + "/strategy/retrieval/validate", + json={ + "active_strategy": "default", + }, + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["valid"] is True + + def test_rollback_strategy(self, client, valid_headers): + """ + [AC-AISVC-RES-07] POST /rollback should rollback to default. + """ + client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + }, + headers=valid_headers, + ) + + response = client.post( + "/strategy/retrieval/rollback", + headers=valid_headers, + ) + + assert response.status_code == 200 + data = response.json() + + assert "success" in data + assert data["current_strategy"] == "default" + + +class TestRetrievalStrategyAPIValidation: + """ + [AC-AISVC-RES-03] Tests for API request validation. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def valid_headers(self): + return { + "X-Tenant-Id": "test@ash@2026", + "X-API-Key": "test-api-key", + } + + def test_switch_invalid_strategy(self, client, valid_headers): + """ + [AC-AISVC-RES-03] Invalid strategy value should return error. + """ + response = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "invalid_strategy", + }, + headers=valid_headers, + ) + + assert response.status_code in [400, 422, 500] + + def test_switch_percentage_out_of_range(self, client, valid_headers): + """ + [AC-AISVC-RES-03] Percentage > 100 should return validation error. + """ + response = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + "grayscale": { + "percentage": 150.0, + }, + }, + headers=valid_headers, + ) + + assert response.status_code in [400, 422] + + +class TestRetrievalStrategyAPIFlow: + """ + [AC-AISVC-RES-01~15] Tests for complete API flow scenarios. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def valid_headers(self): + return { + "X-Tenant-Id": "test@ash@2026", + "X-API-Key": "test-api-key", + } + + def test_complete_strategy_lifecycle(self, client, valid_headers): + """ + [AC-AISVC-RES-01~07] Test complete strategy lifecycle: + 1. Get current strategy + 2. Switch to enhanced + 3. Validate + 4. Rollback + 5. Verify back to default + """ + current = client.get( + "/strategy/retrieval/current", + headers=valid_headers, + ) + assert current.status_code == 200 + assert current.json()["active_strategy"] == "default" + + switch = client.post( + "/strategy/retrieval/switch", + json={ + "active_strategy": "enhanced", + "grayscale": {"enabled": True, "percentage": 50.0}, + }, + headers=valid_headers, + ) + assert switch.status_code == 200 + assert switch.json()["current_strategy"] == "enhanced" + + validate = client.post( + "/strategy/retrieval/validate", + json={"active_strategy": "enhanced"}, + headers=valid_headers, + ) + assert validate.status_code == 200 + + rollback = client.post( + "/strategy/retrieval/rollback", + headers=valid_headers, + ) + assert rollback.status_code == 200 + assert rollback.json()["current_strategy"] == "default" + + final = client.get( + "/strategy/retrieval/current", + headers=valid_headers, + ) + assert final.status_code == 200 + assert final.json()["active_strategy"] == "default" + + +class TestRetrievalStrategyAPIMissingTenant: + """ + Tests for API behavior without tenant ID. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def api_key_headers(self): + return {"X-API-Key": "test-api-key"} + + def test_current_without_tenant(self, client, api_key_headers): + """ + Missing X-Tenant-Id should return 400. + """ + response = client.get( + "/strategy/retrieval/current", + headers=api_key_headers, + ) + assert response.status_code == 400 + + def test_switch_without_tenant(self, client, api_key_headers): + """ + Missing X-Tenant-Id should return 400. + """ + response = client.post( + "/strategy/retrieval/switch", + json={"active_strategy": "enhanced"}, + headers=api_key_headers, + ) + assert response.status_code == 400