217 lines
6.3 KiB
Python
217 lines
6.3 KiB
Python
|
|
"""
|
||
|
|
Guardrail Tester for AI Service.
|
||
|
|
[AC-AISVC-105] Forbidden word testing service.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
from app.models.entities import ForbiddenWord, ForbiddenWordStrategy
|
||
|
|
from app.services.guardrail.word_service import ForbiddenWordService
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TriggeredWordInfo:
|
||
|
|
"""Information about a triggered forbidden word."""
|
||
|
|
|
||
|
|
word: str
|
||
|
|
category: str
|
||
|
|
strategy: str
|
||
|
|
replacement: str | None = None
|
||
|
|
fallbackReply: str | None = None
|
||
|
|
|
||
|
|
def to_dict(self) -> dict[str, Any]:
|
||
|
|
result = {
|
||
|
|
"word": self.word,
|
||
|
|
"category": self.category,
|
||
|
|
"strategy": self.strategy,
|
||
|
|
}
|
||
|
|
if self.replacement is not None:
|
||
|
|
result["replacement"] = self.replacement
|
||
|
|
if self.fallbackReply is not None:
|
||
|
|
result["fallbackReply"] = self.fallbackReply
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class GuardrailTestResult:
|
||
|
|
"""Result of testing a single text."""
|
||
|
|
|
||
|
|
originalText: str
|
||
|
|
triggered: bool
|
||
|
|
triggeredWords: list[TriggeredWordInfo] = field(default_factory=list)
|
||
|
|
filteredText: str = ""
|
||
|
|
blocked: bool = False
|
||
|
|
|
||
|
|
def to_dict(self) -> dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"originalText": self.originalText,
|
||
|
|
"triggered": self.triggered,
|
||
|
|
"triggeredWords": [w.to_dict() for w in self.triggeredWords],
|
||
|
|
"filteredText": self.filteredText,
|
||
|
|
"blocked": self.blocked,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class GuardrailTestSummary:
|
||
|
|
"""Summary of guardrail testing."""
|
||
|
|
|
||
|
|
totalTests: int
|
||
|
|
triggeredCount: int
|
||
|
|
blockedCount: int
|
||
|
|
triggerRate: float
|
||
|
|
|
||
|
|
def to_dict(self) -> dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"totalTests": self.totalTests,
|
||
|
|
"triggeredCount": self.triggeredCount,
|
||
|
|
"blockedCount": self.blockedCount,
|
||
|
|
"triggerRate": round(self.triggerRate, 2),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class GuardrailTestResponse:
|
||
|
|
"""Full response for guardrail testing."""
|
||
|
|
|
||
|
|
results: list[GuardrailTestResult]
|
||
|
|
summary: GuardrailTestSummary
|
||
|
|
|
||
|
|
def to_dict(self) -> dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"results": [r.to_dict() for r in self.results],
|
||
|
|
"summary": self.summary.to_dict(),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class GuardrailTester:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-105] Guardrail testing service.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Test forbidden word detection
|
||
|
|
- Apply filter strategies (mask/replace/block)
|
||
|
|
- Return detailed detection results
|
||
|
|
- No database modification (read-only test)
|
||
|
|
"""
|
||
|
|
|
||
|
|
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
|
||
|
|
|
||
|
|
def __init__(self, session: AsyncSession):
|
||
|
|
self._session = session
|
||
|
|
self._word_service = ForbiddenWordService(session)
|
||
|
|
|
||
|
|
async def test_guardrail(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
test_texts: list[str],
|
||
|
|
) -> GuardrailTestResponse:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-105] Test forbidden word detection and filtering.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant ID for isolation
|
||
|
|
test_texts: List of texts to test
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
GuardrailTestResponse with results and summary
|
||
|
|
"""
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-105] Testing guardrail for tenant={tenant_id}, "
|
||
|
|
f"texts_count={len(test_texts)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
|
||
|
|
|
||
|
|
results: list[GuardrailTestResult] = []
|
||
|
|
triggered_count = 0
|
||
|
|
blocked_count = 0
|
||
|
|
|
||
|
|
for text in test_texts:
|
||
|
|
result = self._test_single_text(text, words)
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
if result.triggered:
|
||
|
|
triggered_count += 1
|
||
|
|
if result.blocked:
|
||
|
|
blocked_count += 1
|
||
|
|
|
||
|
|
total_tests = len(test_texts)
|
||
|
|
trigger_rate = triggered_count / total_tests if total_tests > 0 else 0.0
|
||
|
|
|
||
|
|
summary = GuardrailTestSummary(
|
||
|
|
totalTests=total_tests,
|
||
|
|
triggeredCount=triggered_count,
|
||
|
|
blockedCount=blocked_count,
|
||
|
|
triggerRate=trigger_rate,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-105] Guardrail test completed: tenant={tenant_id}, "
|
||
|
|
f"triggered={triggered_count}/{total_tests}, blocked={blocked_count}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return GuardrailTestResponse(results=results, summary=summary)
|
||
|
|
|
||
|
|
def _test_single_text(
|
||
|
|
self,
|
||
|
|
text: str,
|
||
|
|
words: list[ForbiddenWord],
|
||
|
|
) -> GuardrailTestResult:
|
||
|
|
"""Test a single text against forbidden words."""
|
||
|
|
if not text or not text.strip():
|
||
|
|
return GuardrailTestResult(
|
||
|
|
originalText=text,
|
||
|
|
triggered=False,
|
||
|
|
filteredText=text,
|
||
|
|
blocked=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
triggered_words: list[TriggeredWordInfo] = []
|
||
|
|
filtered_text = text
|
||
|
|
blocked = False
|
||
|
|
|
||
|
|
for word in words:
|
||
|
|
if word.word in filtered_text:
|
||
|
|
triggered_words.append(
|
||
|
|
TriggeredWordInfo(
|
||
|
|
word=word.word,
|
||
|
|
category=word.category,
|
||
|
|
strategy=word.strategy,
|
||
|
|
replacement=word.replacement,
|
||
|
|
fallbackReply=word.fallback_reply,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
|
||
|
|
blocked = True
|
||
|
|
fallback = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
|
||
|
|
return GuardrailTestResult(
|
||
|
|
originalText=text,
|
||
|
|
triggered=True,
|
||
|
|
triggeredWords=triggered_words,
|
||
|
|
filteredText=fallback,
|
||
|
|
blocked=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
elif word.strategy == ForbiddenWordStrategy.MASK.value:
|
||
|
|
filtered_text = filtered_text.replace(word.word, "*" * len(word.word))
|
||
|
|
|
||
|
|
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
|
||
|
|
replacement = word.replacement or ""
|
||
|
|
filtered_text = filtered_text.replace(word.word, replacement)
|
||
|
|
|
||
|
|
return GuardrailTestResult(
|
||
|
|
originalText=text,
|
||
|
|
triggered=len(triggered_words) > 0,
|
||
|
|
triggeredWords=triggered_words,
|
||
|
|
filteredText=filtered_text,
|
||
|
|
blocked=blocked,
|
||
|
|
)
|