From aab30238161567221b545b86555a3c590ffa7dc4 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 15:17:26 -0700 Subject: [PATCH 1/5] Reorganize AttackResult/StrategyResult into pyrit.models.results package Move AttackResult, AttackOutcome, and StrategyResult into a new pyrit.models.results sub-package, mirroring the pyrit.models.messages reorganization. The old pyrit/models/attack_result.py and pyrit/models/strategy_result.py modules become silent backward-compat shims that re-export from the new location, so existing deep-path imports keep working. Classes remain @dataclass for now; the Pydantic v2 conversion lands in a follow-up once Score (Phase 5) is merged. Also add a .gitignore negation so the new package dir is not caught by the existing runtime 'results/' ignore rule. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .gitignore | 1 + pyrit/models/__init__.py | 4 +- pyrit/models/attack_result.py | 346 +--------------------- pyrit/models/results/__init__.py | 21 ++ pyrit/models/results/attack_result.py | 343 +++++++++++++++++++++ pyrit/models/results/strategy_result.py | 26 ++ pyrit/models/strategy_result.py | 28 +- tests/unit/models/test_attack_result.py | 2 +- tests/unit/models/test_scenario_result.py | 2 +- tests/unit/models/test_strategy_result.py | 2 +- 10 files changed, 421 insertions(+), 354 deletions(-) create mode 100644 pyrit/models/results/__init__.py create mode 100644 pyrit/models/results/attack_result.py create mode 100644 pyrit/models/results/strategy_result.py diff --git a/.gitignore b/.gitignore index 3668d43b6d..9bb4847b0f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # PyRIT-specific configs submodules/ results/ +!pyrit/models/results/ dbdata/ eval/ default_memory.json.memory diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index cb071a0b9a..520f37f6e2 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Any from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.attack_result import AttackOutcome, AttackResult, AttackResultT from pyrit.models.chat_message import ( ALLOWED_CHAT_MESSAGE_ROLES, ChatMessage, @@ -74,6 +73,8 @@ sort_message_pieces, ) from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice +from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT from pyrit.models.retry_event import RetryEvent from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.models.score import Score, ScoreType, UnvalidatedScore @@ -96,7 +97,6 @@ # Keep old module-level imports working (deprecated, will be removed) # These are re-exported from the seeds submodule from pyrit.models.storage_io import AzureBlobStorageIO, DiskStorageIO, StorageIO -from pyrit.models.strategy_result import StrategyResult, StrategyResultT __all__ = [ "ALLOWED_CHAT_MESSAGE_ROLES", diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index eaafc67030..19ab95a058 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -1,343 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations +""" +Backward-compatibility shim. -import functools -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum -from typing import Any, Optional, TypeVar +``AttackResult`` and ``AttackOutcome`` now live in ``pyrit.models.results``. +Import from there (or from ``pyrit.models``) instead. This module re-exports the +public names so existing ``from pyrit.models.attack_result import ...`` imports +keep working. +""" -from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.conversation_reference import ConversationReference, ConversationType -from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier -from pyrit.models.identifiers.component_identifier import ComponentIdentifier -from pyrit.models.messages.message_piece import MessagePiece -from pyrit.models.retry_event import RetryEvent -from pyrit.models.score import Score -from pyrit.models.strategy_result import StrategyResult +from typing import Any -AttackResultT = TypeVar("AttackResultT", bound="AttackResult") +from pyrit.models.results import attack_result as _attack_result +from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT -class AttackOutcome(str, Enum): - """ - Enum representing the possible outcomes of an attack. +def __getattr__(name: str) -> Any: + return getattr(_attack_result, name) - Inherits from ``str`` so that values serialize naturally in Pydantic - models and REST responses without a dedicated mapping function. - """ - # The attack was successful in achieving its objective - SUCCESS = "success" - - # The attack failed to achieve its objective - FAILURE = "failure" - - # The attack failed due to an infrastructure error (exception), not a defensive refusal - ERROR = "error" - - # The outcome of the attack is unknown or could not be determined - UNDETERMINED = "undetermined" - - -@dataclass -class AttackResult(StrategyResult): - """Base class for all attack results.""" - - # Identity - # Unique identifier of the conversation that produced this result - conversation_id: str - - # Natural-language description of the attacker's objective - objective: str - - # Database-assigned unique ID for this AttackResult row. - # Auto-generated if not provided (e.g. when loading from DB, the persisted ID is passed in). - attack_result_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - # Composite identifier combining the attack strategy identity with - # seed identifiers from the dataset. - # Contains the attack strategy as children["attack"] plus optional seeds. - atomic_attack_identifier: Optional[ComponentIdentifier] = None - - # Evidence - # Model response generated in the final turn of the attack - last_response: Optional[MessagePiece] = None - - # Score assigned to the final response by a scorer component - last_score: Optional[Score] = None - - # Metrics - # Total number of turns that were executed - executed_turns: int = 0 - - # Total execution time of the attack in milliseconds - execution_time_ms: int = 0 - - # Outcome - # The outcome of the attack, indicating success, failure, or undetermined - outcome: AttackOutcome = AttackOutcome.UNDETERMINED - - # Optional reason for the outcome, providing additional context - outcome_reason: Optional[str] = None - - # Wall-clock time the result was created or persisted. - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - # Flexible conversation refs (nothing unused) - related_conversations: set[ConversationReference] = field(default_factory=set) - - # Arbitrary metadata - metadata: dict[str, Any] = field(default_factory=dict) - - # labels associated with this attack result - labels: dict[str, str] = field(default_factory=dict) - - # Error information (populated when attack fails with exception) - error_message: str | None = None - error_type: str | None = None - error_traceback: str | None = None - - # Retry tracking - retry_events: list[RetryEvent] = field(default_factory=list) - total_retries: int = 0 - - # Attribution / parent linkage (infrastructure-managed). Set by the attack - # persistence path when an AttackResultAttribution is present on the - # AttackContext. User code should not set these directly; ad-hoc - # AttackResults created outside an orchestrator leave both fields as None - # and the corresponding DB columns remain NULL. - attribution_parent_id: str | None = None - attribution_data: dict[str, Any] | None = None - - @property - def attack_identifier(self) -> Optional[ComponentIdentifier]: - """ - Deprecated: use ``get_attack_strategy_identifier()`` or ``atomic_attack_identifier`` instead. - - Returns the attack strategy ``ComponentIdentifier`` extracted from - ``atomic_attack_identifier``, emitting a deprecation warning. - - Returns: - Optional[ComponentIdentifier]: The attack strategy identifier, or ``None``. - - """ - print_deprecation_message( - old_item="AttackResult.attack_identifier", - new_item="AttackResult.atomic_attack_identifier or get_attack_strategy_identifier()", - removed_in="0.15.0", - ) - return self.get_attack_strategy_identifier() - - def get_attack_strategy_identifier(self) -> Optional[ComponentIdentifier]: - """ - Return the attack strategy identifier from the composite atomic identifier. - - This is the non-deprecated replacement for the ``attack_identifier`` property. - Extracts the ``"attack"`` child from the nested ``"attack_technique"`` child - of ``atomic_attack_identifier``. - - Falls back to ``children["attack"]`` for rows created before the nested - structure was introduced. - - Returns: - Optional[ComponentIdentifier]: The attack strategy identifier, or ``None`` if - ``atomic_attack_identifier`` is not set or the expected children are missing. - - """ - if self.atomic_attack_identifier is None: - return None - technique = self.atomic_attack_identifier.get_child("attack_technique") - if technique is not None: - return technique.get_child("attack") - # Fallback for pre-nesting rows that had children["attack"] directly. - return self.atomic_attack_identifier.get_child("attack") - - def get_conversations_by_type(self, conversation_type: ConversationType) -> list[ConversationReference]: - """ - Return all related conversations of the requested type. - - Args: - conversation_type (ConversationType): The type of conversation to filter by. - - Returns: - list: A list of related conversations matching the specified type. - - """ - return [ref for ref in self.related_conversations if ref.conversation_type == conversation_type] - - def get_all_conversation_ids(self) -> set[str]: - """ - Return the main conversation ID plus all related conversation IDs. - - Returns: - set[str]: All conversation IDs associated with this attack. - """ - return {self.conversation_id} | {ref.conversation_id for ref in self.related_conversations} - - def get_active_conversation_ids(self) -> set[str]: - """ - Return the main conversation ID plus pruned (user-visible) related conversation IDs. - - Excludes adversarial chat conversations which are internal implementation details. - - Returns: - set[str]: Main + pruned conversation IDs. - """ - return {self.conversation_id} | { - ref.conversation_id - for ref in self.related_conversations - if ref.conversation_type == ConversationType.PRUNED - } - - def get_pruned_conversation_ids(self) -> list[str]: - """ - Return IDs of pruned (branched) conversations only. - - Returns: - list[str]: Pruned conversation IDs. - """ - return [ - ref.conversation_id - for ref in self.related_conversations - if ref.conversation_type == ConversationType.PRUNED - ] - - def includes_conversation(self, conversation_id: str) -> bool: - """ - Check whether a conversation belongs to this attack (main or any related). - - Args: - conversation_id (str): The conversation ID to check. - - Returns: - bool: True if the conversation is part of this attack. - """ - return conversation_id in self.get_all_conversation_ids() - - def __str__(self) -> str: - """ - Return a concise string representation of this attack result. - - Returns: - str: Summary containing conversation ID, outcome, and objective preview. - - """ - return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." - - def to_dict(self) -> dict[str, Any]: - """ - Serialize this attack result to a JSON-compatible dictionary. - - Returns: - dict[str, Any]: Serialized payload suitable for REST APIs or persistence. - """ - return { - "conversation_id": self.conversation_id, - "objective": self.objective, - "attack_result_id": self.attack_result_id, - "atomic_attack_identifier": ( - self.atomic_attack_identifier.model_dump() if self.atomic_attack_identifier else None - ), - "last_response": self.last_response.model_dump(mode="json") if self.last_response else None, - "last_score": self.last_score.to_dict() if self.last_score else None, - "executed_turns": self.executed_turns, - "execution_time_ms": self.execution_time_ms, - "outcome": self.outcome.value, - "outcome_reason": self.outcome_reason, - "timestamp": self.timestamp.isoformat(), - "related_conversations": sorted( - [ref.model_dump(mode="json") for ref in self.related_conversations], - key=lambda r: r["conversation_id"], - ), - "metadata": self.metadata, - "labels": self.labels, - "error_message": self.error_message, - "error_type": self.error_type, - "error_traceback": self.error_traceback, - "retry_events": [e.model_dump(mode="json") for e in self.retry_events], - "total_retries": self.total_retries, - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> AttackResult: - """ - Reconstruct an AttackResult from a dictionary. - - Args: - data (dict[str, Any]): Dictionary as produced by to_dict(). - - Returns: - AttackResult: Reconstructed instance. - """ - return cls( - conversation_id=data["conversation_id"], - objective=data["objective"], - attack_result_id=data.get("attack_result_id", str(uuid.uuid4())), - atomic_attack_identifier=( - ComponentIdentifier.model_validate(data["atomic_attack_identifier"]) - if data.get("atomic_attack_identifier") - else None - ), - last_response=(MessagePiece.model_validate(data["last_response"]) if data.get("last_response") else None), - last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, - executed_turns=data.get("executed_turns", 0), - execution_time_ms=data.get("execution_time_ms", 0), - outcome=AttackOutcome(data.get("outcome", "undetermined")), - outcome_reason=data.get("outcome_reason"), - timestamp=( - datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) - ), - related_conversations={ - ConversationReference.model_validate(r) for r in data.get("related_conversations", []) - }, - metadata=data.get("metadata", {}), - labels=data.get("labels", {}), - error_message=data.get("error_message"), - error_type=data.get("error_type"), - error_traceback=data.get("error_traceback"), - retry_events=[RetryEvent.model_validate(e) for e in data.get("retry_events", [])], - total_retries=data.get("total_retries", 0), - ) - - -def _add_attack_identifier_compat(cls: type) -> type: - """ - Wrap a dataclass ``__init__`` to accept the deprecated ``attack_identifier`` kwarg. - - When ``attack_identifier`` is passed, it is automatically promoted to - ``atomic_attack_identifier`` via ``build_atomic_attack_identifier`` and a - deprecation warning is emitted. - - Args: - cls: The dataclass to wrap. - - Returns: - The same class with a wrapped ``__init__``. - - """ - original_init = cls.__init__ - - @functools.wraps(original_init) - def wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: - attack_identifier = kwargs.pop("attack_identifier", None) - if attack_identifier is not None: - print_deprecation_message( - old_item="AttackResult(attack_identifier=...)", - new_item="AttackResult(atomic_attack_identifier=...)", - removed_in="0.15.0", - ) - if kwargs.get("atomic_attack_identifier") is None: - kwargs["atomic_attack_identifier"] = build_atomic_attack_identifier( - attack_identifier=attack_identifier, - ) - original_init(self, *args, **kwargs) - - cls.__init__ = wrapped_init # type: ignore[ty:invalid-assignment] - return cls - - -_add_attack_identifier_compat(AttackResult) +__all__ = ["AttackOutcome", "AttackResult", "AttackResultT"] diff --git a/pyrit/models/results/__init__.py b/pyrit/models/results/__init__.py new file mode 100644 index 0000000000..4bcc2f8848 --- /dev/null +++ b/pyrit/models/results/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Results module - strategy and attack result types for PyRIT. + +- StrategyResult: Base class for all strategy results. +- AttackResult: Result of an attack execution, with conversation/scoring evidence. +- AttackOutcome: Enum of possible attack outcomes. +""" + +from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT + +__all__ = [ + "AttackOutcome", + "AttackResult", + "AttackResultT", + "StrategyResult", + "StrategyResultT", +] diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py new file mode 100644 index 0000000000..e937ed90e4 --- /dev/null +++ b/pyrit/models/results/attack_result.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import functools +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Optional, TypeVar + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models.results.strategy_result import StrategyResult +from pyrit.models.retry_event import RetryEvent +from pyrit.models.score import Score + +AttackResultT = TypeVar("AttackResultT", bound="AttackResult") + + +class AttackOutcome(str, Enum): + """ + Enum representing the possible outcomes of an attack. + + Inherits from ``str`` so that values serialize naturally in Pydantic + models and REST responses without a dedicated mapping function. + """ + + # The attack was successful in achieving its objective + SUCCESS = "success" + + # The attack failed to achieve its objective + FAILURE = "failure" + + # The attack failed due to an infrastructure error (exception), not a defensive refusal + ERROR = "error" + + # The outcome of the attack is unknown or could not be determined + UNDETERMINED = "undetermined" + + +@dataclass +class AttackResult(StrategyResult): + """Base class for all attack results.""" + + # Identity + # Unique identifier of the conversation that produced this result + conversation_id: str + + # Natural-language description of the attacker's objective + objective: str + + # Database-assigned unique ID for this AttackResult row. + # Auto-generated if not provided (e.g. when loading from DB, the persisted ID is passed in). + attack_result_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + # Composite identifier combining the attack strategy identity with + # seed identifiers from the dataset. + # Contains the attack strategy as children["attack"] plus optional seeds. + atomic_attack_identifier: Optional[ComponentIdentifier] = None + + # Evidence + # Model response generated in the final turn of the attack + last_response: Optional[MessagePiece] = None + + # Score assigned to the final response by a scorer component + last_score: Optional[Score] = None + + # Metrics + # Total number of turns that were executed + executed_turns: int = 0 + + # Total execution time of the attack in milliseconds + execution_time_ms: int = 0 + + # Outcome + # The outcome of the attack, indicating success, failure, or undetermined + outcome: AttackOutcome = AttackOutcome.UNDETERMINED + + # Optional reason for the outcome, providing additional context + outcome_reason: Optional[str] = None + + # Wall-clock time the result was created or persisted. + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + # Flexible conversation refs (nothing unused) + related_conversations: set[ConversationReference] = field(default_factory=set) + + # Arbitrary metadata + metadata: dict[str, Any] = field(default_factory=dict) + + # labels associated with this attack result + labels: dict[str, str] = field(default_factory=dict) + + # Error information (populated when attack fails with exception) + error_message: str | None = None + error_type: str | None = None + error_traceback: str | None = None + + # Retry tracking + retry_events: list[RetryEvent] = field(default_factory=list) + total_retries: int = 0 + + # Attribution / parent linkage (infrastructure-managed). Set by the attack + # persistence path when an AttackResultAttribution is present on the + # AttackContext. User code should not set these directly; ad-hoc + # AttackResults created outside an orchestrator leave both fields as None + # and the corresponding DB columns remain NULL. + attribution_parent_id: str | None = None + attribution_data: dict[str, Any] | None = None + + @property + def attack_identifier(self) -> Optional[ComponentIdentifier]: + """ + Deprecated: use ``get_attack_strategy_identifier()`` or ``atomic_attack_identifier`` instead. + + Returns the attack strategy ``ComponentIdentifier`` extracted from + ``atomic_attack_identifier``, emitting a deprecation warning. + + Returns: + Optional[ComponentIdentifier]: The attack strategy identifier, or ``None``. + + """ + print_deprecation_message( + old_item="AttackResult.attack_identifier", + new_item="AttackResult.atomic_attack_identifier or get_attack_strategy_identifier()", + removed_in="0.15.0", + ) + return self.get_attack_strategy_identifier() + + def get_attack_strategy_identifier(self) -> Optional[ComponentIdentifier]: + """ + Return the attack strategy identifier from the composite atomic identifier. + + This is the non-deprecated replacement for the ``attack_identifier`` property. + Extracts the ``"attack"`` child from the nested ``"attack_technique"`` child + of ``atomic_attack_identifier``. + + Falls back to ``children["attack"]`` for rows created before the nested + structure was introduced. + + Returns: + Optional[ComponentIdentifier]: The attack strategy identifier, or ``None`` if + ``atomic_attack_identifier`` is not set or the expected children are missing. + + """ + if self.atomic_attack_identifier is None: + return None + technique = self.atomic_attack_identifier.get_child("attack_technique") + if technique is not None: + return technique.get_child("attack") + # Fallback for pre-nesting rows that had children["attack"] directly. + return self.atomic_attack_identifier.get_child("attack") + + def get_conversations_by_type(self, conversation_type: ConversationType) -> list[ConversationReference]: + """ + Return all related conversations of the requested type. + + Args: + conversation_type (ConversationType): The type of conversation to filter by. + + Returns: + list: A list of related conversations matching the specified type. + + """ + return [ref for ref in self.related_conversations if ref.conversation_type == conversation_type] + + def get_all_conversation_ids(self) -> set[str]: + """ + Return the main conversation ID plus all related conversation IDs. + + Returns: + set[str]: All conversation IDs associated with this attack. + """ + return {self.conversation_id} | {ref.conversation_id for ref in self.related_conversations} + + def get_active_conversation_ids(self) -> set[str]: + """ + Return the main conversation ID plus pruned (user-visible) related conversation IDs. + + Excludes adversarial chat conversations which are internal implementation details. + + Returns: + set[str]: Main + pruned conversation IDs. + """ + return {self.conversation_id} | { + ref.conversation_id + for ref in self.related_conversations + if ref.conversation_type == ConversationType.PRUNED + } + + def get_pruned_conversation_ids(self) -> list[str]: + """ + Return IDs of pruned (branched) conversations only. + + Returns: + list[str]: Pruned conversation IDs. + """ + return [ + ref.conversation_id + for ref in self.related_conversations + if ref.conversation_type == ConversationType.PRUNED + ] + + def includes_conversation(self, conversation_id: str) -> bool: + """ + Check whether a conversation belongs to this attack (main or any related). + + Args: + conversation_id (str): The conversation ID to check. + + Returns: + bool: True if the conversation is part of this attack. + """ + return conversation_id in self.get_all_conversation_ids() + + def __str__(self) -> str: + """ + Return a concise string representation of this attack result. + + Returns: + str: Summary containing conversation ID, outcome, and objective preview. + + """ + return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this attack result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + return { + "conversation_id": self.conversation_id, + "objective": self.objective, + "attack_result_id": self.attack_result_id, + "atomic_attack_identifier": ( + self.atomic_attack_identifier.model_dump() if self.atomic_attack_identifier else None + ), + "last_response": self.last_response.model_dump(mode="json") if self.last_response else None, + "last_score": self.last_score.to_dict() if self.last_score else None, + "executed_turns": self.executed_turns, + "execution_time_ms": self.execution_time_ms, + "outcome": self.outcome.value, + "outcome_reason": self.outcome_reason, + "timestamp": self.timestamp.isoformat(), + "related_conversations": sorted( + [ref.model_dump(mode="json") for ref in self.related_conversations], + key=lambda r: r["conversation_id"], + ), + "metadata": self.metadata, + "labels": self.labels, + "error_message": self.error_message, + "error_type": self.error_type, + "error_traceback": self.error_traceback, + "retry_events": [e.model_dump(mode="json") for e in self.retry_events], + "total_retries": self.total_retries, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AttackResult: + """ + Reconstruct an AttackResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + AttackResult: Reconstructed instance. + """ + return cls( + conversation_id=data["conversation_id"], + objective=data["objective"], + attack_result_id=data.get("attack_result_id", str(uuid.uuid4())), + atomic_attack_identifier=( + ComponentIdentifier.model_validate(data["atomic_attack_identifier"]) + if data.get("atomic_attack_identifier") + else None + ), + last_response=(MessagePiece.model_validate(data["last_response"]) if data.get("last_response") else None), + last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, + executed_turns=data.get("executed_turns", 0), + execution_time_ms=data.get("execution_time_ms", 0), + outcome=AttackOutcome(data.get("outcome", "undetermined")), + outcome_reason=data.get("outcome_reason"), + timestamp=( + datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) + ), + related_conversations={ + ConversationReference.model_validate(r) for r in data.get("related_conversations", []) + }, + metadata=data.get("metadata", {}), + labels=data.get("labels", {}), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + error_traceback=data.get("error_traceback"), + retry_events=[RetryEvent.model_validate(e) for e in data.get("retry_events", [])], + total_retries=data.get("total_retries", 0), + ) + + +def _add_attack_identifier_compat(cls: type) -> type: + """ + Wrap a dataclass ``__init__`` to accept the deprecated ``attack_identifier`` kwarg. + + When ``attack_identifier`` is passed, it is automatically promoted to + ``atomic_attack_identifier`` via ``build_atomic_attack_identifier`` and a + deprecation warning is emitted. + + Args: + cls: The dataclass to wrap. + + Returns: + The same class with a wrapped ``__init__``. + + """ + original_init = cls.__init__ + + @functools.wraps(original_init) + def wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: + attack_identifier = kwargs.pop("attack_identifier", None) + if attack_identifier is not None: + print_deprecation_message( + old_item="AttackResult(attack_identifier=...)", + new_item="AttackResult(atomic_attack_identifier=...)", + removed_in="0.15.0", + ) + if kwargs.get("atomic_attack_identifier") is None: + kwargs["atomic_attack_identifier"] = build_atomic_attack_identifier( + attack_identifier=attack_identifier, + ) + original_init(self, *args, **kwargs) + + cls.__init__ = wrapped_init # type: ignore[ty:invalid-assignment] + return cls + + +_add_attack_identifier_compat(AttackResult) diff --git a/pyrit/models/results/strategy_result.py b/pyrit/models/results/strategy_result.py new file mode 100644 index 0000000000..38fac2af04 --- /dev/null +++ b/pyrit/models/results/strategy_result.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from abc import ABC +from copy import deepcopy +from dataclasses import dataclass +from typing import TypeVar + +StrategyResultT = TypeVar("StrategyResultT", bound="StrategyResult") + + +@dataclass +class StrategyResult(ABC): # noqa: B024 + """Base class for all strategy results.""" + + def duplicate(self: StrategyResultT) -> StrategyResultT: + """ + Create a deep copy of the result. + + Returns: + StrategyResult: A deep copy of the result. + + """ + return deepcopy(self) diff --git a/pyrit/models/strategy_result.py b/pyrit/models/strategy_result.py index 38fac2af04..bd4367695b 100644 --- a/pyrit/models/strategy_result.py +++ b/pyrit/models/strategy_result.py @@ -1,26 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations +""" +Backward-compatibility shim. -from abc import ABC -from copy import deepcopy -from dataclasses import dataclass -from typing import TypeVar +``StrategyResult`` now lives in ``pyrit.models.results``. Import from there (or +from ``pyrit.models``) instead. This module re-exports the public names so +existing ``from pyrit.models.strategy_result import ...`` imports keep working. +""" -StrategyResultT = TypeVar("StrategyResultT", bound="StrategyResult") +from typing import Any +from pyrit.models.results import strategy_result as _strategy_result +from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT -@dataclass -class StrategyResult(ABC): # noqa: B024 - """Base class for all strategy results.""" - def duplicate(self: StrategyResultT) -> StrategyResultT: - """ - Create a deep copy of the result. +def __getattr__(name: str) -> Any: + return getattr(_strategy_result, name) - Returns: - StrategyResult: A deep copy of the result. - """ - return deepcopy(self) +__all__ = ["StrategyResult", "StrategyResultT"] diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 47465d0299..357c3059ae 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -6,9 +6,9 @@ from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ComponentIdentifier, build_atomic_attack_identifier -from pyrit.models.attack_result import AttackOutcome, AttackResult from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index d9722ccea1..fbf46e19da 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -4,7 +4,7 @@ import uuid from pyrit.models import ComponentIdentifier -from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult diff --git a/tests/unit/models/test_strategy_result.py b/tests/unit/models/test_strategy_result.py index 15eceb5cbb..1eb3b3277d 100644 --- a/tests/unit/models/test_strategy_result.py +++ b/tests/unit/models/test_strategy_result.py @@ -3,7 +3,7 @@ from dataclasses import dataclass -from pyrit.models.strategy_result import StrategyResult +from pyrit.models.results.strategy_result import StrategyResult @dataclass From 6dc9c52a181317707b7b77f4d58fbe9a69c9a823 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 16:17:44 -0700 Subject: [PATCH 2/5] Convert AttackResult/StrategyResult and subclasses to Pydantic Complete the Phase 6 Pydantic conversion: StrategyResult and AttackResult (relocated into pyrit.models.results) are now Pydantic v2 BaseModels with extra='forbid', naive-timestamp coercion, and a deprecated attack_identifier kwarg/property that promotes to atomic_attack_identifier. to_dict/from_dict are retained as deprecated shims preserving the legacy wire shape. The before- validators copy the input dict so model_validate never mutates caller payloads. De-dataclass all StrategyResult/AttackResult subclasses so their Pydantic __init__ works correctly: Crescendo/TAP/Sequential attack results, Workflow/XPIA results, and PromptGenerator/GCG/Fuzzer/Anecdoctor results. Context classes remain dataclasses. field(default_factory=...) becomes pydantic Field(...). Fix tests that relied on the old dataclass accepting type-invalid values, and add pre-existing None-guards in crescendo/tree_of_attacks surfaced by ty once these files entered the changed-file set. Add conversion tests covering silent shims, extra='forbid', timestamp coercion, to_dict/from_dict deprecation, deep- copy independence, and combined validator / no-mutation behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/generator.py | 8 +- .../attack/compound/sequential_attack.py | 5 +- pyrit/executor/attack/multi_turn/crescendo.py | 5 +- .../attack/multi_turn/tree_of_attacks.py | 6 +- pyrit/executor/promptgen/anecdoctor.py | 1 - .../core/prompt_generator_strategy.py | 3 +- pyrit/executor/promptgen/fuzzer/fuzzer.py | 6 +- .../workflow/core/workflow_strategy.py | 3 +- pyrit/executor/workflow/xpia.py | 1 - pyrit/models/results/attack_result.py | 139 ++++++++++------- pyrit/models/results/strategy_result.py | 22 ++- .../component/test_simulated_conversation.py | 24 +-- .../attack/core/test_attack_strategy.py | 6 +- .../attack/multi_turn/test_red_teaming.py | 1 + .../attack/multi_turn/test_tree_of_attacks.py | 4 +- .../attack/single_turn/test_skeleton_key.py | 6 +- tests/unit/models/test_attack_result.py | 142 ++++++++++++++++++ tests/unit/models/test_strategy_result.py | 21 ++- .../unit/scenario/core/test_atomic_attack.py | 4 +- 19 files changed, 306 insertions(+), 101 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index ace55d0c32..98e9d5544b 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -42,6 +42,7 @@ import numpy as np import torch.multiprocessing as mp +from pydantic import Field import pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack as attack_lib from pyrit.auxiliary_attacks.gcg.attack.base.attack_manager import ( @@ -96,7 +97,6 @@ class GCGContext(PromptGeneratorStrategyContext): logfile_path: Optional[str] = None -@dataclass class GCGResult(PromptGeneratorStrategyResult): """Result of one GCGGenerator run. @@ -117,10 +117,10 @@ class GCGResult(PromptGeneratorStrategyResult): final_suffix: str = "" final_loss: float = float("nan") step_count: int = 0 - loss_history: list[float] = field(default_factory=list) - control_history: list[str] = field(default_factory=list) + loss_history: list[float] = Field(default_factory=list) + control_history: list[str] = Field(default_factory=list) log_path: Optional[str] = None - memory_labels: dict[str, str] = field(default_factory=dict) + memory_labels: dict[str, str] = Field(default_factory=dict) class GCGGenerator( diff --git a/pyrit/executor/attack/compound/sequential_attack.py b/pyrit/executor/attack/compound/sequential_attack.py index 2d07bb23c9..7e851507d4 100644 --- a/pyrit/executor/attack/compound/sequential_attack.py +++ b/pyrit/executor/attack/compound/sequential_attack.py @@ -28,6 +28,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional +from pydantic import Field + from pyrit.executor.attack.core.attack_executor import AttackExecutor from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy @@ -110,7 +112,6 @@ class SequentialChildAttack: memory_labels: Mapping[str, str] = field(default_factory=dict) -@dataclass class SequentialAttackResult(AttackResult): """ Result of a ``SequentialAttack`` execution. @@ -138,7 +139,7 @@ class SequentialAttackResult(AttackResult): round-trip. """ - child_attack_results: list[AttackResult] = field(default_factory=list) + child_attack_results: list[AttackResult] = Field(default_factory=list) completion_policy: SequenceCompletionPolicy = SequenceCompletionPolicy.FIRST_SUCCESS @property diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index dfd57e2515..3c1796d21c 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -81,7 +81,6 @@ class CrescendoAttackContext(MultiTurnAttackContext[Any]): backtrack_count: int = 0 -@dataclass class CrescendoAttackResult(AttackResult): """Result of the Crescendo attack strategy execution.""" @@ -832,7 +831,9 @@ async def _perform_backtrack_if_refused_async( # Check for refusal using the scorer (handles blocked/error responses internally) refusal_score = await self._check_refusal_async(context, prompt_sent) - self._logger.debug(f"Refusal check: {refusal_score.get_value()} - {refusal_score.score_rationale[:100]}...") + self._logger.debug( + f"Refusal check: {refusal_score.get_value()} - {(refusal_score.score_rationale or '')[:100]}..." + ) is_refusal = bool(refusal_score.get_value()) if not is_refusal: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 04b5aef915..94f3bc9ac1 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -176,7 +176,6 @@ class TAPAttackContext(MultiTurnAttackContext[Any]): best_adversarial_conversation_id: Optional[str] = None -@dataclass class TAPAttackResult(AttackResult): """ Result of the Tree of Attacks with Pruning (TAP) attack strategy execution. @@ -699,7 +698,8 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N # Extract auxiliary scores auxiliary_scores = scoring_results["auxiliary_scores"] for score in auxiliary_scores: - scorer_name = score.scorer_class_identifier.class_name + scorer_identifier = score.scorer_class_identifier + scorer_name = scorer_identifier.class_name if scorer_identifier else "unknown" self.auxiliary_scores[scorer_name] = score logger.debug(f"Node {self.node_id}: {scorer_name} score: {score.get_value()}") @@ -904,7 +904,7 @@ async def _generate_red_teaming_prompt_async(self, objective: str) -> str: # Generate feedback prompt and get a new response feedback_prompt = self._generate_off_topic_feedback_prompt( original_prompt=prompt, - off_topic_rationale=on_topic_score.score_rationale, + off_topic_rationale=on_topic_score.score_rationale or "", objective=objective, ) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index f4af880288..3e32fa4faa 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -57,7 +57,6 @@ class AnecdoctorContext(PromptGeneratorStrategyContext): memory_labels: dict[str, str] = field(default_factory=dict) -@dataclass class AnecdoctorResult(PromptGeneratorStrategyResult): """ Result of Anecdoctor prompt generation. diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index 6caafb437d..ff31513275 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -26,8 +26,7 @@ class PromptGeneratorStrategyContext(StrategyContext, ABC): """Base class for all prompt generator strategy contexts.""" -@dataclass -class PromptGeneratorStrategyResult(StrategyResult, ABC): +class PromptGeneratorStrategyResult(StrategyResult, ABC): # noqa: B024 """Base class for all prompt generator strategy results.""" diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 533a581675..df04314069 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -12,6 +12,7 @@ import numpy as np from colorama import Fore, Style +from pydantic import Field from pyrit.common.utils import combine_dict, get_kwarg_param from pyrit.exceptions import MissingPromptPlaceholderException, pyrit_placeholder_retry @@ -212,7 +213,6 @@ def __post_init__(self) -> None: ) -@dataclass class FuzzerResult(PromptGeneratorStrategyResult): """ Result of the Fuzzer prompt generation strategy execution. @@ -222,8 +222,8 @@ class FuzzerResult(PromptGeneratorStrategyResult): """ # Concrete fields instead of metadata storage - successful_templates: list[str] = field(default_factory=list) - jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = field(default_factory=list) + successful_templates: list[str] = Field(default_factory=list) + jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = Field(default_factory=list) total_queries: int = 0 templates_explored: int = 0 diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index cee2abfbd6..455178ae9a 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -27,8 +27,7 @@ class WorkflowContext(StrategyContext, ABC): """Base class for all workflow contexts.""" -@dataclass -class WorkflowResult(StrategyResult, ABC): +class WorkflowResult(StrategyResult, ABC): # noqa: B024 """Base class for all workflow results.""" diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index e95366ceb4..e981c46b63 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -81,7 +81,6 @@ class XPIAContext(WorkflowContext): memory_labels: dict[str, str] = field(default_factory=dict) -@dataclass class XPIAResult(WorkflowResult): """ Result of XPIA workflow execution. diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index de72afb7e7..1df42ce6e3 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -3,13 +3,13 @@ from __future__ import annotations -import functools import uuid -from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, Optional, TypeVar +from pydantic import AwareDatetime, Field, model_validator + from pyrit.common.deprecation import print_deprecation_message from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier @@ -43,7 +43,6 @@ class AttackOutcome(str, Enum): UNDETERMINED = "undetermined" -@dataclass class AttackResult(StrategyResult): """Base class for all attack results.""" @@ -56,7 +55,7 @@ class AttackResult(StrategyResult): # Database-assigned unique ID for this AttackResult row. # Auto-generated if not provided (e.g. when loading from DB, the persisted ID is passed in). - attack_result_id: str = field(default_factory=lambda: str(uuid.uuid4())) + attack_result_id: str = Field(default_factory=lambda: str(uuid.uuid4())) # Composite identifier combining the attack strategy identity with # seed identifiers from the dataset. @@ -85,24 +84,24 @@ class AttackResult(StrategyResult): outcome_reason: Optional[str] = None # Wall-clock time the result was created or persisted. - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # Flexible conversation refs (nothing unused) - related_conversations: set[ConversationReference] = field(default_factory=set) + related_conversations: set[ConversationReference] = Field(default_factory=set) # Arbitrary metadata - metadata: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) # labels associated with this attack result - labels: dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = Field(default_factory=dict) # Error information (populated when attack fails with exception) - error_message: str | None = None - error_type: str | None = None - error_traceback: str | None = None + error_message: Optional[str] = None + error_type: Optional[str] = None + error_traceback: Optional[str] = None # Retry tracking - retry_events: list[RetryEvent] = field(default_factory=list) + retry_events: list[RetryEvent] = Field(default_factory=list) total_retries: int = 0 # Attribution / parent linkage (infrastructure-managed). Set by the attack @@ -110,8 +109,62 @@ class AttackResult(StrategyResult): # AttackContext. User code should not set these directly; ad-hoc # AttackResults created outside an orchestrator leave both fields as None # and the corresponding DB columns remain NULL. - attribution_parent_id: str | None = None - attribution_data: dict[str, Any] | None = None + attribution_parent_id: Optional[str] = None + attribution_data: Optional[dict[str, Any]] = None + + @model_validator(mode="before") + @classmethod + def _coerce_naive_timestamp(cls, data: Any) -> Any: + """ + Coerce a naive ``timestamp`` (datetime or ISO string) to UTC. + + ``AwareDatetime`` rejects naive datetimes that the legacy dataclass + accepted (e.g. SQLite-loaded timestamps). Mirror ``_ensure_utc`` so + existing naive inputs keep validating. + + Returns: + The input ``data`` with a tz-aware ``timestamp`` when one was supplied. + """ + if not isinstance(data, dict): + return data + data = dict(data) + ts = data.get("timestamp") + if isinstance(ts, str): + ts = datetime.fromisoformat(ts) + if isinstance(ts, datetime) and ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + if ts is not None: + data["timestamp"] = ts + return data + + @model_validator(mode="before") + @classmethod + def _promote_deprecated_attack_identifier(cls, data: Any) -> Any: + """ + Promote the deprecated ``attack_identifier`` kwarg to ``atomic_attack_identifier``. + + Runs ahead of ``extra="forbid"`` so the legacy kwarg is consumed before + Pydantic would reject it. Emits a deprecation warning when present. + + Returns: + The input ``data`` with ``attack_identifier`` removed and (when it was + set and ``atomic_attack_identifier`` was not) promoted. + """ + if not isinstance(data, dict): + return data + data = dict(data) + attack_identifier = data.pop("attack_identifier", None) + if attack_identifier is not None: + print_deprecation_message( + old_item="AttackResult(attack_identifier=...)", + new_item="AttackResult(atomic_attack_identifier=...)", + removed_in="0.15.0", + ) + if data.get("atomic_attack_identifier") is None: + data["atomic_attack_identifier"] = build_atomic_attack_identifier( + attack_identifier=attack_identifier, + ) + return data @property def attack_identifier(self) -> Optional[ComponentIdentifier]: @@ -232,9 +285,19 @@ def to_dict(self) -> dict[str, Any]: """ Serialize this attack result to a JSON-compatible dictionary. + Deprecated: use ``model_dump(mode="json")`` for the canonical Pydantic + serialization. This shim preserves the legacy wire shape (base fields + only, raw ``metadata``, sorted ``related_conversations``) through the + deprecation window. + Returns: dict[str, Any]: Serialized payload suitable for REST APIs or persistence. """ + print_deprecation_message( + old_item="AttackResult.to_dict()", + new_item="AttackResult.model_dump(mode='json')", + removed_in="0.16.0", + ) return { "conversation_id": self.conversation_id, "objective": self.objective, @@ -267,12 +330,21 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: """ Reconstruct an AttackResult from a dictionary. + Deprecated: use ``model_validate(...)`` for the canonical Pydantic + deserialization. This shim accepts the legacy ``to_dict()`` wire shape + (base fields only) through the deprecation window. + Args: data (dict[str, Any]): Dictionary as produced by to_dict(). Returns: AttackResult: Reconstructed instance. """ + print_deprecation_message( + old_item="AttackResult.from_dict(...)", + new_item="AttackResult.model_validate(...)", + removed_in="0.16.0", + ) return cls( conversation_id=data["conversation_id"], objective=data["objective"], @@ -302,42 +374,3 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: retry_events=[RetryEvent.model_validate(e) for e in data.get("retry_events", [])], total_retries=data.get("total_retries", 0), ) - - -def _add_attack_identifier_compat(cls: type) -> type: - """ - Wrap a dataclass ``__init__`` to accept the deprecated ``attack_identifier`` kwarg. - - When ``attack_identifier`` is passed, it is automatically promoted to - ``atomic_attack_identifier`` via ``build_atomic_attack_identifier`` and a - deprecation warning is emitted. - - Args: - cls: The dataclass to wrap. - - Returns: - The same class with a wrapped ``__init__``. - - """ - original_init = cls.__init__ - - @functools.wraps(original_init) - def wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: - attack_identifier = kwargs.pop("attack_identifier", None) - if attack_identifier is not None: - print_deprecation_message( - old_item="AttackResult(attack_identifier=...)", - new_item="AttackResult(atomic_attack_identifier=...)", - removed_in="0.15.0", - ) - if kwargs.get("atomic_attack_identifier") is None: - kwargs["atomic_attack_identifier"] = build_atomic_attack_identifier( - attack_identifier=attack_identifier, - ) - original_init(self, *args, **kwargs) - - cls.__init__ = wrapped_init # type: ignore[ty:invalid-assignment] - return cls - - -_add_attack_identifier_compat(AttackResult) diff --git a/pyrit/models/results/strategy_result.py b/pyrit/models/results/strategy_result.py index 38fac2af04..2c30c4509b 100644 --- a/pyrit/models/results/strategy_result.py +++ b/pyrit/models/results/strategy_result.py @@ -4,18 +4,26 @@ from __future__ import annotations from abc import ABC -from copy import deepcopy -from dataclasses import dataclass -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar + +from pydantic import BaseModel, ConfigDict + +if TYPE_CHECKING: + from typing import Self StrategyResultT = TypeVar("StrategyResultT", bound="StrategyResult") -@dataclass -class StrategyResult(ABC): # noqa: B024 +class StrategyResult(BaseModel, ABC): # noqa: B024 """Base class for all strategy results.""" - def duplicate(self: StrategyResultT) -> StrategyResultT: + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + def duplicate(self) -> Self: """ Create a deep copy of the result. @@ -23,4 +31,4 @@ def duplicate(self: StrategyResultT) -> StrategyResultT: StrategyResult: A deep copy of the result. """ - return deepcopy(self) + return self.model_copy(deep=True) diff --git a/tests/unit/executor/attack/component/test_simulated_conversation.py b/tests/unit/executor/attack/component/test_simulated_conversation.py index 3baacfc375..7e2f0eddee 100644 --- a/tests/unit/executor/attack/component/test_simulated_conversation.py +++ b/tests/unit/executor/attack/component/test_simulated_conversation.py @@ -162,7 +162,7 @@ async def test_uses_adversarial_chat_as_simulated_target( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -206,7 +206,7 @@ async def test_creates_attack_with_score_last_turn_only_true( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -249,7 +249,7 @@ async def test_creates_attack_with_correct_max_turns( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -295,7 +295,7 @@ async def test_returns_simulated_conversation_result( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=conversation_id, @@ -344,7 +344,7 @@ async def test_passes_system_prompt_via_prepended_conversation( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -395,7 +395,7 @@ async def test_passes_memory_labels_to_execute( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -441,7 +441,7 @@ async def test_passes_converter_config_to_attack( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -485,7 +485,7 @@ async def test_prepends_system_message_to_conversation( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -533,7 +533,7 @@ async def test_uses_default_num_turns_of_3( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=str(uuid.uuid4()), @@ -591,7 +591,7 @@ async def test_next_message_system_prompt_path_generates_final_user_message( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=conversation_id, @@ -660,7 +660,7 @@ async def test_next_message_system_prompt_path_sets_system_prompt( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=conversation_id, @@ -707,7 +707,7 @@ async def test_starting_sequence_sets_first_sequence_number( ) mock_attack.execute_async = AsyncMock( return_value=AttackResult( - attack_identifier=ComponentIdentifier( + atomic_attack_identifier=ComponentIdentifier( class_name="RedTeamingAttack", class_module="pyrit.executor.attack" ), conversation_id=conversation_id, diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 31cd5fdb87..0cc9f42d2b 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -426,7 +426,11 @@ async def test_on_post_execute_adds_results_to_memory(self, mock_memory): sample_context = MagicMock() sample_context.start_time = 100.0 - sample_result = MagicMock(spec=AttackResult) + sample_result = AttackResult( + conversation_id="conv-id", + objective="test objective", + outcome=AttackOutcome.SUCCESS, + ) event_data = StrategyEventData( event=StrategyEvent.ON_POST_EXECUTE, diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 547709862b..1fdf2cbe97 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -1662,6 +1662,7 @@ async def test_attack_result_includes_adversarial_chat_conversation_ids( mock_send.return_value = sample_response mock_score.return_value = {"objective_scores": [success_score]} mock_generate.return_value = generated_message + mock_objective_scorer.score_async = AsyncMock(return_value=[success_score]) # Run setup and attack await attack._setup_async(context=basic_context) diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 21974e72ab..5c805455bf 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -95,7 +95,7 @@ def create_node(config: Optional[NodeMockConfig] = None) -> "_TreeOfAttacksNode" # Set up objective score if config.objective_score_value is not None: node.objective_score = MagicMock( - get_value=MagicMock(return_value=config.objective_score_value), score_metadata=None + spec=Score, get_value=MagicMock(return_value=config.objective_score_value), score_metadata=None ) else: node.objective_score = None @@ -2207,11 +2207,13 @@ async def _send_prompt(objective: str) -> None: node.completed = True if b.error is not None: node.objective_score = MagicMock( + spec=Score, get_value=MagicMock(return_value=0.0), score_metadata=None, ) elif b.score is not None: node.objective_score = MagicMock( + spec=Score, get_value=MagicMock(return_value=b.score), score_metadata=None, ) diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index e0d3205ff7..8676199d51 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -343,7 +343,7 @@ async def test_perform_attack_skeleton_key_success_objective_success( mock_parent.return_value = AttackResult( conversation_id=basic_context.conversation_id, objective=basic_context.objective, - last_response=sample_response, + last_response=sample_response.get_piece(), last_score=success_score, outcome=AttackOutcome.SUCCESS, executed_turns=1, @@ -360,7 +360,7 @@ async def test_perform_attack_skeleton_key_success_objective_success( # Verify result properties assert result.outcome == AttackOutcome.SUCCESS assert result.executed_turns == 2 # Should be updated to 2 turns - assert result.last_response == sample_response + assert result.last_response == sample_response.get_piece() assert result.last_score == success_score async def test_perform_attack_skeleton_key_failure(self, mock_target, basic_context): @@ -408,7 +408,7 @@ async def test_perform_attack_skeleton_key_success_objective_failure( mock_parent.return_value = AttackResult( conversation_id=basic_context.conversation_id, objective=basic_context.objective, - last_response=sample_response, + last_response=sample_response.get_piece(), last_score=failure_score, outcome=AttackOutcome.FAILURE, executed_turns=1, diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 357c3059ae..e1ed815454 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -4,6 +4,8 @@ import warnings from datetime import datetime, timezone +import pytest + from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ComponentIdentifier, build_atomic_attack_identifier from pyrit.models.conversation_reference import ConversationReference, ConversationType @@ -434,3 +436,143 @@ def test_to_dict_from_dict_roundtrip(): ) roundtripped = AttackResult.from_dict(original.to_dict()) assert original.to_dict() == roundtripped.to_dict() + + +class TestAttackResultValidation: + """Tests for the Pydantic validation behaviour introduced by the BaseModel conversion.""" + + def test_extra_fields_are_forbidden(self) -> None: + """Unknown kwargs must raise (extra='forbid' on the StrategyResult config).""" + with pytest.raises(ValueError): + AttackResult(conversation_id="c1", objective="test", not_a_field="boom") + + def test_naive_datetime_timestamp_is_coerced_to_utc(self) -> None: + """A naive datetime passed at construction is coerced to tz-aware UTC.""" + naive = datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 + result = AttackResult(conversation_id="c1", objective="test", timestamp=naive) + assert result.timestamp.tzinfo is timezone.utc + assert result.timestamp.replace(tzinfo=None) == naive + + def test_naive_iso_string_timestamp_is_coerced_to_utc(self) -> None: + """A naive ISO-8601 string is parsed and coerced to tz-aware UTC.""" + result = AttackResult(conversation_id="c1", objective="test", timestamp="2026-01-01T12:00:00") + assert result.timestamp.tzinfo is timezone.utc + assert result.timestamp.replace(tzinfo=None) == datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 + + def test_aware_iso_string_timestamp_is_preserved(self) -> None: + """An ISO string carrying an offset is parsed without altering the instant.""" + result = AttackResult(conversation_id="c1", objective="test", timestamp="2026-01-01T12:00:00+00:00") + assert result.timestamp == datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + def test_deprecated_kwarg_and_naive_timestamp_together(self) -> None: + """Both before-validators apply: kwarg promoted, timestamp coerced, no extra-field error.""" + attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") + naive = datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = AttackResult( + conversation_id="c1", + objective="test", + attack_identifier=attack_id, + timestamp=naive, + ) + + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert result.atomic_attack_identifier is not None + assert result.timestamp.tzinfo is timezone.utc + + def test_model_validate_does_not_mutate_input_dict(self) -> None: + """Before-validators must copy, not mutate, the caller-provided payload dict.""" + attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") + payload = { + "conversation_id": "c1", + "objective": "test", + "attack_identifier": attack_id, + "timestamp": "2026-01-01T12:00:00", + } + original = dict(payload) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + AttackResult.model_validate(payload) + + assert payload == original, "model_validate must not mutate the input dict" + + +class TestAttackResultLegacyDictDeprecation: + """to_dict()/from_dict() are retained as deprecated shims and must warn.""" + + def test_to_dict_emits_deprecation_warning(self) -> None: + result = AttackResult(conversation_id="c1", objective="test") + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result.to_dict() + + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert "to_dict" in str(deprecation_warnings[0].message).lower() + + def test_from_dict_emits_deprecation_warning(self) -> None: + result = AttackResult(conversation_id="c1", objective="test") + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + payload = result.to_dict() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + AttackResult.from_dict(payload) + + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert "from_dict" in str(deprecation_warnings[0].message).lower() + + +class TestAttackResultDuplicate: + """duplicate() must deep-copy so mutations on the copy never touch the original.""" + + def test_duplicate_metadata_is_independent(self) -> None: + original = AttackResult( + conversation_id="c1", + objective="test", + metadata={"nested": {"key": "value"}}, + ) + copy = original.duplicate() + copy.metadata["nested"]["key"] = "mutated" + copy.metadata["added"] = "new" + + assert original.metadata == {"nested": {"key": "value"}} + assert type(copy) is AttackResult + + def test_duplicate_preserves_subclass_type(self) -> None: + """duplicate() on a subclass returns the same subclass.""" + from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttackResult + + original = CrescendoAttackResult(conversation_id="c1", objective="test") + original.backtrack_count = 3 + copy = original.duplicate() + + assert type(copy) is CrescendoAttackResult + assert copy.backtrack_count == 3 + copy.backtrack_count = 9 + assert original.backtrack_count == 3 + + +class TestAttackResultShim: + """The relocated module must be importable from the legacy path silently.""" + + def test_shim_reexports_same_classes_silently(self) -> None: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from pyrit.models.attack_result import AttackOutcome as ShimOutcome + from pyrit.models.attack_result import AttackResult as ShimResult + + assert ShimResult is AttackResult + assert ShimOutcome is AttackOutcome + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0, "Shim import must be silent" + + def test_shim_getattr_reexports_dynamic_names(self) -> None: + """The module __getattr__ falls through to the relocated module.""" + import pyrit.models.attack_result as shim + + assert shim.AttackResultT is not None diff --git a/tests/unit/models/test_strategy_result.py b/tests/unit/models/test_strategy_result.py index 1eb3b3277d..07508abef7 100644 --- a/tests/unit/models/test_strategy_result.py +++ b/tests/unit/models/test_strategy_result.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +import warnings + +import pytest from pyrit.models.results.strategy_result import StrategyResult -@dataclass class ConcreteResult(StrategyResult): value: str = "" count: int = 0 @@ -33,3 +34,19 @@ def test_strategy_result_duplicate_preserves_type(): original = ConcreteResult(value="test", count=1) copy = original.duplicate() assert type(copy) is ConcreteResult + + +def test_strategy_result_forbids_extra_fields(): + with pytest.raises(ValueError): + ConcreteResult(value="hello", count=1, unexpected="boom") + + +def test_strategy_result_shim_reexports_same_class_silently(): + """The old import path must re-export the identical class without warning.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + from pyrit.models.strategy_result import StrategyResult as ShimStrategyResult + + assert ShimStrategyResult is StrategyResult + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0, "Shim import must be silent" diff --git a/tests/unit/scenario/core/test_atomic_attack.py b/tests/unit/scenario/core/test_atomic_attack.py index 17940a8348..af28362fc0 100644 --- a/tests/unit/scenario/core/test_atomic_attack.py +++ b/tests/unit/scenario/core/test_atomic_attack.py @@ -998,7 +998,7 @@ async def test_enrichment_persists_to_db(self, mock_attack): assert persisted["class_name"] == "AtomicAttack" async def test_enrichment_skips_db_update_when_no_attack_result_id(self, mock_attack): - """Test that enrichment does not attempt a DB update when attack_result_id is None.""" + """Test that enrichment does not attempt a DB update when attack_result_id is empty.""" seed_groups = [ SeedAttackGroup( seeds=[ @@ -1013,7 +1013,7 @@ async def test_enrichment_skips_db_update_when_no_attack_result_id(self, mock_at objective="obj1", outcome=AttackOutcome.SUCCESS, executed_turns=1, - attack_result_id=None, + attack_result_id="", atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_id), ) From c7132f5a4cdbdfc77c314285550a3ebe46166454 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 2 Jun 2026 16:31:42 -0700 Subject: [PATCH 3/5] Drop bespoke naive-timestamp validator on AttackResult Use plain AwareDatetime (matching Score/MessagePiece) instead of a custom _coerce_naive_timestamp model_validator. Naive timestamps from the DB are already normalized to UTC by AttackResultEntry.get_attack_result via _ensure_utc before the constructor runs, so the model can stay strict. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/results/attack_result.py | 27 +-------------------- tests/unit/models/test_attack_result.py | 31 +++++++++++-------------- 2 files changed, 14 insertions(+), 44 deletions(-) diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index 1df42ce6e3..eeca8b1fff 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -84,7 +84,7 @@ class AttackResult(StrategyResult): outcome_reason: Optional[str] = None # Wall-clock time the result was created or persisted. - timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) # Flexible conversation refs (nothing unused) related_conversations: set[ConversationReference] = Field(default_factory=set) @@ -112,31 +112,6 @@ class AttackResult(StrategyResult): attribution_parent_id: Optional[str] = None attribution_data: Optional[dict[str, Any]] = None - @model_validator(mode="before") - @classmethod - def _coerce_naive_timestamp(cls, data: Any) -> Any: - """ - Coerce a naive ``timestamp`` (datetime or ISO string) to UTC. - - ``AwareDatetime`` rejects naive datetimes that the legacy dataclass - accepted (e.g. SQLite-loaded timestamps). Mirror ``_ensure_utc`` so - existing naive inputs keep validating. - - Returns: - The input ``data`` with a tz-aware ``timestamp`` when one was supplied. - """ - if not isinstance(data, dict): - return data - data = dict(data) - ts = data.get("timestamp") - if isinstance(ts, str): - ts = datetime.fromisoformat(ts) - if isinstance(ts, datetime) and ts.tzinfo is None: - ts = ts.replace(tzinfo=timezone.utc) - if ts is not None: - data["timestamp"] = ts - return data - @model_validator(mode="before") @classmethod def _promote_deprecated_attack_identifier(cls, data: Any) -> Any: diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index e1ed815454..b16cec2ab6 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -446,50 +446,45 @@ def test_extra_fields_are_forbidden(self) -> None: with pytest.raises(ValueError): AttackResult(conversation_id="c1", objective="test", not_a_field="boom") - def test_naive_datetime_timestamp_is_coerced_to_utc(self) -> None: - """A naive datetime passed at construction is coerced to tz-aware UTC.""" - naive = datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 - result = AttackResult(conversation_id="c1", objective="test", timestamp=naive) - assert result.timestamp.tzinfo is timezone.utc - assert result.timestamp.replace(tzinfo=None) == naive + def test_naive_datetime_timestamp_is_rejected(self) -> None: + """Naive datetimes are rejected (AwareDatetime), matching Score/MessagePiece. - def test_naive_iso_string_timestamp_is_coerced_to_utc(self) -> None: - """A naive ISO-8601 string is parsed and coerced to tz-aware UTC.""" - result = AttackResult(conversation_id="c1", objective="test", timestamp="2026-01-01T12:00:00") - assert result.timestamp.tzinfo is timezone.utc - assert result.timestamp.replace(tzinfo=None) == datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 + SQLite-loaded naive timestamps are normalized to UTC by the memory layer + (``AttackResultEntry.get_attack_result`` via ``_ensure_utc``) before they + ever reach this constructor, so the model itself stays strict. + """ + naive = datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 + with pytest.raises(ValueError): + AttackResult(conversation_id="c1", objective="test", timestamp=naive) def test_aware_iso_string_timestamp_is_preserved(self) -> None: """An ISO string carrying an offset is parsed without altering the instant.""" result = AttackResult(conversation_id="c1", objective="test", timestamp="2026-01-01T12:00:00+00:00") assert result.timestamp == datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - def test_deprecated_kwarg_and_naive_timestamp_together(self) -> None: - """Both before-validators apply: kwarg promoted, timestamp coerced, no extra-field error.""" + def test_deprecated_kwarg_promotes_without_extra_field_error(self) -> None: + """The promote before-validator pops attack_identifier before extra='forbid' runs.""" attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") - naive = datetime(2026, 1, 1, 12, 0, 0) # noqa: DTZ001 with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") result = AttackResult( conversation_id="c1", objective="test", attack_identifier=attack_id, - timestamp=naive, ) deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] assert len(deprecation_warnings) >= 1 assert result.atomic_attack_identifier is not None - assert result.timestamp.tzinfo is timezone.utc def test_model_validate_does_not_mutate_input_dict(self) -> None: - """Before-validators must copy, not mutate, the caller-provided payload dict.""" + """The promote before-validator must copy, not mutate, the caller-provided payload dict.""" attack_id = ComponentIdentifier(class_name="TestAttack", class_module="tests.unit") payload = { "conversation_id": "c1", "objective": "test", "attack_identifier": attack_id, - "timestamp": "2026-01-01T12:00:00", + "timestamp": "2026-01-01T12:00:00+00:00", } original = dict(payload) with warnings.catch_warnings(record=True): From 9c4a2686c234f7ecdcbc7b955913bd75f315a7d8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 3 Jun 2026 09:00:46 -0700 Subject: [PATCH 4/5] Use X | None instead of Optional[...] per style guide Address review feedback on AttackResult and GCGResult: replace the Optional[...] type annotations with the modern union syntax (X | None) mandated by the PyRIT style guide. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/generator.py | 2 +- pyrit/models/results/attack_result.py | 28 ++++++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index 98e9d5544b..cf353a2adf 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -119,7 +119,7 @@ class GCGResult(PromptGeneratorStrategyResult): step_count: int = 0 loss_history: list[float] = Field(default_factory=list) control_history: list[str] = Field(default_factory=list) - log_path: Optional[str] = None + log_path: str | None = None memory_labels: dict[str, str] = Field(default_factory=dict) diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index eeca8b1fff..8d2043f16d 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -6,7 +6,7 @@ import uuid from datetime import datetime, timezone from enum import Enum -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from pydantic import AwareDatetime, Field, model_validator @@ -60,14 +60,14 @@ class AttackResult(StrategyResult): # Composite identifier combining the attack strategy identity with # seed identifiers from the dataset. # Contains the attack strategy as children["attack"] plus optional seeds. - atomic_attack_identifier: Optional[ComponentIdentifier] = None + atomic_attack_identifier: ComponentIdentifier | None = None # Evidence # Model response generated in the final turn of the attack - last_response: Optional[MessagePiece] = None + last_response: MessagePiece | None = None # Score assigned to the final response by a scorer component - last_score: Optional[Score] = None + last_score: Score | None = None # Metrics # Total number of turns that were executed @@ -81,7 +81,7 @@ class AttackResult(StrategyResult): outcome: AttackOutcome = AttackOutcome.UNDETERMINED # Optional reason for the outcome, providing additional context - outcome_reason: Optional[str] = None + outcome_reason: str | None = None # Wall-clock time the result was created or persisted. timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) @@ -96,9 +96,9 @@ class AttackResult(StrategyResult): labels: dict[str, str] = Field(default_factory=dict) # Error information (populated when attack fails with exception) - error_message: Optional[str] = None - error_type: Optional[str] = None - error_traceback: Optional[str] = None + error_message: str | None = None + error_type: str | None = None + error_traceback: str | None = None # Retry tracking retry_events: list[RetryEvent] = Field(default_factory=list) @@ -109,8 +109,8 @@ class AttackResult(StrategyResult): # AttackContext. User code should not set these directly; ad-hoc # AttackResults created outside an orchestrator leave both fields as None # and the corresponding DB columns remain NULL. - attribution_parent_id: Optional[str] = None - attribution_data: Optional[dict[str, Any]] = None + attribution_parent_id: str | None = None + attribution_data: dict[str, Any] | None = None @model_validator(mode="before") @classmethod @@ -142,7 +142,7 @@ def _promote_deprecated_attack_identifier(cls, data: Any) -> Any: return data @property - def attack_identifier(self) -> Optional[ComponentIdentifier]: + def attack_identifier(self) -> ComponentIdentifier | None: """ Deprecated: use ``get_attack_strategy_identifier()`` or ``atomic_attack_identifier`` instead. @@ -150,7 +150,7 @@ def attack_identifier(self) -> Optional[ComponentIdentifier]: ``atomic_attack_identifier``, emitting a deprecation warning. Returns: - Optional[ComponentIdentifier]: The attack strategy identifier, or ``None``. + ComponentIdentifier | None: The attack strategy identifier, or ``None``. """ print_deprecation_message( @@ -160,7 +160,7 @@ def attack_identifier(self) -> Optional[ComponentIdentifier]: ) return self.get_attack_strategy_identifier() - def get_attack_strategy_identifier(self) -> Optional[ComponentIdentifier]: + def get_attack_strategy_identifier(self) -> ComponentIdentifier | None: """ Return the attack strategy identifier from the composite atomic identifier. @@ -172,7 +172,7 @@ def get_attack_strategy_identifier(self) -> Optional[ComponentIdentifier]: structure was introduced. Returns: - Optional[ComponentIdentifier]: The attack strategy identifier, or ``None`` if + ComponentIdentifier | None: The attack strategy identifier, or ``None`` if ``atomic_attack_identifier`` is not set or the expected children are missing. """ From 61204590b646534d8836af64a105f8003b1e07f4 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 3 Jun 2026 09:07:19 -0700 Subject: [PATCH 5/5] Anchor results/ gitignore rule to repo root instead of negating The bare results/ pattern matched a directory named results at any depth, which swallowed the new pyrit/models/results/ source package and required a !pyrit/models/results/ negation. Anchor it as /results/ so it only ignores the top-level runtime output dir (memory results_path), and drop the negation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9bb4847b0f..e050563978 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ # PyRIT-specific configs submodules/ -results/ -!pyrit/models/results/ +/results/ dbdata/ eval/ default_memory.json.memory