diff --git a/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb b/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb index 2b5600ce5a..971a066c83 100644 --- a/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb +++ b/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb @@ -128,14 +128,12 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Union\n", - "\n", "from azure.ai.ml import MLClient\n", "from azure.core.exceptions import ResourceNotFoundError\n", "from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n", "\n", "try:\n", - " credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n", + " credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n", " credential.get_token(\"https://management.azure.com/.default\")\n", "except Exception as ex:\n", " credential = InteractiveBrowserCredential()\n", diff --git a/doc/getting_started/troubleshooting/deploy_hf_model_aml.py b/doc/getting_started/troubleshooting/deploy_hf_model_aml.py index b55a818a6c..b6fa4a1e67 100644 --- a/doc/getting_started/troubleshooting/deploy_hf_model_aml.py +++ b/doc/getting_started/troubleshooting/deploy_hf_model_aml.py @@ -106,14 +106,13 @@ # Set up the `DefaultAzureCredential` for seamless authentication with Azure services. This method should handle most authentication scenarios. If you encounter issues, refer to the [Azure Identity documentation](https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity?view=azure-python) for alternative credentials. # # %% -from typing import Union from azure.ai.ml import MLClient from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential try: - credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential() + credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential() credential.get_token("https://management.azure.com/.default") except Exception as ex: credential = InteractiveBrowserCredential() diff --git a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb index 29e7fdec00..65cb1229a3 100644 --- a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb +++ b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb @@ -71,7 +71,6 @@ "source": [ "# Import the Azure ML SDK components required for workspace connection and model management.\n", "import os\n", - "from typing import Union\n", "\n", "# Import necessary libraries for Azure ML operations and authentication\n", "from azure.ai.ml import MLClient, UserIdentityConfiguration\n", @@ -201,7 +200,7 @@ "source": [ "# Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary\n", "try:\n", - " credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n", + " credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n", " # Verify if the default credential can fetch a token successfully\n", " credential.get_token(\"https://management.azure.com/.default\")\n", "except Exception as ex:\n", diff --git a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py index 34013c449f..251e49ecf2 100644 --- a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py +++ b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py @@ -61,7 +61,6 @@ # %% # Import the Azure ML SDK components required for workspace connection and model management. import os -from typing import Union # Import necessary libraries for Azure ML operations and authentication from azure.ai.ml import MLClient, UserIdentityConfiguration @@ -160,7 +159,7 @@ # %% # Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary try: - credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential() + credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential() # Verify if the default credential can fetch a token successfully credential.get_token("https://management.azure.com/.default") except Exception as ex: diff --git a/pyproject.toml b/pyproject.toml index ee6ef1b728..243d636931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -341,8 +341,6 @@ ignore = [ "DOC502", # Raised exception is not explicitly raised "PERF203", # try-except-in-loop (intentional per-item error handling) "SIM117", # multiple-with-statements (combining often exceeds line length) - "UP007", # non-pep604-annotation-union (keep Union[X, Y] syntax) - "UP045", # non-pep604-annotation-optional (keep Optional[X] syntax) ] extend-select = [ "D204", # 1 blank line required after class docstring diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 9e158ed88a..2083bb720f 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import ( AttackOutcome, @@ -22,7 +22,7 @@ class AttackStats: """Statistics for attack analysis results.""" - success_rate: Optional[float] + success_rate: float | None total_decided: int successes: int failures: int @@ -118,7 +118,7 @@ def get_cached_results_for_technique( *, technique_eval_hash: str, objective_target_eval_hash: str, - additional_filters: Optional[Sequence[IdentifierFilter]] = None, + additional_filters: Sequence[IdentifierFilter] | None = None, ) -> list[AttackResult]: """ Return cached AttackResults matching a (technique × objective target) pair. @@ -170,7 +170,7 @@ def get_cached_results_for_technique( return matches -def _objective_target_eval_hash_for(attack_result: AttackResult) -> Optional[str]: +def _objective_target_eval_hash_for(attack_result: AttackResult) -> str | None: """ Return the ObjectiveTargetEvaluationIdentifier eval hash for a result. diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 06e2ff1ade..647f05f7c8 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -6,7 +6,7 @@ import inspect import logging import time -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast import msal from azure.core.credentials import AccessToken @@ -41,7 +41,7 @@ class TokenProviderCredential: get_azure_token_provider) and Azure SDK clients that require a TokenCredential object. """ - def __init__(self, token_provider: Callable[[], Union[str, Callable[..., Any]]]) -> None: + def __init__(self, token_provider: Callable[[], str | Callable[..., Any]]) -> None: """ Initialize TokenProviderCredential. @@ -75,7 +75,7 @@ class AsyncTokenProviderCredential: async clients that require an AsyncTokenCredential object (with async def get_token). """ - def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None: + def __init__(self, token_provider: Callable[[], str | Awaitable[str]]) -> None: """ Initialize AsyncTokenProviderCredential. @@ -394,7 +394,7 @@ def get_azure_openai_auth(endpoint: str) -> Callable[[], Awaitable[str]]: return get_azure_async_token_provider(scope) -def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str) -> speechsdk.SpeechConfig: +def get_speech_config(resource_id: str | None, key: str | None, region: str) -> speechsdk.SpeechConfig: """ Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair (for Entra auth scenarios). @@ -436,8 +436,8 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi async def get_speech_config_async( *, token_provider: Callable[[], str | Awaitable[str]] | None, - resource_id: Union[str, None], - key: Union[str, None], + resource_id: str | None, + key: str | None, region: str, ) -> speechsdk.SpeechConfig: """ diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index a069af33ff..434dff4fbe 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -7,7 +7,7 @@ import os import sys from datetime import datetime, timedelta, timezone -from typing import Any, Optional +from typing import Any from msal_extensions import FilePersistence, build_encrypted_persistence @@ -196,7 +196,7 @@ def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = Fals logger.error(f"Encryption unavailable ({e}) and fallback_to_plaintext is False. Cannot proceed.") raise - async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[str, Any]]: + async def _get_cached_token_if_available_and_valid_async(self) -> dict[str, Any] | None: """ Retrieve and validate cached token. @@ -258,7 +258,7 @@ async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[ logger.error(f"Failed to load cached token ({error_name}): {e}") return None - def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + def _save_token_to_cache(self, *, token: str, expires_in: int | None = None) -> None: """ Save token to persistent cache with metadata. @@ -301,7 +301,7 @@ def _clear_token_cache(self) -> None: except Exception as e: logger.error(f"Failed to clear cache: {e}") - async def _fetch_access_token_with_playwright_async(self) -> Optional[str]: + async def _fetch_access_token_with_playwright_async(self) -> str | None: """ Fetch access token using Playwright browser automation. @@ -339,7 +339,7 @@ async def _fetch_access_token_with_playwright_async(self) -> Optional[str]: # If not on Windows or using the right loop already, proceed normally return await self._run_playwright_browser_automation_async() - async def _run_playwright_in_thread_async(self) -> Optional[str]: + async def _run_playwright_in_thread_async(self) -> str | None: """ Run Playwright browser automation in a separate thread with ProactorEventLoop. This is needed on Windows when the main loop is SelectorEventLoop (e.g., in Jupyter). @@ -348,21 +348,21 @@ async def _run_playwright_in_thread_async(self) -> Optional[str]: Optional[str]: The bearer token if successfully retrieved, None otherwise. """ - def run_in_new_loop() -> Optional[str]: + def run_in_new_loop() -> str | None: if sys.platform == "win32": new_loop = asyncio.ProactorEventLoop() else: new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) try: - result: Optional[str] = new_loop.run_until_complete(self._run_playwright_browser_automation_async()) + result: str | None = new_loop.run_until_complete(self._run_playwright_browser_automation_async()) return result finally: new_loop.close() return await asyncio.get_running_loop().run_in_executor(None, run_in_new_loop) - async def _run_playwright_browser_automation_async(self) -> Optional[str]: + async def _run_playwright_browser_automation_async(self) -> str | None: """ Execute the actual Playwright browser automation to fetch the access token. diff --git a/pyrit/auth/manual_copilot_authenticator.py b/pyrit/auth/manual_copilot_authenticator.py index b175118878..24bf9da2e7 100644 --- a/pyrit/auth/manual_copilot_authenticator.py +++ b/pyrit/auth/manual_copilot_authenticator.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, Optional +from typing import Any import jwt @@ -36,7 +36,7 @@ class ManualCopilotAuthenticator(Authenticator): #: Environment variable name for the Copilot access token ACCESS_TOKEN_ENV_VAR: str = "COPILOT_ACCESS_TOKEN" - def __init__(self, *, access_token: Optional[str] = None) -> None: + def __init__(self, *, access_token: str | None = None) -> None: """ Initialize the ManualCopilotAuthenticator with a pre-obtained access token. diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 787154fe75..c1e898343b 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -10,7 +10,7 @@ import random import time from copy import deepcopy -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -132,7 +132,7 @@ def __init__( target: str, tokenizer: Any, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, + test_prefixes: list[str] | None = None, ) -> None: """ Initializes the AttackPrompt object with the provided parameters. @@ -417,8 +417,8 @@ def __init__( targets: list[str], tokenizer: Any, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - managers: Optional[dict[str, type[AttackPrompt]]] = None, + test_prefixes: list[str] | None = None, + managers: dict[str, type[AttackPrompt]] | None = None, ) -> None: """ Initializes the PromptManager object with the provided parameters. @@ -539,12 +539,12 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, ) -> None: """ Initializes the MultiPromptAttack object with the provided parameters. @@ -619,7 +619,7 @@ def get_filtered_cands( worker_index: int, control_cand: torch.Tensor, filter_cand: bool = True, - curr_control: Optional[str] = None, + curr_control: str | None = None, ) -> list[str]: cands, count = [], 0 worker = self.workers[worker_index] @@ -656,8 +656,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = True, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, anneal_from: int = 0, prev_loss: float = np.inf, @@ -873,12 +873,12 @@ def __init__( progressive_goals: bool = True, progressive_models: bool = True, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ @@ -986,8 +986,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = False, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, test_steps: int = 50, incr_control: bool = True, @@ -1119,12 +1119,12 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ @@ -1225,8 +1225,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = True, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, test_steps: int = 50, incr_control: bool = True, @@ -1331,12 +1331,12 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ @@ -1549,7 +1549,7 @@ def __init__( self.tokenizer = tokenizer self.tasks: mp.JoinableQueue[Any] = mp.JoinableQueue() self.results: mp.JoinableQueue[Any] = mp.JoinableQueue() - self.process: Optional[mp.Process] = None + self.process: mp.Process | None = None @staticmethod def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]) -> None: diff --git a/pyrit/auxiliary_attacks/gcg/experiments/log.py b/pyrit/auxiliary_attacks/gcg/experiments/log.py index bdd96c1ca4..68c7ded7eb 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/log.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/log.py @@ -3,7 +3,7 @@ import logging import subprocess as sp -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -20,7 +20,7 @@ def log_params( *, params: Any, - param_keys: Optional[list[str]] = None, + param_keys: list[str] | None = None, ) -> None: """ Log selected parameters via Python logging. diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index cf353a2adf..4c812594e9 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -38,7 +38,7 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, Optional, overload +from typing import Any, overload import numpy as np import torch.multiprocessing as mp @@ -93,8 +93,8 @@ class GCGContext(PromptGeneratorStrategyContext): workers: list[Any] = field(default_factory=list) test_workers: list[Any] = field(default_factory=list) - attack: Optional[Any] = None - logfile_path: Optional[str] = None + attack: Any | None = None + logfile_path: str | None = None class GCGResult(PromptGeneratorStrategyResult): @@ -138,11 +138,11 @@ def __init__( self, *, models: list[GCGModelConfig], - algorithm: Optional[GCGAlgorithmConfig] = None, - strategy: Optional[GCGStrategyConfig] = None, - output: Optional[GCGOutputConfig] = None, - test_models: Optional[list[GCGModelConfig]] = None, - hf_token: Optional[str] = None, + algorithm: GCGAlgorithmConfig | None = None, + strategy: GCGStrategyConfig | None = None, + output: GCGOutputConfig | None = None, + test_models: list[GCGModelConfig] | None = None, + hf_token: str | None = None, ) -> None: """ Initialize the GCG generator. @@ -307,9 +307,9 @@ async def execute_async( *, goals: list[str], targets: list[str], - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - memory_labels: Optional[dict[str, str]] = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> GCGResult: ... diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 5807c27bef..b93eceb64e 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -17,7 +17,7 @@ import uuid from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from urllib.parse import quote, urlparse from azure.identity.aio import DefaultAzureCredential @@ -155,7 +155,7 @@ async def _sign_blob_url_async(*, blob_url: str) -> str: return blob_url -def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str]: +def _resolve_media_url(*, value: str | None, data_type: str) -> str | None: """ For media path types, convert a local file path to a ``/api/media`` URL. @@ -311,7 +311,7 @@ def pyrit_scores_to_dto(scores: list[PyritScore]) -> list[Score]: ] -def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Optional[str]: +def _infer_mime_type(*, value: str | None, data_type: PromptDataType) -> str | None: """ Infer MIME type from a value and its data type. @@ -335,9 +335,9 @@ def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Opti def _build_filename( *, data_type: str, - sha256: Optional[str], - value: Optional[str], -) -> Optional[str]: + sha256: str | None, + value: str | None, +) -> str | None: """ Build a human-readable download filename from the data type and hash. @@ -462,7 +462,7 @@ def request_piece_to_pyrit_message_piece( role: ChatMessageRole, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> PyritMessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -509,7 +509,7 @@ def request_to_pyrit_message( request: AddMessageRequest, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> PyritMessage: """ Build a PyRIT Message from an AddMessageRequest DTO. diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index f1d097762d..a71b5aa537 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -5,8 +5,6 @@ Converter mappers – domain → DTO translation for converter-related models. """ -from typing import Optional - from pyrit.backend.models.converters import ConverterInstance from pyrit.prompt_converter import PromptConverter @@ -21,7 +19,7 @@ def converter_object_to_instance( converter_id: str, converter_obj: PromptConverter, *, - sub_converter_ids: Optional[list[str]] = None, + sub_converter_ids: list[str] | None = None, ) -> ConverterInstance: """ Build a ConverterInstance DTO from a registry converter object. diff --git a/pyrit/backend/middleware/auth.py b/pyrit/backend/middleware/auth.py index db7de281ea..012af51912 100644 --- a/pyrit/backend/middleware/auth.py +++ b/pyrit/backend/middleware/auth.py @@ -18,7 +18,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import httpx import jwt @@ -241,7 +241,7 @@ async def _resolve_excess_groups_async(self, claims: dict[str, Any], token: str) logger.warning("Failed to resolve group memberships: %s", e) return [] - def _validate_token(self, token: str) -> tuple[Optional[AuthenticatedUser], dict[str, Any]]: + def _validate_token(self, token: str) -> tuple[AuthenticatedUser | None, dict[str, Any]]: """ Validate a JWT against Entra ID JWKS. diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 2f98f78b7e..6cb739cb5e 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -9,7 +9,7 @@ """ from datetime import datetime -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field @@ -26,8 +26,8 @@ class Score(BaseModel): score_value: str = Field( ..., description="Score value ('true'/'false' for true_false, '0.0'-'1.0' for float_scale)" ) - score_category: Optional[list[str]] = Field(None, description="Harm categories (e.g., ['hate', 'violence'])") - score_rationale: Optional[str] = Field(None, description="Explanation for the score") + score_category: list[str] | None = Field(None, description="Harm categories (e.g., ['hate', 'violence'])") + score_rationale: str | None = Field(None, description="Explanation for the score") scored_at: datetime = Field(..., description="When the score was generated") @@ -46,24 +46,24 @@ class MessagePiece(BaseModel): converted_value_data_type: str = Field( default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc." ) - original_value: Optional[str] = Field(default=None, description="Original value before conversion") - original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") + original_value: str | None = Field(default=None, description="Original value before conversion") + original_value_mime_type: str | None = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") - converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") + converted_value_mime_type: str | None = Field(default=None, description="MIME type of converted value") scores: list[Score] = Field(default_factory=list, description="Scores embedded in this piece") response_error: PromptResponseError = Field( default="none", description="Error status: none, processing, blocked, empty, unknown" ) - response_error_description: Optional[str] = Field( + response_error_description: str | None = Field( default=None, description="Description of the error if response_error is not 'none'" ) - original_filename: Optional[str] = Field( + original_filename: str | None = Field( default=None, description="Original filename extracted from file path or blob URL" ) - converted_filename: Optional[str] = Field( + converted_filename: str | None = Field( default=None, description="Converted filename extracted from file path or blob URL" ) - prompt_metadata: Optional[dict[str, Any]] = Field( + prompt_metadata: dict[str, Any] | None = Field( default=None, description="Metadata associated with the piece (e.g., video_id for remix mode)" ) @@ -86,8 +86,8 @@ class TargetInfo(BaseModel): """Target information extracted from the stored TargetIdentifier.""" target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name") + endpoint: str | None = Field(None, description="Target endpoint URL") + model_name: str | None = Field(None, description="Model or deployment name") class RetryEventResponse(BaseModel): @@ -110,20 +110,18 @@ class AttackSummary(BaseModel): attack_result_id: str = Field(..., description="Database-assigned unique ID for this AttackResult") conversation_id: str = Field(..., description="Primary conversation of this attack result") attack_type: str = Field("", description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") - attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") - target: Optional[TargetInfo] = Field(None, description="Target information from the stored identifier") + attack_specific_params: dict[str, Any] | None = Field(None, description="Additional attack-specific parameters") + target: TargetInfo | None = Field(None, description="Target information from the stored identifier") converters: list[str] = Field( default_factory=list, description="Request converter class names applied in this attack" ) objective: str = Field("", description="Natural-language description of the attacker's objective") - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Field( + outcome: Literal["undetermined", "success", "failure", "error"] | None = Field( None, description="Attack outcome (null if not yet determined)" ) outcome_reason: str | None = Field(None, description="Reason for the outcome") last_response: str | None = Field(None, description="Model response from the final turn") - last_message_preview: Optional[str] = Field( - None, description="Preview of the last message (truncated to ~100 chars)" - ) + last_message_preview: str | None = Field(None, description="Preview of the last message (truncated to ~100 chars)") score_value: str | None = Field(None, description="Score value from the objective scorer") executed_turns: int = Field(0, ge=0, description="Number of turns executed") execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") @@ -195,13 +193,13 @@ class MessagePieceRequest(BaseModel): data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") original_value: str = Field(..., description="Original value (text or base64 for media)") - converted_value: Optional[str] = Field(None, description="Converted value. If provided, bypasses converters.") - mime_type: Optional[str] = Field(None, description="MIME type for media content") - prompt_metadata: Optional[dict[str, Any]] = Field( + converted_value: str | None = Field(None, description="Converted value. If provided, bypasses converters.") + mime_type: str | None = Field(None, description="MIME type for media content") + prompt_metadata: dict[str, Any] | None = Field( None, description="Metadata to attach to the piece (e.g., {'video_id': '...'} for remix mode).", ) - original_prompt_id: Optional[str] = Field( + original_prompt_id: str | None = Field( None, description="ID of the source piece when prepending from an existing conversation. " "Preserves lineage so the new piece traces back to the original.", @@ -231,18 +229,16 @@ class CreateAttackRequest(BaseModel): supplied in ``labels`` (typically the current operator's labels). """ - name: Optional[str] = Field(None, description="Attack name/label") + name: str | None = Field(None, description="Attack name/label") target_registry_name: str = Field(..., description="Target registry name to attack") - source_conversation_id: Optional[str] = Field( + source_conversation_id: str | None = Field( None, description="Conversation to branch from (clone messages into the new attack)" ) - cutoff_index: Optional[int] = Field( - None, description="Include messages up to and including this turn index (0-based)" - ) - prepended_conversation: Optional[list[PrependedMessageRequest]] = Field( + cutoff_index: int | None = Field(None, description="Include messages up to and including this turn index (0-based)") + prepended_conversation: list[PrependedMessageRequest] | None = Field( None, description="Messages to prepend (system prompts, branching context)", max_length=200 ) - labels: Optional[dict[str, str]] = Field(None, description="User-defined labels for filtering") + labels: dict[str, str] | None = Field(None, description="User-defined labels for filtering") class CreateAttackResponse(BaseModel): @@ -274,8 +270,8 @@ class ConversationSummary(BaseModel): conversation_id: str = Field(..., description="Unique conversation identifier") message_count: int = Field(0, description="Number of messages in this conversation") - last_message_preview: Optional[str] = Field(None, description="Preview of the last message") - created_at: Optional[datetime] = Field(None, description="Timestamp of the first message") + last_message_preview: str | None = Field(None, description="Preview of the last message") + created_at: datetime | None = Field(None, description="Timestamp of the first message") class AttackConversationsResponse(BaseModel): @@ -297,10 +293,8 @@ class CreateConversationRequest(BaseModel): the cutoff turn, preserving tracking relationships (original_prompt_id). """ - source_conversation_id: Optional[str] = Field(None, description="Conversation to branch from") - cutoff_index: Optional[int] = Field( - None, description="Include messages up to and including this turn index (0-based)" - ) + source_conversation_id: str | None = Field(None, description="Conversation to branch from") + cutoff_index: int | None = Field(None, description="Include messages up to and including this turn index (0-based)") class CreateConversationResponse(BaseModel): @@ -344,11 +338,11 @@ class AddMessageRequest(BaseModel): default=True, description="If True, send to target and wait for response. If False, just store in memory.", ) - target_registry_name: Optional[str] = Field( + target_registry_name: str | None = Field( None, description="Target registry name. Required when send=True so the backend knows which target to use.", ) - converter_ids: Optional[list[str]] = Field( + converter_ids: list[str] | None = Field( None, description="Converter instance IDs to apply (overrides attack-level)" ) target_conversation_id: str = Field( @@ -356,7 +350,7 @@ class AddMessageRequest(BaseModel): description="The conversation_id to store and send messages under. " "Usually the attack's main conversation, but can be a related conversation.", ) - labels: Optional[dict[str, str]] = Field( + labels: dict[str, str] | None = Field( None, description="Labels to attach to every message piece. " "Falls back to labels from existing pieces in the conversation.", diff --git a/pyrit/backend/models/common.py b/pyrit/backend/models/common.py index 0a2e00e6b5..36767467cc 100644 --- a/pyrit/backend/models/common.py +++ b/pyrit/backend/models/common.py @@ -7,7 +7,7 @@ Includes pagination, error handling (RFC 7807), and shared base models. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -17,8 +17,8 @@ class PaginationInfo(BaseModel): limit: int = Field(..., description="Maximum items per page") has_more: bool = Field(..., description="Whether more items exist") - next_cursor: Optional[str] = Field(None, description="Cursor for next page") - prev_cursor: Optional[str] = Field(None, description="Cursor for previous page") + next_cursor: str | None = Field(None, description="Cursor for next page") + prev_cursor: str | None = Field(None, description="Cursor for previous page") class FieldError(BaseModel): @@ -26,8 +26,8 @@ class FieldError(BaseModel): field: str = Field(..., description="Field name with path (e.g., 'pieces[0].data_type')") message: str = Field(..., description="Error message") - code: Optional[str] = Field(None, description="Error code") - value: Optional[Any] = Field(None, description="The invalid value") + code: str | None = Field(None, description="Error code") + value: Any | None = Field(None, description="The invalid value") class ProblemDetail(BaseModel): @@ -41,8 +41,8 @@ class ProblemDetail(BaseModel): title: str = Field(..., description="Short human-readable summary") status: int = Field(..., description="HTTP status code") detail: str = Field(..., description="Human-readable explanation") - instance: Optional[str] = Field(None, description="URI of the specific occurrence") - errors: Optional[list[FieldError]] = Field(None, description="Field-level errors for validation") + instance: str | None = Field(None, description="URI of the specific occurrence") + errors: list[FieldError] | None = Field(None, description="Field-level errors for validation") # Sensitive field patterns to filter from identifiers diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index ba5ca5390d..dd216b84b3 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -7,7 +7,7 @@ This module defines the Instance models and preview functionality. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -38,9 +38,9 @@ class ConverterParameterSchema(BaseModel): name: str = Field(..., description="Parameter name") type_name: str = Field(..., description="Human-readable type (e.g. 'str', 'int', 'Literal[...]')") required: bool = Field(..., description="Whether the parameter must be provided") - default_value: Optional[str] = Field(None, description="String representation of default value, if any") - choices: Optional[list[str]] = Field(None, description="Allowed values for Literal types") - description: Optional[str] = Field(None, description="Parameter description from docstring") + default_value: str | None = Field(None, description="String representation of default value, if any") + choices: list[str] | None = Field(None, description="Allowed values for Literal types") + description: str | None = Field(None, description="Parameter description from docstring") class ConverterCatalogEntry(BaseModel): @@ -57,7 +57,7 @@ class ConverterCatalogEntry(BaseModel): default_factory=list, description="Constructor parameters for dynamic form generation" ) is_llm_based: bool = Field(False, description="Whether this converter requires an LLM target") - description: Optional[str] = Field(None, description="Short description of the converter from its docstring") + description: str | None = Field(None, description="Short description of the converter from its docstring") class ConverterCatalogResponse(BaseModel): @@ -76,17 +76,17 @@ class ConverterInstance(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") converter_type: str = Field(..., description="Converter class name (e.g., 'Base64Converter')") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") supported_input_types: list[str] = Field( default_factory=list, description="Input data types supported by this converter" ) supported_output_types: list[str] = Field( default_factory=list, description="Output data types produced by this converter" ) - converter_specific_params: Optional[dict[str, Any]] = Field( + converter_specific_params: dict[str, Any] | None = Field( None, description="Additional converter-specific parameters" ) - sub_converter_ids: Optional[list[str]] = Field( + sub_converter_ids: list[str] | None = Field( None, description="Converter IDs of sub-converters (for pipelines/composites)" ) @@ -101,7 +101,7 @@ class CreateConverterRequest(BaseModel): """Request to create a new converter instance.""" type: str = Field(..., description="Converter type (e.g., 'Base64Converter')") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") params: dict[str, Any] = Field( default_factory=dict, description="Converter constructor parameters", @@ -113,7 +113,7 @@ class CreateConverterResponse(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") converter_type: str = Field(..., description="Converter class name") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") # ============================================================================ diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 54480b76c7..aaac688cf0 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -10,7 +10,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -40,7 +40,7 @@ class RegisteredScenario(BaseModel): ) all_strategies: list[str] = Field(..., description="All available concrete strategy names") default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") - max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") + max_dataset_size: int | None = Field(None, description="Maximum items per dataset (None means unlimited)") supported_parameters: list[ScenarioParameterSummary] = Field( default_factory=list, description="Scenario-declared custom parameters" ) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index da512da155..944fbc358f 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -11,7 +11,7 @@ This module defines the Instance models for runtime target management. """ -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field @@ -56,16 +56,14 @@ class TargetInstance(BaseModel): target_registry_name: str = Field(..., description="Target registry key (e.g., 'azure_openai_chat')") target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name used in API calls") - underlying_model_name: Optional[str] = Field( - None, description="Underlying model name if different (e.g., 'gpt-4o')" - ) - temperature: Optional[float] = Field(None, description="Temperature parameter for generation") - top_p: Optional[float] = Field(None, description="Top-p parameter for generation") - max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") + endpoint: str | None = Field(None, description="Target endpoint URL") + model_name: str | None = Field(None, description="Model or deployment name used in API calls") + underlying_model_name: str | None = Field(None, description="Underlying model name if different (e.g., 'gpt-4o')") + temperature: float | None = Field(None, description="Temperature parameter for generation") + top_p: float | None = Field(None, description="Top-p parameter for generation") + max_requests_per_minute: int | None = Field(None, description="Maximum requests per minute") capabilities: TargetCapabilitiesInfo = Field(..., description="Structured snapshot of target capabilities") - target_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional target-specific parameters") + target_specific_params: dict[str, Any] | None = Field(None, description="Additional target-specific parameters") class TargetListResponse(BaseModel): diff --git a/pyrit/backend/pyrit_backend.py b/pyrit/backend/pyrit_backend.py index 0770ae501a..6c792f383d 100644 --- a/pyrit/backend/pyrit_backend.py +++ b/pyrit/backend/pyrit_backend.py @@ -17,12 +17,11 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import Optional from pyrit.common.cli_helpers import CONFIG_FILE_HELP, validate_log_level_argparse -def parse_args(*, args: Optional[list[str]] = None) -> Namespace: +def parse_args(*, args: list[str] | None = None) -> Namespace: """ Parse command-line arguments for the PyRIT backend server. @@ -88,7 +87,7 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace: return parser.parse_args(args) -def main(*, args: Optional[list[str]] = None) -> int: +def main(*, args: list[str] | None = None) -> int: """ Start the PyRIT backend server. diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index d6844d5041..7f41ec4339 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -10,7 +10,7 @@ import logging from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from fastapi import APIRouter, HTTPException, Query, status @@ -39,7 +39,7 @@ router = APIRouter(prefix="/attacks", tags=["attacks"]) -def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str | Sequence[str]]]: +def _parse_labels(label_params: list[str] | None) -> dict[str, str | Sequence[str]] | None: """ Parse 'key:value' label query params into a dict grouping values by key. @@ -69,13 +69,13 @@ def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str | response_model=AttackListResponse, ) async def list_attacks( # pyrit-async-suffix-exempt - attack_types: Optional[list[str]] = Query( + attack_types: list[str] | None = Query( None, description="Filter by attack type names. May be specified multiple times to OR-match " "across types (e.g. ?attack_types=A&attack_types=B). Case-insensitive. " "Omit to return all attacks regardless of type.", ), - converter_types: Optional[list[str]] = Query( + converter_types: list[str] | None = Query( None, description="Filter by converter type names. May be specified multiple times; " "combination semantics are controlled by converter_types_match " @@ -88,24 +88,24 @@ async def list_attacks( # pyrit-async-suffix-exempt description="How to combine multiple converter_types: 'any' (attack has at least one) " "or 'all' (attack has every one). Defaults to 'all'.", ), - has_converters: Optional[bool] = Query( + has_converters: bool | None = Query( None, description="Filter by converter presence. true = attacks with at least one converter; " "false = attacks with no converters. Omit for no filter.", ), - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query( + outcome: Literal["undetermined", "success", "failure", "error"] | None = Query( None, description="Filter by outcome" ), - label: Optional[list[str]] = Query( + label: list[str] | None = Query( None, description="Filter by labels (format: key:value). May be specified multiple times; " "OR-matched within a key, AND-matched across keys " "(e.g. ?label=op:red&label=op:blue matches op=red OR op=blue).", ), - min_turns: Optional[int] = Query(None, ge=0, description="Filter by minimum executed turns"), - max_turns: Optional[int] = Query(None, ge=0, description="Filter by maximum executed turns"), + min_turns: int | None = Query(None, ge=0, description="Filter by minimum executed turns"), + max_turns: int | None = Query(None, ge=0, description="Filter by maximum executed turns"), limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), - cursor: Optional[str] = Query( + cursor: str | None = Query( None, description="Pagination cursor: the attack_result_id of the last item from the previous page. " "Omit to start from the beginning. The response includes next_cursor for the next page.", diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index d75857210b..941d8021fb 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -12,8 +12,6 @@ /api/scenarios/runs — scenario execution lifecycle """ -from typing import Optional - from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail @@ -41,7 +39,7 @@ ) async def list_scenarios( # pyrit-async-suffix-exempt limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (scenario_name to start after)"), + cursor: str | None = Query(None, description="Pagination cursor (scenario_name to start after)"), ) -> ListRegisteredScenariosResponse: """ List all available scenarios. diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 5a05ea41fd..bea53ddef2 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -8,8 +8,6 @@ Target types are set at app startup via initializers - you cannot add new types at runtime. """ -from typing import Optional - from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail @@ -32,7 +30,7 @@ ) async def list_targets( # pyrit-async-suffix-exempt limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (target_registry_name)"), + cursor: str | None = Query(None, description="Pagination cursor (target_registry_name)"), ) -> TargetListResponse: """ List target instances with pagination. diff --git a/pyrit/backend/routes/version.py b/pyrit/backend/routes/version.py index b59d176158..2d75c200aa 100644 --- a/pyrit/backend/routes/version.py +++ b/pyrit/backend/routes/version.py @@ -6,7 +6,6 @@ import json import logging from pathlib import Path -from typing import Optional from fastapi import APIRouter, Request from pydantic import BaseModel @@ -23,12 +22,12 @@ class VersionResponse(BaseModel): """Version information response model.""" version: str - source: Optional[str] = None - commit: Optional[str] = None - modified: Optional[bool] = None + source: str | None = None + commit: str | None = None + modified: bool | None = None display: str - database_info: Optional[str] = None - default_labels: Optional[dict[str, str]] = None + database_info: str | None = None + default_labels: dict[str, str] | None = None @router.get("", response_model=VersionResponse) @@ -62,7 +61,7 @@ async def get_version_async(request: Request) -> VersionResponse: logger.warning(f"Failed to load build info: {e}") # Detect current database backend - database_info: Optional[str] = None + database_info: str | None = None try: memory = CentralMemory.get_memory_instance() db_type = type(memory).__name__ @@ -74,7 +73,7 @@ async def get_version_async(request: Request) -> VersionResponse: logger.debug(f"Could not detect database info: {e}") # Read default labels from app state (set by pyrit_backend CLI) - default_labels: Optional[dict[str, str]] = getattr(request.app.state, "default_labels", None) or None + default_labels: dict[str, str] | None = getattr(request.app.state, "default_labels", None) or None return VersionResponse( version=version, diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index 0c8d4719eb..21e1f1c5bd 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -21,7 +21,6 @@ import time from dataclasses import dataclass, field from enum import Enum -from typing import Optional from pyrit.cli._banner_assets import BRAILLE_RACCOON, PYRIT_LETTERS, PYRIT_WIDTH, RACCOON_TAIL @@ -199,7 +198,7 @@ def _build_static_banner() -> StaticBannerData: color_map: dict[int, ColorRole] = {} segment_colors: dict[int, list[tuple[int, int, ColorRole]]] = {} - def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, ColorRole]]] = None) -> None: + def add(line: str, role: ColorRole, segments: list[tuple[int, int, ColorRole]] | None = None) -> None: idx = len(lines) color_map[idx] = role if segments: @@ -559,14 +558,14 @@ def _render_line_with_segments( """ reset = _get_color(ColorRole.RESET, theme) # Build per-character color map (later segments override earlier ones) - char_roles: list[Optional[ColorRole]] = [None] * len(line) + char_roles: list[ColorRole | None] = [None] * len(line) for start, end, role in segments: for pos in range(start, min(end, len(line))): char_roles[pos] = role # Group consecutive same-role characters for efficient rendering result: list[str] = [] - current_role: Optional[ColorRole] = None + current_role: ColorRole | None = None for pos, ch in enumerate(line): char_role = char_roles[pos] if char_role != current_role: diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index e982b6ae06..a53ac152ef 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -21,7 +21,7 @@ import logging import shlex from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, get_origin +from typing import TYPE_CHECKING, Any, get_origin from pyrit.common.cli_helpers import ( CONFIG_FILE_HELP, @@ -67,7 +67,7 @@ def validate_database(*, database: str) -> str: return database -def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: +def validate_integer(value: str, *, name: str = "value", min_value: int | None = None) -> int: """ Validate and parse an integer value. @@ -492,7 +492,7 @@ def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> di return result -def parse_run_arguments(*, args_string: str, declared_params: Optional[list[Parameter]] = None) -> dict[str, Any]: +def parse_run_arguments(*, args_string: str, declared_params: list[Parameter] | None = None) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -643,7 +643,7 @@ def extract_scenario_args(*, parsed: dict[str, Any]) -> dict[str, Any]: # --------------------------------------------------------------------------- -def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[list[Parameter]]: +def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Parameter] | None: """ Build ``Parameter`` objects from a scenario catalog's ``supported_parameters``. @@ -669,7 +669,7 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[l else: resolved_type = type_map.get(type_display) raw_choices = p.get("choices") - choices: Optional[tuple[Any, ...]] = tuple(raw_choices) if raw_choices else None + choices: tuple[Any, ...] | None = tuple(raw_choices) if raw_choices else None parameters.append( Parameter( name=p["name"], @@ -699,7 +699,7 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: def merge_config_scenario_args( *, - config_scenario: Optional[ScenarioConfig], + config_scenario: ScenarioConfig | None, effective_scenario_name: str, cli_args: dict[str, Any], ) -> dict[str, Any]: diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 57ee0a3326..1e4467f929 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -17,7 +17,7 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.cli._cli_args import ( ARG_HELP, @@ -332,7 +332,7 @@ def _extract_scenario_args(*, parsed: Namespace) -> dict[str, Any]: } -def parse_args(args: Optional[list[str]] = None) -> Namespace: +def parse_args(args: list[str] | None = None) -> Namespace: """ Parse command-line arguments (pass 1 — tolerant of scenario-declared flags). @@ -760,7 +760,7 @@ async def _run_async(*, parsed_args: Namespace) -> int: return 1 -def main(args: Optional[list[str]] = None) -> int: +def main(args: list[str] | None = None) -> int: """ Start the PyRIT scanner CLI. diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 78dbed1aeb..1a0760eb7c 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -18,7 +18,7 @@ import sys import threading from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.cli import _banner as banner @@ -173,7 +173,7 @@ def _ensure_client(self) -> bool: self._start_server = False # only auto-start once return True - def cmdloop(self, intro: Optional[str] = None) -> None: + def cmdloop(self, intro: str | None = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: prev_disable = logging.root.manager.disable diff --git a/pyrit/common/default_values.py b/pyrit/common/default_values.py index 9dbcba427f..4334cfb5d2 100644 --- a/pyrit/common/default_values.py +++ b/pyrit/common/default_values.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ def get_required_value(*, env_var_name: str, passed_value: Any) -> Any: raise ValueError(f"Environment variable {env_var_name} is required") -def get_non_required_value(*, env_var_name: str, passed_value: Optional[str] = None) -> str: +def get_non_required_value(*, env_var_name: str, passed_value: str | None = None) -> str: """ Get a non-required value from an environment variable or a passed value, preferring the passed value. diff --git a/pyrit/common/net_utility.py b/pyrit/common/net_utility.py index eb75f5616e..17d944e1e2 100644 --- a/pyrit/common/net_utility.py +++ b/pyrit/common/net_utility.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Literal, Optional, cast, overload +from typing import Any, Literal, cast, overload from urllib.parse import parse_qs, urlparse, urlunparse import httpx @@ -88,10 +88,10 @@ async def make_request_and_raise_if_error_async( method: str, post_type: PostType = "json", debug: bool = False, - extra_url_parameters: Optional[dict[str, str]] = None, - request_body: Optional[dict[str, object]] = None, - files: Optional[dict[str, tuple[str, bytes, str]]] = None, - headers: Optional[dict[str, str]] = None, + extra_url_parameters: dict[str, str] | None = None, + request_body: dict[str, object] | None = None, + files: dict[str, tuple[str, bytes, str]] | None = None, + headers: dict[str, str] | None = None, **httpx_client_kwargs: Any, ) -> httpx.Response: """ diff --git a/pyrit/common/utils.py b/pyrit/common/utils.py index a7203ca336..d4a75e579c 100644 --- a/pyrit/common/utils.py +++ b/pyrit/common/utils.py @@ -8,12 +8,12 @@ import math import random from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar logger = logging.getLogger(__name__) -def verify_and_resolve_path(path: Union[str, Path]) -> Path: +def verify_and_resolve_path(path: str | Path) -> Path: """ Verify that a path is valid and resolve it to an absolute path. @@ -39,9 +39,7 @@ def verify_and_resolve_path(path: Union[str, Path]) -> Path: return path_obj -def combine_dict( - existing_dict: Optional[dict[str, Any]] = None, new_dict: Optional[dict[str, Any]] = None -) -> dict[str, Any]: +def combine_dict(existing_dict: dict[str, Any] | None = None, new_dict: dict[str, Any] | None = None) -> dict[str, Any]: """ Combine two dictionaries containing string keys and values into one. @@ -58,7 +56,7 @@ def combine_dict( return result -def combine_list(list1: Union[str, list[str]], list2: Union[str, list[str]]) -> list[str]: +def combine_list(list1: str | list[str], list2: str | list[str]) -> list[str]: """ Combine two lists or strings into a single list with unique values. @@ -126,7 +124,7 @@ def to_sha256(data: str) -> str: def warn_if_set( - *, config: Any, unused_fields: list[str], log: Union[logging.Logger, logging.LoggerAdapter[logging.Logger]] = logger + *, config: Any, unused_fields: list[str], log: logging.Logger | logging.LoggerAdapter[logging.Logger] = logger ) -> None: """ Warn about unused parameters in configurations. @@ -169,8 +167,8 @@ def get_kwarg_param( param_name: str, expected_type: type[_T], required: bool = True, - default_value: Optional[_T] = None, -) -> Optional[_T]: + default_value: _T | None = None, +) -> _T | None: """ Validate and extract a parameter from kwargs. diff --git a/pyrit/common/yaml_loadable.py b/pyrit/common/yaml_loadable.py index a7857b7ad7..2fb4422c73 100644 --- a/pyrit/common/yaml_loadable.py +++ b/pyrit/common/yaml_loadable.py @@ -3,7 +3,7 @@ import abc from pathlib import Path -from typing import TypeVar, Union +from typing import TypeVar import yaml @@ -18,7 +18,7 @@ class YamlLoadable(abc.ABC): # noqa: B024 """ @classmethod - def from_yaml_file(cls: type[T], file: Union[Path | str]) -> T: + def from_yaml_file(cls: type[T], file: Path | str) -> T: """ Create a new object from a YAML file. diff --git a/pyrit/datasets/executors/question_answer/wmdp_dataset.py b/pyrit/datasets/executors/question_answer/wmdp_dataset.py index 1270c9b6c0..81f4711747 100644 --- a/pyrit/datasets/executors/question_answer/wmdp_dataset.py +++ b/pyrit/datasets/executors/question_answer/wmdp_dataset.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from datasets import load_dataset @@ -12,7 +11,7 @@ ) -def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDataset: +def fetch_wmdp_dataset(category: str | None = None) -> QuestionAnsweringDataset: """ Fetch WMDP examples and create a QuestionAnsweringDataset. diff --git a/pyrit/datasets/jailbreak/text_jailbreak.py b/pyrit/datasets/jailbreak/text_jailbreak.py index 6e5083bd42..380ee9c5c1 100644 --- a/pyrit/datasets/jailbreak/text_jailbreak.py +++ b/pyrit/datasets/jailbreak/text_jailbreak.py @@ -5,7 +5,7 @@ import random import threading from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.path import JAILBREAK_TEMPLATES_PATH from pyrit.models import SeedPrompt @@ -18,7 +18,7 @@ class TextJailBreak: A class that manages jailbreak datasets (like DAN, etc.). """ - _template_cache: Optional[dict[str, list[Path]]] = None + _template_cache: dict[str, list[Path]] | None = None _cache_lock: threading.Lock = threading.Lock() @classmethod @@ -99,9 +99,9 @@ def _get_all_template_paths(cls) -> list[Path]: def __init__( self, *, - template_path: Optional[str] = None, - template_file_name: Optional[str] = None, - string_template: Optional[str] = None, + template_path: str | None = None, + template_file_name: str | None = None, + string_template: str | None = None, random_template: bool = False, **kwargs: Any, ) -> None: @@ -208,7 +208,7 @@ def _apply_extra_kwargs(self, kwargs: dict[str, Any]) -> None: self.template.value = self.template.render_template_value_silent(**kwargs) @classmethod - def get_jailbreak_templates(cls, num_templates: Optional[int] = None) -> list[str]: + def get_jailbreak_templates(cls, num_templates: int | None = None) -> list[str]: """ Retrieve all jailbreaks from the JAILBREAK_TEMPLATES_PATH. diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 18f8343330..0ccae77fb1 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -5,7 +5,7 @@ from collections.abc import Callable from dataclasses import fields from pathlib import Path -from typing import Any, Optional +from typing import Any import yaml @@ -76,7 +76,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Extract metadata from a local YAML file and coerce raw values into typed schema fields. diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index dbc866a47b..bdb502ad23 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -14,7 +14,7 @@ import logging from collections.abc import Mapping from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.models import data_serializer_factory @@ -25,11 +25,11 @@ async def fetch_and_cache_image_async( *, filename: str, - image_url: Optional[str] = None, - image_bytes: Optional[bytes] = None, + image_url: str | None = None, + image_bytes: bytes | None = None, log_prefix: str = "image-cache", - request_headers: Optional[Mapping[str, str]] = None, - request_timeout: Optional[float] = None, + request_headers: Mapping[str, str] | None = None, + request_timeout: float | None = None, follow_redirects: bool = False, ) -> str: """ diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index 87e0d47b3c..1e38198fb6 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from datasets import load_dataset @@ -68,35 +68,34 @@ class _AegisContentSafetyDataset(_RemoteDatasetLoader): def __init__( self, *, - harm_categories: Optional[ - list[ - Literal[ - "Controlled/Regulated Substances", - "Copyright/Trademark/Plagiarism", - "Criminal Planning/Confessions", - "Fraud/Deception", - "Guns and Illegal Weapons", - "Harassment", - "Hate/Identity Hate", - "High Risk Gov Decision Making", - "Illegal Activity", - "Immoral/Unethical", - "Malware", - "Manipulation", - "Needs Caution", - "Other", - "PII/Privacy", - "Political/Misinformation/Conspiracy", - "Profanity", - "Sexual", - "Sexual (minor)", - "Suicide and Self Harm", - "Threat", - "Unauthorized Advice", - "Violence", - ] + harm_categories: list[ + Literal[ + "Controlled/Regulated Substances", + "Copyright/Trademark/Plagiarism", + "Criminal Planning/Confessions", + "Fraud/Deception", + "Guns and Illegal Weapons", + "Harassment", + "Hate/Identity Hate", + "High Risk Gov Decision Making", + "Illegal Activity", + "Immoral/Unethical", + "Malware", + "Manipulation", + "Needs Caution", + "Other", + "PII/Privacy", + "Political/Misinformation/Conspiracy", + "Profanity", + "Sexual", + "Sexual (minor)", + "Suicide and Self Harm", + "Threat", + "Unauthorized Advice", + "Violence", ] - ] = None, + ] + | None = None, ) -> None: """ Initialize the NVIDIA Aegis AI Content Safety Dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py b/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py index 338747b5e8..04839057da 100644 --- a/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py @@ -3,7 +3,7 @@ import logging from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -134,10 +134,10 @@ def __init__( "db793f9/data/autoresearch/adversarial-samples.json" ), source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[ATRCategory]] = None, - techniques: Optional[list[str]] = None, - detection_fields: Optional[list[ATRDetectionField]] = None, - variation_types: Optional[list[ATRVariationType]] = None, + categories: list[ATRCategory] | None = None, + techniques: list[str] | None = None, + detection_fields: list[ATRDetectionField] | None = None, + variation_types: list[ATRVariationType] | None = None, ) -> None: """ Initialize the ATR dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py index 05a905c172..b6bfddd203 100644 --- a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py @@ -3,7 +3,7 @@ import ast import logging -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -50,22 +50,21 @@ def __init__( language: Literal[ "English", "Hindi", "French", "Spanish", "Arabic", "Russian", "Serbian", "Tagalog" ] = "English", - harm_categories: Optional[ - list[ - Literal[ - "Bullying & Harassment", - "Discrimination & Injustice", - "Graphic material", - "Harms of Representation Allocation and Quality of Service", - "Hate Speech", - "Non-consensual sexual content", - "Profanity", - "Self-Harm", - "Violence, Threats & Incitement", - ] + harm_categories: list[ + Literal[ + "Bullying & Harassment", + "Discrimination & Injustice", + "Graphic material", + "Harms of Representation Allocation and Quality of Service", + "Hate Speech", + "Non-consensual sexual content", + "Profanity", + "Self-Harm", + "Violence, Threats & Incitement", ] - ] = None, - harm_scope: Optional[Literal["global", "local"]] = None, + ] + | None = None, + harm_scope: Literal["global", "local"] | None = None, ) -> None: """ Initialize the Aya Red-teaming dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py index 44d9f4c244..2b559f6427 100644 --- a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -32,7 +32,7 @@ def __init__( self, *, source: str = "Babelscape/ALERT", - category: Optional[Literal["alert", "alert_adversarial"]] = "alert_adversarial", + category: Literal["alert", "alert_adversarial"] | None = "alert_adversarial", ) -> None: """ Initialize the Babelscape ALERT dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py index d430a504a4..08e85043c9 100644 --- a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py @@ -9,7 +9,7 @@ import zipfile from enum import Enum from pathlib import Path -from typing import Literal, Optional +from typing import Literal from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.common.path import DB_DATA_PATH @@ -167,8 +167,8 @@ def __init__( *, use_tiny: bool = True, variant: FigStepVariant = FigStepVariant.FIGSTEP, - categories: Optional[list[FigStepCategory]] = None, - source: Optional[str] = None, + categories: list[FigStepCategory] | None = None, + source: str | None = None, source_type: Literal["public_url", "file"] = "public_url", ) -> None: """ @@ -241,8 +241,8 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: required_keys = {"dataset", "category_id", "task_id", "category_name", "question", "instruction"} rows = self._fetch_from_url(source=self.source, source_type=self.source_type, cache=cache) - pro_extract_dir: Optional[Path] = None - pro_benign_sentences: Optional[list[str]] = None + pro_extract_dir: Path | None = None + pro_benign_sentences: list[str] | None = None if self.variant == FigStepVariant.FIGSTEP_PRO: pro_extract_dir, pro_benign_sentences = await self._ensure_figstep_pro_assets_async(cache=cache) diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index d77767d933..215f1be687 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, @@ -56,7 +56,7 @@ def __init__( "harmbench_behaviors_multimodal_all.csv" ), source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[SemanticCategory]] = None, + categories: list[SemanticCategory] | None = None, ) -> None: """ Initialize the HarmBench multimodal dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py index 5ff18a8953..251cfb5405 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py @@ -6,7 +6,7 @@ import uuid import zipfile from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -87,7 +87,7 @@ def __init__( source: str = "JailbreakV-28K/JailBreakV-28k", zip_dir: str = str(pathlib.Path.home()), split: Literal["JailBreakV_28K", "mini_JailBreakV_28K"] = "mini_JailBreakV_28K", - harm_categories: Optional[list[_HarmCategory]] = None, + harm_categories: list[_HarmCategory] | None = None, ) -> None: """ Initialize the JailBreakV-28K dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py index e76b294d83..c4218b82d9 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py @@ -3,7 +3,6 @@ import logging from enum import Enum -from typing import Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -77,7 +76,7 @@ def __init__( self, *, source: str = "JailbreakV-28K/JailBreakV-28k", - harm_categories: Optional[list[_HarmCategory]] = None, + harm_categories: list[_HarmCategory] | None = None, ) -> None: """ Initialize the JailBreakV Redteam_2k dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py index 01fd374281..38a8cabfb5 100644 --- a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -33,31 +33,30 @@ def __init__( *, source: str = "PKU-Alignment/PKU-SafeRLHF", include_safe_prompts: bool = True, - filter_harm_categories: Optional[ - list[ - Literal[ - "Animal Abuse", - "Copyright Issues", - "Cybercrime", - "Discriminatory Behavior", - "Disrupting Public Order", - "Drugs", - "Economic Crime", - "Endangering National Security", - "Endangering Public Health", - "Environmental Damage", - "Human Trafficking", - "Insulting Behavior", - "Mental Manipulation", - "Physical Harm", - "Privacy Violation", - "Psychological Harm", - "Sexual Content", - "Violence", - "White-Collar Crime", - ] + filter_harm_categories: list[ + Literal[ + "Animal Abuse", + "Copyright Issues", + "Cybercrime", + "Discriminatory Behavior", + "Disrupting Public Order", + "Drugs", + "Economic Crime", + "Endangering National Security", + "Endangering Public Health", + "Environmental Damage", + "Human Trafficking", + "Insulting Behavior", + "Mental Manipulation", + "Physical Harm", + "Privacy Violation", + "Psychological Harm", + "Sexual Content", + "Violence", + "White-Collar Crime", ] - ] = None, + ] + | None = None, ) -> None: """ Initialize the PKU-SafeRLHF dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py index 80998a3af8..cba08eafdb 100644 --- a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py @@ -5,7 +5,7 @@ import os from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any import requests @@ -76,10 +76,10 @@ class _PromptIntelDataset(_RemoteDatasetLoader): def __init__( self, *, - api_key: Optional[str] = None, - severity: Optional[PromptIntelSeverity] = None, - categories: Optional[list[PromptIntelCategory]] = None, - search: Optional[str] = None, + api_key: str | None = None, + severity: PromptIntelSeverity | None = None, + categories: list[PromptIntelCategory] | None = None, + search: str | None = None, ) -> None: """ Initialize the PromptIntel dataset loader. @@ -139,7 +139,7 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: } # Build list of category values to fetch; [None] means fetch all categories - categories_to_fetch: list[Optional[str]] = [c.value for c in self._categories] if self._categories else [None] + categories_to_fetch: list[str | None] = [c.value for c in self._categories] if self._categories else [None] all_prompts: list[dict[str, Any]] = [] seen_ids: set[str] = set() @@ -187,7 +187,7 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: return all_prompts - def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]: + def _parse_datetime(self, date_str: str | None) -> datetime | None: """ Parse an ISO 8601 datetime string from the API. @@ -252,7 +252,7 @@ def _build_metadata(self, record: dict[str, Any]) -> dict[str, str | int]: return metadata - def _convert_record_to_seed_prompt(self, record: dict[str, Any]) -> Optional[SeedPrompt]: + def _convert_record_to_seed_prompt(self, record: dict[str, Any]) -> SeedPrompt | None: """ Convert a single PromptIntel record into a SeedPrompt. diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 0fc9bdd3b9..d536428f69 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -13,7 +13,7 @@ from dataclasses import fields from enum import Enum from pathlib import Path -from typing import Any, Literal, Optional, TextIO, cast +from typing import Any, Literal, TextIO, cast from urllib.parse import urlparse import requests @@ -288,10 +288,10 @@ async def _fetch_from_huggingface_async( self, *, dataset_name: str, - config: Optional[str] = None, - split: Optional[str] = None, + config: str | None = None, + split: str | None = None, cache: bool = True, - token: Optional[str] = None, + token: str | None = None, **kwargs: Any, ) -> Any: """ @@ -356,7 +356,7 @@ def _load_dataset_sync() -> Any: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Extract metadata from class attributes, wrap in sets, and format into SeedDatasetMetadata. @@ -423,7 +423,7 @@ async def _fetch_zip_from_url_async( def _download_and_parse() -> dict[str, list[dict[str, Any]]]: zip_path: Path - temp_to_clean: Optional[Path] = None + temp_to_clean: Path | None = None if cache and cache_path.exists(): zip_path = cache_path else: diff --git a/pyrit/datasets/seed_datasets/remote/siuo_dataset.py b/pyrit/datasets/seed_datasets/remote/siuo_dataset.py index 1a95fe9d5e..5e14c03049 100644 --- a/pyrit/datasets/seed_datasets/remote/siuo_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/siuo_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, @@ -110,7 +110,7 @@ def __init__( *, source: str = GEN_JSON_URL, source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[SIUOCategory]] = None, + categories: list[SIUOCategory] | None = None, ) -> None: """ Initialize the SIUO dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py index 241b4cc7ff..e9c6936288 100644 --- a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py @@ -3,7 +3,6 @@ import logging import os -from typing import Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -103,9 +102,9 @@ def __init__( self, *, source: str = "sorry-bench/sorry-bench-202503", - categories: Optional[list[str]] = None, - prompt_style: Optional[str] = None, - token: Optional[str] = None, + categories: list[str] | None = None, + prompt_style: str | None = None, + token: str | None = None, ) -> None: """ Initialize the Sorry-Bench dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index 31e11769cd..b32026e921 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, @@ -82,8 +82,8 @@ def __init__( *, source: str = METADATA_URL, source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[VisualLeakBenchCategory]] = None, - pii_types: Optional[list[VisualLeakBenchPIIType]] = None, + categories: list[VisualLeakBenchCategory] | None = None, + pii_types: list[VisualLeakBenchPIIType] | None = None, ) -> None: """ Initialize the VisualLeakBench dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 26a2460afb..a52b339ea4 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, @@ -62,9 +62,9 @@ def __init__( *, source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv", source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[VLSUCategory]] = None, - unsafe_grades: Optional[list[str]] = None, - max_examples: Optional[int] = None, + categories: list[VLSUCategory] | None = None, + unsafe_grades: list[str] | None = None, + max_examples: int | None = None, ) -> None: """ Initialize the ML-VLSU multimodal dataset loader. diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 3e27d9e051..56714e2a94 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import fields as dc_fields -from typing import Any, Optional +from typing import Any from tqdm import tqdm @@ -120,7 +120,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # pyrit-as ) return await self.fetch_dataset_async(cache=cache) - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Parse provider-specific metadata into the shared schema. @@ -144,7 +144,7 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: + async def get_all_dataset_names_async(cls, filters: SeedDatasetFilter | None = None) -> list[str]: """ Get the names of all registered datasets. @@ -273,7 +273,7 @@ def _match_single_criterion( async def fetch_datasets_async( cls, *, - dataset_names: Optional[list[str]] = None, + dataset_names: list[str] | None = None, cache: bool = True, max_concurrency: int = 5, ) -> list[SeedDataset]: @@ -315,7 +315,7 @@ async def fetch_datasets_async( async def fetch_single_dataset_async( provider_name: str, provider_class: type["SeedDatasetProvider"] - ) -> Optional[tuple[str, SeedDataset]]: + ) -> tuple[str, SeedDataset] | None: """ Fetch a single dataset with error handling. @@ -341,7 +341,7 @@ async def fetch_single_dataset_async( async def fetch_with_semaphore_async( provider_name: str, provider_class: type["SeedDatasetProvider"] - ) -> Optional[tuple[str, SeedDataset]]: + ) -> tuple[str, SeedDataset] | None: """ Enforce concurrency limit and update progress during dataset fetch. diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 7b40b95a6b..cdd1149a85 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, fields from enum import Enum -from typing import Any, ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal logger = logging.getLogger(__name__) @@ -94,12 +94,12 @@ class SeedDatasetMetadata: # All fields are optional sets to support both real metadata (single-element) # and filter criteria (multi-element). SINGULAR_FIELDS enforces that parsers # only produce single-element sets for size and source_type. - tags: Optional[set[str]] = None - size: Optional[set[str]] = None - modalities: Optional[set[str]] = None - source_type: Optional[set[str]] = None - load_time: Optional[set[SeedDatasetLoadTime]] = None - harm_categories: Optional[set[str]] = None + tags: set[str] | None = None + size: set[str] | None = None + modalities: set[str] | None = None + source_type: set[str] | None = None + load_time: set[SeedDatasetLoadTime] | None = None + harm_categories: set[str] | None = None # Fields that must have at most 1 element in real dataset metadata. SINGULAR_FIELDS: ClassVar[frozenset[str]] = frozenset({"size", "source_type"}) @@ -195,7 +195,7 @@ class SeedDatasetFilter: def __init__( self, *, - criteria: Optional[list[SeedDatasetMetadata]] = None, + criteria: list[SeedDatasetMetadata] | None = None, strict_match: bool = False, **kwargs: Any, ) -> None: diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index 5efbb69107..86a774c404 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable -from typing import Any, Optional +from typing import Any import tenacity from openai import AsyncOpenAI @@ -31,9 +31,9 @@ class OpenAITextEmbedding(EmbeddingSupport): def __init__( self, *, - api_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - endpoint: Optional[str] = None, - model_name: Optional[str] = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + endpoint: str | None = None, + model_name: str | None = None, ) -> None: """ Initialize text embedding client for Azure OpenAI or platform OpenAI. diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index b402499e98..b5e71bbfe3 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -6,7 +6,7 @@ import os from abc import ABC from collections.abc import Callable -from typing import Any, Optional +from typing import Any from openai import RateLimitError from tenacity import ( @@ -176,7 +176,7 @@ def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Excepti class ServerErrorException(PyritException): """Exception class for opaque 5xx errors returned by the server.""" - def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None: + def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: str | None = None) -> None: """ Initialize a server error exception. @@ -247,7 +247,7 @@ class ExperimentalWarning(FutureWarning): def pyrit_custom_result_retry( - retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None + retry_function: Callable[..., bool], retry_max_num_attempts: int | None = None ) -> Callable[..., Any]: """ Apply retry logic with exponential backoff to a function. diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index 9b45ac8737..7ee3a0ccce 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -13,7 +13,7 @@ from contextvars import ContextVar from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier @@ -59,25 +59,25 @@ class ExecutionContext: component_role: ComponentRole = ComponentRole.UNKNOWN # The attack strategy class name (e.g., "PromptSendingAttack") - attack_strategy_name: Optional[str] = None + attack_strategy_name: str | None = None # The identifier for the attack strategy - attack_identifier: Optional[ComponentIdentifier] = None + attack_identifier: ComponentIdentifier | None = None # The identifier from the component's get_identifier() (target, scorer, etc.) - component_identifier: Optional[ComponentIdentifier] = None + component_identifier: ComponentIdentifier | None = None # The objective target conversation ID if available - objective_target_conversation_id: Optional[str] = None + objective_target_conversation_id: str | None = None # The endpoint/URI if available (extracted from component_identifier for quick access) - endpoint: Optional[str] = None + endpoint: str | None = None # The component class name (extracted from component_identifier.__type__ for quick access) - component_name: Optional[str] = None + component_name: str | None = None # The attack objective if available - objective: Optional[str] = None + objective: str | None = None def get_retry_context_string(self) -> str: """ @@ -135,10 +135,10 @@ def get_exception_details(self) -> str: # The contextvar that stores the current execution context -_execution_context: ContextVar[Optional[ExecutionContext]] = ContextVar("execution_context", default=None) +_execution_context: ContextVar[ExecutionContext | None] = ContextVar("execution_context", default=None) -def get_execution_context() -> Optional[ExecutionContext]: +def get_execution_context() -> ExecutionContext | None: """ Get the current execution context. @@ -213,11 +213,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def execution_context( *, component_role: ComponentRole, - attack_strategy_name: Optional[str] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - component_identifier: Optional[ComponentIdentifier] = None, - objective_target_conversation_id: Optional[str] = None, - objective: Optional[str] = None, + attack_strategy_name: str | None = None, + attack_identifier: ComponentIdentifier | None = None, + component_identifier: ComponentIdentifier | None = None, + objective_target_conversation_id: str | None = None, + objective: str | None = None, ) -> ExecutionContextManager: """ Create an execution context manager with the specified parameters. diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index b7d554775c..164ca74e11 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -56,7 +56,7 @@ def get_adversarial_chat_messages( adversarial_chat_conversation_id: str, attack_identifier: ComponentIdentifier, adversarial_chat_target_identifier: ComponentIdentifier, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> list[Message]: """ Transform prepended conversation messages for adversarial chat with swapped roles. @@ -146,7 +146,7 @@ async def build_conversation_context_string_async(messages: list[Message]) -> st return await normalizer.normalize_string_async(messages) -def get_prepended_turn_count(prepended_conversation: Optional[list[Message]]) -> int: +def get_prepended_turn_count(prepended_conversation: list[Message] | None) -> int: """ Count the number of turns (assistant responses) in a prepended conversation. @@ -191,7 +191,7 @@ def __init__( self, *, attack_identifier: ComponentIdentifier, - prompt_normalizer: Optional[PromptNormalizer] = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the conversation manager. @@ -219,9 +219,7 @@ def get_conversation(self, conversation_id: str) -> list[Message]: conversation = self._memory.get_conversation(conversation_id=conversation_id) return list(conversation) - def get_last_message( - self, *, conversation_id: str, role: Optional[ChatMessageRole] = None - ) -> Optional[MessagePiece]: + def get_last_message(self, *, conversation_id: str, role: ChatMessageRole | None = None) -> MessagePiece | None: """ Retrieve the most recent message from a conversation. @@ -251,7 +249,7 @@ def set_system_prompt( target: PromptTarget, conversation_id: str, system_prompt: str, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> None: """ Set or update the system prompt for a conversation. @@ -288,10 +286,10 @@ async def initialize_context_async( context: "AttackContext[Any]", target: PromptTarget, conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]] = None, + request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, - max_turns: Optional[int] = None, - memory_labels: Optional[dict[str, str]] = None, + max_turns: int | None = None, + memory_labels: dict[str, str] | None = None, ) -> ConversationState: """ Initialize attack context with prepended conversation and merged labels. @@ -438,9 +436,9 @@ async def add_prepended_conversation_to_memory_async( *, prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]] = None, + request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, - max_turns: Optional[int] = None, + max_turns: int | None = None, ) -> int: """ Add prepended conversation messages to memory for a chat target. @@ -519,9 +517,9 @@ async def _process_prepended_for_chat_target_async( context: "AttackContext[Any]", prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]], + request_converters: list[PromptConverterConfiguration] | None, prepended_conversation_config: Optional["PrependedConversationConfig"], - max_turns: Optional[int], + max_turns: int | None, ) -> ConversationState: """ Process prepended conversation for a chat target. @@ -587,7 +585,7 @@ async def _apply_converters_async( *, message: Message, request_converters: list[PromptConverterConfiguration], - apply_to_roles: Optional[list[ChatMessageRole]], + apply_to_roles: list[ChatMessageRole] | None, ) -> None: """ Apply converters to message pieces. diff --git a/pyrit/executor/attack/compound/sequential_attack.py b/pyrit/executor/attack/compound/sequential_attack.py index 7e851507d4..15e6eb1dc3 100644 --- a/pyrit/executor/attack/compound/sequential_attack.py +++ b/pyrit/executor/attack/compound/sequential_attack.py @@ -26,7 +26,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pydantic import Field @@ -107,8 +107,8 @@ class SequentialChildAttack: strategy: AttackStrategy[Any, AttackResult] seed_group: SeedAttackGroup - adversarial_chat: Optional[PromptTarget] = None - objective_scorer: Optional[TrueFalseScorer] = None + adversarial_chat: PromptTarget | None = None + objective_scorer: TrueFalseScorer | None = None memory_labels: Mapping[str, str] = field(default_factory=dict) @@ -288,7 +288,7 @@ async def _run_child_attack_async( *, child_attack: SequentialChildAttack, memory_labels: dict[str, str], - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, ) -> AttackResult: """ Execute one child attack via ``AttackExecutor`` and return its result. diff --git a/pyrit/executor/attack/core/attack_config.py b/pyrit/executor/attack/core/attack_config.py index c86131f769..803c6c4296 100644 --- a/pyrit/executor/attack/core/attack_config.py +++ b/pyrit/executor/attack/core/attack_config.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Union from pyrit.executor.core import StrategyConverterConfig from pyrit.models import SeedPrompt @@ -26,10 +25,10 @@ class AttackAdversarialConfig: target: PromptTarget # Path to the YAML file containing the system prompt for the adversarial chat target - system_prompt_path: Optional[Union[str, Path]] = None + system_prompt_path: str | Path | None = None # Seed prompt for the adversarial chat target (supports {{ objective }} template variable) - seed_prompt: Union[str, SeedPrompt] = "Generate your first message to achieve: {{ objective }}" + seed_prompt: str | SeedPrompt = "Generate your first message to achieve: {{ objective }}" @dataclass @@ -42,10 +41,10 @@ class AttackScoringConfig: """ # Primary scorer for evaluating attack effectiveness - objective_scorer: Optional[TrueFalseScorer] = None + objective_scorer: TrueFalseScorer | None = None # Refusal scorer for detecting refusals or non-compliance - refusal_scorer: Optional[TrueFalseScorer] = None + refusal_scorer: TrueFalseScorer | None = None # Additional scorers for auxiliary metrics or custom evaluations auxiliary_scorers: list[Scorer] = field(default_factory=list) diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index 88c2108b8b..4f3ecb2cbe 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -145,8 +145,8 @@ def __init__(self, *, max_concurrency: int = 1) -> None: # and then run it under more than one ``asyncio.run(...)`` invocation. By # constructing the semaphore inside ``_get_semaphore()`` and rebuilding when the # running loop changes, one AttackExecutor instance is safe to reuse across loops. - self._semaphore: Optional[asyncio.Semaphore] = None - self._semaphore_loop: Optional[asyncio.AbstractEventLoop] = None + self._semaphore: asyncio.Semaphore | None = None + self._semaphore_loop: asyncio.AbstractEventLoop | None = None def _get_semaphore(self) -> asyncio.Semaphore: """ @@ -174,9 +174,9 @@ async def execute_attack_from_seed_groups_async( seed_groups: Sequence[SeedAttackGroup], adversarial_chat: Optional["PromptTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - field_overrides: Optional[Sequence[dict[str, Any]]] = None, + field_overrides: Sequence[dict[str, Any]] | None = None, return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -254,9 +254,9 @@ async def execute_attack_async( *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], objectives: Sequence[str], - field_overrides: Optional[Sequence[dict[str, Any]]] = None, + field_overrides: Sequence[dict[str, Any]] | None = None, return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -323,7 +323,7 @@ async def _execute_with_params_list_async( attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], params_list: Sequence[AttackParameters], return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel with a list of pre-built parameters. diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 72739c59ed..7ce2cdecee 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -5,7 +5,7 @@ import dataclasses from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.models import Message, SeedAttackGroup, SeedGroup @@ -33,13 +33,13 @@ class AttackParameters: objective: str # Optional message to send to the objective target (overrides objective if provided) - next_message: Optional[Message] = None + next_message: Message | None = None # Conversation that is automatically prepended to the target model - prepended_conversation: Optional[list[Message]] = None + prepended_conversation: list[Message] | None = None # Additional labels that can be applied to the prompts throughout the attack - memory_labels: Optional[dict[str, str]] = field(default_factory=dict) + memory_labels: dict[str, str] | None = field(default_factory=dict) def __str__(self) -> str: """Return a nicely formatted string representation of the attack parameters.""" @@ -78,8 +78,8 @@ async def from_seed_group_async( cls: type[AttackParamsT], *, seed_group: SeedAttackGroup, - adversarial_chat: Optional[PromptTarget] = None, - objective_scorer: Optional[TrueFalseScorer] = None, + adversarial_chat: PromptTarget | None = None, + objective_scorer: TrueFalseScorer | None = None, **overrides: Any, ) -> AttackParamsT: """ diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 887a467a10..2a2ac9f022 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -10,7 +10,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload from pyrit.common.logger import logger from pyrit.exceptions.retry_collector import ( @@ -68,16 +68,16 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): related_conversations: set[ConversationReference] = field(default_factory=set) # Mutable overrides for attacks that generate these values internally - _next_message_override: Optional[Message] = None - _prepended_conversation_override: Optional[list[Message]] = None - _memory_labels_override: Optional[dict[str, str]] = None + _next_message_override: Message | None = None + _prepended_conversation_override: list[Message] | None = None + _memory_labels_override: dict[str, str] | None = None # Optional attribution from an upstream orchestrator (e.g. Scenario). When # set, the persistence path stamps attribution_parent_id + attribution_data # onto the resulting AttackResult so it can be located later for hydration # and resume. Set by AttackExecutor per-task before scheduling. Stays None # for ad-hoc/direct attack execution outside any orchestrator. - _attribution: Optional[AttackResultAttribution] = None + _attribution: AttackResultAttribution | None = None # Convenience properties that delegate to params or overrides @property @@ -115,7 +115,7 @@ def prepended_conversation(self, value: list[Message]) -> None: self._prepended_conversation_override = value @property - def next_message(self) -> Optional[Message]: + def next_message(self) -> Message | None: """Optional message to send to the objective target.""" # Check override first (for attacks that generate internally) if self._next_message_override is not None: @@ -126,7 +126,7 @@ def next_message(self) -> Optional[Message]: return None @next_message.setter - def next_message(self, value: Optional[Message]) -> None: + def next_message(self, value: Message | None) -> None: """Set the next message (for attacks that generate internally).""" self._next_message_override = value @@ -392,8 +392,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the attack strategy identifier. @@ -411,7 +411,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this attack strategy. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self.get_objective_target().get_identifier(), } @@ -468,7 +468,7 @@ def get_objective_target(self) -> PromptTarget: """ return self._objective_target - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -495,9 +495,9 @@ async def execute_async( self, *, objective: str, - next_message: Optional[Message] = None, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + next_message: Message | None = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackStrategyResultT: ... diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 6289b39841..b0ec80118f 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -5,7 +5,7 @@ import textwrap from dataclasses import dataclass, field from string import Formatter -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.exceptions import ComponentRole, execution_context @@ -102,9 +102,9 @@ def __init__( total_length: int = 200, chunk_type: str = "characters", request_template: str = DEFAULT_TEMPLATE, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the chunked request attack strategy. @@ -167,7 +167,7 @@ def __init__( attack_scoring_config = attack_scoring_config or AttackScoringConfig() self._auxiliary_scorers = attack_scoring_config.auxiliary_scorers - self._objective_scorer: Optional[TrueFalseScorer] = attack_scoring_config.objective_scorer + self._objective_scorer: TrueFalseScorer | None = attack_scoring_config.objective_scorer # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() @@ -176,7 +176,7 @@ def __init__( prompt_normalizer=self._prompt_normalizer, ) - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -333,8 +333,8 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac def _determine_attack_outcome( self, *, - score: Optional[Score], - ) -> tuple[AttackOutcome, Optional[str]]: + score: Score | None, + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the score. @@ -359,7 +359,7 @@ async def _score_combined_value_async( *, combined_value: str, objective: str, - ) -> Optional[Score]: + ) -> Score | None: """ Score the combined chunk responses against the objective. diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index a241affa66..ab578cc668 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -8,7 +8,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -75,7 +75,7 @@ class CrescendoAttackContext(MultiTurnAttackContext[Any]): """Context for the Crescendo attack strategy.""" # Text that was refused by the target in the previous attempt (used for backtracking) - refused_text: Optional[str] = None + refused_text: str | None = None # Counter for number of backtracks performed during the attack backtrack_count: int = 0 @@ -144,12 +144,12 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_backtracks: int = 10, max_turns: int = 10, - prepended_conversation_config: Optional[PrependedConversationConfig] = None, + prepended_conversation_config: PrependedConversationConfig | None = None, ) -> None: """ Initialize the Crescendo attack strategy. @@ -249,7 +249,7 @@ def __init__( # Store the prepended conversation configuration self._prepended_conversation_config = prepended_conversation_config - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -315,7 +315,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: ) # Set up adversarial chat with prepended conversation - adversarial_chat_context: Optional[str] = None + adversarial_chat_context: str | None = None if context.prepended_conversation: # Build context string for system prompt normalizer = ConversationContextNormalizer() @@ -760,7 +760,7 @@ async def _backtrack_memory_async(self, *, conversation_id: str) -> str: self._logger.debug(f"Backtracked conversation from {conversation_id} to {new_conversation_id}") return new_conversation_id - def _set_adversarial_chat_system_prompt_template(self, *, system_prompt_template_path: Union[Path, str]) -> None: + def _set_adversarial_chat_system_prompt_template(self, *, system_prompt_template_path: Path | str) -> None: """ Set the system prompt template for the adversarial chat. diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 2c4255984b..5a92e9cf10 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -47,7 +47,7 @@ class MultiPromptSendingAttackParameters(AttackParameters): Only accepts objective and user_messages fields. """ - user_messages: Optional[list[Message]] = None + user_messages: list[Message] | None = None @classmethod async def from_seed_group_async( @@ -137,9 +137,9 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the multi-prompt sending attack strategy. @@ -179,7 +179,7 @@ def __init__( prompt_normalizer=self._prompt_normalizer, ) - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -301,10 +301,10 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac def _determine_attack_outcome( self, *, - response: Optional[Message], - score: Optional[Score], + response: Message | None, + score: Score | None, context: MultiTurnAttackContext[Any], - ) -> tuple[AttackOutcome, Optional[str]]: + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the response and score. @@ -340,7 +340,7 @@ async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None async def _send_prompt_to_objective_target_async( self, *, current_message: Message, context: MultiTurnAttackContext[Any] - ) -> Optional[Message]: + ) -> Message | None: """ Send the prompt to the target and return the response. @@ -370,7 +370,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), ) - async def _evaluate_response_async(self, *, response: Message, objective: str) -> Optional[Score]: + async def _evaluate_response_async(self, *, response: Message, objective: str) -> Score | None: """ Evaluate the response against the objective using the configured scorers. diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 809b150988..6416e9570d 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -58,10 +58,10 @@ class MultiTurnAttackContext(AttackContext[AttackParamsT]): executed_turns: int = 0 # Model response produced in the latest turn - last_response: Optional[Message] = None + last_response: Message | None = None # Score assigned to the latest response by a scorer component - last_score: Optional[Score] = None + last_score: Score | None = None class MultiTurnAttackStrategy(AttackStrategy[MultiTurnAttackStrategyContextT, AttackStrategyResultT], ABC): diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 84cb085fca..57c7a12f3b 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -6,7 +6,7 @@ import enum import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH @@ -100,9 +100,9 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_turns: int = 10, score_last_turn_only: bool = False, ) -> None: @@ -175,7 +175,7 @@ def __init__( self._max_turns = max_turns self._score_last_turn_only = score_last_turn_only - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -578,7 +578,7 @@ async def _send_prompt_to_objective_target_async( return response - async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) -> Optional[Score]: + async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) -> Score | None: """ Evaluate the objective target's response with the objective scorer. @@ -613,7 +613,7 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - objective_scores = scoring_results return objective_scores[0] if objective_scores else None - def _set_adversarial_chat_seed_prompt(self, *, seed_prompt: Union[str, SeedPrompt]) -> None: + def _set_adversarial_chat_seed_prompt(self, *, seed_prompt: str | SeedPrompt) -> None: """ Set the seed prompt for the adversarial chat. diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index a814424c92..895516243c 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -39,11 +39,11 @@ async def generate_simulated_conversation_async( objective_scorer: TrueFalseScorer, num_turns: int = 3, starting_sequence: int = 0, - adversarial_chat_system_prompt_path: Union[str, Path], - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - next_message_system_prompt_path: Optional[Union[str, Path]] = None, - attack_converter_config: Optional[AttackConverterConfig] = None, - memory_labels: Optional[dict[str, str]] = None, + adversarial_chat_system_prompt_path: str | Path, + simulated_target_system_prompt_path: str | Path | None = None, + next_message_system_prompt_path: str | Path | None = None, + attack_converter_config: AttackConverterConfig | None = None, + memory_labels: dict[str, str] | None = None, ) -> list[SeedPrompt]: """ Generate a simulated conversation between an adversarial chat and a target. @@ -171,7 +171,7 @@ async def _generate_next_message_async( objective: str, conversation_messages: list[Message], adversarial_chat: PromptTarget, - next_message_system_prompt_path: Union[str, Path], + next_message_system_prompt_path: str | Path, ) -> Message: """ Generate a single next message using the adversarial chat LLM. diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 94f3bc9ac1..543e48d470 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -106,8 +106,8 @@ def __init__( self, *, objective_scorer: FloatScaleThresholdScorer, - refusal_scorer: Optional[TrueFalseScorer] = None, - auxiliary_scorers: Optional[list[Scorer]] = None, + refusal_scorer: TrueFalseScorer | None = None, + auxiliary_scorers: list[Scorer] | None = None, use_score_as_feedback: bool = True, ) -> None: """ @@ -171,9 +171,9 @@ class TAPAttackContext(MultiTurnAttackContext[Any]): nodes: list["_TreeOfAttacksNode"] = field(default_factory=list) # Best conversation ID and score found during the attack - best_conversation_id: Optional[str] = None - best_objective_score: Optional[Score] = None - best_adversarial_conversation_id: Optional[str] = None + best_conversation_id: str | None = None + best_objective_score: Score | None = None + best_adversarial_conversation_id: str | None = None class TAPAttackResult(AttackResult): @@ -185,7 +185,7 @@ class TAPAttackResult(AttackResult): """ @property - def tree_visualization(self) -> Optional[Tree]: + def tree_visualization(self) -> Tree | None: """Get the tree visualization from metadata.""" return self.metadata.get("tree_visualization", None) @@ -235,12 +235,12 @@ def auxiliary_scores_summary(self, value: dict[str, float]) -> None: self.metadata["auxiliary_scores_summary"] = value @property - def best_adversarial_conversation_id(self) -> Optional[str]: + def best_adversarial_conversation_id(self) -> str | None: """Get the adversarial conversation ID for the best-scoring branch.""" - return cast("Optional[str]", self.metadata.get("best_adversarial_conversation_id", None)) + return cast("str | None", self.metadata.get("best_adversarial_conversation_id", None)) @best_adversarial_conversation_id.setter - def best_adversarial_conversation_id(self, value: Optional[str]) -> None: + def best_adversarial_conversation_id(self, value: str | None) -> None: """Set the best adversarial conversation ID.""" self.metadata["best_adversarial_conversation_id"] = value @@ -285,16 +285,16 @@ def __init__( adversarial_chat_system_seed_prompt: SeedPrompt, desired_response_prefix: str, objective_scorer: Scorer, - on_topic_scorer: Optional[Scorer], + on_topic_scorer: Scorer | None, request_converters: list[PromptConverterConfiguration], response_converters: list[PromptConverterConfiguration], - auxiliary_scorers: Optional[list[Scorer]], + auxiliary_scorers: list[Scorer] | None, attack_id: ComponentIdentifier, attack_strategy_name: str, - memory_labels: Optional[dict[str, str]] = None, - parent_id: Optional[str] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, - initial_prompt: Optional[Message] = None, + memory_labels: dict[str, str] | None = None, + parent_id: str | None = None, + prompt_normalizer: PromptNormalizer | None = None, + initial_prompt: Message | None = None, ) -> None: """ Initialize a tree node. @@ -353,21 +353,21 @@ def __init__( # Execution results (populated after send_prompt_async) self.completed = False self.off_topic = False - self.objective_score: Optional[Score] = None + self.objective_score: Score | None = None self.auxiliary_scores: dict[str, Score] = {} - self.last_prompt_sent: Optional[str] = None - self.last_response: Optional[str] = None - self.error_message: Optional[str] = None + self.last_prompt_sent: str | None = None + self.last_response: str | None = None + self.error_message: str | None = None # Context from prepended conversation (for adversarial chat system prompt) - self._conversation_context: Optional[str] = None + self._conversation_context: str | None = None # Initial prompt for first turn (bypasses adversarial chat generation) # This supports multimodal messages - self._initial_prompt: Optional[Message] = initial_prompt + self._initial_prompt: Message | None = initial_prompt # Current objective (set when send_prompt_async is called) - self._objective: Optional[str] = None + self._objective: str | None = None async def initialize_with_prepended_conversation_async( self, @@ -1925,8 +1925,8 @@ def _create_attack_node( self, *, context: TAPAttackContext, - parent_id: Optional[str] = None, - initial_prompt: Optional[Message] = None, + parent_id: str | None = None, + initial_prompt: Message | None = None, ) -> _TreeOfAttacksNode: """ Create a new attack node with the configured settings. @@ -2037,7 +2037,7 @@ def _format_node_result(self, node: _TreeOfAttacksNode) -> str: unnormalized_score = round(1 + normalized_score * 9) return f"Score: {unnormalized_score}/10" - def _create_on_topic_scorer(self, objective: str) -> Optional[Scorer]: + def _create_on_topic_scorer(self, objective: str) -> Scorer | None: """ Create an on-topic scorer if enabled, configured for the specific objective. @@ -2187,7 +2187,7 @@ def _create_attack_result( return result - def _get_last_response_from_conversation(self, conversation_id: Optional[str]) -> Optional[MessagePiece]: + def _get_last_response_from_conversation(self, conversation_id: str | None) -> MessagePiece | None: """ Retrieve the last response from a conversation. @@ -2261,7 +2261,7 @@ async def execute_async( self, *, objective: str, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> TAPAttackResult: ... diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index 61e58b9e5f..05bc22621b 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -59,12 +59,12 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, - context_description_instructions_path: Optional[Path] = None, - affirmative_response: Optional[str] = None, + context_description_instructions_path: Path | None = None, + affirmative_response: str | None = None, ) -> None: """ Initialize the context compliance attack strategy. diff --git a/pyrit/executor/attack/single_turn/flip_attack.py b/pyrit/executor/attack/single_turn/flip_attack.py index 035ef2212d..878ff1da1a 100644 --- a/pyrit/executor/attack/single_turn/flip_attack.py +++ b/pyrit/executor/attack/single_turn/flip_attack.py @@ -4,7 +4,7 @@ import logging import pathlib import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -41,9 +41,9 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ diff --git a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py index e4225b3ddb..6c9f81bbf4 100644 --- a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py +++ b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py @@ -3,7 +3,7 @@ import json import logging -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import DATASETS_PATH, JAILBREAK_TEMPLATES_PATH @@ -50,12 +50,12 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, example_count: int = 100, - many_shot_examples: Optional[list[dict[str, str]]] = None, + many_shot_examples: list[dict[str, str]] | None = None, ) -> None: """ Args: diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index f3c8aeedae..8ca2b4dabe 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import warn_if_set @@ -55,12 +55,12 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, params_type: type[AttackParamsT] = AttackParameters, # type: ignore[ty:invalid-parameter-default] - prepended_conversation_config: Optional[PrependedConversationConfig] = None, + prepended_conversation_config: PrependedConversationConfig | None = None, ) -> None: """ Initialize the prompt injection attack strategy. @@ -119,7 +119,7 @@ def __init__( # Store the prepended conversation configuration self._prepended_conversation_config = prepended_conversation_config - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. @@ -242,8 +242,8 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta ) def _determine_attack_outcome( - self, *, response: Optional[Message], score: Optional[Score], context: SingleTurnAttackContext[Any] - ) -> tuple[AttackOutcome, Optional[str]]: + self, *, response: Message | None, score: Score | None, context: SingleTurnAttackContext[Any] + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the response and score. @@ -299,7 +299,7 @@ def _get_message(self, context: SingleTurnAttackContext[Any]) -> Message: async def _send_prompt_to_objective_target_async( self, *, message: Message, context: SingleTurnAttackContext[Any] - ) -> Optional[Message]: + ) -> Message | None: """ Send the prompt to the target and return the response. @@ -334,7 +334,7 @@ async def _evaluate_response_async( *, response: Message, objective: str, - ) -> Optional[Score]: + ) -> Score | None: """ Evaluate the response against the objective using the configured scorers. diff --git a/pyrit/executor/attack/single_turn/role_play.py b/pyrit/executor/attack/single_turn/role_play.py index dfa21c8aa9..e061c496e3 100644 --- a/pyrit/executor/attack/single_turn/role_play.py +++ b/pyrit/executor/attack/single_turn/role_play.py @@ -4,7 +4,7 @@ import enum import logging import pathlib -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -68,9 +68,9 @@ def __init__( objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, role_play_definition_path: pathlib.Path, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -157,7 +157,7 @@ async def _rephrase_objective_async(self, *, objective: str) -> str: result = await converter.convert_async(prompt=objective, input_type="text") return result.output_text - async def _get_conversation_start_async(self) -> Optional[list[Message]]: + async def _get_conversation_start_async(self) -> list[Message] | None: """ Get the role-play conversation start messages. diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index d719861646..c4c3dc8997 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -32,10 +32,10 @@ class SingleTurnAttackContext(AttackContext[AttackParamsT]): conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # System prompt for chat-based targets - system_prompt: Optional[str] = None + system_prompt: str | None = None # Arbitrary metadata that downstream attacks or scorers may attach - metadata: Optional[dict[str, Union[str, int]]] = None + metadata: dict[str, str | int] | None = None class SingleTurnAttackStrategy(AttackStrategy[SingleTurnAttackContext[Any], AttackResult], ABC): diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 3761cb9cb2..40190f5c25 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -54,10 +54,10 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, - skeleton_key_prompt: Optional[str] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, + skeleton_key_prompt: str | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -85,7 +85,7 @@ def __init__( # Load skeleton key prompt self._skeleton_key_prompt = self._load_skeleton_key_prompt(skeleton_key_prompt) - def _load_skeleton_key_prompt(self, skeleton_key_prompt: Optional[str]) -> str: + def _load_skeleton_key_prompt(self, skeleton_key_prompt: str | None) -> str: """ Load the skeleton key prompt from the provided string or default file. @@ -135,7 +135,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta return result - async def _send_skeleton_key_prompt_async(self, *, context: SingleTurnAttackContext[Any]) -> Optional[Message]: + async def _send_skeleton_key_prompt_async(self, *, context: SingleTurnAttackContext[Any]) -> Message | None: """ Send the skeleton key prompt to the target to prime it for the attack. diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index aa4f0da3b5..e29fba9127 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -6,7 +6,7 @@ import uuid from collections import Counter from dataclasses import dataclass, field -from typing import Any, Optional, cast, overload +from typing import Any, cast, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -45,7 +45,7 @@ class FairnessBiasBenchmarkContext(StrategyContext): num_experiments: int = 1 # Custom objective (optional) - if not provided, default will be generated - objective: Optional[str] = None + objective: str | None = None # Prepended conversation for context prepended_conversation: list[Message] = field(default_factory=list) @@ -56,7 +56,7 @@ class FairnessBiasBenchmarkContext(StrategyContext): # Generated fields for the benchmark run generated_objective: str = field(default_factory=str) generated_story_prompt: str = field(default_factory=str) - generated_message: Optional[Message] = None + generated_message: Message | None = None experiment_results: list[dict[str, str]] = field(default_factory=list) @@ -90,9 +90,9 @@ def __init__( self, *, objective_target: PromptTarget, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -274,7 +274,7 @@ def _format_story_prompt(self, *, subject: str, story_type: str) -> str: f"and use pronouns to refer to the {subject} later in the text." ) - def _extract_name(self, response: str) -> Optional[str]: + def _extract_name(self, response: str) -> str | None: """ Extract person names using spaCy NER (if available) or fallback to regex. spaCy provides much more robust name extraction than simple regex patterns. @@ -306,7 +306,7 @@ def _extract_name(self, response: str) -> Optional[str]: return str(name) return None - def _extract_name_with_regex(self, response: str) -> Optional[str]: + def _extract_name_with_regex(self, response: str) -> str | None: """ Fallback regex-based name extraction (original implementation with improvements). @@ -365,7 +365,7 @@ def get_experiment_summary(self, *, context: FairnessBiasBenchmarkContext) -> di ], } - def get_last_context(self) -> Optional[FairnessBiasBenchmarkContext]: + def get_last_context(self) -> FairnessBiasBenchmarkContext | None: """ Get the context from the last execution. @@ -373,7 +373,7 @@ def get_last_context(self) -> Optional[FairnessBiasBenchmarkContext]: Optional[FairnessBiasBenchmarkContext]: The context from the most recent execution, or None if no execution has occurred """ - return cast("Optional[FairnessBiasBenchmarkContext]", getattr(self, "_last_context", None)) + return cast("FairnessBiasBenchmarkContext | None", getattr(self, "_last_context", None)) async def _teardown_async(self, *, context: FairnessBiasBenchmarkContext) -> None: """ @@ -391,9 +391,9 @@ async def execute_async( subject: str, story_type: str, num_experiments: int = 1, - objective: Optional[str] = None, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + objective: str | None = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackResult: ... diff --git a/pyrit/executor/benchmark/question_answering.py b/pyrit/executor/benchmark/question_answering.py index 8f2307eba9..1e84619abf 100644 --- a/pyrit/executor/benchmark/question_answering.py +++ b/pyrit/executor/benchmark/question_answering.py @@ -4,7 +4,7 @@ import logging import textwrap from dataclasses import dataclass, field -from typing import Any, Optional, overload +from typing import Any, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -45,7 +45,7 @@ class QuestionAnsweringBenchmarkContext(StrategyContext): # The generated question prompt for the benchmark generated_question_prompt: str = field(default_factory=str) # The generated message for the benchmark - generated_message: Optional[Message] = None + generated_message: Message | None = None class QuestionAnsweringBenchmark(Strategy[QuestionAnsweringBenchmarkContext, AttackResult]): @@ -84,9 +84,9 @@ def __init__( self, *, objective_target: PromptTarget, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, objective_format_string: str = _DEFAULT_OBJECTIVE_FORMAT, question_asking_format_string: str = _DEFAULT_QUESTION_FORMAT, options_format_string: str = _DEFAULT_OPTIONS_FORMAT, @@ -259,8 +259,8 @@ async def execute_async( self, *, question_answering_entry: QuestionAnsweringEntry, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackResult: ... diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 38a6c9261f..df10b0b0ca 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -12,7 +12,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from pyrit.common import default_values from pyrit.common.logger import logger @@ -85,10 +85,10 @@ class StrategyEventData(Generic[StrategyContextT, StrategyResultT]): # Context and result of the strategy context: StrategyContextT - result: Optional[StrategyResultT] = None + result: StrategyResultT | None = None # Optional error if the event is related to an error - error: Optional[Exception] = None + error: Exception | None = None class StrategyEventHandler(ABC, Generic[StrategyContextT, StrategyResultT]): @@ -157,7 +157,7 @@ def __init__( self, *, context_type: type[StrategyContextT], - event_handler: Optional[StrategyEventHandler[StrategyContextT, StrategyResultT]] = None, + event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT] | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -250,8 +250,8 @@ async def _handle_event_async( *, event: StrategyEvent, context: StrategyContextT, - result: Optional[StrategyResultT] = None, - error: Optional[Exception] = None, + result: StrategyResultT | None = None, + error: Exception | None = None, ) -> None: """ Handle a strategy event by notifying all registered event handlers. diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 3e32fa4faa..68cdf51785 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, overload import yaml @@ -103,9 +103,9 @@ def __init__( self, *, objective_target: PromptTarget, - processing_model: Optional[PromptTarget] = None, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + processing_model: PromptTarget | None = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the Anecdoctor prompt generation strategy. @@ -136,7 +136,7 @@ def __init__( self._system_prompt_template = self._load_prompt_from_yaml(yaml_filename=self._ANECDOCTOR_USE_KG_YAML) # Also preload the KG extraction prompt so `_extract_knowledge_graph_async` doesn't # repeat the file read + YAML parse on each invocation. - self._kg_prompt_template: Optional[str] = self._load_prompt_from_yaml( + self._kg_prompt_template: str | None = self._load_prompt_from_yaml( yaml_filename=self._ANECDOCTOR_BUILD_KG_YAML ) else: @@ -146,8 +146,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. @@ -160,7 +160,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -286,7 +286,7 @@ async def _prepare_examples_async(self, *, context: AnecdoctorContext) -> str: async def _send_examples_to_target_async( self, *, formatted_examples: str, context: AnecdoctorContext - ) -> Optional[Message]: + ) -> Message | None: """ Send the formatted examples to the target model. @@ -414,7 +414,7 @@ async def execute_async( content_type: str, language: str, evaluation_data: list[str], - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AnecdoctorResult: ... diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index dc53003521..6db8d0bdd7 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -6,7 +6,7 @@ import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -70,9 +70,8 @@ def __init__( self, context_type: type[PromptGeneratorStrategyContextT], logger: logging.Logger = logger, - event_handler: Optional[ - StrategyEventHandler[PromptGeneratorStrategyContextT, PromptGeneratorStrategyResultT] - ] = None, + event_handler: StrategyEventHandler[PromptGeneratorStrategyContextT, PromptGeneratorStrategyResultT] + | None = None, ) -> None: """ Initialize the prompt generator strategy. diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index df04314069..adb596eb42 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -8,7 +8,7 @@ import textwrap import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np from colorama import Fore, Style @@ -51,7 +51,7 @@ class _PromptNode: def __init__( self, template: str, - parent: Optional[_PromptNode] = None, + parent: _PromptNode | None = None, ) -> None: """ Create the PromptNode instance. @@ -66,7 +66,7 @@ def __init__( self.level: int = 0 if parent is None else parent.level + 1 self.visited_num = 0 self.rewards: float = 0 - self.parent: Optional[_PromptNode] = None + self.parent: _PromptNode | None = None if parent is not None: self.add_parent(parent) @@ -157,7 +157,7 @@ def _calculate_uct_score(self, *, node: _PromptNode, step: int) -> float: exploration = self.frequency_weight * np.sqrt(2 * np.log(step) / (node.visited_num + 0.01)) return float(exploitation + exploration) - def update_rewards(self, path: list[_PromptNode], reward: float, last_node: Optional[_PromptNode] = None) -> None: + def update_rewards(self, path: list[_PromptNode], reward: float, last_node: _PromptNode | None = None) -> None: """ Update rewards for nodes in the path. @@ -185,19 +185,19 @@ class FuzzerContext(PromptGeneratorStrategyContext): # Per-execution input data prompts: list[str] prompt_templates: list[str] - max_query_limit: Optional[int] = None + max_query_limit: int | None = None # Tracking state total_target_query_count: int = 0 total_jailbreak_count: int = 0 - jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = field(default_factory=list) + jailbreak_conversation_ids: list[str | uuid.UUID] = field(default_factory=list) executed_turns: int = 0 # Tree structure initial_prompt_nodes: list[_PromptNode] = field(default_factory=list) new_prompt_nodes: list[_PromptNode] = field(default_factory=list) mcts_selected_path: list[_PromptNode] = field(default_factory=list) - last_choice_node: Optional[_PromptNode] = None + last_choice_node: _PromptNode | None = None # Optional memory labels to apply to the prompts memory_labels: dict[str, str] = field(default_factory=dict) @@ -223,7 +223,7 @@ class FuzzerResult(PromptGeneratorStrategyResult): # Concrete fields instead of metadata storage successful_templates: list[str] = Field(default_factory=list) - jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = Field(default_factory=list) + jailbreak_conversation_ids: list[str | uuid.UUID] = Field(default_factory=list) total_queries: int = 0 templates_explored: int = 0 @@ -541,8 +541,8 @@ def with_default_scorer( objective_target: PromptTarget, template_converters: list[FuzzerConverter], scoring_target: PromptTarget, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, frequency_weight: float = _DEFAULT_FREQUENCY_WEIGHT, reward_penalty: float = _DEFAULT_REWARD_PENALTY, minimum_reward: float = _DEFAULT_MINIMUM_REWARD, @@ -608,10 +608,10 @@ def __init__( *, objective_target: PromptTarget, template_converters: list[FuzzerConverter], - converter_config: Optional[StrategyConverterConfig] = None, - scorer: Optional[Scorer] = None, + converter_config: StrategyConverterConfig | None = None, + scorer: Scorer | None = None, scoring_success_threshold: float = 0.8, - prompt_normalizer: Optional[PromptNormalizer] = None, + prompt_normalizer: PromptNormalizer | None = None, frequency_weight: float = _DEFAULT_FREQUENCY_WEIGHT, reward_penalty: float = _DEFAULT_REWARD_PENALTY, minimum_reward: float = _DEFAULT_MINIMUM_REWARD, @@ -685,8 +685,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. @@ -699,7 +699,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -1198,8 +1198,8 @@ async def execute_async( *, prompts: list[str], prompt_templates: list[str], - max_query_limit: Optional[int] = None, - memory_labels: Optional[dict[str, str]] = None, + max_query_limit: int | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> FuzzerResult: ... diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 7e68a9907d..60ccbc1865 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -4,7 +4,7 @@ import pathlib import random import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -25,9 +25,9 @@ class FuzzerCrossOverConverter(FuzzerConverter): def __init__( self, *, - converter_target: Optional[PromptTarget] = None, - prompt_template: Optional[SeedPrompt] = None, - prompt_templates: Optional[list[str]] = None, + converter_target: PromptTarget | None = None, + prompt_template: SeedPrompt | None = None, + prompt_templates: list[str] | None = None, ) -> None: """ Initialize the converter with the specified chat target and prompt templates. diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index 0c1f3fdf95..627ed159ed 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -3,7 +3,6 @@ import pathlib import uuid -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -24,8 +23,8 @@ class FuzzerExpandConverter(FuzzerConverter): def __init__( self, *, - converter_target: Optional[PromptTarget] = None, - prompt_template: Optional[SeedPrompt] = None, + converter_target: PromptTarget | None = None, + prompt_template: SeedPrompt | None = None, ) -> None: """Initialize the expand converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py index 10acff3fb6..d1f6783fc3 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerRephraseConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the rephrase converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py index 6258a5e7b3..dcc098f67c 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerShortenConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the shorten converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py index d7f2796579..25ec6fd2fb 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerSimilarConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the similar converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index ba5a4a3369..179645b879 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -6,7 +6,7 @@ import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -109,7 +109,7 @@ def __init__( *, context_type: type[WorkflowContextT], logger: logging.Logger = logger, - event_handler: Optional[StrategyEventHandler[WorkflowContextT, WorkflowResultT]] = None, + event_handler: StrategyEventHandler[WorkflowContextT, WorkflowResultT] | None = None, ) -> None: """ Initialize the workflow strategy with a specific context type and logger. diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index e981c46b63..426ad74cdb 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional, Protocol, Union, overload +from typing import Any, Protocol, overload from pyrit.common.utils import combine_dict, get_kwarg_param from pyrit.executor.core import StrategyConverterConfig @@ -66,7 +66,7 @@ class XPIAContext(WorkflowContext): attack_content: Message # Callback to execute after the attack prompt is positioned in the attack location - processing_callback: Optional[XPIAProcessingCallback] = None + processing_callback: XPIAProcessingCallback | None = None # Conversation ID for the attack setup target attack_setup_target_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -75,7 +75,7 @@ class XPIAContext(WorkflowContext): processing_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # The prompt to send to the processing target (for test workflow) - processing_prompt: Optional[Message] = None + processing_prompt: Message | None = None # Additional labels that can be applied throughout the workflow memory_labels: dict[str, str] = field(default_factory=dict) @@ -96,10 +96,10 @@ class XPIAResult(WorkflowResult): processing_response: str # Score if a scorer was used, None otherwise - score: Optional[Score] = None + score: Score | None = None # Response from the attack setup target - attack_setup_response: Optional[str] = None + attack_setup_response: str | None = None @property def success(self) -> bool: @@ -145,9 +145,9 @@ def __init__( self, *, attack_setup_target: PromptTarget, - scorer: Optional[Scorer] = None, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + scorer: Scorer | None = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -178,8 +178,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this XPIA workflow. @@ -192,7 +192,7 @@ def _create_identifier( Returns: ComponentIdentifier: The identifier for this XPIA workflow. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "attack_setup_target": self._attack_setup_target.get_identifier(), } if self._scorer: @@ -382,7 +382,7 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: self._logger.info(f'Received the following response from the processing target "{processing_response}"') return processing_response - async def _score_response_async(self, *, processing_response: str) -> Optional[Score]: + async def _score_response_async(self, *, processing_response: str) -> Score | None: """ Score the processing response if a scorer is provided. @@ -429,9 +429,9 @@ async def execute_async( self, *, attack_content: Message, - processing_callback: Optional[XPIAProcessingCallback] = None, - processing_prompt: Optional[Message] = None, - memory_labels: Optional[dict[str, str]] = None, + processing_callback: XPIAProcessingCallback | None = None, + processing_prompt: Message | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> XPIAResult: ... @@ -503,8 +503,8 @@ def __init__( attack_setup_target: PromptTarget, processing_target: PromptTarget, scorer: Scorer, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -605,8 +605,8 @@ def __init__( *, attack_setup_target: PromptTarget, scorer: Scorer, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6723ae2842..0940fa8d82 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -7,7 +7,7 @@ from collections.abc import MutableSequence, Sequence from contextlib import closing, suppress from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from sqlalchemy import and_, create_engine, event, exists, or_, text from sqlalchemy.engine.base import Engine @@ -64,9 +64,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton): def __init__( self, *, - connection_string: Optional[str] = None, - results_container_url: Optional[str] = None, - results_sas_token: Optional[str] = None, + connection_string: str | None = None, + results_container_url: str | None = None, + results_sas_token: str | None = None, verbose: bool = False, skip_schema_migration: bool = False, silent: bool = False, @@ -93,12 +93,12 @@ def __init__( env_var_name=self.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL, passed_value=results_container_url ) - self._results_container_sas_token: Optional[str] = self._resolve_sas_token( + self._results_container_sas_token: str | None = self._resolve_sas_token( self.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN, results_sas_token ) - self._auth_token: Optional[AccessToken] = None - self._auth_token_expiry: Optional[int] = None + self._auth_token: AccessToken | None = None + self._auth_token_expiry: int | None = None self.results_path = self._results_container_url @@ -116,7 +116,7 @@ def __init__( super().__init__() @staticmethod - def _resolve_sas_token(env_var_name: str, passed_value: Optional[str] = None) -> Optional[str]: + def _resolve_sas_token(env_var_name: str, passed_value: str | None = None) -> str | None: """ Resolve the SAS token value, allowing a fallback to None for delegation SAS. @@ -285,7 +285,7 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str return [or_(pme_match, are_match)] - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: + def _get_metadata_conditions(self, *, prompt_metadata: dict[str, str | int]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -310,7 +310,7 @@ def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int] return [condition] def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] + self, *, prompt_metadata: dict[str, str | int] ) -> list[TextClause]: """ Generate SQL conditions for filtering message pieces by prompt metadata. @@ -325,7 +325,7 @@ def _get_message_pieces_prompt_metadata_conditions( """ return self._get_metadata_conditions(prompt_metadata=prompt_metadata) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> TextClause: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> TextClause: """ Generate SQL condition for filtering seed prompts by metadata. @@ -800,10 +800,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index e2ddc26e77..3fc6592783 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.embedding import OpenAITextEmbedding from pyrit.memory.memory_models import EmbeddingDataEntry @@ -16,7 +15,7 @@ class MemoryEmbedding: embedding_model (EmbeddingSupport): An instance of a class that supports embedding generation. """ - def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None) -> None: + def __init__(self, *, embedding_model: EmbeddingSupport | None = None) -> None: """ Initialize the memory embedding helper with a backing embedding model. @@ -55,7 +54,7 @@ def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> Embe raise ValueError("Only text data is supported for embedding.") -def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None: +def default_memory_embedding_factory(embedding_model: EmbeddingSupport | None = None) -> MemoryEmbedding | None: """ Create a MemoryEmbedding instance with default or provided embedding model. diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index 54e61505b3..b2a06668e7 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -4,7 +4,6 @@ import csv import json from pathlib import Path -from typing import Optional from pyrit.models import MessagePiece @@ -30,7 +29,7 @@ def __init__(self) -> None: } def export_data( - self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json" + self, data: list[MessagePiece], *, file_path: Path | None = None, export_type: str = "json" ) -> None: """ Export the provided data to a file in the specified format. @@ -52,7 +51,7 @@ def export_data( else: raise ValueError(f"Unsupported export format: {export_type}") - def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_json(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a JSON file at the specified file path. Each item in the data list, representing a row from the table, @@ -73,7 +72,7 @@ def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = N with open(file_path, "w") as f: json.dump(export_data, f, indent=4) - def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_csv(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a CSV file at the specified file path. Each item in the data list, representing a row from the table, @@ -97,7 +96,7 @@ def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = No writer.writeheader() writer.writerows(export_data) - def export_to_markdown(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_markdown(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a Markdown file at the specified file path. Each item in the data list is converted to a dictionary and formatted as a table. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 26448f5b6c..9aeef709b5 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -11,7 +11,7 @@ from contextlib import closing from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, TypeVar from sqlalchemy import MetaData, and_, not_, or_ from sqlalchemy.engine.base import Engine @@ -93,7 +93,7 @@ def _uid() -> str: """Return a short unique suffix for bind-param deduplication.""" return uuid.uuid4().hex[:8] - def __init__(self, embedding_model: Optional[Any] = None) -> None: + def __init__(self, embedding_model: Any | None = None) -> None: """ Initialize the MemoryInterface. @@ -110,7 +110,7 @@ def __init__(self, embedding_model: Optional[Any] = None) -> None: # Ensure cleanup at process exit self.cleanup() - def enable_embedding(self, embedding_model: Optional[Any] = None) -> None: + def enable_embedding(self, embedding_model: Any | None = None) -> None: """ Enable embedding functionality for the memory interface. @@ -318,9 +318,7 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str """ @abc.abstractmethod - def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] - ) -> list[Any]: + def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dict[str, str | int]) -> list[Any]: """ Return a list of conditions for filtering memory entries based on prompt metadata. @@ -333,7 +331,7 @@ def _get_message_pieces_prompt_metadata_conditions( """ @abc.abstractmethod - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> Any: """ Return a condition for filtering seed prompt entries based on prompt metadata. @@ -362,10 +360,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -393,7 +391,7 @@ def _execute_batched_query( distinct: bool = False, join_scores: bool = False, batch_size: int | None = None, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -708,12 +706,12 @@ def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: def get_scores( self, *, - score_ids: Optional[Sequence[str]] = None, - score_type: Optional[str] = None, - score_category: Optional[str] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + score_ids: Sequence[str] | None = None, + score_type: str | None = None, + score_category: str | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -772,19 +770,19 @@ def get_scores( def get_prompt_scores( self, *, - attack_id: Optional[str | uuid.UUID] = None, - role: Optional[str] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str | uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, + attack_id: str | uuid.UUID | None = None, + role: str | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str | uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + prompt_metadata: dict[str, str | int] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, ) -> Sequence[Score]: """ Retrieve scores attached to message pieces based on the specified filters. @@ -879,20 +877,20 @@ def get_request_from_response(self, *, response: Message) -> Message: def get_message_pieces( self, *, - attack_id: Optional[str | uuid.UUID] = None, - role: Optional[str] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str | uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + attack_id: str | uuid.UUID | None = None, + role: str | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str | uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + prompt_metadata: dict[str, str | int] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -1164,7 +1162,7 @@ def update_labels_by_conversation_id(self, *, conversation_id: str, labels: dict ) def update_prompt_metadata_by_conversation_id( - self, *, conversation_id: str, prompt_metadata: dict[str, Union[str, int]] + self, *, conversation_id: str, prompt_metadata: dict[str, str | int] ) -> bool: """ Update the metadata of prompt entries in memory for a given conversation ID. @@ -1232,20 +1230,20 @@ def cleanup(self) -> None: def get_seeds( self, *, - value: Optional[str] = None, - value_sha256: Optional[Sequence[str]] = None, - dataset_name: Optional[str] = None, - dataset_name_pattern: Optional[str] = None, - data_types: Optional[Sequence[str]] = None, - harm_categories: Optional[Sequence[str]] = None, - added_by: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - seed_type: Optional[SeedType] = None, - parameters: Optional[Sequence[str]] = None, - metadata: Optional[dict[str, Union[str, int]]] = None, - prompt_group_ids: Optional[Sequence[uuid.UUID]] = None, + value: str | None = None, + value_sha256: Sequence[str] | None = None, + dataset_name: str | None = None, + dataset_name_pattern: str | None = None, + data_types: Sequence[str] | None = None, + harm_categories: Sequence[str] | None = None, + added_by: str | None = None, + authors: Sequence[str] | None = None, + groups: Sequence[str] | None = None, + source: str | None = None, + seed_type: SeedType | None = None, + parameters: Sequence[str] | None = None, + metadata: dict[str, str | int] | None = None, + prompt_group_ids: Sequence[uuid.UUID] | None = None, ) -> Sequence[Seed]: """ Retrieve a list of seed prompts based on the specified filters. @@ -1327,7 +1325,7 @@ def get_seeds( raise def _add_list_conditions( - self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Optional[Sequence[str]] = None + self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Sequence[str] | None = None ) -> None: if values: conditions.extend(field.contains(value) for value in values) @@ -1364,7 +1362,7 @@ async def _serialize_seed_value_async(self, prompt: Seed) -> str: serialized_prompt_value = str(serializer.value) return serialized_prompt_value or "" - async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Optional[str] = None) -> None: + async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: str | None = None) -> None: """ Insert a list of seeds into the memory storage. @@ -1441,7 +1439,7 @@ def get_seed_dataset_names(self) -> Sequence[str]: raise async def add_seed_groups_to_memory_async( - self, *, prompt_groups: Sequence[SeedGroup], added_by: Optional[str] = None + self, *, prompt_groups: Sequence[SeedGroup], added_by: str | None = None ) -> None: """ Insert a list of seed groups into the memory storage. @@ -1481,21 +1479,21 @@ async def add_seed_groups_to_memory_async( def get_seed_groups( self, *, - value: Optional[str] = None, - value_sha256: Optional[Sequence[str]] = None, - dataset_name: Optional[str] = None, - dataset_name_pattern: Optional[str] = None, - data_types: Optional[Sequence[str]] = None, - harm_categories: Optional[Sequence[str]] = None, - added_by: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - seed_type: Optional[SeedType] = None, - parameters: Optional[Sequence[str]] = None, - metadata: Optional[dict[str, Union[str, int]]] = None, - prompt_group_ids: Optional[Sequence[uuid.UUID]] = None, - group_length: Optional[Sequence[int]] = None, + value: str | None = None, + value_sha256: Sequence[str] | None = None, + dataset_name: str | None = None, + dataset_name_pattern: str | None = None, + data_types: Sequence[str] | None = None, + harm_categories: Sequence[str] | None = None, + added_by: str | None = None, + authors: Sequence[str] | None = None, + groups: Sequence[str] | None = None, + source: str | None = None, + seed_type: SeedType | None = None, + parameters: Sequence[str] | None = None, + metadata: dict[str, str | int] | None = None, + prompt_group_ids: Sequence[uuid.UUID] | None = None, + group_length: Sequence[int] | None = None, ) -> Sequence[SeedGroup]: """ Retrieve groups of seed prompts based on the provided filtering criteria. @@ -1564,18 +1562,18 @@ def get_seed_groups( def export_conversations( self, *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str] | Sequence[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + file_path: Path | None = None, export_type: str = "json", ) -> Path: """ @@ -1715,21 +1713,21 @@ def update_attack_result_by_id(self, *, attack_result_id: str, update_fields: di def get_attack_results( self, *, - attack_result_ids: Optional[Sequence[str]] = None, - conversation_id: Optional[str] = None, - objective: Optional[str] = None, - objective_sha256: Optional[Sequence[str]] = None, - outcome: Optional[str] = None, - attack_class: Optional[str] = None, - attack_classes: Optional[Sequence[str]] = None, - atomic_attack_eval_hashes: Optional[Sequence[str]] = None, - converter_classes: Optional[Sequence[str]] = None, + attack_result_ids: Sequence[str] | None = None, + conversation_id: str | None = None, + objective: str | None = None, + objective_sha256: Sequence[str] | None = None, + outcome: str | None = None, + attack_class: str | None = None, + attack_classes: Sequence[str] | None = None, + atomic_attack_eval_hashes: Sequence[str] | None = None, + converter_classes: Sequence[str] | None = None, converter_classes_match: Literal["all", "any"] = "all", - has_converters: Optional[bool] = None, - targeted_harm_categories: Optional[Sequence[str]] = None, - labels: Optional[dict[str, str | Sequence[str]]] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, - scenario_result_id: Optional[str] = None, + has_converters: bool | None = None, + targeted_harm_categories: Sequence[str] | None = None, + labels: dict[str, str | Sequence[str]] | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, + scenario_result_id: str | None = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -2091,16 +2089,16 @@ def update_scenario_metadata( def get_scenario_results( self, *, - scenario_result_ids: Optional[Sequence[str]] = None, - scenario_name: Optional[str] = None, - scenario_version: Optional[int] = None, - pyrit_version: Optional[str] = None, - added_after: Optional[datetime] = None, - added_before: Optional[datetime] = None, - labels: Optional[dict[str, str]] = None, - objective_target_endpoint: Optional[str] = None, - objective_target_model_name: Optional[str] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + scenario_result_ids: Sequence[str] | None = None, + scenario_name: str | None = None, + scenario_version: int | None = None, + pyrit_version: str | None = None, + added_after: datetime | None = None, + added_before: datetime | None = None, + labels: dict[str, str] | None = None, + objective_target_endpoint: str | None = None, + objective_target_model_name: str | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, limit: int | None = None, ) -> Sequence[ScenarioResult]: """ diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 461d2b871b..fef14c2471 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -8,7 +8,7 @@ from contextlib import closing, suppress from datetime import datetime from pathlib import Path -from typing import Any, Literal, Optional, TypeVar, Union, cast +from typing import Any, Literal, TypeVar, cast from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine @@ -59,7 +59,7 @@ class SQLiteMemory(MemoryInterface, metaclass=Singleton): def __init__( self, *, - db_path: Optional[Union[Path, str]] = None, + db_path: Path | str | None = None, verbose: bool = False, skip_schema_migration: bool = False, silent: bool = False, @@ -80,7 +80,7 @@ def __init__( super().__init__() if db_path == ":memory:": - self.db_path: Union[Path, str] = ":memory:" + self.db_path: Path | str = ":memory:" else: self.db_path = Path(db_path or Path(DB_DATA_PATH, self.DEFAULT_DB_FILE_NAME)).resolve() self.results_path = str(DB_DATA_PATH) @@ -179,7 +179,7 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str return [or_(pme_match, are_match)] def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] + self, *, prompt_metadata: dict[str, str | int] ) -> list[TextClause]: """ Generate SQLAlchemy filter conditions for filtering conversation pieces by prompt metadata. @@ -195,7 +195,7 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -334,10 +334,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -487,18 +487,18 @@ def dispose_engine(self) -> None: def export_conversations( self, *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str] | Sequence[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + file_path: Path | None = None, export_type: str = "json", ) -> Path: """ diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 4fc11fbd0a..d3fcaade31 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,7 +4,7 @@ import base64 import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import aiofiles @@ -93,7 +93,7 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: # Use simple string for single text piece, otherwise use content list if len(pieces) == 1 and pieces[0].converted_value_data_type == "text": - content: Union[str, list[dict[str, Any]]] = pieces[0].converted_value + content: str | list[dict[str, Any]] = pieces[0].converted_value else: content = [await self._piece_to_content_dict_async(piece) for piece in pieces] diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index caab9fa337..06c904505b 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Literal, Optional, cast +from typing import TYPE_CHECKING, ClassVar, Literal, cast from pyrit.common import get_non_required_value from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer @@ -109,7 +109,7 @@ def __init__( self.system_message_behavior = system_message_behavior @staticmethod - def _load_tokenizer(model_name: str, token: Optional[str]) -> "PreTrainedTokenizerBase": + def _load_tokenizer(model_name: str, token: str | None) -> "PreTrainedTokenizerBase": """ Load a tokenizer from HuggingFace. @@ -134,8 +134,8 @@ def from_model( cls, model_name_or_alias: str, *, - token: Optional[str] = None, - system_message_behavior: Optional[TokenizerSystemBehavior] = None, + token: str | None = None, + system_message_behavior: TokenizerSystemBehavior | None = None, ) -> "TokenizerTemplateNormalizer": """ Create a normalizer from a model name or alias. diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index c2f801862d..b873b33333 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Optional, Union +from typing import Any from pydantic import BaseModel, ConfigDict @@ -30,10 +30,10 @@ class ChatMessage(BaseModel): model_config = ConfigDict(extra="forbid") role: ChatMessageRole - content: Union[str, list[dict[str, Any]]] - name: Optional[str] = None - tool_calls: Optional[list[ToolCall]] = None - tool_call_id: Optional[str] = None + content: str | list[dict[str, Any]] + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None def to_dict(self) -> dict[str, Any]: """ diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 0915a045c4..6e39cfd233 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -4,7 +4,6 @@ from __future__ import annotations from enum import Enum -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -27,7 +26,7 @@ class ConversationReference(BaseModel): conversation_id: str conversation_type: ConversationType - description: Optional[str] = None + description: str | None = None def __hash__(self) -> int: """ diff --git a/pyrit/models/conversation_stats.py b/pyrit/models/conversation_stats.py index 67b09e24be..14497954f5 100644 --- a/pyrit/models/conversation_stats.py +++ b/pyrit/models/conversation_stats.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from datetime import datetime -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import BaseModel, ConfigDict, Field @@ -29,7 +29,7 @@ class ConversationStats(BaseModel): """ message_count: int = 0 - last_message_preview: Optional[str] = None - last_message_data_type: Optional[PromptDataType] = None + last_message_preview: str | None = None + last_message_data_type: PromptDataType | None = None labels: dict[str, str] = Field(default_factory=dict) - created_at: Optional[datetime] = None + created_at: datetime | None = None diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 3ee43eed62..3d529472fc 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -11,7 +11,7 @@ import wave from mimetypes import guess_type from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Literal, get_args from urllib.parse import urlparse import aiofiles @@ -47,8 +47,8 @@ def _write_wav_sync( def data_serializer_factory( *, data_type: PromptDataType, - value: Optional[str] = None, - extension: Optional[str] = None, + value: str | None = None, + extension: str | None = None, category: AllowedCategories, ) -> DataTypeSerializer: """ @@ -113,7 +113,7 @@ class DataTypeSerializer(abc.ABC): data_sub_directory: str file_extension: str - _file_path: Union[Path, str] | None = None + _file_path: Path | str | None = None @property def _memory(self) -> MemoryInterface: @@ -151,7 +151,7 @@ def data_on_disk(self) -> bool: """ - async def save_data_async(self, data: bytes, output_filename: Optional[str] = None) -> None: + async def save_data_async(self, data: bytes, output_filename: str | None = None) -> None: """ Save data to storage. @@ -192,7 +192,7 @@ async def save_formatted_audio_async( num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save PCM16 or similarly formatted audio data to storage. @@ -310,7 +310,7 @@ async def get_sha256_async(self) -> str: hash_object = hashlib.sha256(input_bytes) return hash_object.hexdigest() - async def get_data_filename_async(self, file_name: Optional[str] = None) -> Union[Path, str]: + async def get_data_filename_async(self, file_name: str | None = None) -> Path | str: """ Generate or retrieve a unique filename for the data file. @@ -356,7 +356,7 @@ async def get_data_filename_async(self, file_name: Optional[str] = None) -> Unio return self._file_path async def save_data( # pyrit-async-suffix-exempt - self, data: bytes, output_filename: Optional[str] = None + self, data: bytes, output_filename: str | None = None ) -> None: """ Save data to storage (deprecated alias of ``save_data_async``). @@ -395,7 +395,7 @@ async def save_formatted_audio( # pyrit-async-suffix-exempt num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). @@ -457,8 +457,8 @@ async def get_sha256(self) -> str: # pyrit-async-suffix-exempt return await self.get_sha256_async() async def get_data_filename( # pyrit-async-suffix-exempt - self, file_name: Optional[str] = None - ) -> Union[Path, str]: + self, file_name: str | None = None + ) -> Path | str: """ Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). @@ -574,7 +574,7 @@ def data_on_disk(self) -> bool: class URLDataTypeSerializer(DataTypeSerializer): """Serializer for URL values and URL-backed local file references.""" - def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str, extension: str | None = None) -> None: """ Initialize a URL serializer. @@ -604,7 +604,7 @@ def data_on_disk(self) -> bool: class ImagePathDataTypeSerializer(DataTypeSerializer): """Serializer for image path values stored on disk.""" - def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str | None = None, extension: str | None = None) -> None: """ Initialize an image-path serializer. @@ -639,8 +639,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize an audio-path serializer. @@ -676,8 +676,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a video-path serializer. @@ -713,8 +713,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a generic binary-path serializer. diff --git a/pyrit/models/harm_definition.py b/pyrit/models/harm_definition.py index 9e739244ab..6e74d6aef9 100644 --- a/pyrit/models/harm_definition.py +++ b/pyrit/models/harm_definition.py @@ -10,7 +10,6 @@ import logging import re from pathlib import Path -from typing import Optional, Union import yaml from pydantic import BaseModel, Field @@ -53,9 +52,9 @@ class HarmDefinition(BaseModel): version: str category: str scale_descriptions: list[ScaleDescription] = Field(default_factory=list) - source_path: Optional[str] = None + source_path: str | None = None - def get_scale_description(self, score_value: str) -> Optional[str]: + def get_scale_description(self, score_value: str) -> str | None: """ Get the description for a specific score value. @@ -101,7 +100,7 @@ def validate_category(category: str, *, check_exists: bool = False) -> bool: return True @classmethod - def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": + def from_yaml(cls, harm_definition_path: str | Path) -> "HarmDefinition": """ Load and validate a harm definition from a YAML file. diff --git a/pyrit/models/identifiers/component_identifier.py b/pyrit/models/identifiers/component_identifier.py index d3d0c71933..f050e6e67e 100644 --- a/pyrit/models/identifiers/component_identifier.py +++ b/pyrit/models/identifiers/component_identifier.py @@ -20,7 +20,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, model_serializer, model_validator @@ -162,16 +162,16 @@ class ComponentIdentifier(BaseModel): #: Behavioral parameters that affect output. params: dict[str, Any] = Field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). - children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = Field(default_factory=dict) + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = Field(default_factory=dict) #: Content-addressed SHA256 hash. Computed automatically when ``None``; #: pass an explicit value to preserve a hash from DB storage where params #: may have been truncated. - hash: Optional[str] = None + hash: str | None = None #: Version tag for storage. Not included in the content hash. pyrit_version: str = Field(default=pyrit.__version__) #: Evaluation hash. Computed by EvaluationIdentifier subclasses and attached #: to the identifier so it survives DB round-trips with truncated params. - eval_hash: Optional[str] = None + eval_hash: str | None = None # ------------------------------------------------------------------ # Validators @@ -435,8 +435,8 @@ def of( cls, obj: object, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Build a ComponentIdentifier from a live object instance. @@ -464,7 +464,7 @@ def of( children=clean_children, ) - def get_child(self, key: str) -> Optional[ComponentIdentifier]: + def get_child(self, key: str) -> ComponentIdentifier | None: """ Get a single child by key. @@ -519,7 +519,7 @@ def _collect_child_eval_hashes(self) -> set[str]: return hashes @staticmethod - def _truncate_value(*, value: Any, max_length: Optional[int]) -> Any: + def _truncate_value(*, value: Any, max_length: int | None) -> Any: """ Truncate string values longer than ``max_length`` with a ``...`` suffix. @@ -538,7 +538,7 @@ def _truncate_value(*, value: Any, max_length: Optional[int]) -> Any: # Deprecated shims — kept for one release cycle # ------------------------------------------------------------------ - def to_dict(self, *, max_value_length: Optional[int] = None) -> dict[str, Any]: + def to_dict(self, *, max_value_length: int | None = None) -> dict[str, Any]: """ Return the flat storage dict (deprecated; use ``model_dump`` instead). @@ -584,7 +584,7 @@ class Identifiable(ABC): component's lifetime. """ - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None @abstractmethod def _build_identifier(self) -> ComponentIdentifier: diff --git a/pyrit/models/identifiers/evaluation_identifier.py b/pyrit/models/identifiers/evaluation_identifier.py index b2fc1b996d..27ed3a2cac 100644 --- a/pyrit/models/identifiers/evaluation_identifier.py +++ b/pyrit/models/identifiers/evaluation_identifier.py @@ -23,7 +23,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel, ConfigDict, Field @@ -65,18 +65,18 @@ class ChildEvalRule(BaseModel): model_config = ConfigDict(frozen=True) exclude: bool = False - included_params: Optional[frozenset[str]] = None - included_item_values: Optional[dict[str, Any]] = Field(default=None) - param_fallbacks: Optional[dict[str, str]] = Field(default=None) - inner_child_name: Optional[str] = Field(default=None) + included_params: frozenset[str] | None = None + included_item_values: dict[str, Any] | None = Field(default=None) + param_fallbacks: dict[str, str] | None = Field(default=None) + inner_child_name: str | None = Field(default=None) def _build_eval_dict( identifier: ComponentIdentifier, *, child_eval_rules: dict[str, ChildEvalRule], - _included_params: Optional[frozenset[str]] = None, - _param_fallbacks: Optional[dict[str, str]] = None, + _included_params: frozenset[str] | None = None, + _param_fallbacks: dict[str, str] | None = None, ) -> dict[str, Any]: """ Build a filtered dictionary for eval-hash computation. @@ -177,7 +177,7 @@ def compute_eval_hash( identifier: ComponentIdentifier, *, child_eval_rules: dict[str, ChildEvalRule], - own_rule: Optional[ChildEvalRule] = None, + own_rule: ChildEvalRule | None = None, ) -> str: """ Compute a behavioral equivalence hash for evaluation grouping. @@ -252,7 +252,7 @@ class EvaluationIdentifier(ABC): """ CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] - OWN_RULE: ClassVar[Optional[ChildEvalRule]] = None + OWN_RULE: ClassVar[ChildEvalRule | None] = None def __init__(self, identifier: ComponentIdentifier) -> None: """ @@ -357,7 +357,7 @@ class ObjectiveTargetEvaluationIdentifier(EvaluationIdentifier): """ CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {} - OWN_RULE: ClassVar[Optional[ChildEvalRule]] = ChildEvalRule( + OWN_RULE: ClassVar[ChildEvalRule | None] = ChildEvalRule( included_params=TARGET_EVAL_PARAMS, param_fallbacks=TARGET_EVAL_PARAM_FALLBACKS, ) diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py index 8c4c1b9864..b6915526b3 100644 --- a/pyrit/models/json_response_config.py +++ b/pyrit/models/json_response_config.py @@ -4,7 +4,7 @@ from __future__ import annotations import json -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict @@ -30,12 +30,12 @@ class _JsonResponseConfig(BaseModel): model_config = ConfigDict(extra="forbid") enabled: bool = False - json_schema: Optional[dict[str, Any]] = None + json_schema: dict[str, Any] | None = None schema_name: str = "CustomSchema" strict: bool = True @classmethod - def from_metadata(cls, *, metadata: Optional[dict[str, Any]]) -> _JsonResponseConfig: + def from_metadata(cls, *, metadata: dict[str, Any] | None) -> _JsonResponseConfig: if not metadata: return cls(enabled=False) diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index 2b5224f6c6..6b625ac46c 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.models.messages.message import Message from pyrit.models.messages.message_piece import MessagePiece @@ -178,7 +178,7 @@ def construct_response_from_request( request: MessagePiece, response_text_pieces: list[str], response_type: PromptDataType = "text", - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + prompt_metadata: dict[str, str | int] | None = None, error: PromptResponseError = "none", ) -> Message: """ diff --git a/pyrit/models/messages/message.py b/pyrit/models/messages/message.py index d14f1ec1b9..91f208d8d9 100644 --- a/pyrit/models/messages/message.py +++ b/pyrit/models/messages/message.py @@ -6,7 +6,7 @@ import copy import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, model_validator @@ -195,9 +195,9 @@ def get_piece(self, n: int = 0) -> MessagePiece: def get_pieces_by_type( self, *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, + data_type: PromptDataType | None = None, + original_value_data_type: PromptDataType | None = None, + converted_value_data_type: PromptDataType | None = None, ) -> list[MessagePiece]: """ Return all message pieces matching the given data type. @@ -222,10 +222,10 @@ def get_pieces_by_type( def get_piece_by_type( self, *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, - ) -> Optional[MessagePiece]: + data_type: PromptDataType | None = None, + original_value_data_type: PromptDataType | None = None, + converted_value_data_type: PromptDataType | None = None, + ) -> MessagePiece | None: """ Return the first message piece matching the given data type, or None. @@ -358,7 +358,7 @@ def from_prompt( *, prompt: str, role: ChatMessageRole, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + prompt_metadata: dict[str, str | int] | None = None, ) -> Message: """ Build a single-piece message from prompt text. diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 131919ecf8..0649a08e77 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal from uuid import uuid4 from pydantic import ( @@ -102,20 +102,20 @@ class MessagePiece(BaseModel): timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) original_value: str original_value_data_type: PromptDataType = "text" - original_value_sha256: Optional[str] = None + original_value_sha256: str | None = None converted_value: str = "" converted_value_data_type: PromptDataType = "text" - converted_value_sha256: Optional[str] = None + converted_value_sha256: str | None = None response_error: PromptResponseError = "none" originator: Literal["attack", "converter", "undefined", "scorer"] = "undefined" - original_prompt_id: Optional[uuid.UUID] = None + original_prompt_id: uuid.UUID | None = None labels: dict[str, Any] = Field(default_factory=dict) targeted_harm_categories: list[str] = Field(default_factory=list) prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) - prompt_target_identifier: Optional[ComponentIdentifierField] = None - attack_identifier: Optional[ComponentIdentifierField] = None - scorer_identifier: Optional[ComponentIdentifierField] = None + prompt_target_identifier: ComponentIdentifierField | None = None + attack_identifier: ComponentIdentifierField | None = None + scorer_identifier: ComponentIdentifierField | None = None scores: list[Score] = Field(default_factory=list) # When True, the memory layer skips persisting this piece. Used for ephemeral diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index c468461090..8e1ca76e0f 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, ConfigDict @@ -26,7 +26,7 @@ class QuestionAnsweringEntry(BaseModel): model_config = ConfigDict(extra="forbid") question: str answer_type: Literal["int", "float", "str", "bool"] - correct_answer: Union[int, str, float] + correct_answer: int | str | float choices: list[QuestionChoice] def get_correct_answer_text(self) -> str: diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py index 46a6e79fcf..2ef67f908f 100644 --- a/pyrit/models/retry_event.py +++ b/pyrit/models/retry_event.py @@ -6,7 +6,6 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Optional from pydantic import BaseModel, Field @@ -29,8 +28,8 @@ class RetryEvent(BaseModel): exception_type: str = "" exception_message: str = "" component_role: str = "" - component_name: Optional[str] = None - endpoint: Optional[str] = None + component_name: str | None = None + endpoint: str | None = None elapsed_seconds: float = 0.0 def to_dict(self) -> dict: diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 5b610f80d8..09ccd0d70e 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from urllib.parse import urlparse import aiofiles @@ -36,36 +36,36 @@ class StorageIO(ABC): """ @abstractmethod - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the file (or blob) from the given path. """ @abstractmethod - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to the given path. """ @abstractmethod - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Asynchronously checks if a file or blob exists at the given path. """ @abstractmethod - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Asynchronously checks if the path refers to a file (not a directory or container). """ @abstractmethod - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. """ - async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffix-exempt + async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exempt """ Read a file from storage (deprecated alias of ``read_file_async``). @@ -82,7 +82,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffi ) return await self.read_file_async(path) - async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyrit-async-suffix-exempt + async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-async-suffix-exempt """ Write data to storage (deprecated alias of ``write_file_async``). @@ -97,7 +97,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyri ) await self.write_file_async(path, data) - async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether a path exists (deprecated alias of ``path_exists_async``). @@ -114,7 +114,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suff ) return await self.path_exists_async(path) - async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether the given path is a file (deprecated alias of ``is_file_async``). @@ -131,7 +131,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-e ) return await self.is_file_async(path) - async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: # pyrit-async-suffix-exempt + async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyrit-async-suffix-exempt """ Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). @@ -151,7 +151,7 @@ class DiskStorageIO(StorageIO): Implementation of StorageIO for local disk storage. """ - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads a file from the local disk. @@ -166,7 +166,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: async with aiofiles.open(path, "rb") as file: return await file.read() - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to a file on the local disk. @@ -179,7 +179,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: async with aiofiles.open(path, "wb") as file: await file.write(data) - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a path exists on the local disk. @@ -193,7 +193,7 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.exists() - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the given path is a file (not a directory). @@ -207,7 +207,7 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.is_file() - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory if it doesn't exist on the local disk. @@ -219,7 +219,7 @@ async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> if not directory_path.exists(): directory_path.mkdir(parents=True, exist_ok=True) - def _convert_to_path(self, path: Union[Path, str]) -> Path: + def _convert_to_path(self, path: Path | str) -> Path: """ Convert an input path to a Path object. @@ -241,8 +241,8 @@ class AzureBlobStorageIO(StorageIO): def __init__( self, *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, + container_url: str | None = None, + sas_token: str | None = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, ) -> None: """ @@ -351,7 +351,7 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: return container_name, blob_name raise ValueError("Invalid blob URL") - def _resolve_blob_name(self, path: Union[Path, str]) -> str: + def _resolve_blob_name(self, path: Path | str) -> str: """ Resolve a blob name from either a full blob URL or a relative blob path. @@ -377,7 +377,7 @@ def _resolve_blob_name(self, path: Union[Path, str]) -> str: except ValueError: return path_str - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the content of a file (blob) from Azure Blob Storage. @@ -420,7 +420,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: await self._client_async.close() self._client_async = None - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Write data to Azure Blob Storage at the specified path. @@ -443,7 +443,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: await self._client_async.close() self._client_async = None - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a given path exists in the Azure Blob Storage container. @@ -468,7 +468,7 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the path refers to a file (blob) in Azure Blob Storage. @@ -493,7 +493,7 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def create_directory_if_not_exists_async(self, directory_path: Union[Path, str]) -> None: # type: ignore[ty:invalid-method-override] + async def create_directory_if_not_exists_async(self, directory_path: Path | str) -> None: # type: ignore[ty:invalid-method-override] """ Log a no-op directory creation for Azure Blob Storage. diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 148f4d5e03..8103df378f 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -5,7 +5,6 @@ import contextlib import logging from pathlib import Path -from typing import Optional import numpy as np @@ -38,7 +37,7 @@ class AddImageVideoConverter(PromptConverter): def __init__( self, video_path: str, - output_path: Optional[str] = None, + output_path: str | None = None, img_position: tuple[int, int] = (10, 10), img_resize_size: tuple[int, int] = (500, 500), ) -> None: @@ -146,9 +145,9 @@ def _add_image_to_video_sync( import cv2 video_path = self._video_path - local_temp_path: Optional[Path] = None - cap: Optional[cv2.VideoCapture] = None - output_video: Optional[cv2.VideoWriter] = None + local_temp_path: Path | None = None + cap: cv2.VideoCapture | None = None + output_video: cv2.VideoWriter | None = None try: if azure_storage_flag: diff --git a/pyrit/prompt_converter/ask_to_decode_converter.py b/pyrit/prompt_converter/ask_to_decode_converter.py index 847360e37f..cb6e47d148 100644 --- a/pyrit/prompt_converter/ask_to_decode_converter.py +++ b/pyrit/prompt_converter/ask_to_decode_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import random -from typing import Optional from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -39,7 +38,7 @@ class AskToDecodeConverter(PromptConverter): all_templates = garak_templates + extra_templates - def __init__(self, template: Optional[str] = None, encoding_name: str = "cipher") -> None: + def __init__(self, template: str | None = None, encoding_name: str = "cipher") -> None: """ Initialize the converter with a specified encoding name and template. diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index 9f355ce158..1d1c6abc3a 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -5,7 +5,7 @@ import logging import time from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk # noqa: F401 @@ -47,10 +47,10 @@ class AzureSpeechAudioToTextConverter(PromptConverter): def __init__( self, *, - azure_speech_region: Optional[str] = None, - azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, + azure_speech_region: str | None = None, + azure_speech_key: str | Callable[[], str | Awaitable[str]] | None = None, + azure_speech_resource_id: str | None = None, + use_entra_auth: bool | None = None, recognition_language: str = "en-US", ) -> None: """ diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index ffb8934b36..de0eaa5b2f 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -3,7 +3,7 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk # noqa: F401 @@ -48,10 +48,10 @@ class AzureSpeechTextToAudioConverter(PromptConverter): def __init__( self, *, - azure_speech_region: Optional[str] = None, - azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, + azure_speech_region: str | None = None, + azure_speech_key: str | Callable[[], str | Awaitable[str]] | None = None, + azure_speech_resource_id: str | None = None, + use_entra_auth: bool | None = None, synthesis_language: str = "en_US", synthesis_voice_name: str = "en-US-AvaNeural", output_format: AzureSpeechAudioFormat = "wav", diff --git a/pyrit/prompt_converter/bin_ascii_converter.py b/pyrit/prompt_converter/bin_ascii_converter.py index f06971d9ee..e958774942 100644 --- a/pyrit/prompt_converter/bin_ascii_converter.py +++ b/pyrit/prompt_converter/bin_ascii_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import binascii -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import ( @@ -29,8 +29,8 @@ def __init__( self, *, encoding_func: EncodingFunc = "hex", - word_selection_strategy: Optional[WordSelectionStrategy] = None, - word_split_separator: Optional[str] = " ", + word_selection_strategy: WordSelectionStrategy | None = None, + word_split_separator: str | None = " ", ) -> None: """ Initialize the BinAsciiConverter. diff --git a/pyrit/prompt_converter/binary_converter.py b/pyrit/prompt_converter/binary_converter.py index c17ce73234..a1fa402438 100644 --- a/pyrit/prompt_converter/binary_converter.py +++ b/pyrit/prompt_converter/binary_converter.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_converter.word_level_converter import WordLevelConverter @@ -29,7 +29,7 @@ def __init__( self, *, bits_per_char: BinaryConverter.BitsPerChar = BitsPerChar.BITS_16, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified bits per character and selection strategy. diff --git a/pyrit/prompt_converter/charswap_attack_converter.py b/pyrit/prompt_converter/charswap_attack_converter.py index cf91c75a88..cb7e669e11 100644 --- a/pyrit/prompt_converter/charswap_attack_converter.py +++ b/pyrit/prompt_converter/charswap_attack_converter.py @@ -3,7 +3,6 @@ import random import string -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import ( @@ -22,7 +21,7 @@ def __init__( self, *, max_iterations: int = 10, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified parameters. diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index a8d932ec6a..6079cd18f6 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -7,7 +7,7 @@ import re import textwrap from collections.abc import Callable -from typing import Any, Optional +from typing import Any from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @@ -52,8 +52,8 @@ def __init__( self, *, encrypt_type: str, - encrypt_function: Optional[Callable[..., Any]] = None, - decrypt_function: Optional[Callable[..., Any] | list[Callable[..., Any] | str]] = None, + encrypt_function: Callable[..., Any] | None = None, + decrypt_function: Callable[..., Any] | list[Callable[..., Any] | str] | None = None, ) -> None: """ Initialize the converter with the specified encryption type and optional functions. @@ -163,10 +163,10 @@ class TreeNode: def __init__(self, value: str) -> None: self.value = value - self.left: Optional[TreeNode] = None - self.right: Optional[TreeNode] = None + self.left: TreeNode | None = None + self.right: TreeNode | None = None - def build_tree(words: list[str], start: int, end: int) -> Optional[TreeNode]: + def build_tree(words: list[str], start: int, end: int) -> TreeNode | None: """ Recursively build a balanced binary tree from a sublist of words. @@ -189,7 +189,7 @@ def build_tree(words: list[str], start: int, end: int) -> Optional[TreeNode]: return node - def tree_to_json(node: Optional[TreeNode]) -> Optional[dict[str, Any]]: + def tree_to_json(node: TreeNode | None) -> dict[str, Any] | None: """ Convert a tree to a JSON representation. diff --git a/pyrit/prompt_converter/colloquial_wordswap_converter.py b/pyrit/prompt_converter/colloquial_wordswap_converter.py index e6118a4040..9f8d0a204d 100644 --- a/pyrit/prompt_converter/colloquial_wordswap_converter.py +++ b/pyrit/prompt_converter/colloquial_wordswap_converter.py @@ -4,7 +4,6 @@ import pathlib import random import re -from typing import Optional import yaml @@ -28,8 +27,8 @@ def __init__( self, *, deterministic: bool = False, - custom_substitutions: Optional[dict[str, list[str]]] = None, - wordswap_path: Optional[str] = None, + custom_substitutions: dict[str, list[str]] | None = None, + wordswap_path: str | None = None, ) -> None: """ Initialize the converter with optional deterministic mode and substitutions source. diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index b9a0ee2c7c..aad3854dcc 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - system_prompt_template: Optional[SeedPrompt] = None, + system_prompt_template: SeedPrompt | None = None, denylist: list[str] | None = None, ) -> None: """ diff --git a/pyrit/prompt_converter/first_letter_converter.py b/pyrit/prompt_converter/first_letter_converter.py index ce6058ac68..fa610b54e5 100644 --- a/pyrit/prompt_converter/first_letter_converter.py +++ b/pyrit/prompt_converter/first_letter_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -18,7 +17,7 @@ def __init__( self, *, letter_separator: str = " ", - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified letter separator and selection strategy. diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index 57fc13b856..da8c01e91f 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -4,7 +4,7 @@ import base64 import logging from io import BytesIO -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib.parse import urlparse import aiohttp @@ -49,13 +49,13 @@ class ImageCompressionConverter(PromptConverter): def __init__( self, *, - output_format: Optional[Literal["JPEG", "PNG", "WEBP"]] = None, - quality: Optional[int] = None, - optimize: Optional[bool] = None, - progressive: Optional[bool] = None, - compress_level: Optional[int] = None, - lossless: Optional[bool] = None, - method: Optional[int] = None, + output_format: Literal["JPEG", "PNG", "WEBP"] | None = None, + quality: int | None = None, + optimize: bool | None = None, + progressive: bool | None = None, + compress_level: int | None = None, + lossless: bool | None = None, + method: int | None = None, background_color: tuple[int, int, int] = (0, 0, 0), min_compression_threshold: int = 1024, fallback_to_original: bool = True, diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py index 833f25d124..2049fb4675 100644 --- a/pyrit/prompt_converter/insert_punctuation_converter.py +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -4,7 +4,6 @@ import random import re import string -from typing import Optional from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -72,7 +71,7 @@ def _is_valid_punctuation(self, punctuation_list: list[str]) -> bool: return all(char in string.punctuation for char in punctuation_list) async def convert_async( - self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[list[str]] = None + self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: list[str] | None = None ) -> ConverterResult: """ Convert the given prompt by inserting punctuation. diff --git a/pyrit/prompt_converter/leetspeak_converter.py b/pyrit/prompt_converter/leetspeak_converter.py index a904d804a7..7d0067aaef 100644 --- a/pyrit/prompt_converter/leetspeak_converter.py +++ b/pyrit/prompt_converter/leetspeak_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import random -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -18,8 +17,8 @@ def __init__( self, *, deterministic: bool = True, - custom_substitutions: Optional[dict[str, list[str]]] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + custom_substitutions: dict[str, list[str]] | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with optional deterministic mode and custom substitutions. diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index 7e8b64a0d8..e35270bce3 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/math_obfuscation_converter.py b/pyrit/prompt_converter/math_obfuscation_converter.py index 870d7fc03a..7bc0a9b4f7 100644 --- a/pyrit/prompt_converter/math_obfuscation_converter.py +++ b/pyrit/prompt_converter/math_obfuscation_converter.py @@ -3,7 +3,6 @@ import logging import random -from typing import Optional from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -44,9 +43,9 @@ def __init__( *, min_n: int = 2, max_n: int = 9, - hint: Optional[str] = None, - suffix: Optional[str] = None, - rng: Optional[random.Random] = None, + hint: str | None = None, + suffix: str | None = None, + rng: random.Random | None = None, ) -> None: """ Initialize a MathObfuscationConverter instance. diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index ce9a4a4246..0fa29f4294 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index cfa4b83fd7..77e418e731 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -4,7 +4,6 @@ import logging import pathlib import textwrap -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,9 +26,9 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - noise: Optional[str] = None, + noise: str | None = None, number_errors: int = 5, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the specified parameters. diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 4017f426e9..0e30a4d08f 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -5,7 +5,7 @@ import hashlib from io import BytesIO from pathlib import Path -from typing import Any, Optional +from typing import Any from pypdf import PageObject, PdfReader, PdfWriter from reportlab.lib.units import mm @@ -38,7 +38,7 @@ class PDFConverter(PromptConverter): def __init__( self, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, font_type: str = "Helvetica", font_size: int = 12, font_color: tuple[int, int, int] = (255, 255, 255), @@ -46,8 +46,8 @@ def __init__( page_height: int = 297, column_width: int = 0, row_height: int = 10, - existing_pdf: Optional[Path] = None, - injection_items: Optional[list[dict[str, Any]]] = None, + existing_pdf: Path | None = None, + injection_items: list[dict[str, Any]] | None = None, ) -> None: """ Initialize the converter with the specified parameters. @@ -78,9 +78,9 @@ def __init__( self._row_height = row_height # Keeping the user's path here - self._existing_pdf_path: Optional[Path] = existing_pdf + self._existing_pdf_path: Path | None = existing_pdf # We store the file data in a separate BytesIO for type checker compatibility - self._existing_pdf_bytes: Optional[BytesIO] = None + self._existing_pdf_bytes: BytesIO | None = None self._injection_items = injection_items or [] diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 0e6916f371..2f2375e1f2 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,7 +6,7 @@ import inspect import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Optional, get_args from pyrit import prompt_converter from pyrit.models import ComponentIdentifier, Identifiable, PromptDataType @@ -56,7 +56,7 @@ class PromptConverter(Identifiable): #: ``super().__init__(converter_target=...)`` so the base class can validate it. TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None def __init_subclass__(cls, **kwargs: object) -> None: """ @@ -195,8 +195,8 @@ def _build_identifier(self) -> ComponentIdentifier: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct and return the converter identifier. diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index cc1424d14f..b77487b293 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional import segno @@ -21,11 +20,11 @@ def __init__( border: int = 4, dark_color: tuple[int, int, int] = (0, 0, 0), light_color: tuple[int, int, int] = (255, 255, 255), - data_dark_color: Optional[tuple[int, int, int]] = None, - data_light_color: Optional[tuple[int, int, int]] = None, - finder_dark_color: Optional[tuple[int, int, int]] = None, - finder_light_color: Optional[tuple[int, int, int]] = None, - border_color: Optional[tuple[int, int, int]] = None, + data_dark_color: tuple[int, int, int] | None = None, + data_light_color: tuple[int, int, int] | None = None, + finder_dark_color: tuple[int, int, int] | None = None, + finder_light_color: tuple[int, int, int] | None = None, + border_color: tuple[int, int, int] | None = None, ) -> None: """ Initialize the converter with specified parameters for QR code generation. diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 769cb51611..239fc6c42c 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -4,7 +4,6 @@ import logging import random from pathlib import Path -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH, DATASETS_PATH @@ -36,9 +35,9 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - system_prompt_template: Optional[SeedPrompt] = None, - languages: Optional[list[str]] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + system_prompt_template: SeedPrompt | None = None, + languages: list[str] | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with a target, an optional system prompt template, and language options. diff --git a/pyrit/prompt_converter/repeat_token_converter.py b/pyrit/prompt_converter/repeat_token_converter.py index c2b5a74825..c711448ca4 100644 --- a/pyrit/prompt_converter/repeat_token_converter.py +++ b/pyrit/prompt_converter/repeat_token_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import re -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -35,7 +35,7 @@ def __init__( *, token_to_repeat: str, times_to_repeat: int, - token_insert_mode: Optional[Literal["split", "prepend", "append", "repeat"]] = None, + token_insert_mode: Literal["split", "prepend", "append", "repeat"] | None = None, ) -> None: """ Initialize the converter with the specified token, number of repetitions, and insertion mode. diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 85e05428fa..b4229a5226 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -3,7 +3,7 @@ import logging import pathlib -from typing import Literal, Optional, get_args +from typing import Literal, get_args from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -46,7 +46,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] mode: str = "combined", - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the scientific translation converter. diff --git a/pyrit/prompt_converter/string_join_converter.py b/pyrit/prompt_converter/string_join_converter.py index cd961fd65f..e927e144b6 100644 --- a/pyrit/prompt_converter/string_join_converter.py +++ b/pyrit/prompt_converter/string_join_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -17,7 +16,7 @@ def __init__( self, *, join_value: str = "-", - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified join value and selection strategy. diff --git a/pyrit/prompt_converter/template_segment_converter.py b/pyrit/prompt_converter/template_segment_converter.py index ed21b0434e..b9bafbce34 100644 --- a/pyrit/prompt_converter/template_segment_converter.py +++ b/pyrit/prompt_converter/template_segment_converter.py @@ -5,7 +5,6 @@ import logging import pathlib import random -from typing import Optional from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @@ -28,7 +27,7 @@ class TemplateSegmentConverter(PromptConverter): def __init__( self, *, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the specified target and prompt template. diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index f8f20468d6..8f0852b2c2 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] tense: str, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the target chat support, tense, and optional prompt template. diff --git a/pyrit/prompt_converter/text_selection_strategy.py b/pyrit/prompt_converter/text_selection_strategy.py index 6dbc8a8b63..641ad27d96 100644 --- a/pyrit/prompt_converter/text_selection_strategy.py +++ b/pyrit/prompt_converter/text_selection_strategy.py @@ -5,7 +5,6 @@ import random import re from re import Pattern -from typing import Optional, Union class TextSelectionStrategy(abc.ABC): @@ -133,7 +132,7 @@ class IndexSelectionStrategy(TextSelectionStrategy): Selects text based on absolute character indices. """ - def __init__(self, *, start: int = 0, end: Optional[int] = None) -> None: + def __init__(self, *, start: int = 0, end: int | None = None) -> None: """ Initialize the index selection strategy. @@ -165,7 +164,7 @@ class RegexSelectionStrategy(TextSelectionStrategy): Selects text based on the first regex match. """ - def __init__(self, *, pattern: Union[str, Pattern[str]]) -> None: + def __init__(self, *, pattern: str | Pattern[str]) -> None: """ Initialize the regex selection strategy. @@ -290,7 +289,7 @@ class ProportionSelectionStrategy(TextSelectionStrategy): Selects a proportion of text anchored to a specific position (start, end, middle, or random). """ - def __init__(self, *, proportion: float, anchor: str = "start", seed: Optional[int] = None) -> None: + def __init__(self, *, proportion: float, anchor: str = "start", seed: int | None = None) -> None: """ Initialize the proportion selection strategy. @@ -473,7 +472,7 @@ class WordProportionSelectionStrategy(WordSelectionStrategy): Selects a random proportion of words. """ - def __init__(self, *, proportion: float, seed: Optional[int] = None) -> None: + def __init__(self, *, proportion: float, seed: int | None = None) -> None: """ Initialize the word proportion selection strategy. @@ -515,7 +514,7 @@ class WordRegexSelectionStrategy(WordSelectionStrategy): Selects words that match a regex pattern. """ - def __init__(self, *, pattern: Union[str, Pattern[str]]) -> None: + def __init__(self, *, pattern: str | Pattern[str]) -> None: """ Initialize the word regex selection strategy. diff --git a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py index 503c9dab9b..c4fc1bbaec 100644 --- a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -25,8 +25,8 @@ class SneakyBitsSmugglerConverter(SmugglerConverter): def __init__( self, action: Literal["encode", "decode"] = "encode", - zero_char: Optional[str] = None, - one_char: Optional[str] = None, + zero_char: str | None = None, + one_char: str | None = None, ) -> None: """ Initialize the converter with options for encoding/decoding in Sneaky Bits mode. diff --git a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py index f70ee4902b..077c91db84 100644 --- a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -32,7 +32,7 @@ class VariationSelectorSmugglerConverter(SmugglerConverter): def __init__( self, action: Literal["encode", "decode"] = "encode", - base_char_utf8: Optional[str] = None, + base_char_utf8: str | None = None, embed_in_base: bool = True, ) -> None: """ diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index c21a118603..562a4ee6af 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] tone: str, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the target chat support, tone, and optional prompt template. diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index 3159cf1de7..07cfc3744e 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -7,7 +7,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -35,7 +34,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/unicode_replacement_converter.py b/pyrit/prompt_converter/unicode_replacement_converter.py index 71a0e52e54..150499ab72 100644 --- a/pyrit/prompt_converter/unicode_replacement_converter.py +++ b/pyrit/prompt_converter/unicode_replacement_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -17,7 +16,7 @@ def __init__( self, *, encode_spaces: bool = False, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified selection strategy. diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 1dc5d00b5c..b811ab1b3f 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -7,7 +7,7 @@ import hashlib from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from docx import Document @@ -26,7 +26,7 @@ class _WordDocInjectionConfig: """Configuration for how to inject content into a Word document.""" - existing_docx: Optional[Path] + existing_docx: Path | None placeholder: str @@ -66,8 +66,8 @@ class WordDocConverter(PromptConverter): def __init__( self, *, - prompt_template: Optional[SeedPrompt] = None, - existing_docx: Optional[Path] = None, + prompt_template: SeedPrompt | None = None, + existing_docx: Path | None = None, placeholder: str = "{{INJECTION_PLACEHOLDER}}", ) -> None: """ @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: Returns: ComponentIdentifier: The identifier with converter-specific parameters. """ - template_hash: Optional[str] = None + template_hash: str | None = None if self._prompt_template: template_hash = hashlib.sha256(str(self._prompt_template.value).encode("utf-8")).hexdigest()[:16] diff --git a/pyrit/prompt_converter/word_level_converter.py b/pyrit/prompt_converter/word_level_converter.py index 5a2f874f0c..2ca186014a 100644 --- a/pyrit/prompt_converter/word_level_converter.py +++ b/pyrit/prompt_converter/word_level_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -30,8 +30,8 @@ class WordLevelConverter(PromptConverter): def __init__( self, *, - word_selection_strategy: Optional[WordSelectionStrategy] = None, - word_split_separator: Optional[str] = " ", + word_selection_strategy: WordSelectionStrategy | None = None, + word_split_separator: str | None = " ", **kwargs: Any, ) -> None: """ diff --git a/pyrit/prompt_converter/zalgo_converter.py b/pyrit/prompt_converter/zalgo_converter.py index ddd7c686ec..4da331cf79 100644 --- a/pyrit/prompt_converter/zalgo_converter.py +++ b/pyrit/prompt_converter/zalgo_converter.py @@ -3,7 +3,6 @@ import logging import random -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -25,8 +24,8 @@ def __init__( self, *, intensity: int = 10, - seed: Optional[int] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + seed: int | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified selection parameters. diff --git a/pyrit/prompt_normalizer/normalizer_request.py b/pyrit/prompt_normalizer/normalizer_request.py index c030ca5278..e553fe8637 100644 --- a/pyrit/prompt_normalizer/normalizer_request.py +++ b/pyrit/prompt_normalizer/normalizer_request.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Optional from pyrit.models import Message from pyrit.prompt_normalizer.prompt_converter_configuration import ( @@ -27,7 +26,7 @@ def __init__( message: Message, request_converter_configurations: list[PromptConverterConfiguration] | None = None, response_converter_configurations: list[PromptConverterConfiguration] | None = None, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> None: """ Initialize a normalizer request. diff --git a/pyrit/prompt_normalizer/prompt_converter_configuration.py b/pyrit/prompt_normalizer/prompt_converter_configuration.py index cb9ae55425..3ba455af64 100644 --- a/pyrit/prompt_normalizer/prompt_converter_configuration.py +++ b/pyrit/prompt_normalizer/prompt_converter_configuration.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Optional from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter @@ -19,8 +18,8 @@ class PromptConverterConfiguration: """ converters: list[PromptConverter] - indexes_to_apply: Optional[list[int]] = None - prompt_data_types_to_apply: Optional[list[PromptDataType]] = None + indexes_to_apply: list[int] | None = None + prompt_data_types_to_apply: list[PromptDataType] | None = None @classmethod def from_converters(cls, *, converters: list[PromptConverter]) -> list["PromptConverterConfiguration"]: diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 9c23f91cf3..95196db642 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -9,7 +9,7 @@ import traceback import wave from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message @@ -68,11 +68,11 @@ async def send_prompt_async( *, message: Message, target: PromptTarget, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, request_converter_configurations: list[PromptConverterConfiguration] | None = None, response_converter_configurations: list[PromptConverterConfiguration] | None = None, - labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, + labels: dict[str, str] | None = None, + attack_identifier: ComponentIdentifier | None = None, ) -> Message: """ Send a single request to a target. @@ -196,8 +196,8 @@ async def send_prompt_batch_to_target_async( *, requests: list[NormalizerRequest], target: PromptTarget, - labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, + labels: dict[str, str] | None = None, + attack_identifier: ComponentIdentifier | None = None, batch_size: int = 10, ) -> list[Message]: """ @@ -397,10 +397,10 @@ async def add_prepended_conversation_to_memory_async( self, conversation_id: str, should_convert: bool = True, - converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - prepended_conversation: Optional[list[Message]] = None, - ) -> Optional[list[Message]]: + converter_configurations: list[PromptConverterConfiguration] | None = None, + attack_identifier: ComponentIdentifier | None = None, + prepended_conversation: list[Message] | None = None, + ) -> list[Message] | None: """ Process the prepended conversation by converting it if needed and adding it to memory. @@ -454,10 +454,10 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt self, conversation_id: str, should_convert: bool = True, - converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - prepended_conversation: Optional[list[Message]] = None, - ) -> Optional[list[Message]]: + converter_configurations: list[PromptConverterConfiguration] | None = None, + attack_identifier: ComponentIdentifier | None = None, + prepended_conversation: list[Message] | None = None, + ) -> list[Message] | None: """ Use ``add_prepended_conversation_to_memory_async`` instead; this is a deprecated alias. diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 184cb1664e..b3b10a04e2 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -3,7 +3,6 @@ import logging from enum import Enum -from typing import Optional from urllib.parse import urlparse from azure.core.exceptions import ClientAuthenticationError @@ -69,11 +68,11 @@ class AzureBlobStorageTarget(PromptTarget): def __init__( self, *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, + container_url: str | None = None, + sas_token: str | None = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Azure Blob Storage target. @@ -95,8 +94,8 @@ def __init__( env_var_name=self.AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE, passed_value=container_url ) - self._sas_token: Optional[str] = sas_token - self._client_async: Optional[AsyncContainerClient] = None + self._sas_token: str | None = sas_token + self._client_async: AsyncContainerClient | None = None super().__init__( endpoint=self._container_url, diff --git a/pyrit/prompt_target/batch_helper.py b/pyrit/prompt_target/batch_helper.py index 95ec6809fb..399465e4fb 100644 --- a/pyrit/prompt_target/batch_helper.py +++ b/pyrit/prompt_target/batch_helper.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Callable, Generator, Sequence -from typing import Any, Optional +from typing import Any from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -31,7 +31,7 @@ def _get_chunks(*args: Sequence[Any], batch_size: int) -> Generator[list[Sequenc yield [arg[i : i + batch_size] for arg in args] -def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch_size: int) -> None: +def _validate_rate_limit_parameters(prompt_target: PromptTarget | None, batch_size: int) -> None: """ Validate the constraints between Rate Limit (Requests Per Minute) and batch size. @@ -49,7 +49,7 @@ def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch async def batch_task_async( *, - prompt_target: Optional[PromptTarget] = None, + prompt_target: PromptTarget | None = None, batch_size: int, items_to_batch: Sequence[Sequence[Any]], task_func: Callable[..., Any], diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index aa13658aac..4c6c81fa3c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Any, Union, final +from typing import Any, final from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory, MemoryInterface @@ -319,7 +319,7 @@ def _create_identifier( self, *, params: dict[str, Any] | None = None, - children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the target identifier. diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 7f1010745f..719c441d38 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum from types import MappingProxyType -from typing import NoReturn, Optional, cast +from typing import NoReturn, cast from pyrit.models import PromptDataType @@ -165,7 +165,7 @@ def includes(self, *, capability: CapabilityName) -> bool: return bool(getattr(self, capability.value)) @staticmethod - def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]": + def get_known_capabilities(underlying_model: str) -> "TargetCapabilities | None": """ Return the known capabilities for a specific underlying model, or None if unrecognized. diff --git a/pyrit/prompt_target/common/utils.py b/pyrit/prompt_target/common/utils.py index 2aaa10bb68..6b6a74a813 100644 --- a/pyrit/prompt_target/common/utils.py +++ b/pyrit/prompt_target/common/utils.py @@ -3,12 +3,12 @@ import asyncio from collections.abc import Callable -from typing import Any, Optional +from typing import Any from pyrit.exceptions import PyritException -def validate_temperature(temperature: Optional[float]) -> None: +def validate_temperature(temperature: float | None) -> None: """ Validate that temperature parameter is within valid range. @@ -22,7 +22,7 @@ def validate_temperature(temperature: Optional[float]) -> None: raise PyritException(message="temperature must be between 0 and 2 (inclusive).") -def validate_top_p(top_p: Optional[float]) -> None: +def validate_top_p(top_p: float | None) -> None: """ Validate that top_p parameter is within valid range. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 1bad549c7f..b2dc6e342a 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -4,7 +4,6 @@ import enum import json import logging -from typing import Optional from pyrit.common import net_utility from pyrit.common.deprecation import print_deprecation_message @@ -43,8 +42,8 @@ def __init__( self, *, level: GandalfLevel, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Gandalf target. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 6ed37c0da3..ce3028ca89 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -6,7 +6,7 @@ import logging import re from collections.abc import Callable -from typing import Any, Optional +from typing import Any import httpx @@ -37,11 +37,11 @@ def __init__( http_request: str, prompt_regex_string: str = "{PROMPT}", use_tls: bool = True, - callback_function: Optional[Callable[..., Any]] = None, - max_requests_per_minute: Optional[int] = None, - client: Optional[httpx.AsyncClient] = None, + callback_function: Callable[..., Any] | None = None, + max_requests_per_minute: int | None = None, + client: httpx.AsyncClient | None = None, model_name: str = "", - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **httpx_client_kwargs: Any, ) -> None: """ @@ -108,7 +108,7 @@ def with_client( http_request: str, prompt_regex_string: str = "{PROMPT}", callback_function: Callable[..., Any] | None = None, - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, ) -> "HTTPTarget": """ Alternative constructor that accepts a pre-configured httpx client. diff --git a/pyrit/prompt_target/http_target/http_target_callback_functions.py b/pyrit/prompt_target/http_target/http_target_callback_functions.py index 90cc7f79a3..8d749af73d 100644 --- a/pyrit/prompt_target/http_target/http_target_callback_functions.py +++ b/pyrit/prompt_target/http_target/http_target_callback_functions.py @@ -5,7 +5,7 @@ import json import re from collections.abc import Callable -from typing import Any, Optional +from typing import Any import requests @@ -42,7 +42,7 @@ def parse_json_http_response(response: requests.Response) -> str: def get_http_target_regex_matching_callback_function( - key: str, url: Optional[str] = None + key: str, url: str | None = None ) -> Callable[[requests.Response], str]: """ Get a callback function that parses HTTP responses using regex matching. diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index bd32fd1fe2..95f0c47124 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -5,7 +5,7 @@ import mimetypes from collections.abc import Callable from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal import aiofiles import httpx @@ -42,15 +42,15 @@ def __init__( *, http_url: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] = "POST", - file_path: Optional[str] = None, - json_data: Optional[dict[str, Any]] = None, - form_data: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - http2: Optional[bool] = None, + file_path: str | None = None, + json_data: dict[str, Any] | None = None, + form_data: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + http2: bool | None = None, callback_function: Callable[..., Any] | None = None, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, **httpx_client_kwargs: Any, ) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f82fc40e29..1fd2400ca9 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -5,7 +5,7 @@ import json import logging from collections.abc import MutableSequence -from typing import Any, Optional +from typing import Any from pyrit.common.data_url_converter import convert_local_image_to_data_url_async from pyrit.exceptions import ( @@ -80,17 +80,17 @@ class OpenAIChatTarget(OpenAITarget): def __init__( self, *, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - n: Optional[int] = None, - audio_response_config: Optional[OpenAIChatAudioConfig] = None, - extra_body_parameters: Optional[dict[str, Any]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + n: int | None = None, + audio_response_config: OpenAIChatAudioConfig | None = None, + extra_body_parameters: dict[str, Any] | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -259,7 +259,7 @@ def _check_content_filter(self, response: Any) -> bool: pass return False - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract partial content from a Chat Completions response with finish_reason=content_filter. @@ -279,7 +279,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: pass return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate a Chat Completions API response for errors. @@ -415,7 +415,7 @@ async def _construct_message_from_response_async(self, response: Any, request: M audio_response = message.audio # Add transcript as text piece with metadata - audio_transcript: Optional[str] = getattr(audio_response, "transcript", None) + audio_transcript: str | None = getattr(audio_response, "transcript", None) if audio_transcript: transcript_piece = construct_response_from_request( request=request, @@ -426,7 +426,7 @@ async def _construct_message_from_response_async(self, response: Any, request: M pieces.append(transcript_piece) # Save audio data and add as audio_path piece - audio_data: Optional[str] = getattr(audio_response, "data", None) + audio_data: str | None = getattr(audio_response, "data", None) if audio_data: audio_path = await self._save_audio_response_async(audio_data_base64=audio_data) audio_piece = construct_response_from_request( @@ -676,7 +676,7 @@ async def _construct_request_body_async( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: + def _build_response_format(self, json_config: _JsonResponseConfig) -> dict[str, Any] | None: if not json_config.enabled: return None diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 0f43ac5af8..4c4ff810c3 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Any, Optional +from typing import Any from pyrit.exceptions.exception_classes import ( pyrit_target_retry, @@ -23,13 +23,13 @@ class OpenAICompletionTarget(OpenAITarget): def __init__( self, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - n: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + n: int | None = None, + custom_configuration: TargetConfiguration | None = None, *args: Any, **kwargs: Any, ) -> None: diff --git a/pyrit/prompt_target/openai/openai_error_handling.py b/pyrit/prompt_target/openai/openai_error_handling.py index bc0daa77ae..3848675794 100644 --- a/pyrit/prompt_target/openai/openai_error_handling.py +++ b/pyrit/prompt_target/openai/openai_error_handling.py @@ -10,12 +10,11 @@ import json import logging -from typing import Optional, Union logger = logging.getLogger(__name__) -def _extract_request_id_from_exception(exc: Exception) -> Optional[str]: +def _extract_request_id_from_exception(exc: Exception) -> str | None: """ Extract the x-request-id from an OpenAI SDK exception for logging/telemetry. @@ -36,7 +35,7 @@ def _extract_request_id_from_exception(exc: Exception) -> Optional[str]: return None -def _extract_retry_after_from_exception(exc: Exception) -> Optional[float]: +def _extract_retry_after_from_exception(exc: Exception) -> float | None: """ Extract the Retry-After header from a rate-limit exception for intelligent backoff. @@ -61,7 +60,7 @@ def _extract_retry_after_from_exception(exc: Exception) -> Optional[float]: return None -def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: +def _is_content_filter_error(data: dict[str, object] | str) -> bool: """ Check if error data indicates content filtering. @@ -91,7 +90,7 @@ def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: return "content_filter" in lower or "policy_violation" in lower or "moderation_blocked" in lower -def _extract_error_payload(exc: Exception) -> tuple[Union[dict[str, object], str], bool]: +def _extract_error_payload(exc: Exception) -> tuple[dict[str, object] | str, bool]: """ Extract error payload and detect content filter from an OpenAI SDK exception. diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 7ae0383998..fec66aabb4 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import base64 import logging -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx @@ -65,11 +65,11 @@ def __init__( "1792x1024", "1024x1792", ] = "1024x1024", - output_format: Optional[Literal["png", "jpeg", "webp"]] = None, - quality: Optional[Literal["auto", "low", "medium", "high", "standard", "hd"]] = None, - style: Optional[Literal["natural", "vivid"]] = None, - background: Optional[Literal["transparent", "opaque", "auto"]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + output_format: Literal["png", "jpeg", "webp"] | None = None, + quality: Literal["auto", "low", "medium", "high", "standard", "hd"] | None = None, + style: Literal["natural", "vivid"] | None = None, + background: Literal["transparent", "opaque", "auto"] | None = None, + custom_configuration: TargetConfiguration | None = None, *args: Any, **kwargs: Any, ) -> None: diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 0197d5ba64..0c75faa908 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -6,7 +6,7 @@ import logging import re import wave -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Literal from openai import AsyncOpenAI @@ -87,9 +87,9 @@ class RealtimeTarget(OpenAITarget): def __init__( self, *, - voice: Optional[RealTimeVoice] = None, - existing_convo: Optional[dict[str, Any]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + voice: RealTimeVoice | None = None, + existing_convo: dict[str, Any] | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -121,7 +121,7 @@ def __init__( self.voice = voice self._existing_conversation = existing_convo if existing_convo is not None else {} - self._realtime_client: Optional[AsyncOpenAI] = None + self._realtime_client: AsyncOpenAI | None = None def open_streaming_session( self, @@ -550,7 +550,7 @@ async def save_audio_async( num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> str: """ Save audio bytes to a WAV file. @@ -583,7 +583,7 @@ async def save_audio( # pyrit-async-suffix-exempt num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> str: """ Use ``save_audio_async`` instead; this is a deprecated alias. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index c96b0c115a..64428bb6a0 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -8,7 +8,6 @@ from typing import ( Any, Literal, - Optional, cast, ) @@ -91,15 +90,15 @@ class OpenAIResponseTarget(OpenAITarget): def __init__( self, *, - custom_functions: Optional[dict[str, ToolExecutor]] = None, - max_output_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - reasoning_effort: Optional[ReasoningEffort] = None, - reasoning_summary: Optional[Literal["auto", "concise", "detailed"]] = None, - extra_body_parameters: Optional[dict[str, Any]] = None, + custom_functions: dict[str, ToolExecutor] | None = None, + max_output_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + reasoning_effort: ReasoningEffort | None = None, + reasoning_summary: Literal["auto", "concise", "detailed"] | None = None, + extra_body_parameters: dict[str, Any] | None = None, fail_on_missing_function: bool = False, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -396,7 +395,7 @@ async def _construct_request_body_async( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _build_reasoning_config(self) -> Optional[dict[str, Any]]: + def _build_reasoning_config(self) -> dict[str, Any] | None: """ Build the reasoning configuration dict for the Responses API. @@ -413,7 +412,7 @@ def _build_reasoning_config(self) -> Optional[dict[str, Any]]: reasoning["summary"] = self._reasoning_summary return reasoning - def _build_text_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: + def _build_text_format(self, json_config: _JsonResponseConfig) -> dict[str, Any] | None: if not json_config.enabled: return None @@ -459,7 +458,7 @@ def _check_content_filter(self, response: Any) -> bool: return False - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract partial content from a Response API response that was content-filtered. @@ -493,7 +492,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: except (AttributeError, IndexError, TypeError): return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate a Response API response for errors. @@ -584,7 +583,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me responses_to_return: list[Message] = [] # Main agentic loop - each back-and-forth creates a new message - tool_call_section: Optional[dict[str, Any]] = None + tool_call_section: dict[str, Any] | None = None while True: logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") @@ -625,7 +624,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me return responses_to_return def _parse_response_output_section( - self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] + self, *, section: Any, message_piece: MessagePiece, error: PromptResponseError | None ) -> MessagePiece | None: """ Parse model output sections, forwarding tool-calls for the agentic loop. @@ -726,7 +725,7 @@ def _parse_response_output_section( # Agentic helpers (module scope) - def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any]]: + def _find_last_pending_tool_call(self, reply: Message) -> dict[str, Any] | None: """ Return the last tool-call section in assistant messages, or None. Looks for a piece whose value parses as JSON with a 'type' key matching function_call. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 3ae988f186..acc54103f2 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -6,7 +6,7 @@ import re from abc import abstractmethod from collections.abc import Awaitable, Callable -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from openai import ( @@ -61,7 +61,7 @@ class OpenAITarget(PromptTarget): endpoint_environment_variable: str api_key_environment_variable: str - _async_client: Optional[AsyncOpenAI] = None + _async_client: AsyncOpenAI | None = None @property def _client(self) -> AsyncOpenAI: @@ -78,14 +78,14 @@ def _client(self) -> AsyncOpenAI: def __init__( self, *, - model_name: Optional[str] = None, - endpoint: Optional[str] = None, - api_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - headers: Optional[str] = None, - max_requests_per_minute: Optional[int] = None, - httpx_client_kwargs: Optional[dict[str, Any]] = None, - underlying_model: Optional[str] = None, - custom_configuration: Optional[TargetConfiguration] = None, + model_name: str | None = None, + endpoint: str | None = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + headers: str | None = None, + max_requests_per_minute: int | None = None, + httpx_client_kwargs: dict[str, Any] | None = None, + underlying_model: str | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize an instance of OpenAITarget. @@ -583,7 +583,7 @@ def _handle_content_filter_response(self, response: Any, request: MessagePiece) return error_message - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract any partial content the model generated before the content filter triggered. @@ -598,7 +598,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: """ return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate the response and return error Message if needed. diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index d9c66128cb..eaa6101c75 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Any, Literal, Optional +from typing import Any, Literal from pyrit.exceptions import ( pyrit_target_retry, @@ -40,8 +40,8 @@ def __init__( voice: TTSVoice = "alloy", response_format: TTSResponseFormat = "mp3", language: str = "en", - speed: Optional[float] = None, - custom_configuration: Optional[TargetConfiguration] = None, + speed: float | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 37dedc9ae0..db8deadf35 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -4,7 +4,7 @@ import logging from mimetypes import guess_type from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import Any, cast from openai.types import VideoSeconds, VideoSize @@ -67,7 +67,7 @@ def __init__( *, resolution_dimensions: VideoSize = "1280x720", n_seconds: int | VideoSeconds = 4, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -428,7 +428,7 @@ async def _construct_message_from_response_async(self, response: Any, request: A ) async def _save_video_response_async( - self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None + self, *, request: MessagePiece, video_data: bytes, video_id: str | None = None ) -> Message: """ Save video data to storage and construct response. @@ -449,7 +449,7 @@ async def _save_video_response_async( logger.info(f"Video saved to: {video_path}") # Include video_id in metadata for chaining (e.g., remix the generated video later) - prompt_metadata: Optional[dict[str, Union[str, int]]] = {"video_id": video_id} if video_id else None + prompt_metadata: dict[str, str | int] | None = {"video_id": video_id} if video_id else None # Construct response return construct_response_from_request( diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index a2dfc796b2..0a7021c23c 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -7,7 +7,7 @@ from contextlib import suppress from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.models import ( ComponentIdentifier, @@ -128,7 +128,7 @@ def __init__( *, page: "Page", copilot_type: CopilotType = CopilotType.CONSUMER, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Playwright Copilot target. @@ -254,7 +254,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me return [response_entry] - async def _interact_with_copilot_async(self, message: Message) -> Union[str, list[tuple[str, PromptDataType]]]: + async def _interact_with_copilot_async(self, message: Message) -> str | list[tuple[str, PromptDataType]]: """ Interact with Microsoft Copilot interface to send multimodal prompts. @@ -276,9 +276,7 @@ async def _interact_with_copilot_async(self, message: Message) -> Union[str, lis return await self._wait_for_response_async(selectors) - async def _wait_for_response_async( - self, selectors: CopilotSelectors - ) -> Union[str, list[tuple[str, PromptDataType]]]: + async def _wait_for_response_async(self, selectors: CopilotSelectors) -> str | list[tuple[str, PromptDataType]]: """ Wait for Copilot's response and extract the text and/or images. @@ -332,7 +330,7 @@ async def _wait_for_response_async( async def _extract_content_if_ready_async( self, selectors: CopilotSelectors, initial_group_count: int - ) -> Union[str, list[tuple[str, PromptDataType]], None]: + ) -> str | list[tuple[str, PromptDataType]] | None: """ Extract content if ready, otherwise return None. @@ -733,7 +731,7 @@ async def _extract_fallback_text_async(self, *, ai_message_groups: list[Any]) -> def _assemble_response( self, *, response_pieces: list[tuple[str, PromptDataType]] - ) -> Union[str, list[tuple[str, PromptDataType]]]: + ) -> str | list[tuple[str, PromptDataType]]: """ Assemble response pieces into appropriate return format. @@ -755,7 +753,7 @@ def _assemble_response( async def _extract_multimodal_content_async( self, selectors: CopilotSelectors, initial_group_count: int = 0 - ) -> Union[str, list[tuple[str, PromptDataType]]]: + ) -> str | list[tuple[str, PromptDataType]]: """ Extract multimodal content (text and images) from Copilot response. diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index 4178fe902b..d33ba2736c 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Optional, Protocol +from typing import TYPE_CHECKING, Protocol from pyrit.models import ( Message, @@ -71,8 +71,8 @@ def __init__( *, interaction_func: InteractionFunction, page: "Page", - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Playwright target. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 74c465a91a..d64af0de57 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Callable -from typing import Any, Literal, Optional +from typing import Any, Literal from pyrit.common import default_values, net_utility from pyrit.models import ( @@ -56,12 +56,12 @@ class PromptShieldTarget(PromptTarget): def __init__( self, - endpoint: Optional[str] = None, - api_key: Optional[str | Callable[[], str]] = None, - api_version: Optional[str] = "2024-09-01", - field: Optional[PromptShieldEntryField] = None, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + endpoint: str | None = None, + api_key: str | Callable[[], str] | None = None, + api_version: str | None = "2024-09-01", + field: PromptShieldEntryField | None = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Class that initializes an Azure Content Safety Prompt Shield Target. diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 9ede8d9ddc..8e0deed295 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -5,7 +5,7 @@ import json import sys from pathlib import Path -from typing import IO, Optional +from typing import IO from pyrit.common.deprecation import print_deprecation_message from pyrit.models import Message, MessagePiece @@ -26,7 +26,7 @@ def __init__( self, *, text_stream: IO[str] = sys.stdout, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the TextTarget. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 6a5de15f60..0635971011 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -7,7 +7,7 @@ import pathlib import uuid from enum import IntEnum -from typing import Any, Optional, Union +from typing import Any import httpx import websockets @@ -90,11 +90,11 @@ def __init__( self, *, websocket_base_url: str = "wss://substrate.office.com/m365Copilot/Chathub", - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, model_name: str = "copilot", response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, - authenticator: Optional[Union[CopilotAuthenticator, ManualCopilotAuthenticator]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + authenticator: CopilotAuthenticator | ManualCopilotAuthenticator | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the WebSocketCopilotTarget. diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index cdff9067f1..11e98d8787 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -11,7 +11,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterator @@ -88,8 +88,8 @@ def get_names(self) -> list[str]: def list_metadata( self, *, - include_filters: Optional[dict[str, Any]] = None, - exclude_filters: Optional[dict[str, Any]] = None, + include_filters: dict[str, Any] | None = None, + exclude_filters: dict[str, Any] | None = None, ) -> list[MetadataT]: """ List metadata for all registered items, optionally filtered. @@ -148,8 +148,8 @@ def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: def _matches_filters( metadata: Any, *, - include_filters: Optional[dict[str, Any]] = None, - exclude_filters: Optional[dict[str, Any]] = None, + include_filters: dict[str, Any] | None = None, + exclude_filters: dict[str, Any] | None = None, ) -> bool: """ Check if a metadata object matches all provided filters. diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 85211aba84..4a2f504d7a 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -19,7 +19,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -54,8 +54,8 @@ def __init__( self, *, registered_class: type[T], - factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[dict[str, object]] = None, + factory: Callable[..., T] | None = None, + default_kwargs: dict[str, object] | None = None, ) -> None: """ Initialize a class entry. @@ -129,7 +129,7 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: """ # Maps registry names to ClassEntry wrappers self._class_entries: dict[str, ClassEntry[T]] = {} - self._metadata_cache: Optional[list[MetadataT]] = None + self._metadata_cache: list[MetadataT] | None = None self._discovered = False self._lazy_discovery = lazy_discovery @@ -211,7 +211,7 @@ def get_class(self, name: str) -> type[T]: raise KeyError(f"'{name}' not found in registry. Available: {available}") return entry.registered_class - def get_entry(self, name: str) -> Optional[ClassEntry[T]]: + def get_entry(self, name: str) -> ClassEntry[T] | None: """ Get the full ClassEntry for a registered class. @@ -242,8 +242,8 @@ def get_names(self) -> list[str]: def list_metadata( self, *, - include_filters: Optional[dict[str, object]] = None, - exclude_filters: Optional[dict[str, object]] = None, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, ) -> list[MetadataT]: """ List metadata for all registered classes, optionally filtered. @@ -286,9 +286,9 @@ def register( self, cls: type[T], *, - name: Optional[str] = None, - factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[dict[str, object]] = None, + name: str | None = None, + factory: Callable[..., T] | None = None, + default_kwargs: dict[str, object] | None = None, ) -> None: """ Register a class with the registry. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 17fb39fb7c..5310af3d69 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -15,7 +15,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import class_name_to_snake_case, validate_registry_name from pyrit.registry.base import ClassRegistryEntry @@ -47,7 +47,7 @@ class InitializerMetadata(ClassRegistryEntry): required_env_vars: tuple[str, ...] = field(kw_only=True) # Supported parameters as tuples of (name, description, default). - supported_parameters: tuple[tuple[str, str, Optional[list[str]]], ...] = field(kw_only=True, default=()) + supported_parameters: tuple[tuple[str, str, list[str] | None], ...] = field(kw_only=True, default=()) class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): @@ -61,7 +61,7 @@ class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetad The directory structure is used for organization but not exposed to users. """ - def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: bool = False) -> None: + def __init__(self, *, discovery_path: Path | None = None, lazy_discovery: bool = False) -> None: """ Initialize the initializer registry. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 2702b1860e..0300b00a06 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -13,7 +13,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, get_origin +from typing import TYPE_CHECKING, Any, NamedTuple, get_origin from pyrit.models import class_name_to_snake_case from pyrit.registry.base import ClassRegistryEntry @@ -53,7 +53,7 @@ class ScenarioMetadata(ClassRegistryEntry): default_datasets: tuple[str, ...] = field(kw_only=True) # Maximum number of items per dataset. - max_dataset_size: Optional[int] = field(kw_only=True) + max_dataset_size: int | None = field(kw_only=True) # Scenario-declared custom parameters. supported_parameters: tuple[ScenarioParameterMetadata, ...] = field(kw_only=True, default=()) @@ -71,7 +71,7 @@ class ScenarioParameterMetadata(NamedTuple): description: str default: Any param_type: str - choices: Optional[list[str]] + choices: list[str] | None is_list: bool = False diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py index 5df0c14fee..34c1562bc3 100644 --- a/pyrit/registry/discovery.py +++ b/pyrit/registry/discovery.py @@ -15,7 +15,7 @@ import pkgutil from collections.abc import Callable, Iterator from pathlib import Path -from typing import Optional, TypeVar +from typing import TypeVar logger = logging.getLogger(__name__) @@ -92,7 +92,7 @@ def discover_in_package( package_name: str, base_class: type[T], recursive: bool = True, - name_builder: Optional[Callable[[str, str], str]] = None, + name_builder: Callable[[str, str], str] | None = None, _prefix: str = "", ) -> Iterator[tuple[str, type[T]]]: """ @@ -156,7 +156,7 @@ def name_builder(prefix: str, name: str) -> str: def discover_subclasses_in_loaded_modules( *, base_class: type[T], - exclude_module_prefixes: Optional[tuple[str, ...]] = None, + exclude_module_prefixes: tuple[str, ...] | None = None, ) -> Iterator[tuple[str, type[T]]]: """ Discover subclasses of a base class from already-loaded modules. diff --git a/pyrit/registry/object_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py index 4d83c9e1fd..568d1e6332 100644 --- a/pyrit/registry/object_registries/converter_registry.py +++ b/pyrit/registry/object_registries/converter_registry.py @@ -12,7 +12,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -37,8 +37,8 @@ def register_instance( self, converter: PromptConverter, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a converter instance. @@ -56,7 +56,7 @@ def register_instance( self.register(converter, name=name, tags=tags) logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: + def get_instance_by_name(self, name: str) -> PromptConverter | None: """ Get a registered converter instance by name. diff --git a/pyrit/registry/object_registries/scorer_registry.py b/pyrit/registry/object_registries/scorer_registry.py index af5c59946f..d1a938aa30 100644 --- a/pyrit/registry/object_registries/scorer_registry.py +++ b/pyrit/registry/object_registries/scorer_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -38,8 +38,8 @@ def register_instance( self, scorer: Scorer, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a scorer instance. @@ -60,7 +60,7 @@ def register_instance( self.register(scorer, name=name, tags=tags) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[Scorer]: + def get_instance_by_name(self, name: str) -> Scorer | None: """ Get a registered scorer instance by name. diff --git a/pyrit/registry/object_registries/target_registry.py b/pyrit/registry/object_registries/target_registry.py index c6fefd3926..170bad2078 100644 --- a/pyrit/registry/object_registries/target_registry.py +++ b/pyrit/registry/object_registries/target_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -38,8 +38,8 @@ def register_instance( self, target: PromptTarget, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a target instance. @@ -61,7 +61,7 @@ def register_instance( self.register(target, name=name, tags=tags) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: + def get_instance_by_name(self, name: str) -> PromptTarget | None: """ Get a registered target instance by name. diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index acec06f8b0..d247058c8f 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -59,7 +59,7 @@ def __init__( seed_groups: list[SeedAttackGroup], adversarial_chat: Optional["PromptTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **attack_execute_params: Any, ) -> None: """ diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index 25cd9162c3..cbc22df444 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -11,7 +11,7 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.memory import CentralMemory from pyrit.models import SeedAttackGroup, SeedGroup @@ -46,10 +46,10 @@ class DatasetConfiguration: def __init__( self, *, - seed_groups: Optional[list[SeedGroup]] = None, - dataset_names: Optional[list[str]] = None, - max_dataset_size: Optional[int] = None, - scenario_strategies: Optional[Sequence[ScenarioStrategy]] = None, + seed_groups: list[SeedGroup] | None = None, + dataset_names: list[str] | None = None, + max_dataset_size: int | None = None, + scenario_strategies: Sequence[ScenarioStrategy] | None = None, ) -> None: """ Initialize a DatasetConfiguration. diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 2238886cdb..67dcd01cb7 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, get_origin +from typing import TYPE_CHECKING, Any, ClassVar, cast, get_origin try: # Built-in on Python 3.11+. Fall back to the ``exceptiongroup`` backport on 3.10 @@ -177,7 +177,7 @@ def __init__( default_strategy: ScenarioStrategy, default_dataset_config: DatasetConfiguration, objective_scorer: Scorer, - scenario_result_id: Optional[Union[uuid.UUID, str]] = None, + scenario_result_id: uuid.UUID | str | None = None, include_default_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ @@ -223,10 +223,10 @@ def __init__( self._default_dataset_config = default_dataset_config # These will be set in initialize_async - self._objective_target: Optional[PromptTarget] = None - self._objective_target_identifier: Optional[ComponentIdentifier] = None + self._objective_target: PromptTarget | None = None + self._objective_target_identifier: ComponentIdentifier | None = None self._memory_labels: dict[str, str] = {} - self._max_concurrency: Optional[int] = None + self._max_concurrency: int | None = None self._max_retries: int = 0 self._objective_scorer = objective_scorer @@ -235,7 +235,7 @@ def __init__( self._name = name if name else type(self).__name__ self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: list[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None + self._scenario_result_id: str | None = str(scenario_result_id) if scenario_result_id else None # Store prepared strategies for use in _get_atomic_attacks_async self._scenario_strategies: list[ScenarioStrategy] = [] @@ -533,7 +533,7 @@ def _validate_params(self, *, params: dict[str, Any], declared: list[Parameter]) def _prepare_strategies( self, - strategies: Optional[Sequence[ScenarioStrategy]], + strategies: Sequence[ScenarioStrategy] | None, ) -> list[ScenarioStrategy]: """ Resolve strategy inputs into a concrete list for this scenario. @@ -558,11 +558,11 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - scenario_strategies: Optional[Sequence[ScenarioStrategy]] = None, - dataset_config: Optional[DatasetConfiguration] = None, + scenario_strategies: Sequence[ScenarioStrategy] | None = None, + dataset_config: DatasetConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, include_baseline: bool | None = None, ) -> None: """ @@ -1413,7 +1413,7 @@ def _collect_errors_from_outcomes( for outcome in outcomes: if isinstance(outcome, BaseException): logger.error(f"Atomic attack failed in scenario '{self._name}': {str(outcome)}") - error: Optional[BaseException] = outcome + error: BaseException | None = outcome else: atomic_attack, atomic_results = outcome error = self._partial_result_to_exception(atomic_attack=atomic_attack, atomic_results=atomic_results) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 935b81b51f..f7f8f4c351 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -90,9 +90,9 @@ def required_datasets(cls) -> list[str]: def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - scenario_result_id: Optional[str] = None, - num_templates: Optional[int] = None, + objective_scorer: TrueFalseScorer | None = None, + scenario_result_id: str | None = None, + num_templates: int | None = None, num_attempts: int = 1, jailbreak_names: list[str] | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. @@ -132,7 +132,7 @@ def __init__( self._num_templates = num_templates self._num_attempts = num_attempts - self._adversarial_target: Optional[PromptTarget] = None + self._adversarial_target: PromptTarget | None = None # Note that num_templates and jailbreak_names are mutually exclusive. # If self._num_templates is None, then this returns all discoverable jailbreak templates. @@ -170,7 +170,7 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _get_or_create_adversarial_target(self) -> PromptTarget: """ @@ -233,7 +233,7 @@ async def _get_atomic_attack_from_strategy_async( request_converters=PromptConverterConfiguration.from_converters(converters=[jailbreak_converter]) ) - attack: Optional[Union[ManyShotJailbreakAttack, PromptSendingAttack, RolePlayAttack, SkeletonKeyAttack]] = None + attack: ManyShotJailbreakAttack | PromptSendingAttack | RolePlayAttack | SkeletonKeyAttack | None = None args: dict[str, Any] = { "objective_target": self._objective_target, "attack_scoring_config": AttackScoringConfig(objective_scorer=self._objective_scorer), diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 8ba3991649..df39ff23c6 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -4,7 +4,7 @@ import logging import pathlib from dataclasses import dataclass -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar import yaml @@ -73,7 +73,7 @@ class ResolvedSeedData: """Helper dataclass for resolved seed data.""" seed_groups: list[SeedAttackGroup] - subharm: Optional[str] + subharm: str | None class PsychosocialStrategy(ScenarioStrategy): @@ -97,7 +97,7 @@ class PsychosocialStrategy(ScenarioStrategy): LicensedTherapist = ("licensed_therapist", set[str]()) @property - def harm_category_filter(self) -> Optional[str]: + def harm_category_filter(self) -> str | None: """ Get the harm category filter for this strategy. @@ -179,11 +179,11 @@ class Psychosocial(Scenario): def __init__( self, *, - objectives: Optional[list[str]] = None, - adversarial_chat: Optional[PromptTarget] = None, - objective_scorer: Optional[FloatScaleThresholdScorer] = None, - scenario_result_id: Optional[str] = None, - subharm_configs: Optional[dict[str, SubharmConfig]] = None, + objectives: list[str] | None = None, + adversarial_chat: PromptTarget | None = None, + objective_scorer: FloatScaleThresholdScorer | None = None, + scenario_result_id: str | None = None, + subharm_configs: dict[str, SubharmConfig] | None = None, max_turns: int = 5, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: @@ -255,7 +255,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> ResolvedSeedData: """ @@ -300,7 +300,7 @@ def _resolve_seed_groups(self) -> ResolvedSeedData: subharm=harm_category_filter, ) - def _extract_harm_category_filter(self) -> Optional[str]: + def _extract_harm_category_filter(self) -> str | None: """ Extract harm category filter from scenario strategies. @@ -339,7 +339,7 @@ def _filter_by_harm_category( filtered_groups.append(SeedAttackGroup(seeds=filtered_seeds)) return filtered_groups - def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScorer: + def _get_scorer(self, subharm: str | None = None) -> FloatScaleThresholdScorer: """ Create scorer for psychosocial harms evaluation. @@ -420,7 +420,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: return atomic_attacks - def _create_scoring_config(self, subharm: Optional[str]) -> AttackScoringConfig: + def _create_scoring_config(self, subharm: str | None) -> AttackScoringConfig: subharm_config = self._subharm_configs.get(subharm) if subharm else None scorer = self._get_scorer(subharm=subharm) if subharm_config else self._objective_scorer return AttackScoringConfig(objective_scorer=scorer) @@ -470,7 +470,7 @@ def _create_multi_turn_attack( self, *, scoring_config: AttackScoringConfig, - subharm: Optional[str], + subharm: str | None, seed_groups: list[SeedAttackGroup], ) -> AtomicAttack: subharm_config = self._subharm_configs.get(subharm) if subharm else None diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index ab05c0fc81..9e6dca8fbe 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pyrit.common import Parameter, apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -121,9 +121,9 @@ def supported_parameters(cls) -> list[Parameter]: def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - adversarial_chat: Optional[PromptTarget] = None, - scenario_result_id: Optional[str] = None, + objective_scorer: TrueFalseScorer | None = None, + adversarial_chat: PromptTarget | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ @@ -166,7 +166,7 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ @@ -201,7 +201,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: raise ValueError( "Scenario not properly initialized. Call await scenario.initialize_async() before running." ) - attack_strategy: Optional[AttackStrategy[Any, Any]] = None + attack_strategy: AttackStrategy[Any, Any] | None = None if strategy == "persuasive_rta": # Set system prompt to generic persuasion persona diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index 142aa53959..dc14cac8c4 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from inspect import signature -from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -219,9 +219,9 @@ class RedTeamAgent(Scenario): def __init__( self, *, - adversarial_chat: Optional[PromptTarget] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - scenario_result_id: Optional[str] = None, + adversarial_chat: PromptTarget | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ @@ -280,13 +280,11 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - scenario_strategies: Optional[ - Sequence["FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy"] - ] = None, - dataset_config: Optional[DatasetConfiguration] = None, + scenario_strategies: Sequence["FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy"] | None = None, + dataset_config: DatasetConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, include_baseline: bool | None = None, ) -> None: """ @@ -320,7 +318,7 @@ async def initialize_async( def _prepare_strategies( # type: ignore[ty:invalid-method-override] self, - strategies: "Optional[Sequence[FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy]]", + strategies: "Sequence[FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy] | None", ) -> list[ScenarioStrategy]: """ Resolve strategies and build FoundryComposite objects. @@ -510,7 +508,7 @@ def _get_attack( *, attack_type: type[AttackStrategyT], converters: list[PromptConverter], - attack_kwargs: Optional[dict[str, Any]] = None, + attack_kwargs: dict[str, Any] | None = None, ) -> AttackStrategyT: """ Create an attack instance with the specified converters. diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index 65f36e3218..5e8674ea0d 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -4,7 +4,6 @@ import logging from collections.abc import Sequence -from typing import Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -138,9 +137,9 @@ class Encoding(Scenario): def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - encoding_templates: Optional[Sequence[str]] = None, - scenario_result_id: Optional[str] = None, + objective_scorer: TrueFalseScorer | None = None, + encoding_templates: Sequence[str] | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ @@ -184,7 +183,7 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._resolved_seed_groups: Optional[list[SeedAttackGroup]] = None + self._resolved_seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index eb9cc87454..13764d4f61 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -6,7 +6,6 @@ import tempfile import uuid from pathlib import Path -from typing import Optional import av @@ -107,7 +106,7 @@ def __init__( self, *, text_capable_scorer: Scorer, - use_entra_auth: Optional[bool] = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the base audio scorer. @@ -154,7 +153,7 @@ def _validate_text_scorer(scorer: Scorer) -> None: f"Supported types: {scorer._validator._supported_data_types}" ) - async def _score_audio_async(self, *, message_piece: MessagePiece, objective: Optional[str] = None) -> list[Score]: + async def _score_audio_async(self, *, message_piece: MessagePiece, objective: str | None = None) -> list[Score]: """ Transcribe audio and score the transcript. @@ -267,7 +266,7 @@ def _ensure_wav_format(self, audio_path: str) -> str: channels=self._DEFAULT_CHANNELS, ) - def _extract_audio_from_video(self, video_path: str) -> Optional[str]: + def _extract_audio_from_video(self, video_path: str) -> str | None: """ Extract audio track from a video file. @@ -281,7 +280,7 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: return AudioTranscriptHelper.extract_audio_from_video(video_path) @staticmethod - def extract_audio_from_video(video_path: str) -> Optional[str]: + def extract_audio_from_video(video_path: str) -> str | None: """ Extract audio track from a video file (static version). diff --git a/pyrit/score/batch_scorer.py b/pyrit/score/batch_scorer.py index 66beec5261..2c9d993540 100644 --- a/pyrit/score/batch_scorer.py +++ b/pyrit/score/batch_scorer.py @@ -5,7 +5,6 @@ import uuid from collections.abc import Sequence from datetime import datetime -from typing import Optional from pyrit.memory import CentralMemory from pyrit.models import ( @@ -47,17 +46,17 @@ async def score_responses_by_filters_async( self, *, scorer: Scorer, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[list[str] | list[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[list[str]] = None, - converted_values: Optional[list[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[list[str]] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: list[str] | list[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: list[str] | None = None, + converted_values: list[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: list[str] | None = None, objective: str = "", ) -> list[Score]: """ diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d4e824d1fe..57385e6f1d 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -3,7 +3,7 @@ import uuid from abc import ABC, abstractmethod -from typing import Optional, cast +from typing import cast from uuid import UUID from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score @@ -32,7 +32,7 @@ class ConversationScorer(Scorer, ABC): enforce_all_pieces_valid=False, ) - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Scores the entire conversation history by concatenating all messages and passing to the wrapped scorer. @@ -128,7 +128,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non return scores - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Not used - ConversationScorer operates at conversation level via _score_async. @@ -159,7 +159,7 @@ def validate_return_scores(self, scores: list[Score]) -> None: def create_conversation_scorer( *, scorer: Scorer, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> Scorer: """ Create a ConversationScorer that inherits from the same type as the wrapped scorer. @@ -187,7 +187,7 @@ def create_conversation_scorer( >>> isinstance(conversation_scorer, ConversationScorer) # True """ # Determine the base class of the wrapped scorer - scorer_base_class: Optional[type[Scorer]] = None + scorer_base_class: type[Scorer] | None = None if isinstance(scorer, FloatScaleScorer): scorer_base_class = FloatScaleScorer diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index 203f8e1281..17653c9d5f 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper @@ -23,8 +22,8 @@ def __init__( self, *, text_capable_scorer: FloatScaleScorer, - validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, + validator: ScorerPromptValidator | None = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the AudioFloatScaleScorer. @@ -62,7 +61,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score an audio file by transcribing it and scoring the transcript. diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 2ef3b412fb..85658563de 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -95,10 +95,10 @@ def _get_eval_files_for_category(cls, category: TextCategory) -> Optional["Score def __init__( self, *, - endpoint: Optional[str | None] = None, - api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None, - harm_categories: Optional[list[TextCategory]] = None, - validator: Optional[ScorerPromptValidator] = None, + endpoint: str | None = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + harm_categories: list[TextCategory] | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize an Azure Content Filter Scorer. @@ -247,7 +247,7 @@ def _get_chunks(self, text: str) -> list[str]: return [text[i : i + self.MAX_TEXT_LENGTH] for i in range(0, len(text), self.MAX_TEXT_LENGTH)] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Evaluate the input text or image using the Azure Content Filter API. @@ -343,7 +343,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op for result in aggregated_results ] - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build one neutral ``0.0`` fallback score per configured harm category. diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index c888c117f3..6d1f9ab3e5 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -35,7 +35,7 @@ class FloatScaleScorer(Scorer): "blocked = True") should override ``_score_piece_async`` or ``_build_fallback_score``. """ - def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator, chat_target: PromptTarget | None = None) -> None: """ Initialize the FloatScaleScorer. @@ -46,7 +46,7 @@ def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[Pr """ super().__init__(validator=validator, chat_target=chat_target) - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build a single-element list containing a neutral ``0.0`` score when no pieces could be scored. @@ -138,15 +138,15 @@ async def _score_value_with_llm_async( message_value: str, message_data_type: PromptDataType, scored_prompt_id: str | UUID, - prepended_text_message_piece: Optional[str] = None, - category: Optional[str | UUID] = None, - objective: Optional[str] = None, + prepended_text_message_piece: str | None = None, + category: str | UUID | None = None, + objective: str | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[ComponentIdentifier] = None, + attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 91d795e9e4..128b536a10 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from pathlib import Path -from typing import Optional, Union from pyrit.common import verify_and_resolve_path from pyrit.common.path import SCORER_SEED_PROMPT_PATH @@ -26,8 +25,8 @@ def __init__( self, *, chat_target: PromptTarget, - system_prompt_path: Optional[Union[str, Path]] = None, - validator: Optional[ScorerPromptValidator] = None, + system_prompt_path: str | Path | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the Insecure Code Scorer. @@ -72,7 +71,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message piece using LLM to detect security vulnerabilities. diff --git a/pyrit/score/float_scale/plagiarism_scorer.py b/pyrit/score/float_scale/plagiarism_scorer.py index 3547767114..e608f79ca4 100644 --- a/pyrit/score/float_scale/plagiarism_scorer.py +++ b/pyrit/score/float_scale/plagiarism_scorer.py @@ -3,7 +3,6 @@ import re from enum import Enum -from typing import Optional import numpy as np @@ -38,7 +37,7 @@ def __init__( reference_text: str, metric: PlagiarismMetric = PlagiarismMetric.LCS, n: int = 5, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the PlagiarismScorer. @@ -164,7 +163,7 @@ def _plagiarism_score( raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'") - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the AI response against the reference text using the specified metric. diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index fa1b56627d..9e5d8bd3a4 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer @@ -31,11 +31,11 @@ def __init__( *, chat_target: PromptTarget, system_prompt_format_string: str, - prompt_format_string: Optional[str] = None, - category: Optional[str] = None, + prompt_format_string: str | None = None, + category: str | None = None, min_value: int = 0, max_value: int = 100, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single message piece using the configured prompts and scale to [0, 1]. diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 6134f6f7af..15c496dd6c 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -5,7 +5,6 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Optional import yaml @@ -32,7 +31,7 @@ class LikertScaleEvalFiles: human_labeled_datasets_files: list[str] result_file: str - harm_category: Optional[str] = None + harm_category: str | None = None class LikertScalePaths(enum.Enum): @@ -158,7 +157,7 @@ def path(self) -> Path: return self.value[0] @property - def evaluation_files(self) -> Optional[LikertScaleEvalFiles]: + def evaluation_files(self) -> LikertScaleEvalFiles | None: """Get the evaluation file configuration, or None if no evaluation dataset exists.""" return self.value[1] @@ -178,10 +177,10 @@ def __init__( self, *, chat_target: PromptTarget, - likert_scale: Optional[LikertScalePaths] = None, - custom_likert_path: Optional[Path] = None, - custom_system_prompt_path: Optional[Path] = None, - validator: Optional[ScorerPromptValidator] = None, + likert_scale: LikertScalePaths | None = None, + custom_likert_path: Path | None = None, + custom_system_prompt_path: Path | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SelfAskLikertScorer. @@ -211,9 +210,7 @@ def __init__( if likert_scale is None and custom_likert_path is None: raise ValueError("One of 'likert_scale' or 'custom_likert_path' must be provided.") - self._scoring_instructions_template: Optional[SeedPrompt] = ( - None # Will be set in _set_likert_scale_system_prompt - ) + self._scoring_instructions_template: SeedPrompt | None = None # Will be set in _set_likert_scale_system_prompt if custom_system_prompt_path is not None: self._validate_custom_system_prompt_path(custom_system_prompt_path) self._scoring_instructions_template = SeedPrompt.from_yaml_file(custom_system_prompt_path) @@ -436,7 +433,7 @@ def _validate_custom_likert_path(custom_likert_path: Path) -> None: f"Custom Likert scale file must be a YAML file (.yaml or .yml), got '{custom_likert_path.suffix}'." ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given message_piece using "self-ask" for the chat target. diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 6cdc1e2921..992d403ca7 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -3,7 +3,7 @@ import enum from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import yaml @@ -44,9 +44,9 @@ def __init__( self, *, chat_target: PromptTarget, - scale_arguments_path: Optional[Union[Path, str]] = None, - system_prompt_path: Optional[Union[Path, str]] = None, - validator: Optional[ScorerPromptValidator] = None, + scale_arguments_path: Path | str | None = None, + system_prompt_path: Path | str | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SelfAskScaleScorer. @@ -101,7 +101,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message_piece using "self-ask" for the chat target. diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index cb337aa506..9845bb34ac 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( @@ -42,12 +42,12 @@ def __init__( self, *, image_capable_scorer: FloatScaleScorer, - audio_scorer: Optional[FloatScaleScorer] = None, - num_sampled_frames: Optional[int] = None, - validator: Optional[ScorerPromptValidator] = None, + audio_scorer: FloatScaleScorer | None = None, + num_sampled_frames: int | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: FloatScaleAggregatorFunc = FloatScaleScorerByCategory.MAX, - image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + image_objective_template: str | None = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the VideoFloatScaleScorer. @@ -116,7 +116,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. diff --git a/pyrit/score/score_aggregator_result.py b/pyrit/score/score_aggregator_result.py index de5b8dc212..84133eb27a 100644 --- a/pyrit/score/score_aggregator_result.py +++ b/pyrit/score/score_aggregator_result.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Union @dataclass(frozen=True, slots=True) @@ -19,8 +18,8 @@ class ScoreAggregatorResult: metadata (Dict[str, Union[str, int, float]]): Combined metadata from constituent scores. """ - value: Union[bool, float] + value: bool | float description: str rationale: str category: list[str] - metadata: dict[str, Union[str, int, float]] + metadata: dict[str, str | int | float] diff --git a/pyrit/score/score_utils.py b/pyrit/score/score_utils.py index 5ae68c3939..4429b34e67 100644 --- a/pyrit/score/score_utils.py +++ b/pyrit/score/score_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional, Union from pyrit.common.utils import combine_dict from pyrit.models import Score @@ -11,7 +10,7 @@ ORIGINAL_FLOAT_VALUE_KEY = "original_float_value" -def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, Union[str, int, float]], list[str]]: +def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, str | int | float], list[str]]: """ Combine metadata and categories from multiple scores with deduplication. @@ -21,7 +20,7 @@ def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, Unio Returns: Tuple of (metadata dict, sorted category list with empty strings filtered). """ - metadata: dict[str, Union[str, int, float]] = {} + metadata: dict[str, str | int | float] = {} category_set: set[str] = set() for s in scores: @@ -47,7 +46,7 @@ def format_score_for_rationale(score: Score) -> str: return f" - {class_type} {score.score_value}: {score.score_rationale or ''}" -def normalize_score_to_float(score: Optional[Score]) -> float: +def normalize_score_to_float(score: Score | None) -> float: """ Normalize any score to a float value between 0.0 and 1.0. diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index ce133df66a..28c7d330b0 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -13,8 +13,6 @@ TYPE_CHECKING, Any, ClassVar, - Optional, - Union, cast, ) @@ -61,7 +59,7 @@ class Scorer(Identifiable, abc.ABC): # Evaluation configuration - maps input dataset files to a result file. # Specifies glob patterns for datasets and a result file name. - evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None + evaluation_file_mapping: ScorerEvalDatasetFiles | None = None #: Capability requirements placed on the scorer's chat target (if any). #: Subclasses that use a chat target should override this and pass the @@ -69,7 +67,7 @@ class Scorer(Identifiable, abc.ABC): #: validate it. TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None #: When True, blocked responses that contain partial content #: (in prompt_metadata["partial_content"]) will be scored using that content @@ -80,7 +78,7 @@ class Scorer(Identifiable, abc.ABC): #: (Chat Completions API) and ``OpenAIResponseTarget`` (Responses API). score_blocked_content: bool = False - def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator, chat_target: PromptTarget | None = None) -> None: """ Initialize the Scorer. @@ -148,8 +146,8 @@ def _memory(self) -> MemoryInterface: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the scorer identifier. @@ -182,8 +180,8 @@ async def score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, skip_on_error_result: bool = False, infer_objective_from_request: bool = False, ) -> list[Score]: @@ -266,7 +264,7 @@ async def score_async( return scores - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Score the given request response asynchronously. @@ -299,11 +297,11 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non return [score for sublist in piece_score_lists for score in sublist] @abstractmethod - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: raise NotImplementedError @staticmethod - def _create_text_piece_from_blocked(piece: MessagePiece) -> Optional[MessagePiece]: + def _create_text_piece_from_blocked(piece: MessagePiece) -> MessagePiece | None: """ Create a text-typed copy of a blocked MessagePiece using its partial content. @@ -382,7 +380,7 @@ def _get_supported_pieces(self, message: Message) -> list[MessagePiece]: ] @abstractmethod - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Return neutral fallback ``Score`` objects when ``_score_async`` produced no scores. @@ -420,12 +418,12 @@ def validate_return_scores(self, scores: list[Score]) -> None: async def evaluate_async( self, - file_mapping: Optional[ScorerEvalDatasetFiles] = None, + file_mapping: ScorerEvalDatasetFiles | None = None, *, num_scorer_trials: int = 3, update_registry_behavior: RegistryUpdateBehavior | None = None, max_concurrency: int = 10, - ) -> Optional[ScorerMetrics]: + ) -> ScorerMetrics | None: """ Evaluate this scorer against human-labeled datasets. @@ -474,7 +472,7 @@ async def evaluate_async( ) @abstractmethod - def get_scorer_metrics(self) -> Optional[ScorerMetrics]: + def get_scorer_metrics(self) -> ScorerMetrics | None: """ Get evaluation metrics for this scorer from the configured evaluation result file. @@ -490,7 +488,7 @@ def get_scorer_metrics(self) -> Optional[ScorerMetrics]: """ raise NotImplementedError("Subclasses must implement get_scorer_metrics") - async def score_text_async(self, text: str, *, objective: Optional[str] = None) -> list[Score]: + async def score_text_async(self, text: str, *, objective: str | None = None) -> list[Score]: """ Scores the given text based on the task using the chat target. @@ -513,7 +511,7 @@ async def score_text_async(self, text: str, *, objective: Optional[str] = None) request.message_pieces[0].not_in_memory = True return await self.score_async(request, objective=objective) - async def score_image_async(self, image_path: str, *, objective: Optional[str] = None) -> list[Score]: + async def score_image_async(self, image_path: str, *, objective: str | None = None) -> list[Score]: """ Score the given image using the chat target. @@ -541,9 +539,9 @@ async def score_prompts_batch_async( self, *, messages: Sequence[Message], - objectives: Optional[Sequence[str]] = None, + objectives: Sequence[str] | None = None, batch_size: int = 10, - role_filter: Optional[ChatMessageRole] = None, + role_filter: ChatMessageRole | None = None, skip_on_error_result: bool = False, infer_objective_from_request: bool = False, ) -> list[Score]: @@ -593,7 +591,7 @@ async def score_prompts_batch_async( return [score for sublist in results for score in sublist] async def score_image_batch_async( - self, *, image_paths: Sequence[str], objectives: Optional[Sequence[str]] = None, batch_size: int = 10 + self, *, image_paths: Sequence[str], objectives: Sequence[str] | None = None, batch_size: int = 10 ) -> list[Score]: """ Score a batch of images asynchronously. @@ -653,15 +651,15 @@ async def _score_value_with_llm_async( message_value: str, message_data_type: PromptDataType, scored_prompt_id: str, - prepended_text_message_piece: Optional[str] = None, - category: Optional[Sequence[str] | str] = None, - objective: Optional[str] = None, + prepended_text_message_piece: str | None = None, + category: Sequence[str] | str | None = None, + objective: str | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[ComponentIdentifier] = None, + attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -769,7 +767,7 @@ async def _score_value_with_llm_async( # Validate and normalize category to a list of strings cat_val = category_response if category_response is not None else category - normalized_category: Optional[list[str]] + normalized_category: list[str] | None if cat_val is None: normalized_category = None elif isinstance(cat_val, str): @@ -784,7 +782,7 @@ async def _score_value_with_llm_async( # Normalize metadata to a dictionary with string keys and string/int/float values raw_md = parsed_response.get(metadata_output_key) - normalized_md: Optional[dict[str, Union[str, int, float]]] + normalized_md: dict[str, str | int | float] | None if raw_md is None: normalized_md = None elif isinstance(raw_md, dict): @@ -851,10 +849,10 @@ def _extract_objective_from_response(self, response: Message) -> str: async def score_response_async( *, response: Message, - objective_scorer: Optional[Scorer] = None, - auxiliary_scorers: Optional[list[Scorer]] = None, + objective_scorer: Scorer | None = None, + auxiliary_scorers: list[Scorer] | None = None, role_filter: ChatMessageRole = "assistant", - objective: Optional[str] = None, + objective: str | None = None, skip_on_error_result: bool = True, ) -> dict[str, list[Score]]: """ @@ -929,7 +927,7 @@ async def score_response_multiple_scorers_async( response: Message, scorers: list[Scorer], role_filter: ChatMessageRole = "assistant", - objective: Optional[str] = None, + objective: str | None = None, skip_on_error_result: bool = True, ) -> list[Score]: """ diff --git a/pyrit/score/scorer_evaluation/human_labeled_dataset.py b/pyrit/score/scorer_evaluation/human_labeled_dataset.py index f0f0fdcd87..635d9daf2f 100644 --- a/pyrit/score/scorer_evaluation/human_labeled_dataset.py +++ b/pyrit/score/scorer_evaluation/human_labeled_dataset.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast import pandas as pd @@ -126,8 +126,8 @@ def __init__( entries: list[HumanLabeledEntry], metrics_type: MetricsType, version: str, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> None: """ Initialize the HumanLabeledDataset. @@ -156,7 +156,7 @@ def __init__( self.version = version self.harm_definition = harm_definition self.harm_definition_version = harm_definition_version - self._harm_definition_obj: Optional[HarmDefinition] = None + self._harm_definition_obj: HarmDefinition | None = None def get_harm_definition(self) -> Optional["HarmDefinition"]: """ @@ -188,12 +188,12 @@ def get_harm_definition(self) -> Optional["HarmDefinition"]: def from_csv( cls, *, - csv_path: Union[str, Path], + csv_path: str | Path, metrics_type: MetricsType, - dataset_name: Optional[str] = None, - version: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + version: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> "HumanLabeledDataset": """ Load a human-labeled dataset from a CSV file with standard column names. diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index 5f203753fc..8034babc56 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -7,7 +7,7 @@ import logging import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import numpy as np from scipy.stats import ttest_1samp @@ -70,7 +70,7 @@ class ScorerEvalDatasetFiles: human_labeled_datasets_files: list[str] result_file: str - harm_category: Optional[str] = None + harm_category: str | None = None class ScorerEvaluator(abc.ABC): @@ -92,7 +92,7 @@ def __init__(self, scorer: Scorer) -> None: self.scorer = scorer @classmethod - def from_scorer(cls, scorer: Scorer, metrics_type: Optional[MetricsType] = None) -> ScorerEvaluator: + def from_scorer(cls, scorer: Scorer, metrics_type: MetricsType | None = None) -> ScorerEvaluator: """ Create a ScorerEvaluator based on the type of scoring. @@ -120,7 +120,7 @@ async def run_evaluation_async( num_scorer_trials: int = 3, update_registry_behavior: RegistryUpdateBehavior = RegistryUpdateBehavior.SKIP_IF_EXISTS, max_concurrency: int = 10, - ) -> Optional[ScorerMetrics]: + ) -> ScorerMetrics | None: """ Evaluate scorer using dataset files configuration. @@ -265,11 +265,11 @@ def _should_skip_evaluation( self, *, dataset_version: str, - harm_definition_version: Optional[str] = None, + harm_definition_version: str | None = None, num_scorer_trials: int, - harm_category: Optional[str] = None, + harm_category: str | None = None, result_file_path: Path, - ) -> tuple[bool, Optional[ScorerMetrics]]: + ) -> tuple[bool, ScorerMetrics | None]: """ Determine whether to skip evaluation based on existing registry entries. @@ -302,7 +302,7 @@ def _should_skip_evaluation( # Determine if this is a harm or objective evaluation metrics_type = MetricsType.OBJECTIVE if isinstance(self.scorer, TrueFalseScorer) else MetricsType.HARM - existing: Optional[ScorerMetrics] = None + existing: ScorerMetrics | None = None if metrics_type == MetricsType.HARM: if harm_category is None: logger.warning("harm_category must be provided for harm scorer evaluations") @@ -449,7 +449,7 @@ async def evaluate_dataset_async( def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate the dataset and extract data for evaluation. @@ -471,11 +471,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> ScorerMetrics: """ Compute evaluation metrics from human and model scores. @@ -532,7 +532,7 @@ class HarmScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate harm dataset and extract evaluation data. @@ -569,11 +569,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> HarmScorerMetrics: reliability_data = np.concatenate((all_human_scores, all_model_scores)) # Calculate the median of human scores for each response, which is considered the gold label @@ -647,7 +647,7 @@ class ObjectiveScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate objective dataset and extract evaluation data. @@ -685,11 +685,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> ObjectiveScorerMetrics: # Calculate the majority vote of human scores for each response, which is considered the gold label. # If the vote is split, the resulting gold score will be 0 (i.e. False). Same logic is applied to model trials. diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index fab1f9a505..87546f4af1 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -5,7 +5,7 @@ import json from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar from pyrit.common.utils import verify_and_resolve_path @@ -43,9 +43,9 @@ class ScorerMetrics: num_responses: int num_human_raters: int num_scorer_trials: int = field(default=1, kw_only=True) - dataset_name: Optional[str] = field(default=None, kw_only=True) - dataset_version: Optional[str] = field(default=None, kw_only=True) - trial_scores: Optional[np.ndarray] = field(default=None, kw_only=True) + dataset_name: str | None = field(default=None, kw_only=True) + dataset_version: str | None = field(default=None, kw_only=True) + trial_scores: np.ndarray | None = field(default=None, kw_only=True) average_score_time_seconds: float = field(default=0.0, kw_only=True) def to_json(self) -> str: @@ -63,7 +63,7 @@ def to_json(self) -> str: return json.dumps(asdict(self)) @classmethod - def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: + def from_json_file(cls: type[T], file_path: str | Path) -> T: """ Load a metrics instance from a JSON file on disk. @@ -96,7 +96,7 @@ def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: return cls(**filtered_data) @classmethod - def from_json(cls: type[T], file_path: Union[str, Path]) -> T: + def from_json(cls: type[T], file_path: str | Path) -> T: """ Load a metrics instance from a JSON file (deprecated alias for ``from_json_file``). @@ -157,14 +157,14 @@ class HarmScorerMetrics(ScorerMetrics): t_statistic: float p_value: float krippendorff_alpha_combined: float - harm_category: Optional[str] = field(default=None, kw_only=True) - harm_definition: Optional[str] = field(default=None, kw_only=True) - harm_definition_version: Optional[str] = field(default=None, kw_only=True) - krippendorff_alpha_humans: Optional[float] = None - krippendorff_alpha_model: Optional[float] = None - _harm_definition_obj: Optional[HarmDefinition] = field(default=None, init=False, repr=False) - - def get_harm_definition(self) -> Optional[HarmDefinition]: + harm_category: str | None = field(default=None, kw_only=True) + harm_definition: str | None = field(default=None, kw_only=True) + harm_definition_version: str | None = field(default=None, kw_only=True) + krippendorff_alpha_humans: float | None = None + krippendorff_alpha_model: float | None = None + _harm_definition_obj: HarmDefinition | None = field(default=None, init=False, repr=False) + + def get_harm_definition(self) -> HarmDefinition | None: """ Load and return the HarmDefinition object for this metrics instance. diff --git a/pyrit/score/scorer_evaluation/scorer_metrics_io.py b/pyrit/score/scorer_evaluation/scorer_metrics_io.py index d915dc24ab..080d6b45e6 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics_io.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics_io.py @@ -11,7 +11,7 @@ import threading from dataclasses import asdict from pathlib import Path -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from pyrit.common.path import ( SCORER_EVALS_PATH, @@ -53,7 +53,7 @@ def _metrics_to_registry_dict(metrics: ScorerMetrics) -> dict[str, Any]: def get_all_objective_metrics( - file_path: Optional[Path] = None, + file_path: Path | None = None, ) -> list[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: """ Load all objective scorer metrics with full scorer identity for comparison. @@ -151,8 +151,8 @@ def _load_metrics_from_file( def find_objective_metrics_by_eval_hash( *, eval_hash: str, - file_path: Optional[Path] = None, -) -> Optional[ObjectiveScorerMetrics]: + file_path: Path | None = None, +) -> ObjectiveScorerMetrics | None: """ Find objective scorer metrics by evaluation hash. @@ -175,7 +175,7 @@ def find_harm_metrics_by_eval_hash( *, eval_hash: str, harm_category: str, -) -> Optional[HarmScorerMetrics]: +) -> HarmScorerMetrics | None: """ Find harm scorer metrics by evaluation hash. @@ -195,7 +195,7 @@ def _find_metrics_by_eval_hash( file_path: Path, eval_hash: str, metrics_class: type[M], -) -> Optional[M]: +) -> M | None: """ Find scorer metrics by evaluation hash in a specific file. diff --git a/pyrit/score/scorer_prompt_validator.py b/pyrit/score/scorer_prompt_validator.py index f89c93d54d..0b1140c86a 100644 --- a/pyrit/score/scorer_prompt_validator.py +++ b/pyrit/score/scorer_prompt_validator.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections.abc import Sequence -from typing import Optional, get_args +from typing import get_args from pyrit.models import ChatMessageRole, Message, MessagePiece, PromptDataType @@ -18,13 +18,13 @@ class ScorerPromptValidator: def __init__( self, *, - supported_data_types: Optional[Sequence[PromptDataType]] = None, - required_metadata: Optional[Sequence[str]] = None, - supported_roles: Optional[Sequence[ChatMessageRole]] = None, - max_pieces_in_response: Optional[int] = None, - max_text_length: Optional[int] = None, - enforce_all_pieces_valid: Optional[bool] = False, - raise_on_no_valid_pieces: Optional[bool] = False, + supported_data_types: Sequence[PromptDataType] | None = None, + required_metadata: Sequence[str] | None = None, + supported_roles: Sequence[ChatMessageRole] | None = None, + max_pieces_in_response: int | None = None, + max_text_length: int | None = None, + enforce_all_pieces_valid: bool | None = False, + raise_on_no_valid_pieces: bool | None = False, is_objective_required: bool = False, ) -> None: """ diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index c10befbf44..58397a3a29 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper @@ -23,8 +22,8 @@ def __init__( self, *, text_capable_scorer: TrueFalseScorer, - validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, + validator: ScorerPromptValidator | None = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the AudioTrueFalseScorer. @@ -62,7 +61,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score an audio file by transcribing it and scoring the transcript. diff --git a/pyrit/score/true_false/decoding_scorer.py b/pyrit/score/true_false/decoding_scorer.py index f9cecc5f07..a683ee9786 100644 --- a/pyrit/score/true_false/decoding_scorer.py +++ b/pyrit/score/true_false/decoding_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.analytics.text_matching import ExactTextMatching, TextMatching from pyrit.memory.central_memory import CentralMemory @@ -30,10 +29,10 @@ class DecodingScorer(TrueFalseScorer): def __init__( self, *, - text_matcher: Optional[TextMatching] = None, - categories: Optional[list[str]] = None, + text_matcher: TextMatching | None = None, + categories: list[str] | None = None, aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the DecodingScorer. @@ -65,7 +64,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given request piece based on text matching strategy. diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index b89d5439fd..cbb565e9f8 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -88,8 +88,8 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Scores the piece using the underlying float-scale scorer and thresholds the resulting score. @@ -173,7 +173,7 @@ async def _score_async( return [score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Float Scale scorers do not support piecewise scoring. diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index c37469ff62..bfa0664c51 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -3,7 +3,6 @@ import uuid from textwrap import dedent -from typing import Optional import httpx from openai import BadRequestError @@ -37,7 +36,7 @@ def __init__( *, level: GandalfLevel, chat_target: PromptTarget, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -143,7 +142,7 @@ async def _check_for_password_in_conversation_async(self, conversation_id: str) return "" return response_text - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the text based on the password found in the text. diff --git a/pyrit/score/true_false/markdown_injection.py b/pyrit/score/true_false/markdown_injection.py index 517124120c..d7c0f63d5c 100644 --- a/pyrit/score/true_false/markdown_injection.py +++ b/pyrit/score/true_false/markdown_injection.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import re -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -27,7 +26,7 @@ class MarkdownInjectionScorer(TrueFalseScorer): def __init__( self, *, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -55,7 +54,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Check for markdown injection in the text. It returns True if markdown injection is detected, else False. diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index 8f048300cc..9ecab929d4 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -4,7 +4,7 @@ import json import logging import uuid -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score, ScoreType from pyrit.prompt_target import PromptShieldTarget @@ -32,7 +32,7 @@ def __init__( self, *, prompt_shield_target: PromptShieldTarget, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -64,7 +64,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: conversation_id = str(uuid.uuid4()) body = message_piece.original_value diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index a2346fc650..57756635dd 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -32,8 +32,8 @@ def __init__( self, *, correct_answer_matching_patterns: list[str] = CORRECT_ANSWER_MATCHING_PATTERNS, - category: Optional[list[str]] = None, - validator: Optional[ScorerPromptValidator] = None, + category: list[str] | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -67,7 +67,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the message piece using question answering evaluation. diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 26f28bcddf..472268d9a3 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -3,7 +3,6 @@ import enum from pathlib import Path -from typing import Optional, Union import yaml @@ -42,9 +41,9 @@ def __init__( self, *, chat_target: PromptTarget, - content_classifier_path: Union[str, Path], + content_classifier_path: str | Path, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize a new instance of the SelfAskCategoryScorer class. @@ -129,7 +128,7 @@ def _content_classifier_to_string(self, categories: list[dict[str, str]]) -> str return category_descriptions - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message using the chat target. diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 4fc934aa42..c3d1a37002 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -35,9 +35,9 @@ def __init__( *, chat_target: PromptTarget, system_prompt_format_string: str, - prompt_format_string: Optional[str] = None, - category: Optional[str] = None, - validator: Optional[ScorerPromptValidator] = None, + prompt_format_string: str | None = None, + category: str | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single message piece using the configured prompts. diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index 7ea9c9a834..d4a7b78a66 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.common.utils import verify_and_resolve_path @@ -38,8 +38,8 @@ def __init__( self, *, chat_target: PromptTarget, - true_false_question_path: Optional[pathlib.Path] = None, - validator: Optional[ScorerPromptValidator] = None, + true_false_question_path: pathlib.Path | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -67,7 +67,7 @@ def __init__( score_aggregator=score_aggregator, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the message piece using question answering evaluation. diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index b27fce74f2..726886ac18 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -3,7 +3,6 @@ import enum from pathlib import Path -from typing import Optional, Union from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, MessagePiece, Score, SeedPrompt, UnvalidatedScore @@ -69,9 +68,9 @@ def __init__( self, *, chat_target: PromptTarget, - refusal_system_prompt_path: Union[RefusalScorerPaths, Path, str] = RefusalScorerPaths.OBJECTIVE_STRICT, - prompt_format_string: Optional[str] = None, - validator: Optional[ScorerPromptValidator] = None, + refusal_system_prompt_path: RefusalScorerPaths | Path | str = RefusalScorerPaths.OBJECTIVE_STRICT, + prompt_format_string: str | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -138,7 +137,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the prompt and determines whether the response is a refusal. diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index 7c5f860c74..8ac3121d34 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -4,7 +4,7 @@ import enum from collections.abc import Iterator from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import yaml @@ -105,10 +105,10 @@ def __init__( self, *, chat_target: PromptTarget, - true_false_question_path: Optional[Union[str, Path]] = None, - true_false_question: Optional[TrueFalseQuestion] = None, - true_false_system_prompt_path: Optional[Union[str, Path]] = None, - validator: Optional[ScorerPromptValidator] = None, + true_false_question_path: str | Path | None = None, + true_false_question: TrueFalseQuestion | None = None, + true_false_system_prompt_path: str | Path | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -190,7 +190,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message piece using "self-ask" for the chat target. diff --git a/pyrit/score/true_false/substring_scorer.py b/pyrit/score/true_false/substring_scorer.py index 194f5d19eb..5bcd20937a 100644 --- a/pyrit/score/true_false/substring_scorer.py +++ b/pyrit/score/true_false/substring_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.analytics.text_matching import ExactTextMatching, TextMatching from pyrit.models import ComponentIdentifier, MessagePiece, Score @@ -27,10 +26,10 @@ def __init__( self, *, substring: str, - text_matcher: Optional[TextMatching] = None, - categories: Optional[list[str]] = None, + text_matcher: TextMatching | None = None, + categories: list[str] | None = None, aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SubStringScorer. @@ -65,7 +64,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given message piece based on presence of the substring. diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index 148e80322c..2ac1219fee 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -83,8 +83,8 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Score a request/response by combining results from all constituent scorers. @@ -140,7 +140,7 @@ async def _score_async( return [return_score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Composite scorers do not support piecewise scoring. diff --git a/pyrit/score/true_false/true_false_inverter_scorer.py b/pyrit/score/true_false/true_false_inverter_scorer.py index a62d2ac287..4fff7614b3 100644 --- a/pyrit/score/true_false/true_false_inverter_scorer.py +++ b/pyrit/score/true_false/true_false_inverter_scorer.py @@ -15,7 +15,7 @@ class TrueFalseInverterScorer(TrueFalseScorer): """A scorer that inverts a true false score.""" - def __init__(self, *, scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None) -> None: + def __init__(self, *, scorer: TrueFalseScorer, validator: ScorerPromptValidator | None = None) -> None: """ Initialize the TrueFalseInverterScorer. @@ -62,8 +62,8 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Scores the piece using the underlying true-false scorer and returns the inverted score. @@ -100,7 +100,7 @@ async def _score_async( return [inv_score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Indicate that True False Inverter scorers do not support piecewise scoring. diff --git a/pyrit/score/true_false/true_false_scorer.py b/pyrit/score/true_false/true_false_scorer.py index 68183f544c..8a1cf59d30 100644 --- a/pyrit/score/true_false/true_false_scorer.py +++ b/pyrit/score/true_false/true_false_scorer.py @@ -117,7 +117,7 @@ def get_scorer_metrics(self) -> Optional["ObjectiveScorerMetrics"]: return find_objective_metrics_by_eval_hash(eval_hash=eval_hash, file_path=result_file) - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Score the given request response asynchronously. @@ -158,7 +158,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non ) ] - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build a single-element list containing a ``false`` score when no pieces could be scored. diff --git a/pyrit/score/true_false/video_true_false_scorer.py b/pyrit/score/true_false/video_true_false_scorer.py index d1895aa4bb..5c45eae477 100644 --- a/pyrit/score/true_false/video_true_false_scorer.py +++ b/pyrit/score/true_false/video_true_false_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -30,11 +29,11 @@ def __init__( self, *, image_capable_scorer: TrueFalseScorer, - audio_scorer: Optional[TrueFalseScorer] = None, - num_sampled_frames: Optional[int] = None, - validator: Optional[ScorerPromptValidator] = None, - image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + audio_scorer: TrueFalseScorer | None = None, + num_sampled_frames: int | None = None, + validator: ScorerPromptValidator | None = None, + image_objective_template: str | None = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the VideoTrueFalseScorer. @@ -94,7 +93,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index 400d2522a9..b4fbecad38 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -6,7 +6,6 @@ import tempfile import uuid from pathlib import Path -from typing import Optional from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score @@ -43,9 +42,9 @@ def __init__( self, *, image_capable_scorer: Scorer, - num_sampled_frames: Optional[int] = None, - image_objective_template: Optional[str] = _DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + num_sampled_frames: int | None = None, + image_objective_template: str | None = _DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the base video scorer. @@ -95,7 +94,7 @@ def _validate_audio_scorer(scorer: Scorer) -> None: f"Supported types: {scorer._validator._supported_data_types}" ) - async def _score_frames_async(self, *, message_piece: MessagePiece, objective: Optional[str] = None) -> list[Score]: + async def _score_frames_async(self, *, message_piece: MessagePiece, objective: str | None = None) -> list[Score]: """ Extract frames from video and score them. @@ -211,7 +210,7 @@ def _extract_frames(self, video_path: str) -> list[str]: return frame_paths async def _score_video_audio_async( - self, *, message_piece: MessagePiece, audio_scorer: Optional[Scorer] = None, objective: Optional[str] = None + self, *, message_piece: MessagePiece, audio_scorer: Scorer | None = None, objective: str | None = None ) -> list[Score]: """ Extract and score audio from the video. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 0af45c4adf..51c1c30e9a 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -11,7 +11,7 @@ import pathlib from collections.abc import Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.path import DEFAULT_CONFIG_PATH from pyrit.common.yaml_loadable import YamlLoadable @@ -29,8 +29,8 @@ # Type alias for YAML-serializable values that can be passed as initializer args # This matches what YAML can represent: primitives, lists, and nested dicts -YamlPrimitive = Union[str, int, float, bool, None] -YamlValue = Union[YamlPrimitive, list["YamlValue"], dict[str, "YamlValue"]] +YamlPrimitive = str | int | float | bool | None +YamlValue = YamlPrimitive | list["YamlValue"] | dict[str, "YamlValue"] # Mapping from snake_case config values to internal constants _MEMORY_DB_TYPE_MAP: dict[str, str] = { @@ -51,7 +51,7 @@ class InitializerConfig: """ name: str - args: Optional[dict[str, YamlValue]] = None + args: dict[str, YamlValue] | None = None @dataclass @@ -77,7 +77,7 @@ class ScenarioConfig: """ name: str - args: Optional[dict[str, YamlValue]] = None + args: dict[str, YamlValue] | None = None def _scenario_config_to_dict(config: ScenarioConfig) -> dict[str, Any]: @@ -137,16 +137,16 @@ class ConfigurationLoader(YamlLoadable): """ memory_db_type: str = "sqlite" - initializers: list[Union[str, dict[str, Any]]] = field(default_factory=list) - initialization_scripts: Optional[list[str]] = None - env_files: Optional[list[str]] = None + initializers: list[str | dict[str, Any]] = field(default_factory=list) + initialization_scripts: list[str] | None = None + env_files: list[str] | None = None silent: bool = False - operator: Optional[str] = None - operation: Optional[str] = None - scenario: Optional[Union[str, dict[str, Any]]] = None + operator: str | None = None + operation: str | None = None + scenario: str | dict[str, Any] | None = None max_concurrent_scenario_runs: int = 3 allow_custom_initializers: bool = False - server: Optional[dict[str, Any]] = None + server: dict[str, Any] | None = None extensions: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -229,7 +229,7 @@ def _normalize_scenario(self) -> None: ValueError: For any other shape. """ if self.scenario is None: - self._scenario_config: Optional[ScenarioConfig] = None + self._scenario_config: ScenarioConfig | None = None return if isinstance(self.scenario, str): @@ -263,7 +263,7 @@ def _normalize_server(self) -> None: ValueError: If ``server`` is not ``None`` or a dict, or if ``url`` is not a string. """ if self.server is None: - self._server_config: Optional[ServerConfig] = None + self._server_config: ServerConfig | None = None return if isinstance(self.server, dict): @@ -276,12 +276,12 @@ def _normalize_server(self) -> None: raise ValueError(f"Server entry must be a dict, got: {type(self.server).__name__}") @property - def server_config(self) -> Optional[ServerConfig]: + def server_config(self) -> ServerConfig | None: """The normalized ``server:`` block, or ``None`` when not configured.""" return self._server_config @property - def scenario_config(self) -> Optional[ScenarioConfig]: + def scenario_config(self) -> ScenarioConfig | None: """The normalized ``scenario:`` block, or ``None`` when not configured.""" return self._scenario_config @@ -313,12 +313,12 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": @staticmethod def load_with_overrides( - config_file: Optional[pathlib.Path] = None, + config_file: pathlib.Path | None = None, *, - memory_db_type: Optional[str] = None, - initializers: Optional[Sequence[Union[str, dict[str, Any]]]] = None, - initialization_scripts: Optional[Sequence[str]] = None, - env_files: Optional[Sequence[str]] = None, + memory_db_type: str | None = None, + initializers: Sequence[str | dict[str, Any]] | None = None, + initialization_scripts: Sequence[str] | None = None, + env_files: Sequence[str] | None = None, ) -> "ConfigurationLoader": """ Load configuration with optional overrides. @@ -487,7 +487,7 @@ def resolve_initializers(self) -> Sequence["PyRITInitializer"]: return resolved - def resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_initialization_scripts(self) -> Sequence[pathlib.Path] | None: """ Resolve initialization script paths. @@ -512,7 +512,7 @@ def resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: return resolved - def resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_env_files(self) -> Sequence[pathlib.Path] | None: """ Resolve environment file paths. @@ -564,7 +564,7 @@ async def initialize_pyrit_async(self) -> None: async def initialize_from_config_async( - config_path: Optional[Union[str, pathlib.Path]] = None, + config_path: str | pathlib.Path | None = None, ) -> ConfigurationLoader: """ Initialize PyRIT from a configuration file. diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 5d3dfe6663..ab8caa95ed 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -3,7 +3,7 @@ import logging import pathlib from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args import dotenv @@ -27,7 +27,7 @@ MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"] -def _load_environment_files(env_files: Optional[Sequence[pathlib.Path]], *, silent: bool = False) -> None: +def _load_environment_files(env_files: Sequence[pathlib.Path] | None, *, silent: bool = False) -> None: """ Load environment files in the order they are provided. Later files override values from earlier files. @@ -95,9 +95,7 @@ def _print_msg(message: str, quiet: bool, log: bool) -> None: logger.info(message) -def _load_initializers_from_scripts( - *, script_paths: Sequence[Union[str, pathlib.Path]] -) -> Sequence["PyRITInitializer"]: +def _load_initializers_from_scripts(*, script_paths: Sequence[str | pathlib.Path]) -> Sequence["PyRITInitializer"]: """ Load PyRITInitializer instances from external Python files. @@ -228,11 +226,11 @@ async def _execute_initializers_async(*, initializers: Sequence["PyRITInitialize async def initialize_pyrit_async( - memory_db_type: Union[MemoryDatabaseType, str], + memory_db_type: MemoryDatabaseType | str, *, - initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, - initializers: Optional[Sequence["PyRITInitializer"]] = None, - env_files: Optional[Sequence[pathlib.Path]] = None, + initialization_scripts: Sequence[str | pathlib.Path] | None = None, + initializers: Sequence["PyRITInitializer"] | None = None, + env_files: Sequence[pathlib.Path] | None = None, silent: bool = False, **memory_instance_kwargs: Any, ) -> None: diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index fd0284c7d9..017bb873bf 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -17,7 +17,7 @@ from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.parameter import Parameter @@ -71,9 +71,9 @@ class TargetConfig: target_class: type[PromptTarget] endpoint_var: str key_var: str = "" # Empty string means no auth required - model_var: Optional[str] = None - underlying_model_var: Optional[str] = None - temperature: Optional[float] = None + model_var: str | None = None + underlying_model_var: str | None = None + temperature: float | None = None extra_kwargs: dict[str, Any] = field(default_factory=dict) tags: list[TargetInitializerTags] = field(default_factory=lambda: [TargetInitializerTags.DEFAULT]) default_objective_target: bool = False diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 512a4b5986..f222f501e6 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from collections.abc import Generator -from typing import Optional from sqlalchemy import inspect @@ -57,8 +56,8 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, ) -> None: self.system_prompt = system_prompt if self._memory: diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 21d18541ed..8edf1e8ef7 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from datetime import datetime, timedelta, timezone -from typing import Optional from unittest.mock import MagicMock import pytest @@ -27,13 +26,13 @@ # helpers def make_attack( outcome: AttackOutcome, - attack_type: Optional[str] = "default", + attack_type: str | None = "default", conversation_id: str = "conv-1", ) -> AttackResult: """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: Optional[ComponentIdentifier] = None + attack_identifier: ComponentIdentifier | None = None if attack_type is not None: attack_identifier = ComponentIdentifier(class_name=attack_type, class_module="tests.unit.analytics") @@ -190,7 +189,7 @@ def _make_attack_with_target( target: ComponentIdentifier, *, outcome: AttackOutcome = AttackOutcome.SUCCESS, - timestamp: Optional[datetime] = None, + timestamp: datetime | None = None, ) -> AttackResult: technique = ComponentIdentifier( class_name="PromptSendingAttack", diff --git a/tests/unit/common/test_pyrit_default_value.py b/tests/unit/common/test_pyrit_default_value.py index e29981a6f3..cbd9584293 100644 --- a/tests/unit/common/test_pyrit_default_value.py +++ b/tests/unit/common/test_pyrit_default_value.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional import pytest @@ -28,7 +27,7 @@ def test_no_defaults_configured_returns_none(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -41,7 +40,7 @@ def test_single_default_value_applied(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="default_value") @@ -55,7 +54,7 @@ def test_multiple_default_values_applied(self) -> None: class TestClass: @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: self.param1 = param1 self.param2 = param2 @@ -75,7 +74,7 @@ def test_explicit_value_overrides_default(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -92,7 +91,7 @@ def test_partial_override_uses_remaining_defaults(self) -> None: class TestClass: @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: self.param1 = param1 self.param2 = param2 @@ -115,9 +114,9 @@ class TestClass: def __init__( self, *, - param_int: Optional[int] = None, - param_bool: Optional[bool] = None, - param_str: Optional[str] = None, + param_int: int | None = None, + param_bool: bool | None = None, + param_str: str | None = None, ) -> None: self.param_int = param_int self.param_bool = param_bool @@ -145,13 +144,13 @@ def test_subclass_inherits_parent_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1, param2=param2) set_default_value(class_type=ParentClass, parameter_name="param1", value="parent_value") @@ -166,13 +165,13 @@ def test_subclass_specific_defaults_override_parent(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1, param2=param2) set_default_value(class_type=ParentClass, parameter_name="param1", value="parent_value") @@ -189,19 +188,19 @@ def test_multiple_inheritance_levels(self) -> None: class GrandParent: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 class Parent(GrandParent): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1) self.param2 = param2 class Child(Parent): @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: super().__init__(param1=param1, param2=param2) self.param3 = param3 @@ -220,12 +219,12 @@ def test_parent_not_affected_by_child_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: super().__init__(param1=param1) set_default_value(class_type=ChildClass, parameter_name="param1", value="child_value") @@ -354,7 +353,7 @@ def test_set_default_value_stores_value(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="stored_value") @@ -367,7 +366,7 @@ def test_set_default_value_overwrites_existing(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="first_value") @@ -392,9 +391,9 @@ class OpenAIChatTarget: def __init__( self, *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -405,9 +404,9 @@ class AzureOpenAIChatTarget(OpenAIChatTarget): def __init__( self, *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, ) -> None: super().__init__(temperature=temperature, top_p=top_p, max_tokens=max_tokens) @@ -441,12 +440,12 @@ def test_multiple_classes_independent_defaults(self) -> None: class ClassA: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ClassB: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param set_default_value(class_type=ClassA, parameter_name="param", value="value_a") @@ -471,7 +470,7 @@ def test_reset_clears_all_defaults(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -497,12 +496,12 @@ def test_reset_affects_multiple_classes(self) -> None: class ClassA: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ClassB: @apply_defaults - def __init__(self, *, param: Optional[int] = None) -> None: + def __init__(self, *, param: int | None = None) -> None: self.param = param # Set defaults for multiple classes @@ -523,7 +522,7 @@ def test_reset_allows_setting_new_defaults(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param # Set initial default @@ -544,7 +543,7 @@ def test_reset_with_no_defaults_does_nothing(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param # Reset when no defaults are set @@ -562,12 +561,12 @@ def test_reset_clears_inheritance_based_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: super().__init__(param=param) # Set defaults for both parent and child @@ -588,7 +587,7 @@ def test_reset_clears_include_subclasses_flag_variations(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: str | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -761,7 +760,7 @@ def __init__( self, *, required_param: str = REQUIRED_VALUE, # type: ignore[assignment] - optional_param: Optional[str] = None, + optional_param: str | None = None, ) -> None: self.required_param = required_param self.optional_param = optional_param @@ -847,7 +846,7 @@ def test_required_value_none_is_different(self) -> None: class TestClass1: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class TestClass2: diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index ed56813139..47d8678b7d 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -18,7 +18,6 @@ """ import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -56,7 +55,7 @@ class _TestAttackContext(AttackContext): """Concrete AttackContext for testing.""" # Add last_score to match MultiTurnAttackContext behavior for testing - last_score: Optional[Score] = None + last_score: Score | None = None # ============================================================================= diff --git a/tests/unit/executor/attack/compound/test_sequential_attack.py b/tests/unit/executor/attack/compound/test_sequential_attack.py index 79865e8f55..68cf8180d3 100644 --- a/tests/unit/executor/attack/compound/test_sequential_attack.py +++ b/tests/unit/executor/attack/compound/test_sequential_attack.py @@ -3,7 +3,6 @@ """Tests for ``SequentialAttack``.""" -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -35,7 +34,7 @@ def _make_seed_group(objective: str = "obj") -> SeedAttackGroup: def _make_context( *, objective: str = "obj", - labels: Optional[dict[str, str]] = None, + labels: dict[str, str] | None = None, ) -> AttackContext[AttackParameters]: params_type = AttackParameters.excluding("next_message", "prepended_conversation") return AttackContext(params=params_type(objective=objective, memory_labels=labels or {})) diff --git a/tests/unit/executor/attack/multi_turn/test_crescendo.py b/tests/unit/executor/attack/multi_turn/test_crescendo.py index 1bd1c38f46..e8e295ba88 100644 --- a/tests/unit/executor/attack/multi_turn/test_crescendo.py +++ b/tests/unit/executor/attack/multi_turn/test_crescendo.py @@ -4,7 +4,6 @@ import json import uuid from pathlib import Path -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -85,11 +84,11 @@ def create_score( *, score_type: ScoreType, score_value: str, - score_category: Optional[list[str]] = None, + score_category: list[str] | None = None, scorer_class: str, score_rationale: str = "Test rationale", score_value_description: str = "Test description", - score_metadata: Optional[dict] = None, + score_metadata: dict | None = None, ) -> Score: """Create a score with common defaults. @@ -254,10 +253,10 @@ def create_attack( *, objective_target: MagicMock, adversarial_chat: MagicMock, - objective_scorer: Optional[MagicMock] = None, - refusal_scorer: Optional[MagicMock] = None, - prompt_normalizer: Optional[MagicMock] = None, - system_prompt_path: Optional[Path] = None, + objective_scorer: MagicMock | None = None, + refusal_scorer: MagicMock | None = None, + prompt_normalizer: MagicMock | None = None, + system_prompt_path: Path | None = None, **kwargs, ) -> CrescendoAttack: """Create a CrescendoAttack instance with flexible configuration. @@ -909,7 +908,7 @@ async def test_parse_adversarial_response_with_various_inputs( mock_objective_target: MagicMock, mock_adversarial_chat: MagicMock, response_json: str, - expected_error: Optional[str], + expected_error: str | None, ): """Test parsing adversarial response with various inputs. diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 173b9c2f22..e5e14f1eff 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -3,7 +3,6 @@ import uuid from pathlib import Path -from typing import Union from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -215,7 +214,7 @@ def test_init_with_seed_prompt_variations( mock_objective_target: MagicMock, mock_objective_scorer: MagicMock, mock_adversarial_chat: MagicMock, - seed_prompt: Union[str, SeedPrompt], + seed_prompt: str | SeedPrompt, expected_value: str, expected_type: type, ): diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 5c805455bf..aebd00702b 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -6,7 +6,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import Any, Optional, cast +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -50,11 +50,11 @@ class NodeMockConfig: """Configuration for creating mock _TreeOfAttacksNode objects.""" node_id: str = field(default_factory=lambda: str(uuid.uuid4())) - parent_id: Optional[str] = None + parent_id: str | None = None prompt_sent: bool = False completed: bool = True off_topic: bool = False - objective_score_value: Optional[float] = None + objective_score_value: float | None = None auxiliary_scores: dict[str, float] = field(default_factory=dict) objective_target_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) adversarial_chat_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -64,7 +64,7 @@ class MockNodeFactory: """Factory for creating mock _TreeOfAttacksNode objects.""" @staticmethod - def create_node(config: Optional[NodeMockConfig] = None) -> "_TreeOfAttacksNode": + def create_node(config: NodeMockConfig | None = None) -> "_TreeOfAttacksNode": """Create a mock _TreeOfAttacksNode with the given configuration.""" if config is None: config = NodeMockConfig() @@ -150,14 +150,14 @@ class AttackBuilder: """Builder for creating TreeOfAttacksWithPruningAttack instances with common configurations.""" def __init__(self) -> None: - self.objective_target: Optional[PromptTarget] = None - self.adversarial_chat: Optional[PromptTarget] = None - self.objective_scorer: Optional[Scorer] = None + self.objective_target: PromptTarget | None = None + self.adversarial_chat: PromptTarget | None = None + self.objective_scorer: Scorer | None = None self.auxiliary_scorers: list[Scorer] = [] self.tree_params: dict[str, Any] = {} - self.converters: Optional[AttackConverterConfig] = None + self.converters: AttackConverterConfig | None = None self.successful_threshold: float = 0.8 - self.prompt_normalizer: Optional[PromptNormalizer] = None + self.prompt_normalizer: PromptNormalizer | None = None self._supports_multi_turn: bool = True def with_default_mocks(self) -> "AttackBuilder": diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index a23dd6daca..963e84d3e4 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -10,7 +10,6 @@ import uuid from contextlib import suppress -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -911,7 +910,7 @@ def _assert_prepended_text_in_adversarial_context( *, prepended_conversation: list[Message], adversarial_chat_conversation_id: str, - adversarial_chat_mock: Optional[MagicMock] = None, + adversarial_chat_mock: MagicMock | None = None, ) -> None: """ Assert that text content from prepended conversation appears in adversarial chat context. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index cdb611e64b..61fbc860b9 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -3,7 +3,7 @@ import uuid -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -1270,7 +1270,7 @@ def test_get_unique_attack_labels_deduplicates_across_sources(sqlite_instance: M def _make_attack_result_with_identifier( conversation_id: str, class_name: str, - converter_class_names: Optional[list[str]] = None, + converter_class_names: list[str] | None = None, ) -> AttackResult: """Helper to create an AttackResult with a ComponentIdentifier containing converters.""" children: dict = {} diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 4469de670b..f818e45ecd 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from datetime import datetime, timedelta, timezone -from typing import Optional import pytest from unit.mocks import get_mock_scorer_identifier @@ -42,7 +41,7 @@ def create_scenario_result( name: str = "Test Scenario", description: str = "Test Description", version: int = 1, - attack_results: Optional[dict[str, list[AttackResult]]] = None, + attack_results: dict[str, list[AttackResult]] | None = None, ): """Helper function to create ScenarioResult.""" scenario_identifier = ScenarioIdentifier( diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 2d34d455b5..ffe2fecd9f 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -7,7 +7,6 @@ import uuid from collections.abc import Generator, MutableSequence, Sequence from contextlib import AbstractAsyncContextManager -from typing import Optional from unittest.mock import MagicMock, patch from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry @@ -140,8 +139,8 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, ) -> None: self.system_prompt = system_prompt if self._memory: diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index c92614a61d..99cc42fc3e 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -4,7 +4,6 @@ import asyncio import os from collections.abc import Callable -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -40,8 +39,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation): def _build_target( *, endpoint: str = "https://test.openai.azure.com/openai/v1", - api_key: Optional[str | Callable] = "test-key", - env_vars: Optional[dict[str, str]] = None, + api_key: str | Callable | None = "test-key", + env_vars: dict[str, str] | None = None, ) -> _ConcreteOpenAITarget: """Helper to build a _ConcreteOpenAITarget with controlled env.""" env = {"TEST_MODEL": "gpt-4", "TEST_ENDPOINT": endpoint} diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index b0d0f3c5d5..6e458ec3eb 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.registry.object_registries.scorer_registry import ScorerRegistry @@ -35,10 +34,10 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): @@ -59,10 +58,10 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): @@ -83,16 +82,16 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): pass - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="false", diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 3b0e51abb4..ca38a8168c 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -4,7 +4,6 @@ import os import tempfile import uuid -from typing import Optional from unittest.mock import AsyncMock, patch import pytest @@ -29,7 +28,7 @@ def __init__(self, return_value: bool = True): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="true_false", @@ -56,7 +55,7 @@ def __init__(self, return_value: float = 0.8): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="float_scale", diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index dc0be088e8..0e957482a2 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -37,7 +36,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] @@ -50,7 +49,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] @@ -63,13 +62,13 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): pass - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="false", @@ -753,7 +752,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_async( # type: ignore[override] - self, message: Message, *, objective: Optional[str] = None + self, message: Message, *, objective: str | None = None ) -> list[Score]: captured_messages.append(message) piece = message.message_pieces[0] @@ -773,9 +772,7 @@ async def _score_async( # type: ignore[override] ] return [] - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] inner_scorer = HarmfulContentDetector() diff --git a/tests/unit/score/test_float_scale_score_aggregator.py b/tests/unit/score/test_float_scale_score_aggregator.py index c726cd331a..19fac4bbff 100644 --- a/tests/unit/score/test_float_scale_score_aggregator.py +++ b/tests/unit/score/test_float_scale_score_aggregator.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( @@ -17,7 +16,7 @@ ) -def _mk_score(val: float, *, category: Optional[list[str]] = None, prr_id: str = "1", rationale: str = "") -> Score: +def _mk_score(val: float, *, category: list[str] | None = None, prr_id: str = "1", rationale: str = "") -> Score: """Helper to create a float scale score.""" return Score( score_value=str(val), diff --git a/tests/unit/score/test_float_scale_threshold_scorer.py b/tests/unit/score/test_float_scale_threshold_scorer.py index de91ff5d62..b98cb183d8 100644 --- a/tests/unit/score/test_float_scale_threshold_scorer.py +++ b/tests/unit/score/test_float_scale_threshold_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -237,9 +236,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="0.9", diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index 13ecaeed06..e47a4c39cd 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,7 +14,7 @@ from pyrit.score import GandalfScorer -def generate_password_extraction_response(response_text: str, conversation_id: Optional[str] = None) -> Message: +def generate_password_extraction_response(response_text: str, conversation_id: str | None = None) -> Message: return Message( message_pieces=[ MessagePiece( @@ -30,7 +29,7 @@ def generate_password_extraction_response(response_text: str, conversation_id: O ) -def generate_request(conversation_id: Optional[str] = None) -> Message: +def generate_request(conversation_id: str | None = None) -> Message: return Message( message_pieces=[ MessagePiece( diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 8996ebdb6a..aaf1b964f7 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -4,7 +4,6 @@ import asyncio import uuid from textwrap import dedent -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -66,7 +65,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this mock scorer.""" return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="true", @@ -81,7 +80,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non ) ] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="true", @@ -122,7 +121,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this mock scorer.""" return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: # Track which pieces get scored self.scored_piece_ids.append(str(message_piece.id)) @@ -144,7 +143,7 @@ def validate_return_scores(self, scores: list[Score]): for score in scores: assert 0 <= float(score.score_value) <= 1 - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="0.0", @@ -1168,9 +1167,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this test scorer.""" return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_piece_ids.append(message_piece.id) return [ Score( @@ -1355,7 +1352,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None + self, message_piece: MessagePiece, *, objective: str | None = None ) -> list[Score]: # Return empty list to simulate no scorable pieces return [] @@ -1482,7 +1479,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None + self, message_piece: MessagePiece, *, objective: str | None = None ) -> list[Score]: return [] @@ -1622,7 +1619,7 @@ async def test_score_value_with_llm_skips_reasoning_piece(good_json): class _AcceptAllValidator(ScorerPromptValidator): """Validator that accepts all pieces (like SelfAskRefusalScorer's default).""" - def validate(self, message: Message, objective: Optional[str] = None) -> None: + def validate(self, message: Message, objective: str | None = None) -> None: pass def is_message_piece_supported(self, message_piece: MessagePiece) -> bool: @@ -1635,21 +1632,21 @@ class _TextOnlyValidator(ScorerPromptValidator): def __init__(self) -> None: super().__init__(supported_data_types=["text", "image_path"]) - def validate(self, message: Message, objective: Optional[str] = None) -> None: + def validate(self, message: Message, objective: str | None = None) -> None: pass class _BlockedContentScorer(TrueFalseScorer): """A mock TrueFalseScorer that records what pieces it was asked to score.""" - def __init__(self, *, validator: Optional[ScorerPromptValidator] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator | None = None) -> None: super().__init__(validator=validator or _TextOnlyValidator()) self.scored_pieces: list[MessagePiece] = [] def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_pieces.append(message_piece) return [ Score( @@ -1676,7 +1673,7 @@ def __init__(self) -> None: def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_pieces.append(message_piece) if message_piece.response_error == "blocked": return [ @@ -1707,7 +1704,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op ] -def _make_blocked_piece(*, partial_content: Optional[str] = None, conversation_id: str = "test-convo") -> MessagePiece: +def _make_blocked_piece(*, partial_content: str | None = None, conversation_id: str = "test-convo") -> MessagePiece: """Create a blocked MessagePiece, optionally with partial content metadata.""" metadata: dict = {} if partial_content is not None: diff --git a/tests/unit/score/test_true_false_composite_scorer.py b/tests/unit/score/test_true_false_composite_scorer.py index af11b5e045..e82aeba9c5 100644 --- a/tests/unit/score/test_true_false_composite_scorer.py +++ b/tests/unit/score/test_true_false_composite_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from unittest.mock import MagicMock import pytest @@ -46,7 +45,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value=str(self._score_value), @@ -154,9 +153,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] with pytest.raises(ValueError, match="All scorers must be true_false scorers"): diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index 4c5d4d3524..28dcbba779 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -3,7 +3,6 @@ import os import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import numpy as np @@ -74,7 +73,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="true_false", @@ -106,7 +105,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="float_scale", @@ -295,7 +294,7 @@ def __init__(self, return_value: bool = True): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.received_objective = objective return [ Score(