ai-robot-core/ai-service/app/services/context.py

246 lines
8.4 KiB
Python
Raw Normal View History

"""
Context management utilities for AI Service.
[AC-AISVC-14, AC-AISVC-15] Context merging and truncation strategies.
Design reference: design.md Section 7 - 上下文合并规则
- H_local: Memory layer history (sorted by time)
- H_ext: External history from Java request (in passed order)
- Deduplication: fingerprint = hash(role + "|" + normalized(content))
- Truncation: Keep most recent N messages within token budget
"""
import hashlib
import logging
from dataclasses import dataclass, field
from typing import Any
import tiktoken
from app.core.config import get_settings
from app.models import ChatMessage, Role
logger = logging.getLogger(__name__)
@dataclass
class MergedContext:
"""
Result of context merging.
[AC-AISVC-14, AC-AISVC-15] Contains merged messages and diagnostics.
"""
messages: list[dict[str, str]] = field(default_factory=list)
total_tokens: int = 0
local_count: int = 0
external_count: int = 0
duplicates_skipped: int = 0
truncated_count: int = 0
diagnostics: list[dict[str, Any]] = field(default_factory=list)
class ContextMerger:
"""
[AC-AISVC-14, AC-AISVC-15] Context merger for combining local and external history.
Design reference: design.md Section 7
- Deduplication based on message fingerprint
- Priority: local history takes precedence
- Token-based truncation using tiktoken
"""
def __init__(
self,
max_history_tokens: int | None = None,
encoding_name: str = "cl100k_base",
):
settings = get_settings()
self._max_history_tokens = max_history_tokens or 4096
self._encoding = tiktoken.get_encoding(encoding_name)
def compute_fingerprint(self, role: str, content: str) -> str:
"""
Compute message fingerprint for deduplication.
[AC-AISVC-15] fingerprint = hash(role + "|" + normalized(content))
Args:
role: Message role (user/assistant)
content: Message content
Returns:
SHA256 hash of the normalized message
"""
normalized_content = content.strip()
fingerprint_input = f"{role}|{normalized_content}"
return hashlib.sha256(fingerprint_input.encode("utf-8")).hexdigest()
def _message_to_dict(self, message: ChatMessage | dict[str, str]) -> dict[str, str]:
"""Convert ChatMessage or dict to standard dict format."""
if isinstance(message, ChatMessage):
return {"role": message.role.value, "content": message.content}
return message
def _count_tokens(self, messages: list[dict[str, str]]) -> int:
"""
Count total tokens in messages using tiktoken.
[AC-AISVC-14] Token counting for history truncation.
"""
total = 0
for msg in messages:
total += len(self._encoding.encode(msg.get("role", "")))
total += len(self._encoding.encode(msg.get("content", "")))
total += 4 # Approximate overhead for message structure
return total
def merge_context(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
) -> MergedContext:
"""
Merge local and external history with deduplication.
[AC-AISVC-14, AC-AISVC-15] Implements context merging strategy.
Design reference: design.md Section 7.2
1. Build seen set from H_local
2. Traverse H_ext, append if fingerprint not seen
3. Local history takes priority
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
Returns:
MergedContext with merged messages and diagnostics
"""
result = MergedContext()
seen_fingerprints: set[str] = set()
merged_messages: list[dict[str, str]] = []
diagnostics: list[dict[str, Any]] = []
local_messages = [self._message_to_dict(m) for m in (local_history or [])]
external_messages = [self._message_to_dict(m) for m in (external_history or [])]
for msg in local_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.local_count += 1
for msg in external_messages:
fingerprint = self.compute_fingerprint(msg["role"], msg["content"])
if fingerprint not in seen_fingerprints:
seen_fingerprints.add(fingerprint)
merged_messages.append(msg)
result.external_count += 1
else:
result.duplicates_skipped += 1
diagnostics.append({
"type": "duplicate_skipped",
"role": msg["role"],
"content_preview": msg["content"][:50] + "..." if len(msg["content"]) > 50 else msg["content"],
})
result.messages = merged_messages
result.diagnostics = diagnostics
result.total_tokens = self._count_tokens(merged_messages)
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={result.local_count}, external={result.external_count}, "
f"duplicates_skipped={result.duplicates_skipped}, "
f"total_tokens={result.total_tokens}"
)
return result
def truncate_context(
self,
messages: list[dict[str, str]],
max_tokens: int | None = None,
) -> tuple[list[dict[str, str]], int]:
"""
Truncate context to fit within token budget.
[AC-AISVC-14] Keep most recent N messages within budget.
Design reference: design.md Section 7.4
- Budget = maxHistoryTokens (configurable)
- Strategy: Keep most recent messages (from tail backward)
Args:
messages: List of messages to truncate
max_tokens: Maximum token budget (uses default if not provided)
Returns:
Tuple of (truncated messages, truncated count)
"""
budget = max_tokens or self._max_history_tokens
if not messages:
return [], 0
total_tokens = self._count_tokens(messages)
if total_tokens <= budget:
return messages, 0
truncated_messages: list[dict[str, str]] = []
current_tokens = 0
truncated_count = 0
for msg in reversed(messages):
msg_tokens = len(self._encoding.encode(msg.get("role", "")))
msg_tokens += len(self._encoding.encode(msg.get("content", "")))
msg_tokens += 4
if current_tokens + msg_tokens <= budget:
truncated_messages.insert(0, msg)
current_tokens += msg_tokens
else:
truncated_count += 1
logger.info(
f"[AC-AISVC-14] Context truncated: "
f"original={len(messages)}, truncated={len(truncated_messages)}, "
f"removed={truncated_count}, tokens={current_tokens}/{budget}"
)
return truncated_messages, truncated_count
def merge_and_truncate(
self,
local_history: list[ChatMessage] | list[dict[str, str]] | None,
external_history: list[ChatMessage] | list[dict[str, str]] | None,
max_tokens: int | None = None,
) -> MergedContext:
"""
Merge and truncate context in one operation.
[AC-AISVC-14, AC-AISVC-15] Complete context preparation pipeline.
Args:
local_history: History from Memory layer (H_local)
external_history: History from Java request (H_ext)
max_tokens: Maximum token budget
Returns:
MergedContext with final messages after merge and truncate
"""
merged = self.merge_context(local_history, external_history)
truncated_messages, truncated_count = self.truncate_context(
merged.messages, max_tokens
)
merged.messages = truncated_messages
merged.truncated_count = truncated_count
merged.total_tokens = self._count_tokens(truncated_messages)
return merged
_context_merger: ContextMerger | None = None
def get_context_merger() -> ContextMerger:
"""Get or create context merger instance."""
global _context_merger
if _context_merger is None:
_context_merger = ContextMerger()
return _context_merger