Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 82 additions & 38 deletions pyrit/backend/services/scenario_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ def _build_init_kwargs(

Resolves strategies and dataset configuration from the request.

Dataset configuration is built so that the scenario's default
``DatasetConfiguration`` *subclass* (e.g. ``EncodingDatasetConfiguration``)
is preserved when the caller overrides ``dataset_names`` or
``max_dataset_size``. Subclasses commonly override
``get_all_seed_attack_groups()`` or ``_load_seed_groups_for_dataset()``
to shape seeds into scenario-appropriate ``SeedAttackGroup`` objects;
substituting a plain ``DatasetConfiguration`` silently loses that
behavior and surfaces as obscure downstream errors at attack-construction
time (e.g. ``SeedAttackGroup must have exactly one objective``).

Args:
request: The run request.
scenario_class: The resolved scenario class.
Expand All @@ -276,7 +286,10 @@ def _build_init_kwargs(
Dict of kwargs to pass to scenario.initialize_async.

Raises:
ValueError: If a strategy name is invalid for the scenario.
ValueError: If a strategy name is invalid for the scenario, or the
scenario class cannot be instantiated with no arguments when
introspection is required to resolve strategies or dataset
configuration.
"""
init_kwargs: dict[str, Any] = {
"objective_target": objective_target,
Expand All @@ -287,46 +300,77 @@ def _build_init_kwargs(
if request.labels:
init_kwargs["memory_labels"] = request.labels

# Resolve strategies and default dataset config from a temporary instance
# of the scenario. The downstream _initialize_scenario_async builds its
# own instance (so that scenario_result_id can be passed), so this is a
# cheap throwaway used only for introspection.
needs_introspection = bool(request.strategies) or (
request.max_dataset_size is not None and not request.dataset_names
# Resolve strategies and dataset config from a temporary instance of the
# scenario. The downstream _initialize_scenario_async builds its own
# instance (so scenario_result_id can be passed), so this is a cheap
# throwaway used only for introspection. Introspection is required
# whenever the caller wants to override strategies, dataset names, or
# the sample cap, because each of those needs the scenario's own
# strategy enum or dataset-config subclass to be resolved correctly.
needs_introspection = (
bool(request.strategies) or bool(request.dataset_names) or request.max_dataset_size is not None
)
if needs_introspection:
try:
introspection_instance = scenario_class() # type: ignore[ty:missing-argument]
except Exception as exc:
raise ValueError(
f"Cannot resolve runtime configuration for scenario '{request.scenario_name}': "
f"scenario class is not instantiable without arguments ({exc})."
) from exc

if request.strategies:
strategy_class = introspection_instance._strategy_class
strategy_enums = []
for name in request.strategies:
try:
strategy_enums.append(strategy_class(name))
except ValueError:
available_strategies = [s.value for s in strategy_class]
raise ValueError(
f"Strategy '{name}' not found for scenario '{request.scenario_name}'. "
f"Available: {', '.join(available_strategies)}"
) from None
init_kwargs["scenario_strategies"] = strategy_enums

if request.max_dataset_size is not None and not request.dataset_names:
default_config = introspection_instance._default_dataset_config
default_config.max_dataset_size = request.max_dataset_size
init_kwargs["dataset_config"] = default_config
if not needs_introspection:
return init_kwargs

try:
introspection_instance = scenario_class() # type: ignore[ty:missing-argument]
except Exception as exc:
raise ValueError(
f"Cannot resolve runtime configuration for scenario '{request.scenario_name}': "
f"scenario class is not instantiable without arguments ({exc})."
) from exc

if request.strategies:
strategy_class = introspection_instance._strategy_class
strategy_enums = []
for name in request.strategies:
try:
strategy_enums.append(strategy_class(name))
except ValueError:
available_strategies = [s.value for s in strategy_class]
raise ValueError(
f"Strategy '{name}' not found for scenario '{request.scenario_name}'. "
f"Available: {', '.join(available_strategies)}"
) from None
init_kwargs["scenario_strategies"] = strategy_enums

default_config = introspection_instance._default_dataset_config

if request.dataset_names:
init_kwargs["dataset_config"] = DatasetConfiguration(
dataset_names=request.dataset_names,
max_dataset_size=request.max_dataset_size,
)
# Construct a fresh instance of the scenario's own dataset-config
# class so subclass-specific behavior (e.g. EncodingDatasetConfiguration's
# objective-shaping in get_all_seed_attack_groups) is preserved.
default_config_class = type(default_config)
try:
init_kwargs["dataset_config"] = default_config_class(
dataset_names=request.dataset_names,
max_dataset_size=request.max_dataset_size,
)
except TypeError as exc:
# The subclass __init__ takes extra required kwargs we cannot
# supply from a backend request. Fall back to the base
# DatasetConfiguration so the run can still proceed; downstream
# scenarios that strictly require the subclass should either
# define a no-extra-required-args constructor or surface the
# incompatibility through their own initialize_async validation.
logger.warning(
"Cannot construct %s(dataset_names=..., max_dataset_size=...) (%s). "
"Falling back to a generic DatasetConfiguration; scenario-specific "
"dataset-config behavior may be lost.",
default_config_class.__name__,
exc,
)
init_kwargs["dataset_config"] = DatasetConfiguration(
dataset_names=request.dataset_names,
max_dataset_size=request.max_dataset_size,
)
elif request.max_dataset_size is not None:
# Reuse the scenario's default dataset config (preserves subtype +
# the scenario's own default dataset names) and override only the
# sample cap. Safe because the introspection instance is throwaway.
default_config.max_dataset_size = request.max_dataset_size
init_kwargs["dataset_config"] = default_config

return init_kwargs

Expand Down
135 changes: 135 additions & 0 deletions tests/unit/backend/test_scenario_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from datetime import datetime, timezone
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand All @@ -20,6 +21,7 @@
ScenarioRunService,
)
from pyrit.models import AttackOutcome
from pyrit.scenario.core import DatasetConfiguration

_REGISTRY_PATCH_BASE = "pyrit.registry"
_MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance"
Expand Down Expand Up @@ -279,6 +281,139 @@ async def test_start_run_max_dataset_size_uses_default_config(self, mock_all_reg
init_call = scenario_instance.initialize_async.await_args
assert init_call.kwargs["dataset_config"] is default_config

async def test_start_run_dataset_names_preserves_subclass_config_type(self, mock_all_registries) -> None:
"""``dataset_names`` rebuilds the config using the scenario's own DatasetConfiguration subclass.

Regression: passing ``dataset_names`` via the backend used to construct
a plain ``DatasetConfiguration``, silently losing subclass behavior
(e.g. ``EncodingDatasetConfiguration``'s objective shaping).
"""

# Create a marker subclass so we can verify type preservation without
# depending on any concrete scenario implementation.
class _MarkerDatasetConfiguration(DatasetConfiguration):
pass

default_config = _MarkerDatasetConfiguration(dataset_names=["original"], max_dataset_size=100)
scenario_instance = mock_all_registries["scenario_instance"]
scenario_instance._default_dataset_config = default_config

service = ScenarioRunService()
await service.start_run_async(request=_make_request(dataset_names=["custom_a", "custom_b"], max_dataset_size=3))

init_call = scenario_instance.initialize_async.await_args
built_config = init_call.kwargs["dataset_config"]

# Type is preserved (this is the regression assertion)
assert type(built_config) is _MarkerDatasetConfiguration
# And carries the caller-supplied values, not the scenario defaults
assert built_config.get_default_dataset_names() == ["custom_a", "custom_b"]
assert built_config.max_dataset_size == 3
# The original default config is not mutated when a fresh dataset_names is supplied
assert default_config.get_default_dataset_names() == ["original"]
assert default_config.max_dataset_size == 100

async def test_start_run_dataset_names_without_max_dataset_size_preserves_subclass(
self, mock_all_registries
) -> None:
"""``dataset_names`` alone (no ``max_dataset_size``) still preserves the subclass type."""

class _MarkerDatasetConfiguration(DatasetConfiguration):
pass

scenario_instance = mock_all_registries["scenario_instance"]
scenario_instance._default_dataset_config = _MarkerDatasetConfiguration(dataset_names=["original"])

service = ScenarioRunService()
await service.start_run_async(request=_make_request(dataset_names=["only_this"]))

init_call = scenario_instance.initialize_async.await_args
built_config = init_call.kwargs["dataset_config"]
assert type(built_config) is _MarkerDatasetConfiguration
assert built_config.get_default_dataset_names() == ["only_this"]
assert built_config.max_dataset_size is None

async def test_start_run_dataset_names_falls_back_when_subclass_constructor_incompatible(
self, mock_all_registries, caplog
) -> None:
"""If the subclass __init__ rejects standard kwargs, fall back to plain ``DatasetConfiguration``."""

class _RequiresExtraArgConfiguration(DatasetConfiguration):
def __init__(self, *, required_extra: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._required_extra = required_extra

scenario_instance = mock_all_registries["scenario_instance"]
# Build the default with the required kwarg so introspection succeeds.
scenario_instance._default_dataset_config = _RequiresExtraArgConfiguration(
required_extra="seeded", dataset_names=["original"]
)

service = ScenarioRunService()
with caplog.at_level("WARNING", logger=_svc_mod.logger.name):
await service.start_run_async(request=_make_request(dataset_names=["custom"]))

init_call = scenario_instance.initialize_async.await_args
built_config = init_call.kwargs["dataset_config"]

# Fallback is the generic base class, not the subclass
assert type(built_config) is DatasetConfiguration
assert built_config.get_default_dataset_names() == ["custom"]
# Warning was logged so the operator can see the silent degradation
assert any(
"_RequiresExtraArgConfiguration" in record.message
and "Falling back to a generic DatasetConfiguration" in record.message
for record in caplog.records
)

async def test_start_run_dataset_names_introspection_failure_raises(self, mock_memory) -> None:
"""Passing ``dataset_names`` against a non-no-arg-instantiable scenario fails fast."""
# Mirrors test_start_run_scenario_not_no_arg_instantiable_raises but for the dataset_names path.
mock_scenario_class = MagicMock(
side_effect=[
TypeError("missing 1 required positional argument: 'objective_target'"),
]
)
mock_sr = MagicMock()
mock_sr.get_class.return_value = mock_scenario_class

mock_tr = MagicMock()
mock_tr.get_instance_by_name.return_value = MagicMock()
mock_tr.get_names.return_value = ["my_target"]

mock_ir = MagicMock()

service = ScenarioRunService()

with (
patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr),
patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr),
patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir),
):
with pytest.raises(ValueError, match="not instantiable without arguments"):
await service.start_run_async(request=_make_request(dataset_names=["custom"]))

async def test_start_run_max_dataset_size_with_dataset_names_uses_subclass_with_both(
self, mock_all_registries
) -> None:
"""When both ``dataset_names`` and ``max_dataset_size`` are supplied, both flow into the subclass instance."""

class _MarkerDatasetConfiguration(DatasetConfiguration):
pass

scenario_instance = mock_all_registries["scenario_instance"]
scenario_instance._default_dataset_config = _MarkerDatasetConfiguration(
dataset_names=["original"], max_dataset_size=99
)

service = ScenarioRunService()
await service.start_run_async(request=_make_request(dataset_names=["a", "b"], max_dataset_size=7))

built_config = scenario_instance.initialize_async.await_args.kwargs["dataset_config"]
assert type(built_config) is _MarkerDatasetConfiguration
assert built_config.get_default_dataset_names() == ["a", "b"]
assert built_config.max_dataset_size == 7

async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None:
"""Test that exceeding concurrent run limit raises ValueError."""
service = ScenarioRunService()
Expand Down
Loading