feat(ai-service): implement context merging for T3.2 [AC-AISVC-14, AC-AISVC-15]
- Add ContextMerger class for combining local and external history - Implement message fingerprint computation (SHA256 hash) - Implement deduplication: local history takes priority - Implement token-based truncation using tiktoken - Add comprehensive unit tests (20 test cases)
This commit is contained in:
parent
4cee28e9f4
commit
550d0d8498
|
|
@ -101,7 +101,10 @@ def create_final_event(
|
||||||
transfer_reason=transfer_reason,
|
transfer_reason=transfer_reason,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True))
|
return format_sse_event(
|
||||||
|
SSEEventType.FINAL,
|
||||||
|
event_data.model_dump(exclude_none=True, by_alias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_error_event(
|
def create_error_event(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""
|
||||||
|
Context management utilities for AI Service.
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
|
||||||
|
|
||||||
|
Design reference: design.md Section 7 - 上下文合并规则
|
||||||
|
- H_local: Memory layer history (sorted by time)
|
||||||
|
- H_ext: External history from Java request (in passed order)
|
||||||
|
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
|
||||||
|
- Truncation: Keep most recent N messages within token budget
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.models import ChatMessage, Role
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MergedContext:
|
||||||
|
"""
|
||||||
|
Result of context merging.
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
|
||||||
|
"""
|
||||||
|
messages: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
total_tokens: int = 0
|
||||||
|
local_count: int = 0
|
||||||
|
external_count: int = 0
|
||||||
|
duplicates_skipped: int = 0
|
||||||
|
truncated_count: int = 0
|
||||||
|
diagnostics: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextMerger:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
|
||||||
|
|
||||||
|
Design reference: design.md Section 7
|
||||||
|
- Deduplication based on message fingerprint
|
||||||
|
- Priority: local history takes precedence
|
||||||
|
- Token-based truncation using tiktoken
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_history_tokens: int | None = None,
|
||||||
|
encoding_name: str = "cl100k_base",
|
||||||
|
):
|
||||||
|
settings = get_settings()
|
||||||
|
self._max_history_tokens = max_history_tokens or 4096
|
||||||
|
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||||
|
|
||||||
|
def compute_fingerprint(self, role: str, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute message fingerprint for deduplication.
|
||||||
|
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: Message role (user/assistant)
|
||||||
|
content: Message content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SHA256 hash of the normalized message
|
||||||
|
"""
|
||||||
|
normalized_content = content.strip()
|
||||||
|
fingerprint_input = f"{role}|{normalized_content}"
|
||||||
|
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
|
||||||
|
"""Convert ChatMessage or dict to standard dict format."""
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
return {"role": message.role.value, "content": message.content}
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
|
||||||
|
"""
|
||||||
|
Count total tokens in messages using tiktoken.
|
||||||
|
[AC-AISVC-14] Token counting for history truncation.
|
||||||
|
"""
|
||||||
|
total = 0
|
||||||
|
for msg in messages:
|
||||||
|
total += len(self._encoding.encode(msg.get("role", "")))
|
||||||
|
total += len(self._encoding.encode(msg.get("content", "")))
|
||||||
|
total += 4 # Approximate overhead for message structure
|
||||||
|
return total
|
||||||
|
|
||||||
|
def merge_context(
|
||||||
|
self,
|
||||||
|
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||||
|
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||||
|
) -> MergedContext:
|
||||||
|
"""
|
||||||
|
Merge local and external history with deduplication.
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
|
||||||
|
|
||||||
|
Design reference: design.md Section 7.2
|
||||||
|
1. Build seen set from H_local
|
||||||
|
2. Traverse H_ext, append if fingerprint not seen
|
||||||
|
3. Local history takes priority
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_history: History from Memory layer (H_local)
|
||||||
|
external_history: History from Java request (H_ext)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MergedContext with merged messages and diagnostics
|
||||||
|
"""
|
||||||
|
result = MergedContext()
|
||||||
|
seen_fingerprints: set[str] = set()
|
||||||
|
merged_messages: list[dict[str, str]] = []
|
||||||
|
diagnostics: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
|
||||||
|
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
|
||||||
|
|
||||||
|
for msg in local_messages:
|
||||||
|
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||||
|
seen_fingerprints.add(fingerprint)
|
||||||
|
merged_messages.append(msg)
|
||||||
|
result.local_count += 1
|
||||||
|
|
||||||
|
for msg in external_messages:
|
||||||
|
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
|
||||||
|
if fingerprint not in seen_fingerprints:
|
||||||
|
seen_fingerprints.add(fingerprint)
|
||||||
|
merged_messages.append(msg)
|
||||||
|
result.external_count += 1
|
||||||
|
else:
|
||||||
|
result.duplicates_skipped += 1
|
||||||
|
diagnostics.append({
|
||||||
|
"type": "duplicate_skipped",
|
||||||
|
"role": msg["role"],
|
||||||
|
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
|
||||||
|
})
|
||||||
|
|
||||||
|
result.messages = merged_messages
|
||||||
|
result.diagnostics = diagnostics
|
||||||
|
result.total_tokens = self._count_tokens(merged_messages)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
|
||||||
|
f"local={result.local_count}, external={result.external_count}, "
|
||||||
|
f"duplicates_skipped={result.duplicates_skipped}, "
|
||||||
|
f"total_tokens={result.total_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def truncate_context(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> tuple[list[dict[str, str]], int]:
|
||||||
|
"""
|
||||||
|
Truncate context to fit within token budget.
|
||||||
|
[AC-AISVC-14] Keep most recent N messages within budget.
|
||||||
|
|
||||||
|
Design reference: design.md Section 7.4
|
||||||
|
- Budget = maxHistoryTokens (configurable)
|
||||||
|
- Strategy: Keep most recent messages (from tail backward)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to truncate
|
||||||
|
max_tokens: Maximum token budget (uses default if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (truncated messages, truncated count)
|
||||||
|
"""
|
||||||
|
budget = max_tokens or self._max_history_tokens
|
||||||
|
if not messages:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
total_tokens = self._count_tokens(messages)
|
||||||
|
if total_tokens <= budget:
|
||||||
|
return messages, 0
|
||||||
|
|
||||||
|
truncated_messages: list[dict[str, str]] = []
|
||||||
|
current_tokens = 0
|
||||||
|
truncated_count = 0
|
||||||
|
|
||||||
|
for msg in reversed(messages):
|
||||||
|
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
|
||||||
|
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
|
||||||
|
msg_tokens += 4
|
||||||
|
|
||||||
|
if current_tokens + msg_tokens <= budget:
|
||||||
|
truncated_messages.insert(0, msg)
|
||||||
|
current_tokens += msg_tokens
|
||||||
|
else:
|
||||||
|
truncated_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-AISVC-14] Context truncated: "
|
||||||
|
f"original={len(messages)}, truncated={len(truncated_messages)}, "
|
||||||
|
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated_messages, truncated_count
|
||||||
|
|
||||||
|
def merge_and_truncate(
|
||||||
|
self,
|
||||||
|
local_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||||
|
external_history: list[ChatMessage] | list[dict[str, str]] | None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> MergedContext:
|
||||||
|
"""
|
||||||
|
Merge and truncate context in one operation.
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_history: History from Memory layer (H_local)
|
||||||
|
external_history: History from Java request (H_ext)
|
||||||
|
max_tokens: Maximum token budget
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MergedContext with final messages after merge and truncate
|
||||||
|
"""
|
||||||
|
merged = self.merge_context(local_history, external_history)
|
||||||
|
|
||||||
|
truncated_messages, truncated_count = self.truncate_context(
|
||||||
|
merged.messages, max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
merged.messages = truncated_messages
|
||||||
|
merged.truncated_count = truncated_count
|
||||||
|
merged.total_tokens = self._count_tokens(truncated_messages)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
_context_merger: ContextMerger | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_merger() -> ContextMerger:
|
||||||
|
"""Get or create context merger instance."""
|
||||||
|
global _context_merger
|
||||||
|
if _context_merger is None:
|
||||||
|
_context_merger = ContextMerger()
|
||||||
|
return _context_merger
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Orchestrator service for AI Service.
|
Orchestrator service for AI Service.
|
||||||
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
|
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -9,17 +9,31 @@ from typing import AsyncGenerator
|
||||||
from sse_starlette.sse import ServerSentEvent
|
from sse_starlette.sse import ServerSentEvent
|
||||||
|
|
||||||
from app.models import ChatRequest, ChatResponse
|
from app.models import ChatRequest, ChatResponse
|
||||||
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
|
from app.core.sse import create_error_event, create_final_event, create_message_event, SSEStateMachine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OrchestratorService:
|
class OrchestratorService:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
|
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
|
||||||
Coordinates memory, retrieval, and LLM components.
|
Coordinates memory, retrieval, and LLM components.
|
||||||
|
|
||||||
|
SSE Event Flow (per design.md Section 6.2):
|
||||||
|
- message* (0 or more) -> final (exactly 1) -> close
|
||||||
|
- OR message* (0 or more) -> error (exactly 1) -> close
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_client=None):
|
||||||
|
"""
|
||||||
|
Initialize orchestrator with optional LLM client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_client: Optional LLM client for dependency injection.
|
||||||
|
If None, will use mock implementation for demo.
|
||||||
|
"""
|
||||||
|
self._llm_client = llm_client
|
||||||
|
|
||||||
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
|
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
|
||||||
"""
|
"""
|
||||||
Generate a non-streaming response.
|
Generate a non-streaming response.
|
||||||
|
|
@ -30,6 +44,15 @@ class OrchestratorService:
|
||||||
f"session={request.session_id}"
|
f"session={request.session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._llm_client:
|
||||||
|
messages = self._build_messages(request)
|
||||||
|
response = await self._llm_client.generate(messages)
|
||||||
|
return ChatResponse(
|
||||||
|
reply=response.content,
|
||||||
|
confidence=0.85,
|
||||||
|
should_transfer=False,
|
||||||
|
)
|
||||||
|
|
||||||
reply = f"Received your message: {request.current_message}"
|
reply = f"Received your message: {request.current_message}"
|
||||||
return ChatResponse(
|
return ChatResponse(
|
||||||
reply=reply,
|
reply=reply,
|
||||||
|
|
@ -43,6 +66,16 @@ class OrchestratorService:
|
||||||
"""
|
"""
|
||||||
Generate a streaming response.
|
Generate a streaming response.
|
||||||
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||||||
|
|
||||||
|
SSE Event Sequence (per design.md Section 6.2):
|
||||||
|
1. message events (multiple) - each with incremental delta
|
||||||
|
2. final event (exactly 1) - with complete response
|
||||||
|
3. connection close
|
||||||
|
|
||||||
|
OR on error:
|
||||||
|
1. message events (0 or more)
|
||||||
|
2. error event (exactly 1)
|
||||||
|
3. connection close
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||||||
|
|
@ -53,14 +86,18 @@ class OrchestratorService:
|
||||||
await state_machine.transition_to_streaming()
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
|
||||||
full_reply = ""
|
full_reply = ""
|
||||||
|
|
||||||
for part in reply_parts:
|
if self._llm_client:
|
||||||
if state_machine.can_send_message():
|
async for event in self._stream_from_llm(request, state_machine):
|
||||||
full_reply += part
|
if event.event == "message":
|
||||||
yield create_message_event(delta=part)
|
full_reply += self._extract_delta_from_event(event)
|
||||||
await self._simulate_llm_delay()
|
yield event
|
||||||
|
else:
|
||||||
|
async for event in self._stream_mock_response(request, state_machine):
|
||||||
|
if event.event == "message":
|
||||||
|
full_reply += self._extract_delta_from_event(event)
|
||||||
|
yield event
|
||||||
|
|
||||||
if await state_machine.transition_to_final():
|
if await state_machine.transition_to_final():
|
||||||
yield create_final_event(
|
yield create_final_event(
|
||||||
|
|
@ -72,7 +109,6 @@ class OrchestratorService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||||||
if await state_machine.transition_to_error():
|
if await state_machine.transition_to_error():
|
||||||
from app.core.sse import create_error_event
|
|
||||||
yield create_error_event(
|
yield create_error_event(
|
||||||
code="GENERATION_ERROR",
|
code="GENERATION_ERROR",
|
||||||
message=str(e),
|
message=str(e),
|
||||||
|
|
@ -80,10 +116,73 @@ class OrchestratorService:
|
||||||
finally:
|
finally:
|
||||||
await state_machine.close()
|
await state_machine.close()
|
||||||
|
|
||||||
async def _simulate_llm_delay(self) -> None:
|
async def _stream_from_llm(
|
||||||
"""Simulate LLM processing delay for demo purposes."""
|
self, request: ChatRequest, state_machine: SSEStateMachine
|
||||||
|
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
|
||||||
|
"""
|
||||||
|
messages = self._build_messages(request)
|
||||||
|
|
||||||
|
async for chunk in self._llm_client.stream_generate(messages):
|
||||||
|
if not state_machine.can_send_message():
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.delta:
|
||||||
|
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
|
||||||
|
yield create_message_event(delta=chunk.delta)
|
||||||
|
|
||||||
|
if chunk.finish_reason:
|
||||||
|
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _stream_mock_response(
|
||||||
|
self, request: ChatRequest, state_machine: SSEStateMachine
|
||||||
|
) -> AsyncGenerator[ServerSentEvent, None]:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
|
||||||
|
Simulates LLM-style incremental output.
|
||||||
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
||||||
|
|
||||||
|
for part in reply_parts:
|
||||||
|
if not state_machine.can_send_message():
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
|
||||||
|
yield create_message_event(delta=part)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
def _build_messages(self, request: ChatRequest) -> list[dict[str, str]]:
|
||||||
|
"""Build messages list for LLM from request."""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if request.history:
|
||||||
|
for msg in request.history:
|
||||||
|
messages.append({
|
||||||
|
"role": msg.role.value,
|
||||||
|
"content": msg.content,
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": request.current_message,
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
|
||||||
|
"""Extract delta content from a message event."""
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
if event.data:
|
||||||
|
data = json.loads(event.data)
|
||||||
|
return data.get("delta", "")
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
_orchestrator_service: OrchestratorService | None = None
|
_orchestrator_service: OrchestratorService | None = None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,287 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Context Merger.
|
||||||
|
[AC-AISVC-14, AC-AISVC-15] Tests for context merging and truncation.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Message fingerprint computation
|
||||||
|
- Context merging with deduplication
|
||||||
|
- Token-based truncation
|
||||||
|
- Complete merge_and_truncate pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models import ChatMessage, Role
|
||||||
|
from app.services.context import ContextMerger, MergedContext, get_context_merger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""Mock settings for testing."""
|
||||||
|
settings = MagicMock()
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def context_merger(mock_settings):
|
||||||
|
"""Create context merger with mocked settings."""
|
||||||
|
with patch("app.services.context.get_settings", return_value=mock_settings):
|
||||||
|
merger = ContextMerger(max_history_tokens=1000)
|
||||||
|
yield merger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def local_history():
|
||||||
|
"""Sample local history messages."""
|
||||||
|
return [
|
||||||
|
ChatMessage(role=Role.USER, content="Hello"),
|
||||||
|
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
|
||||||
|
ChatMessage(role=Role.USER, content="How are you?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def external_history():
|
||||||
|
"""Sample external history messages."""
|
||||||
|
return [
|
||||||
|
ChatMessage(role=Role.USER, content="Hello"),
|
||||||
|
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
|
||||||
|
ChatMessage(role=Role.USER, content="What's the weather?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dict_local_history():
|
||||||
|
"""Sample local history as dicts."""
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dict_external_history():
|
||||||
|
"""Sample external history as dicts."""
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFingerprintComputation:
|
||||||
|
"""Tests for message fingerprint computation. [AC-AISVC-15]"""
|
||||||
|
|
||||||
|
def test_fingerprint_consistency(self, context_merger):
|
||||||
|
"""Test that same input produces same fingerprint."""
|
||||||
|
fp1 = context_merger.compute_fingerprint("user", "Hello world")
|
||||||
|
fp2 = context_merger.compute_fingerprint("user", "Hello world")
|
||||||
|
assert fp1 == fp2
|
||||||
|
|
||||||
|
def test_fingerprint_role_difference(self, context_merger):
|
||||||
|
"""Test that different roles produce different fingerprints."""
|
||||||
|
fp_user = context_merger.compute_fingerprint("user", "Hello")
|
||||||
|
fp_assistant = context_merger.compute_fingerprint("assistant", "Hello")
|
||||||
|
assert fp_user != fp_assistant
|
||||||
|
|
||||||
|
def test_fingerprint_content_difference(self, context_merger):
|
||||||
|
"""Test that different content produces different fingerprints."""
|
||||||
|
fp1 = context_merger.compute_fingerprint("user", "Hello")
|
||||||
|
fp2 = context_merger.compute_fingerprint("user", "World")
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
def test_fingerprint_normalization(self, context_merger):
|
||||||
|
"""Test that content is normalized (trimmed)."""
|
||||||
|
fp1 = context_merger.compute_fingerprint("user", "Hello")
|
||||||
|
fp2 = context_merger.compute_fingerprint("user", " Hello ")
|
||||||
|
assert fp1 == fp2
|
||||||
|
|
||||||
|
def test_fingerprint_is_sha256(self, context_merger):
|
||||||
|
"""Test that fingerprint is SHA256 hash."""
|
||||||
|
fp = context_merger.compute_fingerprint("user", "Hello")
|
||||||
|
expected = hashlib.sha256("user|Hello".encode("utf-8")).hexdigest()
|
||||||
|
assert fp == expected
|
||||||
|
assert len(fp) == 64 # SHA256 produces 64 hex characters
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextMerging:
|
||||||
|
"""Tests for context merging with deduplication. [AC-AISVC-14, AC-AISVC-15]"""
|
||||||
|
|
||||||
|
def test_merge_empty_histories(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test merging empty histories."""
|
||||||
|
result = context_merger.merge_context(None, None)
|
||||||
|
|
||||||
|
assert isinstance(result, MergedContext)
|
||||||
|
assert result.messages == []
|
||||||
|
assert result.local_count == 0
|
||||||
|
assert result.external_count == 0
|
||||||
|
assert result.duplicates_skipped == 0
|
||||||
|
|
||||||
|
def test_merge_local_only(self, context_merger, local_history):
|
||||||
|
"""[AC-AISVC-14] Test merging with only local history (no external)."""
|
||||||
|
result = context_merger.merge_context(local_history, None)
|
||||||
|
|
||||||
|
assert len(result.messages) == 3
|
||||||
|
assert result.local_count == 3
|
||||||
|
assert result.external_count == 0
|
||||||
|
assert result.duplicates_skipped == 0
|
||||||
|
|
||||||
|
def test_merge_external_only(self, context_merger, external_history):
|
||||||
|
"""[AC-AISVC-15] Test merging with only external history (no local)."""
|
||||||
|
result = context_merger.merge_context(None, external_history)
|
||||||
|
|
||||||
|
assert len(result.messages) == 3
|
||||||
|
assert result.local_count == 0
|
||||||
|
assert result.external_count == 3
|
||||||
|
assert result.duplicates_skipped == 0
|
||||||
|
|
||||||
|
def test_merge_with_duplicates(self, context_merger, local_history, external_history):
|
||||||
|
"""[AC-AISVC-15] Test deduplication when merging overlapping histories."""
|
||||||
|
result = context_merger.merge_context(local_history, external_history)
|
||||||
|
|
||||||
|
assert len(result.messages) == 4
|
||||||
|
assert result.local_count == 3
|
||||||
|
assert result.external_count == 1
|
||||||
|
assert result.duplicates_skipped == 2
|
||||||
|
|
||||||
|
roles = [m["role"] for m in result.messages]
|
||||||
|
contents = [m["content"] for m in result.messages]
|
||||||
|
assert "What's the weather?" in contents
|
||||||
|
|
||||||
|
def test_merge_with_dict_histories(self, context_merger, dict_local_history, dict_external_history):
|
||||||
|
"""[AC-AISVC-14, AC-AISVC-15] Test merging with dict format histories."""
|
||||||
|
result = context_merger.merge_context(dict_local_history, dict_external_history)
|
||||||
|
|
||||||
|
assert len(result.messages) == 3
|
||||||
|
assert result.local_count == 2
|
||||||
|
assert result.external_count == 1
|
||||||
|
assert result.duplicates_skipped == 1
|
||||||
|
|
||||||
|
def test_merge_priority_local(self, context_merger):
|
||||||
|
"""[AC-AISVC-15] Test that local history takes priority."""
|
||||||
|
local = [ChatMessage(role=Role.USER, content="Hello")]
|
||||||
|
external = [ChatMessage(role=Role.USER, content="Hello")]
|
||||||
|
|
||||||
|
result = context_merger.merge_context(local, external)
|
||||||
|
|
||||||
|
assert len(result.messages) == 1
|
||||||
|
assert result.duplicates_skipped == 1
|
||||||
|
|
||||||
|
def test_merge_records_diagnostics(self, context_merger, local_history, external_history):
|
||||||
|
"""[AC-AISVC-15] Test that duplicates are recorded in diagnostics."""
|
||||||
|
result = context_merger.merge_context(local_history, external_history)
|
||||||
|
|
||||||
|
assert len(result.diagnostics) == 2
|
||||||
|
for diag in result.diagnostics:
|
||||||
|
assert diag["type"] == "duplicate_skipped"
|
||||||
|
assert "role" in diag
|
||||||
|
assert "content_preview" in diag
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenTruncation:
|
||||||
|
"""Tests for token-based truncation. [AC-AISVC-14]"""
|
||||||
|
|
||||||
|
def test_truncate_empty_messages(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test truncating empty message list."""
|
||||||
|
truncated, count = context_merger.truncate_context([], 100)
|
||||||
|
assert truncated == []
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
def test_truncate_within_budget(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test that messages within budget are not truncated."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"},
|
||||||
|
]
|
||||||
|
truncated, count = context_merger.truncate_context(messages, 1000)
|
||||||
|
|
||||||
|
assert len(truncated) == 2
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
def test_truncate_exceeds_budget(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test that messages exceeding budget are truncated."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello world " * 100},
|
||||||
|
{"role": "assistant", "content": "Hi there " * 100},
|
||||||
|
{"role": "user", "content": "Short message"},
|
||||||
|
]
|
||||||
|
truncated, count = context_merger.truncate_context(messages, 50)
|
||||||
|
|
||||||
|
assert len(truncated) < len(messages)
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
def test_truncate_keeps_recent_messages(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test that truncation keeps most recent messages."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "First message"},
|
||||||
|
{"role": "assistant", "content": "Second message"},
|
||||||
|
{"role": "user", "content": "Third message"},
|
||||||
|
]
|
||||||
|
truncated, count = context_merger.truncate_context(messages, 20)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
assert "Third message" in [m["content"] for m in truncated]
|
||||||
|
|
||||||
|
def test_truncate_with_default_budget(self, context_merger):
|
||||||
|
"""[AC-AISVC-14] Test truncation with default budget from config."""
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
truncated, count = context_merger.truncate_context(messages)
|
||||||
|
|
||||||
|
assert len(truncated) == 1
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeAndTruncate:
|
||||||
|
"""Tests for complete merge_and_truncate pipeline. [AC-AISVC-14, AC-AISVC-15]"""
|
||||||
|
|
||||||
|
def test_merge_and_truncate_combined(self, context_merger):
|
||||||
|
"""[AC-AISVC-14, AC-AISVC-15] Test complete pipeline."""
|
||||||
|
local = [
|
||||||
|
ChatMessage(role=Role.USER, content="Hello"),
|
||||||
|
ChatMessage(role=Role.ASSISTANT, content="Hi"),
|
||||||
|
]
|
||||||
|
external = [
|
||||||
|
ChatMessage(role=Role.USER, content="Hello"),
|
||||||
|
ChatMessage(role=Role.USER, content="What's up?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = context_merger.merge_and_truncate(local, external, max_tokens=1000)
|
||||||
|
|
||||||
|
assert isinstance(result, MergedContext)
|
||||||
|
assert len(result.messages) == 3
|
||||||
|
assert result.local_count == 2
|
||||||
|
assert result.external_count == 1
|
||||||
|
assert result.duplicates_skipped == 1
|
||||||
|
|
||||||
|
def test_merge_and_truncate_with_truncation(self, context_merger):
|
||||||
|
"""[AC-AISVC-14, AC-AISVC-15] Test pipeline with truncation."""
|
||||||
|
local = [
|
||||||
|
ChatMessage(role=Role.USER, content="Hello " * 50),
|
||||||
|
ChatMessage(role=Role.ASSISTANT, content="Hi " * 50),
|
||||||
|
]
|
||||||
|
external = [
|
||||||
|
ChatMessage(role=Role.USER, content="Short"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = context_merger.merge_and_truncate(local, external, max_tokens=50)
|
||||||
|
|
||||||
|
assert result.truncated_count > 0
|
||||||
|
assert result.total_tokens <= 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextMergerSingleton:
|
||||||
|
"""Tests for singleton pattern."""
|
||||||
|
|
||||||
|
def test_get_context_merger_singleton(self, mock_settings):
|
||||||
|
"""Test that get_context_merger returns singleton."""
|
||||||
|
with patch("app.services.context.get_settings", return_value=mock_settings):
|
||||||
|
from app.services.context import _context_merger
|
||||||
|
import app.services.context as context_module
|
||||||
|
context_module._context_merger = None
|
||||||
|
|
||||||
|
merger1 = get_context_merger()
|
||||||
|
merger2 = get_context_merger()
|
||||||
|
|
||||||
|
assert merger1 is merger2
|
||||||
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""
|
||||||
|
Tests for SSE event generator.
|
||||||
|
[AC-AISVC-07] Tests for message event generation with delta content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from sse_starlette.sse import ServerSentEvent
|
||||||
|
|
||||||
|
from app.core.sse import (
|
||||||
|
create_message_event,
|
||||||
|
create_final_event,
|
||||||
|
create_error_event,
|
||||||
|
SSEStateMachine,
|
||||||
|
SSEState,
|
||||||
|
)
|
||||||
|
from app.services.orchestrator import OrchestratorService
|
||||||
|
from app.models import ChatRequest, ChannelType
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEventGenerator:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test cases for SSE event generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_create_message_event_format(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that message event has correct format.
|
||||||
|
Event should have:
|
||||||
|
- event: "message"
|
||||||
|
- data: JSON with "delta" field
|
||||||
|
"""
|
||||||
|
event = create_message_event(delta="Hello, ")
|
||||||
|
|
||||||
|
assert event.event == "message"
|
||||||
|
assert event.data is not None
|
||||||
|
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert "delta" in data
|
||||||
|
assert data["delta"] == "Hello, "
|
||||||
|
|
||||||
|
def test_create_message_event_with_unicode(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that message event handles unicode correctly.
|
||||||
|
"""
|
||||||
|
event = create_message_event(delta="你好,世界!")
|
||||||
|
|
||||||
|
assert event.event == "message"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["delta"] == "你好,世界!"
|
||||||
|
|
||||||
|
def test_create_message_event_with_empty_delta(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that message event handles empty delta.
|
||||||
|
"""
|
||||||
|
event = create_message_event(delta="")
|
||||||
|
|
||||||
|
assert event.event == "message"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["delta"] == ""
|
||||||
|
|
||||||
|
def test_create_final_event_format(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that final event has correct format.
|
||||||
|
"""
|
||||||
|
event = create_final_event(
|
||||||
|
reply="Complete response",
|
||||||
|
confidence=0.85,
|
||||||
|
should_transfer=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "final"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["reply"] == "Complete response"
|
||||||
|
assert data["confidence"] == 0.85
|
||||||
|
assert data["shouldTransfer"] is False
|
||||||
|
|
||||||
|
def test_create_final_event_with_transfer_reason(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test final event with transfer reason.
|
||||||
|
"""
|
||||||
|
event = create_final_event(
|
||||||
|
reply="I cannot help with this",
|
||||||
|
confidence=0.3,
|
||||||
|
should_transfer=True,
|
||||||
|
transfer_reason="Low confidence score",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "final"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["shouldTransfer"] is True
|
||||||
|
assert data["transferReason"] == "Low confidence score"
|
||||||
|
|
||||||
|
def test_create_error_event_format(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that error event has correct format.
|
||||||
|
"""
|
||||||
|
event = create_error_event(
|
||||||
|
code="GENERATION_ERROR",
|
||||||
|
message="Failed to generate response",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "error"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["code"] == "GENERATION_ERROR"
|
||||||
|
assert data["message"] == "Failed to generate response"
|
||||||
|
|
||||||
|
def test_create_error_event_with_details(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test error event with details.
|
||||||
|
"""
|
||||||
|
event = create_error_event(
|
||||||
|
code="VALIDATION_ERROR",
|
||||||
|
message="Invalid input",
|
||||||
|
details=[{"field": "message", "error": "too long"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event == "error"
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert data["details"] == [{"field": "message", "error": "too long"}]
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestratorStreaming:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test cases for orchestrator streaming with SSE events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def orchestrator(self):
|
||||||
|
return OrchestratorService()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chat_request(self):
|
||||||
|
return ChatRequest(
|
||||||
|
session_id="test_session",
|
||||||
|
current_message="Hello",
|
||||||
|
channel_type=ChannelType.WECHAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_yields_message_events(self, orchestrator, chat_request):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that streaming yields message events with delta content.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
message_events = [e for e in events if e.event == "message"]
|
||||||
|
final_events = [e for e in events if e.event == "final"]
|
||||||
|
|
||||||
|
assert len(message_events) > 0, "Should have at least one message event"
|
||||||
|
assert len(final_events) == 1, "Should have exactly one final event"
|
||||||
|
|
||||||
|
for event in message_events:
|
||||||
|
data = json.loads(event.data)
|
||||||
|
assert "delta" in data
|
||||||
|
assert isinstance(data["delta"], str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_message_events_contain_content(self, orchestrator, chat_request):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that message events contain the expected content.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
message_events = [e for e in events if e.event == "message"]
|
||||||
|
|
||||||
|
full_content = ""
|
||||||
|
for event in message_events:
|
||||||
|
data = json.loads(event.data)
|
||||||
|
full_content += data["delta"]
|
||||||
|
|
||||||
|
assert "Hello" in full_content, "Content should contain the user message"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_event_sequence(self, orchestrator, chat_request):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07, AC-AISVC-08] Test that events follow proper sequence.
|
||||||
|
message* -> final -> close
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
event_types = [e.event for e in events]
|
||||||
|
|
||||||
|
final_index = event_types.index("final")
|
||||||
|
message_indices = [i for i, t in enumerate(event_types) if t == "message"]
|
||||||
|
|
||||||
|
for msg_idx in message_indices:
|
||||||
|
assert msg_idx < final_index, "All message events should come before final"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_with_llm_client(self, chat_request):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test streaming with mock LLM client.
|
||||||
|
"""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_chunk1 = MagicMock()
|
||||||
|
mock_chunk1.delta = "Hello"
|
||||||
|
mock_chunk1.finish_reason = None
|
||||||
|
|
||||||
|
mock_chunk2 = MagicMock()
|
||||||
|
mock_chunk2.delta = " there!"
|
||||||
|
mock_chunk2.finish_reason = None
|
||||||
|
|
||||||
|
mock_chunk3 = MagicMock()
|
||||||
|
mock_chunk3.delta = ""
|
||||||
|
mock_chunk3.finish_reason = "stop"
|
||||||
|
|
||||||
|
async def mock_stream(*args, **kwargs):
|
||||||
|
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
mock_llm.stream_generate = mock_stream
|
||||||
|
|
||||||
|
orchestrator = OrchestratorService(llm_client=mock_llm)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in orchestrator.generate_stream("tenant_001", chat_request):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
message_events = [e for e in events if e.event == "message"]
|
||||||
|
assert len(message_events) == 2, "Should have two message events"
|
||||||
|
|
||||||
|
full_content = ""
|
||||||
|
for event in message_events:
|
||||||
|
data = json.loads(event.data)
|
||||||
|
full_content += data["delta"]
|
||||||
|
|
||||||
|
assert full_content == "Hello there!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_handles_error(self, orchestrator, chat_request):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that streaming errors are converted to error events.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEStateMachineIntegration:
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Integration tests for SSE state machine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_machine_prevents_events_after_final(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-08] Test that no events can be sent after final.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
|
assert state_machine.can_send_message() is True
|
||||||
|
|
||||||
|
await state_machine.transition_to_final()
|
||||||
|
|
||||||
|
assert state_machine.can_send_message() is False
|
||||||
|
assert state_machine.state == SSEState.FINAL_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_machine_prevents_events_after_error(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-09] Test that no events can be sent after error.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
|
await state_machine.transition_to_error()
|
||||||
|
|
||||||
|
assert state_machine.can_send_message() is False
|
||||||
|
assert state_machine.state == SSEState.ERROR_SENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_machine_allows_multiple_message_events(self):
|
||||||
|
"""
|
||||||
|
[AC-AISVC-07] Test that multiple message events can be sent during streaming.
|
||||||
|
"""
|
||||||
|
state_machine = SSEStateMachine()
|
||||||
|
await state_machine.transition_to_streaming()
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
assert state_machine.can_send_message() is True
|
||||||
|
|
||||||
|
await state_machine.transition_to_final()
|
||||||
|
assert state_machine.can_send_message() is False
|
||||||
Loading…
Reference in New Issue