216 lines
7.2 KiB
Python
216 lines
7.2 KiB
Python
"""
|
|
Unit tests for ScriptGenerator.
|
|
[AC-IDS-04, AC-IDS-11] Test script generation for flexible mode.
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
|
|
from app.services.flow.script_generator import ScriptGenerator
|
|
|
|
|
|
class MockLLMClient:
|
|
"""Mock LLM client for testing."""
|
|
|
|
def __init__(self, response: str = "您好,请问怎么称呼您?", delay: float = 0):
|
|
self._response = response
|
|
self._delay = delay
|
|
|
|
async def generate_text(self, prompt: str) -> str:
|
|
if self._delay > 0:
|
|
await asyncio.sleep(self._delay)
|
|
return self._response
|
|
|
|
async def generate(self, messages: list) -> "MockResponse":
|
|
if self._delay > 0:
|
|
await asyncio.sleep(self._delay)
|
|
return MockResponse(self._response)
|
|
|
|
|
|
class MockResponse:
|
|
"""Mock LLM response."""
|
|
def __init__(self, content: str):
|
|
self.content = content
|
|
|
|
|
|
class TestScriptGenerator:
|
|
"""[AC-IDS-04, AC-IDS-11] Test cases for ScriptGenerator."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_fixed_mode_returns_fallback(self):
|
|
"""Test that fixed mode returns fallback content."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
result = await generator.generate(
|
|
intent="获取用户姓名",
|
|
intent_description="礼貌询问用户姓名",
|
|
constraints=["必须礼貌"],
|
|
context=None,
|
|
history=None,
|
|
fallback="请问怎么称呼您?",
|
|
)
|
|
|
|
assert result == "请问怎么称呼您?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_llm_client(self):
|
|
"""Test script generation with LLM client."""
|
|
llm_client = MockLLMClient(response="您好,请问您贵姓?")
|
|
generator = ScriptGenerator(llm_client=llm_client)
|
|
|
|
result = await generator.generate(
|
|
intent="获取用户姓名",
|
|
intent_description="礼貌询问用户姓名",
|
|
constraints=["必须礼貌", "语气自然"],
|
|
context={"inputs": [{"step": 1, "input": "我想咨询"}]},
|
|
history=[{"role": "user", "content": "我想咨询"}],
|
|
fallback="请问怎么称呼您?",
|
|
)
|
|
|
|
assert result == "您好,请问您贵姓?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_timeout_fallback(self):
|
|
"""Test that timeout returns fallback content."""
|
|
llm_client = MockLLMClient(response="生成的话术", delay=6.0)
|
|
generator = ScriptGenerator(llm_client=llm_client)
|
|
|
|
result = await generator.generate(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=None,
|
|
context=None,
|
|
history=None,
|
|
fallback="请问怎么称呼您?",
|
|
)
|
|
|
|
assert result == "请问怎么称呼您?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_exception_fallback(self):
|
|
"""Test that exception returns fallback content."""
|
|
class FailingLLMClient:
|
|
async def generate_text(self, prompt: str) -> str:
|
|
raise RuntimeError("LLM service unavailable")
|
|
|
|
generator = ScriptGenerator(llm_client=FailingLLMClient())
|
|
|
|
result = await generator.generate(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=None,
|
|
context=None,
|
|
history=None,
|
|
fallback="请问怎么称呼您?",
|
|
)
|
|
|
|
assert result == "请问怎么称呼您?"
|
|
|
|
def test_build_prompt_basic(self):
|
|
"""Test prompt building with basic parameters."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=None,
|
|
context=None,
|
|
history=None,
|
|
)
|
|
|
|
assert "获取用户姓名" in prompt
|
|
assert "步骤目标" in prompt
|
|
|
|
def test_build_prompt_with_description(self):
|
|
"""Test prompt building with intent description."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description="需要获取用户的真实姓名用于后续身份确认",
|
|
constraints=None,
|
|
context=None,
|
|
history=None,
|
|
)
|
|
|
|
assert "获取用户姓名" in prompt
|
|
assert "需要获取用户的真实姓名用于后续身份确认" in prompt
|
|
assert "详细说明" in prompt
|
|
|
|
def test_build_prompt_with_constraints(self):
|
|
"""Test prompt building with constraints."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=["必须礼貌", "语气自然", "不要生硬"],
|
|
context=None,
|
|
history=None,
|
|
)
|
|
|
|
assert "约束条件" in prompt
|
|
assert "- 必须礼貌" in prompt
|
|
assert "- 语气自然" in prompt
|
|
assert "- 不要生硬" in prompt
|
|
|
|
def test_build_prompt_with_history(self):
|
|
"""Test prompt building with conversation history."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=None,
|
|
context=None,
|
|
history=[
|
|
{"role": "user", "content": "你好"},
|
|
{"role": "assistant", "content": "您好,有什么可以帮您?"},
|
|
{"role": "user", "content": "我想咨询"},
|
|
],
|
|
)
|
|
|
|
assert "对话历史" in prompt
|
|
assert "用户: 你好" in prompt
|
|
assert "客服: 您好,有什么可以帮您?" in prompt
|
|
|
|
def test_build_prompt_with_context(self):
|
|
"""Test prompt building with session context."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description=None,
|
|
constraints=None,
|
|
context={
|
|
"inputs": [
|
|
{"step": 1, "input": "我想咨询产品"},
|
|
{"step": 2, "input": "手机"},
|
|
]
|
|
},
|
|
history=None,
|
|
)
|
|
|
|
assert "已收集信息" in prompt
|
|
assert "步骤1: 我想咨询产品" in prompt
|
|
assert "步骤2: 手机" in prompt
|
|
|
|
def test_build_prompt_complete(self):
|
|
"""Test prompt building with all parameters."""
|
|
generator = ScriptGenerator(llm_client=None)
|
|
|
|
prompt = generator._build_prompt(
|
|
intent="获取用户姓名",
|
|
intent_description="需要获取用户的真实姓名",
|
|
constraints=["必须礼貌", "语气自然"],
|
|
context={"inputs": [{"step": 1, "input": "咨询"}]},
|
|
history=[{"role": "user", "content": "你好"}],
|
|
)
|
|
|
|
assert "步骤目标" in prompt
|
|
assert "详细说明" in prompt
|
|
assert "约束条件" in prompt
|
|
assert "对话历史" in prompt
|
|
assert "已收集信息" in prompt
|
|
assert "不超过200字" in prompt
|