ai-robot-core/ai-service/tests/test_orchestrator.py

655 lines
21 KiB
Python
Raw Normal View History

"""
Tests for OrchestratorService.
[AC-AISVC-01, AC-AISVC-02] Test complete generation pipeline integration.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import AsyncGenerator
from app.models import ChatRequest, ChatResponse, ChannelType, ChatMessage, Role
from app.services.orchestrator import (
OrchestratorService,
OrchestratorConfig,
GenerationContext,
set_orchestrator_service,
)
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
from app.services.memory import MemoryService
from app.services.retrieval.base import (
BaseRetriever,
RetrievalContext,
RetrievalResult,
RetrievalHit,
)
from app.services.confidence import ConfidenceCalculator, ConfidenceConfig
from app.services.context import ContextMerger
from app.models.entities import ChatMessage as ChatMessageEntity
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""
def __init__(self, response_content: str = "Mock LLM response"):
self._response_content = response_content
self._generate_called = False
self._stream_generate_called = False
async def generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs,
) -> LLMResponse:
self._generate_called = True
return LLMResponse(
content=self._response_content,
model="mock-model",
usage={"prompt_tokens": 100, "completion_tokens": 50},
finish_reason="stop",
)
async def stream_generate(
self,
messages: list[dict[str, str]],
config: LLMConfig | None = None,
**kwargs,
) -> AsyncGenerator[LLMStreamChunk, None]:
self._stream_generate_called = True
chunks = ["Hello", " from", " mock", " LLM"]
for chunk in chunks:
yield LLMStreamChunk(delta=chunk, model="mock-model")
yield LLMStreamChunk(delta="", model="mock-model", finish_reason="stop")
async def close(self) -> None:
pass
class MockRetriever(BaseRetriever):
"""Mock retriever for testing."""
def __init__(self, hits: list[RetrievalHit] | None = None):
self._hits = hits or []
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
return RetrievalResult(
hits=self._hits,
diagnostics={"mock": True},
)
async def health_check(self) -> bool:
return True
class MockMemoryService:
"""Mock memory service for testing."""
def __init__(self, history: list[ChatMessageEntity] | None = None):
self._history = history or []
self._saved_messages: list[dict] = []
self._session_created = False
async def get_or_create_session(
self,
tenant_id: str,
session_id: str,
channel_type: str | None = None,
metadata: dict | None = None,
):
self._session_created = True
return MagicMock(tenant_id=tenant_id, session_id=session_id)
async def load_history(
self,
tenant_id: str,
session_id: str,
limit: int | None = None,
):
return self._history
async def append_message(
self,
tenant_id: str,
session_id: str,
role: str,
content: str,
):
self._saved_messages.append({"role": role, "content": content})
async def append_messages(
self,
tenant_id: str,
session_id: str,
messages: list[dict[str, str]],
):
self._saved_messages.extend(messages)
def create_chat_request(
message: str = "Hello",
session_id: str = "test-session",
history: list[ChatMessage] | None = None,
metadata: dict | None = None,
) -> ChatRequest:
"""Helper to create ChatRequest."""
return ChatRequest(
session_id=session_id,
current_message=message,
channel_type=ChannelType.WECHAT,
history=history,
metadata=metadata,
)
class TestOrchestratorServiceGenerate:
"""Tests for OrchestratorService.generate() method."""
@pytest.mark.asyncio
async def test_generate_basic_without_dependencies(self):
"""
[AC-AISVC-01, AC-AISVC-02] Test basic generation without external dependencies.
Should return fallback response with low confidence.
"""
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
request = create_chat_request(message="What is the price?")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert isinstance(response, ChatResponse)
assert response.reply is not None
assert response.confidence >= 0.0
assert response.confidence <= 1.0
assert isinstance(response.should_transfer, bool)
assert "diagnostics" in response.metadata
@pytest.mark.asyncio
async def test_generate_with_llm_client(self):
"""
[AC-AISVC-02] Test generation with LLM client.
Should use LLM response.
"""
mock_llm = MockLLMClient(response_content="This is the AI response.")
orchestrator = OrchestratorService(
llm_client=mock_llm,
config=OrchestratorConfig(enable_rag=False),
)
request = create_chat_request(message="Hello")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.reply == "This is the AI response."
assert mock_llm._generate_called is True
@pytest.mark.asyncio
async def test_generate_with_memory_service(self):
"""
[AC-AISVC-13] Test generation with memory service.
Should load history and save messages.
"""
mock_memory = MockMemoryService(
history=[
ChatMessageEntity(
tenant_id="tenant-1",
session_id="test-session",
role="user",
content="Previous message",
)
]
)
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
memory_service=mock_memory,
config=OrchestratorConfig(enable_rag=False),
)
request = create_chat_request(message="New message")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert len(mock_memory._saved_messages) == 2
assert mock_memory._saved_messages[0]["role"] == "user"
assert mock_memory._saved_messages[1]["role"] == "assistant"
@pytest.mark.asyncio
async def test_generate_with_retrieval(self):
"""
[AC-AISVC-16, AC-AISVC-17] Test generation with RAG retrieval.
Should include evidence in LLM prompt.
"""
mock_retriever = MockRetriever(
hits=[
RetrievalHit(
text="Product price is $100",
score=0.85,
source="kb",
)
]
)
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
retriever=mock_retriever,
config=OrchestratorConfig(enable_rag=True),
)
request = create_chat_request(message="What is the price?")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert "retrieval" in response.metadata["diagnostics"]
assert response.metadata["diagnostics"]["retrieval"]["hit_count"] == 1
@pytest.mark.asyncio
async def test_generate_with_context_merging(self):
"""
[AC-AISVC-14, AC-AISVC-15] Test context merging with external history.
Should merge local and external history.
"""
mock_memory = MockMemoryService(
history=[
ChatMessageEntity(
tenant_id="tenant-1",
session_id="test-session",
role="user",
content="Local message",
)
]
)
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
memory_service=mock_memory,
config=OrchestratorConfig(enable_rag=False),
)
request = create_chat_request(
message="New message",
history=[
ChatMessage(role=Role.USER, content="External message"),
ChatMessage(role=Role.ASSISTANT, content="External response"),
],
)
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert "merged_context" in response.metadata["diagnostics"]
merged = response.metadata["diagnostics"]["merged_context"]
assert merged["local_count"] == 1
assert merged["external_count"] == 2
@pytest.mark.asyncio
async def test_generate_with_confidence_calculation(self):
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Test confidence calculation.
Should calculate confidence based on retrieval results.
"""
mock_retriever = MockRetriever(
hits=[
RetrievalHit(text="High relevance content", score=0.9, source="kb"),
RetrievalHit(text="Medium relevance", score=0.8, source="kb"),
]
)
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
retriever=mock_retriever,
config=OrchestratorConfig(enable_rag=True),
)
request = create_chat_request(message="Test query")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.confidence > 0.5
assert "confidence" in response.metadata["diagnostics"]
@pytest.mark.asyncio
async def test_generate_low_confidence_triggers_transfer(self):
"""
[AC-AISVC-18, AC-AISVC-19] Test low confidence triggers transfer.
Should set should_transfer=True when confidence is low.
"""
mock_retriever = MockRetriever(hits=[])
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
retriever=mock_retriever,
config=OrchestratorConfig(enable_rag=True),
)
request = create_chat_request(message="Unknown topic")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.should_transfer is True
assert response.transfer_reason is not None
@pytest.mark.asyncio
async def test_generate_handles_llm_error(self):
"""
[AC-AISVC-02] Test handling of LLM errors.
Should return fallback response on error.
"""
mock_llm = MagicMock()
mock_llm.generate = AsyncMock(side_effect=Exception("LLM unavailable"))
orchestrator = OrchestratorService(
llm_client=mock_llm,
config=OrchestratorConfig(enable_rag=False),
)
request = create_chat_request(message="Hello")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.reply is not None
assert "llm_error" in response.metadata["diagnostics"]
@pytest.mark.asyncio
async def test_generate_handles_retrieval_error(self):
"""
[AC-AISVC-16] Test handling of retrieval errors.
Should continue with empty retrieval result.
"""
mock_retriever = MagicMock()
mock_retriever.retrieve = AsyncMock(side_effect=Exception("Qdrant unavailable"))
mock_llm = MockLLMClient()
orchestrator = OrchestratorService(
llm_client=mock_llm,
retriever=mock_retriever,
config=OrchestratorConfig(enable_rag=True),
)
request = create_chat_request(message="Hello")
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.reply == "Mock LLM response"
assert "retrieval_error" in response.metadata["diagnostics"]
@pytest.mark.asyncio
async def test_generate_full_pipeline_integration(self):
"""
[AC-AISVC-01, AC-AISVC-02] Test complete pipeline integration.
All components working together.
"""
mock_memory = MockMemoryService(
history=[
ChatMessageEntity(
tenant_id="tenant-1",
session_id="test-session",
role="user",
content="Previous question",
),
ChatMessageEntity(
tenant_id="tenant-1",
session_id="test-session",
role="assistant",
content="Previous answer",
),
]
)
mock_retriever = MockRetriever(
hits=[
RetrievalHit(text="Knowledge base content", score=0.85, source="kb"),
]
)
mock_llm = MockLLMClient(response_content="AI generated response")
orchestrator = OrchestratorService(
llm_client=mock_llm,
memory_service=mock_memory,
retriever=mock_retriever,
config=OrchestratorConfig(enable_rag=True),
)
request = create_chat_request(
message="New question",
history=[
ChatMessage(role=Role.USER, content="External history"),
],
)
response = await orchestrator.generate(
tenant_id="tenant-1",
request=request,
)
assert response.reply == "AI generated response"
assert response.confidence > 0.0
assert len(mock_memory._saved_messages) == 2
diagnostics = response.metadata["diagnostics"]
assert diagnostics["memory_enabled"] is True
assert diagnostics["retrieval"]["hit_count"] == 1
assert diagnostics["llm_mode"] == "live"
class TestOrchestratorServiceGenerationContext:
"""Tests for GenerationContext dataclass."""
def test_generation_context_initialization(self):
"""Test GenerationContext initialization."""
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="Hello",
channel_type="wechat",
)
assert ctx.tenant_id == "tenant-1"
assert ctx.session_id == "session-1"
assert ctx.current_message == "Hello"
assert ctx.channel_type == "wechat"
assert ctx.local_history == []
assert ctx.diagnostics == {}
def test_generation_context_with_metadata(self):
"""Test GenerationContext with metadata."""
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="Hello",
channel_type="wechat",
request_metadata={"user_id": "user-123"},
)
assert ctx.request_metadata == {"user_id": "user-123"}
class TestOrchestratorConfig:
"""Tests for OrchestratorConfig dataclass."""
def test_default_config(self):
"""Test default configuration values."""
config = OrchestratorConfig()
assert config.max_history_tokens == 4000
assert config.max_evidence_tokens == 2000
assert config.enable_rag is True
assert "智能客服" in config.system_prompt
def test_custom_config(self):
"""Test custom configuration values."""
config = OrchestratorConfig(
max_history_tokens=8000,
enable_rag=False,
system_prompt="Custom prompt",
)
assert config.max_history_tokens == 8000
assert config.enable_rag is False
assert config.system_prompt == "Custom prompt"
class TestOrchestratorServiceHelperMethods:
"""Tests for OrchestratorService helper methods."""
def test_build_llm_messages_basic(self):
"""Test _build_llm_messages with basic context."""
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="Hello",
channel_type="wechat",
)
messages = orchestrator._build_llm_messages(ctx)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hello"
def test_build_llm_messages_with_evidence(self):
"""Test _build_llm_messages includes evidence from retrieval."""
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=True),
)
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="What is the price?",
channel_type="wechat",
retrieval_result=RetrievalResult(
hits=[
RetrievalHit(text="Price is $100", score=0.9, source="kb"),
]
),
)
messages = orchestrator._build_llm_messages(ctx)
assert "知识库参考内容" in messages[0]["content"]
assert "Price is $100" in messages[0]["content"]
def test_build_llm_messages_with_history(self):
"""Test _build_llm_messages includes merged history."""
from app.services.context import MergedContext
orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="New question",
channel_type="wechat",
merged_context=MergedContext(
messages=[
{"role": "user", "content": "Previous question"},
{"role": "assistant", "content": "Previous answer"},
]
),
)
messages = orchestrator._build_llm_messages(ctx)
assert len(messages) == 4
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Previous question"
assert messages[2]["role"] == "assistant"
assert messages[3]["role"] == "user"
assert messages[3]["content"] == "New question"
def test_fallback_response_with_evidence(self):
"""Test _fallback_response when retrieval has evidence."""
orchestrator = OrchestratorService()
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="Question",
channel_type="wechat",
retrieval_result=RetrievalResult(
hits=[RetrievalHit(text="Evidence", score=0.8, source="kb")]
),
)
fallback = orchestrator._fallback_response(ctx)
assert "知识库" in fallback
def test_fallback_response_without_evidence(self):
"""Test _fallback_response when no retrieval evidence."""
orchestrator = OrchestratorService()
ctx = GenerationContext(
tenant_id="tenant-1",
session_id="session-1",
current_message="Question",
channel_type="wechat",
retrieval_result=RetrievalResult(hits=[]),
)
fallback = orchestrator._fallback_response(ctx)
assert "无法处理" in fallback or "人工客服" in fallback
def test_format_evidence(self):
"""Test _format_evidence formats hits correctly."""
orchestrator = OrchestratorService()
result = RetrievalResult(
hits=[
RetrievalHit(text="First result", score=0.9, source="kb"),
RetrievalHit(text="Second result", score=0.8, source="kb"),
]
)
formatted = orchestrator._format_evidence(result)
assert "[1]" in formatted
assert "[2]" in formatted
assert "First result" in formatted
assert "Second result" in formatted
class TestOrchestratorServiceSetInstance:
"""Tests for set_orchestrator_service function."""
def test_set_orchestrator_service(self):
"""Test setting orchestrator service instance."""
custom_orchestrator = OrchestratorService(
config=OrchestratorConfig(enable_rag=False),
)
set_orchestrator_service(custom_orchestrator)
from app.services.orchestrator import get_orchestrator_service
instance = get_orchestrator_service()
assert instance is custom_orchestrator