655 lines
21 KiB
Python
655 lines
21 KiB
Python
|
|
"""
|
||
|
|
Tests for OrchestratorService.
|
||
|
|
[AC-AISVC-01, AC-AISVC-02] Test complete generation pipeline integration.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
from typing import AsyncGenerator
|
||
|
|
|
||
|
|
from app.models import ChatRequest, ChatResponse, ChannelType, ChatMessage, Role
|
||
|
|
from app.services.orchestrator import (
|
||
|
|
OrchestratorService,
|
||
|
|
OrchestratorConfig,
|
||
|
|
GenerationContext,
|
||
|
|
set_orchestrator_service,
|
||
|
|
)
|
||
|
|
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk
|
||
|
|
from app.services.memory import MemoryService
|
||
|
|
from app.services.retrieval.base import (
|
||
|
|
BaseRetriever,
|
||
|
|
RetrievalContext,
|
||
|
|
RetrievalResult,
|
||
|
|
RetrievalHit,
|
||
|
|
)
|
||
|
|
from app.services.confidence import ConfidenceCalculator, ConfidenceConfig
|
||
|
|
from app.services.context import ContextMerger
|
||
|
|
from app.models.entities import ChatMessage as ChatMessageEntity
|
||
|
|
|
||
|
|
|
||
|
|
class MockLLMClient(LLMClient):
|
||
|
|
"""Mock LLM client for testing."""
|
||
|
|
|
||
|
|
def __init__(self, response_content: str = "Mock LLM response"):
|
||
|
|
self._response_content = response_content
|
||
|
|
self._generate_called = False
|
||
|
|
self._stream_generate_called = False
|
||
|
|
|
||
|
|
async def generate(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
config: LLMConfig | None = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> LLMResponse:
|
||
|
|
self._generate_called = True
|
||
|
|
return LLMResponse(
|
||
|
|
content=self._response_content,
|
||
|
|
model="mock-model",
|
||
|
|
usage={"prompt_tokens": 100, "completion_tokens": 50},
|
||
|
|
finish_reason="stop",
|
||
|
|
)
|
||
|
|
|
||
|
|
async def stream_generate(
|
||
|
|
self,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
config: LLMConfig | None = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> AsyncGenerator[LLMStreamChunk, None]:
|
||
|
|
self._stream_generate_called = True
|
||
|
|
chunks = ["Hello", " from", " mock", " LLM"]
|
||
|
|
for chunk in chunks:
|
||
|
|
yield LLMStreamChunk(delta=chunk, model="mock-model")
|
||
|
|
yield LLMStreamChunk(delta="", model="mock-model", finish_reason="stop")
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class MockRetriever(BaseRetriever):
|
||
|
|
"""Mock retriever for testing."""
|
||
|
|
|
||
|
|
def __init__(self, hits: list[RetrievalHit] | None = None):
|
||
|
|
self._hits = hits or []
|
||
|
|
|
||
|
|
async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult:
|
||
|
|
return RetrievalResult(
|
||
|
|
hits=self._hits,
|
||
|
|
diagnostics={"mock": True},
|
||
|
|
)
|
||
|
|
|
||
|
|
async def health_check(self) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
class MockMemoryService:
|
||
|
|
"""Mock memory service for testing."""
|
||
|
|
|
||
|
|
def __init__(self, history: list[ChatMessageEntity] | None = None):
|
||
|
|
self._history = history or []
|
||
|
|
self._saved_messages: list[dict] = []
|
||
|
|
self._session_created = False
|
||
|
|
|
||
|
|
async def get_or_create_session(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
session_id: str,
|
||
|
|
channel_type: str | None = None,
|
||
|
|
metadata: dict | None = None,
|
||
|
|
):
|
||
|
|
self._session_created = True
|
||
|
|
return MagicMock(tenant_id=tenant_id, session_id=session_id)
|
||
|
|
|
||
|
|
async def load_history(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
session_id: str,
|
||
|
|
limit: int | None = None,
|
||
|
|
):
|
||
|
|
return self._history
|
||
|
|
|
||
|
|
async def append_message(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
session_id: str,
|
||
|
|
role: str,
|
||
|
|
content: str,
|
||
|
|
):
|
||
|
|
self._saved_messages.append({"role": role, "content": content})
|
||
|
|
|
||
|
|
async def append_messages(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
session_id: str,
|
||
|
|
messages: list[dict[str, str]],
|
||
|
|
):
|
||
|
|
self._saved_messages.extend(messages)
|
||
|
|
|
||
|
|
|
||
|
|
def create_chat_request(
|
||
|
|
message: str = "Hello",
|
||
|
|
session_id: str = "test-session",
|
||
|
|
history: list[ChatMessage] | None = None,
|
||
|
|
metadata: dict | None = None,
|
||
|
|
) -> ChatRequest:
|
||
|
|
"""Helper to create ChatRequest."""
|
||
|
|
return ChatRequest(
|
||
|
|
session_id=session_id,
|
||
|
|
current_message=message,
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
history=history,
|
||
|
|
metadata=metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorServiceGenerate:
|
||
|
|
"""Tests for OrchestratorService.generate() method."""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_basic_without_dependencies(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-01, AC-AISVC-02] Test basic generation without external dependencies.
|
||
|
|
Should return fallback response with low confidence.
|
||
|
|
"""
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="What is the price?")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert isinstance(response, ChatResponse)
|
||
|
|
assert response.reply is not None
|
||
|
|
assert response.confidence >= 0.0
|
||
|
|
assert response.confidence <= 1.0
|
||
|
|
assert isinstance(response.should_transfer, bool)
|
||
|
|
assert "diagnostics" in response.metadata
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_with_llm_client(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-02] Test generation with LLM client.
|
||
|
|
Should use LLM response.
|
||
|
|
"""
|
||
|
|
mock_llm = MockLLMClient(response_content="This is the AI response.")
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="Hello")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.reply == "This is the AI response."
|
||
|
|
assert mock_llm._generate_called is True
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_with_memory_service(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-13] Test generation with memory service.
|
||
|
|
Should load history and save messages.
|
||
|
|
"""
|
||
|
|
mock_memory = MockMemoryService(
|
||
|
|
history=[
|
||
|
|
ChatMessageEntity(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="test-session",
|
||
|
|
role="user",
|
||
|
|
content="Previous message",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
memory_service=mock_memory,
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="New message")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(mock_memory._saved_messages) == 2
|
||
|
|
assert mock_memory._saved_messages[0]["role"] == "user"
|
||
|
|
assert mock_memory._saved_messages[1]["role"] == "assistant"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_with_retrieval(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-16, AC-AISVC-17] Test generation with RAG retrieval.
|
||
|
|
Should include evidence in LLM prompt.
|
||
|
|
"""
|
||
|
|
mock_retriever = MockRetriever(
|
||
|
|
hits=[
|
||
|
|
RetrievalHit(
|
||
|
|
text="Product price is $100",
|
||
|
|
score=0.85,
|
||
|
|
source="kb",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
retriever=mock_retriever,
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="What is the price?")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert "retrieval" in response.metadata["diagnostics"]
|
||
|
|
assert response.metadata["diagnostics"]["retrieval"]["hit_count"] == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_with_context_merging(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-14, AC-AISVC-15] Test context merging with external history.
|
||
|
|
Should merge local and external history.
|
||
|
|
"""
|
||
|
|
mock_memory = MockMemoryService(
|
||
|
|
history=[
|
||
|
|
ChatMessageEntity(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="test-session",
|
||
|
|
role="user",
|
||
|
|
content="Local message",
|
||
|
|
)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
memory_service=mock_memory,
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(
|
||
|
|
message="New message",
|
||
|
|
history=[
|
||
|
|
ChatMessage(role=Role.USER, content="External message"),
|
||
|
|
ChatMessage(role=Role.ASSISTANT, content="External response"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert "merged_context" in response.metadata["diagnostics"]
|
||
|
|
merged = response.metadata["diagnostics"]["merged_context"]
|
||
|
|
assert merged["local_count"] == 1
|
||
|
|
assert merged["external_count"] == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_with_confidence_calculation(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Test confidence calculation.
|
||
|
|
Should calculate confidence based on retrieval results.
|
||
|
|
"""
|
||
|
|
mock_retriever = MockRetriever(
|
||
|
|
hits=[
|
||
|
|
RetrievalHit(text="High relevance content", score=0.9, source="kb"),
|
||
|
|
RetrievalHit(text="Medium relevance", score=0.8, source="kb"),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
retriever=mock_retriever,
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="Test query")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.confidence > 0.5
|
||
|
|
assert "confidence" in response.metadata["diagnostics"]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_low_confidence_triggers_transfer(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-18, AC-AISVC-19] Test low confidence triggers transfer.
|
||
|
|
Should set should_transfer=True when confidence is low.
|
||
|
|
"""
|
||
|
|
mock_retriever = MockRetriever(hits=[])
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
retriever=mock_retriever,
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="Unknown topic")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.should_transfer is True
|
||
|
|
assert response.transfer_reason is not None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_handles_llm_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-02] Test handling of LLM errors.
|
||
|
|
Should return fallback response on error.
|
||
|
|
"""
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_llm.generate = AsyncMock(side_effect=Exception("LLM unavailable"))
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="Hello")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.reply is not None
|
||
|
|
assert "llm_error" in response.metadata["diagnostics"]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_handles_retrieval_error(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-16] Test handling of retrieval errors.
|
||
|
|
Should continue with empty retrieval result.
|
||
|
|
"""
|
||
|
|
mock_retriever = MagicMock()
|
||
|
|
mock_retriever.retrieve = AsyncMock(side_effect=Exception("Qdrant unavailable"))
|
||
|
|
mock_llm = MockLLMClient()
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
retriever=mock_retriever,
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(message="Hello")
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.reply == "Mock LLM response"
|
||
|
|
assert "retrieval_error" in response.metadata["diagnostics"]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_full_pipeline_integration(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-01, AC-AISVC-02] Test complete pipeline integration.
|
||
|
|
All components working together.
|
||
|
|
"""
|
||
|
|
mock_memory = MockMemoryService(
|
||
|
|
history=[
|
||
|
|
ChatMessageEntity(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="test-session",
|
||
|
|
role="user",
|
||
|
|
content="Previous question",
|
||
|
|
),
|
||
|
|
ChatMessageEntity(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="test-session",
|
||
|
|
role="assistant",
|
||
|
|
content="Previous answer",
|
||
|
|
),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_retriever = MockRetriever(
|
||
|
|
hits=[
|
||
|
|
RetrievalHit(text="Knowledge base content", score=0.85, source="kb"),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
mock_llm = MockLLMClient(response_content="AI generated response")
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
llm_client=mock_llm,
|
||
|
|
memory_service=mock_memory,
|
||
|
|
retriever=mock_retriever,
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
request = create_chat_request(
|
||
|
|
message="New question",
|
||
|
|
history=[
|
||
|
|
ChatMessage(role=Role.USER, content="External history"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
response = await orchestrator.generate(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
request=request,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.reply == "AI generated response"
|
||
|
|
assert response.confidence > 0.0
|
||
|
|
assert len(mock_memory._saved_messages) == 2
|
||
|
|
|
||
|
|
diagnostics = response.metadata["diagnostics"]
|
||
|
|
assert diagnostics["memory_enabled"] is True
|
||
|
|
assert diagnostics["retrieval"]["hit_count"] == 1
|
||
|
|
assert diagnostics["llm_mode"] == "live"
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorServiceGenerationContext:
|
||
|
|
"""Tests for GenerationContext dataclass."""
|
||
|
|
|
||
|
|
def test_generation_context_initialization(self):
|
||
|
|
"""Test GenerationContext initialization."""
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="Hello",
|
||
|
|
channel_type="wechat",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert ctx.tenant_id == "tenant-1"
|
||
|
|
assert ctx.session_id == "session-1"
|
||
|
|
assert ctx.current_message == "Hello"
|
||
|
|
assert ctx.channel_type == "wechat"
|
||
|
|
assert ctx.local_history == []
|
||
|
|
assert ctx.diagnostics == {}
|
||
|
|
|
||
|
|
def test_generation_context_with_metadata(self):
|
||
|
|
"""Test GenerationContext with metadata."""
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="Hello",
|
||
|
|
channel_type="wechat",
|
||
|
|
request_metadata={"user_id": "user-123"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert ctx.request_metadata == {"user_id": "user-123"}
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorConfig:
|
||
|
|
"""Tests for OrchestratorConfig dataclass."""
|
||
|
|
|
||
|
|
def test_default_config(self):
|
||
|
|
"""Test default configuration values."""
|
||
|
|
config = OrchestratorConfig()
|
||
|
|
|
||
|
|
assert config.max_history_tokens == 4000
|
||
|
|
assert config.max_evidence_tokens == 2000
|
||
|
|
assert config.enable_rag is True
|
||
|
|
assert "智能客服" in config.system_prompt
|
||
|
|
|
||
|
|
def test_custom_config(self):
|
||
|
|
"""Test custom configuration values."""
|
||
|
|
config = OrchestratorConfig(
|
||
|
|
max_history_tokens=8000,
|
||
|
|
enable_rag=False,
|
||
|
|
system_prompt="Custom prompt",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert config.max_history_tokens == 8000
|
||
|
|
assert config.enable_rag is False
|
||
|
|
assert config.system_prompt == "Custom prompt"
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorServiceHelperMethods:
|
||
|
|
"""Tests for OrchestratorService helper methods."""
|
||
|
|
|
||
|
|
def test_build_llm_messages_basic(self):
|
||
|
|
"""Test _build_llm_messages with basic context."""
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="Hello",
|
||
|
|
channel_type="wechat",
|
||
|
|
)
|
||
|
|
|
||
|
|
messages = orchestrator._build_llm_messages(ctx)
|
||
|
|
|
||
|
|
assert len(messages) == 2
|
||
|
|
assert messages[0]["role"] == "system"
|
||
|
|
assert messages[1]["role"] == "user"
|
||
|
|
assert messages[1]["content"] == "Hello"
|
||
|
|
|
||
|
|
def test_build_llm_messages_with_evidence(self):
|
||
|
|
"""Test _build_llm_messages includes evidence from retrieval."""
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
config=OrchestratorConfig(enable_rag=True),
|
||
|
|
)
|
||
|
|
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="What is the price?",
|
||
|
|
channel_type="wechat",
|
||
|
|
retrieval_result=RetrievalResult(
|
||
|
|
hits=[
|
||
|
|
RetrievalHit(text="Price is $100", score=0.9, source="kb"),
|
||
|
|
]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
messages = orchestrator._build_llm_messages(ctx)
|
||
|
|
|
||
|
|
assert "知识库参考内容" in messages[0]["content"]
|
||
|
|
assert "Price is $100" in messages[0]["content"]
|
||
|
|
|
||
|
|
def test_build_llm_messages_with_history(self):
|
||
|
|
"""Test _build_llm_messages includes merged history."""
|
||
|
|
from app.services.context import MergedContext
|
||
|
|
|
||
|
|
orchestrator = OrchestratorService(
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="New question",
|
||
|
|
channel_type="wechat",
|
||
|
|
merged_context=MergedContext(
|
||
|
|
messages=[
|
||
|
|
{"role": "user", "content": "Previous question"},
|
||
|
|
{"role": "assistant", "content": "Previous answer"},
|
||
|
|
]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
messages = orchestrator._build_llm_messages(ctx)
|
||
|
|
|
||
|
|
assert len(messages) == 4
|
||
|
|
assert messages[1]["role"] == "user"
|
||
|
|
assert messages[1]["content"] == "Previous question"
|
||
|
|
assert messages[2]["role"] == "assistant"
|
||
|
|
assert messages[3]["role"] == "user"
|
||
|
|
assert messages[3]["content"] == "New question"
|
||
|
|
|
||
|
|
def test_fallback_response_with_evidence(self):
|
||
|
|
"""Test _fallback_response when retrieval has evidence."""
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="Question",
|
||
|
|
channel_type="wechat",
|
||
|
|
retrieval_result=RetrievalResult(
|
||
|
|
hits=[RetrievalHit(text="Evidence", score=0.8, source="kb")]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
fallback = orchestrator._fallback_response(ctx)
|
||
|
|
assert "知识库" in fallback
|
||
|
|
|
||
|
|
def test_fallback_response_without_evidence(self):
|
||
|
|
"""Test _fallback_response when no retrieval evidence."""
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
|
||
|
|
ctx = GenerationContext(
|
||
|
|
tenant_id="tenant-1",
|
||
|
|
session_id="session-1",
|
||
|
|
current_message="Question",
|
||
|
|
channel_type="wechat",
|
||
|
|
retrieval_result=RetrievalResult(hits=[]),
|
||
|
|
)
|
||
|
|
|
||
|
|
fallback = orchestrator._fallback_response(ctx)
|
||
|
|
assert "无法处理" in fallback or "人工客服" in fallback
|
||
|
|
|
||
|
|
def test_format_evidence(self):
|
||
|
|
"""Test _format_evidence formats hits correctly."""
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
|
||
|
|
result = RetrievalResult(
|
||
|
|
hits=[
|
||
|
|
RetrievalHit(text="First result", score=0.9, source="kb"),
|
||
|
|
RetrievalHit(text="Second result", score=0.8, source="kb"),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
formatted = orchestrator._format_evidence(result)
|
||
|
|
|
||
|
|
assert "[1]" in formatted
|
||
|
|
assert "[2]" in formatted
|
||
|
|
assert "First result" in formatted
|
||
|
|
assert "Second result" in formatted
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorServiceSetInstance:
|
||
|
|
"""Tests for set_orchestrator_service function."""
|
||
|
|
|
||
|
|
def test_set_orchestrator_service(self):
|
||
|
|
"""Test setting orchestrator service instance."""
|
||
|
|
custom_orchestrator = OrchestratorService(
|
||
|
|
config=OrchestratorConfig(enable_rag=False),
|
||
|
|
)
|
||
|
|
|
||
|
|
set_orchestrator_service(custom_orchestrator)
|
||
|
|
|
||
|
|
from app.services.orchestrator import get_orchestrator_service
|
||
|
|
|
||
|
|
instance = get_orchestrator_service()
|
||
|
|
assert instance is custom_orchestrator
|