234 lines
7.7 KiB
Python
234 lines
7.7 KiB
Python
"""
|
|
Retrieval Strategy Integration for Dialogue Flow.
|
|
[AC-AISVC-RES-01~15] Integrates StrategyRouter and ModeRouter into dialogue pipeline.
|
|
|
|
Usage:
|
|
from app.services.retrieval.strategy_integration import RetrievalStrategyIntegration
|
|
|
|
integration = RetrievalStrategyIntegration()
|
|
result = await integration.execute(ctx)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from app.services.retrieval.routing_config import (
|
|
RagRuntimeMode,
|
|
StrategyType,
|
|
RoutingConfig,
|
|
StrategyContext,
|
|
StrategyResult,
|
|
)
|
|
from app.services.retrieval.strategy_router import (
|
|
StrategyRouter,
|
|
get_strategy_router,
|
|
)
|
|
from app.services.retrieval.mode_router import (
|
|
ModeRouter,
|
|
ModeRouteResult,
|
|
get_mode_router,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.retrieval.base import RetrievalResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RetrievalStrategyResult:
|
|
"""Combined result from strategy and mode routing."""
|
|
retrieval_result: "RetrievalResult | None"
|
|
final_answer: str | None
|
|
strategy: StrategyType
|
|
mode: RagRuntimeMode
|
|
should_fallback: bool = False
|
|
fallback_reason: str | None = None
|
|
mode_route_result: ModeRouteResult | None = None
|
|
diagnostics: dict[str, Any] = field(default_factory=dict)
|
|
duration_ms: int = 0
|
|
|
|
|
|
class RetrievalStrategyIntegration:
|
|
"""
|
|
[AC-AISVC-RES-01~15] Integration layer for retrieval strategy.
|
|
|
|
Combines StrategyRouter and ModeRouter to provide a unified interface
|
|
for the dialogue pipeline.
|
|
|
|
Flow:
|
|
1. StrategyRouter selects default or enhanced strategy
|
|
2. ModeRouter selects direct, react, or auto mode
|
|
3. Execute retrieval with selected strategy and mode
|
|
4. Handle fallback scenarios
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: RoutingConfig | None = None,
|
|
strategy_router: StrategyRouter | None = None,
|
|
mode_router: ModeRouter | None = None,
|
|
):
|
|
self._config = config or RoutingConfig()
|
|
self._strategy_router = strategy_router or get_strategy_router()
|
|
self._mode_router = mode_router or get_mode_router()
|
|
|
|
@property
|
|
def config(self) -> RoutingConfig:
|
|
"""Get current configuration."""
|
|
return self._config
|
|
|
|
def update_config(self, new_config: RoutingConfig) -> None:
|
|
"""
|
|
[AC-AISVC-RES-15] Update all routing configurations.
|
|
"""
|
|
self._config = new_config
|
|
self._strategy_router.update_config(new_config)
|
|
self._mode_router.update_config(new_config)
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-RES-15] RetrievalStrategyIntegration config updated: "
|
|
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
|
|
)
|
|
|
|
async def execute(
|
|
self,
|
|
ctx: StrategyContext,
|
|
) -> RetrievalStrategyResult:
|
|
"""
|
|
Execute retrieval with strategy and mode routing.
|
|
|
|
Args:
|
|
ctx: Strategy context with tenant, query, metadata, etc.
|
|
|
|
Returns:
|
|
RetrievalStrategyResult with retrieval results and diagnostics
|
|
"""
|
|
start_time = time.time()
|
|
|
|
strategy_result = self._strategy_router.route(ctx)
|
|
|
|
mode_result = self._mode_router.route(ctx)
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-RES-01~15] Strategy routing: "
|
|
f"strategy={strategy_result.strategy.value}, mode={mode_result.mode.value}, "
|
|
f"tenant={ctx.tenant_id}, query_len={len(ctx.query)}"
|
|
)
|
|
|
|
retrieval_result = None
|
|
final_answer = None
|
|
should_fallback = False
|
|
fallback_reason = None
|
|
|
|
try:
|
|
if mode_result.mode == RagRuntimeMode.DIRECT:
|
|
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
|
|
|
|
if answer is not None:
|
|
final_answer = answer
|
|
should_fallback = mode_result.should_fallback_to_react
|
|
fallback_reason = mode_result.fallback_reason
|
|
|
|
elif mode_result.mode == RagRuntimeMode.REACT:
|
|
answer, retrieval_result, react_ctx = await self._mode_router.execute_react(ctx)
|
|
final_answer = answer
|
|
|
|
else:
|
|
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
|
|
|
|
if answer is not None:
|
|
final_answer = answer
|
|
should_fallback = mode_result.should_fallback_to_react
|
|
fallback_reason = mode_result.fallback_reason
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[AC-AISVC-RES-07] Retrieval strategy execution failed: {e}"
|
|
)
|
|
|
|
if strategy_result.strategy == StrategyType.ENHANCED:
|
|
self._strategy_router.rollback(
|
|
reason=str(e),
|
|
tenant_id=ctx.tenant_id,
|
|
)
|
|
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
|
from app.services.retrieval.base import RetrievalContext
|
|
|
|
retriever = await get_optimized_retriever()
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id=ctx.tenant_id,
|
|
query=ctx.query,
|
|
metadata_filter=ctx.metadata_filter,
|
|
kb_ids=ctx.kb_ids,
|
|
)
|
|
retrieval_result = await retriever.retrieve(retrieval_ctx)
|
|
|
|
should_fallback = True
|
|
fallback_reason = str(e)
|
|
|
|
else:
|
|
raise
|
|
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
return RetrievalStrategyResult(
|
|
retrieval_result=retrieval_result,
|
|
final_answer=final_answer,
|
|
strategy=strategy_result.strategy,
|
|
mode=mode_result.mode,
|
|
should_fallback=should_fallback,
|
|
fallback_reason=fallback_reason,
|
|
mode_route_result=mode_result,
|
|
diagnostics={
|
|
"strategy_diagnostics": strategy_result.diagnostics,
|
|
"mode_diagnostics": mode_result.diagnostics,
|
|
"duration_ms": duration_ms,
|
|
},
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
def get_current_strategy(self) -> StrategyType:
|
|
"""Get current active strategy."""
|
|
return self._strategy_router.current_strategy
|
|
|
|
def get_rollback_records(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
"""Get recent rollback records."""
|
|
records = self._strategy_router.get_rollback_records(limit)
|
|
return [
|
|
{
|
|
"timestamp": r.timestamp,
|
|
"from_strategy": r.from_strategy.value,
|
|
"to_strategy": r.to_strategy.value,
|
|
"reason": r.reason,
|
|
"tenant_id": r.tenant_id,
|
|
}
|
|
for r in records
|
|
]
|
|
|
|
def validate_config(self) -> tuple[bool, list[str]]:
|
|
"""Validate current configuration."""
|
|
return self._config.validate()
|
|
|
|
|
|
_integration: RetrievalStrategyIntegration | None = None
|
|
|
|
|
|
def get_retrieval_strategy_integration() -> RetrievalStrategyIntegration:
|
|
"""Get or create RetrievalStrategyIntegration singleton."""
|
|
global _integration
|
|
if _integration is None:
|
|
_integration = RetrievalStrategyIntegration()
|
|
return _integration
|
|
|
|
|
|
def reset_retrieval_strategy_integration() -> None:
|
|
"""Reset RetrievalStrategyIntegration singleton (for testing)."""
|
|
global _integration
|
|
_integration = None
|