ai-robot-core/ai-service/app/services/orchestrator.py

636 lines
23 KiB
Python

"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
Design reference: design.md Section 2.2 - 关键数据流
1. Memory.load(tenantId, sessionId)
2. merge_context(local_history, external_history)
3. Retrieval.retrieve(query, tenantId, channelType, metadata)
4. build_prompt(merged_history, retrieved_docs, currentMessage)
5. LLM.generate(...) (non-streaming) or LLM.stream_generate(...) (streaming)
6. compute_confidence(...)
7. Memory.append(tenantId, sessionId, user/assistant messages)
8. Return ChatResponse (or output via SSE)
"""
import logging
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator
from sse_starlette.sse import ServerSentEvent
from app.core.config import get_settings
from app.core.sse import (
create_error_event,
create_final_event,
create_message_event,
SSEStateMachine,
)
from app.models import ChatRequest, ChatResponse
from app.services.confidence import ConfidenceCalculator, ConfidenceResult
from app.services.context import ContextMerger, MergedContext
from app.services.llm.base import LLMClient, LLMConfig, LLMResponse
from app.services.memory import MemoryService
from app.services.retrieval.base import BaseRetriever, RetrievalContext, RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class OrchestratorConfig:
"""
Configuration for OrchestratorService.
[AC-AISVC-01] Centralized configuration for orchestration.
"""
max_history_tokens: int = 4000
max_evidence_tokens: int = 2000
system_prompt: str = "你是一个智能客服助手,请根据提供的知识库内容回答用户问题。"
enable_rag: bool = True
@dataclass
class GenerationContext:
"""
[AC-AISVC-01, AC-AISVC-02] Context accumulated during generation pipeline.
Contains all intermediate results for diagnostics and response building.
"""
tenant_id: str
session_id: str
current_message: str
channel_type: str
request_metadata: dict[str, Any] | None = None
local_history: list[dict[str, str]] = field(default_factory=list)
merged_context: MergedContext | None = None
retrieval_result: RetrievalResult | None = None
llm_response: LLMResponse | None = None
confidence_result: ConfidenceResult | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
Coordinates memory, retrieval, and LLM components.
SSE Event Flow (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
"""
def __init__(
self,
llm_client: LLMClient | None = None,
memory_service: MemoryService | None = None,
retriever: BaseRetriever | None = None,
context_merger: ContextMerger | None = None,
confidence_calculator: ConfidenceCalculator | None = None,
config: OrchestratorConfig | None = None,
):
"""
Initialize orchestrator with optional dependencies for DI.
Args:
llm_client: LLM client for generation
memory_service: Memory service for session history
retriever: Retriever for RAG
context_merger: Context merger for history deduplication
confidence_calculator: Confidence calculator for response scoring
config: Orchestrator configuration
"""
settings = get_settings()
self._llm_client = llm_client
self._memory_service = memory_service
self._retriever = retriever
self._context_merger = context_merger or ContextMerger(
max_history_tokens=getattr(settings, "max_history_tokens", 4000)
)
self._confidence_calculator = confidence_calculator or ConfidenceCalculator()
self._config = config or OrchestratorConfig(
max_history_tokens=getattr(settings, "max_history_tokens", 4000),
max_evidence_tokens=getattr(settings, "rag_max_evidence_tokens", 2000),
enable_rag=True,
)
self._llm_config = LLMConfig(
model=getattr(settings, "llm_model", "gpt-4o-mini"),
max_tokens=getattr(settings, "llm_max_tokens", 2048),
temperature=getattr(settings, "llm_temperature", 0.7),
timeout_seconds=getattr(settings, "llm_timeout_seconds", 30),
max_retries=getattr(settings, "llm_max_retries", 3),
)
async def generate(
self,
tenant_id: str,
request: ChatRequest,
) -> ChatResponse:
"""
Generate a non-streaming response.
[AC-AISVC-01, AC-AISVC-02] Complete generation pipeline.
Pipeline (per design.md Section 2.2):
1. Load local history from Memory
2. Merge with external history (dedup + truncate)
3. RAG retrieval (optional)
4. Build prompt with context and evidence
5. LLM generation
6. Calculate confidence
7. Save messages to Memory
8. Return ChatResponse
"""
logger.info(
f"[AC-AISVC-01] Starting generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
await self._generate_response(ctx)
self._calculate_confidence(ctx)
await self._save_messages(ctx)
return self._build_response(ctx)
except Exception as e:
logger.error(f"[AC-AISVC-01] Generation failed: {e}")
return ChatResponse(
reply="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
confidence=0.0,
should_transfer=True,
transfer_reason=f"服务异常: {str(e)}",
metadata={"error": str(e), "diagnostics": ctx.diagnostics},
)
async def _load_local_history(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Load local history from Memory service.
Step 1 of the generation pipeline.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping history load")
ctx.diagnostics["memory_enabled"] = False
return
try:
messages = await self._memory_service.load_history(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
)
ctx.local_history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
ctx.diagnostics["memory_enabled"] = True
ctx.diagnostics["local_history_count"] = len(ctx.local_history)
logger.info(
f"[AC-AISVC-13] Loaded {len(ctx.local_history)} messages from memory "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to load history: {e}")
ctx.diagnostics["memory_error"] = str(e)
async def _merge_context(
self,
ctx: GenerationContext,
external_history: list | None,
) -> None:
"""
[AC-AISVC-14, AC-AISVC-15] Merge local and external history.
Step 2 of the generation pipeline.
Design reference: design.md Section 7
- Deduplication based on fingerprint
- Truncation to fit token budget
"""
external_messages = None
if external_history:
external_messages = [
{"role": msg.role.value, "content": msg.content}
for msg in external_history
]
ctx.merged_context = self._context_merger.merge_and_truncate(
local_history=ctx.local_history,
external_history=external_messages,
max_tokens=self._config.max_history_tokens,
)
ctx.diagnostics["merged_context"] = {
"local_count": ctx.merged_context.local_count,
"external_count": ctx.merged_context.external_count,
"duplicates_skipped": ctx.merged_context.duplicates_skipped,
"truncated_count": ctx.merged_context.truncated_count,
"total_tokens": ctx.merged_context.total_tokens,
}
logger.info(
f"[AC-AISVC-14, AC-AISVC-15] Context merged: "
f"local={ctx.merged_context.local_count}, "
f"external={ctx.merged_context.external_count}, "
f"tokens={ctx.merged_context.total_tokens}"
)
async def _retrieve_evidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-16, AC-AISVC-17] RAG retrieval for evidence.
Step 3 of the generation pipeline.
"""
try:
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.current_message,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
ctx.retrieval_result = await self._retriever.retrieve(retrieval_ctx)
ctx.diagnostics["retrieval"] = {
"hit_count": ctx.retrieval_result.hit_count,
"max_score": ctx.retrieval_result.max_score,
"is_empty": ctx.retrieval_result.is_empty,
}
logger.info(
f"[AC-AISVC-16, AC-AISVC-17] Retrieval complete: "
f"hits={ctx.retrieval_result.hit_count}, "
f"max_score={ctx.retrieval_result.max_score:.3f}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-16] Retrieval failed: {e}")
ctx.retrieval_result = RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
)
ctx.diagnostics["retrieval_error"] = str(e)
async def _generate_response(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-02] Generate response using LLM.
Step 4-5 of the generation pipeline.
"""
messages = self._build_llm_messages(ctx)
if not self._llm_client:
logger.warning("[AC-AISVC-02] No LLM client configured, using fallback")
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="fallback",
)
ctx.diagnostics["llm_mode"] = "fallback"
return
try:
ctx.llm_response = await self._llm_client.generate(
messages=messages,
config=self._llm_config,
)
ctx.diagnostics["llm_mode"] = "live"
ctx.diagnostics["llm_model"] = ctx.llm_response.model
ctx.diagnostics["llm_usage"] = ctx.llm_response.usage
logger.info(
f"[AC-AISVC-02] LLM response generated: "
f"model={ctx.llm_response.model}, "
f"tokens={ctx.llm_response.usage}"
)
except Exception as e:
logger.error(f"[AC-AISVC-02] LLM generation failed: {e}")
ctx.llm_response = LLMResponse(
content=self._fallback_response(ctx),
model="fallback",
usage={},
finish_reason="error",
metadata={"error": str(e)},
)
ctx.diagnostics["llm_error"] = str(e)
def _build_llm_messages(self, ctx: GenerationContext) -> list[dict[str, str]]:
"""
[AC-AISVC-02] Build messages for LLM including system prompt and evidence.
"""
messages = []
system_content = self._config.system_prompt
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
evidence_text = self._format_evidence(ctx.retrieval_result)
system_content += f"\n\n知识库参考内容:\n{evidence_text}"
messages.append({"role": "system", "content": system_content})
if ctx.merged_context and ctx.merged_context.messages:
messages.extend(ctx.merged_context.messages)
messages.append({"role": "user", "content": ctx.current_message})
return messages
def _format_evidence(self, retrieval_result: RetrievalResult) -> str:
"""
[AC-AISVC-17] Format retrieval hits as evidence text.
"""
evidence_parts = []
for i, hit in enumerate(retrieval_result.hits[:5], 1):
evidence_parts.append(f"[{i}] (相关度: {hit.score:.2f}) {hit.text}")
return "\n".join(evidence_parts)
def _fallback_response(self, ctx: GenerationContext) -> str:
"""
[AC-AISVC-17] Generate fallback response when LLM is unavailable.
"""
if ctx.retrieval_result and not ctx.retrieval_result.is_empty:
return (
"根据知识库信息,我找到了一些相关内容,"
"但暂时无法生成完整回复。建议您稍后重试或联系人工客服。"
)
return (
"抱歉,我暂时无法处理您的请求。"
"请稍后重试或联系人工客服获取帮助。"
)
def _calculate_confidence(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-17, AC-AISVC-18, AC-AISVC-19] Calculate confidence score.
Step 6 of the generation pipeline.
"""
if ctx.retrieval_result:
evidence_tokens = 0
if not ctx.retrieval_result.is_empty:
evidence_tokens = sum(
len(hit.text.split()) * 2
for hit in ctx.retrieval_result.hits
)
ctx.confidence_result = self._confidence_calculator.calculate_confidence(
retrieval_result=ctx.retrieval_result,
evidence_tokens=evidence_tokens,
)
else:
ctx.confidence_result = self._confidence_calculator.calculate_confidence_no_retrieval()
ctx.diagnostics["confidence"] = {
"score": ctx.confidence_result.confidence,
"should_transfer": ctx.confidence_result.should_transfer,
"is_insufficient": ctx.confidence_result.is_retrieval_insufficient,
}
logger.info(
f"[AC-AISVC-17, AC-AISVC-18] Confidence calculated: "
f"{ctx.confidence_result.confidence:.3f}, "
f"should_transfer={ctx.confidence_result.should_transfer}"
)
async def _save_messages(self, ctx: GenerationContext) -> None:
"""
[AC-AISVC-13] Save user and assistant messages to Memory.
Step 7 of the generation pipeline.
"""
if not self._memory_service:
logger.info("[AC-AISVC-13] No memory service configured, skipping save")
return
try:
await self._memory_service.get_or_create_session(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
channel_type=ctx.channel_type,
metadata=ctx.request_metadata,
)
messages_to_save = [
{"role": "user", "content": ctx.current_message},
]
if ctx.llm_response:
messages_to_save.append({
"role": "assistant",
"content": ctx.llm_response.content,
})
await self._memory_service.append_messages(
tenant_id=ctx.tenant_id,
session_id=ctx.session_id,
messages=messages_to_save,
)
ctx.diagnostics["messages_saved"] = len(messages_to_save)
logger.info(
f"[AC-AISVC-13] Saved {len(messages_to_save)} messages "
f"for tenant={ctx.tenant_id}, session={ctx.session_id}"
)
except Exception as e:
logger.warning(f"[AC-AISVC-13] Failed to save messages: {e}")
ctx.diagnostics["save_error"] = str(e)
def _build_response(self, ctx: GenerationContext) -> ChatResponse:
"""
[AC-AISVC-02] Build final ChatResponse from generation context.
Step 8 of the generation pipeline.
"""
reply = ctx.llm_response.content if ctx.llm_response else self._fallback_response(ctx)
confidence = ctx.confidence_result.confidence if ctx.confidence_result else 0.5
should_transfer = ctx.confidence_result.should_transfer if ctx.confidence_result else True
transfer_reason = ctx.confidence_result.transfer_reason if ctx.confidence_result else None
response_metadata = {
"session_id": ctx.session_id,
"channel_type": ctx.channel_type,
"diagnostics": ctx.diagnostics,
}
return ChatResponse(
reply=reply,
confidence=confidence,
should_transfer=should_transfer,
transfer_reason=transfer_reason,
metadata=response_metadata,
)
async def generate_stream(
self,
tenant_id: str,
request: ChatRequest,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
SSE Event Sequence (per design.md Section 6.2):
1. message events (multiple) - each with incremental delta
2. final event (exactly 1) - with complete response
3. connection close
OR on error:
1. message events (0 or more)
2. error event (exactly 1)
3. connection close
"""
logger.info(
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
ctx = GenerationContext(
tenant_id=tenant_id,
session_id=request.session_id,
current_message=request.current_message,
channel_type=request.channel_type.value,
request_metadata=request.metadata,
)
try:
await self._load_local_history(ctx)
await self._merge_context(ctx, request.history)
if self._config.enable_rag and self._retriever:
await self._retrieve_evidence(ctx)
full_reply = ""
if self._llm_client:
async for event in self._stream_from_llm(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
else:
async for event in self._stream_mock_response(ctx, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
if ctx.llm_response is None:
ctx.llm_response = LLMResponse(
content=full_reply,
model="streaming",
usage={},
finish_reason="stop",
)
self._calculate_confidence(ctx)
await self._save_messages(ctx)
if await state_machine.transition_to_final():
yield create_final_event(
reply=full_reply,
confidence=ctx.confidence_result.confidence if ctx.confidence_result else 0.5,
should_transfer=ctx.confidence_result.should_transfer if ctx.confidence_result else False,
transfer_reason=ctx.confidence_result.transfer_reason if ctx.confidence_result else None,
)
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
)
finally:
await state_machine.close()
async def _stream_from_llm(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
"""
messages = self._build_llm_messages(ctx)
async for chunk in self._llm_client.stream_generate(messages, self._llm_config):
if not state_machine.can_send_message():
break
if chunk.delta:
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
yield create_message_event(delta=chunk.delta)
if chunk.finish_reason:
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
break
async def _stream_mock_response(
self,
ctx: GenerationContext,
state_machine: SSEStateMachine,
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
Simulates LLM-style incremental output.
"""
import asyncio
reply_parts = ["收到", "您的", "消息:", f" {ctx.current_message}"]
for part in reply_parts:
if not state_machine.can_send_message():
break
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
yield create_message_event(delta=part)
await asyncio.sleep(0.05)
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
"""Extract delta content from a message event."""
import json
try:
if event.data:
data = json.loads(event.data)
return data.get("delta", "")
except (json.JSONDecodeError, TypeError):
pass
return ""
_orchestrator_service: OrchestratorService | None = None
def get_orchestrator_service() -> OrchestratorService:
"""Get or create orchestrator service instance."""
global _orchestrator_service
if _orchestrator_service is None:
_orchestrator_service = OrchestratorService()
return _orchestrator_service
def set_orchestrator_service(service: OrchestratorService) -> None:
"""Set orchestrator service instance for testing."""
global _orchestrator_service
_orchestrator_service = service