From 2facd91be03ea5d8d84944e8ea5ed4e91e9087c1 Mon Sep 17 00:00:00 2001 From: romanlutz Date: Mon, 18 May 2026 06:03:10 -0700 Subject: [PATCH] MAINT: migrate AdversarialBenchmark off deprecated PromptChatTarget Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/scanner/benchmark.ipynb | 3 +- doc/scanner/benchmark.py | 3 +- .../scenarios/benchmark/adversarial.py | 38 +++++++++++++------ tests/unit/scenario/test_adversarial.py | 38 ++++++++++++++++--- 4 files changed, 64 insertions(+), 18 deletions(-) diff --git a/doc/scanner/benchmark.ipynb b/doc/scanner/benchmark.ipynb index 88c92c1b7e..35b0ae3cf2 100644 --- a/doc/scanner/benchmark.ipynb +++ b/doc/scanner/benchmark.ipynb @@ -115,7 +115,8 @@ "\n", "await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore\n", "\n", - "# Pass any number of adversarial PromptChatTargets as a list; AdversarialBenchmark\n", + "# Pass any number of adversarial PromptTarget instances (with chat-target\n", + "# capabilities \u2014 multi-turn and editable history) as a list; AdversarialBenchmark\n", "# infers a label for each from its identifier and runs every benchmark-friendly\n", "# attack technique against the objective target with each adversarial model.\n", "adversarial_model = OpenAIChatTarget()\n", diff --git a/doc/scanner/benchmark.py b/doc/scanner/benchmark.py index 90dcba2a6d..177472af5c 100644 --- a/doc/scanner/benchmark.py +++ b/doc/scanner/benchmark.py @@ -26,7 +26,8 @@ await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore -# Pass any number of adversarial PromptChatTargets as a list; AdversarialBenchmark +# Pass any number of adversarial PromptTarget instances (with chat-target +# capabilities — multi-turn and editable history) as a list; AdversarialBenchmark # infers a label for each from its identifier and runs every benchmark-friendly # attack technique against the objective target with each adversarial model. adversarial_model = OpenAIChatTarget() diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index d3f873f67e..57024e79a0 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -10,6 +10,7 @@ from pyrit.common import apply_defaults from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig +from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS from pyrit.registry import AttackTechniqueRegistry, AttackTechniqueSpec from pyrit.registry.tag_query import TagQuery from pyrit.scenario.core.atomic_attack import AtomicAttack @@ -18,7 +19,7 @@ from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES if TYPE_CHECKING: - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer @@ -79,7 +80,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: def __init__( self, *, - adversarial_models: list[PromptChatTarget], + adversarial_models: list[PromptTarget], objective_scorer: TrueFalseScorer | None = None, scenario_result_id: str | None = None, ) -> None: @@ -87,7 +88,11 @@ def __init__( Initialize the AdversarialBenchmark scenario. Args: - adversarial_models: A non-empty list of ``PromptChatTarget`` instances. + adversarial_models: A non-empty list of ``PromptTarget`` instances + that each satisfy :data:`CHAT_TARGET_REQUIREMENTS` (multi-turn + with editable history). Individual techniques selected at + run time may impose stricter capability requirements which are + enforced when their attack instances are constructed. Labels are inferred from each target's identifier (preferring ``underlying_model_name`` over ``model_name`` over the class name). Identical targets are silently deduped and distinct @@ -99,13 +104,24 @@ def __init__( result to resume. Raises: - ValueError: If ``adversarial_models`` is empty or not a list. + ValueError: If ``adversarial_models`` is empty, not a list, or + contains a target that does not satisfy + :data:`CHAT_TARGET_REQUIREMENTS`. """ if not adversarial_models: - raise ValueError("adversarial_models must be a non-empty list of PromptChatTarget instances.") + raise ValueError("adversarial_models must be a non-empty list of PromptTarget instances.") if not isinstance(adversarial_models, list): - raise ValueError("adversarial_models must be a list of PromptChatTarget instances.") + raise ValueError("adversarial_models must be a list of PromptTarget instances.") + + for target in adversarial_models: + try: + CHAT_TARGET_REQUIREMENTS.validate(target=target) + except ValueError as exc: + raise ValueError( + f"adversarial_models entry {type(target).__name__} does not satisfy " + f"the chat-target capability requirements: {exc}" + ) from exc # Infer labels, then wrap each bare target in a default AttackAdversarialConfig # so it can be passed to factory.create() as an override. @@ -184,8 +200,8 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: @staticmethod def _infer_labels( *, - items: list[PromptChatTarget], - ) -> dict[str, PromptChatTarget]: + items: list[PromptTarget], + ) -> dict[str, PromptTarget]: """ Infer user-facing labels for a list of adversarial targets. @@ -195,14 +211,14 @@ def _infer_labels( and a ``logger.warning`` so the situation isn't silent. Args: - items: List of ``PromptChatTarget`` instances. + items: List of ``PromptTarget`` instances. Returns: - dict[str, PromptChatTarget]: Mapping from inferred label to the + dict[str, PromptTarget]: Mapping from inferred label to the original target. Targets are wrapped in an ``AttackAdversarialConfig`` by ``__init__`` after this call. """ - result: dict[str, PromptChatTarget] = {} + result: dict[str, PromptTarget] = {} seen_keys: dict[str, str | None] = {} for target in items: diff --git a/tests/unit/scenario/test_adversarial.py b/tests/unit/scenario/test_adversarial.py index c4dea8f3c5..4f4ccf7fed 100644 --- a/tests/unit/scenario/test_adversarial.py +++ b/tests/unit/scenario/test_adversarial.py @@ -20,8 +20,7 @@ SeedObjective, SeedPrompt, ) -from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core import AtomicAttack, BaselinePolicy from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -59,17 +58,32 @@ def _mock_id(name: str, *, params: dict | None = None) -> ComponentIdentifier: return ComponentIdentifier(class_name=name, class_module="test", params=params or {}) +_CHAT_TARGET_CONFIGURATION = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ), +) + + def _make_adversarial_target(name: str, *, params: dict | None = None) -> MagicMock: - """Create a mock PromptChatTarget with a given model name and optional identifier params. + """Create a mock adversarial PromptTarget with a given model name and optional identifier params. By default, ``model_name`` is stamped into the identifier params so the inferred label produced by ``_infer_labels`` matches ``name``. Pass an explicit ``params`` dict to override (e.g. to omit the key for collision testing or to add ``underlying_model_name`` / ``endpoint``). + + The mock exposes a real ``TargetConfiguration`` declaring multi-turn and + editable history so the target satisfies ``CHAT_TARGET_REQUIREMENTS`` at + construction time. """ - mock = MagicMock(spec=PromptChatTarget) + mock = MagicMock(spec=PromptTarget) mock._model_name = name mock.get_identifier.return_value = _mock_id(name, params=params if params is not None else {"model_name": name}) + mock.configuration = _CHAT_TARGET_CONFIGURATION return mock @@ -168,9 +182,23 @@ def test_empty_list_adversarial_models_raises(self): def test_unsupported_type_adversarial_models_raises(self): """Passing a non-list type must raise ValueError.""" - with pytest.raises(ValueError, match="non-empty list|list of PromptChatTarget"): + with pytest.raises(ValueError, match="non-empty list|list of PromptTarget"): AdversarialBenchmark(adversarial_models="not-a-list") # type: ignore[arg-type] + def test_adversarial_model_missing_chat_capabilities_raises(self): + """A target that does not satisfy CHAT_TARGET_REQUIREMENTS must be rejected at construction.""" + non_chat_target = MagicMock(spec=PromptTarget) + non_chat_target.get_identifier.return_value = _mock_id("NonChatTarget") + non_chat_target.configuration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_editable_history=False, + ), + ) + + with pytest.raises(ValueError, match="chat-target capability requirements"): + AdversarialBenchmark(adversarial_models=[non_chat_target]) + def test_version_is_1(self): assert AdversarialBenchmark.VERSION == 1