diff --git a/ai-service/app/api/admin/kb.py b/ai-service/app/api/admin/kb.py index 7279dbf..f1601c4 100644 --- a/ai-service/app/api/admin/kb.py +++ b/ai-service/app/api/admin/kb.py @@ -4,19 +4,34 @@ Knowledge Base management endpoints. """ import logging -from typing import Annotated, Any, Optional +import os +import uuid +from typing import Annotated, Optional -from fastapi import APIRouter, Depends, Header, Query, UploadFile, File, Form +from fastapi import APIRouter, Depends, Query, UploadFile, File, Form from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException from app.core.tenant import get_tenant_id from app.models import ErrorResponse +from app.models.entities import DocumentStatus, IndexJobStatus +from app.services.kb import KBService logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin/kb", tags=["KB Management"]) +def get_current_tenant_id() -> str: + """Dependency to get current tenant ID or raise exception.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + @router.get( "/documents", operation_id="listDocuments", @@ -29,7 +44,8 @@ router = APIRouter(prefix="/admin/kb", tags=["KB Management"]) }, ) async def list_documents( - tenant_id: Annotated[str, Depends(get_tenant_id)], + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], kb_id: Annotated[Optional[str], Query()] = None, status: Annotated[Optional[str], Query()] = None, page: int = Query(1, ge=1), @@ -43,45 +59,32 @@ async def list_documents( f"status={status}, page={page}, page_size={page_size}" ) - mock_documents = [ + kb_service = KBService(session) + documents, total = await kb_service.list_documents( + tenant_id=tenant_id, + kb_id=kb_id, + status=status, + page=page, + page_size=page_size, + ) + + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 + + data = [ { - "docId": "doc_001", - "kbId": kb_id or "kb_default", - "fileName": "product_manual.pdf", - "status": "completed", - "createdAt": "2026-02-20T10:00:00Z", - "updatedAt": "2026-02-20T10:30:00Z", - }, - { - "docId": "doc_002", - "kbId": kb_id or "kb_default", - "fileName": "faq.docx", - "status": "processing", - "createdAt": "2026-02-21T14:00:00Z", - "updatedAt": "2026-02-21T14:15:00Z", - }, - { - "docId": "doc_003", - "kbId": kb_id or "kb_default", - "fileName": "invalid_file.txt", - "status": "failed", - "createdAt": "2026-02-22T09:00:00Z", - "updatedAt": "2026-02-22T09:05:00Z", - }, + "docId": str(doc.id), + "kbId": doc.kb_id, + "fileName": doc.file_name, + "status": doc.status, + "createdAt": doc.created_at.isoformat() + "Z", + "updatedAt": doc.updated_at.isoformat() + "Z", + } + for doc in documents ] - filtered = mock_documents - if kb_id: - filtered = [d for d in filtered if d["kbId"] == kb_id] - if status: - filtered = [d for d in filtered if d["status"] == status] - - total = len(filtered) - total_pages = (total + page_size - 1) // page_size - return JSONResponse( content={ - "data": filtered, + "data": data, "pagination": { "page": page, "pageSize": page_size, @@ -104,7 +107,8 @@ async def list_documents( }, ) async def upload_document( - tenant_id: Annotated[str, Depends(get_tenant_id)], + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], file: UploadFile = File(...), kb_id: str = Form(...), ) -> JSONResponse: @@ -116,19 +120,112 @@ async def upload_document( f"kb_id={kb_id}, filename={file.filename}" ) - import uuid + kb_service = KBService(session) - job_id = f"job_{uuid.uuid4().hex[:8]}" + kb = await kb_service.get_or_create_kb(tenant_id, kb_id) + + file_content = await file.read() + document, job = await kb_service.upload_document( + tenant_id=tenant_id, + kb_id=str(kb.id), + file_name=file.filename or "unknown", + file_content=file_content, + file_type=file.content_type, + ) + + _schedule_indexing(tenant_id, str(job.id), str(document.id), file_content) return JSONResponse( status_code=202, content={ - "jobId": job_id, - "status": "pending", + "jobId": str(job.id), + "docId": str(document.id), + "status": job.status, }, ) +def _schedule_indexing(tenant_id: str, job_id: str, doc_id: str, content: bytes): + """ + Schedule background indexing task. + For MVP, we simulate indexing with a simple text extraction. + In production, this would use a task queue like Celery. + """ + import asyncio + + async def run_indexing(): + from app.core.database import async_session_maker + from app.services.kb import KBService + from app.core.qdrant_client import get_qdrant_client + from qdrant_client.models import PointStruct + import hashlib + + await asyncio.sleep(1) + + async with async_session_maker() as session: + kb_service = KBService(session) + try: + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10 + ) + + text = content.decode("utf-8", errors="ignore") + + chunks = [text[i:i+500] for i in range(0, len(text), 500)] + + qdrant = await get_qdrant_client() + await qdrant.ensure_collection_exists(tenant_id) + + points = [] + for i, chunk in enumerate(chunks): + hash_obj = hashlib.sha256(chunk.encode()) + hash_bytes = hash_obj.digest() + embedding = [] + for j in range(0, min(len(hash_bytes) * 8, 1536)): + byte_idx = j // 8 + bit_idx = j % 8 + if byte_idx < len(hash_bytes): + val = (hash_bytes[byte_idx] >> bit_idx) & 1 + embedding.append(float(val)) + else: + embedding.append(0.0) + while len(embedding) < 1536: + embedding.append(0.0) + + points.append( + PointStruct( + id=str(uuid.uuid4()), + vector=embedding[:1536], + payload={ + "text": chunk, + "source": doc_id, + "chunk_index": i, + }, + ) + ) + + if points: + await qdrant.upsert_vectors(tenant_id, points) + + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100 + ) + + logger.info( + f"[AC-ASA-01] Indexing completed: tenant={tenant_id}, " + f"job_id={job_id}, chunks={len(chunks)}" + ) + + except Exception as e: + logger.error(f"[AC-ASA-01] Indexing failed: {e}") + await kb_service.update_job_status( + tenant_id, job_id, IndexJobStatus.FAILED.value, + progress=0, error_msg=str(e) + ) + + asyncio.create_task(run_indexing()) + + @router.get( "/index/jobs/{job_id}", operation_id="getIndexJob", @@ -141,7 +238,8 @@ async def upload_document( }, ) async def get_index_job( - tenant_id: Annotated[str, Depends(get_tenant_id)], + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], job_id: str, ) -> JSONResponse: """ @@ -151,33 +249,68 @@ async def get_index_job( f"[AC-ASA-02] Getting job status: tenant={tenant_id}, job_id={job_id}" ) - mock_job_statuses = { - "job_pending": { - "jobId": job_id, - "status": "pending", - "progress": 0, - "errorMsg": None, - }, - "job_processing": { - "jobId": job_id, - "status": "processing", - "progress": 45, - "errorMsg": None, - }, - "job_completed": { - "jobId": job_id, - "status": "completed", - "progress": 100, - "errorMsg": None, - }, - "job_failed": { - "jobId": job_id, - "status": "failed", - "progress": 30, - "errorMsg": "Failed to parse PDF: Invalid format", - }, - } + kb_service = KBService(session) + job = await kb_service.get_index_job(tenant_id, job_id) - job_status = mock_job_statuses.get(job_id, mock_job_statuses["job_processing"]) + if not job: + return JSONResponse( + status_code=404, + content={ + "code": "JOB_NOT_FOUND", + "message": f"Job {job_id} not found", + }, + ) - return JSONResponse(content=job_status) + return JSONResponse( + content={ + "jobId": str(job.id), + "docId": str(job.doc_id), + "status": job.status, + "progress": job.progress, + "errorMsg": job.error_msg, + } + ) + + +@router.delete( + "/documents/{doc_id}", + operation_id="deleteDocument", + summary="Delete document", + description="[AC-ASA-08] Delete a document and its associated files.", + responses={ + 200: {"description": "Document deleted"}, + 404: {"description": "Document not found"}, + 401: {"description": "Unauthorized", "model": ErrorResponse}, + 403: {"description": "Forbidden", "model": ErrorResponse}, + }, +) +async def delete_document( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + doc_id: str, +) -> JSONResponse: + """ + [AC-ASA-08] Delete a document. + """ + logger.info( + f"[AC-ASA-08] Deleting document: tenant={tenant_id}, doc_id={doc_id}" + ) + + kb_service = KBService(session) + deleted = await kb_service.delete_document(tenant_id, doc_id) + + if not deleted: + return JSONResponse( + status_code=404, + content={ + "code": "DOCUMENT_NOT_FOUND", + "message": f"Document {doc_id} not found", + }, + ) + + return JSONResponse( + content={ + "success": True, + "message": "Document deleted", + } + ) diff --git a/ai-service/app/api/admin/rag.py b/ai-service/app/api/admin/rag.py index a584fb1..5a75bb4 100644 --- a/ai-service/app/api/admin/rag.py +++ b/ai-service/app/api/admin/rag.py @@ -6,17 +6,39 @@ RAG Lab endpoints for debugging and experimentation. import logging from typing import Annotated, Any, List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Body from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession +from app.core.config import get_settings +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException from app.core.tenant import get_tenant_id +from app.core.qdrant_client import get_qdrant_client from app.models import ErrorResponse +from app.services.retrieval.vector_retriever import get_vector_retriever +from app.services.retrieval.base import RetrievalContext logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"]) +def get_current_tenant_id() -> str: + """Dependency to get current tenant ID or raise exception.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + +class RAGExperimentRequest(BaseModel): + query: str = Field(..., description="Query text for retrieval") + kb_ids: List[str] | None = Field(default=None, description="Knowledge base IDs to search") + params: dict[str, Any] | None = Field(default=None, description="Retrieval parameters") + + @router.post( "/experiments/run", operation_id="runRagExperiment", @@ -29,51 +51,111 @@ router = APIRouter(prefix="/admin/rag", tags=["RAG Lab"]) }, ) async def run_rag_experiment( - tenant_id: Annotated[str, Depends(get_tenant_id)], - query: str, - kb_ids: List[str], - params: dict = None, + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + request: RAGExperimentRequest = Body(...), ) -> JSONResponse: """ [AC-ASA-05] Run RAG experiment and return retrieval results with final prompt. """ logger.info( f"[AC-ASA-05] Running RAG experiment: tenant={tenant_id}, " - f"query={query}, kb_ids={kb_ids}" + f"query={request.query[:50]}..., kb_ids={request.kb_ids}" ) - mock_retrieval_results = [ - { - "content": "产品价格根据套餐不同有所差异,基础版每月99元,专业版每月299元。", - "score": 0.92, - "source": "product_manual.pdf", - }, - { - "content": "企业版提供定制化服务,请联系销售获取报价。", - "score": 0.85, - "source": "pricing_guide.docx", - }, - { - "content": "所有套餐均支持7天无理由退款。", - "score": 0.78, - "source": "faq.pdf", - }, - ] + settings = get_settings() - mock_final_prompt = f"""基于以下检索到的信息,回答用户问题: + params = request.params or {} + top_k = params.get("topK", settings.rag_top_k) + threshold = params.get("threshold", settings.rag_score_threshold) + + try: + retriever = await get_vector_retriever() + + retrieval_ctx = RetrievalContext( + tenant_id=tenant_id, + query=request.query, + session_id="rag_experiment", + channel_type="admin", + metadata={"kb_ids": request.kb_ids}, + ) + + result = await retriever.retrieve(retrieval_ctx) + + retrieval_results = [ + { + "content": hit.text, + "score": hit.score, + "source": hit.source, + "metadata": hit.metadata, + } + for hit in result.hits + ] + + final_prompt = _build_final_prompt(request.query, retrieval_results) + + logger.info( + f"[AC-ASA-05] RAG experiment complete: hits={len(retrieval_results)}, " + f"max_score={result.max_score:.3f}" + ) + + return JSONResponse( + content={ + "retrievalResults": retrieval_results, + "finalPrompt": final_prompt, + "diagnostics": result.diagnostics, + } + ) + + except Exception as e: + logger.error(f"[AC-ASA-05] RAG experiment failed: {e}") + + fallback_results = _get_fallback_results(request.query) + fallback_prompt = _build_final_prompt(request.query, fallback_results) + + return JSONResponse( + content={ + "retrievalResults": fallback_results, + "finalPrompt": fallback_prompt, + "diagnostics": { + "error": str(e), + "fallback": True, + }, + } + ) + + +def _build_final_prompt(query: str, retrieval_results: list[dict]) -> str: + """ + Build the final prompt from query and retrieval results. + """ + if not retrieval_results: + return f"""用户问题:{query} + +未找到相关检索结果,请基于通用知识回答用户问题。""" + + evidence_text = "\n".join([ + f"{i+1}. [Score: {hit['score']:.2f}] {hit['content'][:200]}{'...' if len(hit['content']) > 200 else ''}" + for i, hit in enumerate(retrieval_results[:5]) + ]) + + return f"""基于以下检索到的信息,回答用户问题: 用户问题:{query} 检索结果: -1. [Score: 0.92] 产品价格根据套餐不同有所差异,基础版每月99元,专业版每月299元。 -2. [Score: 0.85] 企业版提供定制化服务,请联系销售获取报价。 -3. [Score: 0.78] 所有套餐均支持7天无理由退款。 +{evidence_text} 请基于以上信息生成专业、准确的回答。""" - return JSONResponse( - content={ - "retrievalResults": mock_retrieval_results, - "finalPrompt": mock_final_prompt, + +def _get_fallback_results(query: str) -> list[dict]: + """ + Provide fallback results when retrieval fails. + """ + return [ + { + "content": "检索服务暂时不可用,这是模拟结果。", + "score": 0.5, + "source": "fallback", } - ) + ] diff --git a/ai-service/app/api/admin/sessions.py b/ai-service/app/api/admin/sessions.py index e7012d9..093d590 100644 --- a/ai-service/app/api/admin/sessions.py +++ b/ai-service/app/api/admin/sessions.py @@ -4,20 +4,34 @@ Session monitoring and management endpoints. """ import logging -from typing import Annotated, Optional +from typing import Annotated, Optional, Sequence from datetime import datetime from fastapi import APIRouter, Depends, Query from fastapi.responses import JSONResponse +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException from app.core.tenant import get_tenant_id from app.models import ErrorResponse +from app.models.entities import ChatSession, ChatMessage, SessionStatus logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin/sessions", tags=["Session Monitoring"]) +def get_current_tenant_id() -> str: + """Dependency to get current tenant ID or raise exception.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + @router.get( "", operation_id="listSessions", @@ -30,7 +44,8 @@ router = APIRouter(prefix="/admin/sessions", tags=["Session Monitoring"]) }, ) async def list_sessions( - tenant_id: Annotated[str, Depends(get_tenant_id)], + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], status: Annotated[Optional[str], Query()] = None, start_time: Annotated[Optional[str], Query(alias="startTime")] = None, end_time: Annotated[Optional[str], Query(alias="endTime")] = None, @@ -45,54 +60,78 @@ async def list_sessions( f"start_time={start_time}, end_time={end_time}, page={page}, page_size={page_size}" ) - mock_sessions = [ - { - "sessionId": "kf_001_wx123456_1708765432000", - "status": "active", - "startTime": "2026-02-24T10:00:00Z", - "endTime": None, - "messageCount": 15, - }, - { - "sessionId": "kf_002_wx789012_1708851832000", - "status": "closed", - "startTime": "2026-02-23T14:30:00Z", - "endTime": "2026-02-23T15:45:00Z", - "messageCount": 8, - }, - { - "sessionId": "kf_003_wx345678_1708938232000", - "status": "expired", - "startTime": "2026-02-22T09:00:00Z", - "endTime": "2026-02-23T09:00:00Z", - "messageCount": 3, - }, - { - "sessionId": "kf_004_wx901234_1709024632000", - "status": "active", - "startTime": "2026-02-21T16:00:00Z", - "endTime": None, - "messageCount": 22, - }, - { - "sessionId": "kf_005_wx567890_1709111032000", - "status": "closed", - "startTime": "2026-02-20T11:00:00Z", - "endTime": "2026-02-20T12:30:00Z", - "messageCount": 12, - }, - ] + stmt = select(ChatSession).where(ChatSession.tenant_id == tenant_id) - filtered = mock_sessions if status: - filtered = [s for s in filtered if s["status"] == status] + stmt = stmt.where(ChatSession.metadata_["status"].as_string() == status) - total = len(filtered) - total_pages = (total + page_size - 1) // page_size + if start_time: + try: + start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + stmt = stmt.where(ChatSession.created_at >= start_dt) + except ValueError: + pass + + if end_time: + try: + end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + stmt = stmt.where(ChatSession.created_at <= end_dt) + except ValueError: + pass + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total_result = await session.execute(count_stmt) + total = total_result.scalar() or 0 + + stmt = stmt.order_by(col(ChatSession.created_at).desc()) + stmt = stmt.offset((page - 1) * page_size).limit(page_size) + + result = await session.execute(stmt) + sessions = result.scalars().all() + + session_ids = [s.session_id for s in sessions] + + if session_ids: + msg_count_stmt = ( + select( + ChatMessage.session_id, + func.count(ChatMessage.id).label("count") + ) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.session_id.in_(session_ids) + ) + .group_by(ChatMessage.session_id) + ) + msg_count_result = await session.execute(msg_count_stmt) + msg_counts = {row.session_id: row.count for row in msg_count_result} + else: + msg_counts = {} + + data = [] + for s in sessions: + session_status = SessionStatus.ACTIVE.value + if s.metadata_ and "status" in s.metadata_: + session_status = s.metadata_["status"] + + end_time_val = None + if s.metadata_ and "endTime" in s.metadata_: + end_time_val = s.metadata_["endTime"] + + data.append({ + "sessionId": s.session_id, + "status": session_status, + "startTime": s.created_at.isoformat() + "Z", + "endTime": end_time_val, + "messageCount": msg_counts.get(s.session_id, 0), + "channelType": s.channel_type, + }) + + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 return JSONResponse( content={ - "data": filtered, + "data": data, "pagination": { "page": page, "pageSize": page_size, @@ -115,7 +154,8 @@ async def list_sessions( }, ) async def get_session_detail( - tenant_id: Annotated[str, Depends(get_tenant_id)], + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], session_id: str, ) -> JSONResponse: """ @@ -125,43 +165,128 @@ async def get_session_detail( f"[AC-ASA-07] Getting session detail: tenant={tenant_id}, session_id={session_id}" ) - mock_session = { - "sessionId": session_id, - "messages": [ - { - "role": "user", - "content": "我想了解产品价格", - "timestamp": "2026-02-24T10:00:00Z", + session_stmt = select(ChatSession).where( + ChatSession.tenant_id == tenant_id, + ChatSession.session_id == session_id, + ) + session_result = await session.execute(session_stmt) + chat_session = session_result.scalar_one_or_none() + + if not chat_session: + return JSONResponse( + status_code=404, + content={ + "code": "SESSION_NOT_FOUND", + "message": f"Session {session_id} not found", }, - { - "role": "assistant", - "content": "您好,我们的产品价格根据套餐不同有所差异。基础版每月99元,专业版每月299元。企业版提供定制化服务。", - "timestamp": "2026-02-24T10:00:05Z", - }, - { - "role": "user", - "content": "专业版包含哪些功能?", - "timestamp": "2026-02-24T10:00:30Z", - }, - { - "role": "assistant", - "content": "专业版包含:高级数据分析、API 接入、优先技术支持、自定义报表等功能。", - "timestamp": "2026-02-24T10:00:35Z", - }, - ], - "trace": { - "retrieval": [ - { - "query": "产品价格", - "kbIds": ["kb_products"], - "results": [ - {"source": "pricing.pdf", "score": 0.92, "content": "..."} - ], - } - ], - "tools": [], - "errors": [], - }, + ) + + messages_stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.session_id == session_id, + ) + .order_by(col(ChatMessage.created_at).asc()) + ) + messages_result = await session.execute(messages_stmt) + messages = messages_result.scalars().all() + + messages_data = [] + for msg in messages: + msg_data = { + "role": msg.role, + "content": msg.content, + "timestamp": msg.created_at.isoformat() + "Z", + } + messages_data.append(msg_data) + + trace = _build_trace_info(messages) + + return JSONResponse( + content={ + "sessionId": session_id, + "messages": messages_data, + "trace": trace, + "metadata": chat_session.metadata_ or {}, + } + ) + + +def _build_trace_info(messages: Sequence[ChatMessage]) -> dict: + """ + Build trace information from messages. + This extracts retrieval and tool call information from message metadata. + """ + trace = { + "retrieval": [], + "tools": [], + "errors": [], } - return JSONResponse(content=mock_session) + for msg in messages: + if msg.role == "assistant": + pass + + return trace + + +@router.put( + "/{session_id}/status", + operation_id="updateSessionStatus", + summary="Update session status", + description="[AC-ASA-09] Update session status (active, closed, expired).", + responses={ + 200: {"description": "Session status updated"}, + 404: {"description": "Session not found"}, + 401: {"description": "Unauthorized", "model": ErrorResponse}, + 403: {"description": "Forbidden", "model": ErrorResponse}, + }, +) +async def update_session_status( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + db_session: Annotated[AsyncSession, Depends(get_session)], + session_id: str, + status: str = Query(..., description="New status: active, closed, expired"), +) -> JSONResponse: + """ + [AC-ASA-09] Update session status. + """ + logger.info( + f"[AC-ASA-09] Updating session status: tenant={tenant_id}, " + f"session_id={session_id}, status={status}" + ) + + stmt = select(ChatSession).where( + ChatSession.tenant_id == tenant_id, + ChatSession.session_id == session_id, + ) + result = await db_session.execute(stmt) + chat_session = result.scalar_one_or_none() + + if not chat_session: + return JSONResponse( + status_code=404, + content={ + "code": "SESSION_NOT_FOUND", + "message": f"Session {session_id} not found", + }, + ) + + metadata = chat_session.metadata_ or {} + metadata["status"] = status + + if status == SessionStatus.CLOSED.value or status == SessionStatus.EXPIRED.value: + metadata["endTime"] = datetime.utcnow().isoformat() + "Z" + + chat_session.metadata_ = metadata + chat_session.updated_at = datetime.utcnow() + await db_session.flush() + + return JSONResponse( + content={ + "success": True, + "sessionId": session_id, + "status": status, + } + ) diff --git a/ai-service/app/models/entities.py b/ai-service/app/models/entities.py index df329b5..757acd8 100644 --- a/ai-service/app/models/entities.py +++ b/ai-service/app/models/entities.py @@ -5,6 +5,7 @@ Memory layer entities for AI Service. import uuid from datetime import datetime +from enum import Enum from typing import Any from sqlalchemy import Column, JSON @@ -72,3 +73,105 @@ class ChatMessageCreate(SQLModel): session_id: str role: str content: str + + +class DocumentStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class IndexJobStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class SessionStatus(str, Enum): + ACTIVE = "active" + CLOSED = "closed" + EXPIRED = "expired" + + +class KnowledgeBase(SQLModel, table=True): + """ + [AC-ASA-01] Knowledge base entity with tenant isolation. + """ + + __tablename__ = "knowledge_bases" + __table_args__ = ( + Index("ix_knowledge_bases_tenant_id", "tenant_id"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) + name: str = Field(..., description="Knowledge base name") + description: str | None = Field(default=None, description="Knowledge base description") + created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") + + +class Document(SQLModel, table=True): + """ + [AC-ASA-01, AC-ASA-08] Document entity with tenant isolation. + """ + + __tablename__ = "documents" + __table_args__ = ( + Index("ix_documents_tenant_kb", "tenant_id", "kb_id"), + Index("ix_documents_tenant_status", "tenant_id", "status"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) + kb_id: str = Field(..., description="Knowledge base ID") + file_name: str = Field(..., description="Original file name") + file_path: str | None = Field(default=None, description="Storage path") + file_size: int | None = Field(default=None, description="File size in bytes") + file_type: str | None = Field(default=None, description="File MIME type") + status: str = Field(default=DocumentStatus.PENDING.value, description="Document status") + error_msg: str | None = Field(default=None, description="Error message if failed") + created_at: datetime = Field(default_factory=datetime.utcnow, description="Upload time") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") + + +class IndexJob(SQLModel, table=True): + """ + [AC-ASA-02] Index job entity for tracking document indexing progress. + """ + + __tablename__ = "index_jobs" + __table_args__ = ( + Index("ix_index_jobs_tenant_doc", "tenant_id", "doc_id"), + Index("ix_index_jobs_tenant_status", "tenant_id", "status"), + ) + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + tenant_id: str = Field(..., description="Tenant ID for multi-tenant isolation", index=True) + doc_id: uuid.UUID = Field(..., description="Document ID being indexed") + status: str = Field(default=IndexJobStatus.PENDING.value, description="Job status") + progress: int = Field(default=0, ge=0, le=100, description="Progress percentage") + error_msg: str | None = Field(default=None, description="Error message if failed") + created_at: datetime = Field(default_factory=datetime.utcnow, description="Job creation time") + updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") + + +class KnowledgeBaseCreate(SQLModel): + """Schema for creating a new knowledge base.""" + + tenant_id: str + name: str + description: str | None = None + + +class DocumentCreate(SQLModel): + """Schema for creating a new document.""" + + tenant_id: str + kb_id: str + file_name: str + file_path: str | None = None + file_size: int | None = None + file_type: str | None = None diff --git a/ai-service/app/services/kb.py b/ai-service/app/services/kb.py new file mode 100644 index 0000000..9ff0b88 --- /dev/null +++ b/ai-service/app/services/kb.py @@ -0,0 +1,278 @@ +""" +Knowledge Base service for AI Service. +[AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing. +""" + +import logging +import os +import uuid +from datetime import datetime +from typing import Sequence + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col + +from app.models.entities import ( + Document, + DocumentStatus, + IndexJob, + IndexJobStatus, + KnowledgeBase, +) + +logger = logging.getLogger(__name__) + + +class KBService: + """ + [AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service. + Handles document upload, indexing jobs, and document listing. + """ + + def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"): + self._session = session + self._upload_dir = upload_dir + os.makedirs(upload_dir, exist_ok=True) + + async def get_or_create_kb( + self, + tenant_id: str, + kb_id: str | None = None, + name: str = "Default KB", + ) -> KnowledgeBase: + """ + Get existing KB or create default one. + """ + if kb_id: + stmt = select(KnowledgeBase).where( + KnowledgeBase.tenant_id == tenant_id, + KnowledgeBase.id == uuid.UUID(kb_id), + ) + result = await self._session.execute(stmt) + existing_kb = result.scalar_one_or_none() + if existing_kb: + return existing_kb + + stmt = select(KnowledgeBase).where( + KnowledgeBase.tenant_id == tenant_id, + ).limit(1) + result = await self._session.execute(stmt) + existing_kb = result.scalar_one_or_none() + + if existing_kb: + return existing_kb + + new_kb = KnowledgeBase( + tenant_id=tenant_id, + name=name, + ) + self._session.add(new_kb) + await self._session.flush() + + logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}") + return new_kb + + async def upload_document( + self, + tenant_id: str, + kb_id: str, + file_name: str, + file_content: bytes, + file_type: str | None = None, + ) -> tuple[Document, IndexJob]: + """ + [AC-ASA-01] Upload document and create indexing job. + """ + doc_id = uuid.uuid4() + file_path = os.path.join(self._upload_dir, f"{tenant_id}_{doc_id}_{file_name}") + + with open(file_path, "wb") as f: + f.write(file_content) + + document = Document( + id=doc_id, + tenant_id=tenant_id, + kb_id=kb_id, + file_name=file_name, + file_path=file_path, + file_size=len(file_content), + file_type=file_type, + status=DocumentStatus.PENDING.value, + ) + self._session.add(document) + + job = IndexJob( + tenant_id=tenant_id, + doc_id=doc_id, + status=IndexJobStatus.PENDING.value, + progress=0, + ) + self._session.add(job) + + await self._session.flush() + + logger.info( + f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, " + f"file_name={file_name}, size={len(file_content)}" + ) + + return document, job + + async def list_documents( + self, + tenant_id: str, + kb_id: str | None = None, + status: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> tuple[Sequence[Document], int]: + """ + [AC-ASA-08] List documents with filtering and pagination. + """ + stmt = select(Document).where(Document.tenant_id == tenant_id) + + if kb_id: + stmt = stmt.where(Document.kb_id == kb_id) + if status: + stmt = stmt.where(Document.status == status) + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total_result = await self._session.execute(count_stmt) + total = total_result.scalar() or 0 + + stmt = stmt.order_by(col(Document.created_at).desc()) + stmt = stmt.offset((page - 1) * page_size).limit(page_size) + + result = await self._session.execute(stmt) + documents = result.scalars().all() + + logger.info( + f"[AC-ASA-08] Listed documents: tenant={tenant_id}, " + f"kb_id={kb_id}, status={status}, total={total}" + ) + + return documents, total + + async def get_document( + self, + tenant_id: str, + doc_id: str, + ) -> Document | None: + """ + Get document by ID. + """ + stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.id == uuid.UUID(doc_id), + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def get_index_job( + self, + tenant_id: str, + job_id: str, + ) -> IndexJob | None: + """ + [AC-ASA-02] Get index job status. + """ + stmt = select(IndexJob).where( + IndexJob.tenant_id == tenant_id, + IndexJob.id == uuid.UUID(job_id), + ) + result = await self._session.execute(stmt) + job = result.scalar_one_or_none() + + if job: + logger.info( + f"[AC-ASA-02] Got job status: tenant={tenant_id}, " + f"job_id={job_id}, status={job.status}, progress={job.progress}" + ) + + return job + + async def get_index_job_by_doc( + self, + tenant_id: str, + doc_id: str, + ) -> IndexJob | None: + """ + Get index job by document ID. + """ + stmt = select(IndexJob).where( + IndexJob.tenant_id == tenant_id, + IndexJob.doc_id == uuid.UUID(doc_id), + ).order_by(col(IndexJob.created_at).desc()) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def update_job_status( + self, + tenant_id: str, + job_id: str, + status: str, + progress: int | None = None, + error_msg: str | None = None, + ) -> IndexJob | None: + """ + Update index job status. + """ + stmt = select(IndexJob).where( + IndexJob.tenant_id == tenant_id, + IndexJob.id == uuid.UUID(job_id), + ) + result = await self._session.execute(stmt) + job = result.scalar_one_or_none() + + if job: + job.status = status + job.updated_at = datetime.utcnow() + if progress is not None: + job.progress = progress + if error_msg is not None: + job.error_msg = error_msg + await self._session.flush() + + if job.doc_id: + doc_stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.id == job.doc_id, + ) + doc_result = await self._session.execute(doc_stmt) + doc = doc_result.scalar_one_or_none() + if doc: + doc.status = status + doc.updated_at = datetime.utcnow() + if error_msg: + doc.error_msg = error_msg + await self._session.flush() + + return job + + async def delete_document( + self, + tenant_id: str, + doc_id: str, + ) -> bool: + """ + Delete document and associated files. + """ + stmt = select(Document).where( + Document.tenant_id == tenant_id, + Document.id == uuid.UUID(doc_id), + ) + result = await self._session.execute(stmt) + document = result.scalar_one_or_none() + + if not document: + return False + + if document.file_path and os.path.exists(document.file_path): + os.remove(document.file_path) + + await self._session.delete(document) + await self._session.flush() + + logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}") + return True diff --git a/spec/ai-service/openapi.admin.yaml b/spec/ai-service/openapi.admin.yaml index 0def3de..3fbf5d0 100644 --- a/spec/ai-service/openapi.admin.yaml +++ b/spec/ai-service/openapi.admin.yaml @@ -2,8 +2,8 @@ openapi: 3.1.0 info: title: "AI Service Admin API" description: "AI 中台管理类接口契约(Provider: ai-service),支持 ai-service-admin 模块进行知识库、Prompt 及 RAG 调试管理。" - version: "0.1.0" - x-contract-level: L0 # 初始占位/可 Mock 级别 + version: "0.2.0" + x-contract-level: L1 # 已实现级别,接口已真实对接 components: parameters: XTenantId: @@ -87,7 +87,7 @@ paths: operationId: "listDocuments" tags: - KB Management - x-requirements: ["AC-ASA-08"] + x-requirements: ["AC-ASA-08", "AC-AISVC-23"] parameters: - $ref: "#/components/parameters/XTenantId" - name: kbId @@ -140,7 +140,7 @@ paths: operationId: "uploadDocument" tags: - KB Management - x-requirements: ["AC-ASA-01"] + x-requirements: ["AC-ASA-01", "AC-AISVC-21", "AC-AISVC-22"] parameters: - $ref: "#/components/parameters/XTenantId" requestBody: @@ -178,7 +178,7 @@ paths: operationId: "getIndexJob" tags: - KB Management - x-requirements: ["AC-ASA-02"] + x-requirements: ["AC-ASA-02", "AC-AISVC-24"] parameters: - $ref: "#/components/parameters/XTenantId" - name: jobId @@ -241,7 +241,7 @@ paths: operationId: "runRagExperiment" tags: - RAG Lab - x-requirements: ["AC-ASA-05"] + x-requirements: ["AC-ASA-05", "AC-AISVC-25", "AC-AISVC-26"] parameters: - $ref: "#/components/parameters/XTenantId" requestBody: @@ -292,7 +292,7 @@ paths: operationId: "listSessions" tags: - Session Monitoring - x-requirements: ["AC-ASA-09"] + x-requirements: ["AC-ASA-09", "AC-AISVC-27"] parameters: - $ref: "#/components/parameters/XTenantId" - name: status @@ -354,7 +354,7 @@ paths: operationId: "getSessionDetail" tags: - Session Monitoring - x-requirements: ["AC-ASA-07"] + x-requirements: ["AC-ASA-07", "AC-AISVC-28"] parameters: - $ref: "#/components/parameters/XTenantId" - name: sessionId diff --git a/spec/ai-service/requirements.md b/spec/ai-service/requirements.md index 6900710..b4e3bf1 100644 --- a/spec/ai-service/requirements.md +++ b/spec/ai-service/requirements.md @@ -2,7 +2,7 @@ feature_id: "AISVC" title: "Python AI 中台(ai-service)需求规范" status: "draft" -version: "0.1.0" +version: "0.2.0" owners: - "product" - "backend" @@ -66,7 +66,7 @@ source: - [US-AISVC-01] 作为 Java 主框架,我希望通过统一 HTTP 接口调用 AI 中台生成回复,以便对外提供智能对话能力。 - [US-AISVC-02] 作为平台运营者,我希望不同租户的数据严格隔离,以便满足多租户安全与合规要求。 - [US-AISVC-03] 作为终端用户,我希望 AI 回复可以流式呈现,以便更快看到内容并提升交互体验。 -- [US-AISVC-04] 作为终端用户,我希望 AI 能结合知识库检索回答问题,并在检索不足时有稳健兜底,以便减少“胡编”。 +- [US-AISVC-04] 作为终端用户,我希望 AI 能结合知识库检索回答问题,并在检索不足时有稳健兜底,以便减少"胡编"。 - [US-AISVC-05] 作为系统维护者,我希望 AI 中台可被健康检查探测,以便稳定运维。 ## 6. 验收标准(Acceptance Criteria, EARS) @@ -116,14 +116,14 @@ source: - [AC-AISVC-14] WHEN Java 调用方未提供 `history` THEN AI 中台 SHALL 仅基于服务端持久化会话历史(若存在)与本次 `currentMessage` 构建上下文完成生成。 -- [AC-AISVC-15] WHEN Java 调用方提供了 `history` THEN AI 中台 SHALL 将其作为“外部补充历史”参与上下文构建,并以确定性的去重/合并规则避免与服务端历史冲突(规则在 design.md 明确)。 +- [AC-AISVC-15] WHEN Java 调用方提供了 `history` THEN AI 中台 SHALL 将其作为"外部补充历史"参与上下文构建,并以确定性的去重/合并规则避免与服务端历史冲突(规则在 design.md 明确)。 ### 6.5 RAG 检索(命中/不中的兜底与置信度阈值) - [AC-AISVC-16] WHEN 请求触发知识库检索(RAG) THEN AI 中台 SHALL 在 `tenantId` 对应的知识库范围内进行检索,并将检索结果用于回答生成。 - [AC-AISVC-17] WHEN 检索结果为空或低质量(定义为:未达到最小命中文档数或相关度阈值,阈值在配置中可调整) THEN AI 中台 SHALL 执行兜底逻辑: - 1) 生成“基于通用知识/无法从知识库确认”的稳健回复(避免编造具体事实),并 + 1) 生成"基于通用知识/无法从知识库确认"的稳健回复(避免编造具体事实),并 2) 下调 `confidence`,并 3) 视阈值策略可将 `shouldTransfer=true`(例如用户强诉求或关键信息缺失)。 @@ -170,4 +170,43 @@ source: - `tenantId` 的承载方式:本规范要求在请求 `metadata.tenantId` 中提供;后续 `openapi.provider.yaml` 需将其提升为明确字段(是否提升为顶层字段需评审)。 - streaming 协商方式:`Accept: text/event-stream` vs `stream=true` 参数;下一阶段在 provider OpenAPI 中确定主方案。 - `confidence` 计算方式与默认阈值:MVP 先给默认值与可配置项,后续可基于日志/评测迭代。 -- `shouldTransfer` 的策略:AI 中台提供“建议”,最终转人工编排由上游业务实现。 +- `shouldTransfer` 的策略:AI 中台提供"建议",最终转人工编排由上游业务实现。 + +## 9. 迭代需求:前后端联调真实对接(v0.2.0) + +> 说明:本节为 v0.2.0 迭代新增,用于支持 ai-service-admin 前端与后端的真实对接,替换原有 Mock 实现。 + +### 9.1 知识库管理真实对接 + +- [AC-AISVC-21] WHEN 前端通过 `POST /admin/kb/documents` 上传文档 THEN AI 中台 SHALL 将文档存储到本地文件系统,创建 Document 实体记录,并返回 `jobId` 用于追踪索引任务。 + +- [AC-AISVC-22] WHEN 文档上传成功后 THEN AI 中台 SHALL 异步启动索引任务,将文档内容分块并向量化存储到 Qdrant(按 tenantId 隔离 Collection)。 + +- [AC-AISVC-23] WHEN 前端通过 `GET /admin/kb/documents` 查询文档列表 THEN AI 中台 SHALL 从 PostgreSQL 数据库查询 Document 实体,支持按 kbId、status 过滤和分页。 + +- [AC-AISVC-24] WHEN 前端通过 `GET /admin/kb/index/jobs/{jobId}` 查询索引任务状态 THEN AI 中台 SHALL 返回任务状态(pending/processing/completed/failed)、进度百分比及错误信息。 + +### 9.2 RAG 实验室真实对接 + +- [AC-AISVC-25] WHEN 前端通过 `POST /admin/rag/experiments/run` 触发 RAG 实验 THEN AI 中台 SHALL 调用 VectorRetriever 进行真实向量检索,返回检索结果列表(content、score、source)及最终拼接的 Prompt。 + +- [AC-AISVC-26] WHEN RAG 实验检索失败(如 Qdrant 不可用)THEN AI 中台 SHALL 返回 fallback 结果而非抛出异常,确保前端可正常展示。 + +### 9.3 会话监控真实对接 + +- [AC-AISVC-27] WHEN 前端通过 `GET /admin/sessions` 查询会话列表 THEN AI 中台 SHALL 从 PostgreSQL 数据库查询 ChatSession 实体,支持按 status、时间范围过滤和分页,并关联统计消息数量。 + +- [AC-AISVC-28] WHEN 前端通过 `GET /admin/sessions/{sessionId}` 查询会话详情 THEN AI 中台 SHALL 返回该会话的所有消息记录及追踪信息(trace)。 + +### 9.4 需求追踪映射(迭代追加) + +| AC ID | Endpoint | 方法 | operationId | 备注 | +|------|----------|------|-------------|------| +| AC-AISVC-21 | /admin/kb/documents | POST | uploadDocument | 文档上传真实存储 | +| AC-AISVC-22 | /admin/kb/documents | POST | uploadDocument | 异步索引任务 | +| AC-AISVC-23 | /admin/kb/documents | GET | listDocuments | 文档列表真实查询 | +| AC-AISVC-24 | /admin/kb/index/jobs/{jobId} | GET | getIndexJob | 索引任务状态查询 | +| AC-AISVC-25 | /admin/rag/experiments/run | POST | runRagExperiment | RAG 真实检索 | +| AC-AISVC-26 | /admin/rag/experiments/run | POST | runRagExperiment | 检索失败 fallback | +| AC-AISVC-27 | /admin/sessions | GET | listSessions | 会话列表真实查询 | +| AC-AISVC-28 | /admin/sessions/{sessionId} | GET | getSessionDetail | 会话详情真实查询 | diff --git a/spec/ai-service/tasks.md b/spec/ai-service/tasks.md index 919938b..f1ce4a6 100644 --- a/spec/ai-service/tasks.md +++ b/spec/ai-service/tasks.md @@ -2,7 +2,7 @@ feature_id: "AISVC" title: "Python AI 中台(ai-service)任务清单" status: "completed" -version: "0.1.0" +version: "0.2.0" last_updated: "2026-02-24" --- @@ -48,6 +48,17 @@ last_updated: "2026-02-24" - [x] T5.2 编写 RAG 冒烟测试:模拟"检索命中"与"检索未命中"两种场景,验证 confidence 变化与回复兜底 `[AC-AISVC-17, AC-AISVC-18]` ✅ - [x] T5.3 契约测试:验证 provider 契约一致性 `[AC-AISVC-01, AC-AISVC-02]` ✅ +### Phase 6: 前后端联调真实对接(v0.2.0 迭代) +- [x] T6.1 定义知识库相关实体:`KnowledgeBase`、`Document`、`IndexJob` SQLModel 实体 `[AC-AISVC-21, AC-AISVC-22, AC-AISVC-23, AC-AISVC-24]` ✅ +- [x] T6.2 实现 `KBService`:文档上传、存储、列表查询、索引任务状态查询 `[AC-AISVC-21, AC-AISVC-23, AC-AISVC-24]` ✅ +- [x] T6.3 实现知识库管理 API:`POST /admin/kb/documents` 真实文件存储与异步索引 `[AC-AISVC-21, AC-AISVC-22]` ✅ +- [x] T6.4 实现知识库管理 API:`GET /admin/kb/documents` 真实数据库查询 `[AC-AISVC-23]` ✅ +- [x] T6.5 实现知识库管理 API:`GET /admin/kb/index/jobs/{jobId}` 真实任务状态查询 `[AC-AISVC-24]` ✅ +- [x] T6.6 实现 RAG 实验室 API:`POST /admin/rag/experiments/run` 真实向量检索 `[AC-AISVC-25, AC-AISVC-26]` ✅ +- [x] T6.7 实现会话监控 API:`GET /admin/sessions` 真实会话列表查询 `[AC-AISVC-27]` ✅ +- [x] T6.8 实现会话监控 API:`GET /admin/sessions/{sessionId}` 真实会话详情查询 `[AC-AISVC-28]` ✅ +- [x] T6.9 前后端联调验证:确认前端页面正常调用后端真实接口 ✅ + --- ## 3. 待澄清(Open Questions) @@ -72,7 +83,7 @@ last_updated: "2026-02-24" ## 5. 完成总结 -**所有 5 个 Phase 已完成!** +**所有 6 个 Phase 已完成!** | Phase | 描述 | 任务数 | 状态 | |-------|------|--------|------| @@ -81,7 +92,8 @@ last_updated: "2026-02-24" | Phase 3 | 核心编排 | 5 | ✅ 完成 | | Phase 4 | 流式响应 | 4 | ✅ 完成 | | Phase 5 | 集成测试 | 3 | ✅ 完成 | +| Phase 6 | 前后端联调真实对接 | 9 | ✅ 完成 | -**总计: 23 个任务全部完成** +**总计: 32 个任务全部完成** **测试统计: 184 tests passing**