Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/scanner/benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@
"\n",
"await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore\n",
"\n",
"# Pass any number of adversarial PromptChatTargets as a list; AdversarialBenchmark\n",
"# Pass any number of adversarial PromptTarget instances (with chat-target\n",
"# capabilities \u2014 multi-turn and editable history) as a list; AdversarialBenchmark\n",
"# infers a label for each from its identifier and runs every benchmark-friendly\n",
"# attack technique against the objective target with each adversarial model.\n",
"adversarial_model = OpenAIChatTarget()\n",
Expand Down
3 changes: 2 additions & 1 deletion doc/scanner/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore

# Pass any number of adversarial PromptChatTargets as a list; AdversarialBenchmark
# Pass any number of adversarial PromptTarget instances (with chat-target
# capabilities — multi-turn and editable history) as a list; AdversarialBenchmark
# infers a label for each from its identifier and runs every benchmark-friendly
# attack technique against the objective target with each adversarial model.
adversarial_model = OpenAIChatTarget()
Expand Down
38 changes: 27 additions & 11 deletions pyrit/scenario/scenarios/benchmark/adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pyrit.common import apply_defaults
from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig
from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS
from pyrit.registry import AttackTechniqueRegistry, AttackTechniqueSpec
from pyrit.registry.tag_query import TagQuery
from pyrit.scenario.core.atomic_attack import AtomicAttack
Expand All @@ -18,7 +19,7 @@
from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES

if TYPE_CHECKING:
from pyrit.prompt_target import PromptChatTarget
from pyrit.prompt_target import PromptTarget
from pyrit.scenario.core.scenario_strategy import ScenarioStrategy
from pyrit.score import TrueFalseScorer

Expand Down Expand Up @@ -79,15 +80,19 @@ def default_dataset_config(cls) -> DatasetConfiguration:
def __init__(
self,
*,
adversarial_models: list[PromptChatTarget],
adversarial_models: list[PromptTarget],
objective_scorer: TrueFalseScorer | None = None,
scenario_result_id: str | None = None,
) -> None:
"""
Initialize the AdversarialBenchmark scenario.

Args:
adversarial_models: A non-empty list of ``PromptChatTarget`` instances.
adversarial_models: A non-empty list of ``PromptTarget`` instances
that each satisfy :data:`CHAT_TARGET_REQUIREMENTS` (multi-turn
with editable history). Individual techniques selected at
run time may impose stricter capability requirements which are
enforced when their attack instances are constructed.
Labels are inferred from each target's identifier (preferring
``underlying_model_name`` over ``model_name`` over the class
name). Identical targets are silently deduped and distinct
Expand All @@ -99,13 +104,24 @@ def __init__(
result to resume.

Raises:
ValueError: If ``adversarial_models`` is empty or not a list.
ValueError: If ``adversarial_models`` is empty, not a list, or
contains a target that does not satisfy
:data:`CHAT_TARGET_REQUIREMENTS`.
"""
if not adversarial_models:
raise ValueError("adversarial_models must be a non-empty list of PromptChatTarget instances.")
raise ValueError("adversarial_models must be a non-empty list of PromptTarget instances.")

if not isinstance(adversarial_models, list):
raise ValueError("adversarial_models must be a list of PromptChatTarget instances.")
raise ValueError("adversarial_models must be a list of PromptTarget instances.")

for target in adversarial_models:
try:
CHAT_TARGET_REQUIREMENTS.validate(target=target)
except ValueError as exc:
raise ValueError(
f"adversarial_models entry {type(target).__name__} does not satisfy "
f"the chat-target capability requirements: {exc}"
) from exc

# Infer labels, then wrap each bare target in a default AttackAdversarialConfig
# so it can be passed to factory.create() as an override.
Expand Down Expand Up @@ -184,8 +200,8 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]:
@staticmethod
def _infer_labels(
*,
items: list[PromptChatTarget],
) -> dict[str, PromptChatTarget]:
items: list[PromptTarget],
) -> dict[str, PromptTarget]:
"""
Infer user-facing labels for a list of adversarial targets.

Expand All @@ -195,14 +211,14 @@ def _infer_labels(
and a ``logger.warning`` so the situation isn't silent.

Args:
items: List of ``PromptChatTarget`` instances.
items: List of ``PromptTarget`` instances.

Returns:
dict[str, PromptChatTarget]: Mapping from inferred label to the
dict[str, PromptTarget]: Mapping from inferred label to the
original target. Targets are wrapped in an
``AttackAdversarialConfig`` by ``__init__`` after this call.
"""
result: dict[str, PromptChatTarget] = {}
result: dict[str, PromptTarget] = {}
seen_keys: dict[str, str | None] = {}

for target in items:
Expand Down
38 changes: 33 additions & 5 deletions tests/unit/scenario/test_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
SeedObjective,
SeedPrompt,
)
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration
from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry
from pyrit.scenario.core import AtomicAttack, BaselinePolicy
from pyrit.scenario.core.dataset_configuration import DatasetConfiguration
Expand Down Expand Up @@ -59,17 +58,32 @@ def _mock_id(name: str, *, params: dict | None = None) -> ComponentIdentifier:
return ComponentIdentifier(class_name=name, class_module="test", params=params or {})


_CHAT_TARGET_CONFIGURATION = TargetConfiguration(
capabilities=TargetCapabilities(
supports_multi_turn=True,
supports_multi_message_pieces=True,
supports_system_prompt=True,
supports_editable_history=True,
),
)


def _make_adversarial_target(name: str, *, params: dict | None = None) -> MagicMock:
"""Create a mock PromptChatTarget with a given model name and optional identifier params.
"""Create a mock adversarial PromptTarget with a given model name and optional identifier params.

By default, ``model_name`` is stamped into the identifier params so the
inferred label produced by ``_infer_labels`` matches ``name``. Pass an
explicit ``params`` dict to override (e.g. to omit the key for collision
testing or to add ``underlying_model_name`` / ``endpoint``).

The mock exposes a real ``TargetConfiguration`` declaring multi-turn and
editable history so the target satisfies ``CHAT_TARGET_REQUIREMENTS`` at
construction time.
"""
mock = MagicMock(spec=PromptChatTarget)
mock = MagicMock(spec=PromptTarget)
mock._model_name = name
mock.get_identifier.return_value = _mock_id(name, params=params if params is not None else {"model_name": name})
mock.configuration = _CHAT_TARGET_CONFIGURATION
return mock


Expand Down Expand Up @@ -168,9 +182,23 @@ def test_empty_list_adversarial_models_raises(self):

def test_unsupported_type_adversarial_models_raises(self):
"""Passing a non-list type must raise ValueError."""
with pytest.raises(ValueError, match="non-empty list|list of PromptChatTarget"):
with pytest.raises(ValueError, match="non-empty list|list of PromptTarget"):
AdversarialBenchmark(adversarial_models="not-a-list") # type: ignore[arg-type]

def test_adversarial_model_missing_chat_capabilities_raises(self):
"""A target that does not satisfy CHAT_TARGET_REQUIREMENTS must be rejected at construction."""
non_chat_target = MagicMock(spec=PromptTarget)
non_chat_target.get_identifier.return_value = _mock_id("NonChatTarget")
non_chat_target.configuration = TargetConfiguration(
capabilities=TargetCapabilities(
supports_multi_turn=False,
supports_editable_history=False,
),
)

with pytest.raises(ValueError, match="chat-target capability requirements"):
AdversarialBenchmark(adversarial_models=[non_chat_target])

def test_version_is_1(self):
assert AdversarialBenchmark.VERSION == 1

Expand Down
Loading