[AC-AISVC-02, AC-AISVC-16] 多个需求合并 #1

Merged
MerCry merged 45 commits from feature/prompt-unification-and-logging into main 2026-02-25 17:17:35 +00:00
5 changed files with 939 additions and 14 deletions
Showing only changes of commit 550d0d8498 - Show all commits

View File

@ -101,7 +101,10 @@ def create_final_event(
transfer_reason=transfer_reason,
metadata=metadata,
)
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
return format_sse_event(
SSEEventType.FINAL,
event_data.model_dump(exclude_none=True, by_alias=True)
)
def create_error_event(

View File

@ -0,0 +1,245 @@
"""
Context management utilities for AI Service.
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
Design reference: design.md Section 7 - 上下文合并规则
- H_local: Memory layer history (sorted by time)
- H_ext: External history from Java request (in passed order)
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
- Truncation: Keep most recent N messages within token budget
"""
import hashlib
import logging
from dataclasses import dataclass, field
from typing import Any
import tiktoken
from app.core.config import get_settings
from app.models import ChatMessage, Role
logger = logging.getLogger(__name__)
@dataclass
class MergedContext:
"""
Result of context merging.
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
"""
messages: list[dict[str, str]] = field(default_factory=list)
total_tokens: int = 0
local_count: int = 0
external_count: int = 0
duplicates_skipped: int = 0
truncated_count: int = 0
diagnostics: list[dict[str, Any]] = field(default_factory=list)
class ContextMerger:
"""
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
Design reference: design.md Section 7
- Deduplication based on message fingerprint
- Priority: local history takes precedence
- Token-based truncation using tiktoken
"""
def __init__(
self,
max_history_tokens: int | None = None,
encoding_name: str = "cl100k_base",
):
settings = get_settings()
self._max_history_tokens = max_history_tokens or 4096
self._encoding = tiktoken.get_encoding(encoding_name)
def compute_fingerprint(self, role: str, content: str) -> str:
"""
Compute message fingerprint for deduplication.
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
Args:
role: Message role (user/assistant)
content: Message content
Returns:
SHA256 hash of the normalized message
"""
normalized_content = content.strip()
fingerprint_input = f"{role}|{normalized_content}"
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
"""Convert ChatMessage or dict to standard dict format."""
if isinstance(message, ChatMessage):
return {"role": message.role.value, "content": message.content}
return message
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
"""
Count total tokens in messages using tiktoken.
[AC-AISVC-14] Token counting for history truncation.
"""
total = 0
for msg in messages:
total += len(self._encoding.encode(msg.get("role", "")))
total += len(self._encoding.encode(msg.get("content", "")))
total += 4 # Approximate overhead for message structure
return total
def merge_context(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
) -> MergedContext:
"""
Merge local and external history with deduplication.
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
Design reference: design.md Section 7.2
1. Build seen set from H_local
2. Traverse H_ext, append if fingerprint not seen
3. Local history takes priority
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
Returns:
MergedContext with merged messages and diagnostics
"""
result = MergedContext()
seen_fingerprints: set[str] = set()
merged_messages: list[dict[str, str]] = []
diagnostics: list[dict[str, Any]] = []
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
for msg in local_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.local_count += 1
for msg in external_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
if fingerprint not in seen_fingerprints:
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.external_count += 1
else:
result.duplicates_skipped += 1
diagnostics.append({
"type": "duplicate_skipped",
"role": msg["role"],
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
})
result.messages = merged_messages
result.diagnostics = diagnostics
result.total_tokens = self._count_tokens(merged_messages)
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={result.local_count}, external={result.external_count}, "
f"duplicates_skipped={result.duplicates_skipped}, "
f"total_tokens={result.total_tokens}"
)
return result
def truncate_context(
self,
messages: list[dict[str, str]],
max_tokens: int | None = None,
) -> tuple[list[dict[str, str]], int]:
"""
Truncate context to fit within token budget.
[AC-AISVC-14] Keep most recent N messages within budget.
Design reference: design.md Section 7.4
- Budget = maxHistoryTokens (configurable)
- Strategy: Keep most recent messages (from tail backward)
Args:
messages: List of messages to truncate
max_tokens: Maximum token budget (uses default if not provided)
Returns:
Tuple of (truncated messages, truncated count)
"""
budget = max_tokens or self._max_history_tokens
if not messages:
return [], 0
total_tokens = self._count_tokens(messages)
if total_tokens <= budget:
return messages, 0
truncated_messages: list[dict[str, str]] = []
current_tokens = 0
truncated_count = 0
for msg in reversed(messages):
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
msg_tokens += 4
if current_tokens + msg_tokens <= budget:
truncated_messages.insert(0, msg)
current_tokens += msg_tokens
else:
truncated_count += 1
logger.info(
f"[AC-AISVC-14] Context truncated: "
f"original={len(messages)}, truncated={len(truncated_messages)}, "
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
)
return truncated_messages, truncated_count
def merge_and_truncate(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
max_tokens: int | None = None,
) -> MergedContext:
"""
Merge and truncate context in one operation.
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
max_tokens: Maximum token budget
Returns:
MergedContext with final messages after merge and truncate
"""
merged = self.merge_context(local_history, external_history)
truncated_messages, truncated_count = self.truncate_context(
merged.messages, max_tokens
)
merged.messages = truncated_messages
merged.truncated_count = truncated_count
merged.total_tokens = self._count_tokens(truncated_messages)
return merged
_context_merger: ContextMerger | None = None
def get_context_merger() -> ContextMerger:
"""Get or create context merger instance."""
global _context_merger
if _context_merger is None:
_context_merger = ContextMerger()
return _context_merger

