ai-robot-core/ai-service/tests/test_context.py

288 lines
10 KiB
Python
Raw Permalink Normal View History

"""
Unit tests for Context Merger.
[AC-AISVC-14, AC-AISVC-15] Tests for context merging and truncation.
Tests cover:
- Message fingerprint computation
- Context merging with deduplication
- Token-based truncation
- Complete merge_and_truncate pipeline
"""
import hashlib
from unittest.mock import MagicMock, patch
import pytest
from app.models import ChatMessage, Role
from app.services.context import ContextMerger, MergedContext, get_context_merger
@pytest.fixture
def mock_settings():
"""Mock settings for testing."""
settings = MagicMock()
return settings
@pytest.fixture
def context_merger(mock_settings):
"""Create context merger with mocked settings."""
with patch("app.services.context.get_settings", return_value=mock_settings):
merger = ContextMerger(max_history_tokens=1000)
yield merger
@pytest.fixture
def local_history():
"""Sample local history messages."""
return [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
ChatMessage(role=Role.USER, content="How are you?"),
]
@pytest.fixture
def external_history():
"""Sample external history messages."""
return [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi there!"),
ChatMessage(role=Role.USER, content="What's the weather?"),
]
@pytest.fixture
def dict_local_history():
"""Sample local history as dicts."""
return [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
@pytest.fixture
def dict_external_history():
"""Sample external history as dicts."""
return [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "What's the weather?"},
]
class TestFingerprintComputation:
"""Tests for message fingerprint computation. [AC-AISVC-15]"""
def test_fingerprint_consistency(self, context_merger):
"""Test that same input produces same fingerprint."""
fp1 = context_merger.compute_fingerprint("user", "Hello world")
fp2 = context_merger.compute_fingerprint("user", "Hello world")
assert fp1 == fp2
def test_fingerprint_role_difference(self, context_merger):
"""Test that different roles produce different fingerprints."""
fp_user = context_merger.compute_fingerprint("user", "Hello")
fp_assistant = context_merger.compute_fingerprint("assistant", "Hello")
assert fp_user != fp_assistant
def test_fingerprint_content_difference(self, context_merger):
"""Test that different content produces different fingerprints."""
fp1 = context_merger.compute_fingerprint("user", "Hello")
fp2 = context_merger.compute_fingerprint("user", "World")
assert fp1 != fp2
def test_fingerprint_normalization(self, context_merger):
"""Test that content is normalized (trimmed)."""
fp1 = context_merger.compute_fingerprint("user", "Hello")
fp2 = context_merger.compute_fingerprint("user", " Hello ")
assert fp1 == fp2
def test_fingerprint_is_sha256(self, context_merger):
"""Test that fingerprint is SHA256 hash."""
fp = context_merger.compute_fingerprint("user", "Hello")
expected = hashlib.sha256("user|Hello".encode("utf-8")).hexdigest()
assert fp == expected
assert len(fp) == 64 # SHA256 produces 64 hex characters
class TestContextMerging:
"""Tests for context merging with deduplication. [AC-AISVC-14, AC-AISVC-15]"""
def test_merge_empty_histories(self, context_merger):
"""[AC-AISVC-14] Test merging empty histories."""
result = context_merger.merge_context(None, None)
assert isinstance(result, MergedContext)
assert result.messages == []
assert result.local_count == 0
assert result.external_count == 0
assert result.duplicates_skipped == 0
def test_merge_local_only(self, context_merger, local_history):
"""[AC-AISVC-14] Test merging with only local history (no external)."""
result = context_merger.merge_context(local_history, None)
assert len(result.messages) == 3
assert result.local_count == 3
assert result.external_count == 0
assert result.duplicates_skipped == 0
def test_merge_external_only(self, context_merger, external_history):
"""[AC-AISVC-15] Test merging with only external history (no local)."""
result = context_merger.merge_context(None, external_history)
assert len(result.messages) == 3
assert result.local_count == 0
assert result.external_count == 3
assert result.duplicates_skipped == 0
def test_merge_with_duplicates(self, context_merger, local_history, external_history):
"""[AC-AISVC-15] Test deduplication when merging overlapping histories."""
result = context_merger.merge_context(local_history, external_history)
assert len(result.messages) == 4
assert result.local_count == 3
assert result.external_count == 1
assert result.duplicates_skipped == 2
roles = [m["role"] for m in result.messages]
contents = [m["content"] for m in result.messages]
assert "What's the weather?" in contents
def test_merge_with_dict_histories(self, context_merger, dict_local_history, dict_external_history):
"""[AC-AISVC-14, AC-AISVC-15] Test merging with dict format histories."""
result = context_merger.merge_context(dict_local_history, dict_external_history)
assert len(result.messages) == 3
assert result.local_count == 2
assert result.external_count == 1
assert result.duplicates_skipped == 1
def test_merge_priority_local(self, context_merger):
"""[AC-AISVC-15] Test that local history takes priority."""
local = [ChatMessage(role=Role.USER, content="Hello")]
external = [ChatMessage(role=Role.USER, content="Hello")]
result = context_merger.merge_context(local, external)
assert len(result.messages) == 1
assert result.duplicates_skipped == 1
def test_merge_records_diagnostics(self, context_merger, local_history, external_history):
"""[AC-AISVC-15] Test that duplicates are recorded in diagnostics."""
result = context_merger.merge_context(local_history, external_history)
assert len(result.diagnostics) == 2
for diag in result.diagnostics:
assert diag["type"] == "duplicate_skipped"
assert "role" in diag
assert "content_preview" in diag
class TestTokenTruncation:
"""Tests for token-based truncation. [AC-AISVC-14]"""
def test_truncate_empty_messages(self, context_merger):
"""[AC-AISVC-14] Test truncating empty message list."""
truncated, count = context_merger.truncate_context([], 100)
assert truncated == []
assert count == 0
def test_truncate_within_budget(self, context_merger):
"""[AC-AISVC-14] Test that messages within budget are not truncated."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
truncated, count = context_merger.truncate_context(messages, 1000)
assert len(truncated) == 2
assert count == 0
def test_truncate_exceeds_budget(self, context_merger):
"""[AC-AISVC-14] Test that messages exceeding budget are truncated."""
messages = [
{"role": "user", "content": "Hello world " * 100},
{"role": "assistant", "content": "Hi there " * 100},
{"role": "user", "content": "Short message"},
]
truncated, count = context_merger.truncate_context(messages, 50)
assert len(truncated) < len(messages)
assert count > 0
def test_truncate_keeps_recent_messages(self, context_merger):
"""[AC-AISVC-14] Test that truncation keeps most recent messages."""
messages = [
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "Second message"},
{"role": "user", "content": "Third message"},
]
truncated, count = context_merger.truncate_context(messages, 20)
if count > 0:
assert "Third message" in [m["content"] for m in truncated]
def test_truncate_with_default_budget(self, context_merger):
"""[AC-AISVC-14] Test truncation with default budget from config."""
messages = [{"role": "user", "content": "Test"}]
truncated, count = context_merger.truncate_context(messages)
assert len(truncated) == 1
assert count == 0
class TestMergeAndTruncate:
"""Tests for complete merge_and_truncate pipeline. [AC-AISVC-14, AC-AISVC-15]"""
def test_merge_and_truncate_combined(self, context_merger):
"""[AC-AISVC-14, AC-AISVC-15] Test complete pipeline."""
local = [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.ASSISTANT, content="Hi"),
]
external = [
ChatMessage(role=Role.USER, content="Hello"),
ChatMessage(role=Role.USER, content="What's up?"),
]
result = context_merger.merge_and_truncate(local, external, max_tokens=1000)
assert isinstance(result, MergedContext)
assert len(result.messages) == 3
assert result.local_count == 2
assert result.external_count == 1
assert result.duplicates_skipped == 1
def test_merge_and_truncate_with_truncation(self, context_merger):
"""[AC-AISVC-14, AC-AISVC-15] Test pipeline with truncation."""
local = [
ChatMessage(role=Role.USER, content="Hello " * 50),
ChatMessage(role=Role.ASSISTANT, content="Hi " * 50),
]
external = [
ChatMessage(role=Role.USER, content="Short"),
]
result = context_merger.merge_and_truncate(local, external, max_tokens=50)
assert result.truncated_count > 0
assert result.total_tokens <= 50
class TestContextMergerSingleton:
"""Tests for singleton pattern."""
def test_get_context_merger_singleton(self, mock_settings):
"""Test that get_context_merger returns singleton."""
with patch("app.services.context.get_settings", return_value=mock_settings):
from app.services.context import _context_merger
import app.services.context as context_module
context_module._context_merger = None
merger1 = get_context_merger()
merger2 = get_context_merger()
assert merger1 is merger2