From cd479a8ad5a161190e35f82ba16ea1f4874c940d Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 29 May 2026 16:44:03 -0700 Subject: [PATCH 1/6] Define GCG extension protocols module (typing surface only) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds pyrit/auxiliary_attacks/gcg/extension_protocols.py containing four runtime_checkable Protocol classes that mark the algorithmic seams in the GCG optimization loop where a future caller may substitute custom behavior: - SamplingStrategy.sample_candidates — abstracts GCGPromptManager.sample_control - LossFunction.compute_loss — abstracts the weighted target/control CE - CandidateFilter.filter_candidates — abstracts MultiPromptAttack.get_filtered_cands - SuffixInitializer.make_initial_suffix — abstracts the literal control_init plumbing This PR is pure typing surface: no concrete implementations, no defaults, no wiring into GCGAlgorithmConfig or GCGMultiPromptAttack. The default implementations (extracted byte-for-byte from current attack code with a parity gate) and the optional config fields that select between defaults and custom impls land in follow-up PRs. The module uses `from __future__ import annotations` plus a TYPE_CHECKING import for torch so it imports cleanly on the base `dev` extra (no torch), preserving the invariant added by commit 36aaaa31 in Sub-PR A. All four Protocols are re-exported from pyrit.auxiliary_attacks.gcg via the existing PEP 562 _LAZY_IMPORTS pathway so the public surface is consistent with how GCG / GCGGenerator / GCGContext / GCGResult are exposed. Tests cover module `__all__`, package re-export identity, runtime_checkable positive and negative isinstance, and a return-shape smoke test per protocol with a trivial in-test stub implementation. `pytest.importorskip("torch")` gates the whole file because the stubs construct real `torch.Tensor` arguments for the shape assertions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/__init__.py | 18 ++ .../gcg/extension_protocols.py | 282 ++++++++++++++++++ .../gcg/test_extension_protocols.py | 197 ++++++++++++ 3 files changed, 497 insertions(+) create mode 100644 pyrit/auxiliary_attacks/gcg/extension_protocols.py create mode 100644 tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index bff04bc1ec..b578f381e3 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -41,16 +41,30 @@ # only have the base `dev` extra (no torch). Touching any of these names from # the package root triggers the underlying module import on first access; if # torch is missing the user gets a clear ModuleNotFoundError pointing at torch. +# +# The extension Protocols live in ``extension_protocols`` (typing-only — that +# module imports cleanly without torch) but are routed through the same lazy +# mechanism so all GCG public symbols share one re-export pathway. _LAZY_IMPORTS = { + "CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"), "GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"), "GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"), + "LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"), + "SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"), + "SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"), "load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"), } if TYPE_CHECKING: from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) from pyrit.auxiliary_attacks.gcg.generator import ( GCGContext, GCGGenerator, @@ -76,6 +90,7 @@ def __dir__() -> list[str]: __all__ = [ + "CandidateFilter", "GCG", "GCGAlgorithmConfig", "GCGConfig", @@ -86,5 +101,8 @@ def __dir__() -> list[str]: "GCGOutputConfig", "GCGResult", "GCGStrategyConfig", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", "load_goals_and_targets", ] diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py new file mode 100644 index 0000000000..d41c7443e8 --- /dev/null +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Typing-only extension points for the Greedy Coordinate Gradient (GCG) attack. + +This module defines four ``runtime_checkable`` ``Protocol``s that mark the four +algorithmic seams inside the GCG optimization loop where a future caller may +substitute custom behavior: + +- ``SamplingStrategy`` — how candidate suffix token sequences are drawn from + the per-step gradient. +- ``LossFunction`` — how each candidate suffix is scored against the target + string. +- ``CandidateFilter`` — how proposed candidate token tensors get pruned and + decoded into the string form consumed by the evaluation pass. +- ``SuffixInitializer`` — how the initial suffix string fed into the + optimization loop is constructed. + +The module is **typing surface only**. It ships no concrete implementations, +no defaults, and no wiring into ``GCGAlgorithmConfig`` or +``GCGMultiPromptAttack``. The default behaviors that match the current attack +code will land as concrete classes in a follow-up PR; the optional +``GCGAlgorithmConfig`` fields that select between defaults and custom +implementations will land in the PR after that. + +Tensor-typed signatures are kept lazy via ``from __future__ import +annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that +``pyrit.auxiliary_attacks.gcg.extension_protocols`` itself imports cleanly on +installs that only have the base ``dev`` extra (no torch). At call time the +implementations are still operating on real ``torch.Tensor`` objects — the +forward references just keep the runtime import side-effect free. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import torch + + +@runtime_checkable +class SamplingStrategy(Protocol): + """Proposes a batch of candidate suffix token sequences from the gradient. + + Invoked once per GCG optimization step, after the per-worker gradients have + been aggregated. The implementation receives the aggregated gradient over + the control (suffix) positions and the current control token sequence, and + returns a batch of candidate replacement sequences for the search pass to + evaluate. + + Implementations must preserve two invariants: + + - The returned tensor has shape ``(batch_size, control_len)`` where + ``control_len == control_toks.shape[0]``. + - The returned tensor lives on the same device as ``gradient`` (the + orchestrator does not re-locate the result). + + The current GCG implementation (top-k by ``-grad``, uniform pick within the + top-k at one randomly-chosen position per row) lives in + ``GCGPromptManager.sample_control``. + + References: + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` lines 90-122 + (``GCGPromptManager.sample_control``). + """ + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_toks: torch.Tensor, + batch_size: int, + topk: int, + temp: int, + allow_non_ascii: bool, + nonascii_toks: torch.Tensor, + ) -> torch.Tensor: + """Sample ``batch_size`` candidate suffix token sequences. + + Args: + gradient (torch.Tensor): Aggregated gradient over the control + tokens with shape ``(control_len, vocab_size)`` and dtype + matching the model's embedding matrix. + control_toks (torch.Tensor): The current suffix token sequence + with shape ``(control_len,)`` and integer dtype. + batch_size (int): Number of candidate suffix rows to return. + topk (int): Number of top gradient positions per control slot + that the strategy is permitted to draw from. + temp (int): Sampling temperature placeholder kept for API + compatibility with the legacy code path. The current default + strategy samples uniformly within the top-k and does not use + this value. + allow_non_ascii (bool): When False, the implementation must + ensure ``nonascii_toks`` are excluded from the candidate + vocabulary (typically by masking those positions of + ``gradient`` to ``+inf`` before top-k selection). + nonascii_toks (torch.Tensor): Token ids to exclude when + ``allow_non_ascii`` is False, shape ``(num_disallowed,)`` + and integer dtype. + + Returns: + torch.Tensor: Candidate suffix token sequences with shape + ``(batch_size, control_len)`` on the same device as ``gradient``. + """ + ... + + +@runtime_checkable +class LossFunction(Protocol): + """Scores a batch of candidate suffixes against the target completion. + + Invoked once per worker per training prompt per candidate batch during the + search pass. The implementation receives the model's logits for the + candidate batch together with the input ids that produced them, plus the + slices that locate the target and control regions inside the sequence, and + returns a per-candidate scalar loss tensor. + + Owning the entire loss computation in one method (criterion choice, + slicing, and any weighted combination of target/control terms) keeps the + protocol orthogonal to the orchestrator: the caller does not need to know + whether the implementation uses cross-entropy or something else, nor how + target and control contributions are combined. The current GCG code path + weights cross-entropy on the target slice by ``target_weight`` and on the + control slice by ``control_weight`` and sums them; a custom + ``LossFunction`` would encapsulate equivalent knobs in its own + constructor. + + Implementations must preserve one invariant: + + - The returned tensor has shape ``(batch_size,)`` — one scalar loss per + candidate. Lower values indicate a better candidate (the orchestrator + selects the ``argmin``). + + References: + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` lines + 326-336 (``AttackPrompt.target_loss`` and + ``AttackPrompt.control_loss``) and + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` lines + 207-215 (the weighted-sum aggregation inside + ``GCGMultiPromptAttack.step``). + """ + + def compute_loss( + self, + *, + logits: torch.Tensor, + ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + """Compute the per-candidate loss for a candidate batch. + + Args: + logits (torch.Tensor): Model logits for the candidate batch with + shape ``(batch_size, seq_len, vocab_size)``. + ids (torch.Tensor): Input token ids the model was run on with + shape ``(batch_size, seq_len)`` and integer dtype. + target_slice (slice): Slice into the sequence dimension that + identifies the target tokens (the completion being optimized + toward). + control_slice (slice): Slice into the sequence dimension that + identifies the control (suffix) tokens. + + Returns: + torch.Tensor: Per-candidate scalar loss with shape + ``(batch_size,)``. Lower is better. + """ + ... + + +@runtime_checkable +class CandidateFilter(Protocol): + """Decodes and prunes a batch of candidate suffix token tensors. + + Invoked once per worker per optimization step, immediately after sampling. + The implementation receives the raw sampled token tensor and the worker's + tokenizer, and is expected to: + + - Decode each row into its string form (this is what the evaluation pass + consumes — it sends the candidate strings back through the tokenizer + together with the goal prompt). + - Optionally drop candidates that fail some quality check. + - Return *exactly* ``batch_size`` strings (the orchestrator allocates a + flat loss buffer of size ``batch_size`` and does not tolerate ragged + outputs). The current implementation pads any dropped rows by repeating + the last accepted candidate. + + Implementations must preserve two invariants: + + - The returned list has length equal to ``candidate_toks.shape[0]``. + - No element of the returned list equals ``current_control`` *unless* the + implementation explicitly allows the no-op candidate (the legacy + filter drops it on the assumption that re-evaluating the current + suffix wastes a slot). + + References: + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` lines + 617-647 (``MultiPromptAttack.get_filtered_cands``). + """ + + def filter_candidates( + self, + *, + candidate_toks: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + """Decode and filter a batch of candidate suffix token tensors. + + Args: + candidate_toks (torch.Tensor): Sampled candidate suffixes with + shape ``(batch_size, control_len)`` and integer dtype. + tokenizer (Any): The worker's HuggingFace-style tokenizer. + ``tokenizer.decode`` is used to render each row to text and + ``tokenizer(text, add_special_tokens=False).input_ids`` is + used by the default length-preserving filter to detect + re-tokenization drift. + current_control (str): The current suffix string. Used by the + default filter to drop the no-op candidate. + + Returns: + list[str]: Decoded candidate suffix strings of length exactly + ``candidate_toks.shape[0]``. Dropped rows are typically padded by + repeating the last accepted candidate. + """ + ... + + +@runtime_checkable +class SuffixInitializer(Protocol): + """Produces the initial suffix string fed into the optimization loop. + + Invoked once at attack-setup time. The returned string is threaded through + the existing ``control_init`` parameter of ``AttackPrompt`` / + ``PromptManager`` / ``MultiPromptAttack`` / the per-strategy ``*Attack`` + constructors, so a custom initializer is fully decoupled from the per-step + optimization machinery. + + The tokenizer is passed in case an implementation needs vocab access (for + example, a random initializer that draws tokens from the model's + vocabulary). The current default — a literal string supplied via + ``GCGAlgorithmConfig.control_init`` — ignores the tokenizer and returns + the configured string verbatim. + + Implementations must return a non-empty string. The downstream + ``AttackPrompt`` constructor raises ``ValueError`` if the suffix cannot be + located inside the chat-templated prompt, so the returned string must also + survive tokenizer round-tripping inside the goal+control message body + (the default twenty space-separated ``!`` tokens satisfies this for every + chat template PyRIT has been tested against). + + References: + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` line + 158 (``AttackPrompt.__init__`` assigns ``self.control = control_init``) + and the ``control_init`` parameter threaded through lines 134, 419, + 541, 875, 1121, and 1333 of the same file. + """ + + def make_initial_suffix(self, *, tokenizer: Any) -> str: + """Return the initial suffix string for the optimization loop. + + Args: + tokenizer (Any): A HuggingFace-style tokenizer the implementation + may consult (for example, to sample tokens from the + vocabulary). The default literal-string initializer ignores + this argument. + + Returns: + str: The initial suffix string. Must be non-empty and must + survive tokenizer round-tripping inside the chat-templated + prompt body. + """ + ... + + +__all__ = [ + "CandidateFilter", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", +] diff --git a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py new file mode 100644 index 0000000000..0e88b95b31 --- /dev/null +++ b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for :mod:`pyrit.auxiliary_attacks.gcg.extension_protocols`. + +These are typing-surface tests: they verify the four ``Protocol``s are +exposed, are ``runtime_checkable``, and accept a minimal in-test concrete +implementation. They do not exercise any real GCG attack code; the default +implementations and the wiring of these protocols into ``GCGAlgorithmConfig`` / +``GCGMultiPromptAttack`` land in follow-up PRs. +""" + +from typing import Any + +import pytest + +# The protocol method signatures reference ``torch.Tensor``. The test bodies +# instantiate real tensors when calling the stub implementations, so the whole +# file is skipped on installs that only have the base ``dev`` extra. +torch = pytest.importorskip("torch", reason="GCG extension protocols reference torch.Tensor") + +import pyrit.auxiliary_attacks.gcg as gcg_pkg # noqa: E402 +from pyrit.auxiliary_attacks.gcg import ( # noqa: E402 + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, +) +from pyrit.auxiliary_attacks.gcg import extension_protocols as protocols_module # noqa: E402 + +PROTOCOL_NAMES = ( + "CandidateFilter", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", +) + + +def test_module_exports_exactly_four_protocols() -> None: + assert set(protocols_module.__all__) == set(PROTOCOL_NAMES) + + +def test_protocols_are_reexported_from_package_with_identity() -> None: + for name in PROTOCOL_NAMES: + package_attr = getattr(gcg_pkg, name) + module_attr = getattr(protocols_module, name) + assert package_attr is module_attr, ( + f"{name} re-exported from pyrit.auxiliary_attacks.gcg must be the same " + f"object as pyrit.auxiliary_attacks.gcg.extension_protocols.{name}" + ) + + +def test_protocols_are_in_package_dunder_all() -> None: + for name in PROTOCOL_NAMES: + assert name in gcg_pkg.__all__ + + +@pytest.mark.parametrize("name", PROTOCOL_NAMES) +def test_protocols_are_runtime_checkable(name: str) -> None: + proto = getattr(protocols_module, name) + # ``runtime_checkable`` marks the protocol with ``_is_runtime_protocol = True``; + # ``isinstance(obj, Proto)`` is only legal on runtime-checkable protocols. + assert getattr(proto, "_is_runtime_protocol", False), ( + f"{name} must be decorated with @runtime_checkable so isinstance() checks work" + ) + + +class _StubSamplingStrategy: + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_toks: torch.Tensor, + batch_size: int, + topk: int, + temp: int, + allow_non_ascii: bool, + nonascii_toks: torch.Tensor, + ) -> torch.Tensor: + return control_toks.unsqueeze(0).repeat(batch_size, 1) + + +class _StubLossFunction: + def compute_loss( + self, + *, + logits: torch.Tensor, + ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + return torch.zeros(logits.shape[0]) + + +class _StubCandidateFilter: + def filter_candidates( + self, + *, + candidate_toks: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return ["stub"] * candidate_toks.shape[0] + + +class _StubSuffixInitializer: + def make_initial_suffix(self, *, tokenizer: Any) -> str: + return "! ! ! !" + + +def test_sampling_strategy_accepts_minimal_impl() -> None: + impl = _StubSamplingStrategy() + assert isinstance(impl, SamplingStrategy) + + +def test_loss_function_accepts_minimal_impl() -> None: + impl = _StubLossFunction() + assert isinstance(impl, LossFunction) + + +def test_candidate_filter_accepts_minimal_impl() -> None: + impl = _StubCandidateFilter() + assert isinstance(impl, CandidateFilter) + + +def test_suffix_initializer_accepts_minimal_impl() -> None: + impl = _StubSuffixInitializer() + assert isinstance(impl, SuffixInitializer) + + +class _ClassWithoutAnyProtocolMethods: + """Has none of the protocol methods. Must fail all isinstance checks.""" + + +@pytest.mark.parametrize( + "proto", + [SamplingStrategy, LossFunction, CandidateFilter, SuffixInitializer], +) +def test_class_missing_protocol_method_fails_isinstance(proto: type) -> None: + bare = _ClassWithoutAnyProtocolMethods() + assert not isinstance(bare, proto), ( + f"A class missing every method must NOT satisfy {proto.__name__}; " + "if this assertion fires, the protocol signature has drifted to require nothing." + ) + + +def test_sampling_strategy_stub_returns_expected_shape() -> None: + impl = _StubSamplingStrategy() + control_toks = torch.tensor([1, 2, 3, 4], dtype=torch.long) + gradient = torch.zeros((4, 100)) + nonascii_toks = torch.tensor([], dtype=torch.long) + + out = impl.sample_candidates( + gradient=gradient, + control_toks=control_toks, + batch_size=5, + topk=8, + temp=1, + allow_non_ascii=True, + nonascii_toks=nonascii_toks, + ) + assert out.shape == (5, 4) + + +def test_loss_function_stub_returns_expected_shape() -> None: + impl = _StubLossFunction() + logits = torch.zeros((3, 10, 50)) + ids = torch.zeros((3, 10), dtype=torch.long) + + out = impl.compute_loss( + logits=logits, + ids=ids, + target_slice=slice(5, 8), + control_slice=slice(2, 5), + ) + assert out.shape == (3,) + + +def test_candidate_filter_stub_returns_expected_length() -> None: + impl = _StubCandidateFilter() + candidate_toks = torch.zeros((7, 4), dtype=torch.long) + + out = impl.filter_candidates( + candidate_toks=candidate_toks, + tokenizer=object(), + current_control="prev", + ) + assert isinstance(out, list) + assert len(out) == 7 + assert all(isinstance(item, str) for item in out) + + +def test_suffix_initializer_stub_returns_string() -> None: + impl = _StubSuffixInitializer() + out = impl.make_initial_suffix(tokenizer=object()) + assert isinstance(out, str) + assert out From 34b6cc92c251950190db2e7c073db760ccffb166 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 29 May 2026 20:26:38 -0700 Subject: [PATCH 2/6] Spell out abbreviated parameter names in extension protocols Renames in pyrit/auxiliary_attacks/gcg/extension_protocols.py and the corresponding test stubs: control_toks -> control_tokens candidate_toks -> candidate_tokens nonascii_toks -> non_ascii_tokens (mirrors allow_non_ascii) topk -> top_k temp -> temperature control_len -> control_length (docstring shape annotations) The legacy GCGAlgorithmConfig fields (topk, temp) and the legacy attack code (GCGPromptManager.sample_control, get_filtered_cands) keep their existing names. Renaming those is a separate API change that belongs in the B3 wiring PR (where GCGAlgorithmConfig is extended anyway). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../gcg/extension_protocols.py | 41 ++++++++++--------- .../gcg/test_extension_protocols.py | 30 +++++++------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index d41c7443e8..7b06c97efd 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -51,8 +51,8 @@ class SamplingStrategy(Protocol): Implementations must preserve two invariants: - - The returned tensor has shape ``(batch_size, control_len)`` where - ``control_len == control_toks.shape[0]``. + - The returned tensor has shape ``(batch_size, control_length)`` where + ``control_length == control_tokens.shape[0]``. - The returned tensor lives on the same device as ``gradient`` (the orchestrator does not re-locate the result). @@ -69,39 +69,40 @@ def sample_candidates( self, *, gradient: torch.Tensor, - control_toks: torch.Tensor, + control_tokens: torch.Tensor, batch_size: int, - topk: int, - temp: int, + top_k: int, + temperature: int, allow_non_ascii: bool, - nonascii_toks: torch.Tensor, + non_ascii_tokens: torch.Tensor, ) -> torch.Tensor: """Sample ``batch_size`` candidate suffix token sequences. Args: gradient (torch.Tensor): Aggregated gradient over the control - tokens with shape ``(control_len, vocab_size)`` and dtype + tokens with shape ``(control_length, vocab_size)`` and dtype matching the model's embedding matrix. - control_toks (torch.Tensor): The current suffix token sequence - with shape ``(control_len,)`` and integer dtype. + control_tokens (torch.Tensor): The current suffix token sequence + with shape ``(control_length,)`` and integer dtype. batch_size (int): Number of candidate suffix rows to return. - topk (int): Number of top gradient positions per control slot + top_k (int): Number of top gradient positions per control slot that the strategy is permitted to draw from. - temp (int): Sampling temperature placeholder kept for API + temperature (int): Sampling temperature placeholder kept for API compatibility with the legacy code path. The current default strategy samples uniformly within the top-k and does not use this value. allow_non_ascii (bool): When False, the implementation must - ensure ``nonascii_toks`` are excluded from the candidate + ensure ``non_ascii_tokens`` are excluded from the candidate vocabulary (typically by masking those positions of ``gradient`` to ``+inf`` before top-k selection). - nonascii_toks (torch.Tensor): Token ids to exclude when + non_ascii_tokens (torch.Tensor): Token ids to exclude when ``allow_non_ascii`` is False, shape ``(num_disallowed,)`` and integer dtype. Returns: torch.Tensor: Candidate suffix token sequences with shape - ``(batch_size, control_len)`` on the same device as ``gradient``. + ``(batch_size, control_length)`` on the same device as + ``gradient``. """ ... @@ -188,7 +189,7 @@ class CandidateFilter(Protocol): Implementations must preserve two invariants: - - The returned list has length equal to ``candidate_toks.shape[0]``. + - The returned list has length equal to ``candidate_tokens.shape[0]``. - No element of the returned list equals ``current_control`` *unless* the implementation explicitly allows the no-op candidate (the legacy filter drops it on the assumption that re-evaluating the current @@ -202,15 +203,15 @@ class CandidateFilter(Protocol): def filter_candidates( self, *, - candidate_toks: torch.Tensor, + candidate_tokens: torch.Tensor, tokenizer: Any, current_control: str, ) -> list[str]: """Decode and filter a batch of candidate suffix token tensors. Args: - candidate_toks (torch.Tensor): Sampled candidate suffixes with - shape ``(batch_size, control_len)`` and integer dtype. + candidate_tokens (torch.Tensor): Sampled candidate suffixes with + shape ``(batch_size, control_length)`` and integer dtype. tokenizer (Any): The worker's HuggingFace-style tokenizer. ``tokenizer.decode`` is used to render each row to text and ``tokenizer(text, add_special_tokens=False).input_ids`` is @@ -221,8 +222,8 @@ def filter_candidates( Returns: list[str]: Decoded candidate suffix strings of length exactly - ``candidate_toks.shape[0]``. Dropped rows are typically padded by - repeating the last accepted candidate. + ``candidate_tokens.shape[0]``. Dropped rows are typically padded + by repeating the last accepted candidate. """ ... diff --git a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py index 0e88b95b31..896504d3a9 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py +++ b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py @@ -70,14 +70,14 @@ def sample_candidates( self, *, gradient: torch.Tensor, - control_toks: torch.Tensor, + control_tokens: torch.Tensor, batch_size: int, - topk: int, - temp: int, + top_k: int, + temperature: int, allow_non_ascii: bool, - nonascii_toks: torch.Tensor, + non_ascii_tokens: torch.Tensor, ) -> torch.Tensor: - return control_toks.unsqueeze(0).repeat(batch_size, 1) + return control_tokens.unsqueeze(0).repeat(batch_size, 1) class _StubLossFunction: @@ -96,11 +96,11 @@ class _StubCandidateFilter: def filter_candidates( self, *, - candidate_toks: torch.Tensor, + candidate_tokens: torch.Tensor, tokenizer: Any, current_control: str, ) -> list[str]: - return ["stub"] * candidate_toks.shape[0] + return ["stub"] * candidate_tokens.shape[0] class _StubSuffixInitializer: @@ -146,18 +146,18 @@ def test_class_missing_protocol_method_fails_isinstance(proto: type) -> None: def test_sampling_strategy_stub_returns_expected_shape() -> None: impl = _StubSamplingStrategy() - control_toks = torch.tensor([1, 2, 3, 4], dtype=torch.long) + control_tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long) gradient = torch.zeros((4, 100)) - nonascii_toks = torch.tensor([], dtype=torch.long) + non_ascii_tokens = torch.tensor([], dtype=torch.long) out = impl.sample_candidates( gradient=gradient, - control_toks=control_toks, + control_tokens=control_tokens, batch_size=5, - topk=8, - temp=1, + top_k=8, + temperature=1, allow_non_ascii=True, - nonascii_toks=nonascii_toks, + non_ascii_tokens=non_ascii_tokens, ) assert out.shape == (5, 4) @@ -178,10 +178,10 @@ def test_loss_function_stub_returns_expected_shape() -> None: def test_candidate_filter_stub_returns_expected_length() -> None: impl = _StubCandidateFilter() - candidate_toks = torch.zeros((7, 4), dtype=torch.long) + candidate_tokens = torch.zeros((7, 4), dtype=torch.long) out = impl.filter_candidates( - candidate_toks=candidate_toks, + candidate_tokens=candidate_tokens, tokenizer=object(), current_control="prev", ) From b1e6c01d76fe420977dc36d24c6c89674195b869 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 29 May 2026 20:38:46 -0700 Subject: [PATCH 3/6] Rename LossFunction.compute_loss ids parameter to token_ids Completes the parameter-name spell-out pass on the four extension protocols (previous commit handled SamplingStrategy / CandidateFilter). `ids` is a common ML shorthand but `token_ids` is unambiguous and consistent with the other tokens-* parameters in the same module. Descriptive uses of the word `ids` in surrounding docstring prose are left as-is since they read naturally. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/extension_protocols.py | 6 +++--- .../unit/auxiliary_attacks/gcg/test_extension_protocols.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index 7b06c97efd..c07f097684 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -146,7 +146,7 @@ def compute_loss( self, *, logits: torch.Tensor, - ids: torch.Tensor, + token_ids: torch.Tensor, target_slice: slice, control_slice: slice, ) -> torch.Tensor: @@ -155,8 +155,8 @@ def compute_loss( Args: logits (torch.Tensor): Model logits for the candidate batch with shape ``(batch_size, seq_len, vocab_size)``. - ids (torch.Tensor): Input token ids the model was run on with - shape ``(batch_size, seq_len)`` and integer dtype. + token_ids (torch.Tensor): Input token ids the model was run on + with shape ``(batch_size, seq_len)`` and integer dtype. target_slice (slice): Slice into the sequence dimension that identifies the target tokens (the completion being optimized toward). diff --git a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py index 896504d3a9..cd65c18016 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py +++ b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py @@ -85,7 +85,7 @@ def compute_loss( self, *, logits: torch.Tensor, - ids: torch.Tensor, + token_ids: torch.Tensor, target_slice: slice, control_slice: slice, ) -> torch.Tensor: @@ -165,11 +165,11 @@ def test_sampling_strategy_stub_returns_expected_shape() -> None: def test_loss_function_stub_returns_expected_shape() -> None: impl = _StubLossFunction() logits = torch.zeros((3, 10, 50)) - ids = torch.zeros((3, 10), dtype=torch.long) + token_ids = torch.zeros((3, 10), dtype=torch.long) out = impl.compute_loss( logits=logits, - ids=ids, + token_ids=token_ids, target_slice=slice(5, 8), control_slice=slice(2, 5), ) From 97cc3ba7627d683726f70c2428d76bd3dd025a31 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 2 Jun 2026 15:18:25 -0700 Subject: [PATCH 4/6] Address review feedback: drop line refs, type temperature as float MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per @rlundeen2's review on PR #1861: 1. Replace `References:` blocks that cited line ranges in `gcg_attack.py` / `attack_manager.py` with symbol-only references. Line numbers drift the moment the legacy attack code is touched (B3 wiring will do exactly that); symbol names are stable across the refactors that follow. 2. Re-type `SamplingStrategy.sample_candidates(temperature=)` as `float` instead of `int`. The protocol is a brand-new surface and was previously mirroring the legacy `GCGAlgorithmConfig.temp: int = 1` field for no good reason — sampling temperatures are conceptually continuous. The legacy field stays as-is; B3 wiring owns deciding whether to widen it or coerce at the boundary. The stub used by the runtime-checkable tests is updated to match, and the shape-smoke test now passes `temperature=1.0`. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../gcg/extension_protocols.py | 32 +++++++++---------- .../gcg/test_extension_protocols.py | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index c07f097684..4c5bc29a51 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -61,8 +61,8 @@ class SamplingStrategy(Protocol): ``GCGPromptManager.sample_control``. References: - ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` lines 90-122 - (``GCGPromptManager.sample_control``). + ``GCGPromptManager.sample_control`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. """ def sample_candidates( @@ -72,7 +72,7 @@ def sample_candidates( control_tokens: torch.Tensor, batch_size: int, top_k: int, - temperature: int, + temperature: float, allow_non_ascii: bool, non_ascii_tokens: torch.Tensor, ) -> torch.Tensor: @@ -87,7 +87,7 @@ def sample_candidates( batch_size (int): Number of candidate suffix rows to return. top_k (int): Number of top gradient positions per control slot that the strategy is permitted to draw from. - temperature (int): Sampling temperature placeholder kept for API + temperature (float): Sampling temperature placeholder kept for API compatibility with the legacy code path. The current default strategy samples uniformly within the top-k and does not use this value. @@ -134,12 +134,10 @@ class LossFunction(Protocol): selects the ``argmin``). References: - ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` lines - 326-336 (``AttackPrompt.target_loss`` and - ``AttackPrompt.control_loss``) and - ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` lines - 207-215 (the weighted-sum aggregation inside - ``GCGMultiPromptAttack.step``). + ``AttackPrompt.target_loss`` and ``AttackPrompt.control_loss`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``, plus + the weighted-sum aggregation inside ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. """ def compute_loss( @@ -196,8 +194,8 @@ class CandidateFilter(Protocol): suffix wastes a slot). References: - ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` lines - 617-647 (``MultiPromptAttack.get_filtered_cands``). + ``MultiPromptAttack.get_filtered_cands`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. """ def filter_candidates( @@ -252,10 +250,12 @@ class SuffixInitializer(Protocol): chat template PyRIT has been tested against). References: - ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` line - 158 (``AttackPrompt.__init__`` assigns ``self.control = control_init``) - and the ``control_init`` parameter threaded through lines 134, 419, - 541, 875, 1121, and 1333 of the same file. + ``AttackPrompt.__init__`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` assigns + ``self.control = control_init``. The same ``control_init`` parameter + is threaded through the ``PromptManager``, ``MultiPromptAttack``, + ``ProgressiveMultiPromptAttack``, ``IndividualPromptAttack``, and + ``EvaluateAttack`` constructors in the same module. """ def make_initial_suffix(self, *, tokenizer: Any) -> str: diff --git a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py index cd65c18016..404ef4b541 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py +++ b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py @@ -73,7 +73,7 @@ def sample_candidates( control_tokens: torch.Tensor, batch_size: int, top_k: int, - temperature: int, + temperature: float, allow_non_ascii: bool, non_ascii_tokens: torch.Tensor, ) -> torch.Tensor: @@ -155,7 +155,7 @@ def test_sampling_strategy_stub_returns_expected_shape() -> None: control_tokens=control_tokens, batch_size=5, top_k=8, - temperature=1, + temperature=1.0, allow_non_ascii=True, non_ascii_tokens=non_ascii_tokens, ) From 7851fb1d5af909ebfa393c6521d38a4a206681be Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 2 Jun 2026 15:49:42 -0700 Subject: [PATCH 5/6] Widen temperature field to float across GCG surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GCGAlgorithmConfig.temp goes from `int = 1` to `float = 1.0`. The matching parameter on the three downstream methods that still typed it as `int` is widened too: - GCGPromptManager.sample_control - GCGMultiPromptAttack.step - MultiPromptAttack.run The other two strategy `run` overloads (ProgressiveMultiPromptAttack, IndividualPromptAttack) were already `float = 1.0` — the pre-existing inconsistency is now resolved. Sampling temperature is conceptually continuous; typing it as int in a brand-new public-API field made no sense. The module is experimental, no deprecation cycle owed. Also updates the SamplingStrategy protocol docstring to drop the stale "kept for API compatibility with the legacy code path" framing in favour of a description of why the parameter exists (the default sampler ignores it, but custom strategies that want softmax weighting receive it). While here, replace seven Sphinx reST cross-reference roles (`:class:...`, `:meth:...`, `:func:...`) in `config.py` with plain double-backtick code spans. PyRIT renders docstrings with MyST, not Sphinx — these roles show up as raw literal text in the built docs and are now blocked by the `check-no-rest-roles` pre-commit hook. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../gcg/attack/base/attack_manager.py | 2 +- .../gcg/attack/gcg/gcg_attack.py | 8 ++++---- pyrit/auxiliary_attacks/gcg/config.py | 20 +++++++++---------- .../gcg/extension_protocols.py | 9 +++++---- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 6e7991ea30..787154fe75 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -654,7 +654,7 @@ def run( n_steps: int = 100, batch_size: int = 1024, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, target_weight: Optional[float] = None, control_weight: Optional[float] = None, diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 3fb5a8aa46..4df1ae9205 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -92,7 +92,7 @@ def sample_control( grad: torch.Tensor, batch_size: int, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, ) -> torch.Tensor: """ @@ -102,7 +102,7 @@ def sample_control( grad (torch.Tensor): Gradient tensor for control tokens. batch_size (int): Number of candidate controls to generate. topk (int): Number of top gradient positions to sample from. Defaults to 256. - temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1. + temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0. allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True. Returns: @@ -130,7 +130,7 @@ def step( *, batch_size: int = 1024, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, target_weight: float = 1, control_weight: float = 0.1, @@ -146,7 +146,7 @@ def step( Args: batch_size (int): Number of candidate controls per batch. Defaults to 1024. topk (int): Number of top gradient positions to sample from. Defaults to 256. - temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1. + temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0. allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True. target_weight (float): Weight for target loss. Defaults to 1. control_weight (float): Weight for control loss. Defaults to 0.1. diff --git a/pyrit/auxiliary_attacks/gcg/config.py b/pyrit/auxiliary_attacks/gcg/config.py index cd5ee405e7..097a9087af 100644 --- a/pyrit/auxiliary_attacks/gcg/config.py +++ b/pyrit/auxiliary_attacks/gcg/config.py @@ -71,7 +71,7 @@ class GCGDataConfig: Used as a typed bundle for AML transport (a job ships its data config as a separate JSON file alongside the strategy ``GCGConfig``). Library callers loading goals/targets from a CSV can construct one and pass it to - :func:`pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets`. + ``pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets``. Attributes: train_data (str): URL or filesystem path to the training-data CSV. Empty @@ -100,7 +100,7 @@ def to_json(self) -> str: @classmethod def from_json(cls, payload: str) -> GCGDataConfig: - """Deserialize a config previously produced by :meth:`to_json`.""" + """Deserialize a config previously produced by ``to_json``.""" try: data = json.loads(payload) except json.JSONDecodeError as e: @@ -131,8 +131,8 @@ class GCGAlgorithmConfig: Defaults to 512. topk (int): Top-k gradient positions considered for substitution. Defaults to 256. - temp (int): Sampling temperature placeholder; the current sampling - implementation samples uniformly from the top-k. Defaults to 1. + temp (float): Sampling temperature placeholder; the current sampling + implementation samples uniformly from the top-k. Defaults to 1.0. target_weight (float): Weight on the target-string cross-entropy loss. Defaults to 1.0. control_weight (float): Weight on the control-string cross-entropy loss. @@ -153,7 +153,7 @@ class GCGAlgorithmConfig: test_steps: int = 50 batch_size: int = 512 topk: int = 256 - temp: int = 1 + temp: float = 1.0 target_weight: float = 1.0 control_weight: float = 0.0 learning_rate: float = 0.01 @@ -240,10 +240,10 @@ class GCGOutputConfig: class GCGConfig: """Top-level strategy configuration for one GCG attack run. - Bundles everything :class:`pyrit.auxiliary_attacks.gcg.GCGGenerator`'s + Bundles everything ``pyrit.auxiliary_attacks.gcg.GCGGenerator``'s constructor needs. Per-execution data (goals, targets) is **not** here — those flow through ``GCGGenerator.execute_async``, and for AML transport - they ride alongside this object as a separate :class:`GCGDataConfig` JSON. + they ride alongside this object as a separate ``GCGDataConfig`` JSON. Attributes: models (list[GCGModelConfig]): Training models the attack optimizes @@ -287,11 +287,11 @@ def to_json(self) -> str: @classmethod def from_json(cls, payload: str) -> GCGConfig: - """Deserialize a config previously produced by :meth:`to_json`. + """Deserialize a config previously produced by ``to_json``. Args: payload (str): JSON document matching the shape produced by - :meth:`to_json`. + ``to_json``. Returns: GCGConfig: A new ``GCGConfig`` reconstructed from ``payload``. @@ -308,7 +308,7 @@ def from_json(cls, payload: str) -> GCGConfig: @classmethod def from_json_file(cls, path: str | Path) -> GCGConfig: - """Load a config from a JSON file produced by :meth:`to_json_file`. + """Load a config from a JSON file produced by ``to_json_file``. Args: path (str | Path): Filesystem path to a JSON config file. diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index 4c5bc29a51..f9f1a3013e 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -87,10 +87,11 @@ def sample_candidates( batch_size (int): Number of candidate suffix rows to return. top_k (int): Number of top gradient positions per control slot that the strategy is permitted to draw from. - temperature (float): Sampling temperature placeholder kept for API - compatibility with the legacy code path. The current default - strategy samples uniformly within the top-k and does not use - this value. + temperature (float): Sampling temperature. The current default + sampling strategy samples uniformly within the top-k and does + not use this value; it is part of the protocol so custom + strategies that need it (for example, softmax weighting) can + receive it. allow_non_ascii (bool): When False, the implementation must ensure ``non_ascii_tokens`` are excluded from the candidate vocabulary (typically by masking those positions of From 1595fbbc6bc0a8233478011d526fd84c20ae6f35 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 2 Jun 2026 15:58:48 -0700 Subject: [PATCH 6/6] Drop Sphinx reST roles from gcg/__init__.py module docstring The `check-no-rest-roles` pre-commit hook blocks `:class:Foo` patterns; PyRIT renders docstrings with MyST, not Sphinx, so those roles appear as raw literal text in the built docs. Two `:class:...` roles in the module-level docstring (`GCG`, `GCGGenerator`, `PromptGeneratorStrategy`) are replaced with plain double-backtick code spans per the documented convention. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index b578f381e3..a10d862fe3 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -3,8 +3,8 @@ """Public API for the Greedy Coordinate Gradient (GCG) auxiliary attack. -The primary entry point is :class:`GCG` (alias for :class:`GCGGenerator`), a -:class:`pyrit.executor.promptgen.core.PromptGeneratorStrategy` that produces +The primary entry point is ``GCG`` (alias for ``GCGGenerator``), a +``pyrit.executor.promptgen.core.PromptGeneratorStrategy`` that produces adversarial suffixes via the GCG algorithm. Example: