diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index bff04bc1ec..a10d862fe3 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -3,8 +3,8 @@ """Public API for the Greedy Coordinate Gradient (GCG) auxiliary attack. -The primary entry point is :class:`GCG` (alias for :class:`GCGGenerator`), a -:class:`pyrit.executor.promptgen.core.PromptGeneratorStrategy` that produces +The primary entry point is ``GCG`` (alias for ``GCGGenerator``), a +``pyrit.executor.promptgen.core.PromptGeneratorStrategy`` that produces adversarial suffixes via the GCG algorithm. Example: @@ -41,16 +41,30 @@ # only have the base `dev` extra (no torch). Touching any of these names from # the package root triggers the underlying module import on first access; if # torch is missing the user gets a clear ModuleNotFoundError pointing at torch. +# +# The extension Protocols live in ``extension_protocols`` (typing-only — that +# module imports cleanly without torch) but are routed through the same lazy +# mechanism so all GCG public symbols share one re-export pathway. _LAZY_IMPORTS = { + "CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"), "GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"), "GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"), + "LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"), + "SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"), + "SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"), "load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"), } if TYPE_CHECKING: from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) from pyrit.auxiliary_attacks.gcg.generator import ( GCGContext, GCGGenerator, @@ -76,6 +90,7 @@ def __dir__() -> list[str]: __all__ = [ + "CandidateFilter", "GCG", "GCGAlgorithmConfig", "GCGConfig", @@ -86,5 +101,8 @@ def __dir__() -> list[str]: "GCGOutputConfig", "GCGResult", "GCGStrategyConfig", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", "load_goals_and_targets", ] diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 6e7991ea30..787154fe75 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -654,7 +654,7 @@ def run( n_steps: int = 100, batch_size: int = 1024, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, target_weight: Optional[float] = None, control_weight: Optional[float] = None, diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 3fb5a8aa46..4df1ae9205 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -92,7 +92,7 @@ def sample_control( grad: torch.Tensor, batch_size: int, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, ) -> torch.Tensor: """ @@ -102,7 +102,7 @@ def sample_control( grad (torch.Tensor): Gradient tensor for control tokens. batch_size (int): Number of candidate controls to generate. topk (int): Number of top gradient positions to sample from. Defaults to 256. - temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1. + temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0. allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True. Returns: @@ -130,7 +130,7 @@ def step( *, batch_size: int = 1024, topk: int = 256, - temp: int = 1, + temp: float = 1.0, allow_non_ascii: bool = True, target_weight: float = 1, control_weight: float = 0.1, @@ -146,7 +146,7 @@ def step( Args: batch_size (int): Number of candidate controls per batch. Defaults to 1024. topk (int): Number of top gradient positions to sample from. Defaults to 256. - temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1. + temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0. allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True. target_weight (float): Weight for target loss. Defaults to 1. control_weight (float): Weight for control loss. Defaults to 0.1. diff --git a/pyrit/auxiliary_attacks/gcg/config.py b/pyrit/auxiliary_attacks/gcg/config.py index cd5ee405e7..097a9087af 100644 --- a/pyrit/auxiliary_attacks/gcg/config.py +++ b/pyrit/auxiliary_attacks/gcg/config.py @@ -71,7 +71,7 @@ class GCGDataConfig: Used as a typed bundle for AML transport (a job ships its data config as a separate JSON file alongside the strategy ``GCGConfig``). Library callers loading goals/targets from a CSV can construct one and pass it to - :func:`pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets`. + ``pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets``. Attributes: train_data (str): URL or filesystem path to the training-data CSV. Empty @@ -100,7 +100,7 @@ def to_json(self) -> str: @classmethod def from_json(cls, payload: str) -> GCGDataConfig: - """Deserialize a config previously produced by :meth:`to_json`.""" + """Deserialize a config previously produced by ``to_json``.""" try: data = json.loads(payload) except json.JSONDecodeError as e: @@ -131,8 +131,8 @@ class GCGAlgorithmConfig: Defaults to 512. topk (int): Top-k gradient positions considered for substitution. Defaults to 256. - temp (int): Sampling temperature placeholder; the current sampling - implementation samples uniformly from the top-k. Defaults to 1. + temp (float): Sampling temperature placeholder; the current sampling + implementation samples uniformly from the top-k. Defaults to 1.0. target_weight (float): Weight on the target-string cross-entropy loss. Defaults to 1.0. control_weight (float): Weight on the control-string cross-entropy loss. @@ -153,7 +153,7 @@ class GCGAlgorithmConfig: test_steps: int = 50 batch_size: int = 512 topk: int = 256 - temp: int = 1 + temp: float = 1.0 target_weight: float = 1.0 control_weight: float = 0.0 learning_rate: float = 0.01 @@ -240,10 +240,10 @@ class GCGOutputConfig: class GCGConfig: """Top-level strategy configuration for one GCG attack run. - Bundles everything :class:`pyrit.auxiliary_attacks.gcg.GCGGenerator`'s + Bundles everything ``pyrit.auxiliary_attacks.gcg.GCGGenerator``'s constructor needs. Per-execution data (goals, targets) is **not** here — those flow through ``GCGGenerator.execute_async``, and for AML transport - they ride alongside this object as a separate :class:`GCGDataConfig` JSON. + they ride alongside this object as a separate ``GCGDataConfig`` JSON. Attributes: models (list[GCGModelConfig]): Training models the attack optimizes @@ -287,11 +287,11 @@ def to_json(self) -> str: @classmethod def from_json(cls, payload: str) -> GCGConfig: - """Deserialize a config previously produced by :meth:`to_json`. + """Deserialize a config previously produced by ``to_json``. Args: payload (str): JSON document matching the shape produced by - :meth:`to_json`. + ``to_json``. Returns: GCGConfig: A new ``GCGConfig`` reconstructed from ``payload``. @@ -308,7 +308,7 @@ def from_json(cls, payload: str) -> GCGConfig: @classmethod def from_json_file(cls, path: str | Path) -> GCGConfig: - """Load a config from a JSON file produced by :meth:`to_json_file`. + """Load a config from a JSON file produced by ``to_json_file``. Args: path (str | Path): Filesystem path to a JSON config file. diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py new file mode 100644 index 0000000000..f9f1a3013e --- /dev/null +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Typing-only extension points for the Greedy Coordinate Gradient (GCG) attack. + +This module defines four ``runtime_checkable`` ``Protocol``s that mark the four +algorithmic seams inside the GCG optimization loop where a future caller may +substitute custom behavior: + +- ``SamplingStrategy`` — how candidate suffix token sequences are drawn from + the per-step gradient. +- ``LossFunction`` — how each candidate suffix is scored against the target + string. +- ``CandidateFilter`` — how proposed candidate token tensors get pruned and + decoded into the string form consumed by the evaluation pass. +- ``SuffixInitializer`` — how the initial suffix string fed into the + optimization loop is constructed. + +The module is **typing surface only**. It ships no concrete implementations, +no defaults, and no wiring into ``GCGAlgorithmConfig`` or +``GCGMultiPromptAttack``. The default behaviors that match the current attack +code will land as concrete classes in a follow-up PR; the optional +``GCGAlgorithmConfig`` fields that select between defaults and custom +implementations will land in the PR after that. + +Tensor-typed signatures are kept lazy via ``from __future__ import +annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that +``pyrit.auxiliary_attacks.gcg.extension_protocols`` itself imports cleanly on +installs that only have the base ``dev`` extra (no torch). At call time the +implementations are still operating on real ``torch.Tensor`` objects — the +forward references just keep the runtime import side-effect free. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import torch + + +@runtime_checkable +class SamplingStrategy(Protocol): + """Proposes a batch of candidate suffix token sequences from the gradient. + + Invoked once per GCG optimization step, after the per-worker gradients have + been aggregated. The implementation receives the aggregated gradient over + the control (suffix) positions and the current control token sequence, and + returns a batch of candidate replacement sequences for the search pass to + evaluate. + + Implementations must preserve two invariants: + + - The returned tensor has shape ``(batch_size, control_length)`` where + ``control_length == control_tokens.shape[0]``. + - The returned tensor lives on the same device as ``gradient`` (the + orchestrator does not re-locate the result). + + The current GCG implementation (top-k by ``-grad``, uniform pick within the + top-k at one randomly-chosen position per row) lives in + ``GCGPromptManager.sample_control``. + + References: + ``GCGPromptManager.sample_control`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + """Sample ``batch_size`` candidate suffix token sequences. + + Args: + gradient (torch.Tensor): Aggregated gradient over the control + tokens with shape ``(control_length, vocab_size)`` and dtype + matching the model's embedding matrix. + control_tokens (torch.Tensor): The current suffix token sequence + with shape ``(control_length,)`` and integer dtype. + batch_size (int): Number of candidate suffix rows to return. + top_k (int): Number of top gradient positions per control slot + that the strategy is permitted to draw from. + temperature (float): Sampling temperature. The current default + sampling strategy samples uniformly within the top-k and does + not use this value; it is part of the protocol so custom + strategies that need it (for example, softmax weighting) can + receive it. + allow_non_ascii (bool): When False, the implementation must + ensure ``non_ascii_tokens`` are excluded from the candidate + vocabulary (typically by masking those positions of + ``gradient`` to ``+inf`` before top-k selection). + non_ascii_tokens (torch.Tensor): Token ids to exclude when + ``allow_non_ascii`` is False, shape ``(num_disallowed,)`` + and integer dtype. + + Returns: + torch.Tensor: Candidate suffix token sequences with shape + ``(batch_size, control_length)`` on the same device as + ``gradient``. + """ + ... + + +@runtime_checkable +class LossFunction(Protocol): + """Scores a batch of candidate suffixes against the target completion. + + Invoked once per worker per training prompt per candidate batch during the + search pass. The implementation receives the model's logits for the + candidate batch together with the input ids that produced them, plus the + slices that locate the target and control regions inside the sequence, and + returns a per-candidate scalar loss tensor. + + Owning the entire loss computation in one method (criterion choice, + slicing, and any weighted combination of target/control terms) keeps the + protocol orthogonal to the orchestrator: the caller does not need to know + whether the implementation uses cross-entropy or something else, nor how + target and control contributions are combined. The current GCG code path + weights cross-entropy on the target slice by ``target_weight`` and on the + control slice by ``control_weight`` and sums them; a custom + ``LossFunction`` would encapsulate equivalent knobs in its own + constructor. + + Implementations must preserve one invariant: + + - The returned tensor has shape ``(batch_size,)`` — one scalar loss per + candidate. Lower values indicate a better candidate (the orchestrator + selects the ``argmin``). + + References: + ``AttackPrompt.target_loss`` and ``AttackPrompt.control_loss`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``, plus + the weighted-sum aggregation inside ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + """Compute the per-candidate loss for a candidate batch. + + Args: + logits (torch.Tensor): Model logits for the candidate batch with + shape ``(batch_size, seq_len, vocab_size)``. + token_ids (torch.Tensor): Input token ids the model was run on + with shape ``(batch_size, seq_len)`` and integer dtype. + target_slice (slice): Slice into the sequence dimension that + identifies the target tokens (the completion being optimized + toward). + control_slice (slice): Slice into the sequence dimension that + identifies the control (suffix) tokens. + + Returns: + torch.Tensor: Per-candidate scalar loss with shape + ``(batch_size,)``. Lower is better. + """ + ... + + +@runtime_checkable +class CandidateFilter(Protocol): + """Decodes and prunes a batch of candidate suffix token tensors. + + Invoked once per worker per optimization step, immediately after sampling. + The implementation receives the raw sampled token tensor and the worker's + tokenizer, and is expected to: + + - Decode each row into its string form (this is what the evaluation pass + consumes — it sends the candidate strings back through the tokenizer + together with the goal prompt). + - Optionally drop candidates that fail some quality check. + - Return *exactly* ``batch_size`` strings (the orchestrator allocates a + flat loss buffer of size ``batch_size`` and does not tolerate ragged + outputs). The current implementation pads any dropped rows by repeating + the last accepted candidate. + + Implementations must preserve two invariants: + + - The returned list has length equal to ``candidate_tokens.shape[0]``. + - No element of the returned list equals ``current_control`` *unless* the + implementation explicitly allows the no-op candidate (the legacy + filter drops it on the assumption that re-evaluating the current + suffix wastes a slot). + + References: + ``MultiPromptAttack.get_filtered_cands`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + """Decode and filter a batch of candidate suffix token tensors. + + Args: + candidate_tokens (torch.Tensor): Sampled candidate suffixes with + shape ``(batch_size, control_length)`` and integer dtype. + tokenizer (Any): The worker's HuggingFace-style tokenizer. + ``tokenizer.decode`` is used to render each row to text and + ``tokenizer(text, add_special_tokens=False).input_ids`` is + used by the default length-preserving filter to detect + re-tokenization drift. + current_control (str): The current suffix string. Used by the + default filter to drop the no-op candidate. + + Returns: + list[str]: Decoded candidate suffix strings of length exactly + ``candidate_tokens.shape[0]``. Dropped rows are typically padded + by repeating the last accepted candidate. + """ + ... + + +@runtime_checkable +class SuffixInitializer(Protocol): + """Produces the initial suffix string fed into the optimization loop. + + Invoked once at attack-setup time. The returned string is threaded through + the existing ``control_init`` parameter of ``AttackPrompt`` / + ``PromptManager`` / ``MultiPromptAttack`` / the per-strategy ``*Attack`` + constructors, so a custom initializer is fully decoupled from the per-step + optimization machinery. + + The tokenizer is passed in case an implementation needs vocab access (for + example, a random initializer that draws tokens from the model's + vocabulary). The current default — a literal string supplied via + ``GCGAlgorithmConfig.control_init`` — ignores the tokenizer and returns + the configured string verbatim. + + Implementations must return a non-empty string. The downstream + ``AttackPrompt`` constructor raises ``ValueError`` if the suffix cannot be + located inside the chat-templated prompt, so the returned string must also + survive tokenizer round-tripping inside the goal+control message body + (the default twenty space-separated ``!`` tokens satisfies this for every + chat template PyRIT has been tested against). + + References: + ``AttackPrompt.__init__`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py`` assigns + ``self.control = control_init``. The same ``control_init`` parameter + is threaded through the ``PromptManager``, ``MultiPromptAttack``, + ``ProgressiveMultiPromptAttack``, ``IndividualPromptAttack``, and + ``EvaluateAttack`` constructors in the same module. + """ + + def make_initial_suffix(self, *, tokenizer: Any) -> str: + """Return the initial suffix string for the optimization loop. + + Args: + tokenizer (Any): A HuggingFace-style tokenizer the implementation + may consult (for example, to sample tokens from the + vocabulary). The default literal-string initializer ignores + this argument. + + Returns: + str: The initial suffix string. Must be non-empty and must + survive tokenizer round-tripping inside the chat-templated + prompt body. + """ + ... + + +__all__ = [ + "CandidateFilter", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", +] diff --git a/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py new file mode 100644 index 0000000000..404ef4b541 --- /dev/null +++ b/tests/unit/auxiliary_attacks/gcg/test_extension_protocols.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for :mod:`pyrit.auxiliary_attacks.gcg.extension_protocols`. + +These are typing-surface tests: they verify the four ``Protocol``s are +exposed, are ``runtime_checkable``, and accept a minimal in-test concrete +implementation. They do not exercise any real GCG attack code; the default +implementations and the wiring of these protocols into ``GCGAlgorithmConfig`` / +``GCGMultiPromptAttack`` land in follow-up PRs. +""" + +from typing import Any + +import pytest + +# The protocol method signatures reference ``torch.Tensor``. The test bodies +# instantiate real tensors when calling the stub implementations, so the whole +# file is skipped on installs that only have the base ``dev`` extra. +torch = pytest.importorskip("torch", reason="GCG extension protocols reference torch.Tensor") + +import pyrit.auxiliary_attacks.gcg as gcg_pkg # noqa: E402 +from pyrit.auxiliary_attacks.gcg import ( # noqa: E402 + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, +) +from pyrit.auxiliary_attacks.gcg import extension_protocols as protocols_module # noqa: E402 + +PROTOCOL_NAMES = ( + "CandidateFilter", + "LossFunction", + "SamplingStrategy", + "SuffixInitializer", +) + + +def test_module_exports_exactly_four_protocols() -> None: + assert set(protocols_module.__all__) == set(PROTOCOL_NAMES) + + +def test_protocols_are_reexported_from_package_with_identity() -> None: + for name in PROTOCOL_NAMES: + package_attr = getattr(gcg_pkg, name) + module_attr = getattr(protocols_module, name) + assert package_attr is module_attr, ( + f"{name} re-exported from pyrit.auxiliary_attacks.gcg must be the same " + f"object as pyrit.auxiliary_attacks.gcg.extension_protocols.{name}" + ) + + +def test_protocols_are_in_package_dunder_all() -> None: + for name in PROTOCOL_NAMES: + assert name in gcg_pkg.__all__ + + +@pytest.mark.parametrize("name", PROTOCOL_NAMES) +def test_protocols_are_runtime_checkable(name: str) -> None: + proto = getattr(protocols_module, name) + # ``runtime_checkable`` marks the protocol with ``_is_runtime_protocol = True``; + # ``isinstance(obj, Proto)`` is only legal on runtime-checkable protocols. + assert getattr(proto, "_is_runtime_protocol", False), ( + f"{name} must be decorated with @runtime_checkable so isinstance() checks work" + ) + + +class _StubSamplingStrategy: + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + return control_tokens.unsqueeze(0).repeat(batch_size, 1) + + +class _StubLossFunction: + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + return torch.zeros(logits.shape[0]) + + +class _StubCandidateFilter: + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return ["stub"] * candidate_tokens.shape[0] + + +class _StubSuffixInitializer: + def make_initial_suffix(self, *, tokenizer: Any) -> str: + return "! ! ! !" + + +def test_sampling_strategy_accepts_minimal_impl() -> None: + impl = _StubSamplingStrategy() + assert isinstance(impl, SamplingStrategy) + + +def test_loss_function_accepts_minimal_impl() -> None: + impl = _StubLossFunction() + assert isinstance(impl, LossFunction) + + +def test_candidate_filter_accepts_minimal_impl() -> None: + impl = _StubCandidateFilter() + assert isinstance(impl, CandidateFilter) + + +def test_suffix_initializer_accepts_minimal_impl() -> None: + impl = _StubSuffixInitializer() + assert isinstance(impl, SuffixInitializer) + + +class _ClassWithoutAnyProtocolMethods: + """Has none of the protocol methods. Must fail all isinstance checks.""" + + +@pytest.mark.parametrize( + "proto", + [SamplingStrategy, LossFunction, CandidateFilter, SuffixInitializer], +) +def test_class_missing_protocol_method_fails_isinstance(proto: type) -> None: + bare = _ClassWithoutAnyProtocolMethods() + assert not isinstance(bare, proto), ( + f"A class missing every method must NOT satisfy {proto.__name__}; " + "if this assertion fires, the protocol signature has drifted to require nothing." + ) + + +def test_sampling_strategy_stub_returns_expected_shape() -> None: + impl = _StubSamplingStrategy() + control_tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long) + gradient = torch.zeros((4, 100)) + non_ascii_tokens = torch.tensor([], dtype=torch.long) + + out = impl.sample_candidates( + gradient=gradient, + control_tokens=control_tokens, + batch_size=5, + top_k=8, + temperature=1.0, + allow_non_ascii=True, + non_ascii_tokens=non_ascii_tokens, + ) + assert out.shape == (5, 4) + + +def test_loss_function_stub_returns_expected_shape() -> None: + impl = _StubLossFunction() + logits = torch.zeros((3, 10, 50)) + token_ids = torch.zeros((3, 10), dtype=torch.long) + + out = impl.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=slice(5, 8), + control_slice=slice(2, 5), + ) + assert out.shape == (3,) + + +def test_candidate_filter_stub_returns_expected_length() -> None: + impl = _StubCandidateFilter() + candidate_tokens = torch.zeros((7, 4), dtype=torch.long) + + out = impl.filter_candidates( + candidate_tokens=candidate_tokens, + tokenizer=object(), + current_control="prev", + ) + assert isinstance(out, list) + assert len(out) == 7 + assert all(isinstance(item, str) for item in out) + + +def test_suffix_initializer_stub_returns_string() -> None: + impl = _StubSuffixInitializer() + out = impl.make_initial_suffix(tokenizer=object()) + assert isinstance(out, str) + assert out