ai-robot-core/ai-service/app/services/guardrail/tester.py

217 lines
6.3 KiB
Python
Raw Normal View History

"""
Guardrail Tester for AI Service.
[AC-AISVC-105] Forbidden word testing service.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import ForbiddenWord, ForbiddenWordStrategy
from app.services.guardrail.word_service import ForbiddenWordService
logger = logging.getLogger(__name__)
@dataclass
class TriggeredWordInfo:
"""Information about a triggered forbidden word."""
word: str
category: str
strategy: str
replacement: str | None = None
fallbackReply: str | None = None
def to_dict(self) -> dict[str, Any]:
result = {
"word": self.word,
"category": self.category,
"strategy": self.strategy,
}
if self.replacement is not None:
result["replacement"] = self.replacement
if self.fallbackReply is not None:
result["fallbackReply"] = self.fallbackReply
return result
@dataclass
class GuardrailTestResult:
"""Result of testing a single text."""
originalText: str
triggered: bool
triggeredWords: list[TriggeredWordInfo] = field(default_factory=list)
filteredText: str = ""
blocked: bool = False
def to_dict(self) -> dict[str, Any]:
return {
"originalText": self.originalText,
"triggered": self.triggered,
"triggeredWords": [w.to_dict() for w in self.triggeredWords],
"filteredText": self.filteredText,
"blocked": self.blocked,
}
@dataclass
class GuardrailTestSummary:
"""Summary of guardrail testing."""
totalTests: int
triggeredCount: int
blockedCount: int
triggerRate: float
def to_dict(self) -> dict[str, Any]:
return {
"totalTests": self.totalTests,
"triggeredCount": self.triggeredCount,
"blockedCount": self.blockedCount,
"triggerRate": round(self.triggerRate, 2),
}
@dataclass
class GuardrailTestResponse:
"""Full response for guardrail testing."""
results: list[GuardrailTestResult]
summary: GuardrailTestSummary
def to_dict(self) -> dict[str, Any]:
return {
"results": [r.to_dict() for r in self.results],
"summary": self.summary.to_dict(),
}
class GuardrailTester:
"""
[AC-AISVC-105] Guardrail testing service.
Features:
- Test forbidden word detection
- Apply filter strategies (mask/replace/block)
- Return detailed detection results
- No database modification (read-only test)
"""
DEFAULT_FALLBACK_REPLY = "抱歉,让我换个方式回答您"
def __init__(self, session: AsyncSession):
self._session = session
self._word_service = ForbiddenWordService(session)
async def test_guardrail(
self,
tenant_id: str,
test_texts: list[str],
) -> GuardrailTestResponse:
"""
[AC-AISVC-105] Test forbidden word detection and filtering.
Args:
tenant_id: Tenant ID for isolation
test_texts: List of texts to test
Returns:
GuardrailTestResponse with results and summary
"""
logger.info(
f"[AC-AISVC-105] Testing guardrail for tenant={tenant_id}, "
f"texts_count={len(test_texts)}"
)
words = await self._word_service.get_enabled_words_for_filtering(tenant_id)
results: list[GuardrailTestResult] = []
triggered_count = 0
blocked_count = 0
for text in test_texts:
result = self._test_single_text(text, words)
results.append(result)
if result.triggered:
triggered_count += 1
if result.blocked:
blocked_count += 1
total_tests = len(test_texts)
trigger_rate = triggered_count / total_tests if total_tests > 0 else 0.0
summary = GuardrailTestSummary(
totalTests=total_tests,
triggeredCount=triggered_count,
blockedCount=blocked_count,
triggerRate=trigger_rate,
)
logger.info(
f"[AC-AISVC-105] Guardrail test completed: tenant={tenant_id}, "
f"triggered={triggered_count}/{total_tests}, blocked={blocked_count}"
)
return GuardrailTestResponse(results=results, summary=summary)
def _test_single_text(
self,
text: str,
words: list[ForbiddenWord],
) -> GuardrailTestResult:
"""Test a single text against forbidden words."""
if not text or not text.strip():
return GuardrailTestResult(
originalText=text,
triggered=False,
filteredText=text,
blocked=False,
)
triggered_words: list[TriggeredWordInfo] = []
filtered_text = text
blocked = False
for word in words:
if word.word in filtered_text:
triggered_words.append(
TriggeredWordInfo(
word=word.word,
category=word.category,
strategy=word.strategy,
replacement=word.replacement,
fallbackReply=word.fallback_reply,
)
)
if word.strategy == ForbiddenWordStrategy.BLOCK.value:
blocked = True
fallback = word.fallback_reply or self.DEFAULT_FALLBACK_REPLY
return GuardrailTestResult(
originalText=text,
triggered=True,
triggeredWords=triggered_words,
filteredText=fallback,
blocked=True,
)
elif word.strategy == ForbiddenWordStrategy.MASK.value:
filtered_text = filtered_text.replace(word.word, "*" * len(word.word))
elif word.strategy == ForbiddenWordStrategy.REPLACE.value:
replacement = word.replacement or ""
filtered_text = filtered_text.replace(word.word, replacement)
return GuardrailTestResult(
originalText=text,
triggered=len(triggered_words) > 0,
triggeredWords=triggered_words,
filteredText=filtered_text,
blocked=blocked,
)