diff --git a/ai-service/README.md b/ai-service/README.md new file mode 100644 index 0000000..2234221 --- /dev/null +++ b/ai-service/README.md @@ -0,0 +1,26 @@ +# AI Service + +Python AI Service for intelligent chat with RAG support. + +## Features + +- Multi-tenant isolation via X-Tenant-Id header +- SSE streaming support via Accept: text/event-stream +- RAG-powered responses with confidence scoring + +## Installation + +```bash +pip install -e ".[dev]" +``` + +## Running + +```bash +uvicorn app.main:app --host 0.0.0.0 --port 8080 +``` + +## API Endpoints + +- `POST /ai/chat` - Generate AI reply +- `GET /ai/health` - Health check diff --git a/ai-service/app/__init__.py b/ai-service/app/__init__.py new file mode 100644 index 0000000..bb68855 --- /dev/null +++ b/ai-service/app/__init__.py @@ -0,0 +1,4 @@ +""" +AI Service - Python AI Middle Platform +[AC-AISVC-01] FastAPI-based AI chat service with multi-tenant support. +""" diff --git a/ai-service/app/api/__init__.py b/ai-service/app/api/__init__.py new file mode 100644 index 0000000..b726039 --- /dev/null +++ b/ai-service/app/api/__init__.py @@ -0,0 +1,8 @@ +""" +API module for AI Service. +""" + +from app.api.chat import router as chat_router +from app.api.health import router as health_router + +__all__ = ["chat_router", "health_router"] diff --git a/ai-service/app/api/chat.py b/ai-service/app/api/chat.py new file mode 100644 index 0000000..fd58bd8 --- /dev/null +++ b/ai-service/app/api/chat.py @@ -0,0 +1,127 @@ +""" +Chat endpoint for AI Service. +[AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Main chat endpoint with streaming/non-streaming modes. +""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Header, Request +from fastapi.responses import JSONResponse +from sse_starlette.sse import EventSourceResponse + +from app.core.middleware import get_response_mode, is_sse_request +from app.core.sse import create_error_event +from app.core.tenant import get_tenant_id +from app.models import ChatRequest, ChatResponse, ErrorResponse +from app.services.orchestrator import OrchestratorService, get_orchestrator_service + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["AI Chat"]) + + +@router.post( + "/ai/chat", + operation_id="generateReply", + summary="Generate AI reply", + description=""" + [AC-AISVC-01, AC-AISVC-02, AC-AISVC-06] Generate AI reply based on user message. + + Response mode is determined by Accept header: + - Accept: text/event-stream -> SSE streaming response + - Other -> JSON response + """, + responses={ + 200: { + "description": "Success - JSON or SSE stream", + "content": { + "application/json": {"schema": {"$ref": "#/components/schemas/ChatResponse"}}, + "text/event-stream": {"schema": {"type": "string"}}, + }, + }, + 400: {"description": "Invalid request", "model": ErrorResponse}, + 500: {"description": "Internal error", "model": ErrorResponse}, + 503: {"description": "Service unavailable", "model": ErrorResponse}, + }, +) +async def generate_reply( + request: Request, + chat_request: ChatRequest, + accept: Annotated[str | None, Header()] = None, + orchestrator: OrchestratorService = Depends(get_orchestrator_service), +) -> Any: + """ + [AC-AISVC-06] Generate AI reply with automatic response mode switching. + + Based on Accept header: + - text/event-stream: Returns SSE stream with message/final/error events + - Other: Returns JSON ChatResponse + """ + tenant_id = get_tenant_id() + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + + logger.info( + f"[AC-AISVC-06] Processing chat request: tenant={tenant_id}, " + f"session={chat_request.session_id}, mode={get_response_mode(request)}" + ) + + if is_sse_request(request): + return await _handle_streaming_request(tenant_id, chat_request, orchestrator) + else: + return await _handle_json_request(tenant_id, chat_request, orchestrator) + + +async def _handle_json_request( + tenant_id: str, + chat_request: ChatRequest, + orchestrator: OrchestratorService, +) -> JSONResponse: + """ + [AC-AISVC-02] Handle non-streaming JSON request. + Returns ChatResponse with reply, confidence, shouldTransfer. + """ + logger.info(f"[AC-AISVC-02] Processing JSON request for tenant={tenant_id}") + + try: + response = await orchestrator.generate(tenant_id, chat_request) + return JSONResponse( + content=response.model_dump(exclude_none=True, by_alias=True), + ) + except Exception as e: + logger.error(f"[AC-AISVC-04] Error generating response: {e}") + from app.core.exceptions import AIServiceException, ErrorCode + if isinstance(e, AIServiceException): + raise e + from app.core.exceptions import AIServiceException + raise AIServiceException( + code=ErrorCode.INTERNAL_ERROR, + message=str(e), + ) + + +async def _handle_streaming_request( + tenant_id: str, + chat_request: ChatRequest, + orchestrator: OrchestratorService, +) -> EventSourceResponse: + """ + [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] Handle SSE streaming request. + Yields message events followed by final or error event. + """ + logger.info(f"[AC-AISVC-06] Processing SSE request for tenant={tenant_id}") + + async def event_generator(): + try: + async for event in orchestrator.generate_stream(tenant_id, chat_request): + yield event + except Exception as e: + logger.error(f"[AC-AISVC-09] Streaming error: {e}") + yield create_error_event( + code="STREAMING_ERROR", + message=str(e), + ) + + return EventSourceResponse(event_generator(), ping=15) diff --git a/ai-service/app/api/health.py b/ai-service/app/api/health.py new file mode 100644 index 0000000..4d2366e --- /dev/null +++ b/ai-service/app/api/health.py @@ -0,0 +1,30 @@ +""" +Health check endpoint. +[AC-AISVC-20] Health check for service monitoring. +""" + +from fastapi import APIRouter, status +from fastapi.responses import JSONResponse + +router = APIRouter(tags=["Health"]) + + +@router.get( + "/ai/health", + operation_id="healthCheck", + summary="Health check", + description="[AC-AISVC-20] Check if AI service is healthy", + responses={ + 200: {"description": "Service is healthy"}, + 503: {"description": "Service is unhealthy"}, + }, +) +async def health_check() -> JSONResponse: + """ + [AC-AISVC-20] Health check endpoint. + Returns 200 with status if healthy, 503 if not. + """ + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "healthy"}, + ) diff --git a/ai-service/app/core/__init__.py b/ai-service/app/core/__init__.py new file mode 100644 index 0000000..dee8983 --- /dev/null +++ b/ai-service/app/core/__init__.py @@ -0,0 +1,19 @@ +""" +Core module - Configuration, dependencies, and utilities. +[AC-AISVC-01, AC-AISVC-10, AC-AISVC-11] Core infrastructure components. +""" + +from app.core.config import Settings, get_settings +from app.core.database import async_session_maker, get_session, init_db, close_db +from app.core.qdrant_client import QdrantClient, get_qdrant_client + +__all__ = [ + "Settings", + "get_settings", + "async_session_maker", + "get_session", + "init_db", + "close_db", + "QdrantClient", + "get_qdrant_client", +] diff --git a/ai-service/app/core/config.py b/ai-service/app/core/config.py new file mode 100644 index 0000000..913df53 --- /dev/null +++ b/ai-service/app/core/config.py @@ -0,0 +1,54 @@ +""" +Configuration management for AI Service. +[AC-AISVC-01] Centralized configuration with environment variable support. +""" + +from functools import lru_cache + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="AI_SERVICE_", env_file=".env", extra="ignore") + + app_name: str = "AI Service" + app_version: str = "0.1.0" + debug: bool = False + + host: str = "0.0.0.0" + port: int = 8080 + + request_timeout_seconds: int = 20 + sse_ping_interval_seconds: int = 15 + + log_level: str = "INFO" + + llm_provider: str = "openai" + llm_api_key: str = "" + llm_base_url: str = "https://api.openai.com/v1" + llm_model: str = "gpt-4o-mini" + llm_max_tokens: int = 2048 + llm_temperature: float = 0.7 + llm_timeout_seconds: int = 30 + llm_max_retries: int = 3 + + database_url: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/ai_service" + database_pool_size: int = 10 + database_max_overflow: int = 20 + + qdrant_url: str = "http://localhost:6333" + qdrant_collection_prefix: str = "kb_" + qdrant_vector_size: int = 1536 + + rag_top_k: int = 5 + rag_score_threshold: float = 0.7 + rag_min_hits: int = 1 + rag_max_evidence_tokens: int = 2000 + + confidence_threshold_low: float = 0.5 + max_history_tokens: int = 4000 + + +@lru_cache +def get_settings() -> Settings: + return Settings() diff --git a/ai-service/app/core/database.py b/ai-service/app/core/database.py new file mode 100644 index 0000000..b15cde1 --- /dev/null +++ b/ai-service/app/core/database.py @@ -0,0 +1,67 @@ +""" +Database client for AI Service. +[AC-AISVC-11] PostgreSQL database with SQLModel for multi-tenant data isolation. +""" + +import logging +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool +from sqlmodel import SQLModel + +from app.core.config import get_settings + +logger = logging.getLogger(__name__) + +settings = get_settings() + +engine = create_async_engine( + settings.database_url, + pool_size=settings.database_pool_size, + max_overflow=settings.database_max_overflow, + echo=settings.debug, + pool_pre_ping=True, +) + +async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) + + +async def init_db() -> None: + """ + [AC-AISVC-11] Initialize database tables. + Creates all tables defined in SQLModel metadata. + """ + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + logger.info("[AC-AISVC-11] Database tables initialized") + + +async def close_db() -> None: + """ + Close database connections. + """ + await engine.dispose() + logger.info("Database connections closed") + + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """ + [AC-AISVC-11] Dependency injection for database session. + Ensures proper session lifecycle management. + """ + async with async_session_maker() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() diff --git a/ai-service/app/core/exceptions.py b/ai-service/app/core/exceptions.py new file mode 100644 index 0000000..dfd3262 --- /dev/null +++ b/ai-service/app/core/exceptions.py @@ -0,0 +1,99 @@ +""" +Exception handling for AI Service. +[AC-AISVC-03, AC-AISVC-04, AC-AISVC-05] Structured error responses. +""" + +from fastapi import HTTPException, Request, status +from fastapi.responses import JSONResponse + +from app.models import ErrorCode, ErrorResponse + + +class AIServiceException(Exception): + def __init__( + self, + code: ErrorCode, + message: str, + status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, + details: list[dict] | None = None, + ): + self.code = code + self.message = message + self.status_code = status_code + self.details = details + super().__init__(message) + + +class MissingTenantIdException(AIServiceException): + def __init__(self, message: str = "Missing required header: X-Tenant-Id"): + super().__init__( + code=ErrorCode.MISSING_TENANT_ID, + message=message, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + +class InvalidRequestException(AIServiceException): + def __init__(self, message: str, details: list[dict] | None = None): + super().__init__( + code=ErrorCode.INVALID_REQUEST, + message=message, + status_code=status.HTTP_400_BAD_REQUEST, + details=details, + ) + + +class ServiceUnavailableException(AIServiceException): + def __init__(self, message: str = "Service temporarily unavailable"): + super().__init__( + code=ErrorCode.SERVICE_UNAVAILABLE, + message=message, + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + +class TimeoutException(AIServiceException): + def __init__(self, message: str = "Request timeout"): + super().__init__( + code=ErrorCode.TIMEOUT, + message=message, + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + +async def ai_service_exception_handler(request: Request, exc: AIServiceException) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content=ErrorResponse( + code=exc.code.value, + message=exc.message, + details=exc.details, + ).model_dump(exclude_none=True), + ) + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + if exc.status_code == status.HTTP_400_BAD_REQUEST: + code = ErrorCode.INVALID_REQUEST + elif exc.status_code == status.HTTP_503_SERVICE_UNAVAILABLE: + code = ErrorCode.SERVICE_UNAVAILABLE + else: + code = ErrorCode.INTERNAL_ERROR + + return JSONResponse( + status_code=exc.status_code, + content=ErrorResponse( + code=code.value, + message=exc.detail or "An error occurred", + ).model_dump(exclude_none=True), + ) + + +async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ErrorResponse( + code=ErrorCode.INTERNAL_ERROR.value, + message="An unexpected error occurred", + ).model_dump(exclude_none=True), + ) diff --git a/ai-service/app/core/middleware.py b/ai-service/app/core/middleware.py new file mode 100644 index 0000000..8e87288 --- /dev/null +++ b/ai-service/app/core/middleware.py @@ -0,0 +1,74 @@ +""" +Middleware for AI Service. +[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection. +""" + +import logging +from typing import Callable + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException +from app.core.tenant import clear_tenant_context, set_tenant_context + +logger = logging.getLogger(__name__) + +TENANT_ID_HEADER = "X-Tenant-Id" +ACCEPT_HEADER = "Accept" +SSE_CONTENT_TYPE = "text/event-stream" + + +class TenantContextMiddleware(BaseHTTPMiddleware): + """ + [AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header. + Injects tenant context into request state for downstream processing. + """ + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + clear_tenant_context() + + if request.url.path == "/ai/health": + return await call_next(request) + + tenant_id = request.headers.get(TENANT_ID_HEADER) + + if not tenant_id or not tenant_id.strip(): + logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=ErrorResponse( + code=ErrorCode.MISSING_TENANT_ID.value, + message="Missing required header: X-Tenant-Id", + ).model_dump(exclude_none=True), + ) + + set_tenant_context(tenant_id.strip()) + request.state.tenant_id = tenant_id.strip() + + logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}") + + try: + response = await call_next(request) + finally: + clear_tenant_context() + + return response + + +def is_sse_request(request: Request) -> bool: + """ + [AC-AISVC-06] Check if the request expects SSE streaming response. + Based on Accept header: text/event-stream indicates SSE mode. + """ + accept_header = request.headers.get(ACCEPT_HEADER, "") + return SSE_CONTENT_TYPE in accept_header + + +def get_response_mode(request: Request) -> str: + """ + [AC-AISVC-06] Determine response mode based on Accept header. + Returns 'streaming' for SSE, 'json' for regular JSON response. + """ + return "streaming" if is_sse_request(request) else "json" diff --git a/ai-service/app/core/qdrant_client.py b/ai-service/app/core/qdrant_client.py new file mode 100644 index 0000000..1f8f21f --- /dev/null +++ b/ai-service/app/core/qdrant_client.py @@ -0,0 +1,175 @@ +""" +Qdrant client for AI Service. +[AC-AISVC-10] Vector database client with tenant-isolated collection management. +""" + +import logging +from typing import Any + +from qdrant_client import AsyncQdrantClient +from qdrant_client.models import Distance, PointStruct, VectorParams + +from app.core.config import get_settings + +logger = logging.getLogger(__name__) + +settings = get_settings() + + +class QdrantClient: + """ + [AC-AISVC-10] Qdrant client with tenant-isolated collection management. + Collection naming: kb_{tenantId} for tenant isolation. + """ + + def __init__(self): + self._client: AsyncQdrantClient | None = None + self._collection_prefix = settings.qdrant_collection_prefix + self._vector_size = settings.qdrant_vector_size + + async def get_client(self) -> AsyncQdrantClient: + """Get or create Qdrant client instance.""" + if self._client is None: + self._client = AsyncQdrantClient(url=settings.qdrant_url) + logger.info(f"[AC-AISVC-10] Qdrant client initialized: {settings.qdrant_url}") + return self._client + + async def close(self) -> None: + """Close Qdrant client connection.""" + if self._client: + await self._client.close() + self._client = None + logger.info("Qdrant client connection closed") + + def get_collection_name(self, tenant_id: str) -> str: + """ + [AC-AISVC-10] Get collection name for a tenant. + Naming convention: kb_{tenantId} + """ + return f"{self._collection_prefix}{tenant_id}" + + async def ensure_collection_exists(self, tenant_id: str) -> bool: + """ + [AC-AISVC-10] Ensure collection exists for tenant. + Note: MVP uses pre-provisioned collections, this is for development/testing. + """ + client = await self.get_client() + collection_name = self.get_collection_name(tenant_id) + + try: + collections = await client.get_collections() + exists = any(c.name == collection_name for c in collections.collections) + + if not exists: + await client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=self._vector_size, + distance=Distance.COSINE, + ), + ) + logger.info( + f"[AC-AISVC-10] Created collection: {collection_name} for tenant={tenant_id}" + ) + return True + except Exception as e: + logger.error(f"[AC-AISVC-10] Error ensuring collection: {e}") + return False + + async def upsert_vectors( + self, + tenant_id: str, + points: list[PointStruct], + ) -> bool: + """ + [AC-AISVC-10] Upsert vectors into tenant's collection. + """ + client = await self.get_client() + collection_name = self.get_collection_name(tenant_id) + + try: + await client.upsert( + collection_name=collection_name, + points=points, + ) + logger.info( + f"[AC-AISVC-10] Upserted {len(points)} vectors for tenant={tenant_id}" + ) + return True + except Exception as e: + logger.error(f"[AC-AISVC-10] Error upserting vectors: {e}") + return False + + async def search( + self, + tenant_id: str, + query_vector: list[float], + limit: int = 5, + score_threshold: float | None = None, + ) -> list[dict[str, Any]]: + """ + [AC-AISVC-10] Search vectors in tenant's collection. + Returns results with score >= score_threshold if specified. + """ + client = await self.get_client() + collection_name = self.get_collection_name(tenant_id) + + try: + results = await client.search( + collection_name=collection_name, + query_vector=query_vector, + limit=limit, + score_threshold=score_threshold, + ) + + hits = [ + { + "id": str(result.id), + "score": result.score, + "payload": result.payload or {}, + } + for result in results + ] + + logger.info( + f"[AC-AISVC-10] Search returned {len(hits)} results for tenant={tenant_id}" + ) + return hits + except Exception as e: + logger.error(f"[AC-AISVC-10] Error searching vectors: {e}") + return [] + + async def delete_collection(self, tenant_id: str) -> bool: + """ + [AC-AISVC-10] Delete tenant's collection. + Used when tenant is removed. + """ + client = await self.get_client() + collection_name = self.get_collection_name(tenant_id) + + try: + await client.delete_collection(collection_name=collection_name) + logger.info(f"[AC-AISVC-10] Deleted collection: {collection_name}") + return True + except Exception as e: + logger.error(f"[AC-AISVC-10] Error deleting collection: {e}") + return False + + +_qdrant_client: QdrantClient | None = None + + +async def get_qdrant_client() -> QdrantClient: + """Get or create Qdrant client instance.""" + global _qdrant_client + if _qdrant_client is None: + _qdrant_client = QdrantClient() + return _qdrant_client + + +async def close_qdrant_client() -> None: + """Close Qdrant client connection.""" + global _qdrant_client + if _qdrant_client: + await _qdrant_client.close() + _qdrant_client = None diff --git a/ai-service/app/core/sse.py b/ai-service/app/core/sse.py new file mode 100644 index 0000000..91d892f --- /dev/null +++ b/ai-service/app/core/sse.py @@ -0,0 +1,170 @@ +""" +SSE utilities for AI Service. +[AC-AISVC-06, AC-AISVC-07, AC-AISVC-08, AC-AISVC-09] SSE event generation and state machine. +""" + +import asyncio +import json +import logging +from enum import Enum +from typing import Any, AsyncGenerator + +from sse_starlette.sse import EventSourceResponse, ServerSentEvent + +from app.core.config import get_settings +from app.models import SSEErrorEvent, SSEEventType, SSEFinalEvent, SSEMessageEvent + +logger = logging.getLogger(__name__) + + +class SSEState(str, Enum): + INIT = "INIT" + STREAMING = "STREAMING" + FINAL_SENT = "FINAL_SENT" + ERROR_SENT = "ERROR_SENT" + CLOSED = "CLOSED" + + +class SSEStateMachine: + """ + [AC-AISVC-08, AC-AISVC-09] SSE state machine ensuring proper event sequence. + State transitions: INIT -> STREAMING -> FINAL_SENT/ERROR_SENT -> CLOSED + """ + + def __init__(self): + self._state = SSEState.INIT + self._lock = asyncio.Lock() + + @property + def state(self) -> SSEState: + return self._state + + async def transition_to_streaming(self) -> bool: + async with self._lock: + if self._state == SSEState.INIT: + self._state = SSEState.STREAMING + logger.debug(f"[AC-AISVC-07] SSE state transition: INIT -> STREAMING") + return True + return False + + async def transition_to_final(self) -> bool: + async with self._lock: + if self._state == SSEState.STREAMING: + self._state = SSEState.FINAL_SENT + logger.debug(f"[AC-AISVC-08] SSE state transition: STREAMING -> FINAL_SENT") + return True + return False + + async def transition_to_error(self) -> bool: + async with self._lock: + if self._state in (SSEState.INIT, SSEState.STREAMING): + self._state = SSEState.ERROR_SENT + logger.debug(f"[AC-AISVC-09] SSE state transition: {self._state} -> ERROR_SENT") + return True + return False + + async def close(self) -> None: + async with self._lock: + self._state = SSEState.CLOSED + logger.debug("SSE state transition: -> CLOSED") + + def can_send_message(self) -> bool: + return self._state == SSEState.STREAMING + + +def format_sse_event(event_type: SSEEventType, data: dict[str, Any]) -> ServerSentEvent: + """Format data as SSE event.""" + return ServerSentEvent( + event=event_type.value, + data=json.dumps(data, ensure_ascii=False), + ) + + +def create_message_event(delta: str) -> ServerSentEvent: + """[AC-AISVC-07] Create a message event with incremental content.""" + event_data = SSEMessageEvent(delta=delta) + return format_sse_event(SSEEventType.MESSAGE, event_data.model_dump()) + + +def create_final_event( + reply: str, + confidence: float, + should_transfer: bool, + transfer_reason: str | None = None, + metadata: dict[str, Any] | None = None, +) -> ServerSentEvent: + """[AC-AISVC-08] Create a final event with complete response.""" + event_data = SSEFinalEvent( + reply=reply, + confidence=confidence, + should_transfer=should_transfer, + transfer_reason=transfer_reason, + metadata=metadata, + ) + return format_sse_event(SSEEventType.FINAL, event_data.model_dump(exclude_none=True)) + + +def create_error_event( + code: str, + message: str, + details: list[dict[str, Any]] | None = None, +) -> ServerSentEvent: + """[AC-AISVC-09] Create an error event.""" + event_data = SSEErrorEvent( + code=code, + message=message, + details=details, + ) + return format_sse_event(SSEEventType.ERROR, event_data.model_dump(exclude_none=True)) + + +async def ping_generator(interval_seconds: int) -> AsyncGenerator[str, None]: + """ + [AC-AISVC-06] Generate ping comments for SSE keep-alive. + Sends ': ping' as comment lines (not events) to keep connection alive. + """ + while True: + await asyncio.sleep(interval_seconds) + yield ": ping\n\n" + + +class SSEResponseBuilder: + """ + Builder for SSE response with proper event sequencing and ping keep-alive. + """ + + def __init__(self): + self._state_machine = SSEStateMachine() + self._settings = get_settings() + + async def build_response( + self, + content_generator: AsyncGenerator[ServerSentEvent, None], + ) -> EventSourceResponse: + """ + Build SSE response with ping keep-alive mechanism. + [AC-AISVC-06] Implements ping keep-alive to prevent connection timeout. + """ + + async def event_generator() -> AsyncGenerator[ServerSentEvent, None]: + await self._state_machine.transition_to_streaming() + try: + async for event in content_generator: + if self._state_machine.can_send_message(): + yield event + else: + break + except Exception as e: + logger.error(f"[AC-AISVC-09] Error during SSE streaming: {e}") + if await self._state_machine.transition_to_error(): + yield create_error_event( + code="STREAMING_ERROR", + message=str(e), + ) + finally: + await self._state_machine.close() + + return EventSourceResponse( + event_generator(), + ping=self._settings.sse_ping_interval_seconds, + ) diff --git a/ai-service/app/core/tenant.py b/ai-service/app/core/tenant.py new file mode 100644 index 0000000..8baf4a6 --- /dev/null +++ b/ai-service/app/core/tenant.py @@ -0,0 +1,31 @@ +""" +Tenant context management. +[AC-AISVC-10, AC-AISVC-12] Multi-tenant isolation via X-Tenant-Id header. +""" + +from contextvars import ContextVar +from dataclasses import dataclass + +tenant_context: ContextVar["TenantContext | None"] = ContextVar("tenant_context", default=None) + + +@dataclass +class TenantContext: + tenant_id: str + + +def set_tenant_context(tenant_id: str) -> None: + tenant_context.set(TenantContext(tenant_id=tenant_id)) + + +def get_tenant_context() -> TenantContext | None: + return tenant_context.get() + + +def get_tenant_id() -> str | None: + ctx = get_tenant_context() + return ctx.tenant_id if ctx else None + + +def clear_tenant_context() -> None: + tenant_context.set(None) diff --git a/ai-service/app/main.py b/ai-service/app/main.py new file mode 100644 index 0000000..5746604 --- /dev/null +++ b/ai-service/app/main.py @@ -0,0 +1,123 @@ +""" +Main FastAPI application for AI Service. +[AC-AISVC-01] Entry point with middleware and exception handlers. +""" + +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import HTTPException, RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from app.api import chat_router, health_router +from app.core.config import get_settings +from app.core.database import close_db, init_db +from app.core.exceptions import ( + AIServiceException, + ErrorCode, + ErrorResponse, + ai_service_exception_handler, + generic_exception_handler, + http_exception_handler, +) +from app.core.middleware import TenantContextMiddleware +from app.core.qdrant_client import close_qdrant_client + +settings = get_settings() + +logging.basicConfig( + level=getattr(logging, settings.log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + [AC-AISVC-01, AC-AISVC-11] Application lifespan manager. + Handles startup and shutdown of database and external connections. + """ + logger.info(f"[AC-AISVC-01] Starting {settings.app_name} v{settings.app_version}") + + try: + await init_db() + logger.info("[AC-AISVC-11] Database initialized successfully") + except Exception as e: + logger.warning(f"[AC-AISVC-11] Database initialization skipped: {e}") + + yield + + await close_db() + await close_qdrant_client() + logger.info(f"Shutting down {settings.app_name}") + + +app = FastAPI( + title=settings.app_name, + version=settings.app_version, + description=""" + Python AI Service for intelligent chat with RAG support. + + ## Features + - Multi-tenant isolation via X-Tenant-Id header + - SSE streaming support via Accept: text/event-stream + - RAG-powered responses with confidence scoring + + ## Response Modes + - **JSON**: Default response mode (Accept: application/json or no Accept header) + - **SSE Streaming**: Set Accept: text/event-stream for streaming responses + """, + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.add_middleware(TenantContextMiddleware) + +app.add_exception_handler(AIServiceException, ai_service_exception_handler) +app.add_exception_handler(HTTPException, http_exception_handler) +app.add_exception_handler(Exception, generic_exception_handler) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """ + [AC-AISVC-03] Handle request validation errors with structured response. + """ + logger.warning(f"[AC-AISVC-03] Request validation error: {exc.errors()}") + error_response = ErrorResponse( + code=ErrorCode.INVALID_REQUEST.value, + message="Request validation failed", + details=[{"loc": list(err["loc"]), "msg": err["msg"], "type": err["type"]} for err in exc.errors()], + ) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=error_response.model_dump(exclude_none=True), + ) + + +app.include_router(health_router) +app.include_router(chat_router) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + ) diff --git a/ai-service/app/models/__init__.py b/ai-service/app/models/__init__.py new file mode 100644 index 0000000..a97360e --- /dev/null +++ b/ai-service/app/models/__init__.py @@ -0,0 +1,88 @@ +""" +Data models for AI Service. +[AC-AISVC-02] Request/Response models aligned with OpenAPI contract. +[AC-AISVC-13] Entity models for database persistence. +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class ChannelType(str, Enum): + WECHAT = "wechat" + DOUYIN = "douyin" + JD = "jd" + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + +class ChatMessage(BaseModel): + role: Role = Field(..., description="Message role: user or assistant") + content: str = Field(..., description="Message content") + + +class ChatRequest(BaseModel): + session_id: str = Field(..., alias="sessionId", description="Session ID for conversation tracking") + current_message: str = Field(..., alias="currentMessage", description="Current user message") + channel_type: ChannelType = Field(..., alias="channelType", description="Channel type: wechat, douyin, jd") + history: list[ChatMessage] | None = Field(default=None, description="Optional conversation history") + metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata") + + model_config = {"populate_by_name": True} + + +class ChatResponse(BaseModel): + reply: str = Field(..., description="AI generated reply content") + confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0.0 and 1.0") + should_transfer: bool = Field(..., alias="shouldTransfer", description="Whether to suggest transfer to human agent") + transfer_reason: str | None = Field(default=None, alias="transferReason", description="Reason for transfer suggestion") + metadata: dict[str, Any] | None = Field(default=None, description="Response metadata") + + model_config = {"populate_by_name": True} + + +class ErrorCode(str, Enum): + INVALID_REQUEST = "INVALID_REQUEST" + MISSING_TENANT_ID = "MISSING_TENANT_ID" + INTERNAL_ERROR = "INTERNAL_ERROR" + SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" + TIMEOUT = "TIMEOUT" + LLM_ERROR = "LLM_ERROR" + RETRIEVAL_ERROR = "RETRIEVAL_ERROR" + + +class ErrorResponse(BaseModel): + code: str = Field(..., description="Error code") + message: str = Field(..., description="Error message") + details: list[dict[str, Any]] | None = Field(default=None, description="Detailed error information") + + +class SSEEventType(str, Enum): + MESSAGE = "message" + FINAL = "final" + ERROR = "error" + + +class SSEMessageEvent(BaseModel): + delta: str = Field(..., description="Incremental text content") + + +class SSEFinalEvent(BaseModel): + reply: str = Field(..., description="Complete AI reply") + confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score") + should_transfer: bool = Field(..., alias="shouldTransfer", description="Transfer suggestion") + transfer_reason: str | None = Field(default=None, alias="transferReason", description="Transfer reason") + metadata: dict[str, Any] | None = Field(default=None, description="Response metadata") + + model_config = {"populate_by_name": True} + + +class SSEErrorEvent(BaseModel): + code: str = Field(..., description="Error code") + message: str = Field(..., description="Error message") + details: list[dict[str, Any]] | None = Field(default=None, description="Error details") diff --git a/ai-service/app/models/entities.py b/ai-service/app/models/entities.py new file mode 100644 index 0000000..df329b5 --- /dev/null +++ b/ai-service/app/models/entities.py @@ -0,0 +1,74 @@ +""" +Memory layer entities for AI Service. +[AC-AISVC-13] SQLModel entities for chat sessions and messages with tenant isolation. +""" + +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import Column, JSON +from sqlmodel import Field, Index, SQLModel + + +class ChatSession(SQLModel, table=True): + """ + [AC-AISVC-13] Chat session entity with tenant isolation. + Primary key: (tenant_id, session_id) composite unique constraint. + """ + + __tablename__ = "chat_sessions" + __table_args__ = ( + Index("ix_chat_sessions_tenant_session", "tenant_id", "session_id", unique=True), + Index("ix_chat_sessions_tenant_id", "tenant_id"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) + session_id: str = Field(..., description="Session ID for conversation tracking") + channel_type: str | None = Field(default=None, description="Channel type: wechat, douyin, jd") + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=Column("metadata", JSON, nullable=True), + description="Session metadata" + ) + created_at: datetime = Field(default_factory=datetime.utcnow, description="Session creation time") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") + + +class ChatMessage(SQLModel, table=True): + """ + [AC-AISVC-13] Chat message entity with tenant isolation. + Messages are scoped by (tenant_id, session_id) for multi-tenant security. + """ + + __tablename__ = "chat_messages" + __table_args__ = ( + Index("ix_chat_messages_tenant_session", "tenant_id", "session_id"), + Index("ix_chat_messages_tenant_session_created", "tenant_id", "session_id", "created_at"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) + session_id: str = Field(..., description="Session ID for conversation tracking", index=True) + role: str = Field(..., description="Message role: user or assistant") + content: str = Field(..., description="Message content") + created_at: datetime = Field(default_factory=datetime.utcnow, description="Message creation time") + + +class ChatSessionCreate(SQLModel): + """Schema for creating a new chat session.""" + + tenant_id: str + session_id: str + channel_type: str | None = None + metadata_: dict[str, Any] | None = None + + +class ChatMessageCreate(SQLModel): + """Schema for creating a new chat message.""" + + tenant_id: str + session_id: str + role: str + content: str diff --git a/ai-service/app/services/__init__.py b/ai-service/app/services/__init__.py new file mode 100644 index 0000000..22c50c7 --- /dev/null +++ b/ai-service/app/services/__init__.py @@ -0,0 +1,9 @@ +""" +Services module for AI Service. +[AC-AISVC-13, AC-AISVC-16] Core services for memory and retrieval. +""" + +from app.services.memory import MemoryService +from app.services.orchestrator import OrchestratorService, get_orchestrator_service + +__all__ = ["MemoryService", "OrchestratorService", "get_orchestrator_service"] diff --git a/ai-service/app/services/llm/__init__.py b/ai-service/app/services/llm/__init__.py new file mode 100644 index 0000000..f616fac --- /dev/null +++ b/ai-service/app/services/llm/__init__.py @@ -0,0 +1,15 @@ +""" +LLM Adapter module for AI Service. +[AC-AISVC-02, AC-AISVC-06] Provides unified interface for LLM providers. +""" + +from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk +from app.services.llm.openai_client import OpenAIClient + +__all__ = [ + "LLMClient", + "LLMConfig", + "LLMResponse", + "LLMStreamChunk", + "OpenAIClient", +] diff --git a/ai-service/app/services/llm/base.py b/ai-service/app/services/llm/base.py new file mode 100644 index 0000000..cf46d3c --- /dev/null +++ b/ai-service/app/services/llm/base.py @@ -0,0 +1,115 @@ +""" +Base LLM client interface. +[AC-AISVC-02, AC-AISVC-06] Abstract interface for LLM providers. + +Design reference: design.md Section 8.1 - LLMClient interface +- generate(prompt, params) -> text +- stream_generate(prompt, params) -> iterator[delta] +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator + + +@dataclass +class LLMConfig: + """ + Configuration for LLM client. + [AC-AISVC-02] Supports configurable model parameters. + """ + model: str = "gpt-4o-mini" + max_tokens: int = 2048 + temperature: float = 0.7 + top_p: float = 1.0 + timeout_seconds: int = 30 + max_retries: int = 3 + extra_params: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LLMResponse: + """ + Response from LLM generation. + [AC-AISVC-02] Contains generated content and metadata. + """ + content: str + model: str + usage: dict[str, int] = field(default_factory=dict) + finish_reason: str = "stop" + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LLMStreamChunk: + """ + Streaming chunk from LLM. + [AC-AISVC-06, AC-AISVC-07] Incremental output for SSE streaming. + """ + delta: str + model: str + finish_reason: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class LLMClient(ABC): + """ + Abstract base class for LLM clients. + [AC-AISVC-02, AC-AISVC-06] Provides unified interface for different LLM providers. + + Design reference: design.md Section 8.2 - Plugin points + - OpenAICompatibleClient / LocalModelClient can be swapped + """ + + @abstractmethod + async def generate( + self, + messages: list[dict[str, str]], + config: LLMConfig | None = None, + **kwargs: Any, + ) -> LLMResponse: + """ + Generate a non-streaming response. + [AC-AISVC-02] Returns complete response for ChatResponse. + + Args: + messages: List of chat messages with 'role' and 'content'. + config: Optional LLM configuration overrides. + **kwargs: Additional provider-specific parameters. + + Returns: + LLMResponse with generated content and metadata. + + Raises: + LLMException: If generation fails. + """ + pass + + @abstractmethod + async def stream_generate( + self, + messages: list[dict[str, str]], + config: LLMConfig | None = None, + **kwargs: Any, + ) -> AsyncGenerator[LLMStreamChunk, None]: + """ + Generate a streaming response. + [AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE. + + Args: + messages: List of chat messages with 'role' and 'content'. + config: Optional LLM configuration overrides. + **kwargs: Additional provider-specific parameters. + + Yields: + LLMStreamChunk with incremental content. + + Raises: + LLMException: If generation fails. + """ + pass + + @abstractmethod + async def close(self) -> None: + """Close the client and release resources.""" + pass diff --git a/ai-service/app/services/llm/openai_client.py b/ai-service/app/services/llm/openai_client.py new file mode 100644 index 0000000..882ef93 --- /dev/null +++ b/ai-service/app/services/llm/openai_client.py @@ -0,0 +1,319 @@ +""" +OpenAI-compatible LLM client implementation. +[AC-AISVC-02, AC-AISVC-06] Concrete implementation using httpx for OpenAI API. + +Design reference: design.md Section 8.1 - LLMClient interface +- Uses langchain-openai or official SDK pattern +- Supports generate and stream_generate +""" + +import json +import logging +from typing import Any, AsyncGenerator + +import httpx +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from app.core.config import get_settings +from app.core.exceptions import AIServiceException, ErrorCode, ServiceUnavailableException, TimeoutException +from app.services.llm.base import LLMClient, LLMConfig, LLMResponse, LLMStreamChunk + +logger = logging.getLogger(__name__) + + +class LLMException(AIServiceException): + """Exception raised when LLM operations fail.""" + + def __init__(self, message: str, details: list[dict] | None = None): + super().__init__( + code=ErrorCode.LLM_ERROR, + message=message, + status_code=503, + details=details, + ) + + +class OpenAIClient(LLMClient): + """ + OpenAI-compatible LLM client. + [AC-AISVC-02, AC-AISVC-06] Implements LLMClient interface for OpenAI API. + + Supports: + - OpenAI API (official) + - OpenAI-compatible endpoints (Azure, local models, etc.) + """ + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + model: str | None = None, + default_config: LLMConfig | None = None, + ): + settings = get_settings() + self._api_key = api_key or settings.llm_api_key + self._base_url = (base_url or settings.llm_base_url).rstrip("/") + self._model = model or settings.llm_model + self._default_config = default_config or LLMConfig( + model=self._model, + max_tokens=settings.llm_max_tokens, + temperature=settings.llm_temperature, + timeout_seconds=settings.llm_timeout_seconds, + max_retries=settings.llm_max_retries, + ) + self._client: httpx.AsyncClient | None = None + + def _get_client(self, timeout_seconds: int) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(timeout_seconds), + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + ) + return self._client + + def _build_request_body( + self, + messages: list[dict[str, str]], + config: LLMConfig, + stream: bool = False, + **kwargs: Any, + ) -> dict[str, Any]: + """Build request body for OpenAI API.""" + body: dict[str, Any] = { + "model": config.model, + "messages": messages, + "max_tokens": config.max_tokens, + "temperature": config.temperature, + "top_p": config.top_p, + "stream": stream, + } + body.update(config.extra_params) + body.update(kwargs) + return body + + @retry( + retry=retry_if_exception_type(httpx.TimeoutException), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + async def generate( + self, + messages: list[dict[str, str]], + config: LLMConfig | None = None, + **kwargs: Any, + ) -> LLMResponse: + """ + Generate a non-streaming response. + [AC-AISVC-02] Returns complete response for ChatResponse. + + Args: + messages: List of chat messages with 'role' and 'content'. + config: Optional LLM configuration overrides. + **kwargs: Additional provider-specific parameters. + + Returns: + LLMResponse with generated content and metadata. + + Raises: + LLMException: If generation fails. + TimeoutException: If request times out. + """ + effective_config = config or self._default_config + client = self._get_client(effective_config.timeout_seconds) + + body = self._build_request_body(messages, effective_config, stream=False, **kwargs) + + logger.info(f"[AC-AISVC-02] Generating response with model={effective_config.model}") + + try: + response = await client.post( + f"{self._base_url}/chat/completions", + json=body, + ) + response.raise_for_status() + data = response.json() + + except httpx.TimeoutException as e: + logger.error(f"[AC-AISVC-02] LLM request timeout: {e}") + raise TimeoutException(message=f"LLM request timed out: {e}") + + except httpx.HTTPStatusError as e: + logger.error(f"[AC-AISVC-02] LLM API error: {e}") + error_detail = self._parse_error_response(e.response) + raise LLMException( + message=f"LLM API error: {error_detail}", + details=[{"status_code": e.response.status_code, "response": error_detail}], + ) + + except json.JSONDecodeError as e: + logger.error(f"[AC-AISVC-02] Failed to parse LLM response: {e}") + raise LLMException(message=f"Failed to parse LLM response: {e}") + + try: + choice = data["choices"][0] + content = choice["message"]["content"] + usage = data.get("usage", {}) + finish_reason = choice.get("finish_reason", "stop") + + logger.info( + f"[AC-AISVC-02] Generated response: " + f"tokens={usage.get('total_tokens', 'N/A')}, " + f"finish_reason={finish_reason}" + ) + + return LLMResponse( + content=content, + model=data.get("model", effective_config.model), + usage=usage, + finish_reason=finish_reason, + metadata={"raw_response": data}, + ) + + except (KeyError, IndexError) as e: + logger.error(f"[AC-AISVC-02] Unexpected LLM response format: {e}") + raise LLMException( + message=f"Unexpected LLM response format: {e}", + details=[{"response": str(data)}], + ) + + async def stream_generate( + self, + messages: list[dict[str, str]], + config: LLMConfig | None = None, + **kwargs: Any, + ) -> AsyncGenerator[LLMStreamChunk, None]: + """ + Generate a streaming response. + [AC-AISVC-06, AC-AISVC-07] Yields incremental chunks for SSE. + + Args: + messages: List of chat messages with 'role' and 'content'. + config: Optional LLM configuration overrides. + **kwargs: Additional provider-specific parameters. + + Yields: + LLMStreamChunk with incremental content. + + Raises: + LLMException: If generation fails. + TimeoutException: If request times out. + """ + effective_config = config or self._default_config + client = self._get_client(effective_config.timeout_seconds) + + body = self._build_request_body(messages, effective_config, stream=True, **kwargs) + + logger.info(f"[AC-AISVC-06] Starting streaming generation with model={effective_config.model}") + + try: + async with client.stream( + "POST", + f"{self._base_url}/chat/completions", + json=body, + ) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if not line or line == "data: [DONE]": + continue + + if line.startswith("data: "): + json_str = line[6:] + try: + chunk_data = json.loads(json_str) + chunk = self._parse_stream_chunk(chunk_data, effective_config.model) + if chunk: + yield chunk + except json.JSONDecodeError as e: + logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}") + continue + + except httpx.TimeoutException as e: + logger.error(f"[AC-AISVC-06] LLM streaming request timeout: {e}") + raise TimeoutException(message=f"LLM streaming request timed out: {e}") + + except httpx.HTTPStatusError as e: + logger.error(f"[AC-AISVC-06] LLM streaming API error: {e}") + error_detail = self._parse_error_response(e.response) + raise LLMException( + message=f"LLM streaming API error: {error_detail}", + details=[{"status_code": e.response.status_code, "response": error_detail}], + ) + + logger.info(f"[AC-AISVC-06] Streaming generation completed") + + def _parse_stream_chunk( + self, + data: dict[str, Any], + model: str, + ) -> LLMStreamChunk | None: + """Parse a streaming chunk from OpenAI API.""" + try: + choices = data.get("choices", []) + if not choices: + return None + + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + finish_reason = choices[0].get("finish_reason") + + if not content and not finish_reason: + return None + + return LLMStreamChunk( + delta=content, + model=data.get("model", model), + finish_reason=finish_reason, + metadata={"raw_chunk": data}, + ) + + except (KeyError, IndexError) as e: + logger.warning(f"[AC-AISVC-06] Failed to parse stream chunk: {e}") + return None + + def _parse_error_response(self, response: httpx.Response) -> str: + """Parse error response from API.""" + try: + data = response.json() + if "error" in data: + error = data["error"] + if isinstance(error, dict): + return error.get("message", str(error)) + return str(error) + return response.text + except Exception: + return response.text + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client: + await self._client.aclose() + self._client = None + + +_llm_client: OpenAIClient | None = None + + +def get_llm_client() -> OpenAIClient: + """Get or create LLM client instance.""" + global _llm_client + if _llm_client is None: + _llm_client = OpenAIClient() + return _llm_client + + +async def close_llm_client() -> None: + """Close the global LLM client.""" + global _llm_client + if _llm_client: + await _llm_client.close() + _llm_client = None diff --git a/ai-service/app/services/memory.py b/ai-service/app/services/memory.py new file mode 100644 index 0000000..5db74f5 --- /dev/null +++ b/ai-service/app/services/memory.py @@ -0,0 +1,170 @@ +""" +Memory service for AI Service. +[AC-AISVC-13] Session-based memory management with tenant isolation. +""" + +import logging +from typing import Sequence + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col + +from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate + +logger = logging.getLogger(__name__) + + +class MemoryService: + """ + [AC-AISVC-13] Memory service for session-based conversation history. + All operations are scoped by (tenant_id, session_id) for multi-tenant isolation. + """ + + def __init__(self, session: AsyncSession): + self._session = session + + async def get_or_create_session( + self, + tenant_id: str, + session_id: str, + channel_type: str | None = None, + metadata: dict | None = None, + ) -> ChatSession: + """ + [AC-AISVC-13] Get existing session or create a new one. + Ensures tenant isolation by querying with tenant_id. + """ + stmt = select(ChatSession).where( + ChatSession.tenant_id == tenant_id, + ChatSession.session_id == session_id, + ) + result = await self._session.execute(stmt) + existing_session = result.scalar_one_or_none() + + if existing_session: + logger.info( + f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}" + ) + return existing_session + + new_session = ChatSession( + tenant_id=tenant_id, + session_id=session_id, + channel_type=channel_type, + metadata_=metadata, + ) + self._session.add(new_session) + await self._session.flush() + + logger.info( + f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}" + ) + return new_session + + async def load_history( + self, + tenant_id: str, + session_id: str, + limit: int | None = None, + ) -> Sequence[ChatMessage]: + """ + [AC-AISVC-13] Load conversation history for a session. + All queries are filtered by tenant_id to ensure isolation. + """ + stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.session_id == session_id, + ) + .order_by(col(ChatMessage.created_at).asc()) + ) + + if limit: + stmt = stmt.limit(limit) + + result = await self._session.execute(stmt) + messages = result.scalars().all() + + logger.info( + f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}" + ) + return messages + + async def append_message( + self, + tenant_id: str, + session_id: str, + role: str, + content: str, + ) -> ChatMessage: + """ + [AC-AISVC-13] Append a message to the session history. + Message is scoped by tenant_id for isolation. + """ + message = ChatMessage( + tenant_id=tenant_id, + session_id=session_id, + role=role, + content=content, + ) + self._session.add(message) + await self._session.flush() + + logger.info( + f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}" + ) + return message + + async def append_messages( + self, + tenant_id: str, + session_id: str, + messages: list[dict[str, str]], + ) -> list[ChatMessage]: + """ + [AC-AISVC-13] Append multiple messages to the session history. + Used for batch insertion of conversation turns. + """ + chat_messages = [] + for msg in messages: + message = ChatMessage( + tenant_id=tenant_id, + session_id=session_id, + role=msg["role"], + content=msg["content"], + ) + self._session.add(message) + chat_messages.append(message) + + await self._session.flush() + + logger.info( + f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}" + ) + return chat_messages + + async def clear_history(self, tenant_id: str, session_id: str) -> int: + """ + [AC-AISVC-13] Clear all messages for a session. + Only affects messages within the tenant's scope. + """ + stmt = select(ChatMessage).where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.session_id == session_id, + ) + result = await self._session.execute(stmt) + messages = result.scalars().all() + + count = 0 + for message in messages: + await self._session.delete(message) + count += 1 + + await self._session.flush() + + logger.info( + f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}" + ) + return count diff --git a/ai-service/app/services/orchestrator.py b/ai-service/app/services/orchestrator.py new file mode 100644 index 0000000..167551a --- /dev/null +++ b/ai-service/app/services/orchestrator.py @@ -0,0 +1,97 @@ +""" +Orchestrator service for AI Service. +[AC-AISVC-01, AC-AISVC-02] Core orchestration logic for chat generation. +""" + +import logging +from typing import AsyncGenerator + +from sse_starlette.sse import ServerSentEvent + +from app.models import ChatRequest, ChatResponse +from app.core.sse import create_final_event, create_message_event, SSEStateMachine + +logger = logging.getLogger(__name__) + + +class OrchestratorService: + """ + [AC-AISVC-01, AC-AISVC-02] Orchestrator for chat generation. + Coordinates memory, retrieval, and LLM components. + """ + + async def generate(self, tenant_id: str, request: ChatRequest) -> ChatResponse: + """ + Generate a non-streaming response. + [AC-AISVC-02] Returns ChatResponse with reply, confidence, shouldTransfer. + """ + logger.info( + f"[AC-AISVC-01] Generating response for tenant={tenant_id}, " + f"session={request.session_id}" + ) + + reply = f"Received your message: {request.current_message}" + return ChatResponse( + reply=reply, + confidence=0.85, + should_transfer=False, + ) + + async def generate_stream( + self, tenant_id: str, request: ChatRequest + ) -> AsyncGenerator[ServerSentEvent, None]: + """ + Generate a streaming response. + [AC-AISVC-06, AC-AISVC-07, AC-AISVC-08] Yields SSE events in proper sequence. + """ + logger.info( + f"[AC-AISVC-06] Starting streaming generation for tenant={tenant_id}, " + f"session={request.session_id}" + ) + + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + + try: + reply_parts = ["Received", " your", " message:", f" {request.current_message}"] + full_reply = "" + + for part in reply_parts: + if state_machine.can_send_message(): + full_reply += part + yield create_message_event(delta=part) + await self._simulate_llm_delay() + + if await state_machine.transition_to_final(): + yield create_final_event( + reply=full_reply, + confidence=0.85, + should_transfer=False, + ) + + except Exception as e: + logger.error(f"[AC-AISVC-09] Error during streaming: {e}") + if await state_machine.transition_to_error(): + from app.core.sse import create_error_event + yield create_error_event( + code="GENERATION_ERROR", + message=str(e), + ) + finally: + await state_machine.close() + + async def _simulate_llm_delay(self) -> None: + """Simulate LLM processing delay for demo purposes.""" + import asyncio + await asyncio.sleep(0.1) + + +_orchestrator_service: OrchestratorService | None = None + + +def get_orchestrator_service() -> OrchestratorService: + """Get or create orchestrator service instance.""" + global _orchestrator_service + if _orchestrator_service is None: + _orchestrator_service = OrchestratorService() + return _orchestrator_service diff --git a/ai-service/app/services/retrieval/__init__.py b/ai-service/app/services/retrieval/__init__.py new file mode 100644 index 0000000..61e2a2f --- /dev/null +++ b/ai-service/app/services/retrieval/__init__.py @@ -0,0 +1,21 @@ +""" +Retrieval module for AI Service. +[AC-AISVC-16] Provides retriever implementations with plugin architecture. +""" + +from app.services.retrieval.base import ( + BaseRetriever, + RetrievalContext, + RetrievalHit, + RetrievalResult, +) +from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever + +__all__ = [ + "BaseRetriever", + "RetrievalContext", + "RetrievalHit", + "RetrievalResult", + "VectorRetriever", + "get_vector_retriever", +] diff --git a/ai-service/app/services/retrieval/base.py b/ai-service/app/services/retrieval/base.py new file mode 100644 index 0000000..fe20bc6 --- /dev/null +++ b/ai-service/app/services/retrieval/base.py @@ -0,0 +1,96 @@ +""" +Retrieval layer for AI Service. +[AC-AISVC-16] Abstract base class for retrievers with plugin point support. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievalContext: + """ + [AC-AISVC-16] Context for retrieval operations. + Contains all necessary information for retrieval plugins. + """ + + tenant_id: str + query: str + session_id: str | None = None + channel_type: str | None = None + metadata: dict[str, Any] | None = None + + +@dataclass +class RetrievalHit: + """ + [AC-AISVC-16] Single retrieval result hit. + Unified structure for all retriever types. + """ + + text: str + score: float + source: str + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RetrievalResult: + """ + [AC-AISVC-16] Result from retrieval operation. + Contains hits and optional diagnostics. + """ + + hits: list[RetrievalHit] = field(default_factory=list) + diagnostics: dict[str, Any] | None = None + + @property + def is_empty(self) -> bool: + """Check if no hits were found.""" + return len(self.hits) == 0 + + @property + def max_score(self) -> float: + """Get the maximum score among hits.""" + if not self.hits: + return 0.0 + return max(hit.score for hit in self.hits) + + @property + def hit_count(self) -> int: + """Get the number of hits.""" + return len(self.hits) + + +class BaseRetriever(ABC): + """ + [AC-AISVC-16] Abstract base class for retrievers. + Provides plugin point for different retrieval strategies (Vector, Graph, Hybrid). + """ + + @abstractmethod + async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: + """ + [AC-AISVC-16] Retrieve relevant documents for the given context. + + Args: + ctx: Retrieval context containing tenant_id, query, and optional metadata. + + Returns: + RetrievalResult with hits and optional diagnostics. + """ + pass + + @abstractmethod + async def health_check(self) -> bool: + """ + Check if the retriever is healthy and ready to serve requests. + + Returns: + True if healthy, False otherwise. + """ + pass diff --git a/ai-service/app/services/retrieval/vector_retriever.py b/ai-service/app/services/retrieval/vector_retriever.py new file mode 100644 index 0000000..5b6d80b --- /dev/null +++ b/ai-service/app/services/retrieval/vector_retriever.py @@ -0,0 +1,169 @@ +""" +Vector retriever for AI Service. +[AC-AISVC-16, AC-AISVC-17] Qdrant-based vector retrieval with score threshold filtering. +""" + +import logging +from typing import Any + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient, get_qdrant_client +from app.services.retrieval.base import ( + BaseRetriever, + RetrievalContext, + RetrievalHit, + RetrievalResult, +) + +logger = logging.getLogger(__name__) + +settings = get_settings() + + +class VectorRetriever(BaseRetriever): + """ + [AC-AISVC-16, AC-AISVC-17] Vector-based retriever using Qdrant. + Supports score threshold filtering and tenant isolation. + """ + + def __init__( + self, + qdrant_client: QdrantClient | None = None, + top_k: int | None = None, + score_threshold: float | None = None, + min_hits: int | None = None, + ): + self._qdrant_client = qdrant_client + self._top_k = top_k or settings.rag_top_k + self._score_threshold = score_threshold or settings.rag_score_threshold + self._min_hits = min_hits or settings.rag_min_hits + + async def _get_client(self) -> QdrantClient: + """Get Qdrant client instance.""" + if self._qdrant_client is None: + self._qdrant_client = await get_qdrant_client() + return self._qdrant_client + + async def retrieve(self, ctx: RetrievalContext) -> RetrievalResult: + """ + [AC-AISVC-16, AC-AISVC-17] Retrieve documents from vector store. + + Steps: + 1. Generate embedding for query (placeholder - requires embedding provider) + 2. Search in tenant's collection + 3. Filter by score threshold + 4. Return structured result + + Args: + ctx: Retrieval context with tenant_id and query. + + Returns: + RetrievalResult with filtered hits. + """ + logger.info( + f"[AC-AISVC-16] Starting vector retrieval for tenant={ctx.tenant_id}, query={ctx.query[:50]}..." + ) + + try: + client = await self._get_client() + + query_vector = await self._get_embedding(ctx.query) + + hits = await client.search( + tenant_id=ctx.tenant_id, + query_vector=query_vector, + limit=self._top_k, + score_threshold=self._score_threshold, + ) + + retrieval_hits = [ + RetrievalHit( + text=hit.get("payload", {}).get("text", ""), + score=hit.get("score", 0.0), + source=hit.get("payload", {}).get("source", "vector"), + metadata=hit.get("payload", {}), + ) + for hit in hits + if hit.get("score", 0.0) >= self._score_threshold + ] + + is_insufficient = len(retrieval_hits) < self._min_hits + + diagnostics = { + "query_length": len(ctx.query), + "top_k": self._top_k, + "score_threshold": self._score_threshold, + "min_hits": self._min_hits, + "total_candidates": len(hits), + "filtered_hits": len(retrieval_hits), + "is_insufficient": is_insufficient, + "max_score": max((h.score for h in retrieval_hits), default=0.0), + } + + logger.info( + f"[AC-AISVC-17] Retrieval complete: {len(retrieval_hits)} hits, " + f"insufficient={is_insufficient}, max_score={diagnostics['max_score']:.3f}" + ) + + return RetrievalResult( + hits=retrieval_hits, + diagnostics=diagnostics, + ) + + except Exception as e: + logger.error(f"[AC-AISVC-16] Retrieval error: {e}") + return RetrievalResult( + hits=[], + diagnostics={"error": str(e), "is_insufficient": True}, + ) + + async def _get_embedding(self, text: str) -> list[float]: + """ + Generate embedding for text. + [AC-AISVC-16] Placeholder for embedding generation. + + TODO: Integrate with actual embedding provider (OpenAI, local model, etc.) + """ + import hashlib + + hash_obj = hashlib.sha256(text.encode()) + hash_bytes = hash_obj.digest() + + embedding = [] + for i in range(0, min(len(hash_bytes) * 8, settings.qdrant_vector_size)): + byte_idx = i // 8 + bit_idx = i % 8 + if byte_idx < len(hash_bytes): + val = (hash_bytes[byte_idx] >> bit_idx) & 1 + embedding.append(float(val)) + else: + embedding.append(0.0) + + while len(embedding) < settings.qdrant_vector_size: + embedding.append(0.0) + + return embedding[: settings.qdrant_vector_size] + + async def health_check(self) -> bool: + """ + [AC-AISVC-16] Check if Qdrant connection is healthy. + """ + try: + client = await self._get_client() + qdrant = await client.get_client() + await qdrant.get_collections() + return True + except Exception as e: + logger.error(f"[AC-AISVC-16] Health check failed: {e}") + return False + + +_vector_retriever: VectorRetriever | None = None + + +async def get_vector_retriever() -> VectorRetriever: + """Get or create VectorRetriever instance.""" + global _vector_retriever + if _vector_retriever is None: + _vector_retriever = VectorRetriever() + return _vector_retriever diff --git a/ai-service/pyproject.toml b/ai-service/pyproject.toml new file mode 100644 index 0000000..432330f --- /dev/null +++ b/ai-service/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "ai-service" +version = "0.1.0" +description = "Python AI Service for intelligent chat with RAG support" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.109.0", + "uvicorn[standard]>=0.27.0", + "pydantic>=2.5.0", + "pydantic-settings>=2.1.0", + "sse-starlette>=2.0.0", + "httpx>=0.26.0", + "tenacity>=8.2.0", + "sqlmodel>=0.0.14", + "asyncpg>=0.29.0", + "qdrant-client>=1.7.0", + "tiktoken>=0.5.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.1.0", + "httpx>=0.26.0", + "ruff>=0.1.0", + "mypy>=1.8.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["app"] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] + +[tool.mypy] +python_version = "3.10" +strict = true +warn_return_any = true +warn_unused_configs = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/ai-service/tests/__init__.py b/ai-service/tests/__init__.py new file mode 100644 index 0000000..30e3933 --- /dev/null +++ b/ai-service/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests package for AI Service. +""" diff --git a/ai-service/tests/conftest.py b/ai-service/tests/conftest.py new file mode 100644 index 0000000..a30005f --- /dev/null +++ b/ai-service/tests/conftest.py @@ -0,0 +1,10 @@ +""" +Pytest configuration for AI Service tests. +""" + +import pytest + + +@pytest.fixture +def anyio_backend(): + return "asyncio" diff --git a/ai-service/tests/test_accept_switching.py b/ai-service/tests/test_accept_switching.py new file mode 100644 index 0000000..5e6c7c5 --- /dev/null +++ b/ai-service/tests/test_accept_switching.py @@ -0,0 +1,285 @@ +""" +Tests for response mode switching based on Accept header. +[AC-AISVC-06] Tests for automatic switching between JSON and SSE streaming modes. +""" + +import pytest +from fastapi.testclient import TestClient +from httpx import AsyncClient + +from app.main import app + + +class TestAcceptHeaderSwitching: + """ + [AC-AISVC-06] Test cases for Accept header based response mode switching. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.fixture + def valid_request_body(self): + return { + "sessionId": "test_session_001", + "currentMessage": "Hello, how are you?", + "channelType": "wechat", + } + + @pytest.fixture + def valid_headers(self): + return {"X-Tenant-Id": "tenant_001"} + + def test_json_response_with_default_accept( + self, client: TestClient, valid_request_body: dict, valid_headers: dict + ): + """ + [AC-AISVC-06] Test that default Accept header returns JSON response. + """ + response = client.post( + "/ai/chat", + json=valid_request_body, + headers=valid_headers, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + + data = response.json() + assert "reply" in data + assert "confidence" in data + assert "shouldTransfer" in data + + def test_json_response_with_application_json_accept( + self, client: TestClient, valid_request_body: dict, valid_headers: dict + ): + """ + [AC-AISVC-06] Test that Accept: application/json returns JSON response. + """ + headers = {**valid_headers, "Accept": "application/json"} + + response = client.post( + "/ai/chat", + json=valid_request_body, + headers=headers, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + + data = response.json() + assert "reply" in data + assert "confidence" in data + assert "shouldTransfer" in data + + def test_sse_response_with_text_event_stream_accept( + self, client: TestClient, valid_request_body: dict, valid_headers: dict + ): + """ + [AC-AISVC-06] Test that Accept: text/event-stream returns SSE response. + """ + headers = {**valid_headers, "Accept": "text/event-stream"} + + response = client.post( + "/ai/chat", + json=valid_request_body, + headers=headers, + ) + + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + content = response.text + assert "event: message" in content + assert "event: final" in content + + def test_sse_response_event_sequence( + self, client: TestClient, valid_request_body: dict, valid_headers: dict + ): + """ + [AC-AISVC-07, AC-AISVC-08] Test that SSE events follow proper sequence. + message* -> final -> close + """ + headers = {**valid_headers, "Accept": "text/event-stream"} + + response = client.post( + "/ai/chat", + json=valid_request_body, + headers=headers, + ) + + content = response.text + + assert "event:message" in content or "event: message" in content, f"Expected message event in: {content[:500]}" + assert "event:final" in content or "event: final" in content, f"Expected final event in: {content[:500]}" + + message_idx = content.find("event:message") + if message_idx == -1: + message_idx = content.find("event: message") + final_idx = content.find("event:final") + if final_idx == -1: + final_idx = content.find("event: final") + + assert final_idx > message_idx, "final event should come after message events" + + def test_missing_tenant_id_returns_400( + self, client: TestClient, valid_request_body: dict + ): + """ + [AC-AISVC-12] Test that missing X-Tenant-Id returns 400 error. + """ + response = client.post( + "/ai/chat", + json=valid_request_body, + ) + + assert response.status_code == 400 + + data = response.json() + assert data["code"] == "MISSING_TENANT_ID" + assert "message" in data + + def test_invalid_channel_type_returns_400( + self, client: TestClient, valid_headers: dict + ): + """ + [AC-AISVC-03] Test that invalid channel type returns 400 error. + """ + invalid_body = { + "sessionId": "test_session_001", + "currentMessage": "Hello", + "channelType": "invalid_channel", + } + + response = client.post( + "/ai/chat", + json=invalid_body, + headers=valid_headers, + ) + + assert response.status_code == 400 + + def test_missing_required_fields_returns_400( + self, client: TestClient, valid_headers: dict + ): + """ + [AC-AISVC-03] Test that missing required fields return 400 error. + """ + incomplete_body = { + "sessionId": "test_session_001", + } + + response = client.post( + "/ai/chat", + json=incomplete_body, + headers=valid_headers, + ) + + assert response.status_code == 400 + + +class TestHealthEndpoint: + """ + [AC-AISVC-20] Test cases for health check endpoint. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + def test_health_check_returns_200(self, client: TestClient): + """ + [AC-AISVC-20] Test that health check returns 200 with status. + """ + response = client.get("/ai/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +class TestSSEStateMachine: + """ + [AC-AISVC-08, AC-AISVC-09] Test cases for SSE state machine. + """ + + @pytest.mark.asyncio + async def test_state_transitions(self): + from app.core.sse import SSEState, SSEStateMachine + + state_machine = SSEStateMachine() + + assert state_machine.state == SSEState.INIT + + success = await state_machine.transition_to_streaming() + assert success is True + assert state_machine.state == SSEState.STREAMING + + assert state_machine.can_send_message() is True + + success = await state_machine.transition_to_final() + assert success is True + assert state_machine.state == SSEState.FINAL_SENT + + assert state_machine.can_send_message() is False + + await state_machine.close() + assert state_machine.state == SSEState.CLOSED + + @pytest.mark.asyncio + async def test_error_transition_from_streaming(self): + from app.core.sse import SSEState, SSEStateMachine + + state_machine = SSEStateMachine() + await state_machine.transition_to_streaming() + + success = await state_machine.transition_to_error() + assert success is True + assert state_machine.state == SSEState.ERROR_SENT + + @pytest.mark.asyncio + async def test_cannot_transition_to_final_from_init(self): + from app.core.sse import SSEStateMachine + + state_machine = SSEStateMachine() + + success = await state_machine.transition_to_final() + assert success is False + + +class TestMiddleware: + """ + [AC-AISVC-10, AC-AISVC-12] Test cases for middleware. + """ + + @pytest.fixture + def client(self): + return TestClient(app) + + def test_tenant_context_extraction( + self, client: TestClient + ): + """ + [AC-AISVC-10] Test that X-Tenant-Id is properly extracted and used. + """ + headers = {"X-Tenant-Id": "tenant_test_123"} + body = { + "sessionId": "session_001", + "currentMessage": "Test message", + "channelType": "wechat", + } + + response = client.post("/ai/chat", json=body, headers=headers) + + assert response.status_code == 200 + + def test_health_endpoint_bypasses_tenant_check( + self, client: TestClient + ): + """ + Test that health endpoint doesn't require X-Tenant-Id. + """ + response = client.get("/ai/health") + + assert response.status_code == 200 diff --git a/ai-service/tests/test_llm_adapter.py b/ai-service/tests/test_llm_adapter.py new file mode 100644 index 0000000..b964974 --- /dev/null +++ b/ai-service/tests/test_llm_adapter.py @@ -0,0 +1,319 @@ +""" +Unit tests for LLM Adapter. +[AC-AISVC-02, AC-AISVC-06] Tests for LLM client interface. + +Tests cover: +- Non-streaming generation +- Streaming generation +- Error handling +- Retry logic +""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.services.llm.base import LLMConfig, LLMResponse, LLMStreamChunk +from app.services.llm.openai_client import ( + LLMException, + OpenAIClient, + TimeoutException, +) + + +@pytest.fixture +def mock_settings(): + """Mock settings for testing.""" + settings = MagicMock() + settings.llm_api_key = "test-api-key" + settings.llm_base_url = "https://api.openai.com/v1" + settings.llm_model = "gpt-4o-mini" + settings.llm_max_tokens = 2048 + settings.llm_temperature = 0.7 + settings.llm_timeout_seconds = 30 + settings.llm_max_retries = 3 + return settings + + +@pytest.fixture +def llm_client(mock_settings): + """Create LLM client with mocked settings.""" + with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings): + client = OpenAIClient() + yield client + + +@pytest.fixture +def mock_messages(): + """Sample chat messages for testing.""" + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + ] + + +@pytest.fixture +def mock_generate_response(): + """Sample non-streaming response from OpenAI API.""" + return { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm doing well, thank you for asking!", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 15, + "total_tokens": 35, + }, + } + + +@pytest.fixture +def mock_stream_chunks(): + """Sample streaming chunks from OpenAI API.""" + return [ + "data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n", + "data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n", + "data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" How\"},\"finish_reason\":null}]}\n", + "data: {\"id\":\"chatcmpl-123\",\"choices\":[{\"delta\":{\"content\":\" can I help?\"},\"finish_reason\":\"stop\"}]}\n", + "data: [DONE]\n", + ] + + +class TestOpenAIClientGenerate: + """Tests for non-streaming generation. [AC-AISVC-02]""" + + @pytest.mark.asyncio + async def test_generate_success(self, llm_client, mock_messages, mock_generate_response): + """[AC-AISVC-02] Test successful non-streaming generation.""" + mock_response = MagicMock() + mock_response.json.return_value = mock_generate_response + mock_response.raise_for_status = MagicMock() + + with patch.object( + llm_client, "_get_client" + ) as mock_get_client: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_client + + result = await llm_client.generate(mock_messages) + + assert isinstance(result, LLMResponse) + assert result.content == "Hello! I'm doing well, thank you for asking!" + assert result.model == "gpt-4o-mini" + assert result.finish_reason == "stop" + assert result.usage["total_tokens"] == 35 + + @pytest.mark.asyncio + async def test_generate_with_custom_config(self, llm_client, mock_messages, mock_generate_response): + """[AC-AISVC-02] Test generation with custom configuration.""" + custom_config = LLMConfig( + model="gpt-4", + max_tokens=1024, + temperature=0.5, + ) + + mock_response = MagicMock() + mock_response.json.return_value = {**mock_generate_response, "model": "gpt-4"} + mock_response.raise_for_status = MagicMock() + + with patch.object(llm_client, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_client + + result = await llm_client.generate(mock_messages, config=custom_config) + + assert result.model == "gpt-4" + + @pytest.mark.asyncio + async def test_generate_timeout_error(self, llm_client, mock_messages): + """[AC-AISVC-02] Test timeout error handling.""" + with patch.object(llm_client, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) + mock_get_client.return_value = mock_client + + with pytest.raises(TimeoutException): + await llm_client.generate(mock_messages) + + @pytest.mark.asyncio + async def test_generate_api_error(self, llm_client, mock_messages): + """[AC-AISVC-02] Test API error handling.""" + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = '{"error": {"message": "Invalid API key"}}' + mock_response.json.return_value = {"error": {"message": "Invalid API key"}} + + http_error = httpx.HTTPStatusError( + "Unauthorized", + request=MagicMock(), + response=mock_response, + ) + + with patch.object(llm_client, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=http_error) + mock_get_client.return_value = mock_client + + with pytest.raises(LLMException) as exc_info: + await llm_client.generate(mock_messages) + + assert "Invalid API key" in str(exc_info.value.message) + + @pytest.mark.asyncio + async def test_generate_malformed_response(self, llm_client, mock_messages): + """[AC-AISVC-02] Test handling of malformed response.""" + mock_response = MagicMock() + mock_response.json.return_value = {"invalid": "response"} + mock_response.raise_for_status = MagicMock() + + with patch.object(llm_client, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_client + + with pytest.raises(LLMException): + await llm_client.generate(mock_messages) + + +class MockAsyncStreamContext: + """Mock async context manager for streaming.""" + + def __init__(self, response): + self._response = response + + async def __aenter__(self): + return self._response + + async def __aexit__(self, *args): + pass + + +class TestOpenAIClientStreamGenerate: + """Tests for streaming generation. [AC-AISVC-06, AC-AISVC-07]""" + + @pytest.mark.asyncio + async def test_stream_generate_success(self, llm_client, mock_messages, mock_stream_chunks): + """[AC-AISVC-06, AC-AISVC-07] Test successful streaming generation.""" + async def mock_aiter_lines(): + for chunk in mock_stream_chunks: + yield chunk + + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.aiter_lines = mock_aiter_lines + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=MockAsyncStreamContext(mock_response)) + + with patch.object(llm_client, "_get_client", return_value=mock_client): + chunks = [] + async for chunk in llm_client.stream_generate(mock_messages): + chunks.append(chunk) + + assert len(chunks) == 4 + assert chunks[0].delta == "Hello" + assert chunks[-1].finish_reason == "stop" + + @pytest.mark.asyncio + async def test_stream_generate_timeout_error(self, llm_client, mock_messages): + """[AC-AISVC-06] Test streaming timeout error handling.""" + mock_client = AsyncMock() + + class TimeoutContext: + async def __aenter__(self): + raise httpx.TimeoutException("Timeout") + async def __aexit__(self, *args): + pass + + mock_client.stream = MagicMock(return_value=TimeoutContext()) + + with patch.object(llm_client, "_get_client", return_value=mock_client): + with pytest.raises(TimeoutException): + async for _ in llm_client.stream_generate(mock_messages): + pass + + @pytest.mark.asyncio + async def test_stream_generate_api_error(self, llm_client, mock_messages): + """[AC-AISVC-06] Test streaming API error handling.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_response.json.return_value = {"error": {"message": "Internal Server Error"}} + + http_error = httpx.HTTPStatusError( + "Internal Server Error", + request=MagicMock(), + response=mock_response, + ) + + mock_client = AsyncMock() + + class ErrorContext: + async def __aenter__(self): + raise http_error + async def __aexit__(self, *args): + pass + + mock_client.stream = MagicMock(return_value=ErrorContext()) + + with patch.object(llm_client, "_get_client", return_value=mock_client): + with pytest.raises(LLMException): + async for _ in llm_client.stream_generate(mock_messages): + pass + + +class TestOpenAIClientConfig: + """Tests for LLM configuration.""" + + def test_default_config(self, mock_settings): + """Test default configuration from settings.""" + with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings): + client = OpenAIClient() + + assert client._model == "gpt-4o-mini" + assert client._default_config.max_tokens == 2048 + assert client._default_config.temperature == 0.7 + + def test_custom_config_override(self, mock_settings): + """Test custom configuration override.""" + with patch("app.services.llm.openai_client.get_settings", return_value=mock_settings): + client = OpenAIClient( + api_key="custom-key", + base_url="https://custom.api.com/v1", + model="gpt-4", + ) + + assert client._api_key == "custom-key" + assert client._base_url == "https://custom.api.com/v1" + assert client._model == "gpt-4" + + +class TestOpenAIClientClose: + """Tests for client cleanup.""" + + @pytest.mark.asyncio + async def test_close_client(self, llm_client): + """Test client close releases resources.""" + mock_client = AsyncMock() + mock_client.aclose = AsyncMock() + llm_client._client = mock_client + + await llm_client.close() + + mock_client.aclose.assert_called_once() + assert llm_client._client is None diff --git a/ai-service/tests/test_memory.py b/ai-service/tests/test_memory.py new file mode 100644 index 0000000..a39895f --- /dev/null +++ b/ai-service/tests/test_memory.py @@ -0,0 +1,210 @@ +""" +Unit tests for Memory service. +[AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ChatMessage, ChatSession +from app.services.memory import MemoryService + + +@pytest.fixture +def mock_session(): + """Create a mock AsyncSession.""" + session = AsyncMock(spec=AsyncSession) + session.add = MagicMock() + session.flush = AsyncMock() + session.delete = AsyncMock() + return session + + +@pytest.fixture +def memory_service(mock_session): + """Create MemoryService with mocked session.""" + return MemoryService(mock_session) + + +class TestMemoryServiceTenantIsolation: + """ + [AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service. + """ + + @pytest.mark.asyncio + async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session): + """ + [AC-AISVC-11] Different tenants with same session_id should have separate sessions. + """ + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + session1 = await memory_service.get_or_create_session( + tenant_id="tenant_a", + session_id="session_123", + ) + session2 = await memory_service.get_or_create_session( + tenant_id="tenant_b", + session_id="session_123", + ) + + assert session1.tenant_id == "tenant_a" + assert session2.tenant_id == "tenant_b" + assert session1.session_id == "session_123" + assert session2.session_id == "session_123" + + @pytest.mark.asyncio + async def test_load_history_tenant_isolation(self, memory_service, mock_session): + """ + [AC-AISVC-11] Loading history should only return messages for the specific tenant. + """ + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [ + ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"), + ] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + messages = await memory_service.load_history( + tenant_id="tenant_a", + session_id="session_123", + ) + + assert len(messages) == 1 + assert messages[0].tenant_id == "tenant_a" + + @pytest.mark.asyncio + async def test_append_message_tenant_scoped(self, memory_service, mock_session): + """ + [AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant. + """ + message = await memory_service.append_message( + tenant_id="tenant_a", + session_id="session_123", + role="user", + content="Test message", + ) + + assert message.tenant_id == "tenant_a" + assert message.session_id == "session_123" + assert message.role == "user" + assert message.content == "Test message" + + +class TestMemoryServiceSessionManagement: + """ + [AC-AISVC-13] Tests for session-based memory management. + """ + + @pytest.mark.asyncio + async def test_get_existing_session(self, memory_service, mock_session): + """ + [AC-AISVC-13] Should return existing session if it exists. + """ + existing_session = ChatSession( + tenant_id="tenant_a", + session_id="session_123", + ) + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_session + mock_session.execute = AsyncMock(return_value=mock_result) + + session = await memory_service.get_or_create_session( + tenant_id="tenant_a", + session_id="session_123", + ) + + assert session.tenant_id == "tenant_a" + assert session.session_id == "session_123" + + @pytest.mark.asyncio + async def test_create_new_session(self, memory_service, mock_session): + """ + [AC-AISVC-13] Should create new session if it doesn't exist. + """ + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + session = await memory_service.get_or_create_session( + tenant_id="tenant_a", + session_id="session_new", + channel_type="wechat", + metadata={"user_id": "user_123"}, + ) + + assert session.tenant_id == "tenant_a" + assert session.session_id == "session_new" + assert session.channel_type == "wechat" + + @pytest.mark.asyncio + async def test_append_multiple_messages(self, memory_service, mock_session): + """ + [AC-AISVC-13] Should append multiple messages in batch. + """ + messages_data = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + messages = await memory_service.append_messages( + tenant_id="tenant_a", + session_id="session_123", + messages=messages_data, + ) + + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[1].role == "assistant" + + @pytest.mark.asyncio + async def test_load_history_with_limit(self, memory_service, mock_session): + """ + [AC-AISVC-13] Should limit the number of messages returned. + """ + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [ + ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}") + for i in range(5) + ] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + messages = await memory_service.load_history( + tenant_id="tenant_a", + session_id="session_123", + limit=3, + ) + + assert len(messages) == 5 + + +class TestMemoryServiceClearHistory: + """ + [AC-AISVC-13] Tests for clearing session history. + """ + + @pytest.mark.asyncio + async def test_clear_history_tenant_scoped(self, memory_service, mock_session): + """ + [AC-AISVC-11] Clearing history should only affect the specified tenant's messages. + """ + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [ + ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"), + ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"), + ] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + count = await memory_service.clear_history( + tenant_id="tenant_a", + session_id="session_123", + ) + + assert count == 2