From 73dd89f328760d696b79386ab0694a7f73b1efca Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 13:31:25 -0700 Subject: [PATCH 01/11] Convert pyrit.models.seeds to Pydantic v2 with a clean import boundary Convert all 8 seed classes (Seed, SeedPrompt, SeedObjective, SeedSimulatedConversation, SeedGroup, SeedAttackGroup, SeedAttackTechniqueGroup, SeedDataset) from dataclasses/plain classes to Pydantic v2 BaseModel. Establish a two-layer import rule between pyrit.models and pyrit.common, enforced by the import-boundary test. Add str->list coercion (shared coerce_str_to_list helper) so YAML seed files may specify list fields (harm_categories/authors/groups/parameters) as bare scalars, preserving the old unvalidated behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/instructions/models.instructions.md | 31 ++- pyrit/models/seeds/seed.py | 70 +++-- pyrit/models/seeds/seed_attack_group.py | 21 +- .../seeds/seed_attack_technique_group.py | 37 +-- pyrit/models/seeds/seed_dataset.py | 257 ++++++++---------- pyrit/models/seeds/seed_group.py | 94 ++++--- pyrit/models/seeds/seed_objective.py | 15 +- pyrit/models/seeds/seed_prompt.py | 37 ++- .../seeds/seed_simulated_conversation.py | 87 +++--- tests/unit/models/test_import_boundary.py | 219 +++++++++------ tests/unit/models/test_seed.py | 34 ++- 11 files changed, 480 insertions(+), 422 deletions(-) diff --git a/.github/instructions/models.instructions.md b/.github/instructions/models.instructions.md index 8dd5e3255f..4a9e32baa0 100644 --- a/.github/instructions/models.instructions.md +++ b/.github/instructions/models.instructions.md @@ -6,19 +6,30 @@ applyTo: "pyrit/models/**" ## Import Boundary -`pyrit.models` is the canonical data layer. Files in `pyrit/models/` may -import only from: +PyRIT enforces a two-layer rule for its foundational packages. `pyrit.common` +is the foundation layer and `pyrit.models` is the canonical data layer that sits +directly on top of it. + +**Forward (models).** Files in `pyrit/models/` may import only from: - the standard library - `pydantic` -- `pyrit.common.deprecation` +- any `pyrit.common.*` submodule (the whole prefix) - other `pyrit.models.*` submodules -If a helper needs another `pyrit.*` package, it does not belong on a model — -put it in that package as a free function or static helper. +If a helper needs another `pyrit.*` package (e.g. `pyrit.memory`, +`pyrit.score`), it does not belong on a model — put it in that package as a free +function or static helper. + +**Reverse guard (common).** Files in `pyrit/common/` may import only from the +standard library, third-party libraries, and other `pyrit.common.*` submodules. +`pyrit.common` may never import any other `pyrit.*` package — this is what keeps +it a true foundation and prevents an import cycle with `pyrit.models`. If a +`pyrit.common` helper needs `pyrit.models` (or anything higher), it belongs in +that higher layer, not in `common`. -The CI test `tests/unit/models/test_import_boundary.py` enforces this using an -allowlist of known transitional violations, each tagged with the phase that -removes it. The list must shrink monotonically: removing an import from source -without also removing its allowlist entry fails the test, and adding a new -unlisted import also fails the test. +The CI test `tests/unit/models/test_import_boundary.py` enforces both directions +using allowlists of known transitional violations, each tagged with the phase +that removes it. The lists must shrink monotonically: removing an import from +source without also removing its allowlist entry fails the test, and adding a +new unlisted import also fails the test. diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 2f0d045954..b1dc517ac7 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -9,33 +9,48 @@ from __future__ import annotations -import abc import logging import re import uuid -from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import yaml from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment +from pydantic import BaseModel, ConfigDict, Field, field_validator from pyrit.common.utils import verify_and_resolve_path -from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterator from pathlib import Path - from pyrit.models.literals import PromptDataType - logger = logging.getLogger(__name__) # TypeVar for generic return type in class methods T = TypeVar("T", bound="Seed") +def coerce_str_to_list(value: Any) -> Any: + """ + Coerce a bare string into a single-element list, leaving other values unchanged. + + YAML seed files commonly specify list-typed fields as a single scalar (e.g. ``authors: Jane Doe``) + rather than a list. This wraps such a value so it satisfies a ``list[str]`` field type. + + Args: + value: The raw field value provided during validation. + + Returns: + The value wrapped in a list if it was a bare string, otherwise unchanged. + """ + if isinstance(value, str): + return [value] + return value + + class PartialUndefined(Undefined): """Jinja undefined value that preserves unresolved placeholders as text.""" @@ -81,10 +96,11 @@ def __bool__(self) -> bool: return True # Ensures it doesn't evaluate to False -@dataclass -class Seed(YamlLoadable): +class Seed(BaseModel): """Represents seed data with various attributes and metadata.""" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + # The actual prompt value, which can be a string or a file path value: str @@ -92,7 +108,7 @@ class Seed(YamlLoadable): value_sha256: Optional[str] = None # Unique identifier for the prompt - id: Optional[uuid.UUID] = field(default_factory=lambda: uuid.uuid4()) + id: Optional[uuid.UUID] = Field(default_factory=uuid.uuid4) # Name of the prompt name: Optional[str] = None @@ -101,28 +117,28 @@ class Seed(YamlLoadable): dataset_name: Optional[str] = None # Categories of harm associated with this prompt - harm_categories: Optional[Sequence[str]] = field(default_factory=list) + harm_categories: Optional[list[str]] = Field(default_factory=list) # Description of the prompt description: Optional[str] = None # Authors of the prompt - authors: Optional[Sequence[str]] = field(default_factory=list) + authors: Optional[list[str]] = Field(default_factory=list) # Groups affiliated with the prompt - groups: Optional[Sequence[str]] = field(default_factory=list) + groups: Optional[list[str]] = Field(default_factory=list) # Source of the prompt source: Optional[str] = None # Date when the prompt was added to the dataset - date_added: Optional[datetime] = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) # User who added the prompt to the dataset added_by: Optional[str] = None # Arbitrary metadata that can be attached to the prompt - metadata: Optional[dict[str, Union[str, int]]] = field(default_factory=dict) + metadata: Optional[dict[str, Any]] = Field(default_factory=dict) # Unique identifier for the prompt group prompt_group_id: Optional[uuid.UUID] = None @@ -138,15 +154,24 @@ class Seed(YamlLoadable): # to prevent template injection. Trusted sources (YAML files) set this to True automatically. is_jinja_template: bool = False - @property - def data_type(self) -> PromptDataType: + # The type of data this seed represents (e.g., text, image_path, audio_path, video_path). + # SeedPrompt overrides the default to None and infers it from the value; other seed types + # keep "text". + data_type: PromptDataType = "text" + + @field_validator("harm_categories", "authors", "groups", mode="before") + @classmethod + def _coerce_str_to_list(cls, value: Any) -> Any: """ - Return the data type for this seed. + Coerce a bare string into a single-element list for these list-typed fields. + + Args: + value: The raw field value provided during validation. - Base implementation returns 'text'. SeedPrompt overrides this - to support multiple data types (image_path, audio_path, etc.). + Returns: + The value wrapped in a list if it was a bare string, otherwise unchanged. """ - return "text" + return coerce_str_to_list(value) def render_template_value(self, **kwargs: Any) -> str: """ @@ -271,7 +296,6 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: return cls(**yaml_data) @classmethod - @abc.abstractmethod def from_yaml_with_required_parameters( cls, template_path: Union[str, Path], @@ -281,6 +305,9 @@ def from_yaml_with_required_parameters( """ Load a Seed from a YAML file and validate that it contains specific parameters. + The base implementation simply loads the file; subclasses that support parameters + (e.g. SeedPrompt) override this to enforce ``required_parameters``. + Args: template_path: Path to the YAML file containing the template. required_parameters: List of parameter names that must exist in the template. @@ -290,3 +317,4 @@ def from_yaml_with_required_parameters( Seed: The loaded and validated seed of the specific subclass type. """ + return cls.from_yaml_file(template_path) diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index 3c02c96d9c..e99a867dbf 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective @@ -17,7 +17,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup @@ -32,24 +31,6 @@ class SeedAttackGroup(SeedGroup): next_message, etc.) is inherited from SeedGroup. """ - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - ) -> None: - """ - Initialize a SeedAttackGroup. - - Args: - seeds: Sequence of seeds. Must include exactly one SeedObjective. - - Raises: - ValueError: If seeds is empty. - ValueError: If exactly one objective is not provided. - - """ - super().__init__(seeds=seeds) - def validate(self) -> None: """ Validate the seed attack group state. diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index ee8eb0475d..8509fe1022 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,16 +11,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union +from typing import Optional from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective -if TYPE_CHECKING: - from collections.abc import Sequence - - from pyrit.models.seeds.seed import Seed - class SeedAttackTechniqueGroup(SeedGroup): """ @@ -33,33 +28,9 @@ class SeedAttackTechniqueGroup(SeedGroup): next_message, etc.) is inherited from SeedGroup. """ - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - insertion_index: int | None = None, - ) -> None: - """ - Initialize a SeedAttackTechniqueGroup. - - Args: - seeds: Sequence of seeds. All seeds must have is_general_technique=True. - insertion_index: Where to insert technique seeds when merging into a - SeedAttackGroup via ``with_technique()``. ``None`` (default) appends - at the end. An integer inserts before that position in the target - group's seed list. - - Raises: - ValueError: If seeds is empty. - ValueError: If any seed does not have is_general_technique=True. - """ - self._insertion_index = insertion_index - super().__init__(seeds=seeds) - - @property - def insertion_index(self) -> int | None: - """Where to insert technique seeds when merging, or None to append at end.""" - return self._insertion_index + # Where to insert technique seeds when merging into a SeedAttackGroup via ``with_technique()``. + # ``None`` (default) appends at the end; an integer inserts before that position. + insertion_index: Optional[int] = None def validate(self) -> None: """ diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index ac49ce9082..61b0f56581 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -15,10 +15,15 @@ from typing import TYPE_CHECKING, Any, Optional, Union import yaml +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from pyrit.common import utils from pyrit.common.utils import verify_and_resolve_path -from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) +from pyrit.models.seeds.seed import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + Seed, + coerce_str_to_list, +) from pyrit.models.seeds.seed_attack_group import SeedAttackGroup from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective @@ -31,187 +36,151 @@ from pydantic.types import PositiveInt - from pyrit.models.literals import PromptDataType, SeedType - from pyrit.models.seeds.seed import Seed - logger = logging.getLogger(__name__) -class SeedDataset(YamlLoadable): +class SeedDataset(BaseModel): """ SeedDataset manages seed prompts plus optional top-level defaults. Prompts are stored as a Sequence[Seed], so references to prompt properties are straightforward (e.g. ds.seeds[0].value). """ - data_type: Optional[str] - name: Optional[str] - dataset_name: Optional[str] - harm_categories: Optional[Sequence[str]] - description: Optional[str] - authors: Optional[Sequence[str]] - groups: Optional[Sequence[str]] - source: Optional[str] - date_added: Optional[datetime] - added_by: Optional[str] - - # Now the actual prompts - seeds: Sequence[Seed] - + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + data_type: Optional[str] = "text" + name: Optional[str] = None + dataset_name: Optional[str] = None + harm_categories: Optional[list[str]] = None + description: Optional[str] = None + authors: Optional[list[str]] = Field(default_factory=list) + groups: Optional[list[str]] = Field(default_factory=list) + source: Optional[str] = None + date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + added_by: Optional[str] = None + # The default seed type for items that don't specify their own ("prompt", "objective", ...). + seed_type: Optional[SeedType] = None + + # The actual prompts + seeds: list[SerializeAsAny[Seed]] + + @model_validator(mode="before") @classmethod - def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: + def _build_seeds(cls, data: Any) -> Any: """ - Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. + Convert dict seed entries into concrete Seed subclasses, merging dataset-level defaults. - Args: - file: The input file path. + ``is_jinja_template`` is a construction-time flag (consumed here, not stored) that marks + seed values as trusted Jinja2 templates. Returns: - SeedDataset: The loaded dataset. + Any: The input data with ``seeds`` replaced by built Seed instances. Raises: - ValueError: If the YAML file is invalid. + ValueError: If the dataset has no seeds. """ - file = verify_and_resolve_path(file) - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + if not isinstance(data, dict): + return data - if yaml_data is None: - raise ValueError(f"YAML file '{file}' is empty.") - - yaml_data["is_jinja_template"] = True - if hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict")): # noqa: B009 - return cls.from_dict(yaml_data) - return cls(**yaml_data) - - def __init__( - self, - *, - seeds: Optional[Union[Sequence[dict[str, Any]], Sequence[Seed]]] = None, - data_type: Optional[PromptDataType] = "text", - name: Optional[str] = None, - dataset_name: Optional[str] = None, - harm_categories: Optional[Sequence[str]] = None, - description: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - date_added: Optional[datetime] = None, - added_by: Optional[str] = None, - seed_type: Optional[SeedType] = None, - is_jinja_template: bool = False, - ) -> None: - """ - Initialize the dataset. - Typically, you'll call from_dict or from_yaml_file so that top-level defaults - are merged into each seed. If you're passing seeds directly, they can be - either a list of Seed objects or seed dictionaries (which then get - converted to Seed objects). - - Args: - seeds: List of seed dictionaries or Seed objects. - data_type: Default data type for seeds. - name: Name of the dataset. - dataset_name: Dataset name for categorization. - harm_categories: List of harm categories. - description: Description of the dataset. - authors: List of authors. - groups: List of groups. - source: Source of the dataset. - date_added: Date when the dataset was added. - added_by: User who added the dataset. - seed_type: The type of seeds in this dataset ("prompt", "objective", or "simulated_conversation"). - is_jinja_template: When True, seed values are Jinja2 templates. Set by from_yaml_file. - - Raises: - ValueError: If seeds are missing or contain invalid/contradictory seed definitions. - - """ - if not seeds: + data = dict(data) + is_jinja_template = data.pop("is_jinja_template", False) + raw_seeds = data.get("seeds") + if not raw_seeds: raise ValueError("SeedDataset cannot be empty.") - input_seeds = seeds + default_data_type = data.get("data_type", "text") + default_name = data.get("name") + default_dataset_name = data.get("dataset_name") + default_description = data.get("description") + default_source = data.get("source") + dataset_seed_type = data.get("seed_type") - # Store top-level fields - self.data_type = data_type - self.name = name - self.dataset_name = dataset_name - - self.harm_categories = harm_categories - self.description = description - self.authors = authors or [] - self.groups = groups or [] - self.source = source - self.date_added = date_added or datetime.now(tz=timezone.utc) - self.added_by = added_by - - # Convert any dictionaries in `seeds` to SeedPrompt and/or SeedObjective objects - self.seeds = [] - for p in input_seeds: + built: list[Seed] = [] + for p in raw_seeds: if isinstance(p, dict): - p_seed_type = p.get("seed_type", seed_type) # type: ignore[ty:no-matching-overload] - - effective_type: SeedType = "prompt" - if p_seed_type == "objective": - effective_type = "objective" - elif p_seed_type == "simulated_conversation": - effective_type = "simulated_conversation" - elif p_seed_type == "prompt": - effective_type = "prompt" - - # Extract common base parameters (from Seed base class) with dataset defaults. - # Note: If Seed base class param names change, update here too. - # SeedSimulatedConversation computes its own value, so we don't require it. - base_params = { - "value_sha256": p.get("value_sha256"), # type: ignore[ty:invalid-argument-type] + p_seed_type = p.get("seed_type", dataset_seed_type) + + base_params: dict[str, Any] = { + "value_sha256": p.get("value_sha256"), "id": uuid.uuid4(), - "name": p.get("name") or self.name, # type: ignore[ty:invalid-argument-type] - "dataset_name": p.get("dataset_name") or self.dataset_name or self.name, # type: ignore[ty:invalid-argument-type] - "harm_categories": p.get("harm_categories", []), # type: ignore[ty:no-matching-overload] - "description": p.get("description") or self.description, # type: ignore[ty:invalid-argument-type] - "authors": p.get("authors", []), # type: ignore[ty:no-matching-overload] - "groups": p.get("groups", []), # type: ignore[ty:no-matching-overload] - "source": p.get("source") or self.source, # type: ignore[ty:invalid-argument-type] - "date_added": p.get("date_added"), # type: ignore[ty:invalid-argument-type] - "added_by": p.get("added_by"), # type: ignore[ty:invalid-argument-type] - "metadata": p.get("metadata", {}), # type: ignore[ty:no-matching-overload] - "prompt_group_id": p.get("prompt_group_id"), # type: ignore[ty:invalid-argument-type] + "name": p.get("name") or default_name, + "dataset_name": p.get("dataset_name") or default_dataset_name or default_name, + "harm_categories": p.get("harm_categories", []), + "description": p.get("description") or default_description, + "authors": p.get("authors", []), + "groups": p.get("groups", []), + "source": p.get("source") or default_source, + "date_added": p.get("date_added"), + "added_by": p.get("added_by"), + "metadata": p.get("metadata", {}), + "prompt_group_id": p.get("prompt_group_id"), "is_jinja_template": is_jinja_template, } - if effective_type == "simulated_conversation": - _adv_path = p.get("adversarial_chat_system_prompt_path") # type: ignore[ty:invalid-argument-type] - _sim_path = p.get("simulated_target_system_prompt_path") # type: ignore[ty:invalid-argument-type] - _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} # type: ignore[ty:no-matching-overload] + if p_seed_type == "simulated_conversation": + _adv_path = p.get("adversarial_chat_system_prompt_path") + _sim_path = p.get("simulated_target_system_prompt_path") + _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} if _adv_path is not None: _sc_kwargs["adversarial_chat_system_prompt_path"] = str(_adv_path) if _sim_path is not None: _sc_kwargs["simulated_target_system_prompt_path"] = str(_sim_path) - self.seeds.append(SeedSimulatedConversation(**_sc_kwargs)) # type: ignore[ty:invalid-argument-type] - elif effective_type == "objective": - # SeedObjective inherits data_type="text" from base Seed property - base_params["value"] = p["value"] # type: ignore[ty:invalid-argument-type] - self.seeds.append(SeedObjective(**base_params)) # type: ignore[ty:invalid-argument-type] + built.append(SeedSimulatedConversation(**_sc_kwargs)) + elif p_seed_type == "objective": + base_params["value"] = p["value"] + built.append(SeedObjective(**base_params)) else: # prompt - base_params["value"] = p["value"] # type: ignore[ty:invalid-argument-type] - self.seeds.append( + base_params["value"] = p["value"] + built.append( SeedPrompt( - **base_params, # type: ignore[ty:invalid-argument-type] - data_type=p.get("data_type") or self.data_type, # type: ignore[ty:invalid-argument-type] - role=p.get("role", "user"), # type: ignore[ty:no-matching-overload] - sequence=p.get("sequence", 0), # type: ignore[ty:no-matching-overload] - parameters=p.get("parameters", {}), # type: ignore[ty:no-matching-overload] + **base_params, + data_type=p.get("data_type") or default_data_type, + role=p.get("role", "user"), + sequence=p.get("sequence", 0), + parameters=p.get("parameters") or [], ) ) elif isinstance(p, (SeedPrompt, SeedObjective, SeedSimulatedConversation)): - self.seeds.append(p) + built.append(p) else: raise ValueError( "Seeds should be dicts or Seed objects (SeedPrompt, SeedObjective, SeedSimulatedConversation)." ) + data["seeds"] = built + for key in ("harm_categories", "authors", "groups"): + data[key] = coerce_str_to_list(data.get(key)) + data["authors"] = data.get("authors") or [] + data["groups"] = data.get("groups") or [] + data["date_added"] = data.get("date_added") or datetime.now(tz=timezone.utc) + return data + + @classmethod + def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: + """ + Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. + + Args: + file: The input file path. + + Returns: + SeedDataset: The loaded dataset. + + Raises: + ValueError: If the YAML file is invalid. + """ + file = verify_and_resolve_path(file) + try: + yaml_data = yaml.safe_load(file.read_text("utf-8")) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + + if yaml_data is None: + raise ValueError(f"YAML file '{file}' is empty.") + + yaml_data["is_jinja_template"] = True + return cls.from_dict(yaml_data) + def get_values( self, *, @@ -291,7 +260,7 @@ def from_dict(cls, data: dict[str, Any]) -> SeedDataset: dataset_defaults = data # everything else is top-level - merged_seeds = [] + merged_seeds: list[dict[str, Any]] = [] for p in seeds_data: # Merge dataset-level fields with the prompt-level fields merged = utils.combine_dict(dataset_defaults, p) @@ -323,7 +292,7 @@ def from_dict(cls, data: dict[str, Any]) -> SeedDataset: SeedDataset._set_seed_group_id_by_alias(seed_prompts=merged_seeds) # Now create the dataset with the newly merged prompt dicts - return cls(seeds=merged_seeds, **dataset_defaults) + return cls.model_validate({"seeds": merged_seeds, **dataset_defaults}) def render_template_value(self, **kwargs: object) -> None: """ diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 6e76d0a171..7397e26be1 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -13,9 +13,10 @@ import logging import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import BaseModel, ConfigDict, SerializeAsAny, model_validator -from pyrit.common.yaml_loadable import YamlLoadable from pyrit.models.message import Message from pyrit.models.message_piece import MessagePiece from pyrit.models.seeds.seed import Seed @@ -29,7 +30,7 @@ logger = logging.getLogger(__name__) -class SeedGroup(YamlLoadable): +class SeedGroup(BaseModel): """ A container for grouping prompts that need to be sent together. @@ -42,72 +43,83 @@ class SeedGroup(YamlLoadable): All prompts in the group share the same `prompt_group_id`. """ - seeds: list[Seed] + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + seeds: list[SerializeAsAny[Seed]] - def __init__( - self, - *, - seeds: Sequence[Union[Seed, dict[str, Any]]], - is_jinja_template: bool = False, - ) -> None: + @model_validator(mode="before") + @classmethod + def _coerce_seeds(cls, data: Any) -> Any: """ - Initialize a SeedGroup. + Coerce raw seed dicts into concrete Seed subclasses before validation. - Args: - seeds: Sequence of seeds. Can include: - - SeedObjective (or dict with seed_type="objective") - - SeedSimulatedConversation (or dict with seed_type="simulated_conversation") - - SeedPrompt for prompts (or dict with seed_type="prompt" or no seed_type) - is_jinja_template: When True, seed values are treated as Jinja2 templates. - Set automatically by from_yaml_file for trusted sources. + ``is_jinja_template`` is a construction-time flag (not a stored field): it is consumed + here and propagated to each dict seed so trusted YAML values are rendered as templates. - Raises: - ValueError: If seeds is empty. - ValueError: If multiple objectives are provided. - ValueError: If SeedPrompt sequences overlap with SeedSimulatedConversation range. + Returns: + Any: The input data with ``seeds`` replaced by concrete Seed instances. + Raises: + ValueError: If the group has no seeds or a seed has an unsupported type. """ - if not seeds: + if not isinstance(data, dict): + return data + + data = dict(data) + is_jinja_template = data.pop("is_jinja_template", False) + raw_seeds = data.get("seeds") + if not raw_seeds: raise ValueError("SeedGroup cannot be empty.") - self.seeds = [] - for seed in seeds: + coerced: list[Seed] = [] + for seed in raw_seeds: if isinstance(seed, Seed): - self.seeds.append(seed) + coerced.append(seed) elif isinstance(seed, dict): + seed = dict(seed) seed["is_jinja_template"] = is_jinja_template seed_type = seed.pop("seed_type", None) if seed_type == "simulated_conversation": # SeedSimulatedConversation doesn't use data_type (always text) seed.pop("data_type", None) - self.seeds.append(SeedSimulatedConversation.from_dict(seed)) + coerced.append(SeedSimulatedConversation.from_dict(seed)) elif seed_type == "objective": # SeedObjective doesn't use data_type (always text) seed.pop("data_type", None) - self.seeds.append(SeedObjective(**seed)) + coerced.append(SeedObjective(**seed)) else: - self.seeds.append(SeedPrompt(**seed)) + coerced.append(SeedPrompt(**seed)) else: raise ValueError(f"Invalid seed type: {type(seed)}") - # Validate and normalize the seeds - self.validate() + data["seeds"] = coerced + return data + + @model_validator(mode="after") + def _finalize(self) -> SeedGroup: + """ + Validate the group and reorder seeds into canonical order. + + Canonical order is: objective, simulated_conversation, then prompts sorted by sequence. - # Extract simulated conversation config - self._simulated_conversation_config = self._get_simulated_conversation() + Returns: + SeedGroup: The validated, reordered group. + """ + self.validate() - # Reconstruct seeds in canonical order: objective, simulated_conversation, sorted prompts objective = self._get_objective() - simulated_conv = self._simulated_conversation_config + simulated_conv = self._get_simulated_conversation() sorted_prompts = sorted(self.prompts, key=lambda p: p.sequence if p.sequence is not None else 0) - self.seeds = [] + new_seeds: list[Seed] = [] if objective: - self.seeds.append(objective) + new_seeds.append(objective) if simulated_conv: - self.seeds.append(simulated_conv) - self.seeds.extend(sorted_prompts) + new_seeds.append(simulated_conv) + new_seeds.extend(sorted_prompts) + self.seeds = new_seeds + return self # ========================================================================= # Validation @@ -293,12 +305,12 @@ def harm_categories(self) -> list[str]: @property def simulated_conversation_config(self) -> Optional[SeedSimulatedConversation]: """Get the simulated conversation configuration if set.""" - return self._simulated_conversation_config + return self._get_simulated_conversation() @property def has_simulated_conversation(self) -> bool: """Check if this group uses simulated conversation generation.""" - return self._simulated_conversation_config is not None + return self._get_simulated_conversation() is not None # ========================================================================= # Message Extraction diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index 0f0edd743a..26b59f1f1e 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -8,9 +8,10 @@ from __future__ import annotations import logging -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union +from pydantic import model_validator + from pyrit.common.path import PATHS_DICT from pyrit.models.seeds.seed import Seed @@ -20,16 +21,17 @@ logger = logging.getLogger(__name__) -@dataclass class SeedObjective(Seed): """Represents a seed objective with various attributes and metadata.""" - is_general_technique: bool = False - - def __post_init__(self) -> None: + @model_validator(mode="after") + def _validate_and_render(self) -> SeedObjective: """ Post-initialization to render the template to replace existing values. + Returns: + SeedObjective: The validated, rendered objective. + Raises: ValueError: If is_general_technique is True. """ @@ -37,7 +39,8 @@ def __post_init__(self) -> None: raise ValueError("SeedObjective cannot be a general technique.") # Only trusted templates are rendered through Jinja — see seed_prompt.py for details. if self.is_jinja_template: - self.value = super().render_template_value_silent(**PATHS_DICT) + self.value = self.render_template_value_silent(**PATHS_DICT) + return self @classmethod def from_yaml_with_required_parameters( diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index d2b867c105..4d8312eae0 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -9,32 +9,33 @@ import logging import os -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional, Union +from pydantic import Field, field_validator, model_validator from tinytag import TinyTag from pyrit.common.path import PATHS_DICT from pyrit.models import DataTypeSerializer -from pyrit.models.seeds.seed import Seed +from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ChatMessageRole, + PromptDataType, +) +from pyrit.models.seeds.seed import Seed, coerce_str_to_list if TYPE_CHECKING: import uuid - from collections.abc import Sequence from pathlib import Path from pyrit.models import Message - from pyrit.models.literals import ChatMessageRole, PromptDataType logger = logging.getLogger(__name__) -@dataclass class SeedPrompt(Seed): """Represents a seed prompt with various attributes and metadata.""" # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) - # This field shadows the base class property to allow per-prompt data types + # This field overrides the base default to allow per-prompt data types inferred from the value data_type: Optional[PromptDataType] = None # Role of the prompt in a conversation (e.g., "user", "assistant") @@ -45,12 +46,30 @@ class SeedPrompt(Seed): sequence: int = 0 # Parameters that can be used in the prompt template - parameters: Optional[Sequence[str]] = field(default_factory=list) + parameters: Optional[list[str]] = Field(default_factory=list) - def __post_init__(self) -> None: + @field_validator("parameters", mode="before") + @classmethod + def _coerce_parameters_to_list(cls, value: object) -> object: + """ + Coerce a bare string ``parameters`` value into a single-element list. + + Args: + value: The raw field value provided during validation. + + Returns: + The value wrapped in a list if it was a bare string, otherwise unchanged. + """ + return coerce_str_to_list(value) + + @model_validator(mode="after") + def _render_and_infer_data_type(self) -> SeedPrompt: """ Render template placeholders and infer data_type after initialization. + Returns: + SeedPrompt: The validated prompt with rendered value and inferred data_type. + Raises: ValueError: If file-based data type cannot be inferred from extension. @@ -79,6 +98,8 @@ def __post_init__(self) -> None: else: self.data_type = "text" + return self + def set_encoding_metadata(self) -> None: """ Set encoding metadata for the prompt within metadata dictionary. For images, this is just the diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index 019e842faa..9db08f0b44 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -15,12 +15,14 @@ import enum import hashlib +import importlib.metadata import json import logging from pathlib import Path from typing import Any, Optional, Union -import pyrit +from pydantic import field_validator, model_validator + from pyrit.common.path import EXECUTOR_SIMULATED_TARGET_PATH from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_prompt import SeedPrompt @@ -66,64 +68,37 @@ class SeedSimulatedConversation(Seed): """ - def __init__( - self, - *, - adversarial_chat_system_prompt_path: Union[str, Path], - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - next_message_system_prompt_path: Optional[Union[str, Path]] = None, - num_turns: int = 3, - sequence: int = 0, - pyrit_version: Optional[str] = None, - **kwargs: Any, - ) -> None: - """ - Initialize a SeedSimulatedConversation. + # value is computed from the config in the after-validator; it must not be supplied directly. + value: str = "" - Args: - adversarial_chat_system_prompt_path: Path to YAML file containing the adversarial - chat system prompt. - simulated_target_system_prompt_path: Optional path to YAML file containing - the simulated target system prompt. Defaults to the compliant prompt. - next_message_system_prompt_path: Optional path to YAML file containing the system - prompt for generating a final user message. If provided, after the simulated - conversation is generated, a single LLM call generates a user message that - attempts to get the target to fulfill the objective. Defaults to None - (no next message generation). - num_turns: Number of conversation turns to generate. Defaults to 3. - sequence: The starting sequence number for generated turns. When combined with - static SeedPrompts, this determines where the simulated turns are inserted. - Defaults to 0. - pyrit_version: PyRIT version for reproducibility tracking. Defaults to current version. - **kwargs: Additional arguments passed to the Seed base class. + # Simulated conversations are general techniques by default. + is_general_technique: bool = True - Raises: - ValueError: If num_turns is not positive or sequence is negative. + num_turns: int = 3 + sequence: int = 0 + adversarial_chat_system_prompt_path: Path + simulated_target_system_prompt_path: Path = SimulatedTargetSystemPromptPaths.COMPLIANT.value + next_message_system_prompt_path: Optional[Path] = None + pyrit_version: Optional[str] = None - """ - # Apply default for simulated target system prompt if not provided - if simulated_target_system_prompt_path is None: - simulated_target_system_prompt_path = SimulatedTargetSystemPromptPaths.COMPLIANT.value - if num_turns <= 0: + @field_validator("simulated_target_system_prompt_path", mode="before") + @classmethod + def _default_simulated_target_path(cls, value: Any) -> Any: + # Reconstruction from memory may pass an explicit None; fall back to the compliant default. + if value is None: + return SimulatedTargetSystemPromptPaths.COMPLIANT.value + return value + + @model_validator(mode="after") + def _validate_and_compute_value(self) -> SeedSimulatedConversation: + if self.num_turns <= 0: raise ValueError("num_turns must be a positive integer") - if sequence < 0: + if self.sequence < 0: raise ValueError("sequence must be a non-negative integer") - - self.adversarial_chat_system_prompt_path = Path(adversarial_chat_system_prompt_path) - self.simulated_target_system_prompt_path = Path(simulated_target_system_prompt_path) - self.next_message_system_prompt_path = ( - Path(next_message_system_prompt_path) if next_message_system_prompt_path else None - ) - self.num_turns = num_turns - self.sequence = sequence - self.pyrit_version = pyrit_version or pyrit.__version__ - - # Compute value and pass to parent - # Remove 'value' from kwargs if present since we compute it - kwargs.pop("value", None) - # Default is_general_technique to True for simulated conversations - kwargs.setdefault("is_general_technique", True) - super().__init__(value=self._compute_value(), **kwargs) + if not self.pyrit_version: + self.pyrit_version = importlib.metadata.version("pyrit") + self.value = self._compute_value() + return self def _compute_value(self) -> str: """ @@ -173,7 +148,9 @@ def from_dict(cls, data: dict[str, Any]) -> SeedSimulatedConversation: num_turns=data.get("num_turns", 3), sequence=data.get("sequence", 0), adversarial_chat_system_prompt_path=adversarial_path, - simulated_target_system_prompt_path=data.get("simulated_target_system_prompt_path"), + simulated_target_system_prompt_path=( + data.get("simulated_target_system_prompt_path") or SimulatedTargetSystemPromptPaths.COMPLIANT.value + ), next_message_system_prompt_path=data.get("next_message_system_prompt_path"), ) diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index bb4b6ca9e9..b4c33795e2 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -2,17 +2,22 @@ # Licensed under the MIT license. """ -Enforce the ``pyrit.models`` import boundary. +Enforce the ``pyrit.models`` / ``pyrit.common`` import boundary. -``pyrit.models`` is the canonical data layer. Files in ``pyrit/models/`` -may import only from stdlib, ``pydantic``, ``pyrit.common.deprecation``, -and other ``pyrit.models.*`` submodules. +PyRIT uses a two-layer rule for its foundational packages: -This test uses a ratchet pattern: ``KNOWN_TOP_LEVEL_VIOLATIONS`` and -``KNOWN_LAZY_VIOLATIONS`` track imports that exist today and are expected to -disappear in a specific phase. The lists must shrink monotonically — if a known -violation is no longer in source, this test fails and the entry must be -removed. +* **Forward (models):** files in ``pyrit/models/`` may import only from stdlib, + ``pydantic``, all of ``pyrit.common`` (the whole prefix), and other + ``pyrit.models.*`` submodules. +* **Reverse guard (common):** files in ``pyrit/common/`` may import only from + stdlib, third-party libraries, and other ``pyrit.common.*`` submodules — never + any other ``pyrit.*`` package. This keeps ``pyrit.common`` a true foundation + layer and prevents an import cycle with ``pyrit.models``. + +Both directions use a ratchet pattern: the ``KNOWN_*_VIOLATIONS`` lists track +imports that exist today and are expected to disappear in a later phase. The +lists must shrink monotonically — if a known violation is no longer in source, +this test fails and the entry must be removed. See plan.md / ``.github/instructions/models.instructions.md`` for context. """ @@ -25,82 +30,67 @@ import pytest +import pyrit.common import pyrit.models if TYPE_CHECKING: from collections.abc import Iterable # noqa: F401 MODELS_PACKAGE = Path(pyrit.models.__file__).parent +COMMON_PACKAGE = Path(pyrit.common.__file__).parent EXCLUDE_FILES: frozenset[str] = frozenset() -# Always allowed at module level (in addition to pyrit.models.* self-imports). -ALLOWED_TOP_LEVEL: frozenset[str] = frozenset( - { - "pyrit.common.deprecation", - } -) +# Forward rule: pyrit.models may import these pyrit prefixes freely. +MODELS_ALLOWED_PREFIXES: tuple[str, ...] = ("pyrit.models", "pyrit.common") + +# Reverse guard: pyrit.common may import only itself within the pyrit namespace. +COMMON_ALLOWED_PREFIXES: tuple[str, ...] = ("pyrit.common",) -# Transitional known top-level violations. Each entry names the phase that -# clears it (documentation only — the test does not parse the tag). New -# violations not in this list fail the test; entries that no longer match +# Transitional known top-level violations for pyrit.models. Each entry names the +# phase that clears it (documentation only — the test does not parse the tag). +# New violations not in this list fail the test; entries that no longer match # source also fail. -KNOWN_TOP_LEVEL_VIOLATIONS: dict[str, dict[str, str]] = { - "pyrit.models.message": { - "pyrit.common.utils": "phase-4", - }, - "pyrit.models.harm_definition": { - "pyrit.common.path": "phase-1", - }, - "pyrit.models.data_type_serializer": { - "pyrit.common.path": "phase-8", - }, - "pyrit.models.seeds.seed": { - "pyrit.common.utils": "seeds-followup", - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_dataset": { - "pyrit.common": "seeds-followup", - "pyrit.common.utils": "seeds-followup", - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_group": { - "pyrit.common.yaml_loadable": "seeds-followup", - }, - "pyrit.models.seeds.seed_objective": { - "pyrit.common.path": "seeds-followup", - }, - "pyrit.models.seeds.seed_prompt": { - "pyrit.common.path": "seeds-followup", - }, - "pyrit.models.seeds.seed_simulated_conversation": { - "pyrit.common.path": "seeds-followup", - }, -} +KNOWN_TOP_LEVEL_VIOLATIONS: dict[str, dict[str, str]] = {} -# Lazy / TYPE_CHECKING imports of cross-package modules. Same ratchet, tracked -# separately so the phase that removes the lazy workaround is explicit. +# Lazy / TYPE_CHECKING imports of cross-package modules from pyrit.models. Same +# ratchet, tracked separately so the phase that removes the lazy workaround is +# explicit. KNOWN_LAZY_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.models.data_type_serializer": { - "pyrit.memory": "phase-8", - "pyrit.common.path": "phase-8", + "pyrit.memory": "phase-9", }, "pyrit.models.scenario_result": { "pyrit.score.scorer_evaluation.scorer_metrics": "phase-7", "pyrit.score.scorer_evaluation.scorer_metrics_io": "phase-7", }, "pyrit.models.storage_io": { - "pyrit.auth": "phase-8", + "pyrit.auth": "phase-9", + }, +} + +# Reverse-guard violations: pyrit.common modules that still reach up into higher +# layers. These are slated to relocate; the ratchet forces them to shrink. +KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { + "pyrit.common.data_url_converter": { + "pyrit.models": "relocate", + }, + "pyrit.common.display_response": { + "pyrit.memory": "relocate", + "pyrit.models": "relocate", + }, + "pyrit.common.question_answer_helpers": { + "pyrit.models": "relocate", }, } -def _module_name_for(path: Path) -> str: - """Return the dotted module name for a file inside ``pyrit/models/``.""" - rel = path.relative_to(MODELS_PACKAGE).with_suffix("") +def _module_name_for(path: Path, *, package_root: Path, package_prefix: str) -> str: + """Return the dotted module name for a file inside ``package_root``.""" + rel = path.relative_to(package_root).with_suffix("") parts = [p for p in rel.parts if p != "__init__"] if not parts: - return "pyrit.models" - return "pyrit.models." + ".".join(parts) + return package_prefix + return package_prefix + "." + ".".join(parts) def _resolve_from_import(node: ast.ImportFrom, source_module: str) -> str: @@ -126,11 +116,17 @@ def _is_typecheck_test(test: ast.expr) -> bool: ) +def _is_allowed(mod: str, allowed_prefixes: tuple[str, ...]) -> bool: + """Return True iff ``mod`` is within one of the allowed self-prefixes.""" + return any(mod == prefix or mod.startswith(prefix + ".") for prefix in allowed_prefixes) + + class _ImportCollector(ast.NodeVisitor): - """Walk a module AST and bucket ``pyrit.*`` imports into top-level vs lazy.""" + """Walk a module AST and bucket disallowed ``pyrit.*`` imports into top-level vs lazy.""" - def __init__(self, source_module: str) -> None: + def __init__(self, source_module: str, *, allowed_prefixes: tuple[str, ...]) -> None: self.source_module = source_module + self.allowed_prefixes = allowed_prefixes self.top_level: set[str] = set() self.lazy: set[str] = set() self._in_lazy = False @@ -138,8 +134,7 @@ def __init__(self, source_module: str) -> None: def _record(self, mod: str) -> None: if not mod.startswith("pyrit."): return - # pyrit.models.* self-imports are always allowed (we are inside pyrit.models) - if mod == "pyrit.models" or mod.startswith("pyrit.models."): + if _is_allowed(mod, self.allowed_prefixes): return bucket = self.lazy if self._in_lazy else self.top_level bucket.add(mod) @@ -174,51 +169,77 @@ def visit_If(self, node: ast.If) -> None: self.generic_visit(node) -def _scan_files() -> list[Path]: - """Return all ``pyrit/models/**/*.py`` files in scope.""" - return sorted(p for p in MODELS_PACKAGE.rglob("*.py") if p.name not in EXCLUDE_FILES) +def _scan_files(package_root: Path) -> list[Path]: + """Return all ``*.py`` files under ``package_root`` in scope.""" + return sorted(p for p in package_root.rglob("*.py") if p.name not in EXCLUDE_FILES) -def _analyze(path: Path) -> tuple[str, set[str], set[str]]: - source_module = _module_name_for(path) +def _analyze( + path: Path, *, package_root: Path, package_prefix: str, allowed_prefixes: tuple[str, ...] +) -> tuple[str, set[str], set[str]]: + source_module = _module_name_for(path, package_root=package_root, package_prefix=package_prefix) tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) - collector = _ImportCollector(source_module) + collector = _ImportCollector(source_module, allowed_prefixes=allowed_prefixes) collector.visit(tree) return source_module, collector.top_level, collector.lazy -def _collect_actual_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: +def _collect_actual_imports( + *, package_root: Path, package_prefix: str, allowed_prefixes: tuple[str, ...] +) -> tuple[dict[str, set[str]], dict[str, set[str]]]: top: dict[str, set[str]] = {} lazy: dict[str, set[str]] = {} - for path in _scan_files(): - source, top_imports, lazy_imports = _analyze(path) + for path in _scan_files(package_root): + source, top_imports, lazy_imports = _analyze( + path, + package_root=package_root, + package_prefix=package_prefix, + allowed_prefixes=allowed_prefixes, + ) top[source] = top_imports lazy[source] = lazy_imports return top, lazy +def _collect_models_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: + return _collect_actual_imports( + package_root=MODELS_PACKAGE, + package_prefix="pyrit.models", + allowed_prefixes=MODELS_ALLOWED_PREFIXES, + ) + + +def _collect_common_imports() -> tuple[dict[str, set[str]], dict[str, set[str]]]: + return _collect_actual_imports( + package_root=COMMON_PACKAGE, + package_prefix="pyrit.common", + allowed_prefixes=COMMON_ALLOWED_PREFIXES, + ) + + def test_no_new_top_level_violations() -> None: """Module-level pyrit imports outside the allowlist must be listed in KNOWN_TOP_LEVEL_VIOLATIONS.""" - actual_top, _ = _collect_actual_imports() + actual_top, _ = _collect_models_imports() new_violations: list[str] = [] for source, imports in actual_top.items(): known = set(KNOWN_TOP_LEVEL_VIOLATIONS.get(source, {}).keys()) for imp in sorted(imports): - if imp in ALLOWED_TOP_LEVEL or imp in known: + if imp in known: continue new_violations.append(f"{source} -> {imp}") if new_violations: pytest.fail( "New top-level pyrit imports in pyrit.models (not allowed):\n " + "\n ".join(new_violations) - + "\n\nEither remove the import or, if it is transitional, add it to " + + "\n\npyrit.models may import stdlib, pydantic, pyrit.common.*, and pyrit.models.*. " + "Either remove the import or, if it is transitional, add it to " "KNOWN_TOP_LEVEL_VIOLATIONS in this file with a phase tag." ) def test_known_top_level_violations_still_apply() -> None: """Entries in KNOWN_TOP_LEVEL_VIOLATIONS that no longer exist in source must be removed.""" - actual_top, _ = _collect_actual_imports() + actual_top, _ = _collect_models_imports() stale: list[str] = [] for source, allowed in KNOWN_TOP_LEVEL_VIOLATIONS.items(): present = actual_top.get(source, set()) @@ -233,12 +254,12 @@ def test_known_top_level_violations_still_apply() -> None: def test_no_new_lazy_violations() -> None: """Lazy/TYPE_CHECKING pyrit imports outside the allowlist must be listed in KNOWN_LAZY_VIOLATIONS.""" - _, actual_lazy = _collect_actual_imports() + _, actual_lazy = _collect_models_imports() new_violations: list[str] = [] for source, imports in actual_lazy.items(): known = set(KNOWN_LAZY_VIOLATIONS.get(source, {}).keys()) for imp in sorted(imports): - if imp in ALLOWED_TOP_LEVEL or imp in known: + if imp in known: continue new_violations.append(f"{source} -> {imp}") if new_violations: @@ -252,7 +273,7 @@ def test_no_new_lazy_violations() -> None: def test_known_lazy_violations_still_apply() -> None: """Entries in KNOWN_LAZY_VIOLATIONS that no longer exist in source must be removed.""" - _, actual_lazy = _collect_actual_imports() + _, actual_lazy = _collect_models_imports() stale: list[str] = [] for source, allowed in KNOWN_LAZY_VIOLATIONS.items(): present = actual_lazy.get(source, set()) @@ -265,9 +286,45 @@ def test_known_lazy_violations_still_apply() -> None: ) +def test_no_new_common_violations() -> None: + """pyrit.common may not import any pyrit.* outside pyrit.common (reverse guard).""" + actual_top, actual_lazy = _collect_common_imports() + new_violations: list[str] = [] + for source in sorted({*actual_top, *actual_lazy}): + imports = actual_top.get(source, set()) | actual_lazy.get(source, set()) + known = set(KNOWN_COMMON_VIOLATIONS.get(source, {}).keys()) + for imp in sorted(imports): + if imp in known: + continue + new_violations.append(f"{source} -> {imp}") + if new_violations: + pytest.fail( + "pyrit.common modules importing outside pyrit.common (forbidden):\n " + + "\n ".join(new_violations) + + "\n\npyrit.common is the foundation layer and may import only stdlib, " + "third-party libraries, and pyrit.common.*. Either remove the import or, if it " + "is transitional, add it to KNOWN_COMMON_VIOLATIONS in this file with a tag." + ) + + +def test_known_common_violations_still_apply() -> None: + """Entries in KNOWN_COMMON_VIOLATIONS that no longer exist in source must be removed.""" + actual_top, actual_lazy = _collect_common_imports() + stale: list[str] = [] + for source, allowed in KNOWN_COMMON_VIOLATIONS.items(): + present = actual_top.get(source, set()) | actual_lazy.get(source, set()) + stale.extend(f"{source} -> {imp}" for imp in allowed if imp not in present) + if stale: + pytest.fail( + "KNOWN_COMMON_VIOLATIONS entries that no longer exist in source:\n " + + "\n ".join(stale) + + "\n\nThe allowlist must shrink monotonically. Remove these entries." + ) + + def test_scan_finds_expected_files() -> None: """Sanity check: the scanner picks up the known model files.""" - scanned = {p.name for p in _scan_files()} + scanned = {p.name for p in _scan_files(MODELS_PACKAGE)} # A non-exhaustive sample of files that must exist for this test to be meaningful. expected = { "attack_result.py", diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 1dcebb0449..3cb4e4dcb1 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -962,7 +962,7 @@ def test_seed_dataset_dict_to_seed_prompt_all_base_params(): # SeedPrompt-specific fields "role": "assistant", "sequence": 5, - "parameters": {"param1": "val1"}, + "parameters": ["param1"], "seed_type": "prompt", } @@ -990,7 +990,7 @@ def test_seed_dataset_dict_to_seed_prompt_all_base_params(): # Verify SeedPrompt-specific fields assert seed.role == "assistant" assert seed.sequence == 5 - assert seed.parameters == {"param1": "val1"} + assert seed.parameters == ["param1"] def test_seed_dataset_dict_to_seed_objective_all_base_params(): @@ -1098,7 +1098,7 @@ def test_seed_dataset_uses_dataset_defaults_for_missing_params(): # These should use sensible defaults assert seed.role == "user" assert seed.sequence == 0 - assert seed.parameters == {} + assert seed.parameters == [] def test_next_message_single_turn_no_objective(): @@ -1110,6 +1110,34 @@ def test_next_message_single_turn_no_objective(): assert group.prepended_conversation is None assert group.next_message is not None assert len(group.next_message.message_pieces) == 1 + + +def test_seed_group_preserves_polymorphic_subclasses(): + """list[Seed] must preserve concrete subclasses: instances pass through and model_dump keeps subclass fields.""" + objective = SeedObjective(value="Reach the goal") + prompt = SeedPrompt(value="Hello", role="user", sequence=0) + group = SeedGroup(seeds=[objective, prompt]) + + # Instances are preserved, not coerced down to the base Seed. + assert isinstance(group.objective, SeedObjective) + assert any(isinstance(s, SeedPrompt) for s in group.seeds) + + # SerializeAsAny keeps subclass-specific fields on dump (SeedPrompt has role/parameters, SeedObjective does not). + dumped = group.model_dump() + by_value = {s["value"]: s for s in dumped["seeds"]} + assert by_value["Hello"]["role"] == "user" + assert "parameters" in by_value["Hello"] + assert "role" not in by_value["Reach the goal"] + + # Dict reconstruction uses the seed_type discriminator (the path used by YAML/dataset inputs). + rebuilt = SeedGroup( + seeds=[ + {"value": "Reach the goal", "seed_type": "objective"}, + {"value": "Hello", "role": "user", "sequence": 0, "seed_type": "prompt"}, + ] + ) + assert isinstance(rebuilt.objective, SeedObjective) + assert any(isinstance(s, SeedPrompt) for s in rebuilt.seeds) assert group.next_message.get_value() == "Hello" From 72c5d98d736e074676e7c176bce7a3468bb6c2d6 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 14:55:34 -0700 Subject: [PATCH 02/11] Phase 8 review fixes: discriminated-union seeds, rename validate, lock data_type Implements all reviewer feedback on the seeds Pydantic conversion: Blocking fix: SeedGroup/SeedAttackGroup with mixed Seed subclasses no longer corrupts polymorphism on model_dump/model_validate round-trips. Introduce a Literal seed_type discriminator on each leaf class (SeedPrompt/SeedObjective/SeedSimulatedConversation), and switch the polymorphic seeds field to a Field(discriminator=seed_type) annotated union (SeedUnion). The base Seed class is deliberately excluded from the union. NB1: rename validate -> _check_invariants on SeedGroup/SeedAttackGroup/SeedAttackTechniqueGroup so it does not shadow Pydantic v1's BaseModel.validate. External callers updated (atomic_attack, attack_parameters). NB2: stop silently dropping fields on SeedSimulatedConversation from a dict. Delete the bespoke from_dict and route through model_validate; add a before-validator that drops only the computed value field so round-trips are clean. NB3: lock data_type to Literal[text] on SeedObjective and SeedSimulatedConversation. Strip dataset/group-level data_type, role, sequence, parameters from non-prompt seed dicts so dataset-level defaults do not bleed in. Thin-class cleanups: Annotated StrOrList alias replaces the per-field _coerce_str_to_list validators on Seed and SeedPrompt; deterministic order-preserving list merge replaces utils.combine_list (which was nondeterministic across processes). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../executor/attack/core/attack_parameters.py | 2 +- pyrit/models/seeds/seed.py | 31 +-- pyrit/models/seeds/seed_attack_group.py | 4 +- .../seeds/seed_attack_technique_group.py | 4 +- pyrit/models/seeds/seed_dataset.py | 227 +++++++++--------- pyrit/models/seeds/seed_group.py | 73 +++--- pyrit/models/seeds/seed_objective.py | 9 +- pyrit/models/seeds/seed_prompt.py | 25 +- .../seeds/seed_simulated_conversation.py | 63 +++-- pyrit/scenario/core/atomic_attack.py | 2 +- tests/unit/models/test_seed.py | 62 ++++- .../test_seed_attack_technique_group.py | 2 +- tests/unit/models/test_seed_group.py | 2 +- .../test_seed_simulated_conversation.py | 22 +- 14 files changed, 287 insertions(+), 241 deletions(-) diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 6dc4166d7e..79bfa27b8c 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -120,7 +120,7 @@ async def from_seed_group_async( ) # Validate seed_group state before extracting parameters - seed_group.validate() + seed_group._check_invariants() # SeedAttackGroup validates in __init__ that objective is set if seed_group.objective is None: diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index b1dc517ac7..3f34fffedb 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -13,12 +13,12 @@ import re import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, Optional, TypeVar, Union import yaml from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field from pyrit.common.utils import verify_and_resolve_path from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) @@ -51,6 +51,11 @@ def coerce_str_to_list(value: Any) -> Any: return value +# Annotated type for list[str] fields that should accept a bare string as a one-element list. +# Use this for any seed list field populated from YAML where authors/groups/etc. may be scalars. +StrOrList = Annotated[list[str], BeforeValidator(coerce_str_to_list)] + + class PartialUndefined(Undefined): """Jinja undefined value that preserves unresolved placeholders as text.""" @@ -117,16 +122,16 @@ class Seed(BaseModel): dataset_name: Optional[str] = None # Categories of harm associated with this prompt - harm_categories: Optional[list[str]] = Field(default_factory=list) + harm_categories: Optional[StrOrList] = Field(default_factory=list) # Description of the prompt description: Optional[str] = None # Authors of the prompt - authors: Optional[list[str]] = Field(default_factory=list) + authors: Optional[StrOrList] = Field(default_factory=list) # Groups affiliated with the prompt - groups: Optional[list[str]] = Field(default_factory=list) + groups: Optional[StrOrList] = Field(default_factory=list) # Source of the prompt source: Optional[str] = None @@ -156,23 +161,9 @@ class Seed(BaseModel): # The type of data this seed represents (e.g., text, image_path, audio_path, video_path). # SeedPrompt overrides the default to None and infers it from the value; other seed types - # keep "text". + # narrow it to Literal["text"]. data_type: PromptDataType = "text" - @field_validator("harm_categories", "authors", "groups", mode="before") - @classmethod - def _coerce_str_to_list(cls, value: Any) -> Any: - """ - Coerce a bare string into a single-element list for these list-typed fields. - - Args: - value: The raw field value provided during validation. - - Returns: - The value wrapped in a list if it was a bare string, otherwise unchanged. - """ - return coerce_str_to_list(value) - def render_template_value(self, **kwargs: Any) -> str: """ Render self.value as a template with provided parameters. diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index e99a867dbf..f52f9dac7f 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -31,7 +31,7 @@ class SeedAttackGroup(SeedGroup): next_message, etc.) is inherited from SeedGroup. """ - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed attack group state. @@ -41,7 +41,7 @@ def validate(self) -> None: ValueError: If validation fails. """ - super().validate() + super()._check_invariants() self._enforce_exactly_one_objective() def _enforce_exactly_one_objective(self) -> None: diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index 8509fe1022..536e125fb0 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -32,7 +32,7 @@ class SeedAttackTechniqueGroup(SeedGroup): # ``None`` (default) appends at the end; an integer inserts before that position. insertion_index: Optional[int] = None - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed attack technique group state. @@ -42,7 +42,7 @@ def validate(self) -> None: Raises: ValueError: If validation fails. """ - super().validate() + super()._check_invariants() self._enforce_all_general_strategy() self._enforce_no_objectives() diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 61b0f56581..344a28b8f0 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -15,20 +15,18 @@ from typing import TYPE_CHECKING, Any, Optional, Union import yaml -from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator -from pyrit.common import utils from pyrit.common.utils import verify_and_resolve_path from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) -from pyrit.models.seeds.seed import ( # noqa: TC001 (runtime-required by Pydantic field annotations) - Seed, - coerce_str_to_list, -) +from pyrit.models.seeds.seed import Seed, StrOrList from pyrit.models.seeds.seed_attack_group import SeedAttackGroup -from pyrit.models.seeds.seed_group import SeedGroup +from pyrit.models.seeds.seed_group import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + SeedGroup, + SeedUnion, +) from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt -from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation if TYPE_CHECKING: from collections.abc import Sequence @@ -38,6 +36,43 @@ logger = logging.getLogger(__name__) +# Dataset-level defaults that get merged into each dict seed when missing on the seed. +# date_added/added_by/metadata are intentionally excluded — per-seed Pydantic defaults +# (default_factory) are the source of truth. +_SCALAR_DEFAULT_KEYS = ("name", "description", "source") +_LIST_DEFAULT_KEYS = ("harm_categories", "authors", "groups") + + +def _merge_unique(left: Any, right: Any) -> list[str]: + """ + Concatenate two list-or-str inputs into a deterministic, order-preserving deduped list. + + Treats ``None`` as empty, accepts bare strings as single-element lists, and preserves the + order of first occurrence (left first, then any new items from right). Used instead of + ``utils.combine_list`` because the latter goes through ``set()`` and is nondeterministic + across processes for non-trivial inputs. + + Args: + left: First list (or string) of values; falsy values are treated as empty. + right: Second list (or string) of values; falsy values are treated as empty. + + Returns: + list[str]: Deduplicated concatenation, preserving first-occurrence order. + """ + + def _as_list(v: Any) -> list[str]: + if not v: + return [] + return [v] if isinstance(v, str) else list(v) + + seen: set[str] = set() + result: list[str] = [] + for item in _as_list(left) + _as_list(right): + if item not in seen: + seen.add(item) + result.append(item) + return result + class SeedDataset(BaseModel): """ @@ -51,10 +86,10 @@ class SeedDataset(BaseModel): data_type: Optional[str] = "text" name: Optional[str] = None dataset_name: Optional[str] = None - harm_categories: Optional[list[str]] = None + harm_categories: Optional[StrOrList] = None description: Optional[str] = None - authors: Optional[list[str]] = Field(default_factory=list) - groups: Optional[list[str]] = Field(default_factory=list) + authors: Optional[StrOrList] = Field(default_factory=list) + groups: Optional[StrOrList] = Field(default_factory=list) source: Optional[str] = None date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) added_by: Optional[str] = None @@ -62,22 +97,32 @@ class SeedDataset(BaseModel): seed_type: Optional[SeedType] = None # The actual prompts - seeds: list[SerializeAsAny[Seed]] + seeds: list[SeedUnion] @model_validator(mode="before") @classmethod def _build_seeds(cls, data: Any) -> Any: """ - Convert dict seed entries into concrete Seed subclasses, merging dataset-level defaults. + Merge dataset-level defaults into each dict seed and normalize for the discriminator. - ``is_jinja_template`` is a construction-time flag (consumed here, not stored) that marks - seed values as trusted Jinja2 templates. + Concrete Seed instances pass through unchanged. For dict seeds: + + - ``seed_type`` defaults to the dataset's ``seed_type`` or ``"prompt"``. + - ``is_jinja_template`` (construction-time flag, popped from the dataset) is propagated. + - Scalar defaults (name, dataset_name, description, source) fall back to the dataset's + when the seed has none. + - List defaults (harm_categories, authors, groups) are concatenated with deterministic + order-preserving dedup (dataset values first, then seed-only additions). + - For prompts: ``data_type`` falls back to the dataset's; ``role`` defaults to ``"user"``. + - For objective/simulated_conversation: ``data_type``/``role``/``sequence``/ + ``parameters`` are stripped — they aren't valid fields on those classes and a + dataset-level value (e.g. ``data_type: image_path``) would otherwise be rejected. Returns: - Any: The input data with ``seeds`` replaced by built Seed instances. + The data with normalized ``seeds`` (passes through unchanged if not a dict). Raises: - ValueError: If the dataset has no seeds. + ValueError: If the dataset has no seeds or contains an unsupported seed entry. """ if not isinstance(data, dict): return data @@ -88,71 +133,47 @@ def _build_seeds(cls, data: Any) -> Any: if not raw_seeds: raise ValueError("SeedDataset cannot be empty.") - default_data_type = data.get("data_type", "text") - default_name = data.get("name") - default_dataset_name = data.get("dataset_name") - default_description = data.get("description") - default_source = data.get("source") - dataset_seed_type = data.get("seed_type") + default_seed_type = data.get("seed_type") or "prompt" + default_data_type = data.get("data_type") or "text" + default_dataset_name = data.get("dataset_name") or data.get("name") - built: list[Seed] = [] + normalized: list[Any] = [] for p in raw_seeds: - if isinstance(p, dict): - p_seed_type = p.get("seed_type", dataset_seed_type) - - base_params: dict[str, Any] = { - "value_sha256": p.get("value_sha256"), - "id": uuid.uuid4(), - "name": p.get("name") or default_name, - "dataset_name": p.get("dataset_name") or default_dataset_name or default_name, - "harm_categories": p.get("harm_categories", []), - "description": p.get("description") or default_description, - "authors": p.get("authors", []), - "groups": p.get("groups", []), - "source": p.get("source") or default_source, - "date_added": p.get("date_added"), - "added_by": p.get("added_by"), - "metadata": p.get("metadata", {}), - "prompt_group_id": p.get("prompt_group_id"), - "is_jinja_template": is_jinja_template, - } - - if p_seed_type == "simulated_conversation": - _adv_path = p.get("adversarial_chat_system_prompt_path") - _sim_path = p.get("simulated_target_system_prompt_path") - _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} - if _adv_path is not None: - _sc_kwargs["adversarial_chat_system_prompt_path"] = str(_adv_path) - if _sim_path is not None: - _sc_kwargs["simulated_target_system_prompt_path"] = str(_sim_path) - built.append(SeedSimulatedConversation(**_sc_kwargs)) - elif p_seed_type == "objective": - base_params["value"] = p["value"] - built.append(SeedObjective(**base_params)) - else: # prompt - base_params["value"] = p["value"] - built.append( - SeedPrompt( - **base_params, - data_type=p.get("data_type") or default_data_type, - role=p.get("role", "user"), - sequence=p.get("sequence", 0), - parameters=p.get("parameters") or [], - ) - ) - elif isinstance(p, (SeedPrompt, SeedObjective, SeedSimulatedConversation)): - built.append(p) - else: + if isinstance(p, Seed): + normalized.append(p) + continue + if not isinstance(p, dict): raise ValueError( "Seeds should be dicts or Seed objects (SeedPrompt, SeedObjective, SeedSimulatedConversation)." ) - data["seeds"] = built - for key in ("harm_categories", "authors", "groups"): - data[key] = coerce_str_to_list(data.get(key)) - data["authors"] = data.get("authors") or [] - data["groups"] = data.get("groups") or [] - data["date_added"] = data.get("date_added") or datetime.now(tz=timezone.utc) + p = dict(p) + seed_type = p.setdefault("seed_type", default_seed_type) + p["is_jinja_template"] = is_jinja_template + + for key in _SCALAR_DEFAULT_KEYS: + if not p.get(key) and data.get(key) is not None: + p[key] = data.get(key) + if not p.get("dataset_name") and default_dataset_name is not None: + p["dataset_name"] = default_dataset_name + + for key in _LIST_DEFAULT_KEYS: + p[key] = _merge_unique(data.get(key), p.get(key)) + + if seed_type == "prompt": + if not p.get("data_type"): + p["data_type"] = default_data_type + p.setdefault("role", "user") + else: + # Non-prompt seeds narrow data_type to Literal["text"] and don't have + # role/sequence/parameters fields. Drop those so dataset-level defaults + # don't bleed in and trip extra="forbid" on the leaf class. + for prompt_only in ("data_type", "role", "sequence", "parameters"): + p.pop(prompt_only, None) + + normalized.append(p) + + data["seeds"] = normalized return data @classmethod @@ -243,56 +264,36 @@ def get_random_values( @classmethod def from_dict(cls, data: dict[str, Any]) -> SeedDataset: """ - Build a SeedDataset by merging top-level defaults into each item in `seeds`. + Build a SeedDataset, assigning per-seed ``prompt_group_id`` by alias. + + Default merging now lives in :meth:`_build_seeds` so direct construction and + ``from_dict`` produce equivalent results. This method handles the YAML-only + concerns: rejecting pre-set ``prompt_group_id`` on input seeds and resolving + ``prompt_group_alias`` into a shared ``prompt_group_id``. Args: data (Dict[str, Any]): Dataset payload with top-level defaults and seed entries. Returns: - SeedDataset: Constructed dataset with merged defaults. + SeedDataset: Constructed dataset. Raises: - ValueError: If any seed entry includes a pre-set prompt_group_id. - + ValueError: If any seed entry includes a pre-set ``prompt_group_id``. """ - # Pop out the seeds section - seeds_data = data.pop("seeds", []) - - dataset_defaults = data # everything else is top-level - - merged_seeds: list[dict[str, Any]] = [] - for p in seeds_data: - # Merge dataset-level fields with the prompt-level fields - merged = utils.combine_dict(dataset_defaults, p) - - merged["harm_categories"] = utils.combine_list( - dataset_defaults.get("harm_categories", []), - p.get("harm_categories", []), - ) - - merged["authors"] = utils.combine_list( - dataset_defaults.get("authors", []), - p.get("authors", []), - ) - - merged["groups"] = utils.combine_list( - dataset_defaults.get("groups", []), - p.get("groups", []), - ) - - if "data_type" not in merged: - merged["data_type"] = dataset_defaults.get("data_type", "text") + data = dict(data) - merged_seeds.append(merged) + # Shallow-copy each dict seed so alias resolution doesn't mutate caller-owned dicts; + # non-dict seeds (e.g. Seed instances) pass through untouched. + seeds_data: list[Any] = [dict(seed) if isinstance(seed, dict) else seed for seed in data.get("seeds", [])] - for seed in merged_seeds: + dict_seeds = [s for s in seeds_data if isinstance(s, dict)] + for seed in dict_seeds: if "prompt_group_id" in seed: raise ValueError("prompt_group_id should not be set in seed data") + cls._set_seed_group_id_by_alias(dict_seeds) - SeedDataset._set_seed_group_id_by_alias(seed_prompts=merged_seeds) - - # Now create the dataset with the newly merged prompt dicts - return cls.model_validate({"seeds": merged_seeds, **dataset_defaults}) + data["seeds"] = seeds_data + return cls.model_validate(data) def render_template_value(self, **kwargs: object) -> None: """ diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 7397e26be1..80e26e17db 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -13,9 +13,9 @@ import logging import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union -from pydantic import BaseModel, ConfigDict, SerializeAsAny, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from pyrit.models.message import Message from pyrit.models.message_piece import MessagePiece @@ -29,6 +29,12 @@ logger = logging.getLogger(__name__) +# Concrete leaf classes for the polymorphic seed list. Use SeedLeaf when you need a plain +# typing alias (e.g. for local list[...] variables); use SeedUnion in Pydantic field annotations +# so the discriminator is honored during validation. +SeedLeaf = Union[SeedPrompt, SeedObjective, SeedSimulatedConversation] +SeedUnion = Annotated[SeedLeaf, Field(discriminator="seed_type")] + class SeedGroup(BaseModel): """ @@ -45,19 +51,22 @@ class SeedGroup(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - seeds: list[SerializeAsAny[Seed]] + seeds: list[SeedUnion] @model_validator(mode="before") @classmethod def _coerce_seeds(cls, data: Any) -> Any: """ - Coerce raw seed dicts into concrete Seed subclasses before validation. + Normalize dict seed inputs so the polymorphic discriminator can dispatch. - ``is_jinja_template`` is a construction-time flag (not a stored field): it is consumed - here and propagated to each dict seed so trusted YAML values are rendered as templates. + Concrete Seed instances pass through; dicts are tagged with a default + ``seed_type="prompt"`` when missing and have the construction-time + ``is_jinja_template`` flag propagated in. ``data_type`` is stripped for + non-prompt seeds because they narrow it to ``Literal["text"]``; dataset-level + defaults must not bleed into them. Returns: - Any: The input data with ``seeds`` replaced by concrete Seed instances. + The data with normalized ``seeds`` (passes through unchanged if not a dict). Raises: ValueError: If the group has no seeds or a seed has an unsupported type. @@ -71,29 +80,27 @@ def _coerce_seeds(cls, data: Any) -> Any: if not raw_seeds: raise ValueError("SeedGroup cannot be empty.") - coerced: list[Seed] = [] + normalized: list[Any] = [] for seed in raw_seeds: if isinstance(seed, Seed): - coerced.append(seed) - elif isinstance(seed, dict): - seed = dict(seed) - seed["is_jinja_template"] = is_jinja_template - seed_type = seed.pop("seed_type", None) - - if seed_type == "simulated_conversation": - # SeedSimulatedConversation doesn't use data_type (always text) - seed.pop("data_type", None) - coerced.append(SeedSimulatedConversation.from_dict(seed)) - elif seed_type == "objective": - # SeedObjective doesn't use data_type (always text) - seed.pop("data_type", None) - coerced.append(SeedObjective(**seed)) - else: - coerced.append(SeedPrompt(**seed)) - else: + normalized.append(seed) + continue + if not isinstance(seed, dict): raise ValueError(f"Invalid seed type: {type(seed)}") - - data["seeds"] = coerced + seed = dict(seed) + seed.setdefault("seed_type", "prompt") + seed["is_jinja_template"] = is_jinja_template + if seed["seed_type"] == "prompt": + seed.setdefault("role", "user") + else: + # Non-prompt seeds narrow data_type to Literal["text"] and don't have + # role/sequence/parameters fields. Drop them so dataset/group-level + # defaults don't bleed in and trip extra="forbid". + for prompt_only in ("data_type", "role", "sequence", "parameters"): + seed.pop(prompt_only, None) + normalized.append(seed) + + data["seeds"] = normalized return data @model_validator(mode="after") @@ -106,13 +113,13 @@ def _finalize(self) -> SeedGroup: Returns: SeedGroup: The validated, reordered group. """ - self.validate() + self._check_invariants() objective = self._get_objective() simulated_conv = self._get_simulated_conversation() sorted_prompts = sorted(self.prompts, key=lambda p: p.sequence if p.sequence is not None else 0) - new_seeds: list[Seed] = [] + new_seeds: list[SeedLeaf] = [] if objective: new_seeds.append(objective) if simulated_conv: @@ -125,13 +132,13 @@ def _finalize(self) -> SeedGroup: # Validation # ========================================================================= - def validate(self) -> None: + def _check_invariants(self) -> None: """ Validate the seed group state. - This method can be called after external modifications to seeds - to ensure the group remains in a valid state. It is automatically - called during initialization. + Renamed from ``validate`` because that name shadows ``BaseModel.validate`` and + would silently return ``None`` instead of constructing when called on the class. + Subclasses override this hook to add stronger invariants. Raises: ValueError: If validation fails. diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index 26b59f1f1e..9f137c24b9 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union from pydantic import model_validator @@ -24,6 +24,13 @@ class SeedObjective(Seed): """Represents a seed objective with various attributes and metadata.""" + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["objective"] = "objective" + + # Objectives are always text. Narrowing the base field rejects non-text values up-front + # rather than silently dropping them downstream. + data_type: Literal["text"] = "text" + @model_validator(mode="after") def _validate_and_render(self) -> SeedObjective: """ diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 4d8312eae0..862017f729 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -9,9 +9,9 @@ import logging import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from tinytag import TinyTag from pyrit.common.path import PATHS_DICT @@ -20,7 +20,7 @@ ChatMessageRole, PromptDataType, ) -from pyrit.models.seeds.seed import Seed, coerce_str_to_list +from pyrit.models.seeds.seed import Seed, StrOrList if TYPE_CHECKING: import uuid @@ -34,6 +34,9 @@ class SeedPrompt(Seed): """Represents a seed prompt with various attributes and metadata.""" + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["prompt"] = "prompt" + # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) # This field overrides the base default to allow per-prompt data types inferred from the value data_type: Optional[PromptDataType] = None @@ -46,21 +49,7 @@ class SeedPrompt(Seed): sequence: int = 0 # Parameters that can be used in the prompt template - parameters: Optional[list[str]] = Field(default_factory=list) - - @field_validator("parameters", mode="before") - @classmethod - def _coerce_parameters_to_list(cls, value: object) -> object: - """ - Coerce a bare string ``parameters`` value into a single-element list. - - Args: - value: The raw field value provided during validation. - - Returns: - The value wrapped in a list if it was a bare string, otherwise unchanged. - """ - return coerce_str_to_list(value) + parameters: Optional[StrOrList] = Field(default_factory=list) @model_validator(mode="after") def _render_and_infer_data_type(self) -> SeedPrompt: diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index 9db08f0b44..d23a3891e7 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -19,7 +19,7 @@ import json import logging from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import field_validator, model_validator @@ -68,7 +68,16 @@ class SeedSimulatedConversation(Seed): """ - # value is computed from the config in the after-validator; it must not be supplied directly. + # Discriminator field for the polymorphic Seed union (see seed_group.SeedUnion). + seed_type: Literal["simulated_conversation"] = "simulated_conversation" + + # Simulated conversations are always text. Narrowing the base field rejects non-text values + # up-front rather than silently dropping them downstream. + data_type: Literal["text"] = "text" + + # value is computed from the config in the after-validator. The base default of "" plus a + # before-validator that strips any user-supplied value keeps round-trips clean: a dumped + # value comes back in, is dropped, then is recomputed (and matches if the config matches). value: str = "" # Simulated conversations are general techniques by default. @@ -81,6 +90,22 @@ class SeedSimulatedConversation(Seed): next_message_system_prompt_path: Optional[Path] = None pyrit_version: Optional[str] = None + @model_validator(mode="before") + @classmethod + def _strip_user_value(cls, data: Any) -> Any: + """ + Drop any user-supplied ``value`` from dict input; it is always recomputed in the + after-validator. This keeps round-tripping clean and makes the API honest about the + fact that ``value`` is a derived JSON serialization of the config. + + Returns: + The data with ``value`` removed if it was a dict; otherwise the input unchanged. + """ + if isinstance(data, dict) and "value" in data: + data = dict(data) + data.pop("value", None) + return data + @field_validator("simulated_target_system_prompt_path", mode="before") @classmethod def _default_simulated_target_path(cls, value: Any) -> Any: @@ -120,40 +145,6 @@ def _compute_value(self) -> str: } return json.dumps(config, sort_keys=True, separators=(",", ":")) - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SeedSimulatedConversation: - """ - Create a SeedSimulatedConversation from a dictionary, typically from YAML. - - Expected format: - num_turns: 3 - adversarial_chat_system_prompt_path: path/to/adversarial.yaml - simulated_target_system_prompt_path: path/to/simulated.yaml # optional - - Args: - data: Dictionary containing the configuration. - - Returns: - A new SeedSimulatedConversation instance. - - Raises: - ValueError: If required configuration fields are missing. - - """ - adversarial_path = data.get("adversarial_chat_system_prompt_path") - if not adversarial_path: - raise ValueError("adversarial_chat_system_prompt_path is required") - - return cls( - num_turns=data.get("num_turns", 3), - sequence=data.get("sequence", 0), - adversarial_chat_system_prompt_path=adversarial_path, - simulated_target_system_prompt_path=( - data.get("simulated_target_system_prompt_path") or SimulatedTargetSystemPromptPaths.COMPLIANT.value - ), - next_message_system_prompt_path=data.get("next_message_system_prompt_path"), - ) - @classmethod def from_yaml_with_required_parameters( cls, diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 652e203489..ec588e79ff 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -114,7 +114,7 @@ def __init__( # Validate each seed group to ensure they are in a valid state for sg in seed_groups: - sg.validate() + sg._check_invariants() self._seed_groups = seed_groups self._validate_unique_objective_hashes() diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 3cb4e4dcb1..c2f26cf75d 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -16,12 +16,13 @@ from pyrit.models import ( Message, MessagePiece, + SeedAttackGroup, SeedDataset, SeedGroup, SeedObjective, SeedPrompt, ) -from pyrit.models.seeds import SeedSimulatedConversation +from pyrit.models.seeds import SeedSimulatedConversation, SimulatedTargetSystemPromptPaths @pytest.fixture @@ -1141,6 +1142,65 @@ def test_seed_group_preserves_polymorphic_subclasses(): assert group.next_message.get_value() == "Hello" +def test_seed_group_round_trip_preserves_subclasses(): + """model_validate(model_dump()) must reconstruct the original subclass instances, not coerce to base Seed.""" + objective = SeedObjective(value="goal") + prompt = SeedPrompt(value="hi", role="user", sequence=0) + group = SeedGroup(seeds=[objective, prompt]) + + rt = SeedGroup.model_validate(group.model_dump()) + + assert [type(s).__name__ for s in rt.seeds] == [type(s).__name__ for s in group.seeds] + assert isinstance(rt.objective, SeedObjective) + assert rt.objective.value == "goal" + + +def test_seed_attack_group_round_trip_preserves_subclasses(): + """The original blocking bug: SeedAttackGroup(model_validate(model_dump())) must work.""" + sag = SeedAttackGroup( + seeds=[ + SeedObjective(value="objective"), + SeedPrompt(value="hi", data_type="text", role="user", sequence=0), + ] + ) + rt = SeedAttackGroup.model_validate(sag.model_dump()) + assert isinstance(rt.objective, SeedObjective) + assert rt.objective.value == "objective" + assert any(isinstance(s, SeedPrompt) for s in rt.seeds) + + +def test_seed_simulated_conversation_round_trip(): + """SeedSimulatedConversation's computed ``value`` round-trips through both python and json modes.""" + sc = SeedSimulatedConversation( + adversarial_chat_system_prompt_path=SimulatedTargetSystemPromptPaths.COMPLIANT.value, + num_turns=4, + ) + rt_py = SeedSimulatedConversation.model_validate(sc.model_dump()) + rt_json = SeedSimulatedConversation.model_validate(sc.model_dump(mode="json")) + assert rt_py.value == sc.value + assert rt_json.value == sc.value + assert rt_py.num_turns == 4 + + +def test_seed_objective_rejects_non_text_data_type(): + """SeedObjective locks ``data_type`` to ``"text"``.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + SeedObjective(value="goal", data_type="image_path") + + +def test_seed_simulated_conversation_rejects_non_text_data_type(): + """SeedSimulatedConversation locks ``data_type`` to ``"text"``.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + SeedSimulatedConversation( + adversarial_chat_system_prompt_path=SimulatedTargetSystemPromptPaths.COMPLIANT.value, + data_type="image_path", + ) + + def test_next_message_single_turn_with_objective(): """Test next_message property for a single-turn SeedGroup with an objective.""" prompt = SeedPrompt(value="Hello", data_type="text", sequence=0, role="user") diff --git a/tests/unit/models/test_seed_attack_technique_group.py b/tests/unit/models/test_seed_attack_technique_group.py index 05c21a62e3..8c01a166d2 100644 --- a/tests/unit/models/test_seed_attack_technique_group.py +++ b/tests/unit/models/test_seed_attack_technique_group.py @@ -148,7 +148,7 @@ def test_validate_all_general_strategy_passes(self): ] ) # Should not raise - group.validate() + group._check_invariants() def test_error_message_includes_non_general_types(self): """Test that error message lists the types of non-general seeds.""" diff --git a/tests/unit/models/test_seed_group.py b/tests/unit/models/test_seed_group.py index 4a8e43a2c1..8c7e08a811 100644 --- a/tests/unit/models/test_seed_group.py +++ b/tests/unit/models/test_seed_group.py @@ -610,7 +610,7 @@ def test_merged_group_is_valid_seed_attack_group(self): merged = base.with_technique(technique=technique) assert isinstance(merged, SeedAttackGroup) - merged.validate() # should not raise + merged._check_invariants() # should not raise def test_raises_when_technique_has_simulated_conversation_and_prompts_overlap(self): """Merging a technique with SeedSimulatedConversation into a group with overlapping prompts raises.""" diff --git a/tests/unit/models/test_seed_simulated_conversation.py b/tests/unit/models/test_seed_simulated_conversation.py index 8afe06f914..c8239a9caa 100644 --- a/tests/unit/models/test_seed_simulated_conversation.py +++ b/tests/unit/models/test_seed_simulated_conversation.py @@ -174,11 +174,11 @@ def test_init_next_message_system_prompt_path_set(self, tmp_path): assert conv.next_message_system_prompt_path == next_msg_path -class TestSeedSimulatedConversationFromDict: - """Tests for SeedSimulatedConversation.from_dict method.""" +class TestSeedSimulatedConversationFromMapping: + """Tests for constructing SeedSimulatedConversation from a dict via ``model_validate``.""" def test_from_dict_with_paths(self, tmp_path): - """Test from_dict with path values.""" + """Test construction from a dict with path values.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") @@ -186,13 +186,13 @@ def test_from_dict_with_paths(self, tmp_path): "num_turns": 5, "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) assert conv.num_turns == 5 assert conv.adversarial_chat_system_prompt_path == adv_path def test_from_dict_without_simulated_target_path(self, tmp_path): - """Test from_dict without simulated_target_system_prompt_path uses compliant default.""" + """Test construction without simulated_target_system_prompt_path uses compliant default.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") @@ -200,29 +200,29 @@ def test_from_dict_without_simulated_target_path(self, tmp_path): "num_turns": 3, "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) # Default simulated_target_system_prompt_path is the compliant prompt assert conv.simulated_target_system_prompt_path == SimulatedTargetSystemPromptPaths.COMPLIANT.value def test_from_dict_default_num_turns(self, tmp_path): - """Test from_dict uses default num_turns when not specified.""" + """Test that num_turns defaults to 3 when not specified.""" adv_path = tmp_path / "adversarial.yaml" adv_path.write_text("value: test\ndata_type: text") data = { "adversarial_chat_system_prompt_path": str(adv_path), } - conv = SeedSimulatedConversation.from_dict(data) + conv = SeedSimulatedConversation.model_validate(data) assert conv.num_turns == 3 def test_from_dict_missing_adversarial_path_raises_error(self): - """Test that from_dict raises error when adversarial path is missing.""" + """Test that construction raises when adversarial path is missing (required field).""" data = {"num_turns": 3} - with pytest.raises(ValueError, match="adversarial_chat_system_prompt_path is required"): - SeedSimulatedConversation.from_dict(data) + with pytest.raises(ValueError, match="adversarial_chat_system_prompt_path"): + SeedSimulatedConversation.model_validate(data) class TestSeedSimulatedConversationGetIdentifier: From d9abfa356e7b1e651f465d3ef7b7ced7492e5d1d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:15:35 -0700 Subject: [PATCH 03/11] Phase 9.1: split YAML loading out of seed classes (Option A) Move yaml-loading, path-resolution, and the is_jinja_template=True trust marker from inside Seed/SeedDataset classmethods into a dedicated seed_loader module. - New pyrit/models/seeds/seed_loader.py exposes load_seed_from_yaml, load_seed_dataset_from_yaml, and load_seed_prompt_from_yaml_with_required_parameters. - Seed.from_yaml_file, SeedDataset.from_yaml_file, and SeedPrompt.from_yaml_with_required_parameters reduced to thin shims that delegate to the loader functions, so all ~70 existing callsites keep working unchanged. - Deleted SeedObjective.from_yaml_with_required_parameters (no-op shim, no callers), SeedSimulatedConversation.from_yaml_with_required_parameters (no callers), and the base Seed.from_yaml_with_required_parameters (only SeedPrompt's real validation is left, where it is actually used). - yaml and verify_and_resolve_path no longer imported in seed.py / seed_dataset.py. - Stricter loader validation: empty files and top-level non-mappings now raise ValueError with a clear message rather than cryptic TypeErrors. - New tests/unit/models/test_seed_loader.py (14 tests) covers the trust-marker behavior, error paths, dataset propagation, and classmethod-shim equivalence. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/__init__.py | 18 ++- pyrit/models/seeds/seed.py | 42 +---- pyrit/models/seeds/seed_dataset.py | 21 +-- pyrit/models/seeds/seed_loader.py | 137 ++++++++++++++++ pyrit/models/seeds/seed_objective.py | 27 +--- pyrit/models/seeds/seed_prompt.py | 19 ++- .../seeds/seed_simulated_conversation.py | 32 ---- tests/unit/models/test_seed_loader.py | 149 ++++++++++++++++++ 8 files changed, 324 insertions(+), 121 deletions(-) create mode 100644 pyrit/models/seeds/seed_loader.py create mode 100644 tests/unit/models/test_seed_loader.py diff --git a/pyrit/models/seeds/__init__.py b/pyrit/models/seeds/__init__.py index 6130580d63..b89f9aed13 100644 --- a/pyrit/models/seeds/__init__.py +++ b/pyrit/models/seeds/__init__.py @@ -20,6 +20,11 @@ from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup from pyrit.models.seeds.seed_dataset import SeedDataset from pyrit.models.seeds.seed_group import SeedGroup +from pyrit.models.seeds.seed_loader import ( + load_seed_dataset_from_yaml, + load_seed_from_yaml, + load_seed_prompt_from_yaml_with_required_parameters, +) from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import ( @@ -29,14 +34,17 @@ ) __all__ = [ + "NextMessageSystemPromptPaths", "Seed", - "SeedPrompt", - "SeedObjective", - "SeedGroup", "SeedAttackGroup", "SeedAttackTechniqueGroup", + "SeedDataset", + "SeedGroup", + "SeedObjective", + "SeedPrompt", "SeedSimulatedConversation", "SimulatedTargetSystemPromptPaths", - "NextMessageSystemPromptPaths", - "SeedDataset", + "load_seed_dataset_from_yaml", + "load_seed_from_yaml", + "load_seed_prompt_from_yaml_with_required_parameters", ] diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 3f34fffedb..857c641350 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -15,12 +15,10 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Annotated, Any, Optional, TypeVar, Union -import yaml from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, BeforeValidator, ConfigDict, Field -from pyrit.common.utils import verify_and_resolve_path from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) if TYPE_CHECKING: @@ -267,6 +265,9 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: """ Create a new Seed from a YAML file, marking it as a trusted Jinja2 template. + Thin shim that delegates to :func:`pyrit.models.seeds.seed_loader.load_seed_from_yaml`; + file I/O and the ``is_jinja_template`` trust marker live in the loader module. + Args: file: The input file path. @@ -274,38 +275,9 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: A new Seed of the specific subclass type. Raises: - ValueError: If the YAML file is invalid. - """ - file = verify_and_resolve_path(file) - - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc - - yaml_data["is_jinja_template"] = True - return cls(**yaml_data) - - @classmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> Seed: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML file is invalid or empty. """ - Load a Seed from a YAML file and validate that it contains specific parameters. + from pyrit.models.seeds.seed_loader import load_seed_from_yaml - The base implementation simply loads the file; subclasses that support parameters - (e.g. SeedPrompt) override this to enforce ``required_parameters``. - - Args: - template_path: Path to the YAML file containing the template. - required_parameters: List of parameter names that must exist in the template. - error_message: Custom error message if validation fails. If None, a default message is used. - - Returns: - Seed: The loaded and validated seed of the specific subclass type. - - """ - return cls.from_yaml_file(template_path) + return load_seed_from_yaml(file, cls=cls) diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 344a28b8f0..d8dc6de7c6 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -14,10 +14,8 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Optional, Union -import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator -from pyrit.common.utils import verify_and_resolve_path from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) from pyrit.models.seeds.seed import Seed, StrOrList from pyrit.models.seeds.seed_attack_group import SeedAttackGroup @@ -181,6 +179,10 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: """ Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. + Thin shim that delegates to + :func:`pyrit.models.seeds.seed_loader.load_seed_dataset_from_yaml`; file I/O and + the ``is_jinja_template`` trust marker live in the loader module. + Args: file: The input file path. @@ -188,19 +190,12 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: SeedDataset: The loaded dataset. Raises: - ValueError: If the YAML file is invalid. + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML file is invalid or empty. """ - file = verify_and_resolve_path(file) - try: - yaml_data = yaml.safe_load(file.read_text("utf-8")) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc - - if yaml_data is None: - raise ValueError(f"YAML file '{file}' is empty.") + from pyrit.models.seeds.seed_loader import load_seed_dataset_from_yaml - yaml_data["is_jinja_template"] = True - return cls.from_dict(yaml_data) + return load_seed_dataset_from_yaml(file) def get_values( self, diff --git a/pyrit/models/seeds/seed_loader.py b/pyrit/models/seeds/seed_loader.py new file mode 100644 index 0000000000..fc38d0fae7 --- /dev/null +++ b/pyrit/models/seeds/seed_loader.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +YAML loaders for seed types. + +These functions live separately from the seed classes themselves because the +*trust claim* that a value came from a vetted local YAML file (vs. an untrusted +remote dataset) is a property of the loader, not of the data class. A ``Seed`` +instance can't know its own provenance; the loader can, so the +``is_jinja_template=True`` marker is set exactly once at this boundary. + +The ``from_yaml_file`` and ``from_yaml_with_required_parameters`` classmethods +on the seed classes are thin shims that delegate here. See the plan-gist +Phase 9.1 entry for the eventual move of this module to ``pyrit/io/seeds.py``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union + +import yaml + +from pyrit.common.utils import verify_and_resolve_path + +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.models.seeds.seed import Seed + from pyrit.models.seeds.seed_dataset import SeedDataset + from pyrit.models.seeds.seed_prompt import SeedPrompt + +T = TypeVar("T", bound="Seed") + + +def _read_yaml(file: Union[str, Path]) -> dict[str, Any]: + """ + Resolve, read, and parse a YAML file as a mapping. + + Args: + file: Path to a YAML file. + + Returns: + The parsed top-level mapping. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed or empty. + """ + file = verify_and_resolve_path(file) + try: + data = yaml.safe_load(file.read_text("utf-8")) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc + + if data is None: + raise ValueError(f"YAML file '{file}' is empty.") + if not isinstance(data, dict): + raise ValueError(f"YAML file '{file}' must contain a mapping at the top level.") + return data + + +def load_seed_from_yaml(file: Union[str, Path], *, cls: type[T]) -> T: + """ + Load a single seed of type ``cls`` from a YAML file. + + The seed is marked ``is_jinja_template=True`` because the file is treated + as a trusted, vetted local template at this boundary. + + Args: + file: Path to the YAML file containing the seed definition. + cls: Seed subclass to instantiate (e.g. ``SeedPrompt``, ``SeedObjective``). + + Returns: + An instance of ``cls`` populated from the YAML file. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed, empty, or fails validation for ``cls``. + """ + data = _read_yaml(file) + data["is_jinja_template"] = True + return cls(**data) + + +def load_seed_dataset_from_yaml(file: Union[str, Path]) -> SeedDataset: + """ + Load a ``SeedDataset`` from a YAML file. + + Nested seeds inherit the ``is_jinja_template=True`` trust marker set at this + boundary; per-seed overrides in the YAML are intentionally ignored. + + Args: + file: Path to the YAML file containing the dataset definition. + + Returns: + A ``SeedDataset`` populated from the YAML file. + + Raises: + FileNotFoundError: If the path does not resolve to an existing file. + ValueError: If the YAML is malformed, empty, or fails dataset validation. + """ + from pyrit.models.seeds.seed_dataset import SeedDataset + + data = _read_yaml(file) + data["is_jinja_template"] = True + return SeedDataset.from_dict(data) + + +def load_seed_prompt_from_yaml_with_required_parameters( + template_path: Union[str, Path], + required_parameters: list[str], + *, + error_message: Optional[str] = None, +) -> SeedPrompt: + """ + Load a ``SeedPrompt`` and assert that its ``parameters`` field declares each required name. + + Args: + template_path: Path to the YAML file containing the prompt template. + required_parameters: Parameter names that must appear in ``SeedPrompt.parameters``. + error_message: Optional custom message used in the raised ``ValueError``. + + Returns: + The loaded ``SeedPrompt``. + + Raises: + ValueError: If the loaded prompt is missing any required parameter. + """ + from pyrit.models.seeds.seed_prompt import SeedPrompt + + sp = load_seed_from_yaml(template_path, cls=SeedPrompt) + if sp.parameters is None or not all(p in sp.parameters for p in required_parameters): + if error_message is None: + error_message = f"Template must have these parameters: {', '.join(required_parameters)}" + raise ValueError(f"{error_message}: '{sp}'") + return sp diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index 9f137c24b9..dc0588d3f1 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -8,16 +8,13 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import Literal from pydantic import model_validator from pyrit.common.path import PATHS_DICT from pyrit.models.seeds.seed import Seed -if TYPE_CHECKING: - from pathlib import Path - logger = logging.getLogger(__name__) @@ -48,25 +45,3 @@ def _validate_and_render(self) -> SeedObjective: if self.is_jinja_template: self.value = self.render_template_value_silent(**PATHS_DICT) return self - - @classmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> SeedObjective: - """ - Load a Seed from a YAML file. Because SeedObjectives do not have any parameters, the required_parameters - and error_message arguments are unused. - - Args: - template_path: Path to the YAML file containing the template. - required_parameters: List of parameter names that must exist in the template. - error_message: Custom error message if validation fails. If None, a default message is used. - - Returns: - SeedObjective: The loaded and validated seed of the specific subclass type. - - """ - return cls.from_yaml_file(template_path) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 862017f729..6251f78c7b 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -135,7 +135,10 @@ def from_yaml_with_required_parameters( error_message: Optional[str] = None, ) -> SeedPrompt: """ - Load a Seed from a YAML file and validate that it contains specific parameters. + Load a SeedPrompt from a YAML file and validate that it declares each required parameter. + + Thin shim that delegates to + :func:`pyrit.models.seeds.seed_loader.load_seed_prompt_from_yaml_with_required_parameters`. Args: template_path: Path to the YAML file containing the template. @@ -143,20 +146,16 @@ def from_yaml_with_required_parameters( error_message: Custom error message if validation fails. If None, a default message is used. Returns: - SeedPrompt: The loaded and validated SeedPrompt of the specific subclass type. + SeedPrompt: The loaded and validated SeedPrompt. Raises: ValueError: If the template doesn't contain all required parameters. - """ - sp = cls.from_yaml_file(template_path) - - if sp.parameters is None or not all(param in sp.parameters for param in required_parameters): - if error_message is None: - error_message = f"Template must have these parameters: {', '.join(required_parameters)}" - raise ValueError(f"{error_message}: '{sp}'") + from pyrit.models.seeds.seed_loader import load_seed_prompt_from_yaml_with_required_parameters - return sp + return load_seed_prompt_from_yaml_with_required_parameters( + template_path, required_parameters, error_message=error_message + ) @staticmethod def from_messages( diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index d23a3891e7..b1359445c8 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -145,38 +145,6 @@ def _compute_value(self) -> str: } return json.dumps(config, sort_keys=True, separators=(",", ":")) - @classmethod - def from_yaml_with_required_parameters( - cls, - template_path: Union[str, Path], - required_parameters: list[str], - error_message: Optional[str] = None, - ) -> SeedSimulatedConversation: - """ - Load a SeedSimulatedConversation from a YAML file and validate required parameters. - - Args: - template_path: Path to the YAML file containing the config. - required_parameters: List of parameter names that must exist. - error_message: Custom error message if validation fails. - - Returns: - The loaded and validated SeedSimulatedConversation. - - Raises: - ValueError: If required parameters are missing. - - """ - instance = cls.from_yaml_file(template_path) - - # Check required parameters - for param in required_parameters: - if not hasattr(instance, param) or getattr(instance, param) is None: - msg = error_message or f"Missing required parameter: {param}" - raise ValueError(msg) - - return instance - def get_identifier(self) -> dict[str, Any]: """ Get an identifier dict capturing this configuration for comparison/storage. diff --git a/tests/unit/models/test_seed_loader.py b/tests/unit/models/test_seed_loader.py new file mode 100644 index 0000000000..7f7f5fec91 --- /dev/null +++ b/tests/unit/models/test_seed_loader.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.models.seeds import ( + SeedDataset, + SeedObjective, + SeedPrompt, + load_seed_dataset_from_yaml, + load_seed_from_yaml, + load_seed_prompt_from_yaml_with_required_parameters, +) + + +def _write(tmp_path, name, content): + p = tmp_path / name + p.write_text(content, encoding="utf-8") + return p + + +def test_load_seed_from_yaml_returns_typed_seed(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: hello world\ndata_type: text\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert isinstance(loaded, SeedPrompt) + assert loaded.value == "hello world" + # Loader is the trust boundary — it sets the jinja-template marker exactly once. + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_supports_objective(tmp_path): + yaml_file = _write(tmp_path, "o.yaml", "value: stop the attacker\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedObjective) + + assert isinstance(loaded, SeedObjective) + assert loaded.value == "stop the attacker" + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_overrides_in_file_value(tmp_path): + # An ``is_jinja_template: false`` in the YAML must not let the file claim it is untrusted — + # the loader's trust claim wins. + yaml_file = _write(tmp_path, "p.yaml", "value: x\nis_jinja_template: false\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.is_jinja_template is True + + +def test_load_seed_from_yaml_accepts_string_path(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: x\n") + + loaded = load_seed_from_yaml(str(yaml_file), cls=SeedPrompt) + + assert loaded.value == "x" + + +def test_load_seed_from_yaml_missing_file_raises(tmp_path): + with pytest.raises(FileNotFoundError): + load_seed_from_yaml(tmp_path / "missing.yaml", cls=SeedPrompt) + + +def test_load_seed_from_yaml_empty_file_raises(tmp_path): + yaml_file = _write(tmp_path, "empty.yaml", "") + + with pytest.raises(ValueError, match="is empty"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_from_yaml_top_level_list_raises(tmp_path): + yaml_file = _write(tmp_path, "list.yaml", "- a\n- b\n") + + with pytest.raises(ValueError, match="must contain a mapping"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_from_yaml_invalid_yaml_raises(tmp_path): + yaml_file = _write(tmp_path, "bad.yaml", "value: [unterminated\n") + + with pytest.raises(ValueError, match="Invalid YAML file"): + load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + +def test_load_seed_dataset_from_yaml_marks_seeds_as_trusted(tmp_path): + yaml_file = _write( + tmp_path, + "ds.yaml", + "name: tiny\nseeds:\n - value: a\n - value: b\n", + ) + + dataset = load_seed_dataset_from_yaml(yaml_file) + + assert isinstance(dataset, SeedDataset) + assert [s.value for s in dataset.seeds] == ["a", "b"] + # Trust marker propagates from the loader to each nested seed. + assert all(s.is_jinja_template for s in dataset.seeds) + + +def test_load_seed_dataset_from_yaml_empty_raises(tmp_path): + yaml_file = _write(tmp_path, "empty.yaml", "") + + with pytest.raises(ValueError, match="is empty"): + load_seed_dataset_from_yaml(yaml_file) + + +def test_load_seed_prompt_from_yaml_with_required_parameters_succeeds(tmp_path): + yaml_file = _write( + tmp_path, + "t.yaml", + "value: hello {{ name }}\nparameters:\n - name\n", + ) + + loaded = load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["name"]) + + assert isinstance(loaded, SeedPrompt) + assert loaded.parameters == ["name"] + + +def test_load_seed_prompt_from_yaml_with_required_parameters_missing_raises(tmp_path): + yaml_file = _write( + tmp_path, + "t.yaml", + "value: hello {{ name }}\nparameters:\n - name\n", + ) + + with pytest.raises(ValueError, match="Template must have these parameters: name, age"): + load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["name", "age"]) + + +def test_load_seed_prompt_from_yaml_with_required_parameters_custom_error(tmp_path): + yaml_file = _write(tmp_path, "t.yaml", "value: no params here\n") + + with pytest.raises(ValueError, match="bespoke"): + load_seed_prompt_from_yaml_with_required_parameters(yaml_file, ["foo"], error_message="bespoke") + + +def test_classmethod_shims_delegate_to_loader(tmp_path): + # Verifies the public classmethod surface (Seed.from_yaml_file et al.) + # produces the same result as the loader functions — i.e. the shims are honest. + yaml_file = _write(tmp_path, "p.yaml", "value: x\n") + + via_classmethod = SeedPrompt.from_yaml_file(yaml_file) + via_function = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert via_classmethod.value == via_function.value + assert via_classmethod.is_jinja_template == via_function.is_jinja_template is True From 5648e2eb75c4bd16cd2ff9d605d4074d8c5363fa Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:24:33 -0700 Subject: [PATCH 04/11] Phase 9.1: move scalar-to-list YAML accommodation from model to seed_loader StrOrList / coerce_str_to_list existed solely to accommodate YAML's scalar-or-sequence shorthand (e.g. `authors: Jane Doe`). That is a loader-layer concern leaking into the data class, same pattern as the is_jinja_template trust marker handled in 9.1. Move it to the loader: a new _canonicalize_scalar_lists helper wraps bare strings for known list-typed seed fields (harm_categories, authors, groups, parameters) at the YAML boundary and recurses into nested seeds for dataset/group files. The model fields are now plain Optional[list[str]], so programmatic constructors are strict. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/seed.py | 33 ++----------- pyrit/models/seeds/seed_dataset.py | 8 +-- pyrit/models/seeds/seed_loader.py | 41 ++++++++++++++-- pyrit/models/seeds/seed_prompt.py | 4 +- tests/unit/models/test_seed_loader.py | 70 +++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 37 deletions(-) diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 857c641350..e1471b673a 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -13,11 +13,11 @@ import re import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Annotated, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) @@ -31,29 +31,6 @@ T = TypeVar("T", bound="Seed") -def coerce_str_to_list(value: Any) -> Any: - """ - Coerce a bare string into a single-element list, leaving other values unchanged. - - YAML seed files commonly specify list-typed fields as a single scalar (e.g. ``authors: Jane Doe``) - rather than a list. This wraps such a value so it satisfies a ``list[str]`` field type. - - Args: - value: The raw field value provided during validation. - - Returns: - The value wrapped in a list if it was a bare string, otherwise unchanged. - """ - if isinstance(value, str): - return [value] - return value - - -# Annotated type for list[str] fields that should accept a bare string as a one-element list. -# Use this for any seed list field populated from YAML where authors/groups/etc. may be scalars. -StrOrList = Annotated[list[str], BeforeValidator(coerce_str_to_list)] - - class PartialUndefined(Undefined): """Jinja undefined value that preserves unresolved placeholders as text.""" @@ -120,16 +97,16 @@ class Seed(BaseModel): dataset_name: Optional[str] = None # Categories of harm associated with this prompt - harm_categories: Optional[StrOrList] = Field(default_factory=list) + harm_categories: Optional[list[str]] = Field(default_factory=list) # Description of the prompt description: Optional[str] = None # Authors of the prompt - authors: Optional[StrOrList] = Field(default_factory=list) + authors: Optional[list[str]] = Field(default_factory=list) # Groups affiliated with the prompt - groups: Optional[StrOrList] = Field(default_factory=list) + groups: Optional[list[str]] = Field(default_factory=list) # Source of the prompt source: Optional[str] = None diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index d8dc6de7c6..335cb07513 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) -from pyrit.models.seeds.seed import Seed, StrOrList +from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_attack_group import SeedAttackGroup from pyrit.models.seeds.seed_group import ( # noqa: TC001 (runtime-required by Pydantic field annotations) SeedGroup, @@ -84,10 +84,10 @@ class SeedDataset(BaseModel): data_type: Optional[str] = "text" name: Optional[str] = None dataset_name: Optional[str] = None - harm_categories: Optional[StrOrList] = None + harm_categories: Optional[list[str]] = None description: Optional[str] = None - authors: Optional[StrOrList] = Field(default_factory=list) - groups: Optional[StrOrList] = Field(default_factory=list) + authors: Optional[list[str]] = Field(default_factory=list) + groups: Optional[list[str]] = Field(default_factory=list) source: Optional[str] = None date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) added_by: Optional[str] = None diff --git a/pyrit/models/seeds/seed_loader.py b/pyrit/models/seeds/seed_loader.py index fc38d0fae7..eff859c203 100644 --- a/pyrit/models/seeds/seed_loader.py +++ b/pyrit/models/seeds/seed_loader.py @@ -32,6 +32,38 @@ T = TypeVar("T", bound="Seed") +# Seed model fields that callers may write as a bare string in YAML +# (e.g. ``authors: Jane Doe``) but the model declares as ``list[str]``. +# The loader wraps such scalars before constructing the model so the model +# itself can stay strict and YAML's "scalar-or-sequence" idiom doesn't leak +# into the data class. +_SCALAR_OR_LIST_FIELDS: tuple[str, ...] = ("harm_categories", "authors", "groups", "parameters") + + +def _canonicalize_scalar_lists(data: dict[str, Any]) -> dict[str, Any]: + """ + Wrap bare-string values into single-element lists for known list-typed seed fields. + + Mutates ``data`` in place and recurses into nested ``seeds`` entries so + dataset/group YAML files (which carry both top-level defaults and a list of seed + dicts) are normalized in one pass. + + Args: + data: A YAML-decoded mapping representing a seed, group, or dataset. + + Returns: + The same mapping, with scalar values on known list fields wrapped into lists. + """ + for key in _SCALAR_OR_LIST_FIELDS: + if isinstance(data.get(key), str): + data[key] = [data[key]] + seeds = data.get("seeds") + if isinstance(seeds, list): + for seed in seeds: + if isinstance(seed, dict): + _canonicalize_scalar_lists(seed) + return data + def _read_yaml(file: Union[str, Path]) -> dict[str, Any]: """ @@ -65,7 +97,10 @@ def load_seed_from_yaml(file: Union[str, Path], *, cls: type[T]) -> T: Load a single seed of type ``cls`` from a YAML file. The seed is marked ``is_jinja_template=True`` because the file is treated - as a trusted, vetted local template at this boundary. + as a trusted, vetted local template at this boundary. Bare-string values + for known list-typed fields (``authors``, ``harm_categories``, ``groups``, + ``parameters``) are wrapped into single-element lists so the model itself + can stay strict about its shape. Args: file: Path to the YAML file containing the seed definition. @@ -78,7 +113,7 @@ def load_seed_from_yaml(file: Union[str, Path], *, cls: type[T]) -> T: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML is malformed, empty, or fails validation for ``cls``. """ - data = _read_yaml(file) + data = _canonicalize_scalar_lists(_read_yaml(file)) data["is_jinja_template"] = True return cls(**data) @@ -102,7 +137,7 @@ def load_seed_dataset_from_yaml(file: Union[str, Path]) -> SeedDataset: """ from pyrit.models.seeds.seed_dataset import SeedDataset - data = _read_yaml(file) + data = _canonicalize_scalar_lists(_read_yaml(file)) data["is_jinja_template"] = True return SeedDataset.from_dict(data) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 6251f78c7b..7e18a679d8 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -20,7 +20,7 @@ ChatMessageRole, PromptDataType, ) -from pyrit.models.seeds.seed import Seed, StrOrList +from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: import uuid @@ -49,7 +49,7 @@ class SeedPrompt(Seed): sequence: int = 0 # Parameters that can be used in the prompt template - parameters: Optional[StrOrList] = Field(default_factory=list) + parameters: Optional[list[str]] = Field(default_factory=list) @model_validator(mode="after") def _render_and_infer_data_type(self) -> SeedPrompt: diff --git a/tests/unit/models/test_seed_loader.py b/tests/unit/models/test_seed_loader.py index 7f7f5fec91..904ff3c2af 100644 --- a/tests/unit/models/test_seed_loader.py +++ b/tests/unit/models/test_seed_loader.py @@ -147,3 +147,73 @@ def test_classmethod_shims_delegate_to_loader(tmp_path): assert via_classmethod.value == via_function.value assert via_classmethod.is_jinja_template == via_function.is_jinja_template is True + + +# ----- Scalar-to-list canonicalization (loader-side YAML accommodation) ----- + + +def test_load_seed_from_yaml_canonicalizes_scalar_authors_to_list(tmp_path): + yaml_file = _write(tmp_path, "p.yaml", "value: hi\nauthors: Jane Doe\n") + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.authors == ["Jane Doe"] + + +def test_load_seed_from_yaml_canonicalizes_scalar_parameters(tmp_path): + yaml_file = _write( + tmp_path, + "p.yaml", + "value: hello {{ name }}\nparameters: name\n", + ) + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.parameters == ["name"] + + +def test_load_seed_from_yaml_passes_through_list_authors_unchanged(tmp_path): + yaml_file = _write( + tmp_path, + "p.yaml", + "value: hi\nauthors:\n - Alice\n - Bob\n", + ) + + loaded = load_seed_from_yaml(yaml_file, cls=SeedPrompt) + + assert loaded.authors == ["Alice", "Bob"] + + +def test_load_seed_dataset_from_yaml_canonicalizes_top_level_and_nested(tmp_path): + yaml_file = _write( + tmp_path, + "ds.yaml", + ( + "name: tiny\n" + "authors: Top Author\n" + "harm_categories: harm-1\n" + "seeds:\n" + " - value: a\n" + " authors: Seed Author\n" + " groups: g1\n" + " - value: b\n" + ), + ) + + dataset = load_seed_dataset_from_yaml(yaml_file) + + # Top-level scalars wrapped; dataset-level harm_categories propagates to seeds. + assert dataset.harm_categories == ["harm-1"] + # Per-seed scalar authors is wrapped before dataset-level merge with ["Top Author"]. + assert sorted(dataset.seeds[0].authors or []) == sorted(["Top Author", "Seed Author"]) + assert dataset.seeds[0].groups == ["g1"] + # Pure inheritance from dataset defaults still works for the second seed. + assert dataset.seeds[1].authors == ["Top Author"] + + +def test_model_now_rejects_programmatic_scalar_string(): + """The wrap-scalar accommodation has moved to the loader; the model is strict.""" + import pydantic + + with pytest.raises(pydantic.ValidationError): + SeedPrompt(value="hi", authors="not a list") # type: ignore[arg-type] From b470031e044b882aa8cbf0bf6b819bbd6f55e5d8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:31:29 -0700 Subject: [PATCH 05/11] AtomicAttack: replace cross-class _check_invariants() call with isinstance check The previous loop `for sg in seed_groups: sg._check_invariants()` reached across a class boundary into a private hook and was both redundant and ineffective: - Redundant: SeedAttackGroup is a Pydantic v2 model; its `_finalize` validator already runs `_check_invariants` at construction time. By the time AtomicAttack receives an instance, it has been validated. - Ineffective: the docstring claimed it caught seed groups ''missing an objective'', but SeedGroup._check_invariants allows zero objectives. Only SeedAttackGroup enforces ''exactly one'', so the old call silently passed on any plain SeedGroup with no objective. Replace with an `isinstance(sg, SeedAttackGroup)` check that enforces the runtime contract already expressed by the type annotation `seed_groups: list[SeedAttackGroup]`. Raise TypeError with a clear message if a caller passes a plain SeedGroup or a SeedAttackTechniqueGroup. Update tests that were passing the wrong type to the typed parameter. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/core/atomic_attack.py | 17 +++++++++----- tests/unit/scenario/airt/test_jailbreak.py | 14 ++++++------ tests/unit/scenario/airt/test_scam.py | 26 +++++++++++++--------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index ec588e79ff..acec06f8b0 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -76,8 +76,8 @@ def __init__( technique seeds. Preferred over the deprecated ``attack`` parameter. attack: **Deprecated.** Will be removed in v0.16.0. The configured attack strategy to execute. Use ``attack_technique`` instead. - seed_groups: List of seed attack groups. Each seed group must - have an objective set. + seed_groups: List of seed attack groups. Each must be a + ``SeedAttackGroup`` (which guarantees exactly one objective). adversarial_chat: Optional chat target for generating adversarial prompts or simulated conversations. objective_scorer: Optional scorer for evaluating simulated @@ -87,8 +87,9 @@ def __init__( execution method. Raises: - ValueError: If seed_groups list is empty or any seed group is missing an objective. - ValueError: If neither attack_technique nor attack is provided, or both are provided. + ValueError: If seed_groups list is empty, or if neither attack_technique + nor attack is provided, or both are provided. + TypeError: If any entry of ``seed_groups`` is not a ``SeedAttackGroup``. """ self.atomic_attack_name = atomic_attack_name self.display_group = display_group or atomic_attack_name @@ -112,9 +113,13 @@ def __init__( if not seed_groups: raise ValueError("seed_groups list cannot be empty") - # Validate each seed group to ensure they are in a valid state + # Validate that each seed_group is actually a SeedAttackGroup (which Pydantic + # already ensured holds the AtomicAttack invariant of "exactly one objective" + # at construction time). A plain SeedGroup or SeedAttackTechniqueGroup is not + # accepted here even though they share a base class. for sg in seed_groups: - sg._check_invariants() + if not isinstance(sg, SeedAttackGroup): + raise TypeError(f"seed_groups must contain SeedAttackGroup instances; got {type(sg).__name__}.") self._seed_groups = seed_groups self._validate_unique_objective_hashes() diff --git a/tests/unit/scenario/airt/test_jailbreak.py b/tests/unit/scenario/airt/test_jailbreak.py index 307981f485..db07b40c0f 100644 --- a/tests/unit/scenario/airt/test_jailbreak.py +++ b/tests/unit/scenario/airt/test_jailbreak.py @@ -13,7 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.executor.attack.single_turn.role_play import RolePlayAttack from pyrit.executor.attack.single_turn.skeleton_key import SkeletonKeyAttack -from pyrit.models import ComponentIdentifier, SeedGroup, SeedObjective +from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedObjective from pyrit.prompt_target import PromptTarget from pyrit.scenario.core import BaselineAttackPolicy from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy @@ -44,10 +44,10 @@ def mock_scenario_result_id() -> str: @pytest.fixture -def mock_memory_seed_groups() -> list[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedAttackGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" return [ - SeedGroup(seeds=[SeedObjective(value=prompt)]) + SeedAttackGroup(seeds=[SeedObjective(value=prompt)]) for prompt in [ "sample objective 1", "sample objective 2", @@ -396,7 +396,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: """Test initialization with custom max_concurrency.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -409,7 +409,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: """Test initialization with memory labels.""" memory_labels = {"type": "jailbreak", "category": "scenario"} @@ -445,7 +445,7 @@ def test_scenario_default_dataset(self) -> None: assert Jailbreak.required_datasets() == ["airt_harms"] async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): @@ -489,7 +489,7 @@ async def test_roleplay_attacks_share_adversarial_target( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseInverterScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], roleplay_jailbreak_strategy: JailbreakStrategy, ) -> None: """Test that multiple role-play attacks share the same adversarial target instance.""" diff --git a/tests/unit/scenario/airt/test_scam.py b/tests/unit/scenario/airt/test_scam.py index 49befdbd45..3f2014578b 100644 --- a/tests/unit/scenario/airt/test_scam.py +++ b/tests/unit/scenario/airt/test_scam.py @@ -15,7 +15,7 @@ RolePlayAttack, ) from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedDataset, SeedGroup, SeedObjective +from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptTarget from pyrit.scenario import DatasetConfiguration from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy @@ -42,9 +42,9 @@ def _mock_target_id(name: str = "MockTarget") -> ComponentIdentifier: @pytest.fixture -def mock_memory_seed_groups() -> list[SeedGroup]: +def mock_memory_seed_groups() -> list[SeedAttackGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" - return [SeedGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] + return [SeedAttackGroup(seeds=[SeedObjective(value=prompt)]) for prompt in SEED_PROMPT_LIST] @pytest.fixture @@ -56,7 +56,7 @@ def mock_memory_seeds(): @pytest.fixture def mock_dataset_config(mock_memory_seed_groups): """Create a mock dataset config that returns the seed groups.""" - seed_attack_groups = [SeedAttackGroup(seeds=list(sg.seeds)) for sg in mock_memory_seed_groups] + seed_attack_groups = list(mock_memory_seed_groups) mock_config = MagicMock(spec=DatasetConfiguration) mock_config.get_all_seed_attack_groups.return_value = seed_attack_groups mock_config.get_default_dataset_names.return_value = ["airt_scam"] @@ -127,7 +127,7 @@ def test_init_with_default_objectives( self, *, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -141,7 +141,7 @@ def test_init_with_default_scorer(self, mock_memory_seed_groups) -> None: scenario = Scam() assert scenario._objective_scorer_identifier - def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedGroup]) -> None: + def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedAttackGroup]) -> None: """Test initialization with custom scorer.""" scorer = MagicMock(spec=TrueFalseCompositeScorer) @@ -150,7 +150,7 @@ def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedGrou assert isinstance(scenario._scorer_config, AttackScoringConfig) def test_init_default_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam(objective_scorer=mock_objective_scorer) @@ -159,7 +159,7 @@ def test_init_default_adversarial_chat( assert scenario._adversarial_chat._temperature == 1.2 def test_init_with_adversarial_chat( - self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedGroup] + self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") @@ -340,7 +340,7 @@ async def test_initialize_async_with_max_concurrency( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], mock_dataset_config, ) -> None: """Test initialization with custom max_concurrency.""" @@ -356,7 +356,7 @@ async def test_initialize_async_with_memory_labels( *, mock_objective_target: PromptTarget, mock_objective_scorer: TrueFalseCompositeScorer, - mock_memory_seed_groups: list[SeedGroup], + mock_memory_seed_groups: list[SeedAttackGroup], mock_dataset_config, ) -> None: """Test initialization with memory labels.""" @@ -389,7 +389,11 @@ def test_scenario_version_is_set( assert scenario.VERSION == 1 async def test_no_target_duplication_async( - self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedGroup], mock_dataset_config + self, + *, + mock_objective_target: PromptTarget, + mock_memory_seed_groups: list[SeedAttackGroup], + mock_dataset_config, ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): From 6bc53ad1db6e451c53152858791c4f2cca9b3d7d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:34:09 -0700 Subject: [PATCH 06/11] Rename seed_loader.py to yaml_seed_loader.py Make the loader's medium explicit in the filename so future non-YAML loaders (e.g. JSON, remote dataset, hub) read as siblings rather than overloading a single ''seed_loader'' module. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/__init__.py | 10 +++++----- pyrit/models/seeds/seed.py | 4 ++-- pyrit/models/seeds/seed_dataset.py | 4 ++-- pyrit/models/seeds/seed_prompt.py | 4 ++-- .../seeds/{seed_loader.py => yaml_seed_loader.py} | 0 .../{test_seed_loader.py => test_yaml_seed_loader.py} | 0 6 files changed, 11 insertions(+), 11 deletions(-) rename pyrit/models/seeds/{seed_loader.py => yaml_seed_loader.py} (100%) rename tests/unit/models/{test_seed_loader.py => test_yaml_seed_loader.py} (100%) diff --git a/pyrit/models/seeds/__init__.py b/pyrit/models/seeds/__init__.py index b89f9aed13..c182f198eb 100644 --- a/pyrit/models/seeds/__init__.py +++ b/pyrit/models/seeds/__init__.py @@ -20,11 +20,6 @@ from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup from pyrit.models.seeds.seed_dataset import SeedDataset from pyrit.models.seeds.seed_group import SeedGroup -from pyrit.models.seeds.seed_loader import ( - load_seed_dataset_from_yaml, - load_seed_from_yaml, - load_seed_prompt_from_yaml_with_required_parameters, -) from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import ( @@ -32,6 +27,11 @@ SeedSimulatedConversation, SimulatedTargetSystemPromptPaths, ) +from pyrit.models.seeds.yaml_seed_loader import ( + load_seed_dataset_from_yaml, + load_seed_from_yaml, + load_seed_prompt_from_yaml_with_required_parameters, +) __all__ = [ "NextMessageSystemPromptPaths", diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index e1471b673a..ba9d9224c7 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -242,7 +242,7 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: """ Create a new Seed from a YAML file, marking it as a trusted Jinja2 template. - Thin shim that delegates to :func:`pyrit.models.seeds.seed_loader.load_seed_from_yaml`; + Thin shim that delegates to :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_from_yaml`; file I/O and the ``is_jinja_template`` trust marker live in the loader module. Args: @@ -255,6 +255,6 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML file is invalid or empty. """ - from pyrit.models.seeds.seed_loader import load_seed_from_yaml + from pyrit.models.seeds.yaml_seed_loader import load_seed_from_yaml return load_seed_from_yaml(file, cls=cls) diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 335cb07513..52fcbc1ecc 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -180,7 +180,7 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. Thin shim that delegates to - :func:`pyrit.models.seeds.seed_loader.load_seed_dataset_from_yaml`; file I/O and + :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_dataset_from_yaml`; file I/O and the ``is_jinja_template`` trust marker live in the loader module. Args: @@ -193,7 +193,7 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML file is invalid or empty. """ - from pyrit.models.seeds.seed_loader import load_seed_dataset_from_yaml + from pyrit.models.seeds.yaml_seed_loader import load_seed_dataset_from_yaml return load_seed_dataset_from_yaml(file) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 7e18a679d8..a9fce68674 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -138,7 +138,7 @@ def from_yaml_with_required_parameters( Load a SeedPrompt from a YAML file and validate that it declares each required parameter. Thin shim that delegates to - :func:`pyrit.models.seeds.seed_loader.load_seed_prompt_from_yaml_with_required_parameters`. + :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_prompt_from_yaml_with_required_parameters`. Args: template_path: Path to the YAML file containing the template. @@ -151,7 +151,7 @@ def from_yaml_with_required_parameters( Raises: ValueError: If the template doesn't contain all required parameters. """ - from pyrit.models.seeds.seed_loader import load_seed_prompt_from_yaml_with_required_parameters + from pyrit.models.seeds.yaml_seed_loader import load_seed_prompt_from_yaml_with_required_parameters return load_seed_prompt_from_yaml_with_required_parameters( template_path, required_parameters, error_message=error_message diff --git a/pyrit/models/seeds/seed_loader.py b/pyrit/models/seeds/yaml_seed_loader.py similarity index 100% rename from pyrit/models/seeds/seed_loader.py rename to pyrit/models/seeds/yaml_seed_loader.py diff --git a/tests/unit/models/test_seed_loader.py b/tests/unit/models/test_yaml_seed_loader.py similarity index 100% rename from tests/unit/models/test_seed_loader.py rename to tests/unit/models/test_yaml_seed_loader.py From 5ff4c00d23d45e8f04fe1cf685265a36942c9f21 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:35:16 -0700 Subject: [PATCH 07/11] Drop 'Phase 9.1' reference from yaml_seed_loader docstring Per code review: phase references go stale fast and are confusing to future readers. Module-level docstring still explains the architectural reason the loader is separate; the planning context belonged in the gist, not the codebase. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/yaml_seed_loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyrit/models/seeds/yaml_seed_loader.py b/pyrit/models/seeds/yaml_seed_loader.py index eff859c203..242f31fbd0 100644 --- a/pyrit/models/seeds/yaml_seed_loader.py +++ b/pyrit/models/seeds/yaml_seed_loader.py @@ -11,8 +11,7 @@ ``is_jinja_template=True`` marker is set exactly once at this boundary. The ``from_yaml_file`` and ``from_yaml_with_required_parameters`` classmethods -on the seed classes are thin shims that delegate here. See the plan-gist -Phase 9.1 entry for the eventual move of this module to ``pyrit/io/seeds.py``. +on the seed classes are thin shims that delegate here. """ from __future__ import annotations From 31af07fe0a37d6b47f5173f094c4cca45540a79b Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:36:43 -0700 Subject: [PATCH 08/11] Drop SeedLeaf alias; collapse into SeedUnion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SeedLeaf was used only as the inner type of SeedUnion plus one local list annotation in _finalize. The two-alias pattern read as ''which one do I import?'' confusion without buying anything — Annotated[..., Field(discriminator=...)] is fine as a local variable annotation too (the Field metadata is ignored outside Pydantic field contexts). Inline the Union into SeedUnion's definition and use SeedUnion for the local list. SeedDataset and any future container can keep importing the single name. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/seed_group.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 80e26e17db..18baba6028 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -29,11 +29,14 @@ logger = logging.getLogger(__name__) -# Concrete leaf classes for the polymorphic seed list. Use SeedLeaf when you need a plain -# typing alias (e.g. for local list[...] variables); use SeedUnion in Pydantic field annotations -# so the discriminator is honored during validation. -SeedLeaf = Union[SeedPrompt, SeedObjective, SeedSimulatedConversation] -SeedUnion = Annotated[SeedLeaf, Field(discriminator="seed_type")] +# Polymorphic union of seed types that can appear inside a SeedGroup. The discriminator +# field ``seed_type`` is set per-leaf-class so Pydantic dispatches to the correct constructor +# during validation. Exported so SeedDataset (and any future container) can reuse the same +# tagged union for its own ``seeds`` field. +SeedUnion = Annotated[ + Union[SeedPrompt, SeedObjective, SeedSimulatedConversation], + Field(discriminator="seed_type"), +] class SeedGroup(BaseModel): @@ -119,7 +122,7 @@ def _finalize(self) -> SeedGroup: simulated_conv = self._get_simulated_conversation() sorted_prompts = sorted(self.prompts, key=lambda p: p.sequence if p.sequence is not None else 0) - new_seeds: list[SeedLeaf] = [] + new_seeds: list[SeedUnion] = [] if objective: new_seeds.append(objective) if simulated_conv: From 37cdc9b8eaed2f8e1b42bfe4f10a6a0e35ecd9ef Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:39:48 -0700 Subject: [PATCH 09/11] Hoist SeedPrompt/SeedDataset/Seed imports to top of yaml_seed_loader Per style guide: top-of-file imports unless deferred for heavy third-party packages. SeedPrompt, SeedDataset, and Seed are lightweight first-party. The __init__.py imports yaml_seed_loader last, and none of those modules import yaml_seed_loader at module load, so no circular import. Also switch Union[str, Path] / Optional[str] to X | Y / X | None to match the modern type-syntax rules in the style guide. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/yaml_seed_loader.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/pyrit/models/seeds/yaml_seed_loader.py b/pyrit/models/seeds/yaml_seed_loader.py index 242f31fbd0..5d6719ede2 100644 --- a/pyrit/models/seeds/yaml_seed_loader.py +++ b/pyrit/models/seeds/yaml_seed_loader.py @@ -16,20 +16,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar import yaml from pyrit.common.utils import verify_and_resolve_path +from pyrit.models.seeds.seed import Seed +from pyrit.models.seeds.seed_dataset import SeedDataset +from pyrit.models.seeds.seed_prompt import SeedPrompt if TYPE_CHECKING: from pathlib import Path - from pyrit.models.seeds.seed import Seed - from pyrit.models.seeds.seed_dataset import SeedDataset - from pyrit.models.seeds.seed_prompt import SeedPrompt - -T = TypeVar("T", bound="Seed") +T = TypeVar("T", bound=Seed) # Seed model fields that callers may write as a bare string in YAML # (e.g. ``authors: Jane Doe``) but the model declares as ``list[str]``. @@ -64,7 +63,7 @@ def _canonicalize_scalar_lists(data: dict[str, Any]) -> dict[str, Any]: return data -def _read_yaml(file: Union[str, Path]) -> dict[str, Any]: +def _read_yaml(file: str | Path) -> dict[str, Any]: """ Resolve, read, and parse a YAML file as a mapping. @@ -91,7 +90,7 @@ def _read_yaml(file: Union[str, Path]) -> dict[str, Any]: return data -def load_seed_from_yaml(file: Union[str, Path], *, cls: type[T]) -> T: +def load_seed_from_yaml(file: str | Path, *, cls: type[T]) -> T: """ Load a single seed of type ``cls`` from a YAML file. @@ -117,7 +116,7 @@ def load_seed_from_yaml(file: Union[str, Path], *, cls: type[T]) -> T: return cls(**data) -def load_seed_dataset_from_yaml(file: Union[str, Path]) -> SeedDataset: +def load_seed_dataset_from_yaml(file: str | Path) -> SeedDataset: """ Load a ``SeedDataset`` from a YAML file. @@ -134,18 +133,16 @@ def load_seed_dataset_from_yaml(file: Union[str, Path]) -> SeedDataset: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML is malformed, empty, or fails dataset validation. """ - from pyrit.models.seeds.seed_dataset import SeedDataset - data = _canonicalize_scalar_lists(_read_yaml(file)) data["is_jinja_template"] = True return SeedDataset.from_dict(data) def load_seed_prompt_from_yaml_with_required_parameters( - template_path: Union[str, Path], + template_path: str | Path, required_parameters: list[str], *, - error_message: Optional[str] = None, + error_message: str | None = None, ) -> SeedPrompt: """ Load a ``SeedPrompt`` and assert that its ``parameters`` field declares each required name. @@ -161,8 +158,6 @@ def load_seed_prompt_from_yaml_with_required_parameters( Raises: ValueError: If the loaded prompt is missing any required parameter. """ - from pyrit.models.seeds.seed_prompt import SeedPrompt - sp = load_seed_from_yaml(template_path, cls=SeedPrompt) if sp.parameters is None or not all(p in sp.parameters for p in required_parameters): if error_message is None: From 0c69fbbe1c03e06d9181c4971b5b4d61e058ec6d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:44:32 -0700 Subject: [PATCH 10/11] Replace _check_invariants call in AttackParameters with isinstance guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same pattern as the earlier atomic_attack fix: the prior _check_invariants() call on a SeedAttackGroup-typed parameter was redundant (Pydantic already validates at construction) and the downstream ''objective is None'' check guarded against an impossible state for a real SeedAttackGroup. The actually-useful runtime guard is rejecting incorrect subtypes — callers passing a plain SeedGroup, which is silently accepted by the type annotation but doesn''t enforce ''exactly one objective''. Switch to isinstance(seed_group, SeedAttackGroup) raising TypeError, drop the dead objective-None branch, and replace the test that exercised the impossible state with one covering the new isinstance guard. Audited the remaining _check_invariants references: only override/super() definitions in the three SeedGroup subclasses and two direct test calls that assert the method''s own behavior. No other cross-class consumers. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../executor/attack/core/attack_parameters.py | 19 ++++++++++-------- .../attack/core/test_attack_parameters.py | 20 ++++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 79bfa27b8c..fed08d5f2f 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -101,14 +101,21 @@ async def from_seed_group_async( An instance of this AttackParameters type. Raises: - ValueError: If seed_group has no objective or if overrides contain invalid fields. - ValueError: If seed_group has simulated conversation but adversarial_chat/scorer not provided. + TypeError: If ``seed_group`` is not a ``SeedAttackGroup``. + ValueError: If overrides contain invalid fields, or if seed_group has simulated + conversation but adversarial_chat/scorer not provided. """ # Import here to avoid circular imports from pyrit.executor.attack.multi_turn.simulated_conversation import ( generate_simulated_conversation_async, ) + if not isinstance(seed_group, SeedAttackGroup): + raise TypeError( + f"seed_group must be a SeedAttackGroup, got {type(seed_group).__name__}. " + "Plain SeedGroup does not enforce the 'exactly one objective' invariant required for an attack." + ) + # Get valid field names for this params type valid_fields = {f.name for f in dataclasses.fields(cls)} @@ -119,12 +126,8 @@ async def from_seed_group_async( f"{cls.__name__} does not accept parameters: {invalid_fields}. Accepted parameters: {valid_fields}" ) - # Validate seed_group state before extracting parameters - seed_group._check_invariants() - - # SeedAttackGroup validates in __init__ that objective is set - if seed_group.objective is None: - raise ValueError("seed_group.objective is not initialized") + # SeedAttackGroup's Pydantic validator guarantees exactly one objective is present. + assert seed_group.objective is not None # Build params dict, only including fields this class accepts params: dict[str, Any] = {} diff --git a/tests/unit/executor/attack/core/test_attack_parameters.py b/tests/unit/executor/attack/core/test_attack_parameters.py index 18cf37341f..5e8ca975c2 100644 --- a/tests/unit/executor/attack/core/test_attack_parameters.py +++ b/tests/unit/executor/attack/core/test_attack_parameters.py @@ -309,12 +309,14 @@ async def test_excluded_class_rejects_excluded_field_overrides(self) -> None: ) -async def test_from_seed_group_async_raises_when_objective_is_none(): - """Test that from_seed_group_async raises ValueError when seed_group.objective is None.""" - seed_group = MagicMock(spec=SeedAttackGroup) - seed_group.validate = MagicMock() - seed_group.objective = None - seed_group.simulated_conversation = None - - with pytest.raises(ValueError, match="seed_group.objective is not initialized"): - await AttackParameters.from_seed_group_async(seed_group=seed_group) +async def test_from_seed_group_async_rejects_plain_seed_group(): + """Plain SeedGroup is rejected at the boundary because it doesn't enforce the + 'exactly one objective' invariant SeedAttackGroup does. A real SeedAttackGroup + can't reach this method with objective=None — Pydantic validation at construction + blocks that — so the runtime guard targets the more interesting failure mode: + callers passing the wrong subtype.""" + from pyrit.models import SeedGroup + + plain_group = SeedGroup(seeds=[SeedObjective(value="Test objective")]) + with pytest.raises(TypeError, match="seed_group must be a SeedAttackGroup"): + await AttackParameters.from_seed_group_async(seed_group=plain_group) # type: ignore[arg-type] From c08a2b0e451d9e3ab07b49d48a5538904f492c48 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 3 Jun 2026 09:34:43 -0700 Subject: [PATCH 11/11] Apply seeds style-guide fixes: modern type syntax, AwareDatetime, alphabetized __all__ Replace Optional[X]/Union[X, Y] with PEP 604 X | None syntax across the seeds module, route date_added through an AwareDatetimeUTC validator that coerces naive datetimes and bare date strings to UTC, share PROMPT_ONLY_SEED_KEYS, convert reST :func:/:meth: roles to plain backticks, and alphabetize __all__. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/__init__.py | 6 +- pyrit/models/seeds/seed.py | 62 +++++++++++++------ .../seeds/seed_attack_technique_group.py | 4 +- pyrit/models/seeds/seed_dataset.py | 58 +++++++++-------- pyrit/models/seeds/seed_group.py | 29 +++++---- pyrit/models/seeds/seed_prompt.py | 18 +++--- .../seeds/seed_simulated_conversation.py | 10 +-- 7 files changed, 111 insertions(+), 76 deletions(-) diff --git a/pyrit/models/seeds/__init__.py b/pyrit/models/seeds/__init__.py index c182f198eb..7b359847e2 100644 --- a/pyrit/models/seeds/__init__.py +++ b/pyrit/models/seeds/__init__.py @@ -34,6 +34,9 @@ ) __all__ = [ + "load_seed_dataset_from_yaml", + "load_seed_from_yaml", + "load_seed_prompt_from_yaml_with_required_parameters", "NextMessageSystemPromptPaths", "Seed", "SeedAttackGroup", @@ -44,7 +47,4 @@ "SeedPrompt", "SeedSimulatedConversation", "SimulatedTargetSystemPromptPaths", - "load_seed_dataset_from_yaml", - "load_seed_from_yaml", - "load_seed_prompt_from_yaml_with_required_parameters", ] diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index ba9d9224c7..68af6ea87a 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -13,11 +13,11 @@ import re import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, TypeVar from jinja2 import StrictUndefined, Undefined from jinja2.sandbox import SandboxedEnvironment -from pydantic import BaseModel, ConfigDict, Field +from pydantic import AwareDatetime, BaseModel, BeforeValidator, ConfigDict, Field from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) @@ -31,6 +31,31 @@ T = TypeVar("T", bound="Seed") +def _ensure_aware_utc(value: Any) -> Any: + """ + Coerce naive datetimes (and bare date strings) to UTC so AwareDatetime accepts them. + + Args: + value: The raw value provided for a datetime field (string, datetime, or anything else). + + Returns: + Any: A timezone-aware datetime when the input was naive or a parseable date string; + otherwise the value unchanged for Pydantic to validate. + """ + if isinstance(value, str): + try: + value = datetime.fromisoformat(value) + except ValueError: + return value + if isinstance(value, datetime) and value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +# Timezone-aware datetime that interprets naive inputs as UTC instead of rejecting them. +AwareDatetimeUTC = Annotated[AwareDatetime, BeforeValidator(_ensure_aware_utc)] + + class PartialUndefined(Undefined): """Jinja undefined value that preserves unresolved placeholders as text.""" @@ -85,46 +110,46 @@ class Seed(BaseModel): value: str # SHA256 hash of the value, used for deduplication - value_sha256: Optional[str] = None + value_sha256: str | None = None # Unique identifier for the prompt - id: Optional[uuid.UUID] = Field(default_factory=uuid.uuid4) + id: uuid.UUID | None = Field(default_factory=uuid.uuid4) # Name of the prompt - name: Optional[str] = None + name: str | None = None # Name of the dataset this prompt belongs to - dataset_name: Optional[str] = None + dataset_name: str | None = None # Categories of harm associated with this prompt - harm_categories: Optional[list[str]] = Field(default_factory=list) + harm_categories: list[str] | None = Field(default_factory=list) # Description of the prompt - description: Optional[str] = None + description: str | None = None # Authors of the prompt - authors: Optional[list[str]] = Field(default_factory=list) + authors: list[str] | None = Field(default_factory=list) # Groups affiliated with the prompt - groups: Optional[list[str]] = Field(default_factory=list) + groups: list[str] | None = Field(default_factory=list) # Source of the prompt - source: Optional[str] = None + source: str | None = None # Date when the prompt was added to the dataset - date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + date_added: AwareDatetimeUTC | None = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) # User who added the prompt to the dataset - added_by: Optional[str] = None + added_by: str | None = None # Arbitrary metadata that can be attached to the prompt - metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + metadata: dict[str, Any] | None = Field(default_factory=dict) # Unique identifier for the prompt group - prompt_group_id: Optional[uuid.UUID] = None + prompt_group_id: uuid.UUID | None = None # Alias for the prompt group - prompt_group_alias: Optional[str] = None + prompt_group_alias: str | None = None # Whether this seed represents a general attack technique (not tied to a specific objective) is_general_technique: bool = False @@ -238,11 +263,11 @@ def escape_for_jinja(value: str) -> str: return f"{{% raw %}}{value}{{% endraw %}}" @classmethod - def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: + def from_yaml_file(cls: type[T], file: str | Path) -> T: """ Create a new Seed from a YAML file, marking it as a trusted Jinja2 template. - Thin shim that delegates to :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_from_yaml`; + Thin shim that delegates to ``load_seed_from_yaml`` in the ``yaml_seed_loader`` module; file I/O and the ``is_jinja_template`` trust marker live in the loader module. Args: @@ -255,6 +280,7 @@ def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML file is invalid or empty. """ + # Deferred import: yaml_seed_loader imports Seed, so importing it at module top would cycle. from pyrit.models.seeds.yaml_seed_loader import load_seed_from_yaml return load_seed_from_yaml(file, cls=cls) diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index 536e125fb0..f21735641d 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,8 +11,6 @@ from __future__ import annotations -from typing import Optional - from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective @@ -30,7 +28,7 @@ class SeedAttackTechniqueGroup(SeedGroup): # Where to insert technique seeds when merging into a SeedAttackGroup via ``with_technique()``. # ``None`` (default) appends at the end; an integer inserts before that position. - insertion_index: Optional[int] = None + insertion_index: int | None = None def _check_invariants(self) -> None: """ diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 52fcbc1ecc..149f5e79e7 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -12,14 +12,18 @@ import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field, model_validator from pyrit.models.literals import SeedType # noqa: TC001 (runtime-required by Pydantic field annotations) -from pyrit.models.seeds.seed import Seed +from pyrit.models.seeds.seed import ( # noqa: TC001 (AwareDatetimeUTC is runtime-required by Pydantic) + AwareDatetimeUTC, + Seed, +) from pyrit.models.seeds.seed_attack_group import SeedAttackGroup from pyrit.models.seeds.seed_group import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + PROMPT_ONLY_SEED_KEYS, SeedGroup, SeedUnion, ) @@ -81,18 +85,18 @@ class SeedDataset(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - data_type: Optional[str] = "text" - name: Optional[str] = None - dataset_name: Optional[str] = None - harm_categories: Optional[list[str]] = None - description: Optional[str] = None - authors: Optional[list[str]] = Field(default_factory=list) - groups: Optional[list[str]] = Field(default_factory=list) - source: Optional[str] = None - date_added: Optional[datetime] = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) - added_by: Optional[str] = None + data_type: str | None = "text" + name: str | None = None + dataset_name: str | None = None + harm_categories: list[str] | None = None + description: str | None = None + authors: list[str] | None = Field(default_factory=list) + groups: list[str] | None = Field(default_factory=list) + source: str | None = None + date_added: AwareDatetimeUTC | None = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + added_by: str | None = None # The default seed type for items that don't specify their own ("prompt", "objective", ...). - seed_type: Optional[SeedType] = None + seed_type: SeedType | None = None # The actual prompts seeds: list[SeedUnion] @@ -166,7 +170,7 @@ def _build_seeds(cls, data: Any) -> Any: # Non-prompt seeds narrow data_type to Literal["text"] and don't have # role/sequence/parameters fields. Drop those so dataset-level defaults # don't bleed in and trip extra="forbid" on the leaf class. - for prompt_only in ("data_type", "role", "sequence", "parameters"): + for prompt_only in PROMPT_ONLY_SEED_KEYS: p.pop(prompt_only, None) normalized.append(p) @@ -175,12 +179,12 @@ def _build_seeds(cls, data: Any) -> Any: return data @classmethod - def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: + def from_yaml_file(cls, file: str | Path) -> SeedDataset: """ Create a SeedDataset from a YAML file, marking nested seeds as trusted templates. Thin shim that delegates to - :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_dataset_from_yaml`; file I/O and + ``pyrit.models.seeds.yaml_seed_loader.load_seed_dataset_from_yaml``; file I/O and the ``is_jinja_template`` trust marker live in the loader module. Args: @@ -193,6 +197,8 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: FileNotFoundError: If the path does not resolve to an existing file. ValueError: If the YAML file is invalid or empty. """ + # Deferred import: yaml_seed_loader imports SeedDataset at module load, so importing + # it at the top of this module would create a circular import. from pyrit.models.seeds.yaml_seed_loader import load_seed_dataset_from_yaml return load_seed_dataset_from_yaml(file) @@ -200,17 +206,17 @@ def from_yaml_file(cls, file: Union[str, Path]) -> SeedDataset: def get_values( self, *, - first: Optional[PositiveInt] = None, - last: Optional[PositiveInt] = None, - harm_categories: Optional[Sequence[str]] = None, + first: PositiveInt | None = None, + last: PositiveInt | None = None, + harm_categories: Sequence[str] | None = None, ) -> Sequence[str]: """ Extract and return prompt values from the dataset. Args: - first (Optional[int]): If provided, values from the first N prompts are included. - last (Optional[int]): If provided, values from the last N prompts are included. - harm_categories (Optional[Sequence[str]]): If provided, only prompts containing at least one of + first (int | None): If provided, values from the first N prompts are included. + last (int | None): If provided, values from the last N prompts are included. + harm_categories (Sequence[str] | None): If provided, only prompts containing at least one of these harm categories are included. Returns: @@ -238,15 +244,13 @@ def get_values( return first_part + last_part - def get_random_values( - self, *, number: PositiveInt, harm_categories: Optional[Sequence[str]] = None - ) -> Sequence[str]: + def get_random_values(self, *, number: PositiveInt, harm_categories: Sequence[str] | None = None) -> Sequence[str]: """ Extract and return random prompt values from the dataset. Args: number (int): The number of random prompt values to return. - harm_categories (Optional[Sequence[str]]): If provided, only prompts containing at least one of + harm_categories (Sequence[str] | None): If provided, only prompts containing at least one of these harm categories are included. Returns: @@ -261,7 +265,7 @@ def from_dict(cls, data: dict[str, Any]) -> SeedDataset: """ Build a SeedDataset, assigning per-seed ``prompt_group_id`` by alias. - Default merging now lives in :meth:`_build_seeds` so direct construction and + Default merging now lives in ``_build_seeds`` so direct construction and ``from_dict`` produce equivalent results. This method handles the YAML-only concerns: rejecting pre-set ``prompt_group_id`` on input seeds and resolving ``prompt_group_alias`` into a shared ``prompt_group_id``. diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 13b1a6a2a4..3998820c38 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -13,7 +13,7 @@ import logging import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -34,10 +34,15 @@ # during validation. Exported so SeedDataset (and any future container) can reuse the same # tagged union for its own ``seeds`` field. SeedUnion = Annotated[ - Union[SeedPrompt, SeedObjective, SeedSimulatedConversation], + SeedPrompt | SeedObjective | SeedSimulatedConversation, Field(discriminator="seed_type"), ] +# Fields that only exist on prompt-type seeds. They are stripped from non-prompt seed dicts so +# dataset/group-level defaults (e.g. ``data_type: image_path``) don't bleed in and trip +# ``extra="forbid"`` on the leaf class. Shared with SeedDataset, which imports it from here. +PROMPT_ONLY_SEED_KEYS = ("data_type", "role", "sequence", "parameters") + class SeedGroup(BaseModel): """ @@ -99,7 +104,7 @@ def _coerce_seeds(cls, data: Any) -> Any: # Non-prompt seeds narrow data_type to Literal["text"] and don't have # role/sequence/parameters fields. Drop them so dataset/group-level # defaults don't bleed in and trip extra="forbid". - for prompt_only in ("data_type", "role", "sequence", "parameters"): + for prompt_only in PROMPT_ONLY_SEED_KEYS: seed.pop(prompt_only, None) normalized.append(seed) @@ -257,12 +262,12 @@ def _enforce_no_sequence_overlap_with_simulated(self) -> None: # Seed Accessors # ========================================================================= - def _get_objective(self) -> Optional[SeedObjective]: + def _get_objective(self) -> SeedObjective | None: """ Get the objective seed if present. Returns: - Optional[SeedObjective]: Objective seed when available; otherwise None. + SeedObjective | None: Objective seed when available; otherwise None. """ for seed in self.seeds: @@ -270,12 +275,12 @@ def _get_objective(self) -> Optional[SeedObjective]: return seed return None - def _get_simulated_conversation(self) -> Optional[SeedSimulatedConversation]: + def _get_simulated_conversation(self) -> SeedSimulatedConversation | None: """ Get the simulated conversation seed if present. Returns: - Optional[SeedSimulatedConversation]: Simulated conversation seed when available; otherwise None. + SeedSimulatedConversation | None: Simulated conversation seed when available; otherwise None. """ for seed in self.seeds: @@ -289,7 +294,7 @@ def prompts(self) -> Sequence[SeedPrompt]: return [seed for seed in self.seeds if isinstance(seed, SeedPrompt)] @property - def objective(self) -> Optional[SeedObjective]: + def objective(self) -> SeedObjective | None: """Get the objective for this group.""" return self._get_objective() @@ -313,7 +318,7 @@ def harm_categories(self) -> list[str]: # ========================================================================= @property - def simulated_conversation_config(self) -> Optional[SeedSimulatedConversation]: + def simulated_conversation_config(self) -> SeedSimulatedConversation | None: """Get the simulated conversation configuration if set.""" return self._get_simulated_conversation() @@ -327,7 +332,7 @@ def has_simulated_conversation(self) -> bool: # ========================================================================= @property - def prepended_conversation(self) -> Optional[list[Message]]: + def prepended_conversation(self) -> list[Message] | None: """ Returns Messages that should be prepended as conversation history. @@ -357,7 +362,7 @@ def prepended_conversation(self) -> Optional[list[Message]]: return self._prompts_to_messages(list(self.prompts)) @property - def next_message(self) -> Optional[Message]: + def next_message(self) -> Message | None: """ Returns a Message containing only the last turn's prompts if it's a user message. @@ -397,7 +402,7 @@ def user_messages(self) -> list[Message]: return self._prompts_to_messages(list(self.prompts)) - def _get_last_sequence_role(self) -> Optional[str]: + def _get_last_sequence_role(self) -> str | None: """ Get the role of the last sequence. diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 6357fbf6ae..027b6935e9 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -9,7 +9,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal from pydantic import Field, model_validator from tinytag import TinyTag @@ -38,17 +38,17 @@ class SeedPrompt(Seed): # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) # This field overrides the base default to allow per-prompt data types inferred from the value - data_type: Optional[PromptDataType] = None + data_type: PromptDataType | None = None # Role of the prompt in a conversation (e.g., "user", "assistant") - role: Optional[ChatMessageRole] = None + role: ChatMessageRole | None = None # Sequence number for ordering prompts in a conversation, prompts with # the same sequence number are grouped together if they also share the same prompt_group_id sequence: int = 0 # Parameters that can be used in the prompt template - parameters: Optional[list[str]] = Field(default_factory=list) + parameters: list[str] | None = Field(default_factory=list) @model_validator(mode="after") def _render_and_infer_data_type(self) -> SeedPrompt: @@ -136,15 +136,15 @@ def set_encoding_metadata(self) -> None: @classmethod def from_yaml_with_required_parameters( cls, - template_path: Union[str, Path], + template_path: str | Path, required_parameters: list[str], - error_message: Optional[str] = None, + error_message: str | None = None, ) -> SeedPrompt: """ Load a SeedPrompt from a YAML file and validate that it declares each required parameter. Thin shim that delegates to - :func:`pyrit.models.seeds.yaml_seed_loader.load_seed_prompt_from_yaml_with_required_parameters`. + ``pyrit.models.seeds.yaml_seed_loader.load_seed_prompt_from_yaml_with_required_parameters``. Args: template_path: Path to the YAML file containing the template. @@ -157,6 +157,8 @@ def from_yaml_with_required_parameters( Raises: ValueError: If the template doesn't contain all required parameters. """ + # Deferred import: yaml_seed_loader imports SeedPrompt at module load, so importing + # it at the top of this module would create a circular import. from pyrit.models.seeds.yaml_seed_loader import load_seed_prompt_from_yaml_with_required_parameters return load_seed_prompt_from_yaml_with_required_parameters( @@ -168,7 +170,7 @@ def from_messages( messages: list[Message], *, starting_sequence: int = 0, - prompt_group_id: Optional[uuid.UUID] = None, + prompt_group_id: uuid.UUID | None = None, ) -> list[SeedPrompt]: """ Convert a list of Messages to a list of SeedPrompts. diff --git a/pyrit/models/seeds/seed_simulated_conversation.py b/pyrit/models/seeds/seed_simulated_conversation.py index b1359445c8..7b5daa0009 100644 --- a/pyrit/models/seeds/seed_simulated_conversation.py +++ b/pyrit/models/seeds/seed_simulated_conversation.py @@ -19,7 +19,7 @@ import json import logging from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from pydantic import field_validator, model_validator @@ -87,8 +87,8 @@ class SeedSimulatedConversation(Seed): sequence: int = 0 adversarial_chat_system_prompt_path: Path simulated_target_system_prompt_path: Path = SimulatedTargetSystemPromptPaths.COMPLIANT.value - next_message_system_prompt_path: Optional[Path] = None - pyrit_version: Optional[str] = None + next_message_system_prompt_path: Path | None = None + pyrit_version: str | None = None @model_validator(mode="before") @classmethod @@ -182,8 +182,8 @@ def load_simulated_target_system_prompt( *, objective: str, num_turns: int, - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - ) -> Optional[str]: + simulated_target_system_prompt_path: str | Path | None = None, + ) -> str | None: """ Load and render the simulated target system prompt.