""" Unit tests for Retrieval layer. [AC-AISVC-10, AC-AISVC-16, AC-AISVC-17] Tests for vector retrieval with tenant isolation. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult from app.services.retrieval.vector_retriever import VectorRetriever @pytest.fixture def mock_qdrant_client(): """Create a mock QdrantClient.""" client = AsyncMock() client.search = AsyncMock() client.get_collection_name = MagicMock(side_effect=lambda tenant_id: f"kb_{tenant_id}") return client @pytest.fixture def retrieval_context(): """Create a sample RetrievalContext.""" return RetrievalContext( tenant_id="tenant_a", query="What is the product price?", session_id="session_123", channel_type="wechat", metadata={"user_id": "user_123"}, ) class TestRetrievalContext: """ [AC-AISVC-16] Tests for retrieval context. """ def test_retrieval_context_creation(self): """ [AC-AISVC-16] Should create retrieval context with all fields. """ ctx = RetrievalContext( tenant_id="tenant_a", query="Test query", session_id="session_123", channel_type="wechat", metadata={"key": "value"}, ) assert ctx.tenant_id == "tenant_a" assert ctx.query == "Test query" assert ctx.session_id == "session_123" assert ctx.channel_type == "wechat" assert ctx.metadata == {"key": "value"} def test_retrieval_context_minimal(self): """ [AC-AISVC-16] Should create retrieval context with minimal fields. """ ctx = RetrievalContext( tenant_id="tenant_a", query="Test query", ) assert ctx.tenant_id == "tenant_a" assert ctx.query == "Test query" assert ctx.session_id is None assert ctx.channel_type is None class TestRetrievalResult: """ [AC-AISVC-16, AC-AISVC-17] Tests for retrieval result. """ def test_empty_result(self): """ [AC-AISVC-17] Empty result should indicate insufficient retrieval. """ result = RetrievalResult(hits=[]) assert result.is_empty is True assert result.max_score == 0.0 assert result.hit_count == 0 def test_result_with_hits(self): """ [AC-AISVC-16] Result with hits should calculate correct statistics. """ hits = [ RetrievalHit(text="Doc 1", score=0.9, source="vector"), RetrievalHit(text="Doc 2", score=0.7, source="vector"), ] result = RetrievalResult(hits=hits) assert result.is_empty is False assert result.max_score == 0.9 assert result.hit_count == 2 def test_result_max_score(self): """ [AC-AISVC-17] Max score should be the highest among hits. """ hits = [ RetrievalHit(text="Doc 1", score=0.5, source="vector"), RetrievalHit(text="Doc 2", score=0.95, source="vector"), RetrievalHit(text="Doc 3", score=0.3, source="vector"), ] result = RetrievalResult(hits=hits) assert result.max_score == 0.95 class TestVectorRetrieverTenantIsolation: """ [AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in vector retrieval. """ @pytest.mark.asyncio async def test_search_uses_tenant_collection(self, mock_qdrant_client, retrieval_context): """ [AC-AISVC-10] Search should use tenant-specific collection. """ mock_qdrant_client.search.return_value = [ {"id": "1", "score": 0.9, "payload": {"text": "Answer 1", "source": "kb"}} ] retriever = VectorRetriever(qdrant_client=mock_qdrant_client) with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): result = await retriever.retrieve(retrieval_context) mock_qdrant_client.search.assert_called_once() call_args = mock_qdrant_client.search.call_args assert call_args.kwargs["tenant_id"] == "tenant_a" @pytest.mark.asyncio async def test_different_tenants_separate_results(self, mock_qdrant_client): """ [AC-AISVC-11] Different tenants should get separate results. """ mock_qdrant_client.search.side_effect = [ [{"id": "1", "score": 0.9, "payload": {"text": "Tenant A result"}}], [{"id": "2", "score": 0.8, "payload": {"text": "Tenant B result"}}], ] retriever = VectorRetriever(qdrant_client=mock_qdrant_client) with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): ctx_a = RetrievalContext(tenant_id="tenant_a", query="query") ctx_b = RetrievalContext(tenant_id="tenant_b", query="query") result_a = await retriever.retrieve(ctx_a) result_b = await retriever.retrieve(ctx_b) assert result_a.hits[0].text == "Tenant A result" assert result_b.hits[0].text == "Tenant B result" class TestVectorRetrieverScoreThreshold: """ [AC-AISVC-17] Tests for score threshold filtering. """ @pytest.mark.asyncio async def test_filter_by_score_threshold(self, mock_qdrant_client, retrieval_context): """ [AC-AISVC-17] Results below score threshold should be filtered. """ mock_qdrant_client.search.return_value = [ {"id": "1", "score": 0.9, "payload": {"text": "High score"}}, {"id": "2", "score": 0.5, "payload": {"text": "Low score"}}, {"id": "3", "score": 0.8, "payload": {"text": "Medium score"}}, ] retriever = VectorRetriever( qdrant_client=mock_qdrant_client, score_threshold=0.7, ) with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): result = await retriever.retrieve(retrieval_context) assert len(result.hits) == 2 assert all(hit.score >= 0.7 for hit in result.hits) @pytest.mark.asyncio async def test_insufficient_hits_detection(self, mock_qdrant_client, retrieval_context): """ [AC-AISVC-17] Should detect insufficient retrieval when hits < min_hits. """ mock_qdrant_client.search.return_value = [ {"id": "1", "score": 0.9, "payload": {"text": "Only one hit"}}, ] retriever = VectorRetriever( qdrant_client=mock_qdrant_client, score_threshold=0.7, min_hits=2, ) with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): result = await retriever.retrieve(retrieval_context) assert result.diagnostics["is_insufficient"] is True assert result.diagnostics["filtered_hits"] == 1 @pytest.mark.asyncio async def test_sufficient_hits_detection(self, mock_qdrant_client, retrieval_context): """ [AC-AISVC-17] Should detect sufficient retrieval when hits >= min_hits. """ mock_qdrant_client.search.return_value = [ {"id": "1", "score": 0.9, "payload": {"text": "Hit 1"}}, {"id": "2", "score": 0.85, "payload": {"text": "Hit 2"}}, {"id": "3", "score": 0.8, "payload": {"text": "Hit 3"}}, ] retriever = VectorRetriever( qdrant_client=mock_qdrant_client, score_threshold=0.7, min_hits=2, ) with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536): result = await retriever.retrieve(retrieval_context) assert result.diagnostics["is_insufficient"] is False assert result.diagnostics["filtered_hits"] == 3 class TestVectorRetrieverHealthCheck: """ [AC-AISVC-16] Tests for retriever health check. """ @pytest.mark.asyncio async def test_health_check_success(self, mock_qdrant_client): """ [AC-AISVC-16] Health check should return True when Qdrant is available. """ mock_qdrant = AsyncMock() mock_qdrant.get_collections = AsyncMock() mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant) retriever = VectorRetriever(qdrant_client=mock_qdrant_client) is_healthy = await retriever.health_check() assert is_healthy is True @pytest.mark.asyncio async def test_health_check_failure(self, mock_qdrant_client): """ [AC-AISVC-16] Health check should return False when Qdrant is unavailable. """ mock_qdrant = AsyncMock() mock_qdrant.get_collections = AsyncMock(side_effect=Exception("Connection failed")) mock_qdrant_client.get_client = AsyncMock(return_value=mock_qdrant) retriever = VectorRetriever(qdrant_client=mock_qdrant_client) is_healthy = await retriever.health_check() assert is_healthy is False