From 92cef20a86c388f34ab5ac51cdd745e8a9f98ba4 Mon Sep 17 00:00:00 2001 From: MerCry Date: Tue, 24 Feb 2026 13:22:04 +0800 Subject: [PATCH] test(ai-service): add Retrieval layer unit tests [AC-AISVC-10, AC-AISVC-16, AC-AISVC-17] --- ai-service/tests/test_retrieval.py | 264 +++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 ai-service/tests/test_retrieval.py diff --git a/ai-service/tests/test_retrieval.py b/ai-service/tests/test_retrieval.py new file mode 100644 index 0000000..bb7dfe5 --- /dev/null +++ b/ai-service/tests/test_retrieval.py @@ -0,0 +1,264 @@ +""" +Unit tests for Retrieval layer. +[AC-AISVC-10, AC-AISVC-16, AC-AISVC-17] Tests for vector retrieval with tenant isolation. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult +from app.services.retrieval.vector_retriever import VectorRetriever + + +@pytest.fixture +def mock_qdrant_client(): + """Create a mock QdrantClient.""" + client = AsyncMock() + client.search = AsyncMock() + client.get_collection_name = MagicMock(side_effect=lambda tenant_id: f"kb_{tenant_id}") + return client + + +@pytest.fixture +def retrieval_context(): + """Create a sample RetrievalContext.""" + return RetrievalContext( + tenant_id="tenant_a", + query="What is the product price?", + session_id="session_123", + channel_type="wechat", + metadata={"user_id": "user_123"}, + ) + + +class TestRetrievalContext: + """ + [AC-AISVC-16] Tests for retrieval context. + """ + + def test_retrieval_context_creation(self): + """ + [AC-AISVC-16] Should create retrieval context with all fields. + """ + ctx = RetrievalContext( + tenant_id="tenant_a", + query="Test query", + session_id="session_123", + channel_type="wechat", + metadata={"key": "value"}, + ) + + assert ctx.tenant_id == "tenant_a" + assert ctx.query == "Test query" + assert ctx.session_id == "session_123" + assert ctx.channel_type == "wechat" + assert ctx.metadata == {"key": "value"} + + def test_retrieval_context_minimal(self): + """ + [AC-AISVC-16] Should create retrieval context with minimal fields. + """ + ctx = RetrievalContext( + tenant_id="tenant_a", + query="Test query", + ) + + assert ctx.tenant_id == "tenant_a" + assert ctx.query == "Test query" + assert ctx.session_id is None + assert ctx.channel_type is None + + +class TestRetrievalResult: + """ + [AC-AISVC-16, AC-AISVC-17] Tests for retrieval result. + """ + + def test_empty_result(self): + """ + [AC-AISVC-17] Empty result should indicate insufficient retrieval. + """ + result = RetrievalResult(hits=[]) + + assert result.is_empty is True + assert result.max_score == 0.0 + assert result.hit_count == 0 + + def test_result_with_hits(self): + """ + [AC-AISVC-16] Result with hits should calculate correct statistics. + """ + hits = [ + RetrievalHit(text="Doc 1", score=0.9, source="vector"), + RetrievalHit(text="Doc 2", score=0.7, source="vector"), + ] + result = RetrievalResult(hits=hits) + + assert result.is_empty is False + assert result.max_score == 0.9 + assert result.hit_count == 2 + + def test_result_max_score(self): + """ + [AC-AISVC-17] Max score should be the highest among hits. + """ + hits = [ + RetrievalHit(text="Doc 1", score=0.5, source="vector"), + RetrievalHit(text="Doc 2", score=0.95, source="vector"), + RetrievalHit(text="Doc 3", score=0.3, source="vector"), + ] + result = RetrievalResult(hits=hits) + + assert result.max_score == 0.95 + + +class TestVectorRetrieverTenantIsolation: + """ + [AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in vector retrieval. + """ + + @pytest.mark.asyncio + async def test_search_uses_tenant_collection(self, mock_qdrant_client, retrieval_context): + """ + [AC-AISVC-10] Search should use tenant-specific collection. + """ + mock_qdrant_client.search.return_value = [ + {"id": "1", "score": 0.9, "payload": {"text": "Answer 1", "source": "kb"}} + ] + + retriever = VectorRetriever(qdrant_client=mock_qdrant_client) + + with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): + result = await retriever.retrieve(retrieval_context) + + mock_qdrant_client.search.assert_called_once() + call_args = mock_qdrant_client.search.call_args + assert call_args.kwargs["tenant_id"] == "tenant_a" + + @pytest.mark.asyncio + async def test_different_tenants_separate_results(self, mock_qdrant_client): + """ + [AC-AISVC-11] Different tenants should get separate results. + """ + mock_qdrant_client.search.side_effect = [ + [{"id": "1", "score": 0.9, "payload": {"text": "Tenant A result"}}], + [{"id": "2", "score": 0.8, "payload": {"text": "Tenant B result"}}], + ] + + retriever = VectorRetriever(qdrant_client=mock_qdrant_client) + + with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): + ctx_a = RetrievalContext(tenant_id="tenant_a", query="query") + ctx_b = RetrievalContext(tenant_id="tenant_b", query="query") + + result_a = await retriever.retrieve(ctx_a) + result_b = await retriever.retrieve(ctx_b) + + assert result_a.hits[0].text == "Tenant A result" + assert result_b.hits[0].text == "Tenant B result" + + +class TestVectorRetrieverScoreThreshold: + """ + [AC-AISVC-17] Tests for score threshold filtering. + """ + + @pytest.mark.asyncio + async def test_filter_by_score_threshold(self, mock_qdrant_client, retrieval_context): + """ + [AC-AISVC-17] Results below score threshold should be filtered. + """ + mock_qdrant_client.search.return_value = [ + {"id": "1", "score": 0.9, "payload": {"text": "High score"}}, + {"id": "2", "score": 0.5, "payload": {"text": "Low score"}}, + {"id": "3", "score": 0.8, "payload": {"text": "Medium score"}}, + ] + + retriever = VectorRetriever( + qdrant_client=mock_qdrant_client, + score_threshold=0.7, + ) + + with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): + result = await retriever.retrieve(retrieval_context) + + assert len(result.hits) == 2 + assert all(hit.score >= 0.7 for hit in result.hits) + + @pytest.mark.asyncio + async def test_insufficient_hits_detection(self, mock_qdrant_client, retrieval_context): + """ + [AC-AISVC-17] Should detect insufficient retrieval when hits < min_hits. + """ + mock_qdrant_client.search.return_value = [ + {"id": "1", "score": 0.9, "payload": {"text": "Only one hit"}}, + ] + + retriever = VectorRetriever( + qdrant_client=mock_qdrant_client, + score_threshold=0.7, + min_hits=2, + ) + + with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): + result = await retriever.retrieve(retrieval_context) + + assert result.diagnostics["is_insufficient"] is True + assert result.diagnostics["filtered_hits"] == 1 + + @pytest.mark.asyncio + async def test_sufficient_hits_detection(self, mock_qdrant_client, retrieval_context): + """ + [AC-AISVC-17] Should detect sufficient retrieval when hits >= min_hits. + """ + mock_qdrant_client.search.return_value = [ + {"id": "1", "score": 0.9, "payload": {"text": "Hit 1"}}, + {"id": "2", "score": 0.85, "payload": {"text": "Hit 2"}}, + {"id": "3", "score": 0.8, "payload": {"text": "Hit 3"}}, + ] + + retriever = VectorRetriever( + qdrant_client=mock_qdrant_client, + score_threshold=0.7, + min_hits=2, + ) + + with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): + result = await retriever.retrieve(retrieval_context) + + assert result.diagnostics["is_insufficient"] is False + assert result.diagnostics["filtered_hits"] == 3 + + +class TestVectorRetrieverHealthCheck: + """ + [AC-AISVC-16] Tests for retriever health check. + """ + + @pytest.mark.asyncio + async def test_health_check_success(self, mock_qdrant_client): + """ + [AC-AISVC-16] Health check should return True when Qdrant is available. + """ + mock_qdrant = AsyncMock() + mock_qdrant.get_collections = AsyncMock() + mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant) + + retriever = VectorRetriever(qdrant_client=mock_qdrant_client) + is_healthy = await retriever.health_check() + + assert is_healthy is True + + @pytest.mark.asyncio + async def test_health_check_failure(self, mock_qdrant_client): + """ + [AC-AISVC-16] Health check should return False when Qdrant is unavailable. + """ + mock_qdrant = AsyncMock() + mock_qdrant.get_collections = AsyncMock(side_effect=Exception("Connection failed")) + mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant) + + retriever = VectorRetriever(qdrant_client=mock_qdrant_client) + is_healthy = await retriever.health_check() + + assert is_healthy is False