386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
||
Clarification mechanism for intent recognition.
|
||
[AC-CLARIFY] 澄清机制实现
|
||
|
||
核心功能:
|
||
1. 统一置信度计算
|
||
2. 硬拦截规则(confidence检查、required_slots检查)
|
||
3. 澄清状态管理
|
||
4. 埋点指标收集
|
||
"""
|
||
|
||
import logging
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from typing import Any
|
||
import uuid
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
T_HIGH = 0.75
|
||
T_LOW = 0.45
|
||
MAX_CLARIFY_RETRY = 3
|
||
|
||
|
||
class ClarifyReason(str, Enum):
|
||
INTENT_AMBIGUITY = "intent_ambiguity"
|
||
MISSING_SLOT = "missing_slot"
|
||
LOW_CONFIDENCE = "low_confidence"
|
||
MULTI_INTENT = "multi_intent"
|
||
|
||
|
||
class ClarifyMetrics:
|
||
_instance = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._clarify_trigger_count = 0
|
||
cls._instance._clarify_converge_count = 0
|
||
cls._instance._misroute_count = 0
|
||
return cls._instance
|
||
|
||
def record_clarify_trigger(self) -> None:
|
||
self._clarify_trigger_count += 1
|
||
logger.debug(f"[AC-CLARIFY-METRICS] clarify_trigger_count: {self._clarify_trigger_count}")
|
||
|
||
def record_clarify_converge(self) -> None:
|
||
self._clarify_converge_count += 1
|
||
logger.debug(f"[AC-CLARIFY-METRICS] clarify_converge_count: {self._clarify_converge_count}")
|
||
|
||
def record_misroute(self) -> None:
|
||
self._misroute_count += 1
|
||
logger.debug(f"[AC-CLARIFY-METRICS] misroute_count: {self._misroute_count}")
|
||
|
||
def get_metrics(self) -> dict[str, int]:
|
||
return {
|
||
"clarify_trigger_rate": self._clarify_trigger_count,
|
||
"clarify_converge_rate": self._clarify_converge_count,
|
||
"misroute_rate": self._misroute_count,
|
||
}
|
||
|
||
def get_rates(self, total_requests: int) -> dict[str, float]:
|
||
if total_requests == 0:
|
||
return {
|
||
"clarify_trigger_rate": 0.0,
|
||
"clarify_converge_rate": 0.0,
|
||
"misroute_rate": 0.0,
|
||
}
|
||
|
||
return {
|
||
"clarify_trigger_rate": self._clarify_trigger_count / total_requests,
|
||
"clarify_converge_rate": self._clarify_converge_count / total_requests if self._clarify_trigger_count > 0 else 0.0,
|
||
"misroute_rate": self._misroute_count / total_requests,
|
||
}
|
||
|
||
def reset(self) -> None:
|
||
self._clarify_trigger_count = 0
|
||
self._clarify_converge_count = 0
|
||
self._misroute_count = 0
|
||
|
||
|
||
def get_clarify_metrics() -> ClarifyMetrics:
|
||
return ClarifyMetrics()
|
||
|
||
|
||
@dataclass
|
||
class IntentCandidate:
|
||
intent_id: str
|
||
intent_name: str
|
||
confidence: float
|
||
response_type: str | None = None
|
||
target_kb_ids: list[str] | None = None
|
||
flow_id: str | None = None
|
||
fixed_reply: str | None = None
|
||
transfer_message: str | None = None
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
return {
|
||
"intent_id": self.intent_id,
|
||
"intent_name": self.intent_name,
|
||
"confidence": self.confidence,
|
||
"response_type": self.response_type,
|
||
"target_kb_ids": self.target_kb_ids,
|
||
"flow_id": self.flow_id,
|
||
"fixed_reply": self.fixed_reply,
|
||
"transfer_message": self.transfer_message,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class HybridIntentResult:
|
||
intent: IntentCandidate | None
|
||
confidence: float
|
||
candidates: list[IntentCandidate] = field(default_factory=list)
|
||
need_clarify: bool = False
|
||
clarify_reason: ClarifyReason | None = None
|
||
missing_slots: list[str] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
return {
|
||
"intent": self.intent.to_dict() if self.intent else None,
|
||
"confidence": self.confidence,
|
||
"candidates": [c.to_dict() for c in self.candidates],
|
||
"need_clarify": self.need_clarify,
|
||
"clarify_reason": self.clarify_reason.value if self.clarify_reason else None,
|
||
"missing_slots": self.missing_slots,
|
||
}
|
||
|
||
@classmethod
|
||
def from_fusion_result(cls, fusion_result: Any) -> "HybridIntentResult":
|
||
candidates = []
|
||
if fusion_result.clarify_candidates:
|
||
for c in fusion_result.clarify_candidates:
|
||
candidates.append(IntentCandidate(
|
||
intent_id=str(c.id),
|
||
intent_name=c.name,
|
||
confidence=0.0,
|
||
response_type=getattr(c, "response_type", None),
|
||
target_kb_ids=getattr(c, "target_kb_ids", None),
|
||
flow_id=str(c.flow_id) if getattr(c, "flow_id", None) else None,
|
||
fixed_reply=getattr(c, "fixed_reply", None),
|
||
transfer_message=getattr(c, "transfer_message", None),
|
||
))
|
||
|
||
if fusion_result.final_intent:
|
||
final_candidate = IntentCandidate(
|
||
intent_id=str(fusion_result.final_intent.id),
|
||
intent_name=fusion_result.final_intent.name,
|
||
confidence=fusion_result.final_confidence,
|
||
response_type=fusion_result.final_intent.response_type,
|
||
target_kb_ids=fusion_result.final_intent.target_kb_ids,
|
||
flow_id=str(fusion_result.final_intent.flow_id) if fusion_result.final_intent.flow_id else None,
|
||
fixed_reply=fusion_result.final_intent.fixed_reply,
|
||
transfer_message=fusion_result.final_intent.transfer_message,
|
||
)
|
||
if not any(c.intent_id == final_candidate.intent_id for c in candidates):
|
||
candidates.insert(0, final_candidate)
|
||
|
||
clarify_reason = None
|
||
if fusion_result.need_clarify:
|
||
if fusion_result.decision_reason == "multi_intent":
|
||
clarify_reason = ClarifyReason.MULTI_INTENT
|
||
elif fusion_result.decision_reason == "gray_zone":
|
||
clarify_reason = ClarifyReason.INTENT_AMBIGUITY
|
||
else:
|
||
clarify_reason = ClarifyReason.LOW_CONFIDENCE
|
||
|
||
return cls(
|
||
intent=candidates[0] if candidates else None,
|
||
confidence=fusion_result.final_confidence,
|
||
candidates=candidates,
|
||
need_clarify=fusion_result.need_clarify,
|
||
clarify_reason=clarify_reason,
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class ClarifyState:
|
||
reason: ClarifyReason
|
||
asked_slot: str | None = None
|
||
retry_count: int = 0
|
||
candidates: list[IntentCandidate] = field(default_factory=list)
|
||
asked_intent_ids: list[str] = field(default_factory=list)
|
||
created_at: float = field(default_factory=time.time)
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
return {
|
||
"reason": self.reason.value,
|
||
"asked_slot": self.asked_slot,
|
||
"retry_count": self.retry_count,
|
||
"candidates": [c.to_dict() for c in self.candidates],
|
||
"asked_intent_ids": self.asked_intent_ids,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
def increment_retry(self) -> "ClarifyState":
|
||
self.retry_count += 1
|
||
return self
|
||
|
||
def is_max_retry(self) -> bool:
|
||
return self.retry_count >= MAX_CLARIFY_RETRY
|
||
|
||
|
||
class ClarificationEngine:
|
||
|
||
def __init__(
|
||
self,
|
||
t_high: float = T_HIGH,
|
||
t_low: float = T_LOW,
|
||
max_retry: int = MAX_CLARIFY_RETRY,
|
||
):
|
||
self._t_high = t_high
|
||
self._t_low = t_low
|
||
self._max_retry = max_retry
|
||
self._metrics = get_clarify_metrics()
|
||
|
||
def compute_confidence(
|
||
self,
|
||
rule_score: float = 0.0,
|
||
semantic_score: float = 0.0,
|
||
llm_score: float = 0.0,
|
||
w_rule: float = 0.5,
|
||
w_semantic: float = 0.3,
|
||
w_llm: float = 0.2,
|
||
) -> float:
|
||
total_weight = w_rule + w_semantic + w_llm
|
||
if total_weight == 0:
|
||
return 0.0
|
||
|
||
weighted_score = (
|
||
rule_score * w_rule +
|
||
semantic_score * w_semantic +
|
||
llm_score * w_llm
|
||
)
|
||
|
||
return min(1.0, max(0.0, weighted_score / total_weight))
|
||
|
||
def check_hard_block(
|
||
self,
|
||
result: HybridIntentResult,
|
||
required_slots: list[str] | None = None,
|
||
filled_slots: dict[str, Any] | None = None,
|
||
) -> tuple[bool, ClarifyReason | None]:
|
||
if result.confidence < self._t_high:
|
||
return True, ClarifyReason.LOW_CONFIDENCE
|
||
|
||
if required_slots and filled_slots is not None:
|
||
missing = [s for s in required_slots if s not in filled_slots]
|
||
if missing:
|
||
return True, ClarifyReason.MISSING_SLOT
|
||
|
||
return False, None
|
||
|
||
def should_trigger_clarify(
|
||
self,
|
||
result: HybridIntentResult,
|
||
required_slots: list[str] | None = None,
|
||
filled_slots: dict[str, Any] | None = None,
|
||
) -> tuple[bool, ClarifyState | None]:
|
||
if result.confidence >= self._t_high:
|
||
if required_slots and filled_slots is not None:
|
||
missing = [s for s in required_slots if s not in filled_slots]
|
||
if missing:
|
||
self._metrics.record_clarify_trigger()
|
||
return True, ClarifyState(
|
||
reason=ClarifyReason.MISSING_SLOT,
|
||
asked_slot=missing[0],
|
||
candidates=result.candidates,
|
||
)
|
||
return False, None
|
||
|
||
if result.confidence < self._t_low:
|
||
self._metrics.record_clarify_trigger()
|
||
return True, ClarifyState(
|
||
reason=ClarifyReason.LOW_CONFIDENCE,
|
||
candidates=result.candidates,
|
||
)
|
||
|
||
self._metrics.record_clarify_trigger()
|
||
|
||
reason = result.clarify_reason or ClarifyReason.INTENT_AMBIGUITY
|
||
return True, ClarifyState(
|
||
reason=reason,
|
||
candidates=result.candidates,
|
||
)
|
||
|
||
def generate_clarify_prompt(
|
||
self,
|
||
state: ClarifyState,
|
||
slot_label: str | None = None,
|
||
) -> str:
|
||
if state.reason == ClarifyReason.MISSING_SLOT:
|
||
slot_name = slot_label or state.asked_slot or "相关信息"
|
||
return f"为了更好地为您服务,请告诉我您的{slot_name}。"
|
||
|
||
if state.reason == ClarifyReason.LOW_CONFIDENCE:
|
||
return "抱歉,我不太理解您的意思,能否请您详细描述一下您的需求?"
|
||
|
||
if state.reason == ClarifyReason.MULTI_INTENT and len(state.candidates) > 1:
|
||
candidates = state.candidates[:3]
|
||
if len(candidates) == 2:
|
||
return (
|
||
f"请问您是想「{candidates[0].intent_name}」"
|
||
f"还是「{candidates[1].intent_name}」?"
|
||
)
|
||
else:
|
||
options = "、".join([f"「{c.intent_name}」" for c in candidates[:-1]])
|
||
return f"请问您是想{options},还是「{candidates[-1].intent_name}」?"
|
||
|
||
if state.reason == ClarifyReason.INTENT_AMBIGUITY and len(state.candidates) > 1:
|
||
candidates = state.candidates[:2]
|
||
return (
|
||
f"请问您是想「{candidates[0].intent_name}」"
|
||
f"还是「{candidates[1].intent_name}」?"
|
||
)
|
||
|
||
return "请问您具体想了解什么?"
|
||
|
||
def process_clarify_response(
|
||
self,
|
||
user_message: str,
|
||
state: ClarifyState,
|
||
intent_router: Any = None,
|
||
rules: list[Any] | None = None,
|
||
) -> HybridIntentResult:
|
||
state.increment_retry()
|
||
|
||
if state.is_max_retry():
|
||
self._metrics.record_misroute()
|
||
return HybridIntentResult(
|
||
intent=None,
|
||
confidence=0.0,
|
||
need_clarify=False,
|
||
)
|
||
|
||
if state.reason == ClarifyReason.MISSING_SLOT:
|
||
self._metrics.record_clarify_converge()
|
||
return HybridIntentResult(
|
||
intent=state.candidates[0] if state.candidates else None,
|
||
confidence=0.8,
|
||
candidates=state.candidates,
|
||
need_clarify=False,
|
||
)
|
||
|
||
return HybridIntentResult(
|
||
intent=None,
|
||
confidence=0.0,
|
||
candidates=state.candidates,
|
||
need_clarify=True,
|
||
clarify_reason=state.reason,
|
||
)
|
||
|
||
def get_metrics(self) -> dict[str, int]:
|
||
return self._metrics.get_metrics()
|
||
|
||
def get_rates(self, total_requests: int) -> dict[str, float]:
|
||
return self._metrics.get_rates(total_requests)
|
||
|
||
|
||
class ClarifySessionManager:
|
||
_sessions: dict[str, ClarifyState] = {}
|
||
|
||
@classmethod
|
||
def get_session(cls, session_id: str) -> ClarifyState | None:
|
||
return cls._sessions.get(session_id)
|
||
|
||
@classmethod
|
||
def set_session(cls, session_id: str, state: ClarifyState) -> None:
|
||
cls._sessions[session_id] = state
|
||
logger.debug(f"[AC-CLARIFY] Session state set: session={session_id}, reason={state.reason}")
|
||
|
||
@classmethod
|
||
def clear_session(cls, session_id: str) -> None:
|
||
if session_id in cls._sessions:
|
||
del cls._sessions[session_id]
|
||
logger.debug(f"[AC-CLARIFY] Session state cleared: session={session_id}")
|
||
|
||
@classmethod
|
||
def has_active_clarify(cls, session_id: str) -> bool:
|
||
state = cls._sessions.get(session_id)
|
||
if state:
|
||
return not state.is_max_retry()
|
||
return False
|