View File

@ -1,6 +1,6 @@
"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
"""
import logging
@ -9,17 +9,31 @@ from typing import AsyncGenerator
from sse_starlette.sse import ServerSentEvent
from app.models import ChatRequest, ChatResponse
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
from app.core.sse import create_error_event, create_final_event, create_message_event, SSEStateMachine
logger = logging.getLogger(__name__)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
Coordinates memory, retrieval, and LLM components.
SSE Event Flow (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
"""
def __init__(self, llm_client=None):
"""
Initialize orchestrator with optional LLM client.
Args:
llm_client: Optional LLM client for dependency injection.
If None, will use mock implementation for demo.
"""
self._llm_client = llm_client
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
"""
Generate a non-streaming response.
@ -30,6 +44,15 @@ class OrchestratorService:
f"session={request.session_id}"
)
if self._llm_client:
messages = self._build_messages(request)
response = await self._llm_client.generate(messages)
return ChatResponse(
reply=response.content,
confidence=0.85,
should_transfer=False,
)
reply = f"Received your message: {request.current_message}"
return ChatResponse(
reply=reply,
@ -43,6 +66,16 @@ class OrchestratorService:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
SSE Event Sequence (per design.md Section 6.2):
1. message events (multiple) - each with incremental delta
2. final event (exactly 1) - with complete response
3. connection close
OR on error:
1. message events (0 or more)
2. error event (exactly 1)
3. connection close
"""
logger.info(
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
@ -53,14 +86,18 @@ class OrchestratorService:
await state_machine.transition_to_streaming()
try:
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
full_reply = ""
for part in reply_parts:
if state_machine.can_send_message():
full_reply += part
yield create_message_event(delta=part)
await self._simulate_llm_delay()
if self._llm_client:
async for event in self._stream_from_llm(request, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
else:
async for event in self._stream_mock_response(request, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
if await state_machine.transition_to_final():
yield create_final_event(
@ -72,7 +109,6 @@ class OrchestratorService:
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
from app.core.sse import create_error_event
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
@ -80,10 +116,73 @@ class OrchestratorService:
finally:
await state_machine.close()
async def _simulate_llm_delay(self) -> None:
"""Simulate LLM processing delay for demo purposes."""
async def _stream_from_llm(
self, request: ChatRequest, state_machine: SSEStateMachine
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
"""
messages = self._build_messages(request)
async for chunk in self._llm_client.stream_generate(messages):
if not state_machine.can_send_message():
break
if chunk.delta:
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
yield create_message_event(delta=chunk.delta)
if chunk.finish_reason:
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
break
async def _stream_mock_response(
self, request: ChatRequest, state_machine: SSEStateMachine
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
Simulates LLM-style incremental output.
"""
import asyncio
await asyncio.sleep(0.1)
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
for part in reply_parts:
if not state_machine.can_send_message():
break
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
yield create_message_event(delta=part)
await asyncio.sleep(0.05)
def _build_messages(self, request: ChatRequest) -> list[dict[str, str]]:
"""Build messages list for LLM from request."""
messages = []
if request.history:
for msg in request.history:
messages.append({
"role": msg.role.value,
"content": msg.content,
})
messages.append({
"role": "user",
"content": request.current_message,
})
return messages
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
"""Extract delta content from a message event."""
import json
try:
if event.data:
data = json.loads(event.data)
return data.get("delta", "")
except (json.JSONDecodeError, TypeError):
pass
return ""
_orchestrator_service: OrchestratorService | None = None

View File

@ -0,0 +1,287 @@
"""
Unit tests for Context Merger.
[AC-AISVC-14, AC-AISVC-15] Tests for context merging and truncation.
Tests cover:
- Message fingerprint computation
- Context merging with deduplication
- Token-based truncation
- Complete merge_and_truncate pipeline
"""
import hashlib
from unittest.mock import MagicMock, patch
import pytest
from app.models import ChatMessage, Role
from app.services.context import ContextMerger, MergedContext, get_context_merger
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
return settings
@pytest.fixture
def context_merger(mock_settings):
"""Create context merger with mocked settings."""
with patch("app.services.context.get_settings", return_value=mock_settings):
merger = ContextMerger(max_history_tokens=1000)
yield merger
@pytest.fixture
def local_history():
"""Sample local history messages."""
return [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
ChatMessage(role=Role.USER, content="How are you?"),
]
@pytest.fixture
def external_history():
"""Sample external history messages."""
return [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
ChatMessage(role=Role.USER, content="What's the weather?"),
]
@pytest.fixture
def dict_local_history():
"""Sample local history as dicts."""
return [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
@pytest.fixture
def dict_external_history():
"""Sample external history as dicts."""
return [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "What's the weather?"},
]
class TestFingerprintComputation:
"""Tests for message fingerprint computation. [AC-AISVC-15]"""
def test_fingerprint_consistency(self, context_merger):
"""Test that same input produces same fingerprint."""
fp1 = context_merger.compute_fingerprint("user", "Hello world")
fp2 = context_merger.compute_fingerprint("user", "Hello world")
assert fp1 == fp2
def test_fingerprint_role_difference(self, context_merger):
"""Test that different roles produce different fingerprints."""
fp_user = context_merger.compute_fingerprint("user", "Hello")
fp_assistant = context_merger.compute_fingerprint("assistant", "Hello")
assert fp_user != fp_assistant
def test_fingerprint_content_difference(self, context_merger):
"""Test that different content produces different fingerprints."""
fp1 = context_merger.compute_fingerprint("user", "Hello")
fp2 = context_merger.compute_fingerprint("user", "World")
assert fp1 != fp2
def test_fingerprint_normalization(self, context_merger):
"""Test that content is normalized (trimmed)."""
fp1 = context_merger.compute_fingerprint("user", "Hello")
fp2 = context_merger.compute_fingerprint("user", " Hello ")
assert fp1 == fp2
def test_fingerprint_is_sha256(self, context_merger):
"""Test that fingerprint is SHA256 hash."""
fp = context_merger.compute_fingerprint("user", "Hello")
expected = hashlib.sha256("user|Hello".encode("utf-8")).hexdigest()
assert fp == expected
assert len(fp) == 64 # SHA256 produces 64 hex characters
class TestContextMerging:
"""Tests for context merging with deduplication. [AC-AISVC-14, AC-AISVC-15]"""
def test_merge_empty_histories(self, context_merger):
"""[AC-AISVC-14] Test merging empty histories."""
result = context_merger.merge_context(None, None)
assert isinstance(result, MergedContext)
assert result.messages == []
assert result.local_count == 0
assert result.external_count == 0
assert result.duplicates_skipped == 0
def test_merge_local_only(self, context_merger, local_history):
"""[AC-AISVC-14] Test merging with only local history (no external)."""
result = context_merger.merge_context(local_history, None)
assert len(result.messages) == 3
assert result.local_count == 3
assert result.external_count == 0
assert result.duplicates_skipped == 0
def test_merge_external_only(self, context_merger, external_history):
"""[AC-AISVC-15] Test merging with only external history (no local)."""
result = context_merger.merge_context(None, external_history)
assert len(result.messages) == 3
assert result.local_count == 0
assert result.external_count == 3
assert result.duplicates_skipped == 0
def test_merge_with_duplicates(self, context_merger, local_history, external_history):
"""[AC-AISVC-15] Test deduplication when merging overlapping histories."""
result = context_merger.merge_context(local_history, external_history)
assert len(result.messages) == 4
assert result.local_count == 3
assert result.external_count == 1
assert result.duplicates_skipped == 2
roles = [m["role"] for m in result.messages]
contents = [m["content"] for m in result.messages]
assert "What's the weather?" in contents
def test_merge_with_dict_histories(self, context_merger, dict_local_history, dict_external_history):
"""[AC-AISVC-14, AC-AISVC-15] Test merging with dict format histories."""
result = context_merger.merge_context(dict_local_history, dict_external_history)
assert len(result.messages) == 3
assert result.local_count == 2
assert result.external_count == 1
assert result.duplicates_skipped == 1
def test_merge_priority_local(self, context_merger):
"""[AC-AISVC-15] Test that local history takes priority."""
local = [ChatMessage(role=Role.USER, content="Hello")]
external = [ChatMessage(role=Role.USER, content="Hello")]
result = context_merger.merge_context(local, external)
assert len(result.messages) == 1
assert result.duplicates_skipped == 1
def test_merge_records_diagnostics(self, context_merger, local_history, external_history):
"""[AC-AISVC-15] Test that duplicates are recorded in diagnostics."""
result = context_merger.merge_context(local_history, external_history)
assert len(result.diagnostics) == 2
for diag in result.diagnostics:
assert diag["type"] == "duplicate_skipped"
assert "role" in diag
assert "content_preview" in diag
class TestTokenTruncation:
"""Tests for token-based truncation. [AC-AISVC-14]"""
def test_truncate_empty_messages(self, context_merger):
"""[AC-AISVC-14] Test truncating empty message list."""
truncated, count = context_merger.truncate_context([], 100)
assert truncated == []
assert count == 0
def test_truncate_within_budget(self, context_merger):
"""[AC-AISVC-14] Test that messages within budget are not truncated."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
truncated, count = context_merger.truncate_context(messages, 1000)
assert len(truncated) == 2
assert count == 0
def test_truncate_exceeds_budget(self, context_merger):
"""[AC-AISVC-14] Test that messages exceeding budget are truncated."""
messages = [
{"role": "user", "content": "Hello world " * 100},
{"role": "assistant", "content": "Hi there " * 100},
{"role": "user", "content": "Short message"},
]
truncated, count = context_merger.truncate_context(messages, 50)
assert len(truncated) < len(messages)
assert count > 0
def test_truncate_keeps_recent_messages(self, context_merger):
"""[AC-AISVC-14] Test that truncation keeps most recent messages."""
messages = [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Second message"},
{"role": "user", "content": "Third message"},
]
truncated, count = context_merger.truncate_context(messages, 20)
if count > 0:
assert "Third message" in [m["content"] for m in truncated]
def test_truncate_with_default_budget(self, context_merger):
"""[AC-AISVC-14] Test truncation with default budget from config."""
messages = [{"role": "user", "content": "Test"}]
truncated, count = context_merger.truncate_context(messages)
assert len(truncated) == 1
assert count == 0
class TestMergeAndTruncate:
"""Tests for complete merge_and_truncate pipeline. [AC-AISVC-14, AC-AISVC-15]"""
def test_merge_and_truncate_combined(self, context_merger):
"""[AC-AISVC-14, AC-AISVC-15] Test complete pipeline."""
local = [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi"),
]
external = [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.USER, content="What's up?"),
]
result = context_merger.merge_and_truncate(local, external, max_tokens=1000)
assert isinstance(result, MergedContext)
assert len(result.messages) == 3
assert result.local_count == 2
assert result.external_count == 1
assert result.duplicates_skipped == 1
def test_merge_and_truncate_with_truncation(self, context_merger):
"""[AC-AISVC-14, AC-AISVC-15] Test pipeline with truncation."""
local = [
ChatMessage(role=Role.USER, content="Hello " * 50),
ChatMessage(role=Role.ASSISTANT, content="Hi " * 50),
]
external = [
ChatMessage(role=Role.USER, content="Short"),
]
result = context_merger.merge_and_truncate(local, external, max_tokens=50)
assert result.truncated_count > 0
assert result.total_tokens <= 50
class TestContextMergerSingleton:
"""Tests for singleton pattern."""
def test_get_context_merger_singleton(self, mock_settings):
"""Test that get_context_merger returns singleton."""
with patch("app.services.context.get_settings", return_value=mock_settings):
from app.services.context import _context_merger
import app.services.context as context_module
context_module._context_merger = None
merger1 = get_context_merger()
merger2 = get_context_merger()
assert merger1 is merger2

View File

@ -0,0 +1,291 @@
"""
Tests for SSE event generator.
[AC-AISVC-07] Tests for message event generation with delta content.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from sse_starlette.sse import ServerSentEvent
from app.core.sse import (
create_message_event,
create_final_event,
create_error_event,
SSEStateMachine,
SSEState,
)
from app.services.orchestrator import OrchestratorService
from app.models import ChatRequest, ChannelType
class TestSSEEventGenerator:
"""
[AC-AISVC-07] Test cases for SSE event generation.
"""
def test_create_message_event_format(self):
"""
[AC-AISVC-07] Test that message event has correct format.
Event should have:
- event: "message"
- data: JSON with "delta" field
"""
event = create_message_event(delta="Hello, ")
assert event.event == "message"
assert event.data is not None
data = json.loads(event.data)
assert "delta" in data
assert data["delta"] == "Hello, "
def test_create_message_event_with_unicode(self):
"""
[AC-AISVC-07] Test that message event handles unicode correctly.
"""
event = create_message_event(delta="你好,世界!")
assert event.event == "message"
data = json.loads(event.data)
assert data["delta"] == "你好,世界!"
def test_create_message_event_with_empty_delta(self):
"""
[AC-AISVC-07] Test that message event handles empty delta.
"""
event = create_message_event(delta="")
assert event.event == "message"
data = json.loads(event.data)
assert data["delta"] == ""
def test_create_final_event_format(self):
"""
[AC-AISVC-08] Test that final event has correct format.
"""
event = create_final_event(
reply="Complete response",
confidence=0.85,
should_transfer=False,
)
assert event.event == "final"
data = json.loads(event.data)
assert data["reply"] == "Complete response"
assert data["confidence"] == 0.85
assert data["shouldTransfer"] is False
def test_create_final_event_with_transfer_reason(self):
"""
[AC-AISVC-08] Test final event with transfer reason.
"""
event = create_final_event(
reply="I cannot help with this",
confidence=0.3,
should_transfer=True,
transfer_reason="Low confidence score",
)
assert event.event == "final"
data = json.loads(event.data)
assert data["shouldTransfer"] is True
assert data["transferReason"] == "Low confidence score"
def test_create_error_event_format(self):
"""
[AC-AISVC-09] Test that error event has correct format.
"""
event = create_error_event(
code="GENERATION_ERROR",
message="Failed to generate response",
)
assert event.event == "error"
data = json.loads(event.data)
assert data["code"] == "GENERATION_ERROR"
assert data["message"] == "Failed to generate response"
def test_create_error_event_with_details(self):
"""
[AC-AISVC-09] Test error event with details.
"""
event = create_error_event(
code="VALIDATION_ERROR",
message="Invalid input",
details=[{"field": "message", "error": "too long"}],
)
assert event.event == "error"
data = json.loads(event.data)
assert data["details"] == [{"field": "message", "error": "too long"}]
class TestOrchestratorStreaming:
"""
[AC-AISVC-07] Test cases for orchestrator streaming with SSE events.
"""
@pytest.fixture
def orchestrator(self):
return OrchestratorService()
@pytest.fixture
def chat_request(self):
return ChatRequest(
session_id="test_session",
current_message="Hello",
channel_type=ChannelType.WECHAT,
)
@pytest.mark.asyncio
async def test_stream_yields_message_events(self, orchestrator, chat_request):
"""
[AC-AISVC-07] Test that streaming yields message events with delta content.
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
final_events = [e for e in events if e.event == "final"]
assert len(message_events) > 0, "Should have at least one message event"
assert len(final_events) == 1, "Should have exactly one final event"
for event in message_events:
data = json.loads(event.data)
assert "delta" in data
assert isinstance(data["delta"], str)
@pytest.mark.asyncio
async def test_stream_message_events_contain_content(self, orchestrator, chat_request):
"""
[AC-AISVC-07] Test that message events contain the expected content.
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
full_content = ""
for event in message_events:
data = json.loads(event.data)
full_content += data["delta"]
assert "Hello" in full_content, "Content should contain the user message"
@pytest.mark.asyncio
async def test_stream_event_sequence(self, orchestrator, chat_request):
"""
[AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence.
message* -> final -> close
"""
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
event_types = [e.event for e in events]
final_index = event_types.index("final")
message_indices = [i for i, t in enumerate(event_types) if t == "message"]
for msg_idx in message_indices:
assert msg_idx < final_index, "All message events should come before final"
@pytest.mark.asyncio
async def test_stream_with_llm_client(self, chat_request):
"""
[AC-AISVC-07] Test streaming with mock LLM client.
"""
mock_llm = MagicMock()
mock_chunk1 = MagicMock()
mock_chunk1.delta = "Hello"
mock_chunk1.finish_reason = None
mock_chunk2 = MagicMock()
mock_chunk2.delta = " there!"
mock_chunk2.finish_reason = None
mock_chunk3 = MagicMock()
mock_chunk3.delta = ""
mock_chunk3.finish_reason = "stop"
async def mock_stream(*args, **kwargs):
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
yield chunk
mock_llm.stream_generate = mock_stream
orchestrator = OrchestratorService(llm_client=mock_llm)
events = []
async for event in orchestrator.generate_stream("tenant_001", chat_request):
events.append(event)
message_events = [e for e in events if e.event == "message"]
assert len(message_events) == 2, "Should have two message events"
full_content = ""
for event in message_events:
data = json.loads(event.data)
full_content += data["delta"]
assert full_content == "Hello there!"
@pytest.mark.asyncio
async def test_stream_handles_error(self, orchestrator, chat_request):
"""
[AC-AISVC-09] Test that streaming errors are converted to error events.
"""
pass
class TestSSEStateMachineIntegration:
"""
[AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine.
"""
@pytest.mark.asyncio
async def test_state_machine_prevents_events_after_final(self):
"""
[AC-AISVC-08] Test that no events can be sent after final.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
assert state_machine.can_send_message() is True
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False
assert state_machine.state == SSEState.FINAL_SENT
@pytest.mark.asyncio
async def test_state_machine_prevents_events_after_error(self):
"""
[AC-AISVC-09] Test that no events can be sent after error.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
await state_machine.transition_to_error()
assert state_machine.can_send_message() is False
assert state_machine.state == SSEState.ERROR_SENT
@pytest.mark.asyncio
async def test_state_machine_allows_multiple_message_events(self):
"""
[AC-AISVC-07] Test that multiple message events can be sent during streaming.
"""
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
for _ in range(5):
assert state_machine.can_send_message() is True
await state_machine.transition_to_final()
assert state_machine.can_send_message() is False