Skip to content
22 changes: 20 additions & 2 deletions pyrit/auxiliary_attacks/gcg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

"""Public API for the Greedy Coordinate Gradient (GCG) auxiliary attack.

The primary entry point is :class:`GCG` (alias for :class:`GCGGenerator`), a
:class:`pyrit.executor.promptgen.core.PromptGeneratorStrategy` that produces
The primary entry point is ``GCG`` (alias for ``GCGGenerator``), a
``pyrit.executor.promptgen.core.PromptGeneratorStrategy`` that produces
adversarial suffixes via the GCG algorithm.

Example:
Expand Down Expand Up @@ -41,16 +41,30 @@
# only have the base `dev` extra (no torch). Touching any of these names from
# the package root triggers the underlying module import on first access; if
# torch is missing the user gets a clear ModuleNotFoundError pointing at torch.
#
# The extension Protocols live in ``extension_protocols`` (typing-only — that
# module imports cleanly without torch) but are routed through the same lazy
# mechanism so all GCG public symbols share one re-export pathway.
_LAZY_IMPORTS = {
"CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"),
"GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
"GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"),
"GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
"GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"),
"LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"),
"SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"),
"SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"),
"load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"),
}

if TYPE_CHECKING:
from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets
from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
SamplingStrategy,
SuffixInitializer,
)
from pyrit.auxiliary_attacks.gcg.generator import (
GCGContext,
GCGGenerator,
Expand All @@ -76,6 +90,7 @@ def __dir__() -> list[str]:


__all__ = [
"CandidateFilter",
"GCG",
"GCGAlgorithmConfig",
"GCGConfig",
Expand All @@ -86,5 +101,8 @@ def __dir__() -> list[str]:
"GCGOutputConfig",
"GCGResult",
"GCGStrategyConfig",
"LossFunction",
"SamplingStrategy",
"SuffixInitializer",
"load_goals_and_targets",
]
2 changes: 1 addition & 1 deletion pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def run(
n_steps: int = 100,
batch_size: int = 1024,
topk: int = 256,
temp: int = 1,
temp: float = 1.0,
allow_non_ascii: bool = True,
target_weight: Optional[float] = None,
control_weight: Optional[float] = None,
Expand Down
8 changes: 4 additions & 4 deletions pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def sample_control(
grad: torch.Tensor,
batch_size: int,
topk: int = 256,
temp: int = 1,
temp: float = 1.0,
allow_non_ascii: bool = True,
) -> torch.Tensor:
"""
Expand All @@ -102,7 +102,7 @@ def sample_control(
grad (torch.Tensor): Gradient tensor for control tokens.
batch_size (int): Number of candidate controls to generate.
topk (int): Number of top gradient positions to sample from. Defaults to 256.
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0.
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.

Returns:
Expand Down Expand Up @@ -130,7 +130,7 @@ def step(
*,
batch_size: int = 1024,
topk: int = 256,
temp: int = 1,
temp: float = 1.0,
allow_non_ascii: bool = True,
target_weight: float = 1,
control_weight: float = 0.1,
Expand All @@ -146,7 +146,7 @@ def step(
Args:
batch_size (int): Number of candidate controls per batch. Defaults to 1024.
topk (int): Number of top gradient positions to sample from. Defaults to 256.
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
temp (float): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.0.
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.
target_weight (float): Weight for target loss. Defaults to 1.
control_weight (float): Weight for control loss. Defaults to 0.1.
Expand Down
20 changes: 10 additions & 10 deletions pyrit/auxiliary_attacks/gcg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GCGDataConfig:
Used as a typed bundle for AML transport (a job ships its data config as
a separate JSON file alongside the strategy ``GCGConfig``). Library
callers loading goals/targets from a CSV can construct one and pass it to
:func:`pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets`.
``pyrit.auxiliary_attacks.gcg.data.load_goals_and_targets``.

Attributes:
train_data (str): URL or filesystem path to the training-data CSV. Empty
Expand Down Expand Up @@ -100,7 +100,7 @@ def to_json(self) -> str:

@classmethod
def from_json(cls, payload: str) -> GCGDataConfig:
"""Deserialize a config previously produced by :meth:`to_json`."""
"""Deserialize a config previously produced by ``to_json``."""
try:
data = json.loads(payload)
except json.JSONDecodeError as e:
Expand Down Expand Up @@ -131,8 +131,8 @@ class GCGAlgorithmConfig:
Defaults to 512.
topk (int): Top-k gradient positions considered for substitution.
Defaults to 256.
temp (int): Sampling temperature placeholder; the current sampling
implementation samples uniformly from the top-k. Defaults to 1.
temp (float): Sampling temperature placeholder; the current sampling
implementation samples uniformly from the top-k. Defaults to 1.0.
target_weight (float): Weight on the target-string cross-entropy loss.
Defaults to 1.0.
control_weight (float): Weight on the control-string cross-entropy loss.
Expand All @@ -153,7 +153,7 @@ class GCGAlgorithmConfig:
test_steps: int = 50
batch_size: int = 512
topk: int = 256
temp: int = 1
temp: float = 1.0
target_weight: float = 1.0
control_weight: float = 0.0
learning_rate: float = 0.01
Expand Down Expand Up @@ -240,10 +240,10 @@ class GCGOutputConfig:
class GCGConfig:
"""Top-level strategy configuration for one GCG attack run.

Bundles everything :class:`pyrit.auxiliary_attacks.gcg.GCGGenerator`'s
Bundles everything ``pyrit.auxiliary_attacks.gcg.GCGGenerator``'s
constructor needs. Per-execution data (goals, targets) is **not** here —
those flow through ``GCGGenerator.execute_async``, and for AML transport
they ride alongside this object as a separate :class:`GCGDataConfig` JSON.
they ride alongside this object as a separate ``GCGDataConfig`` JSON.

Attributes:
models (list[GCGModelConfig]): Training models the attack optimizes
Expand Down Expand Up @@ -287,11 +287,11 @@ def to_json(self) -> str:

@classmethod
def from_json(cls, payload: str) -> GCGConfig:
"""Deserialize a config previously produced by :meth:`to_json`.
"""Deserialize a config previously produced by ``to_json``.

Args:
payload (str): JSON document matching the shape produced by
:meth:`to_json`.
``to_json``.

Returns:
GCGConfig: A new ``GCGConfig`` reconstructed from ``payload``.
Expand All @@ -308,7 +308,7 @@ def from_json(cls, payload: str) -> GCGConfig:

@classmethod
def from_json_file(cls, path: str | Path) -> GCGConfig:
"""Load a config from a JSON file produced by :meth:`to_json_file`.
"""Load a config from a JSON file produced by ``to_json_file``.

Args:
path (str | Path): Filesystem path to a JSON config file.
Expand Down
Loading
Loading