ai-robot-core/ai-service/app/services/orchestrator.py

197 lines
6.6 KiB
Python

"""
Orchestrator service for AI Service.
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Core orchestration logic for chat generation.
"""
import logging
from typing import AsyncGenerator
from sse_starlette.sse import ServerSentEvent
from app.models import ChatRequest, ChatResponse
from app.core.sse import create_error_event, create_final_event, create_message_event, SSEStateMachine
logger = logging.getLogger(__name__)
class OrchestratorService:
"""
[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06, AC-AISVC-07] Orchestrator for chat generation.
Coordinates memory, retrieval, and LLM components.
SSE Event Flow (per design.md Section 6.2):
- message* (0 or more) -> final (exactly 1) -> close
- OR message* (0 or more) -> error (exactly 1) -> close
"""
def __init__(self, llm_client=None):
"""
Initialize orchestrator with optional LLM client.
Args:
llm_client: Optional LLM client for dependency injection.
If None, will use mock implementation for demo.
"""
self._llm_client = llm_client
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
"""
Generate a non-streaming response.
[AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer.
"""
logger.info(
f"[AC-AISVC-01] Generating response for tenant={tenant_id}, "
f"session={request.session_id}"
)
if self._llm_client:
messages = self._build_messages(request)
response = await self._llm_client.generate(messages)
return ChatResponse(
reply=response.content,
confidence=0.85,
should_transfer=False,
)
reply = f"Received your message: {request.current_message}"
return ChatResponse(
reply=reply,
confidence=0.85,
should_transfer=False,
)
async def generate_stream(
self, tenant_id: str, request: ChatRequest
) -> AsyncGenerator[ServerSentEvent, None]:
"""
Generate a streaming response.
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
SSE Event Sequence (per design.md Section 6.2):
1. message events (multiple) - each with incremental delta
2. final event (exactly 1) - with complete response
3. connection close
OR on error:
1. message events (0 or more)
2. error event (exactly 1)
3. connection close
"""
logger.info(
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
f"session={request.session_id}"
)
state_machine = SSEStateMachine()
await state_machine.transition_to_streaming()
try:
full_reply = ""
if self._llm_client:
async for event in self._stream_from_llm(request, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
else:
async for event in self._stream_mock_response(request, state_machine):
if event.event == "message":
full_reply += self._extract_delta_from_event(event)
yield event
if await state_machine.transition_to_final():
yield create_final_event(
reply=full_reply,
confidence=0.85,
should_transfer=False,
)
except Exception as e:
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
if await state_machine.transition_to_error():
yield create_error_event(
code="GENERATION_ERROR",
message=str(e),
)
finally:
await state_machine.close()
async def _stream_from_llm(
self, request: ChatRequest, state_machine: SSEStateMachine
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Stream from LLM client, wrapping each chunk as message event.
"""
messages = self._build_messages(request)
async for chunk in self._llm_client.stream_generate(messages):
if not state_machine.can_send_message():
break
if chunk.delta:
logger.debug(f"[AC-AISVC-07] Yielding message event with delta: {chunk.delta[:50]}...")
yield create_message_event(delta=chunk.delta)
if chunk.finish_reason:
logger.info(f"[AC-AISVC-07] LLM stream finished with reason: {chunk.finish_reason}")
break
async def _stream_mock_response(
self, request: ChatRequest, state_machine: SSEStateMachine
) -> AsyncGenerator[ServerSentEvent, None]:
"""
[AC-AISVC-07] Mock streaming response for demo/testing purposes.
Simulates LLM-style incremental output.
"""
import asyncio
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
for part in reply_parts:
if not state_machine.can_send_message():
break
logger.debug(f"[AC-AISVC-07] Yielding mock message event with delta: {part}")
yield create_message_event(delta=part)
await asyncio.sleep(0.05)
def _build_messages(self, request: ChatRequest) -> list[dict[str, str]]:
"""Build messages list for LLM from request."""
messages = []
if request.history:
for msg in request.history:
messages.append({
"role": msg.role.value,
"content": msg.content,
})
messages.append({
"role": "user",
"content": request.current_message,
})
return messages
def _extract_delta_from_event(self, event: ServerSentEvent) -> str:
"""Extract delta content from a message event."""
import json
try:
if event.data:
data = json.loads(event.data)
return data.get("delta", "")
except (json.JSONDecodeError, TypeError):
pass
return ""
_orchestrator_service: OrchestratorService | None = None
def get_orchestrator_service() -> OrchestratorService:
"""Get or create orchestrator service instance."""
global _orchestrator_service
if _orchestrator_service is None:
_orchestrator_service = OrchestratorService()
return _orchestrator_service