diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index a10d862fe3..160b2f313a 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -47,18 +47,28 @@ # mechanism so all GCG public symbols share one re-export pathway. _LAZY_IMPORTS = { "CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"), + "CrossEntropyLoss": ("pyrit.auxiliary_attacks.gcg.default_implementations", "CrossEntropyLoss"), "GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"), "GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"), + "LengthPreservingFilter": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LengthPreservingFilter"), + "LiteralStringInit": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LiteralStringInit"), "LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"), "SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"), + "StandardGCGSampling": ("pyrit.auxiliary_attacks.gcg.default_implementations", "StandardGCGSampling"), "SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"), "load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"), } if TYPE_CHECKING: from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets + from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, + ) from pyrit.auxiliary_attacks.gcg.extension_protocols import ( CandidateFilter, LossFunction, @@ -91,6 +101,7 @@ def __dir__() -> list[str]: __all__ = [ "CandidateFilter", + "CrossEntropyLoss", "GCG", "GCGAlgorithmConfig", "GCGConfig", @@ -101,8 +112,11 @@ def __dir__() -> list[str]: "GCGOutputConfig", "GCGResult", "GCGStrategyConfig", + "LengthPreservingFilter", + "LiteralStringInit", "LossFunction", "SamplingStrategy", + "StandardGCGSampling", "SuffixInitializer", "load_goals_and_targets", ] diff --git a/pyrit/auxiliary_attacks/gcg/default_implementations.py b/pyrit/auxiliary_attacks/gcg/default_implementations.py new file mode 100644 index 0000000000..3967c128c7 --- /dev/null +++ b/pyrit/auxiliary_attacks/gcg/default_implementations.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Default concrete implementations of the four GCG extension protocols. + +Each class in this module reproduces the byte-identical behavior of the +legacy GCG attack code path it replaces: + +- ``StandardGCGSampling`` reproduces ``GCGPromptManager.sample_control``. +- ``CrossEntropyLoss`` reproduces ``AttackPrompt.target_loss`` and + ``AttackPrompt.control_loss`` combined via the weighted sum applied + inside ``GCGMultiPromptAttack.step``. +- ``LengthPreservingFilter`` reproduces ``MultiPromptAttack.get_filtered_cands``. +- ``LiteralStringInit`` reproduces the literal-string ``control_init`` + parameter threaded through the attack constructors. + +The defaults are *not* wired into ``GCGMultiPromptAttack`` here. They are +shipped ahead of wiring so the strategy objects can already be constructed +and inspected, and so the wiring change is a pure orchestration edit. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class StandardGCGSampling: + """Top-k by ``-gradient``, uniform pick within top-k at one random position per row. + + The standard GCG sampling rule: for each of ``batch_size`` candidate + rows, pick one of the ``control_length`` positions, then replace the + token at that position with a uniformly-sampled token id from the top-k + smallest-gradient (most-promising) candidates at that position. The + ``temperature`` argument is part of the protocol but is unused by this + sampler, which always samples uniformly within the top-k. + + Reproduces ``GCGPromptManager.sample_control`` from + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` byte-for-byte. + """ + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + """Sample ``batch_size`` candidate suffix token sequences. + + Args: + gradient (torch.Tensor): Aggregated gradient over the control + tokens with shape ``(control_length, vocab_size)``. Mutated + in-place when ``allow_non_ascii`` is False (the disallowed + token positions are set to ``+inf``), matching legacy + behavior. + control_tokens (torch.Tensor): Current suffix token sequence + with shape ``(control_length,)``. + batch_size (int): Number of candidate suffix rows to return. + top_k (int): Number of top gradient positions per control slot + drawn from. + temperature (float): Sampling temperature. Unused by this + implementation; kept to match the protocol signature. + allow_non_ascii (bool): When False, mask the ``non_ascii_tokens`` + positions of ``gradient`` to ``+inf`` so they fall out of + the top-k. + non_ascii_tokens (torch.Tensor): Token ids to exclude when + ``allow_non_ascii`` is False. + + Returns: + torch.Tensor: Candidate suffix token sequences with shape + ``(batch_size, control_length)`` on the same device as + ``gradient``. + """ + if not allow_non_ascii: + gradient[:, non_ascii_tokens.to(gradient.device)] = np.inf + top_indices = (-gradient).topk(top_k, dim=1).indices + control_tokens = control_tokens.to(gradient.device) + original_control_tokens = control_tokens.repeat(batch_size, 1) + new_token_pos = torch.arange( + 0, + len(control_tokens), + len(control_tokens) / batch_size, + device=gradient.device, + ).type(torch.int64) + new_token_val = torch.gather( + top_indices[new_token_pos], + 1, + torch.randint(0, top_k, (batch_size, 1), device=gradient.device), + ) + return original_control_tokens.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val) + + +class CrossEntropyLoss: + """Weighted token-level cross-entropy on the target and control slices. + + Per candidate: ``target_weight * CE(target_slice) + control_weight * + CE(control_slice)``, where each cross-entropy term is reduced over its + slice with ``.mean(dim=-1)`` to give one scalar per candidate. The + ``.mean(dim=-1)`` reduction matches where the legacy orchestrator + applies it: ``GCGMultiPromptAttack.step`` calls + ``target_loss(...).mean(dim=-1)`` outside the per-prompt loss method, + so the ``LossFunction`` protocol places the per-candidate scalar + reduction inside the implementation. + + When ``control_weight == 0`` the control term is skipped entirely, + matching the legacy ``if control_weight != 0:`` guard inside ``step``. + The same skip is applied when ``target_weight == 0`` for symmetry. + + Reproduces ``AttackPrompt.target_loss`` + ``AttackPrompt.control_loss`` + from ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``, + combined per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + + def __init__(self, *, target_weight: float = 1.0, control_weight: float = 0.0) -> None: + """Initialize the cross-entropy loss with target / control weights. + + Args: + target_weight (float): Weight on the target-slice cross-entropy. + Defaults to 1.0. + control_weight (float): Weight on the control-slice + cross-entropy. Defaults to 0.0 (target-only signal). + + Raises: + ValueError: If either weight is negative, or if both are zero. + """ + if target_weight < 0 or control_weight < 0: + raise ValueError( + "CrossEntropyLoss target_weight and control_weight must be >= 0, " + f"got target_weight={target_weight}, control_weight={control_weight}." + ) + if target_weight == 0 and control_weight == 0: + raise ValueError( + "CrossEntropyLoss requires at least one of target_weight or control_weight to be > 0; " + "with both at 0 the loss is identically zero and provides no signal." + ) + self._target_weight = target_weight + self._control_weight = control_weight + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + """Compute the per-candidate weighted cross-entropy loss. + + Args: + logits (torch.Tensor): Model logits for the candidate batch + with shape ``(batch_size, seq_len, vocab_size)``. + token_ids (torch.Tensor): Input token ids the model was run on + with shape ``(batch_size, seq_len)``. + target_slice (slice): Slice into the sequence dimension that + identifies the target tokens. + control_slice (slice): Slice into the sequence dimension that + identifies the control (suffix) tokens. + + Returns: + torch.Tensor: Per-candidate scalar loss with shape + ``(batch_size,)``. + """ + criterion = nn.CrossEntropyLoss(reduction="none") + total: torch.Tensor | None = None + + if self._target_weight > 0: + target_loss_slice = slice(target_slice.start - 1, target_slice.stop - 1) + target_term = criterion( + logits[:, target_loss_slice, :].transpose(1, 2), + token_ids[:, target_slice], + ).mean(dim=-1) + total = self._target_weight * target_term + + if self._control_weight > 0: + control_loss_slice = slice(control_slice.start - 1, control_slice.stop - 1) + control_term = criterion( + logits[:, control_loss_slice, :].transpose(1, 2), + token_ids[:, control_slice], + ).mean(dim=-1) + weighted_control = self._control_weight * control_term + total = weighted_control if total is None else total + weighted_control + + # Constructor guarantees at least one weight is > 0, so ``total`` is + # always assigned. The check is kept for the type checker. + if total is None: + raise RuntimeError( + "CrossEntropyLoss.compute_loss produced no terms; " + "this indicates a corrupted instance with both weights at 0." + ) + return total + + +class LengthPreservingFilter: + """Decodes each candidate token row and drops any whose decoded string + either (a) equals ``current_control`` or (b) re-tokenizes to a different + token count, padding dropped rows by repeating the last accepted + candidate. + + The ``filter`` constructor parameter selects between filtering (legacy + ``filter_cand=True`` branch) and passthrough decode-only mode (legacy + ``filter_cand=False`` branch). + + Also performs the legacy out-of-vocab clamping: tokens above + ``tokenizer.vocab_size`` are replaced in-place by the id of ``"!"``, + matching the safety pass at the top of ``get_filtered_cands``. + + Reproduces ``MultiPromptAttack.get_filtered_cands`` from + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, filter: bool = True) -> None: + """Initialize the filter. + + Args: + filter (bool): When True, drop candidates that equal + ``current_control`` or re-tokenize to a different length, + padding the result with the last accepted candidate. When + False, decode every row and return them all unchanged. + Defaults to True. + """ + self._filter = filter + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + """Decode and filter a batch of candidate suffix token tensors. + + Args: + candidate_tokens (torch.Tensor): Sampled candidate suffixes + with shape ``(batch_size, control_length)``. Mutated + in-place by the out-of-vocab clamp, matching legacy + behavior. + tokenizer (Any): HuggingFace-style tokenizer. ``tokenizer.decode`` + renders each row to text; ``tokenizer(text, + add_special_tokens=False).input_ids`` is used to detect + re-tokenization drift; ``tokenizer("!").input_ids[0]`` + provides the replacement id for out-of-vocab clamping. + current_control (str): Current suffix string. When ``filter`` + is True, candidates that decode to this string are dropped. + + Returns: + list[str]: Decoded candidate suffix strings of length exactly + ``candidate_tokens.shape[0]``. + """ + logger.info("Masking out of range token_id.") + vocab_size = tokenizer.vocab_size + candidate_tokens[candidate_tokens > vocab_size] = tokenizer("!").input_ids[0] + + candidates: list[str] = [] + for i in range(candidate_tokens.shape[0]): + decoded_str = tokenizer.decode( + candidate_tokens[i], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if self._filter: + if decoded_str != current_control and len( + tokenizer(decoded_str, add_special_tokens=False).input_ids + ) == len(candidate_tokens[i]): + candidates.append(decoded_str) + else: + candidates.append(decoded_str) + + if self._filter: + candidates = candidates + [candidates[-1]] * (len(candidate_tokens) - len(candidates)) + return candidates + + +class LiteralStringInit: + """Returns the configured literal suffix verbatim; ignores the tokenizer. + + Encapsulates the current ``control_init`` plumbing — a literal string + threaded through ``AttackPrompt.__init__``, ``PromptManager.__init__``, + ``MultiPromptAttack.__init__``, and the per-strategy ``*Attack`` + constructors — so that custom initializers that do need the tokenizer + (for example, a random vocabulary sampler) can be swapped in without + changing those constructor signatures. + + Reproduces the literal-string ``control_init`` parameter assignment + (``self.control = control_init``) inside ``AttackPrompt.__init__`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, suffix: str) -> None: + """Initialize the literal-string suffix initializer. + + Args: + suffix (str): The literal suffix string to return on every + call to ``make_initial_suffix``. Must be non-empty. + + Raises: + ValueError: If ``suffix`` is the empty string. + """ + if not suffix: + raise ValueError("LiteralStringInit.suffix must be a non-empty string.") + self._suffix = suffix + + def make_initial_suffix(self, *, tokenizer: Any) -> str: + """Return the configured suffix string. + + Args: + tokenizer (Any): Ignored. Present to match the protocol + signature so custom initializers that need vocabulary + access can be substituted without changing call sites. + + Returns: + str: The literal suffix string supplied at construction. + """ + return self._suffix + + +__all__ = [ + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +] diff --git a/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py new file mode 100644 index 0000000000..8b89745052 --- /dev/null +++ b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py @@ -0,0 +1,454 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.auxiliary_attacks.gcg.default_implementations``. + +These tests verify byte-identical parity between the four default +implementations and the legacy GCG attack code paths they reproduce: + +- ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control`` +- ``CrossEntropyLoss`` vs the weighted sum of ``AttackPrompt.target_loss`` + and ``AttackPrompt.control_loss`` applied inside + ``GCGMultiPromptAttack.step`` +- ``LengthPreservingFilter`` vs ``MultiPromptAttack.get_filtered_cands`` +- ``LiteralStringInit`` vs the literal-string ``control_init`` assignment + inside ``AttackPrompt.__init__`` + +Mocking patterns follow the conventions established in +``tests/unit/auxiliary_attacks/gcg/test_gcg_core.py`` (``object.__new__`` +to skip the real ``__init__``, ``MagicMock`` tokenizers). +""" + +from unittest.mock import MagicMock + +import pytest + +torch = pytest.importorskip("torch", reason="GCG default implementations require torch") + +attack_manager_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.base.attack_manager", + reason="GCG optional dependencies (torch, mlflow, etc.) not installed", +) +gcg_attack_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", + reason="GCG optional dependencies not installed", +) + +import pyrit.auxiliary_attacks.gcg as gcg_pkg # noqa: E402 +from pyrit.auxiliary_attacks.gcg import ( # noqa: E402 + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg import default_implementations as defaults_module # noqa: E402 +from pyrit.auxiliary_attacks.gcg.config import GCGAlgorithmConfig # noqa: E402 + +AttackPrompt = attack_manager_mod.AttackPrompt +MultiPromptAttack = attack_manager_mod.MultiPromptAttack +GCGPromptManager = gcg_attack_mod.GCGPromptManager + + +DEFAULT_NAMES = ( + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +) + + +class TestPackageReExports: + """Verify the four default classes are re-exported from the package root.""" + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_is_reexported_with_identity(self, name: str) -> None: + package_attr = getattr(gcg_pkg, name) + module_attr = getattr(defaults_module, name) + assert package_attr is module_attr, ( + f"{name} re-exported from pyrit.auxiliary_attacks.gcg must be the same " + f"object as pyrit.auxiliary_attacks.gcg.default_implementations.{name}" + ) + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_in_package_dunder_all(self, name: str) -> None: + assert name in gcg_pkg.__all__ + + +class TestStandardGCGSampling: + """Parity: ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control``.""" + + def _make_legacy_prompt_manager( + self, + *, + control_tokens: torch.Tensor, + non_ascii_tokens: torch.Tensor, + ) -> GCGPromptManager: + # Mirrors the construction pattern used by TestSampleControl in + # test_gcg_core.py: skip __init__ and seed just the attributes that + # sample_control reads. + prompt_manager = object.__new__(GCGPromptManager) + prompt_manager._nonascii_toks = non_ascii_tokens + prompt_manager._prompts = [MagicMock()] + prompt_manager._prompts[0].control_toks = control_tokens.clone() + return prompt_manager + + def test_sample_candidates_matches_legacy_with_ascii_only(self) -> None: + """Legacy reference: ``GCGPromptManager.sample_control(grad, batch_size, + topk=top_k, temp=1.0, allow_non_ascii=False)`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + n_control_tokens = 5 + vocab_size = 50 + batch_size = 4 + top_k = 8 + + torch.manual_seed(2026) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([2, 7, 13]) + + # Legacy path + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(12345) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=False, + ) + + # Default path + default = StandardGCGSampling() + torch.manual_seed(12345) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=False, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + def test_sample_candidates_matches_legacy_with_non_ascii_allowed(self) -> None: + """Legacy reference: same as above but with ``allow_non_ascii=True`` + (the no-mask branch where the gradient is not mutated). + """ + n_control_tokens = 6 + vocab_size = 40 + batch_size = 5 + top_k = 10 + + torch.manual_seed(2027) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([1, 4]) + + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(54321) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=True, + ) + + default = StandardGCGSampling() + torch.manual_seed(54321) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=True, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + +class TestCrossEntropyLoss: + """Parity: ``CrossEntropyLoss`` vs ``AttackPrompt.target_loss`` + + ``AttackPrompt.control_loss``. + """ + + def _make_legacy_prompt( + self, + *, + target_slice: slice, + control_slice: slice, + ) -> AttackPrompt: + # Mirrors TestTargetAndControlLoss in test_gcg_core.py: skip + # __init__ and seed only the slice attributes that the loss methods + # consult. + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + def test_compute_loss_matches_legacy_weighted_sum(self) -> None: + """Legacy reference: + ``target_weight * AttackPrompt.target_loss(logits, ids).mean(dim=-1)`` + ``+ control_weight * AttackPrompt.control_loss(logits, ids).mean(dim=-1)``, + per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + batch_size = 4 + seq_len = 10 + vocab_size = 30 + target_slice = slice(5, 8) + control_slice = slice(2, 5) + target_weight = 1.0 + control_weight = 0.1 + + torch.manual_seed(99) + logits = torch.randn(batch_size, seq_len, vocab_size) + token_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_target = prompt.target_loss(logits, token_ids).mean(dim=-1) + legacy_control = prompt.control_loss(logits, token_ids).mean(dim=-1) + legacy_total = target_weight * legacy_target + control_weight * legacy_control + + default = CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_target_only_matches_legacy_target_loss(self) -> None: + """With ``control_weight=0`` the legacy ``step`` skips the control + term (``if control_weight != 0:`` guard at line 211). The default + must produce the same per-candidate value as + ``target_weight * target_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(7) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 1.0 * prompt.target_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.0) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_control_only_matches_legacy_control_loss(self) -> None: + """With ``target_weight=0`` the default must produce the same value + as ``control_weight * control_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(13) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 0.5 * prompt.control_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=0.0, control_weight=0.5) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_init_rejects_both_weights_zero(self) -> None: + with pytest.raises(ValueError, match="at least one"): + CrossEntropyLoss(target_weight=0.0, control_weight=0.0) + + def test_init_rejects_negative_target_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=-0.5, control_weight=1.0) + + def test_init_rejects_negative_control_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=1.0, control_weight=-0.5) + + def test_compute_loss_returns_batch_sized_tensor(self) -> None: + batch_size = 4 + logits = torch.randn(batch_size, 10, 20) + token_ids = torch.randint(0, 20, (batch_size, 10)) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.1) + out = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=slice(5, 8), + control_slice=slice(2, 5), + ) + + assert out.shape == (batch_size,) + + +def _make_filter_tokenizer() -> MagicMock: + """Build a fresh, deterministic, stateless mock tokenizer for filter tests. + + Behavior: + - ``decode(tensor)`` -> ``"x" * int(tensor[0].item())`` — string length + is keyed off the first token id, so each row maps to a distinct + predictable string. + - ``tokenizer(text, ...).input_ids`` has length ``len(text)`` — so the + retokenized length check is fully predictable from the decoded + string. + - ``tokenizer("!").input_ids[0] == 0`` — provides the clamp + replacement id. + - ``vocab_size == 100``. + """ + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + return "x" * int(ids[0].item()) + + tokenizer.decode.side_effect = decode_fn + + def call_tokenizer(text, **_kwargs): + result = MagicMock() + if text == "!": + result.input_ids = [0] + else: + result.input_ids = list(range(len(text))) + return result + + tokenizer.side_effect = call_tokenizer + return tokenizer + + +class TestLengthPreservingFilter: + """Parity: ``LengthPreservingFilter`` vs + ``MultiPromptAttack.get_filtered_cands``. + """ + + def _make_legacy_attack(self, *, tokenizer: MagicMock) -> MultiPromptAttack: + # Mirrors TestGetFilteredCands in test_gcg_core.py: skip __init__ + # and only attach the workers list that get_filtered_cands reads. + attack = object.__new__(MultiPromptAttack) + worker = MagicMock() + worker.tokenizer = tokenizer + attack.workers = [worker] + return attack + + def test_filter_candidates_matches_legacy_filtered(self) -> None: + """Legacy reference: + ``MultiPromptAttack.get_filtered_cands(0, control_cand, + filter_cand=True, curr_control=...)`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + + With the helper tokenizer: + - Row 0 ``[3, 0, 1]`` -> decode ``"xxx"`` (len 3); retok len 3 == + control_length 3 -> KEEP. + - Row 1 ``[5, 0, 0]`` -> decode ``"xxxxx"`` (len 5); retok len 5 + != 3 -> DROP. + - Row 2 ``[2, 0, 1]`` -> decode ``"xx"`` (len 2); retok len 2 != + 3 -> DROP. + Pad-with-last gives ``["xxx", "xxx", "xxx"]``. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands( + 0, candidate_template.clone(), filter_cand=True, curr_control="never_matches" + ) + + default = LengthPreservingFilter(filter=True) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="never_matches", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxx", "xxx"] + + def test_filter_candidates_matches_legacy_unfiltered(self) -> None: + """Legacy reference: ``get_filtered_cands(0, control_cand, + filter_cand=False)``. Every row is decoded and returned unchanged. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands(0, candidate_template.clone(), filter_cand=False) + + default = LengthPreservingFilter(filter=False) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="ignored_when_filter_false", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxxxx", "xx"] + + def test_filter_candidates_clamps_out_of_vocab_tokens(self) -> None: + """Both code paths apply the legacy vocab-clamp in-place: tokens + above ``vocab_size`` are replaced by the id of ``"!"`` before any + decoding happens. + """ + candidate_template = torch.tensor([[150, 0, 1], [3, 0, 1]]) # 150 > vocab_size=100 + + legacy_input = candidate_template.clone() + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_attack.get_filtered_cands(0, legacy_input, filter_cand=False) + + default_input = candidate_template.clone() + default = LengthPreservingFilter(filter=False) + default.filter_candidates( + candidate_tokens=default_input, + tokenizer=_make_filter_tokenizer(), + current_control="", + ) + + assert torch.equal(default_input, legacy_input) + assert default_input[0, 0].item() == 0 + + +class TestLiteralStringInit: + """Parity: ``LiteralStringInit`` vs the literal-string ``control_init`` + assignment inside ``AttackPrompt.__init__`` (``self.control = + control_init``). + """ + + def test_make_initial_suffix_returns_default_control_init(self) -> None: + """Legacy reference: ``GCGAlgorithmConfig.control_init`` (default + ``_DEFAULT_CONTROL_INIT``) is assigned to ``self.control`` in + ``AttackPrompt.__init__``. + """ + default_suffix = GCGAlgorithmConfig().control_init + initializer = LiteralStringInit(suffix=default_suffix) + assert initializer.make_initial_suffix(tokenizer=MagicMock()) == default_suffix + + def test_make_initial_suffix_ignores_tokenizer(self) -> None: + suffix = "custom suffix string" + initializer = LiteralStringInit(suffix=suffix) + assert initializer.make_initial_suffix(tokenizer=None) == suffix + + def test_init_rejects_empty_suffix(self) -> None: + with pytest.raises(ValueError, match="non-empty"): + LiteralStringInit(suffix="")