Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions .github/instructions/models.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@ applyTo: "pyrit/models/**"

## Import Boundary

`pyrit.models` is the canonical data layer. Files in `pyrit/models/` may
import only from:
PyRIT enforces a two-layer rule for its foundational packages. `pyrit.common`
is the foundation layer and `pyrit.models` is the canonical data layer that sits
directly on top of it.

**Forward (models).** Files in `pyrit/models/` may import only from:

- the standard library
- `pydantic`
- `pyrit.common.deprecation`
- any `pyrit.common.*` submodule (the whole prefix)
- other `pyrit.models.*` submodules

If a helper needs another `pyrit.*` package, it does not belong on a model —
put it in that package as a free function or static helper.
If a helper needs another `pyrit.*` package (e.g. `pyrit.memory`,
`pyrit.score`), it does not belong on a model — put it in that package as a free
function or static helper.

**Reverse guard (common).** Files in `pyrit/common/` may import only from the
standard library, third-party libraries, and other `pyrit.common.*` submodules.
`pyrit.common` may never import any other `pyrit.*` package — this is what keeps
it a true foundation and prevents an import cycle with `pyrit.models`. If a
`pyrit.common` helper needs `pyrit.models` (or anything higher), it belongs in
that higher layer, not in `common`.

The CI test `tests/unit/models/test_import_boundary.py` enforces this using an
allowlist of known transitional violations, each tagged with the phase that
removes it. The list must shrink monotonically: removing an import from source
without also removing its allowlist entry fails the test, and adding a new
unlisted import also fails the test.
The CI test `tests/unit/models/test_import_boundary.py` enforces both directions
using allowlists of known transitional violations, each tagged with the phase
that removes it. The lists must shrink monotonically: removing an import from
source without also removing its allowlist entry fails the test, and adding a
new unlisted import also fails the test.
19 changes: 11 additions & 8 deletions pyrit/executor/attack/core/attack_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,21 @@ async def from_seed_group_async(
An instance of this AttackParameters type.

Raises:
ValueError: If seed_group has no objective or if overrides contain invalid fields.
ValueError: If seed_group has simulated conversation but adversarial_chat/scorer not provided.
TypeError: If ``seed_group`` is not a ``SeedAttackGroup``.
ValueError: If overrides contain invalid fields, or if seed_group has simulated
conversation but adversarial_chat/scorer not provided.
"""
# Import here to avoid circular imports
from pyrit.executor.attack.multi_turn.simulated_conversation import (
generate_simulated_conversation_async,
)

if not isinstance(seed_group, SeedAttackGroup):
raise TypeError(
f"seed_group must be a SeedAttackGroup, got {type(seed_group).__name__}. "
"Plain SeedGroup does not enforce the 'exactly one objective' invariant required for an attack."
)

# Get valid field names for this params type
valid_fields = {f.name for f in dataclasses.fields(cls)}

Expand All @@ -119,12 +126,8 @@ async def from_seed_group_async(
f"{cls.__name__} does not accept parameters: {invalid_fields}. Accepted parameters: {valid_fields}"
)

# Validate seed_group state before extracting parameters
seed_group.validate()

# SeedAttackGroup validates in __init__ that objective is set
if seed_group.objective is None:
raise ValueError("seed_group.objective is not initialized")
# SeedAttackGroup's Pydantic validator guarantees exactly one objective is present.
assert seed_group.objective is not None

# Build params dict, only including fields this class accepts
params: dict[str, Any] = {}
Expand Down
18 changes: 13 additions & 5 deletions pyrit/models/seeds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,24 @@
SeedSimulatedConversation,
SimulatedTargetSystemPromptPaths,
)
from pyrit.models.seeds.yaml_seed_loader import (
load_seed_dataset_from_yaml,
load_seed_from_yaml,
load_seed_prompt_from_yaml_with_required_parameters,
)

__all__ = [
"load_seed_dataset_from_yaml",
"load_seed_from_yaml",
"load_seed_prompt_from_yaml_with_required_parameters",
"NextMessageSystemPromptPaths",
"Seed",
"SeedPrompt",
"SeedObjective",
"SeedGroup",
"SeedAttackGroup",
"SeedAttackTechniqueGroup",
"SeedDataset",
"SeedGroup",
"SeedObjective",
"SeedPrompt",
"SeedSimulatedConversation",
"SimulatedTargetSystemPromptPaths",
"NextMessageSystemPromptPaths",
"SeedDataset",
]
124 changes: 59 additions & 65 deletions pyrit/models/seeds/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,53 @@

from __future__ import annotations

import abc
import logging
import re
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Annotated, Any, TypeVar

import yaml
from jinja2 import StrictUndefined, Undefined
from jinja2.sandbox import SandboxedEnvironment
from pydantic import AwareDatetime, BaseModel, BeforeValidator, ConfigDict, Field

from pyrit.common.utils import verify_and_resolve_path
from pyrit.common.yaml_loadable import YamlLoadable
from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations)

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Iterator
from pathlib import Path

from pyrit.models.literals import PromptDataType

logger = logging.getLogger(__name__)

# TypeVar for generic return type in class methods
T = TypeVar("T", bound="Seed")


def _ensure_aware_utc(value: Any) -> Any:
"""
Coerce naive datetimes (and bare date strings) to UTC so AwareDatetime accepts them.

Args:
value: The raw value provided for a datetime field (string, datetime, or anything else).

Returns:
Any: A timezone-aware datetime when the input was naive or a parseable date string;
otherwise the value unchanged for Pydantic to validate.
"""
if isinstance(value, str):
try:
value = datetime.fromisoformat(value)
except ValueError:
return value
if isinstance(value, datetime) and value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value


# Timezone-aware datetime that interprets naive inputs as UTC instead of rejecting them.
AwareDatetimeUTC = Annotated[AwareDatetime, BeforeValidator(_ensure_aware_utc)]


class PartialUndefined(Undefined):
"""Jinja undefined value that preserves unresolved placeholders as text."""

Expand Down Expand Up @@ -81,54 +101,55 @@ def __bool__(self) -> bool:
return True # Ensures it doesn't evaluate to False


@dataclass
class Seed(YamlLoadable):
class Seed(BaseModel):
"""Represents seed data with various attributes and metadata."""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

# The actual prompt value, which can be a string or a file path
value: str

# SHA256 hash of the value, used for deduplication
value_sha256: Optional[str] = None
value_sha256: str | None = None

# Unique identifier for the prompt
id: Optional[uuid.UUID] = field(default_factory=lambda: uuid.uuid4())
id: uuid.UUID | None = Field(default_factory=uuid.uuid4)

# Name of the prompt
name: Optional[str] = None
name: str | None = None

# Name of the dataset this prompt belongs to
dataset_name: Optional[str] = None
dataset_name: str | None = None

# Categories of harm associated with this prompt
harm_categories: Optional[Sequence[str]] = field(default_factory=list)
harm_categories: list[str] | None = Field(default_factory=list)

# Description of the prompt
description: Optional[str] = None
description: str | None = None

# Authors of the prompt
authors: Optional[Sequence[str]] = field(default_factory=list)
authors: list[str] | None = Field(default_factory=list)

# Groups affiliated with the prompt
groups: Optional[Sequence[str]] = field(default_factory=list)
groups: list[str] | None = Field(default_factory=list)

# Source of the prompt
source: Optional[str] = None
source: str | None = None

# Date when the prompt was added to the dataset
date_added: Optional[datetime] = field(default_factory=lambda: datetime.now(tz=timezone.utc))
date_added: AwareDatetimeUTC | None = Field(default_factory=lambda: datetime.now(tz=timezone.utc))

# User who added the prompt to the dataset
added_by: Optional[str] = None
added_by: str | None = None

# Arbitrary metadata that can be attached to the prompt
metadata: Optional[dict[str, Union[str, int]]] = field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=dict)

# Unique identifier for the prompt group
prompt_group_id: Optional[uuid.UUID] = None
prompt_group_id: uuid.UUID | None = None

# Alias for the prompt group
prompt_group_alias: Optional[str] = None
prompt_group_alias: str | None = None

# Whether this seed represents a general attack technique (not tied to a specific objective)
is_general_technique: bool = False
Expand All @@ -138,15 +159,10 @@ class Seed(YamlLoadable):
# to prevent template injection. Trusted sources (YAML files) set this to True automatically.
is_jinja_template: bool = False

@property
def data_type(self) -> PromptDataType:
"""
Return the data type for this seed.

Base implementation returns 'text'. SeedPrompt overrides this
to support multiple data types (image_path, audio_path, etc.).
"""
return "text"
# The type of data this seed represents (e.g., text, image_path, audio_path, video_path).
# SeedPrompt overrides the default to None and infers it from the value; other seed types
# narrow it to Literal["text"].
data_type: PromptDataType = "text"

def render_template_value(self, **kwargs: Any) -> str:
"""
Expand Down Expand Up @@ -247,46 +263,24 @@ def escape_for_jinja(value: str) -> str:
return f"{{% raw %}}{value}{{% endraw %}}"

@classmethod
def from_yaml_file(cls: type[T], file: Union[str, Path]) -> T:
def from_yaml_file(cls: type[T], file: str | Path) -> T:
"""
Create a new Seed from a YAML file, marking it as a trusted Jinja2 template.

Thin shim that delegates to ``load_seed_from_yaml`` in the ``yaml_seed_loader`` module;
file I/O and the ``is_jinja_template`` trust marker live in the loader module.

Args:
file: The input file path.

Returns:
A new Seed of the specific subclass type.

Raises:
ValueError: If the YAML file is invalid.
"""
file = verify_and_resolve_path(file)

try:
yaml_data = yaml.safe_load(file.read_text("utf-8"))
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML file '{file}': {exc}") from exc

yaml_data["is_jinja_template"] = True
return cls(**yaml_data)

@classmethod
@abc.abstractmethod
def from_yaml_with_required_parameters(
cls,
template_path: Union[str, Path],
required_parameters: list[str],
error_message: Optional[str] = None,
) -> Seed:
FileNotFoundError: If the path does not resolve to an existing file.
ValueError: If the YAML file is invalid or empty.
"""
Load a Seed from a YAML file and validate that it contains specific parameters.

Args:
template_path: Path to the YAML file containing the template.
required_parameters: List of parameter names that must exist in the template.
error_message: Custom error message if validation fails. If None, a default message is used.
# Deferred import: yaml_seed_loader imports Seed, so importing it at module top would cycle.
from pyrit.models.seeds.yaml_seed_loader import load_seed_from_yaml

Returns:
Seed: The loaded and validated seed of the specific subclass type.

"""
return load_seed_from_yaml(file, cls=cls)
25 changes: 3 additions & 22 deletions pyrit/models/seeds/seed_attack_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING

from pyrit.models.seeds.seed_group import SeedGroup
from pyrit.models.seeds.seed_objective import SeedObjective

if TYPE_CHECKING:
from collections.abc import Sequence

from pyrit.models.seeds.seed import Seed
from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup


Expand All @@ -32,25 +31,7 @@ class SeedAttackGroup(SeedGroup):
next_message, etc.) is inherited from SeedGroup.
"""

def __init__(
self,
*,
seeds: Sequence[Union[Seed, dict[str, Any]]],
) -> None:
"""
Initialize a SeedAttackGroup.

Args:
seeds: Sequence of seeds. Must include exactly one SeedObjective.

Raises:
ValueError: If seeds is empty.
ValueError: If exactly one objective is not provided.

"""
super().__init__(seeds=seeds)

def validate(self) -> None:
def _check_invariants(self) -> None:
"""
Validate the seed attack group state.

Expand All @@ -60,7 +41,7 @@ def validate(self) -> None:
ValueError: If validation fails.

"""
super().validate()
super()._check_invariants()
self._enforce_exactly_one_objective()

def _enforce_exactly_one_objective(self) -> None:
Expand Down
Loading
Loading