98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
|
|
"""
|
||
|
|
Orchestrator service for AI Service.
|
||
|
|
[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import AsyncGenerator
|
||
|
|
|
||
|
|
from sse_starlette.sse import ServerSentEvent
|
||
|
|
|
||
|
|
from app.models import ChatRequest, ChatResponse
|
||
|
|
from app.core.sse import create_final_event, create_message_event, SSEStateMachine
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class OrchestratorService:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation.
|
||
|
|
Coordinates memory, retrieval, and LLM components.
|
||
|
|
"""
|
||
|
|
|
||
|
|
async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse:
|
||
|
|
"""
|
||
|
|
Generate a non-streaming response.
|
||
|
|
[AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer.
|
||
|
|
"""
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-01] Generating response for tenant={tenant_id}, "
|
||
|
|
f"session={request.session_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
reply = f"Received your message: {request.current_message}"
|
||
|
|
return ChatResponse(
|
||
|
|
reply=reply,
|
||
|
|
confidence=0.85,
|
||
|
|
should_transfer=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def generate_stream(
|
||
|
|
self, tenant_id: str, request: ChatRequest
|
||
|
|
) -> AsyncGenerator[ServerSentEvent, None]:
|
||
|
|
"""
|
||
|
|
Generate a streaming response.
|
||
|
|
[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence.
|
||
|
|
"""
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, "
|
||
|
|
f"session={request.session_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
state_machine = SSEStateMachine()
|
||
|
|
await state_machine.transition_to_streaming()
|
||
|
|
|
||
|
|
try:
|
||
|
|
reply_parts = ["Received", " your", " message:", f" {request.current_message}"]
|
||
|
|
full_reply = ""
|
||
|
|
|
||
|
|
for part in reply_parts:
|
||
|
|
if state_machine.can_send_message():
|
||
|
|
full_reply += part
|
||
|
|
yield create_message_event(delta=part)
|
||
|
|
await self._simulate_llm_delay()
|
||
|
|
|
||
|
|
if await state_machine.transition_to_final():
|
||
|
|
yield create_final_event(
|
||
|
|
reply=full_reply,
|
||
|
|
confidence=0.85,
|
||
|
|
should_transfer=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"[AC-AISVC-09] Error during streaming: {e}")
|
||
|
|
if await state_machine.transition_to_error():
|
||
|
|
from app.core.sse import create_error_event
|
||
|
|
yield create_error_event(
|
||
|
|
code="GENERATION_ERROR",
|
||
|
|
message=str(e),
|
||
|
|
)
|
||
|
|
finally:
|
||
|
|
await state_machine.close()
|
||
|
|
|
||
|
|
async def _simulate_llm_delay(self) -> None:
|
||
|
|
"""Simulate LLM processing delay for demo purposes."""
|
||
|
|
import asyncio
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
|
||
|
|
_orchestrator_service: OrchestratorService | None = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_orchestrator_service() -> OrchestratorService:
|
||
|
|
"""Get or create orchestrator service instance."""
|
||
|
|
global _orchestrator_service
|
||
|
|
if _orchestrator_service is None:
|
||
|
|
_orchestrator_service = OrchestratorService()
|
||
|
|
return _orchestrator_service
|