diff --git a/.github/instructions/models.instructions.md b/.github/instructions/models.instructions.md index 8dd5e3255f..4a9e32baa0 100644 --- a/.github/instructions/models.instructions.md +++ b/.github/instructions/models.instructions.md @@ -6,19 +6,30 @@ applyTo: "pyrit/models/**" ## Import Boundary -`pyrit.models` is the canonical data layer. Files in `pyrit/models/` may -import only from: +PyRIT enforces a two-layer rule for its foundational packages. `pyrit.common` +is the foundation layer and `pyrit.models` is the canonical data layer that sits +directly on top of it. + +**Forward (models).** Files in `pyrit/models/` may import only from: - the standard library - `pydantic` -- `pyrit.common.deprecation` +- any `pyrit.common.*` submodule (the whole prefix) - other `pyrit.models.*` submodules -If a helper needs another `pyrit.*` package, it does not belong on a model — -put it in that package as a free function or static helper. +If a helper needs another `pyrit.*` package (e.g. `pyrit.memory`, +`pyrit.score`), it does not belong on a model — put it in that package as a free +function or static helper. + +**Reverse guard (common).** Files in `pyrit/common/` may import only from the +standard library, third-party libraries, and other `pyrit.common.*` submodules. +`pyrit.common` may never import any other `pyrit.*` package — this is what keeps +it a true foundation and prevents an import cycle with `pyrit.models`. If a +`pyrit.common` helper needs `pyrit.models` (or anything higher), it belongs in +that higher layer, not in `common`. -The CI test `tests/unit/models/test_import_boundary.py` enforces this using an -allowlist of known transitional violations, each tagged with the phase that -removes it. The list must shrink monotonically: removing an import from source -without also removing its allowlist entry fails the test, and adding a new -unlisted import also fails the test. +The CI test `tests/unit/models/test_import_boundary.py` enforces both directions +using allowlists of known transitional violations, each tagged with the phase +that removes it. The lists must shrink monotonically: removing an import from +source without also removing its allowlist entry fails the test, and adding a +new unlisted import also fails the test. diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 6dc4166d7e..fed08d5f2f 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -101,14 +101,21 @@ async def from_seed_group_async( An instance of this AttackParameters type. Raises: - ValueError: If seed_group has no objective or if overrides contain invalid fields. - ValueError: If seed_group has simulated conversation but adversarial_chat/scorer not provided. + TypeError: If ``seed_group`` is not a ``SeedAttackGroup``. + ValueError: If overrides contain invalid fields, or if seed_group has simulated + conversation but adversarial_chat/scorer not provided. """ # Import here to avoid circular imports from pyrit.executor.attack.multi_turn.simulated_conversation import ( generate_simulated_conversation_async, ) + if not isinstance(seed_group, SeedAttackGroup): + raise TypeError( + f"seed_group must be a SeedAttackGroup, got {type(seed_group).__name__}. " + "Plain SeedGroup does not enforce the 'exactly one objective' invariant required for an attack." + ) + # Get valid field names for this params type valid_fields = {f.name for f in dataclasses.fields(cls)} @@ -119,12 +126,8 @@ async def from_seed_group_async( f"{cls.__name__} does not accept parameters: {invalid_fields}. Accepted parameters: {valid_fields}" ) - # Validate seed_group state before extracting parameters - seed_group.validate() - - # SeedAttackGroup validates in __init__ that objective is set - if seed_group.objective is None: - raise ValueError("seed_group.objective is not initialized") + # SeedAttackGroup's Pydantic validator guarantees exactly one objective is present. + assert seed_group.objective is not None # Build params dict, only including fields this class accepts params: dict[str, Any] = {} diff --git a/pyrit/models/seeds/__init__.py b/pyrit/models/seeds/__init__.py index 6130580d63..7b359847e2 100644 --- a/pyrit/models/seeds/__init__.py +++ b/pyrit/models/seeds/__init__.py @@ -27,16 +27,24 @@ SeedSimulatedConversation, SimulatedTargetSystemPromptPaths, ) +from pyrit.models.seeds.yaml_seed_loader import ( + load_seed_dataset_from_yaml, + load_seed_from_yaml, + load_seed_prompt_from_yaml_with_required_parameters, +) __all__ = [ + "load_seed_dataset_from_yaml", + "load_seed_from_yaml", + "load_seed_prompt_from_yaml_with_required_parameters", + "NextMessageSystemPromptPaths", "Seed", - "SeedPrompt", - "SeedObjective", - "SeedGroup", "SeedAttackGroup", "SeedAttackTechniqueGroup", + "SeedDataset", + "SeedGroup", + "SeedObjective", + "SeedPrompt", "SeedSimulatedConversation", "SimulatedTargetSystemPromptPaths", - "NextMessageSystemPromptPaths", - "SeedDataset", ] diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 2f0d045954..68af6ea87a 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -9,33 +9,53 @@ from __future__ import annotations -import abc import logging import re import uuid -from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, TypeVar -import yaml from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment +from pydantic import AwareDatetime, BaseModel, BeforeValidator, ConfigDict, Field -from pyrit.common.utils import verify_and_resolve_path -from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterator from pathlib import Path - from pyrit.models.literals import PromptDataType - logger = logging.getLogger(__name__) # TypeVar for generic return type in class methods T = TypeVar("T", bound="Seed") +def _ensure_aware_utc(value: Any) -> Any: + """ + Coerce naive datetimes (and bare date strings) to UTC so AwareDatetime accepts them. + + Args: + value: The raw value provided for a datetime field (string, datetime, or anything else). + + Returns: + Any: A timezone-aware datetime when the input was naive or a parseable date string; + otherwise the value unchanged for Pydantic to validate. + """ + if isinstance(value, str): + try: + value = datetime.fromisoformat(value) + except ValueError: + return value + if isinstance(value, datetime) and value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +# Timezone-aware datetime that interprets naive inputs as UTC instead of rejecting them. +AwareDatetimeUTC = Annotated[AwareDatetime, BeforeValidator(_ensure_aware_utc)] + + class PartialUndefined(Undefined): """Jinja undefined value that preserves unresolved placeholders as text.""" @@ -81,54 +101,55 @@ def __bool__(self) -> bool: return True # Ensures it doesn't evaluate to False -@dataclass -class Seed(YamlLoadable): +class Seed(BaseModel): """Represents seed data with various attributes and metadata.""" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + # The actual prompt value, which can be a string or a file path value: str # SHA256 hash of the value, used for deduplication - value_sha256: Optional[str] = None + value_sha256: str | None = None # Unique identifier for the prompt - id: Optional[uuid.UUID] = field(default_factory=lambda: uuid.uuid4()) + id: uuid.UUID | None = Field(default_factory=uuid.uuid4) # Name of the prompt - name: Optional[str] = None + name: str | None = None # Name of the dataset this prompt belongs to - dataset_name: Optional[str] = None + dataset_name: str | None = None # Categories of harm associated with this prompt - harm_categories: Optional[Sequence[str]] = field(default_factory=list) + harm_categories: list[str] | None = Field(default_factory=list) # Description of the prompt - description: Optional[str] = None + description: str | None = None # Authors of the prompt - authors: Optional[Sequence[str]] = field(default_factory=list) + authors: list[str] | None = Field(default_factory=list) # Groups affiliated with the prompt - groups: Optional[Sequence[str]] = field(default_factory=list) + groups: list[str] | None = Field(default_factory=list) # Source of the prompt - source: Optional[str] = None + source: str | None = None # Date when the prompt was added to the dataset - date_added: Optional[datetime] = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + date_added: AwareDatetimeUTC | None = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) # User who added the prompt to the dataset - added_by: Optional[str] = None + added_by: str | None = None # Arbitrary metadata that can be attached to the prompt - metadata: Optional[dict[str, Union[str, int]]] = field(default_factory=dict) + metadata: dict[str, Any] | None = Field(default_factory=dict) # Unique identifier for the prompt group - prompt_group_id: Optional[uuid.UUID] = None + prompt_group_id: uuid.UUID | None = None # Alias for the prompt group - prompt_group_alias: Optional[str] = None + prompt_group_alias: str | None = None # Whether this seed represents a general attack technique (not tied to a specific objective) is_general_technique: bool = False @@ -138,15 +159,10 @@ class Seed(YamlLoadable): # to prevent template injection. Trusted sources (YAML files) set this to True automatically. is_jinja_template: bool = False - @property - def data_type(self) -> PromptDataType: - """ - Return the data type for this seed. - - Base implementation returns 'text'. SeedPrompt overrides this - to support multiple data types (image_path, audio_path, etc.). - """ - return "text" + # The type of data this seed represents (e.g., text, image_path, audio_path, video_path). + # SeedPrompt overrides the default to None and infers it from the value; other seed types + # narrow it to Literal["text"]. + data_type: PromptDataType = "text" def render_template_value(self, **kwargs: Any) -> str: """ @@ -247,10 +263,13 @@ def escape_for_jinja(value: str) -> str: return f"{{% raw %}}{value}{{% endraw %}}" @classmethod - def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: + def from_yaml_file(cls: type[T], file: str | Path) -> T: """ Create a new Seed from a YAML file, marking it as a trusted Jinja2 template. + Thin shim that delegates to ``load_seed_from_yaml`` in the ``yaml_seed_loader`` module; + file I/O and the ``is_jinja_template`` trust marker live in the loader module. + Args: file: The input file path. @@ -258,35 +277,10 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: A new Seed of the specific subclass type. Raises: - ValueError: If the YAML file is invalid. - """ - file = verify_and_resolve_path(file) - - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc - - yaml_data["is_jinja_template"] = True - return cls(**yaml_data) - - @classmethod - @abc.abstractmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> Seed: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML file is invalid or empty. """ - Load a Seed from a YAML file and validate that it contains specific parameters. - - Args: - template_path: Path to the YAML file containing the template. - required_parameters: List of parameter names that must exist in the template. - error_message: Custom error message if validation fails. If None, a default message is used. + # Deferred import: yaml_seed_loader imports Seed, so importing it at module top would cycle. + from pyrit.models.seeds.yaml_seed_loader import load_seed_from_yaml - Returns: - Seed: The loaded and validated seed of the specific subclass type. - - """ + return load_seed_from_yaml(file, cls=cls) diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index 3c02c96d9c..f52f9dac7f 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective @@ -17,7 +17,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup @@ -32,25 +31,7 @@ class SeedAttackGroup(SeedGroup): next_message, etc.) is inherited from SeedGroup. """ - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - ) -> None: - """ - Initialize a SeedAttackGroup. - - Args: - seeds: Sequence of seeds. Must include exactly one SeedObjective. - - Raises: - ValueError: If seeds is empty. - ValueError: If exactly one objective is not provided. - - """ - super().__init__(seeds=seeds) - - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed attack group state. @@ -60,7 +41,7 @@ def validate(self) -> None: ValueError: If validation fails. """ - super().validate() + super()._check_invariants() self._enforce_exactly_one_objective() def _enforce_exactly_one_objective(self) -> None: diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index ee8eb0475d..f21735641d 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,16 +11,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union - from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective -if TYPE_CHECKING: - from collections.abc import Sequence - - from pyrit.models.seeds.seed import Seed - class SeedAttackTechniqueGroup(SeedGroup): """ @@ -33,35 +26,11 @@ class SeedAttackTechniqueGroup(SeedGroup): next_message, etc.) is inherited from SeedGroup. """ - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - insertion_index: int | None = None, - ) -> None: - """ - Initialize a SeedAttackTechniqueGroup. - - Args: - seeds: Sequence of seeds. All seeds must have is_general_technique=True. - insertion_index: Where to insert technique seeds when merging into a - SeedAttackGroup via ``with_technique()``. ``None`` (default) appends - at the end. An integer inserts before that position in the target - group's seed list. - - Raises: - ValueError: If seeds is empty. - ValueError: If any seed does not have is_general_technique=True. - """ - self._insertion_index = insertion_index - super().__init__(seeds=seeds) - - @property - def insertion_index(self) -> int | None: - """Where to insert technique seeds when merging, or None to append at end.""" - return self._insertion_index + # Where to insert technique seeds when merging into a SeedAttackGroup via ``with_technique()``. + # ``None`` (default) appends at the end; an integer inserts before that position. + insertion_index: int | None = None - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed attack technique group state. @@ -71,7 +40,7 @@ def validate(self) -> None: Raises: ValueError: If validation fails. """ - super().validate() + super()._check_invariants() self._enforce_all_general_strategy() self._enforce_no_objectives() diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index ac49ce9082..149f5e79e7 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -12,18 +12,23 @@ import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any -import yaml +from pydantic import BaseModel, ConfigDict, Field, model_validator -from pyrit.common import utils -from pyrit.common.utils import verify_and_resolve_path -from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) +from pyrit.models.seeds.seed import ( # noqa: TC001 (AwareDatetimeUTC is runtime-required by Pydantic) + AwareDatetimeUTC, + Seed, +) from pyrit.models.seeds.seed_attack_group import SeedAttackGroup -from pyrit.models.seeds.seed_group import SeedGroup +from pyrit.models.seeds.seed_group import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + PROMPT_ONLY_SEED_KEYS, + SeedGroup, + SeedUnion, +) from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt -from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation if TYPE_CHECKING: from collections.abc import Sequence @@ -31,201 +36,187 @@ from pydantic.types import PositiveInt - from pyrit.models.literals import PromptDataType, SeedType - from pyrit.models.seeds.seed import Seed - logger = logging.getLogger(__name__) +# Dataset-level defaults that get merged into each dict seed when missing on the seed. +# date_added/added_by/metadata are intentionally excluded — per-seed Pydantic defaults +# (default_factory) are the source of truth. +_SCALAR_DEFAULT_KEYS = ("name", "description", "source") +_LIST_DEFAULT_KEYS = ("harm_categories", "authors", "groups") + + +def _merge_unique(left: Any, right: Any) -> list[str]: + """ + Concatenate two list-or-str inputs into a deterministic, order-preserving deduped list. + + Treats ``None`` as empty, accepts bare strings as single-element lists, and preserves the + order of first occurrence (left first, then any new items from right). Used instead of + ``utils.combine_list`` because the latter goes through ``set()`` and is nondeterministic + across processes for non-trivial inputs. + + Args: + left: First list (or string) of values; falsy values are treated as empty. + right: Second list (or string) of values; falsy values are treated as empty. + + Returns: + list[str]: Deduplicated concatenation, preserving first-occurrence order. + """ + + def _as_list(v: Any) -> list[str]: + if not v: + return [] + return [v] if isinstance(v, str) else list(v) + + seen: set[str] = set() + result: list[str] = [] + for item in _as_list(left) + _as_list(right): + if item not in seen: + seen.add(item) + result.append(item) + return result + -class SeedDataset(YamlLoadable): +class SeedDataset(BaseModel): """ SeedDataset manages seed prompts plus optional top-level defaults. Prompts are stored as a Sequence[Seed], so references to prompt properties are straightforward (e.g. ds.seeds[0].value). """ - data_type: Optional[str] - name: Optional[str] - dataset_name: Optional[str] - harm_categories: Optional[Sequence[str]] - description: Optional[str] - authors: Optional[Sequence[str]] - groups: Optional[Sequence[str]] - source: Optional[str] - date_added: Optional[datetime] - added_by: Optional[str] - - # Now the actual prompts - seeds: Sequence[Seed] - + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + data_type: str | None = "text" + name: str | None = None + dataset_name: str | None = None + harm_categories: list[str] | None = None + description: str | None = None + authors: list[str] | None = Field(default_factory=list) + groups: list[str] | None = Field(default_factory=list) + source: str | None = None + date_added: AwareDatetimeUTC | None = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + added_by: str | None = None + # The default seed type for items that don't specify their own ("prompt", "objective", ...). + seed_type: SeedType | None = None + + # The actual prompts + seeds: list[SeedUnion] + + @model_validator(mode="before") @classmethod - def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: + def _build_seeds(cls, data: Any) -> Any: """ - Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. + Merge dataset-level defaults into each dict seed and normalize for the discriminator. - Args: - file: The input file path. + Concrete Seed instances pass through unchanged. For dict seeds: + + - ``seed_type`` defaults to the dataset's ``seed_type`` or ``"prompt"``. + - ``is_jinja_template`` (construction-time flag, popped from the dataset) is propagated. + - Scalar defaults (name, dataset_name, description, source) fall back to the dataset's + when the seed has none. + - List defaults (harm_categories, authors, groups) are concatenated with deterministic + order-preserving dedup (dataset values first, then seed-only additions). + - For prompts: ``data_type`` falls back to the dataset's; ``role`` defaults to ``"user"``. + - For objective/simulated_conversation: ``data_type``/``role``/``sequence``/ + ``parameters`` are stripped — they aren't valid fields on those classes and a + dataset-level value (e.g. ``data_type: image_path``) would otherwise be rejected. Returns: - SeedDataset: The loaded dataset. + The data with normalized ``seeds`` (passes through unchanged if not a dict). Raises: - ValueError: If the YAML file is invalid. + ValueError: If the dataset has no seeds or contains an unsupported seed entry. """ - file = verify_and_resolve_path(file) - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + if not isinstance(data, dict): + return data - if yaml_data is None: - raise ValueError(f"YAML file '{file}' is empty.") + data = dict(data) + is_jinja_template = data.pop("is_jinja_template", False) + raw_seeds = data.get("seeds") + if not raw_seeds: + raise ValueError("SeedDataset cannot be empty.") - yaml_data["is_jinja_template"] = True - if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009 - return cls.from_dict(yaml_data) - return cls(**yaml_data) + default_seed_type = data.get("seed_type") or "prompt" + default_data_type = data.get("data_type") or "text" + default_dataset_name = data.get("dataset_name") or data.get("name") - def __init__( - self, - *, - seeds: Optional[Union[Sequence[dict[str, Any]], Sequence[Seed]]] = None, - data_type: Optional[PromptDataType] = "text", - name: Optional[str] = None, - dataset_name: Optional[str] = None, - harm_categories: Optional[Sequence[str]] = None, - description: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - date_added: Optional[datetime] = None, - added_by: Optional[str] = None, - seed_type: Optional[SeedType] = None, - is_jinja_template: bool = False, - ) -> None: + normalized: list[Any] = [] + for p in raw_seeds: + if isinstance(p, Seed): + normalized.append(p) + continue + if not isinstance(p, dict): + raise ValueError( + "Seeds should be dicts or Seed objects (SeedPrompt, SeedObjective, SeedSimulatedConversation)." + ) + + p = dict(p) + seed_type = p.setdefault("seed_type", default_seed_type) + p["is_jinja_template"] = is_jinja_template + + for key in _SCALAR_DEFAULT_KEYS: + if not p.get(key) and data.get(key) is not None: + p[key] = data.get(key) + if not p.get("dataset_name") and default_dataset_name is not None: + p["dataset_name"] = default_dataset_name + + for key in _LIST_DEFAULT_KEYS: + p[key] = _merge_unique(data.get(key), p.get(key)) + + if seed_type == "prompt": + if not p.get("data_type"): + p["data_type"] = default_data_type + p.setdefault("role", "user") + else: + # Non-prompt seeds narrow data_type to Literal["text"] and don't have + # role/sequence/parameters fields. Drop those so dataset-level defaults + # don't bleed in and trip extra="forbid" on the leaf class. + for prompt_only in PROMPT_ONLY_SEED_KEYS: + p.pop(prompt_only, None) + + normalized.append(p) + + data["seeds"] = normalized + return data + + @classmethod + def from_yaml_file(cls, file: str | Path) -> SeedDataset: """ - Initialize the dataset. - Typically, you'll call from_dict or from_yaml_file so that top-level defaults - are merged into each seed. If you're passing seeds directly, they can be - either a list of Seed objects or seed dictionaries (which then get - converted to Seed objects). + Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. + + Thin shim that delegates to + ``pyrit.models.seeds.yaml_seed_loader.load_seed_dataset_from_yaml``; file I/O and + the ``is_jinja_template`` trust marker live in the loader module. Args: - seeds: List of seed dictionaries or Seed objects. - data_type: Default data type for seeds. - name: Name of the dataset. - dataset_name: Dataset name for categorization. - harm_categories: List of harm categories. - description: Description of the dataset. - authors: List of authors. - groups: List of groups. - source: Source of the dataset. - date_added: Date when the dataset was added. - added_by: User who added the dataset. - seed_type: The type of seeds in this dataset ("prompt", "objective", or "simulated_conversation"). - is_jinja_template: When True, seed values are Jinja2 templates. Set by from_yaml_file. + file: The input file path. - Raises: - ValueError: If seeds are missing or contain invalid/contradictory seed definitions. + Returns: + SeedDataset: The loaded dataset. + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML file is invalid or empty. """ - if not seeds: - raise ValueError("SeedDataset cannot be empty.") + # Deferred import: yaml_seed_loader imports SeedDataset at module load, so importing + # it at the top of this module would create a circular import. + from pyrit.models.seeds.yaml_seed_loader import load_seed_dataset_from_yaml - input_seeds = seeds - - # Store top-level fields - self.data_type = data_type - self.name = name - self.dataset_name = dataset_name - - self.harm_categories = harm_categories - self.description = description - self.authors = authors or [] - self.groups = groups or [] - self.source = source - self.date_added = date_added or datetime.now(tz=timezone.utc) - self.added_by = added_by - - # Convert any dictionaries in `seeds` to SeedPrompt and/or SeedObjective objects - self.seeds = [] - for p in input_seeds: - if isinstance(p, dict): - p_seed_type = p.get("seed_type", seed_type) # type: ignore[ty:no-matching-overload] - - effective_type: SeedType = "prompt" - if p_seed_type == "objective": - effective_type = "objective" - elif p_seed_type == "simulated_conversation": - effective_type = "simulated_conversation" - elif p_seed_type == "prompt": - effective_type = "prompt" - - # Extract common base parameters (from Seed base class) with dataset defaults. - # Note: If Seed base class param names change, update here too. - # SeedSimulatedConversation computes its own value, so we don't require it. - base_params = { - "value_sha256": p.get("value_sha256"), # type: ignore[ty:invalid-argument-type] - "id": uuid.uuid4(), - "name": p.get("name") or self.name, # type: ignore[ty:invalid-argument-type] - "dataset_name": p.get("dataset_name") or self.dataset_name or self.name, # type: ignore[ty:invalid-argument-type] - "harm_categories": p.get("harm_categories", []), # type: ignore[ty:no-matching-overload] - "description": p.get("description") or self.description, # type: ignore[ty:invalid-argument-type] - "authors": p.get("authors", []), # type: ignore[ty:no-matching-overload] - "groups": p.get("groups", []), # type: ignore[ty:no-matching-overload] - "source": p.get("source") or self.source, # type: ignore[ty:invalid-argument-type] - "date_added": p.get("date_added"), # type: ignore[ty:invalid-argument-type] - "added_by": p.get("added_by"), # type: ignore[ty:invalid-argument-type] - "metadata": p.get("metadata", {}), # type: ignore[ty:no-matching-overload] - "prompt_group_id": p.get("prompt_group_id"), # type: ignore[ty:invalid-argument-type] - "is_jinja_template": is_jinja_template, - } - - if effective_type == "simulated_conversation": - _adv_path = p.get("adversarial_chat_system_prompt_path") # type: ignore[ty:invalid-argument-type] - _sim_path = p.get("simulated_target_system_prompt_path") # type: ignore[ty:invalid-argument-type] - _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} # type: ignore[ty:no-matching-overload] - if _adv_path is not None: - _sc_kwargs["adversarial_chat_system_prompt_path"] = str(_adv_path) - if _sim_path is not None: - _sc_kwargs["simulated_target_system_prompt_path"] = str(_sim_path) - self.seeds.append(SeedSimulatedConversation(**_sc_kwargs)) # type: ignore[ty:invalid-argument-type] - elif effective_type == "objective": - # SeedObjective inherits data_type="text" from base Seed property - base_params["value"] = p["value"] # type: ignore[ty:invalid-argument-type] - self.seeds.append(SeedObjective(**base_params)) # type: ignore[ty:invalid-argument-type] - else: # prompt - base_params["value"] = p["value"] # type: ignore[ty:invalid-argument-type] - self.seeds.append( - SeedPrompt( - **base_params, # type: ignore[ty:invalid-argument-type] - data_type=p.get("data_type") or self.data_type, # type: ignore[ty:invalid-argument-type] - role=p.get("role", "user"), # type: ignore[ty:no-matching-overload] - sequence=p.get("sequence", 0), # type: ignore[ty:no-matching-overload] - parameters=p.get("parameters", {}), # type: ignore[ty:no-matching-overload] - ) - ) - elif isinstance(p, (SeedPrompt, SeedObjective, SeedSimulatedConversation)): - self.seeds.append(p) - else: - raise ValueError( - "Seeds should be dicts or Seed objects (SeedPrompt, SeedObjective, SeedSimulatedConversation)." - ) + return load_seed_dataset_from_yaml(file) def get_values( self, *, - first: Optional[PositiveInt] = None, - last: Optional[PositiveInt] = None, - harm_categories: Optional[Sequence[str]] = None, + first: PositiveInt | None = None, + last: PositiveInt | None = None, + harm_categories: Sequence[str] | None = None, ) -> Sequence[str]: """ Extract and return prompt values from the dataset. Args: - first (Optional[int]): If provided, values from the first N prompts are included. - last (Optional[int]): If provided, values from the last N prompts are included. - harm_categories (Optional[Sequence[str]]): If provided, only prompts containing at least one of + first (int | None): If provided, values from the first N prompts are included. + last (int | None): If provided, values from the last N prompts are included. + harm_categories (Sequence[str] | None): If provided, only prompts containing at least one of these harm categories are included. Returns: @@ -253,15 +244,13 @@ def get_values( return first_part + last_part - def get_random_values( - self, *, number: PositiveInt, harm_categories: Optional[Sequence[str]] = None - ) -> Sequence[str]: + def get_random_values(self, *, number: PositiveInt, harm_categories: Sequence[str] | None = None) -> Sequence[str]: """ Extract and return random prompt values from the dataset. Args: number (int): The number of random prompt values to return. - harm_categories (Optional[Sequence[str]]): If provided, only prompts containing at least one of + harm_categories (Sequence[str] | None): If provided, only prompts containing at least one of these harm categories are included. Returns: @@ -274,56 +263,36 @@ def get_random_values( @classmethod def from_dict(cls, data: dict[str, Any]) -> SeedDataset: """ - Build a SeedDataset by merging top-level defaults into each item in `seeds`. + Build a SeedDataset, assigning per-seed ``prompt_group_id`` by alias. + + Default merging now lives in ``_build_seeds`` so direct construction and + ``from_dict`` produce equivalent results. This method handles the YAML-only + concerns: rejecting pre-set ``prompt_group_id`` on input seeds and resolving + ``prompt_group_alias`` into a shared ``prompt_group_id``. Args: data (Dict[str, Any]): Dataset payload with top-level defaults and seed entries. Returns: - SeedDataset: Constructed dataset with merged defaults. + SeedDataset: Constructed dataset. Raises: - ValueError: If any seed entry includes a pre-set prompt_group_id. - + ValueError: If any seed entry includes a pre-set ``prompt_group_id``. """ - # Pop out the seeds section - seeds_data = data.pop("seeds", []) - - dataset_defaults = data # everything else is top-level + data = dict(data) - merged_seeds = [] - for p in seeds_data: - # Merge dataset-level fields with the prompt-level fields - merged = utils.combine_dict(dataset_defaults, p) + # Shallow-copy each dict seed so alias resolution doesn't mutate caller-owned dicts; + # non-dict seeds (e.g. Seed instances) pass through untouched. + seeds_data: list[Any] = [dict(seed) if isinstance(seed, dict) else seed for seed in data.get("seeds", [])] - merged["harm_categories"] = utils.combine_list( - dataset_defaults.get("harm_categories", []), - p.get("harm_categories", []), - ) - - merged["authors"] = utils.combine_list( - dataset_defaults.get("authors", []), - p.get("authors", []), - ) - - merged["groups"] = utils.combine_list( - dataset_defaults.get("groups", []), - p.get("groups", []), - ) - - if "data_type" not in merged: - merged["data_type"] = dataset_defaults.get("data_type", "text") - - merged_seeds.append(merged) - - for seed in merged_seeds: + dict_seeds = [s for s in seeds_data if isinstance(s, dict)] + for seed in dict_seeds: if "prompt_group_id" in seed: raise ValueError("prompt_group_id should not be set in seed data") + cls._set_seed_group_id_by_alias(dict_seeds) - SeedDataset._set_seed_group_id_by_alias(seed_prompts=merged_seeds) - - # Now create the dataset with the newly merged prompt dicts - return cls(seeds=merged_seeds, **dataset_defaults) + data["seeds"] = seeds_data + return cls.model_validate(data) def render_template_value(self, **kwargs: object) -> None: """ diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 579ba44360..3998820c38 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -13,9 +13,10 @@ import logging import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator -from pyrit.common.yaml_loadable import YamlLoadable from pyrit.models.messages.message import Message from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.seeds.seed import Seed @@ -28,8 +29,22 @@ logger = logging.getLogger(__name__) +# Polymorphic union of seed types that can appear inside a SeedGroup. The discriminator +# field ``seed_type`` is set per-leaf-class so Pydantic dispatches to the correct constructor +# during validation. Exported so SeedDataset (and any future container) can reuse the same +# tagged union for its own ``seeds`` field. +SeedUnion = Annotated[ + SeedPrompt | SeedObjective | SeedSimulatedConversation, + Field(discriminator="seed_type"), +] + +# Fields that only exist on prompt-type seeds. They are stripped from non-prompt seed dicts so +# dataset/group-level defaults (e.g. ``data_type: image_path``) don't bleed in and trip +# ``extra="forbid"`` on the leaf class. Shared with SeedDataset, which imports it from here. +PROMPT_ONLY_SEED_KEYS = ("data_type", "role", "sequence", "parameters") -class SeedGroup(YamlLoadable): + +class SeedGroup(BaseModel): """ A container for grouping prompts that need to be sent together. @@ -42,84 +57,96 @@ class SeedGroup(YamlLoadable): All prompts in the group share the same `prompt_group_id`. """ - seeds: list[Seed] + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + seeds: list[SeedUnion] - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - is_jinja_template: bool = False, - ) -> None: + @model_validator(mode="before") + @classmethod + def _coerce_seeds(cls, data: Any) -> Any: """ - Initialize a SeedGroup. + Normalize dict seed inputs so the polymorphic discriminator can dispatch. - Args: - seeds: Sequence of seeds. Can include: - - SeedObjective (or dict with seed_type="objective") - - SeedSimulatedConversation (or dict with seed_type="simulated_conversation") - - SeedPrompt for prompts (or dict with seed_type="prompt" or no seed_type) - is_jinja_template: When True, seed values are treated as Jinja2 templates. - Set automatically by from_yaml_file for trusted sources. + Concrete Seed instances pass through; dicts are tagged with a default + ``seed_type="prompt"`` when missing and have the construction-time + ``is_jinja_template`` flag propagated in. ``data_type`` is stripped for + non-prompt seeds because they narrow it to ``Literal["text"]``; dataset-level + defaults must not bleed into them. - Raises: - ValueError: If seeds is empty. - ValueError: If multiple objectives are provided. - ValueError: If SeedPrompt sequences overlap with SeedSimulatedConversation range. + Returns: + The data with normalized ``seeds`` (passes through unchanged if not a dict). + Raises: + ValueError: If the group has no seeds or a seed has an unsupported type. """ - if not seeds: + if not isinstance(data, dict): + return data + + data = dict(data) + is_jinja_template = data.pop("is_jinja_template", False) + raw_seeds = data.get("seeds") + if not raw_seeds: raise ValueError("SeedGroup cannot be empty.") - self.seeds = [] - for seed in seeds: + normalized: list[Any] = [] + for seed in raw_seeds: if isinstance(seed, Seed): - self.seeds.append(seed) - elif isinstance(seed, dict): - seed["is_jinja_template"] = is_jinja_template - seed_type = seed.pop("seed_type", None) - - if seed_type == "simulated_conversation": - # SeedSimulatedConversation doesn't use data_type (always text) - seed.pop("data_type", None) - self.seeds.append(SeedSimulatedConversation.from_dict(seed)) - elif seed_type == "objective": - # SeedObjective doesn't use data_type (always text) - seed.pop("data_type", None) - self.seeds.append(SeedObjective(**seed)) - else: - self.seeds.append(SeedPrompt(**seed)) - else: + normalized.append(seed) + continue + if not isinstance(seed, dict): raise ValueError(f"Invalid seed type: {type(seed)}") + seed = dict(seed) + seed.setdefault("seed_type", "prompt") + seed["is_jinja_template"] = is_jinja_template + if seed["seed_type"] == "prompt": + seed.setdefault("role", "user") + else: + # Non-prompt seeds narrow data_type to Literal["text"] and don't have + # role/sequence/parameters fields. Drop them so dataset/group-level + # defaults don't bleed in and trip extra="forbid". + for prompt_only in PROMPT_ONLY_SEED_KEYS: + seed.pop(prompt_only, None) + normalized.append(seed) + + data["seeds"] = normalized + return data - # Validate and normalize the seeds - self.validate() + @model_validator(mode="after") + def _finalize(self) -> SeedGroup: + """ + Validate the group and reorder seeds into canonical order. + + Canonical order is: objective, simulated_conversation, then prompts sorted by sequence. - # Extract simulated conversation config - self._simulated_conversation_config = self._get_simulated_conversation() + Returns: + SeedGroup: The validated, reordered group. + """ + self._check_invariants() - # Reconstruct seeds in canonical order: objective, simulated_conversation, sorted prompts objective = self._get_objective() - simulated_conv = self._simulated_conversation_config + simulated_conv = self._get_simulated_conversation() sorted_prompts = sorted(self.prompts, key=lambda p: p.sequence if p.sequence is not None else 0) - self.seeds = [] + new_seeds: list[SeedUnion] = [] if objective: - self.seeds.append(objective) + new_seeds.append(objective) if simulated_conv: - self.seeds.append(simulated_conv) - self.seeds.extend(sorted_prompts) + new_seeds.append(simulated_conv) + new_seeds.extend(sorted_prompts) + self.seeds = new_seeds + return self # ========================================================================= # Validation # ========================================================================= - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed group state. - This method can be called after external modifications to seeds - to ensure the group remains in a valid state. It is automatically - called during initialization. + Renamed from ``validate`` because that name shadows ``BaseModel.validate`` and + would silently return ``None`` instead of constructing when called on the class. + Subclasses override this hook to add stronger invariants. Raises: ValueError: If validation fails. @@ -235,12 +262,12 @@ def _enforce_no_sequence_overlap_with_simulated(self) -> None: # Seed Accessors # ========================================================================= - def _get_objective(self) -> Optional[SeedObjective]: + def _get_objective(self) -> SeedObjective | None: """ Get the objective seed if present. Returns: - Optional[SeedObjective]: Objective seed when available; otherwise None. + SeedObjective | None: Objective seed when available; otherwise None. """ for seed in self.seeds: @@ -248,12 +275,12 @@ def _get_objective(self) -> Optional[SeedObjective]: return seed return None - def _get_simulated_conversation(self) -> Optional[SeedSimulatedConversation]: + def _get_simulated_conversation(self) -> SeedSimulatedConversation | None: """ Get the simulated conversation seed if present. Returns: - Optional[SeedSimulatedConversation]: Simulated conversation seed when available; otherwise None. + SeedSimulatedConversation | None: Simulated conversation seed when available; otherwise None. """ for seed in self.seeds: @@ -267,7 +294,7 @@ def prompts(self) -> Sequence[SeedPrompt]: return [seed for seed in self.seeds if isinstance(seed, SeedPrompt)] @property - def objective(self) -> Optional[SeedObjective]: + def objective(self) -> SeedObjective | None: """Get the objective for this group.""" return self._get_objective() @@ -291,21 +318,21 @@ def harm_categories(self) -> list[str]: # ========================================================================= @property - def simulated_conversation_config(self) -> Optional[SeedSimulatedConversation]: + def simulated_conversation_config(self) -> SeedSimulatedConversation | None: """Get the simulated conversation configuration if set.""" - return self._simulated_conversation_config + return self._get_simulated_conversation() @property def has_simulated_conversation(self) -> bool: """Check if this group uses simulated conversation generation.""" - return self._simulated_conversation_config is not None + return self._get_simulated_conversation() is not None # ========================================================================= # Message Extraction # ========================================================================= @property - def prepended_conversation(self) -> Optional[list[Message]]: + def prepended_conversation(self) -> list[Message] | None: """ Returns Messages that should be prepended as conversation history. @@ -335,7 +362,7 @@ def prepended_conversation(self) -> Optional[list[Message]]: return self._prompts_to_messages(list(self.prompts)) @property - def next_message(self) -> Optional[Message]: + def next_message(self) -> Message | None: """ Returns a Message containing only the last turn's prompts if it's a user message. @@ -375,7 +402,7 @@ def user_messages(self) -> list[Message]: return self._prompts_to_messages(list(self.prompts)) - def _get_last_sequence_role(self) -> Optional[str]: + def _get_last_sequence_role(self) -> str | None: """ Get the role of the last sequence. diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index 0f0edd743a..dc0588d3f1 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -8,28 +8,34 @@ from __future__ import annotations import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Literal + +from pydantic import model_validator from pyrit.common.path import PATHS_DICT from pyrit.models.seeds.seed import Seed -if TYPE_CHECKING: - from pathlib import Path - logger = logging.getLogger(__name__) -@dataclass class SeedObjective(Seed): """Represents a seed objective with various attributes and metadata.""" - is_general_technique: bool = False + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["objective"] = "objective" + + # Objectives are always text. Narrowing the base field rejects non-text values up-front + # rather than silently dropping them downstream. + data_type: Literal["text"] = "text" - def __post_init__(self) -> None: + @model_validator(mode="after") + def _validate_and_render(self) -> SeedObjective: """ Post-initialization to render the template to replace existing values. + Returns: + SeedObjective: The validated, rendered objective. + Raises: ValueError: If is_general_technique is True. """ @@ -37,26 +43,5 @@ def __post_init__(self) -> None: raise ValueError("SeedObjective cannot be a general technique.") # Only trusted templates are rendered through Jinja — see seed_prompt.py for details. if self.is_jinja_template: - self.value = super().render_template_value_silent(**PATHS_DICT) - - @classmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> SeedObjective: - """ - Load a Seed from a YAML file. Because SeedObjectives do not have any parameters, the required_parameters - and error_message arguments are unused. - - Args: - template_path: Path to the YAML file containing the template. - required_parameters: List of parameter names that must exist in the template. - error_message: Custom error message if validation fails. If None, a default message is used. - - Returns: - SeedObjective: The loaded and validated seed of the specific subclass type. - - """ - return cls.from_yaml_file(template_path) + self.value = self.render_template_value_silent(**PATHS_DICT) + return self diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index fa6b9b59db..027b6935e9 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -8,48 +8,56 @@ from __future__ import annotations import logging -from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal +from pydantic import Field, model_validator from tinytag import TinyTag from pyrit.common.path import PATHS_DICT from pyrit.models import DataTypeSerializer +from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ChatMessageRole, + PromptDataType, +) from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: import uuid - from collections.abc import Sequence from pyrit.models import Message - from pyrit.models.literals import ChatMessageRole, PromptDataType logger = logging.getLogger(__name__) -@dataclass class SeedPrompt(Seed): """Represents a seed prompt with various attributes and metadata.""" + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["prompt"] = "prompt" + # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) - # This field shadows the base class property to allow per-prompt data types - data_type: Optional[PromptDataType] = None + # This field overrides the base default to allow per-prompt data types inferred from the value + data_type: PromptDataType | None = None # Role of the prompt in a conversation (e.g., "user", "assistant") - role: Optional[ChatMessageRole] = None + role: ChatMessageRole | None = None # Sequence number for ordering prompts in a conversation, prompts with # the same sequence number are grouped together if they also share the same prompt_group_id sequence: int = 0 # Parameters that can be used in the prompt template - parameters: Optional[Sequence[str]] = field(default_factory=list) + parameters: list[str] | None = Field(default_factory=list) - def __post_init__(self) -> None: + @model_validator(mode="after") + def _render_and_infer_data_type(self) -> SeedPrompt: """ Render template placeholders and infer data_type after initialization. + Returns: + SeedPrompt: The validated prompt with rendered value and inferred data_type. + Raises: ValueError: If file-based data type cannot be inferred from extension. @@ -85,6 +93,8 @@ def __post_init__(self) -> None: else: self.data_type = "text" + return self + def set_encoding_metadata(self) -> None: """ Set encoding metadata for the prompt within metadata dictionary. For images, this is just the @@ -126,12 +136,15 @@ def set_encoding_metadata(self) -> None: @classmethod def from_yaml_with_required_parameters( cls, - template_path: Union[str, Path], + template_path: str | Path, required_parameters: list[str], - error_message: Optional[str] = None, + error_message: str | None = None, ) -> SeedPrompt: """ - Load a Seed from a YAML file and validate that it contains specific parameters. + Load a SeedPrompt from a YAML file and validate that it declares each required parameter. + + Thin shim that delegates to + ``pyrit.models.seeds.yaml_seed_loader.load_seed_prompt_from_yaml_with_required_parameters``. Args: template_path: Path to the YAML file containing the template. @@ -139,27 +152,25 @@ def from_yaml_with_required_parameters( error_message: Custom error message if validation fails. If None, a default message is used. Returns: - SeedPrompt: The loaded and validated SeedPrompt of the specific subclass type. + SeedPrompt: The loaded and validated SeedPrompt. Raises: ValueError: If the template doesn't contain all required parameters. - """ - sp = cls.from_yaml_file(template_path) - - if sp.parameters is None or not all(param in sp.parameters for param in required_parameters): - if error_message is None: - error_message = f"Template must have these parameters: {', '.join(required_parameters)}" - raise ValueError(f"{error_message}: '{sp}'") + # Deferred import: yaml_seed_loader imports SeedPrompt at module load, so importing + # it at the top of this module would create a circular import. + from pyrit.models.seeds.yaml_seed_loader import load_seed_prompt_from_yaml_with_required_parameters - return sp + return load_seed_prompt_from_yaml_with_required_parameters( + template_path, required_parameters, error_message=error_message + ) @staticmethod def from_messages( messages: list[Message], *, starting_sequence: int = 0, - prompt_group_id: Optional[uuid.UUID] = None, + prompt_group_id: uuid.UUID | None = None, ) -> list[SeedPrompt]: """ Convert a list of Messages to a list of SeedPrompts. diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index 019e842faa..7b5daa0009 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -15,12 +15,14 @@ import enum import hashlib +import importlib.metadata import json import logging from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Literal + +from pydantic import field_validator, model_validator -import pyrit from pyrit.common.path import EXECUTOR_SIMULATED_TARGET_PATH from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_prompt import SeedPrompt @@ -66,64 +68,62 @@ class SeedSimulatedConversation(Seed): """ - def __init__( - self, - *, - adversarial_chat_system_prompt_path: Union[str, Path], - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - next_message_system_prompt_path: Optional[Union[str, Path]] = None, - num_turns: int = 3, - sequence: int = 0, - pyrit_version: Optional[str] = None, - **kwargs: Any, - ) -> None: - """ - Initialize a SeedSimulatedConversation. + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["simulated_conversation"] = "simulated_conversation" - Args: - adversarial_chat_system_prompt_path: Path to YAML file containing the adversarial - chat system prompt. - simulated_target_system_prompt_path: Optional path to YAML file containing - the simulated target system prompt. Defaults to the compliant prompt. - next_message_system_prompt_path: Optional path to YAML file containing the system - prompt for generating a final user message. If provided, after the simulated - conversation is generated, a single LLM call generates a user message that - attempts to get the target to fulfill the objective. Defaults to None - (no next message generation). - num_turns: Number of conversation turns to generate. Defaults to 3. - sequence: The starting sequence number for generated turns. When combined with - static SeedPrompts, this determines where the simulated turns are inserted. - Defaults to 0. - pyrit_version: PyRIT version for reproducibility tracking. Defaults to current version. - **kwargs: Additional arguments passed to the Seed base class. + # Simulated conversations are always text. Narrowing the base field rejects non-text values + # up-front rather than silently dropping them downstream. + data_type: Literal["text"] = "text" - Raises: - ValueError: If num_turns is not positive or sequence is negative. + # value is computed from the config in the after-validator. The base default of "" plus a + # before-validator that strips any user-supplied value keeps round-trips clean: a dumped + # value comes back in, is dropped, then is recomputed (and matches if the config matches). + value: str = "" + + # Simulated conversations are general techniques by default. + is_general_technique: bool = True + num_turns: int = 3 + sequence: int = 0 + adversarial_chat_system_prompt_path: Path + simulated_target_system_prompt_path: Path = SimulatedTargetSystemPromptPaths.COMPLIANT.value + next_message_system_prompt_path: Path | None = None + pyrit_version: str | None = None + + @model_validator(mode="before") + @classmethod + def _strip_user_value(cls, data: Any) -> Any: """ - # Apply default for simulated target system prompt if not provided - if simulated_target_system_prompt_path is None: - simulated_target_system_prompt_path = SimulatedTargetSystemPromptPaths.COMPLIANT.value - if num_turns <= 0: - raise ValueError("num_turns must be a positive integer") - if sequence < 0: - raise ValueError("sequence must be a non-negative integer") + Drop any user-supplied ``value`` from dict input; it is always recomputed in the + after-validator. This keeps round-tripping clean and makes the API honest about the + fact that ``value`` is a derived JSON serialization of the config. - self.adversarial_chat_system_prompt_path = Path(adversarial_chat_system_prompt_path) - self.simulated_target_system_prompt_path = Path(simulated_target_system_prompt_path) - self.next_message_system_prompt_path = ( - Path(next_message_system_prompt_path) if next_message_system_prompt_path else None - ) - self.num_turns = num_turns - self.sequence = sequence - self.pyrit_version = pyrit_version or pyrit.__version__ + Returns: + The data with ``value`` removed if it was a dict; otherwise the input unchanged. + """ + if isinstance(data, dict) and "value" in data: + data = dict(data) + data.pop("value", None) + return data - # Compute value and pass to parent - # Remove 'value' from kwargs if present since we compute it - kwargs.pop("value", None) - # Default is_general_technique to True for simulated conversations - kwargs.setdefault("is_general_technique", True) - super().__init__(value=self._compute_value(), **kwargs) + @field_validator("simulated_target_system_prompt_path", mode="before") + @classmethod + def _default_simulated_target_path(cls, value: Any) -> Any: + # Reconstruction from memory may pass an explicit None; fall back to the compliant default. + if value is None: + return SimulatedTargetSystemPromptPaths.COMPLIANT.value + return value + + @model_validator(mode="after") + def _validate_and_compute_value(self) -> SeedSimulatedConversation: + if self.num_turns <= 0: + raise ValueError("num_turns must be a positive integer") + if self.sequence < 0: + raise ValueError("sequence must be a non-negative integer") + if not self.pyrit_version: + self.pyrit_version = importlib.metadata.version("pyrit") + self.value = self._compute_value() + return self def _compute_value(self) -> str: """ @@ -145,70 +145,6 @@ def _compute_value(self) -> str: } return json.dumps(config, sort_keys=True, separators=(",", ":")) - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SeedSimulatedConversation: - """ - Create a SeedSimulatedConversation from a dictionary, typically from YAML. - - Expected format: - num_turns: 3 - adversarial_chat_system_prompt_path: path/to/adversarial.yaml - simulated_target_system_prompt_path: path/to/simulated.yaml # optional - - Args: - data: Dictionary containing the configuration. - - Returns: - A new SeedSimulatedConversation instance. - - Raises: - ValueError: If required configuration fields are missing. - - """ - adversarial_path = data.get("adversarial_chat_system_prompt_path") - if not adversarial_path: - raise ValueError("adversarial_chat_system_prompt_path is required") - - return cls( - num_turns=data.get("num_turns", 3), - sequence=data.get("sequence", 0), - adversarial_chat_system_prompt_path=adversarial_path, - simulated_target_system_prompt_path=data.get("simulated_target_system_prompt_path"), - next_message_system_prompt_path=data.get("next_message_system_prompt_path"), - ) - - @classmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> SeedSimulatedConversation: - """ - Load a SeedSimulatedConversation from a YAML file and validate required parameters. - - Args: - template_path: Path to the YAML file containing the config. - required_parameters: List of parameter names that must exist. - error_message: Custom error message if validation fails. - - Returns: - The loaded and validated SeedSimulatedConversation. - - Raises: - ValueError: If required parameters are missing. - - """ - instance = cls.from_yaml_file(template_path) - - # Check required parameters - for param in required_parameters: - if not hasattr(instance, param) or getattr(instance, param) is None: - msg = error_message or f"Missing required parameter: {param}" - raise ValueError(msg) - - return instance - def get_identifier(self) -> dict[str, Any]: """ Get an identifier dict capturing this configuration for comparison/storage. @@ -246,8 +182,8 @@ def load_simulated_target_system_prompt( *, objective: str, num_turns: int, - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - ) -> Optional[str]: + simulated_target_system_prompt_path: str | Path | None = None, + ) -> str | None: """ Load and render the simulated target system prompt. diff --git a/pyrit/models/seeds/yaml_seed_loader.py b/pyrit/models/seeds/yaml_seed_loader.py new file mode 100644 index 0000000000..5d6719ede2 --- /dev/null +++ b/pyrit/models/seeds/yaml_seed_loader.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +YAML loaders for seed types. + +These functions live separately from the seed classes themselves because the +*trust claim* that a value came from a vetted local YAML file (vs. an untrusted +remote dataset) is a property of the loader, not of the data class. A ``Seed`` +instance can't know its own provenance; the loader can, so the +``is_jinja_template=True`` marker is set exactly once at this boundary. + +The ``from_yaml_file`` and ``from_yaml_with_required_parameters`` classmethods +on the seed classes are thin shims that delegate here. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import yaml + +from pyrit.common.utils import verify_and_resolve_path +from pyrit.models.seeds.seed import Seed +from pyrit.models.seeds.seed_dataset import SeedDataset +from pyrit.models.seeds.seed_prompt import SeedPrompt + +if TYPE_CHECKING: + from pathlib import Path + +T = TypeVar("T", bound=Seed) + +# Seed model fields that callers may write as a bare string in YAML +# (e.g. ``authors: Jane Doe``) but the model declares as ``list[str]``. +# The loader wraps such scalars before constructing the model so the model +# itself can stay strict and YAML's "scalar-or-sequence" idiom doesn't leak +# into the data class. +_SCALAR_OR_LIST_FIELDS: tuple[str, ...] = ("harm_categories", "authors", "groups", "parameters") + + +def _canonicalize_scalar_lists(data: dict[str, Any]) -> dict[str, Any]: + """ + Wrap bare-string values into single-element lists for known list-typed seed fields. + + Mutates ``data`` in place and recurses into nested ``seeds`` entries so + dataset/group YAML files (which carry both top-level defaults and a list of seed + dicts) are normalized in one pass. + + Args: + data: A YAML-decoded mapping representing a seed, group, or dataset. + + Returns: + The same mapping, with scalar values on known list fields wrapped into lists. + """ + for key in _SCALAR_OR_LIST_FIELDS: + if isinstance(data.get(key), str): + data[key] = [data[key]] + seeds = data.get("seeds") + if isinstance(seeds, list): + for seed in seeds: + if isinstance(seed, dict): + _canonicalize_scalar_lists(seed) + return data + + +def _read_yaml(file: str | Path) -> dict[str, Any]: + """ + Resolve, read, and parse a YAML file as a mapping. + + Args: + file: Path to a YAML file. + + Returns: + The parsed top-level mapping. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed or empty. + """ + file = verify_and_resolve_path(file) + try: + data = yaml.safe_load(file.read_text("utf-8")) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + + if data is None: + raise ValueError(f"YAML file '{file}' is empty.") + if not isinstance(data, dict): + raise ValueError(f"YAML file '{file}' must contain a mapping at the top level.") + return data + + +def load_seed_from_yaml(file: str | Path, *, cls: type[T]) -> T: + """ + Load a single seed of type ``cls`` from a YAML file. + + The seed is marked ``is_jinja_template=True`` because the file is treated + as a trusted, vetted local template at this boundary. Bare-string values + for known list-typed fields (``authors``, ``harm_categories``, ``groups``, + ``parameters``) are wrapped into single-element lists so the model itself + can stay strict about its shape. + + Args: + file: Path to the YAML file containing the seed definition. + cls: Seed subclass to instantiate (e.g. ``SeedPrompt``, ``SeedObjective``). + + Returns: + An instance of ``cls`` populated from the YAML file. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed, empty, or fails validation for ``cls``. + """ + data = _canonicalize_scalar_lists(_read_yaml(file)) + data["is_jinja_template"] = True + return cls(**data) + + +def load_seed_dataset_from_yaml(file: str | Path) -> SeedDataset: + """ + Load a ``SeedDataset`` from a YAML file. + + Nested seeds inherit the ``is_jinja_template=True`` trust marker set at this + boundary; per-seed overrides in the YAML are intentionally ignored. + + Args: + file: Path to the YAML file containing the dataset definition. + + Returns: + A ``SeedDataset`` populated from the YAML file. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed, empty, or fails dataset validation. + """ + data = _canonicalize_scalar_lists(_read_yaml(file)) + data["is_jinja_template"] = True + return SeedDataset.from_dict(data) + + +def load_seed_prompt_from_yaml_with_required_parameters( + template_path: str | Path, + required_parameters: list[str], + *, + error_message: str | None = None, +) -> SeedPrompt: + """ + Load a ``SeedPrompt`` and assert that its ``parameters`` field declares each required name. + + Args: + template_path: Path to the YAML file containing the prompt template. + required_parameters: Parameter names that must appear in ``SeedPrompt.parameters``. + error_message: Optional custom message used in the raised ``ValueError``. + + Returns: + The loaded ``SeedPrompt``. + + Raises: + ValueError: If the loaded prompt is missing any required parameter. + """ + sp = load_seed_from_yaml(template_path, cls=SeedPrompt) + if sp.parameters is None or not all(p in sp.parameters for p in required_parameters): + if error_message is None: + error_message = f"Template must have these parameters: {', '.join(required_parameters)}" + raise ValueError(f"{error_message}: '{sp}'") + return sp diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 652e203489..acec06f8b0 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -76,8 +76,8 @@ def __init__( technique seeds. Preferred over the deprecated ``attack`` parameter. attack: **Deprecated.** Will be removed in v0.16.0. The configured attack strategy to execute. Use ``attack_technique`` instead. - seed_groups: List of seed attack groups. Each seed group must - have an objective set. + seed_groups: List of seed attack groups. Each must be a + ``SeedAttackGroup`` (which guarantees exactly one objective). adversarial_chat: Optional chat target for generating adversarial prompts or simulated conversations. objective_scorer: Optional scorer for evaluating simulated @@ -87,8 +87,9 @@ def __init__( execution method. Raises: - ValueError: If seed_groups list is empty or any seed group is missing an objective. - ValueError: If neither attack_technique nor attack is provided, or both are provided. + ValueError: If seed_groups list is empty, or if neither attack_technique + nor attack is provided, or both are provided. + TypeError: If any entry of ``seed_groups`` is not a ``SeedAttackGroup``. """ self.atomic_attack_name = atomic_attack_name self.display_group = display_group or atomic_attack_name @@ -112,9 +113,13 @@ def __init__( if not seed_groups: raise ValueError("seed_groups list cannot be empty") - # Validate each seed group to ensure they are in a valid state + # Validate that each seed_group is actually a SeedAttackGroup (which Pydantic + # already ensured holds the AtomicAttack invariant of "exactly one objective" + # at construction time). A plain SeedGroup or SeedAttackTechniqueGroup is not + # accepted here even though they share a base class. for sg in seed_groups: - sg.validate() + if not isinstance(sg, SeedAttackGroup): + raise TypeError(f"seed_groups must contain SeedAttackGroup instances; got {type(sg).__name__}.") self._seed_groups = seed_groups self._validate_unique_objective_hashes() diff --git a/tests/unit/executor/attack/core/test_attack_parameters.py b/tests/unit/executor/attack/core/test_attack_parameters.py index 18cf37341f..5e8ca975c2 100644 --- a/tests/unit/executor/attack/core/test_attack_parameters.py +++ b/tests/unit/executor/attack/core/test_attack_parameters.py @@ -309,12 +309,14 @@ async def test_excluded_class_rejects_excluded_field_overrides(self) -> None: ) -async def test_from_seed_group_async_raises_when_objective_is_none(): - """Test that from_seed_group_async raises ValueError when seed_group.objective is None.""" - seed_group = MagicMock(spec=SeedAttackGroup) - seed_group.validate = MagicMock() - seed_group.objective = None - seed_group.simulated_conversation = None - - with pytest.raises(ValueError, match="seed_group.objective is not initialized"): - await AttackParameters.from_seed_group_async(seed_group=seed_group) +async def test_from_seed_group_async_rejects_plain_seed_group(): + """Plain SeedGroup is rejected at the boundary because it doesn't enforce the + 'exactly one objective' invariant SeedAttackGroup does. A real SeedAttackGroup + can't reach this method with objective=None — Pydantic validation at construction + blocks that — so the runtime guard targets the more interesting failure mode: + callers passing the wrong subtype.""" + from pyrit.models import SeedGroup + + plain_group = SeedGroup(seeds=[SeedObjective(value="Test objective")]) + with pytest.raises(TypeError, match="seed_group must be a SeedAttackGroup"): + await AttackParameters.from_seed_group_async(seed_group=plain_group) # type: ignore[arg-type] diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index bc6342d8d8..b4c33795e2 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -2,17 +2,22 @@ # Licensed under the MIT license. """ -Enforce the ``pyrit.models`` import boundary. +Enforce the ``pyrit.models`` / ``pyrit.common`` import boundary. -``pyrit.models`` is the canonical data layer. Files in ``pyrit/models/`` -may import only from stdlib, ``pydantic``, ``pyrit.common.deprecation``, -and other ``pyrit.models.*`` submodules. +PyRIT uses a two-layer rule for its foundational packages: -This test uses a ratchet pattern: ``KNOWN_TOP_LEVEL_VIOLATIONS`` and -``KNOWN_LAZY_VIOLATIONS`` track imports that exist today and are expected to -disappear in a specific phase. The lists must shrink monotonically — if a known -violation is no longer in source, this test fails and the entry must be -removed. +* **Forward (models):** files in ``pyrit/models/`` may import only from stdlib, + ``pydantic``, all of ``pyrit.common`` (the whole prefix), and other + ``pyrit.models.*`` submodules. +* **Reverse guard (common):** files in ``pyrit/common/`` may import only from + stdlib, third-party libraries, and other ``pyrit.common.*`` submodules — never + any other ``pyrit.*`` package. This keeps ``pyrit.common`` a true foundation + layer and prevents an import cycle with ``pyrit.models``. + +Both directions use a ratchet pattern: the ``KNOWN_*_VIOLATIONS`` lists track +imports that exist today and are expected to disappear in a later phase. The +lists must shrink monotonically — if a known violation is no longer in source, +this test fails and the entry must be removed. See plan.md / ``.github/instructions/models.instructions.md`` for context. """ @@ -25,79 +30,67 @@ import pytest +import pyrit.common import pyrit.models if TYPE_CHECKING: from collections.abc import Iterable # noqa: F401 MODELS_PACKAGE = Path(pyrit.models.__file__).parent +COMMON_PACKAGE = Path(pyrit.common.__file__).parent EXCLUDE_FILES: frozenset[str] = frozenset() -# Always allowed at module level (in addition to pyrit.models.* self-imports). -ALLOWED_TOP_LEVEL: frozenset[str] = frozenset( - { - "pyrit.common.deprecation", - } -) +# Forward rule: pyrit.models may import these pyrit prefixes freely. +MODELS_ALLOWED_PREFIXES: tuple[str, ...] = ("pyrit.models", "pyrit.common") + +# Reverse guard: pyrit.common may import only itself within the pyrit namespace. +COMMON_ALLOWED_PREFIXES: tuple[str, ...] = ("pyrit.common",) -# Transitional known top-level violations. Each entry names the phase that -# clears it (documentation only — the test does not parse the tag). New -# violations not in this list fail the test; entries that no longer match +# Transitional known top-level violations for pyrit.models. Each entry names the +# phase that clears it (documentation only — the test does not parse the tag). +# New violations not in this list fail the test; entries that no longer match # source also fail. -KNOWN_TOP_LEVEL_VIOLATIONS: dict[str, dict[str, str]] = { - "pyrit.models.harm_definition": { - "pyrit.common.path": "phase-1", - }, - "pyrit.models.data_type_serializer": { - "pyrit.common.path": "phase-8", - }, - "pyrit.models.seeds.seed": { - "pyrit.common.utils": "seeds-followup", - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_dataset": { - "pyrit.common": "seeds-followup", - "pyrit.common.utils": "seeds-followup", - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_group": { - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_objective": { - "pyrit.common.path": "seeds-followup", - }, - "pyrit.models.seeds.seed_prompt": { - "pyrit.common.path": "seeds-followup", - }, - "pyrit.models.seeds.seed_simulated_conversation": { - "pyrit.common.path": "seeds-followup", - }, -} +KNOWN_TOP_LEVEL_VIOLATIONS: dict[str, dict[str, str]] = {} -# Lazy / TYPE_CHECKING imports of cross-package modules. Same ratchet, tracked -# separately so the phase that removes the lazy workaround is explicit. +# Lazy / TYPE_CHECKING imports of cross-package modules from pyrit.models. Same +# ratchet, tracked separately so the phase that removes the lazy workaround is +# explicit. KNOWN_LAZY_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.models.data_type_serializer": { - "pyrit.memory": "phase-8", - "pyrit.common.path": "phase-8", + "pyrit.memory": "phase-9", }, "pyrit.models.scenario_result": { "pyrit.score.scorer_evaluation.scorer_metrics": "phase-7", "pyrit.score.scorer_evaluation.scorer_metrics_io": "phase-7", }, "pyrit.models.storage_io": { - "pyrit.auth": "phase-8", + "pyrit.auth": "phase-9", + }, +} + +# Reverse-guard violations: pyrit.common modules that still reach up into higher +# layers. These are slated to relocate; the ratchet forces them to shrink. +KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { + "pyrit.common.data_url_converter": { + "pyrit.models": "relocate", + }, + "pyrit.common.display_response": { + "pyrit.memory": "relocate", + "pyrit.models": "relocate", + }, + "pyrit.common.question_answer_helpers": { + "pyrit.models": "relocate", }, } -def _module_name_for(path: Path) -> str: - """Return the dotted module name for a file inside ``pyrit/models/``.""" - rel = path.relative_to(MODELS_PACKAGE).with_suffix("") +def _module_name_for(path: Path, *, package_root: Path, package_prefix: str) -> str: + """Return the dotted module name for a file inside ``package_root``.""" + rel = path.relative_to(package_root).with_suffix("") parts = [p for p in rel.parts if p != "__init__"] if not parts: - return "pyrit.models" - return "pyrit.models." + ".".join(parts) + return package_prefix + return package_prefix + "." + ".".join(parts) def _resolve_from_import(node: ast.ImportFrom, source_module: str) -> str: @@ -123,11 +116,17 @@ def _is_typecheck_test(test: ast.expr) -> bool: ) +def _is_allowed(mod: str, allowed_prefixes: tuple[str, ...]) -> bool: + """Return True iff ``mod`` is within one of the allowed self-prefixes.""" + return any(mod == prefix or mod.startswith(prefix + ".") for prefix in allowed_prefixes) + + class _ImportCollector(ast.NodeVisitor): - """Walk a module AST and bucket ``pyrit.*`` imports into top-level vs lazy.""" + """Walk a module AST and bucket disallowed ``pyrit.*`` imports into top-level vs lazy.""" - def __init__(self, source_module: str) -> None: + def __init__(self, source_module: str, *, allowed_prefixes: tuple[str, ...]) -> None: self.source_module = source_module + self.allowed_prefixes = allowed_prefixes self.top_level: set[str] = set() self.lazy: set[str] = set() self._in_lazy = False @@ -135,8 +134,7 @@ def __init__(self, source_module: str) -> None: def _record(self, mod: str) -> None: if not mod.startswith("pyrit."): return - # pyrit.models.* self-imports are always allowed (we are inside pyrit.models) - if mod == "pyrit.models" or mod.startswith("pyrit.models."): + if _is_allowed(mod, self.allowed_prefixes): return bucket = self.lazy if self._in_lazy else self.top_level bucket.add(mod) @@ -171,51 +169,77 @@ def visit_If(self, node: ast.If) -> None: self.generic_visit(node) -def _scan_files() -> list[Path]: - """Return all ``pyrit/models/**/*.py`` files in scope.""" - return sorted(p for p in MODELS_PACKAGE.rglob("*.py") if p.name not in EXCLUDE_FILES) +def _scan_files(package_root: Path) -> list[Path]: + """Return all ``*.py`` files under ``package_root`` in scope.""" + return sorted(p for p in package_root.rglob("*.py") if p.name not in EXCLUDE_FILES) -def _analyze(path: Path) -> tuple[str, set[str], set[str]]: - source_module = _module_name_for(path) +def _analyze( + path: Path, *, package_root: Path, package_prefix: str, allowed_prefixes: tuple[str, ...] +) -> tuple[str, set[str], set[str]]: + source_module = _module_name_for(path, package_root=package_root, package_prefix=package_prefix) tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) - collector = _ImportCollector(source_module) + collector = _ImportCollector(source_module, allowed_prefixes=allowed_prefixes) collector.visit(tree) return source_module, collector.top_level, collector.lazy -def _collect_actual_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: +def _collect_actual_imports( + *, package_root: Path, package_prefix: str, allowed_prefixes: tuple[str, ...] +) -> tuple[dict[str, set[str]], dict[str, set[str]]]: top: dict[str, set[str]] = {} lazy: dict[str, set[str]] = {} - for path in _scan_files(): - source, top_imports, lazy_imports = _analyze(path) + for path in _scan_files(package_root): + source, top_imports, lazy_imports = _analyze( + path, + package_root=package_root, + package_prefix=package_prefix, + allowed_prefixes=allowed_prefixes, + ) top[source] = top_imports lazy[source] = lazy_imports return top, lazy +def _collect_models_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: + return _collect_actual_imports( + package_root=MODELS_PACKAGE, + package_prefix="pyrit.models", + allowed_prefixes=MODELS_ALLOWED_PREFIXES, + ) + + +def _collect_common_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: + return _collect_actual_imports( + package_root=COMMON_PACKAGE, + package_prefix="pyrit.common", + allowed_prefixes=COMMON_ALLOWED_PREFIXES, + ) + + def test_no_new_top_level_violations() -> None: """Module-level pyrit imports outside the allowlist must be listed in KNOWN_TOP_LEVEL_VIOLATIONS.""" - actual_top, _ = _collect_actual_imports() + actual_top, _ = _collect_models_imports() new_violations: list[str] = [] for source, imports in actual_top.items(): known = set(KNOWN_TOP_LEVEL_VIOLATIONS.get(source, {}).keys()) for imp in sorted(imports): - if imp in ALLOWED_TOP_LEVEL or imp in known: + if imp in known: continue new_violations.append(f"{source} -> {imp}") if new_violations: pytest.fail( "New top-level pyrit imports in pyrit.models (not allowed):\n " + "\n ".join(new_violations) - + "\n\nEither remove the import or, if it is transitional, add it to " + + "\n\npyrit.models may import stdlib, pydantic, pyrit.common.*, and pyrit.models.*. " + "Either remove the import or, if it is transitional, add it to " "KNOWN_TOP_LEVEL_VIOLATIONS in this file with a phase tag." ) def test_known_top_level_violations_still_apply() -> None: """Entries in KNOWN_TOP_LEVEL_VIOLATIONS that no longer exist in source must be removed.""" - actual_top, _ = _collect_actual_imports() + actual_top, _ = _collect_models_imports() stale: list[str] = [] for source, allowed in KNOWN_TOP_LEVEL_VIOLATIONS.items(): present = actual_top.get(source, set()) @@ -230,12 +254,12 @@ def test_known_top_level_violations_still_apply() -> None: def test_no_new_lazy_violations() -> None: """Lazy/TYPE_CHECKING pyrit imports outside the allowlist must be listed in KNOWN_LAZY_VIOLATIONS.""" - _, actual_lazy = _collect_actual_imports() + _, actual_lazy = _collect_models_imports() new_violations: list[str] = [] for source, imports in actual_lazy.items(): known = set(KNOWN_LAZY_VIOLATIONS.get(source, {}).keys()) for imp in sorted(imports): - if imp in ALLOWED_TOP_LEVEL or imp in known: + if imp in known: continue new_violations.append(f"{source} -> {imp}") if new_violations: @@ -249,7 +273,7 @@ def test_no_new_lazy_violations() -> None: def test_known_lazy_violations_still_apply() -> None: """Entries in KNOWN_LAZY_VIOLATIONS that no longer exist in source must be removed.""" - _, actual_lazy = _collect_actual_imports() + _, actual_lazy = _collect_models_imports() stale: list[str] = [] for source, allowed in KNOWN_LAZY_VIOLATIONS.items(): present = actual_lazy.get(source, set()) @@ -262,9 +286,45 @@ def test_known_lazy_violations_still_apply() -> None: ) +def test_no_new_common_violations() -> None: + """pyrit.common may not import any pyrit.* outside pyrit.common (reverse guard).""" + actual_top, actual_lazy = _collect_common_imports() + new_violations: list[str] = [] + for source in sorted({*actual_top, *actual_lazy}): + imports = actual_top.get(source, set()) | actual_lazy.get(source, set()) + known = set(KNOWN_COMMON_VIOLATIONS.get(source, {}).keys()) + for imp in sorted(imports): + if imp in known: + continue + new_violations.append(f"{source} -> {imp}") + if new_violations: + pytest.fail( + "pyrit.common modules importing outside pyrit.common (forbidden):\n " + + "\n ".join(new_violations) + + "\n\npyrit.common is the foundation layer and may import only stdlib, " + "third-party libraries, and pyrit.common.*. Either remove the import or, if it " + "is transitional, add it to KNOWN_COMMON_VIOLATIONS in this file with a tag." + ) + + +def test_known_common_violations_still_apply() -> None: + """Entries in KNOWN_COMMON_VIOLATIONS that no longer exist in source must be removed.""" + actual_top, actual_lazy = _collect_common_imports() + stale: list[str] = [] + for source, allowed in KNOWN_COMMON_VIOLATIONS.items(): + present = actual_top.get(source, set()) | actual_lazy.get(source, set()) + stale.extend(f"{source} -> {imp}" for imp in allowed if imp not in present) + if stale: + pytest.fail( + "KNOWN_COMMON_VIOLATIONS entries that no longer exist in source:\n " + + "\n ".join(stale) + + "\n\nThe allowlist must shrink monotonically. Remove these entries." + ) + + def test_scan_finds_expected_files() -> None: """Sanity check: the scanner picks up the known model files.""" - scanned = {p.name for p in _scan_files()} + scanned = {p.name for p in _scan_files(MODELS_PACKAGE)} # A non-exhaustive sample of files that must exist for this test to be meaningful. expected = { "attack_result.py", diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 1dcebb0449..c2f26cf75d 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -16,12 +16,13 @@ from pyrit.models import ( Message, MessagePiece, + SeedAttackGroup, SeedDataset, SeedGroup, SeedObjective, SeedPrompt, ) -from pyrit.models.seeds import SeedSimulatedConversation +from pyrit.models.seeds import SeedSimulatedConversation, SimulatedTargetSystemPromptPaths @pytest.fixture @@ -962,7 +963,7 @@ def test_seed_dataset_dict_to_seed_prompt_all_base_params(): # SeedPrompt-specific fields "role": "assistant", "sequence": 5, - "parameters": {"param1": "val1"}, + "parameters": ["param1"], "seed_type": "prompt", } @@ -990,7 +991,7 @@ def test_seed_dataset_dict_to_seed_prompt_all_base_params(): # Verify SeedPrompt-specific fields assert seed.role == "assistant" assert seed.sequence == 5 - assert seed.parameters == {"param1": "val1"} + assert seed.parameters == ["param1"] def test_seed_dataset_dict_to_seed_objective_all_base_params(): @@ -1098,7 +1099,7 @@ def test_seed_dataset_uses_dataset_defaults_for_missing_params(): # These should use sensible defaults assert seed.role == "user" assert seed.sequence == 0 - assert seed.parameters == {} + assert seed.parameters == [] def test_next_message_single_turn_no_objective(): @@ -1110,9 +1111,96 @@ def test_next_message_single_turn_no_objective(): assert group.prepended_conversation is None assert group.next_message is not None assert len(group.next_message.message_pieces) == 1 + + +def test_seed_group_preserves_polymorphic_subclasses(): + """list[Seed] must preserve concrete subclasses: instances pass through and model_dump keeps subclass fields.""" + objective = SeedObjective(value="Reach the goal") + prompt = SeedPrompt(value="Hello", role="user", sequence=0) + group = SeedGroup(seeds=[objective, prompt]) + + # Instances are preserved, not coerced down to the base Seed. + assert isinstance(group.objective, SeedObjective) + assert any(isinstance(s, SeedPrompt) for s in group.seeds) + + # SerializeAsAny keeps subclass-specific fields on dump (SeedPrompt has role/parameters, SeedObjective does not). + dumped = group.model_dump() + by_value = {s["value"]: s for s in dumped["seeds"]} + assert by_value["Hello"]["role"] == "user" + assert "parameters" in by_value["Hello"] + assert "role" not in by_value["Reach the goal"] + + # Dict reconstruction uses the seed_type discriminator (the path used by YAML/dataset inputs). + rebuilt = SeedGroup( + seeds=[ + {"value": "Reach the goal", "seed_type": "objective"}, + {"value": "Hello", "role": "user", "sequence": 0, "seed_type": "prompt"}, + ] + ) + assert isinstance(rebuilt.objective, SeedObjective) + assert any(isinstance(s, SeedPrompt) for s in rebuilt.seeds) assert group.next_message.get_value() == "Hello" +def test_seed_group_round_trip_preserves_subclasses(): + """model_validate(model_dump()) must reconstruct the original subclass instances, not coerce to base Seed.""" + objective = SeedObjective(value="goal") + prompt = SeedPrompt(value="hi", role="user", sequence=0) + group = SeedGroup(seeds=[objective, prompt]) + + rt = SeedGroup.model_validate(group.model_dump()) + + assert [type(s).__name__ for s in rt.seeds] == [type(s).__name__ for s in group.seeds] + assert isinstance(rt.objective, SeedObjective) + assert rt.objective.value == "goal" + + +def test_seed_attack_group_round_trip_preserves_subclasses(): + """The original blocking bug: SeedAttackGroup(model_validate(model_dump())) must work.""" + sag = SeedAttackGroup( + seeds=[ + SeedObjective(value="objective"), + SeedPrompt(value="hi", data_type="text", role="user", sequence=0), + ] + ) + rt = SeedAttackGroup.model_validate(sag.model_dump()) + assert isinstance(rt.objective, SeedObjective) + assert rt.objective.value == "objective" + assert any(isinstance(s, SeedPrompt) for s in rt.seeds) + + +def test_seed_simulated_conversation_round_trip(): + """SeedSimulatedConversation's computed ``value`` round-trips through both python and json modes.""" + sc = SeedSimulatedConversation( + adversarial_chat_system_prompt_path=SimulatedTargetSystemPromptPaths.COMPLIANT.value, + num_turns=4, + ) + rt_py = SeedSimulatedConversation.model_validate(sc.model_dump()) + rt_json = SeedSimulatedConversation.model_validate(sc.model_dump(mode="json")) + assert rt_py.value == sc.value + assert rt_json.value == sc.value + assert rt_py.num_turns == 4 + + +def test_seed_objective_rejects_non_text_data_type(): + """SeedObjective locks ``data_type`` to ``"text"``.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + SeedObjective(value="goal", data_type="image_path") + + +def test_seed_simulated_conversation_rejects_non_text_data_type(): + """SeedSimulatedConversation locks ``data_type`` to ``"text"``.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + SeedSimulatedConversation( + adversarial_chat_system_prompt_path=SimulatedTargetSystemPromptPaths.COMPLIANT.value, + data_type="image_path", + ) + + def test_next_message_single_turn_with_objective(): """Test next_message property for a single-turn SeedGroup with an objective.""" prompt = SeedPrompt(value="Hello", data_type="text", sequence=0, role="user") diff --git a/tests/unit/models/test_seed_attack_technique_group.py b/tests/unit/models/test_seed_attack_technique_group.py index 05c21a62e3..8c01a166d2 100644 --- a/tests/unit/models/test_seed_attack_technique_group.py +++ b/tests/unit/models/test_seed_attack_technique_group.py @@ -148,7 +148,7 @@ def test_validate_all_general_strategy_passes(self): ] ) # Should not raise - group.validate() + group._check_invariants() def test_error_message_includes_non_general_types(self): """Test that error message lists the types of non-general seeds.""" diff --git a/tests/unit/models/test_seed_group.py b/tests/unit/models/test_seed_group.py index 4a8e43a2c1..8c7e08a811 100644 --- a/tests/unit/models/test_seed_group.py +++ b/tests/unit/models/test_seed_group.py @@ -610,7 +610,7 @@ def test_merged_group_is_valid_seed_attack_group(self): merged = base.with_technique(technique=technique) assert isinstance(merged, SeedAttackGroup) - merged.validate() # should not raise + merged._check_invariants() # should not raise def test_raises_when_technique_has_simulated_conversation_and_prompts_overlap(self): """Merging a technique with SeedSimulatedConversation into a group with overlapping prompts raises.""" diff --git a/tests/unit/models/test_seed_simulated_conversation.py b/tests/unit/models/test_seed_simulated_conversation.py index 8afe06f914..c8239a9caa 100644 --- a/tests/unit/models/test_seed_simulated_conversation.py +++ b/tests/unit/models/test_seed_simulated_conversation.py @@ -174,11 +174,11 @@ def test_init_next_message_system_prompt_path_set(self, tmp_path): assert conv.next_message_system_prompt_path == next_msg_path -class TestSeedSimulatedConversationFromDict: - """Tests for SeedSimulatedConversation.from_dict method.""" +class TestSeedSimulatedConversationFromMapping: + """Tests for constructing SeedSimulatedConversation from a dict via ``model_validate``.""" def test_from_dict_with_paths(self, tmp_path): - """Test from_dict with path values.""" + """Test construction from a dict with path values.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") @@ -186,13 +186,13 @@ def test_from_dict_with_paths(self, tmp_path): "num_turns": 5, "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) assert conv.num_turns == 5 assert conv.adversarial_chat_system_prompt_path == adv_path def test_from_dict_without_simulated_target_path(self, tmp_path): - """Test from_dict without simulated_target_system_prompt_path uses compliant default.""" + """Test construction without simulated_target_system_prompt_path uses compliant default.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") @@ -200,29 +200,29 @@ def test_from_dict_without_simulated_target_path(self, tmp_path): "num_turns": 3, "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) # Default simulated_target_system_prompt_path is the compliant prompt assert conv.simulated_target_system_prompt_path == SimulatedTargetSystemPromptPaths.COMPLIANT.value def test_from_dict_default_num_turns(self, tmp_path): - """Test from_dict uses default num_turns when not specified.""" + """Test that num_turns defaults to 3 when not specified.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") data = { "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) assert conv.num_turns == 3 def test_from_dict_missing_adversarial_path_raises_error(self): - """Test that from_dict raises error when adversarial path is missing.""" + """Test that construction raises when adversarial path is missing (required field).""" data = {"num_turns": 3} - with pytest.raises(ValueError, match="adversarial_chat_system_prompt_path is required"): - SeedSimulatedConversation.from_dict(data) + with pytest.raises(ValueError, match="adversarial_chat_system_prompt_path"): + SeedSimulatedConversation.model_validate(data) class TestSeedSimulatedConversationGetIdentifier: diff --git a/tests/unit/models/test_yaml_seed_loader.py b/tests/unit/models/test_yaml_seed_loader.py new file mode 100644 index 0000000000..904ff3c2af --- /dev/null +++ b/tests/unit/models/test_yaml_seed_loader.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.models.seeds import ( + SeedDataset, + SeedObjective, + SeedPrompt, + load_seed_dataset_from_yaml, + load_seed_from_yaml, + load_seed_prompt_from_yaml_with_required_parameters, +) + + +def _write(tmp_path, name, content): + p = tmp_path / name + p.write_text(content, encoding="utf-8") + return p + + +def test_load_seed_from_yaml_returns_typed_seed(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: hello world\ndata_type: text\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert isinstance(loaded, SeedPrompt) + assert loaded.value == "hello world" + # Loader is the trust boundary — it sets the jinja-template marker exactly once. + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_supports_objective(tmp_path): + yaml_file = _write(tmp_path, "o.yaml", "value: stop the attacker\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedObjective) + + assert isinstance(loaded, SeedObjective) + assert loaded.value == "stop the attacker" + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_overrides_in_file_value(tmp_path): + # An ``is_jinja_template: false`` in the YAML must not let the file claim it is untrusted — + # the loader's trust claim wins. + yaml_file = _write(tmp_path, "p.yaml", "value: x\nis_jinja_template: false\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_accepts_string_path(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: x\n") + + loaded = load_seed_from_yaml(str(yaml_file), cls=SeedPrompt) + + assert loaded.value == "x" + + +def test_load_seed_from_yaml_missing_file_raises(tmp_path): + with pytest.raises(FileNotFoundError): + load_seed_from_yaml(tmp_path / "missing.yaml", cls=SeedPrompt) + + +def test_load_seed_from_yaml_empty_file_raises(tmp_path): + yaml_file = _write(tmp_path, "empty.yaml", "") + + with pytest.raises(ValueError, match="is empty"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_from_yaml_top_level_list_raises(tmp_path): + yaml_file = _write(tmp_path, "list.yaml", "- a\n- b\n") + + with pytest.raises(ValueError, match="must contain a mapping"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_from_yaml_invalid_yaml_raises(tmp_path): + yaml_file = _write(tmp_path, "bad.yaml", "value: [unterminated\n") + + with pytest.raises(ValueError, match="Invalid YAML file"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_dataset_from_yaml_marks_seeds_as_trusted(tmp_path): + yaml_file = _write( + tmp_path, + "ds.yaml", + "name: tiny\nseeds:\n - value: a\n - value: b\n", + ) + + dataset = load_seed_dataset_from_yaml(yaml_file) + + assert isinstance(dataset, SeedDataset) + assert [s.value for s in dataset.seeds] == ["a", "b"] + # Trust marker propagates from the loader to each nested seed. + assert all(s.is_jinja_template for s in dataset.seeds) + + +def test_load_seed_dataset_from_yaml_empty_raises(tmp_path): + yaml_file = _write(tmp_path, "empty.yaml", "") + + with pytest.raises(ValueError, match="is empty"): + load_seed_dataset_from_yaml(yaml_file) + + +def test_load_seed_prompt_from_yaml_with_required_parameters_succeeds(tmp_path): + yaml_file = _write( + tmp_path, + "t.yaml", + "value: hello {{ name }}\nparameters:\n - name\n", + ) + + loaded = load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["name"]) + + assert isinstance(loaded, SeedPrompt) + assert loaded.parameters == ["name"] + + +def test_load_seed_prompt_from_yaml_with_required_parameters_missing_raises(tmp_path): + yaml_file = _write( + tmp_path, + "t.yaml", + "value: hello {{ name }}\nparameters:\n - name\n", + ) + + with pytest.raises(ValueError, match="Template must have these parameters: name, age"): + load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["name", "age"]) + + +def test_load_seed_prompt_from_yaml_with_required_parameters_custom_error(tmp_path): + yaml_file = _write(tmp_path, "t.yaml", "value: no params here\n") + + with pytest.raises(ValueError, match="bespoke"): + load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["foo"], error_message="bespoke") + + +def test_classmethod_shims_delegate_to_loader(tmp_path): + # Verifies the public classmethod surface (Seed.from_yaml_file et al.) + # produces the same result as the loader functions — i.e. the shims are honest. + yaml_file = _write(tmp_path, "p.yaml", "value: x\n") + + via_classmethod = SeedPrompt.from_yaml_file(yaml_file) + via_function = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert via_classmethod.value == via_function.value + assert via_classmethod.is_jinja_template == via_function.is_jinja_template is True + + +# ----- Scalar-to-list canonicalization (loader-side YAML accommodation) ----- + + +def test_load_seed_from_yaml_canonicalizes_scalar_authors_to_list(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: hi\nauthors: Jane Doe\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.authors == ["Jane Doe"] + + +def test_load_seed_from_yaml_canonicalizes_scalar_parameters(tmp_path): + yaml_file = _write( + tmp_path, + "p.yaml", + "value: hello {{ name }}\nparameters: name\n", + ) + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.parameters == ["name"] + + +def test_load_seed_from_yaml_passes_through_list_authors_unchanged(tmp_path): + yaml_file = _write( + tmp_path, + "p.yaml", + "value: hi\nauthors:\n - Alice\n - Bob\n", + ) + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.authors == ["Alice", "Bob"] + + +def test_load_seed_dataset_from_yaml_canonicalizes_top_level_and_nested(tmp_path): + yaml_file = _write( + tmp_path, + "ds.yaml", + ( + "name: tiny\n" + "authors: Top Author\n" + "harm_categories: harm-1\n" + "seeds:\n" + " - value: a\n" + " authors: Seed Author\n" + " groups: g1\n" + " - value: b\n" + ), + ) + + dataset = load_seed_dataset_from_yaml(yaml_file) + + # Top-level scalars wrapped; dataset-level harm_categories propagates to seeds. + assert dataset.harm_categories == ["harm-1"] + # Per-seed scalar authors is wrapped before dataset-level merge with ["Top Author"]. + assert sorted(dataset.seeds[0].authors or []) == sorted(["Top Author", "Seed Author"]) + assert dataset.seeds[0].groups == ["g1"] + # Pure inheritance from dataset defaults still works for the second seed. + assert dataset.seeds[1].authors == ["Top Author"] + + +def test_model_now_rejects_programmatic_scalar_string(): + """The wrap-scalar accommodation has moved to the loader; the model is strict.""" + import pydantic + + with pytest.raises(pydantic.ValidationError): + SeedPrompt(value="hi", authors="not a list") # type: ignore[arg-type] diff --git a/tests/unit/scenario/airt/test_jailbreak.py b/tests/unit/scenario/airt/test_jailbreak.py index 307981f485..db07b40c0f 100644 --- a/tests/unit/scenario/airt/test_jailbreak.py +++ b/tests/unit/scenario/airt/test_jailbreak.py @@ -13,7 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.executor.attack.single_turn.role_play import RolePlayAttack from pyrit.executor.attack.single_turn.skeleton_key import SkeletonKeyAttack -from pyrit.models import ComponentIdentifier, SeedGroup, SeedObjective +from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedObjective from pyrit.prompt_target import PromptTarget from pyrit.scenario.core import BaselineAttackPolicy from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy @@ -44,10 +44,10 @@ def mock_scenario_result_id() -> str: @pytest.fixture -def mock_memory_seed_groups() -> list[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedAttackGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" return [ - SeedGroup(seeds=[SeedObjective(value=prompt)]) + SeedAttackGroup(seeds=[SeedObjective(value=prompt)]) for prompt in [ "sample objective 1", "sample objective 2", @@ -396,7 +396,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: """Test initialization with custom max_concurrency.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -409,7 +409,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: """Test initialization with memory labels.""" memory_labels = {"type": "jailbreak", "category": "scenario"} @@ -445,7 +445,7 @@ def test_scenario_default_dataset(self) -> None: assert Jailbreak.required_datasets() == ["airt_harms"] async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -489,7 +489,7 @@ async def test_roleplay_attacks_share_adversarial_target( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], roleplay_jailbreak_strategy: JailbreakStrategy, ) -> None: """Test that multiple role-play attacks share the same adversarial target instance.""" diff --git a/tests/unit/scenario/airt/test_scam.py b/tests/unit/scenario/airt/test_scam.py index 49befdbd45..3f2014578b 100644 --- a/tests/unit/scenario/airt/test_scam.py +++ b/tests/unit/scenario/airt/test_scam.py @@ -15,7 +15,7 @@ RolePlayAttack, ) from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedDataset, SeedGroup, SeedObjective +from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptTarget from pyrit.scenario import DatasetConfiguration from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy @@ -42,9 +42,9 @@ def _mock_target_id(name: str = "MockTarget") -> ComponentIdentifier: @pytest.fixture -def mock_memory_seed_groups() -> list[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedAttackGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" - return [SeedGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] + return [SeedAttackGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] @pytest.fixture @@ -56,7 +56,7 @@ def mock_memory_seeds(): @pytest.fixture def mock_dataset_config(mock_memory_seed_groups): """Create a mock dataset config that returns the seed groups.""" - seed_attack_groups = [SeedAttackGroup(seeds=list(sg.seeds)) for sg in mock_memory_seed_groups] + seed_attack_groups = list(mock_memory_seed_groups) mock_config = MagicMock(spec=DatasetConfiguration) mock_config.get_all_seed_attack_groups.return_value = seed_attack_groups mock_config.get_default_dataset_names.return_value = ["airt_scam"] @@ -127,7 +127,7 @@ def test_init_with_default_objectives( self, *, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -141,7 +141,7 @@ def test_init_with_default_scorer(self, mock_memory_seed_groups) -> None: scenario = Scam() assert scenario._objective_scorer_identifier - def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedGroup]) -> None: + def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedAttackGroup]) -> None: """Test initialization with custom scorer.""" scorer = MagicMock(spec=TrueFalseCompositeScorer) @@ -150,7 +150,7 @@ def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedGrou assert isinstance(scenario._scorer_config, AttackScoringConfig) def test_init_default_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -159,7 +159,7 @@ def test_init_default_adversarial_chat( assert scenario._adversarial_chat._temperature == 1.2 def test_init_with_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") @@ -340,7 +340,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], mock_dataset_config, ) -> None: """Test initialization with custom max_concurrency.""" @@ -356,7 +356,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], mock_dataset_config, ) -> None: """Test initialization with memory labels.""" @@ -389,7 +389,11 @@ def test_scenario_version_is_set( assert scenario.VERSION == 1 async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup], mock_dataset_config + self, + *, + mock_objective_target: PromptTarget, + mock_memory_seed_groups: list[SeedAttackGroup], + mock_dataset_config, ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups):