ai-robot-core/ai-service/app/services/intent/clarification.py

386 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Clarification mechanism for intent recognition.
[AC-CLARIFY] 澄清机制实现
核心功能:
1. 统一置信度计算
2. 硬拦截规则confidence检查、required_slots检查
3. 澄清状态管理
4. 埋点指标收集
"""
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import uuid
logger = logging.getLogger(__name__)
T_HIGH = 0.75
T_LOW = 0.45
MAX_CLARIFY_RETRY = 3
class ClarifyReason(str, Enum):
INTENT_AMBIGUITY = "intent_ambiguity"
MISSING_SLOT = "missing_slot"
LOW_CONFIDENCE = "low_confidence"
MULTI_INTENT = "multi_intent"
class ClarifyMetrics:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._clarify_trigger_count = 0
cls._instance._clarify_converge_count = 0
cls._instance._misroute_count = 0
return cls._instance
def record_clarify_trigger(self) -> None:
self._clarify_trigger_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] clarify_trigger_count: {self._clarify_trigger_count}")
def record_clarify_converge(self) -> None:
self._clarify_converge_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] clarify_converge_count: {self._clarify_converge_count}")
def record_misroute(self) -> None:
self._misroute_count += 1
logger.debug(f"[AC-CLARIFY-METRICS] misroute_count: {self._misroute_count}")
def get_metrics(self) -> dict[str, int]:
return {
"clarify_trigger_rate": self._clarify_trigger_count,
"clarify_converge_rate": self._clarify_converge_count,
"misroute_rate": self._misroute_count,
}
def get_rates(self, total_requests: int) -> dict[str, float]:
if total_requests == 0:
return {
"clarify_trigger_rate": 0.0,
"clarify_converge_rate": 0.0,
"misroute_rate": 0.0,
}
return {
"clarify_trigger_rate": self._clarify_trigger_count / total_requests,
"clarify_converge_rate": self._clarify_converge_count / total_requests if self._clarify_trigger_count > 0 else 0.0,
"misroute_rate": self._misroute_count / total_requests,
}
def reset(self) -> None:
self._clarify_trigger_count = 0
self._clarify_converge_count = 0
self._misroute_count = 0
def get_clarify_metrics() -> ClarifyMetrics:
return ClarifyMetrics()
@dataclass
class IntentCandidate:
intent_id: str
intent_name: str
confidence: float
response_type: str | None = None
target_kb_ids: list[str] | None = None
flow_id: str | None = None
fixed_reply: str | None = None
transfer_message: str | None = None
def to_dict(self) -> dict[str, Any]:
return {
"intent_id": self.intent_id,
"intent_name": self.intent_name,
"confidence": self.confidence,
"response_type": self.response_type,
"target_kb_ids": self.target_kb_ids,
"flow_id": self.flow_id,
"fixed_reply": self.fixed_reply,
"transfer_message": self.transfer_message,
}
@dataclass
class HybridIntentResult:
intent: IntentCandidate | None
confidence: float
candidates: list[IntentCandidate] = field(default_factory=list)
need_clarify: bool = False
clarify_reason: ClarifyReason | None = None
missing_slots: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"intent": self.intent.to_dict() if self.intent else None,
"confidence": self.confidence,
"candidates": [c.to_dict() for c in self.candidates],
"need_clarify": self.need_clarify,
"clarify_reason": self.clarify_reason.value if self.clarify_reason else None,
"missing_slots": self.missing_slots,
}
@classmethod
def from_fusion_result(cls, fusion_result: Any) -> "HybridIntentResult":
candidates = []
if fusion_result.clarify_candidates:
for c in fusion_result.clarify_candidates:
candidates.append(IntentCandidate(
intent_id=str(c.id),
intent_name=c.name,
confidence=0.0,
response_type=getattr(c, "response_type", None),
target_kb_ids=getattr(c, "target_kb_ids", None),
flow_id=str(c.flow_id) if getattr(c, "flow_id", None) else None,
fixed_reply=getattr(c, "fixed_reply", None),
transfer_message=getattr(c, "transfer_message", None),
))
if fusion_result.final_intent:
final_candidate = IntentCandidate(
intent_id=str(fusion_result.final_intent.id),
intent_name=fusion_result.final_intent.name,
confidence=fusion_result.final_confidence,
response_type=fusion_result.final_intent.response_type,
target_kb_ids=fusion_result.final_intent.target_kb_ids,
flow_id=str(fusion_result.final_intent.flow_id) if fusion_result.final_intent.flow_id else None,
fixed_reply=fusion_result.final_intent.fixed_reply,
transfer_message=fusion_result.final_intent.transfer_message,
)
if not any(c.intent_id == final_candidate.intent_id for c in candidates):
candidates.insert(0, final_candidate)
clarify_reason = None
if fusion_result.need_clarify:
if fusion_result.decision_reason == "multi_intent":
clarify_reason = ClarifyReason.MULTI_INTENT
elif fusion_result.decision_reason == "gray_zone":
clarify_reason = ClarifyReason.INTENT_AMBIGUITY
else:
clarify_reason = ClarifyReason.LOW_CONFIDENCE
return cls(
intent=candidates[0] if candidates else None,
confidence=fusion_result.final_confidence,
candidates=candidates,
need_clarify=fusion_result.need_clarify,
clarify_reason=clarify_reason,
)
@dataclass
class ClarifyState:
reason: ClarifyReason
asked_slot: str | None = None
retry_count: int = 0
candidates: list[IntentCandidate] = field(default_factory=list)
asked_intent_ids: list[str] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, Any]:
return {
"reason": self.reason.value,
"asked_slot": self.asked_slot,
"retry_count": self.retry_count,
"candidates": [c.to_dict() for c in self.candidates],
"asked_intent_ids": self.asked_intent_ids,
"created_at": self.created_at,
}
def increment_retry(self) -> "ClarifyState":
self.retry_count += 1
return self
def is_max_retry(self) -> bool:
return self.retry_count >= MAX_CLARIFY_RETRY
class ClarificationEngine:
def __init__(
self,
t_high: float = T_HIGH,
t_low: float = T_LOW,
max_retry: int = MAX_CLARIFY_RETRY,
):
self._t_high = t_high
self._t_low = t_low
self._max_retry = max_retry
self._metrics = get_clarify_metrics()
def compute_confidence(
self,
rule_score: float = 0.0,
semantic_score: float = 0.0,
llm_score: float = 0.0,
w_rule: float = 0.5,
w_semantic: float = 0.3,
w_llm: float = 0.2,
) -> float:
total_weight = w_rule + w_semantic + w_llm
if total_weight == 0:
return 0.0
weighted_score = (
rule_score * w_rule +
semantic_score * w_semantic +
llm_score * w_llm
)
return min(1.0, max(0.0, weighted_score / total_weight))
def check_hard_block(
self,
result: HybridIntentResult,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> tuple[bool, ClarifyReason | None]:
if result.confidence < self._t_high:
return True, ClarifyReason.LOW_CONFIDENCE
if required_slots and filled_slots is not None:
missing = [s for s in required_slots if s not in filled_slots]
if missing:
return True, ClarifyReason.MISSING_SLOT
return False, None
def should_trigger_clarify(
self,
result: HybridIntentResult,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> tuple[bool, ClarifyState | None]:
if result.confidence >= self._t_high:
if required_slots and filled_slots is not None:
missing = [s for s in required_slots if s not in filled_slots]
if missing:
self._metrics.record_clarify_trigger()
return True, ClarifyState(
reason=ClarifyReason.MISSING_SLOT,
asked_slot=missing[0],
candidates=result.candidates,
)
return False, None
if result.confidence < self._t_low:
self._metrics.record_clarify_trigger()
return True, ClarifyState(
reason=ClarifyReason.LOW_CONFIDENCE,
candidates=result.candidates,
)
self._metrics.record_clarify_trigger()
reason = result.clarify_reason or ClarifyReason.INTENT_AMBIGUITY
return True, ClarifyState(
reason=reason,
candidates=result.candidates,
)
def generate_clarify_prompt(
self,
state: ClarifyState,
slot_label: str | None = None,
) -> str:
if state.reason == ClarifyReason.MISSING_SLOT:
slot_name = slot_label or state.asked_slot or "相关信息"
return f"为了更好地为您服务,请告诉我您的{slot_name}"
if state.reason == ClarifyReason.LOW_CONFIDENCE:
return "抱歉,我不太理解您的意思,能否请您详细描述一下您的需求?"
if state.reason == ClarifyReason.MULTI_INTENT and len(state.candidates) > 1:
candidates = state.candidates[:3]
if len(candidates) == 2:
return (
f"请问您是想「{candidates[0].intent_name}"
f"还是「{candidates[1].intent_name}」?"
)
else:
options = "".join([f"{c.intent_name}" for c in candidates[:-1]])
return f"请问您是想{options},还是「{candidates[-1].intent_name}」?"
if state.reason == ClarifyReason.INTENT_AMBIGUITY and len(state.candidates) > 1:
candidates = state.candidates[:2]
return (
f"请问您是想「{candidates[0].intent_name}"
f"还是「{candidates[1].intent_name}」?"
)
return "请问您具体想了解什么?"
def process_clarify_response(
self,
user_message: str,
state: ClarifyState,
intent_router: Any = None,
rules: list[Any] | None = None,
) -> HybridIntentResult:
state.increment_retry()
if state.is_max_retry():
self._metrics.record_misroute()
return HybridIntentResult(
intent=None,
confidence=0.0,
need_clarify=False,
)
if state.reason == ClarifyReason.MISSING_SLOT:
self._metrics.record_clarify_converge()
return HybridIntentResult(
intent=state.candidates[0] if state.candidates else None,
confidence=0.8,
candidates=state.candidates,
need_clarify=False,
)
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=state.candidates,
need_clarify=True,
clarify_reason=state.reason,
)
def get_metrics(self) -> dict[str, int]:
return self._metrics.get_metrics()
def get_rates(self, total_requests: int) -> dict[str, float]:
return self._metrics.get_rates(total_requests)
class ClarifySessionManager:
_sessions: dict[str, ClarifyState] = {}
@classmethod
def get_session(cls, session_id: str) -> ClarifyState | None:
return cls._sessions.get(session_id)
@classmethod
def set_session(cls, session_id: str, state: ClarifyState) -> None:
cls._sessions[session_id] = state
logger.debug(f"[AC-CLARIFY] Session state set: session={session_id}, reason={state.reason}")
@classmethod
def clear_session(cls, session_id: str) -> None:
if session_id in cls._sessions:
del cls._sessions[session_id]
logger.debug(f"[AC-CLARIFY] Session state cleared: session={session_id}")
@classmethod
def has_active_clarify(cls, session_id: str) -> bool:
state = cls._sessions.get(session_id)
if state:
return not state.is_max_retry()
return False