diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index c5318a0642..c8236aceda 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -267,6 +267,13 @@ def _build_init_kwargs( Resolves strategies and dataset configuration from the request. + Dataset configuration is built so that the scenario's default + ``DatasetConfiguration`` *subclass* (e.g. ``EncodingDatasetConfiguration``) + is preserved when the caller overrides ``dataset_names`` or + ``max_dataset_size``. Subclasses commonly override + ``get_all_seed_attack_groups()`` or ``_load_seed_groups_for_dataset()`` + to shape seeds into scenario-appropriate ``SeedAttackGroup`` objects. + Args: request: The run request. scenario_class: The resolved scenario class. @@ -276,7 +283,10 @@ def _build_init_kwargs( Dict of kwargs to pass to scenario.initialize_async. Raises: - ValueError: If a strategy name is invalid for the scenario. + ValueError: If a strategy name is invalid for the scenario, or the + scenario class cannot be instantiated with no arguments when + introspection is required to resolve strategies or dataset + configuration. """ init_kwargs: dict[str, Any] = { "objective_target": objective_target, @@ -287,47 +297,78 @@ def _build_init_kwargs( if request.labels: init_kwargs["memory_labels"] = request.labels - # Resolve strategies and default dataset config from a temporary instance - # of the scenario. The downstream _initialize_scenario_async builds its - # own instance (so that scenario_result_id can be passed), so this is a - # cheap throwaway used only for introspection. - needs_introspection = bool(request.strategies) or ( - request.max_dataset_size is not None and not request.dataset_names + # Resolve strategies and dataset config from a temporary instance of the + # scenario. The downstream _initialize_scenario_async builds its own + # instance (so scenario_result_id can be passed), so this is a cheap + # throwaway used only for introspection. Introspection is required + # whenever the caller wants to override strategies, dataset names, or + # the sample cap, because each of those needs the scenario's own + # strategy enum or dataset-config subclass to be resolved correctly. + needs_introspection = ( + bool(request.strategies) or bool(request.dataset_names) or request.max_dataset_size is not None ) - if needs_introspection: - try: - introspection_instance = scenario_class() # type: ignore[ty:missing-argument] - except Exception as exc: - raise ValueError( - f"Cannot resolve runtime configuration for scenario '{request.scenario_name}': " - f"scenario class is not instantiable without arguments ({exc})." - ) from exc - - if request.strategies: - strategy_class = introspection_instance._strategy_class - strategy_enums = [] - for name in request.strategies: - try: - strategy_enums.append(strategy_class(name)) - except ValueError: - available_strategies = [s.value for s in strategy_class] - raise ValueError( - f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " - f"Available: {', '.join(available_strategies)}" - ) from None - init_kwargs["scenario_strategies"] = strategy_enums - - if request.max_dataset_size is not None and not request.dataset_names: - default_config = introspection_instance._default_dataset_config + if not needs_introspection: + return init_kwargs + + try: + introspection_instance = scenario_class() # type: ignore[ty:missing-argument] + except Exception as exc: + raise ValueError( + f"Cannot resolve runtime configuration for scenario '{request.scenario_name}': " + f"scenario class is not instantiable without arguments ({exc})." + ) from exc + + if request.strategies: + strategy_class = introspection_instance._strategy_class + strategy_enums = [] + for name in request.strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + if request.dataset_names or request.max_dataset_size is not None: + default_config = introspection_instance._default_dataset_config + + if request.dataset_names: + # Construct a fresh instance of the scenario's own dataset-config + # class so subclass-specific behavior is preserved. + default_config_class = type(default_config) + try: + init_kwargs["dataset_config"] = default_config_class( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + except TypeError as exc: + # The subclass __init__ takes extra required kwargs we cannot + # supply from a backend request. Fall back to the base + # DatasetConfiguration so the run can still proceed; downstream + # scenarios that strictly require the subclass should either + # define a no-extra-required-args constructor or surface the + # incompatibility through their own initialize_async validation. + logger.warning( + "Cannot construct %s(dataset_names=..., max_dataset_size=...) (%s). " + "Falling back to a generic DatasetConfiguration; scenario-specific " + "dataset-config behavior may be lost.", + default_config_class.__name__, + exc, + ) + init_kwargs["dataset_config"] = DatasetConfiguration( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + elif request.max_dataset_size is not None: + # Reuse the scenario's default dataset config (preserves subtype + + # the scenario's own default dataset names) and override only the + # sample cap. Safe because the introspection instance is throwaway. default_config.max_dataset_size = request.max_dataset_size init_kwargs["dataset_config"] = default_config - if request.dataset_names: - init_kwargs["dataset_config"] = DatasetConfiguration( - dataset_names=request.dataset_names, - max_dataset_size=request.max_dataset_size, - ) - return init_kwargs async def _initialize_scenario_async( diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 9827a22e8e..0a7463d1f0 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -6,6 +6,7 @@ """ from datetime import datetime, timezone +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -20,6 +21,7 @@ ScenarioRunService, ) from pyrit.models import AttackOutcome +from pyrit.scenario.core import DatasetConfiguration _REGISTRY_PATCH_BASE = "pyrit.registry" _MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" @@ -279,6 +281,139 @@ async def test_start_run_max_dataset_size_uses_default_config(self, mock_all_reg init_call = scenario_instance.initialize_async.await_args assert init_call.kwargs["dataset_config"] is default_config + async def test_start_run_dataset_names_preserves_subclass_config_type(self, mock_all_registries) -> None: + """``dataset_names`` rebuilds the config using the scenario's own DatasetConfiguration subclass. + + Regression: passing ``dataset_names`` via the backend used to construct + a plain ``DatasetConfiguration``, silently losing subclass behavior + (e.g. ``EncodingDatasetConfiguration``'s objective shaping). + """ + + # Create a marker subclass so we can verify type preservation without + # depending on any concrete scenario implementation. + class _MarkerDatasetConfiguration(DatasetConfiguration): + pass + + default_config = _MarkerDatasetConfiguration(dataset_names=["original"], max_dataset_size=100) + scenario_instance = mock_all_registries["scenario_instance"] + scenario_instance._default_dataset_config = default_config + + service = ScenarioRunService() + await service.start_run_async(request=_make_request(dataset_names=["custom_a", "custom_b"], max_dataset_size=3)) + + init_call = scenario_instance.initialize_async.await_args + built_config = init_call.kwargs["dataset_config"] + + # Type is preserved (this is the regression assertion) + assert type(built_config) is _MarkerDatasetConfiguration + # And carries the caller-supplied values, not the scenario defaults + assert built_config.get_default_dataset_names() == ["custom_a", "custom_b"] + assert built_config.max_dataset_size == 3 + # The original default config is not mutated when a fresh dataset_names is supplied + assert default_config.get_default_dataset_names() == ["original"] + assert default_config.max_dataset_size == 100 + + async def test_start_run_dataset_names_without_max_dataset_size_preserves_subclass( + self, mock_all_registries + ) -> None: + """``dataset_names`` alone (no ``max_dataset_size``) still preserves the subclass type.""" + + class _MarkerDatasetConfiguration(DatasetConfiguration): + pass + + scenario_instance = mock_all_registries["scenario_instance"] + scenario_instance._default_dataset_config = _MarkerDatasetConfiguration(dataset_names=["original"]) + + service = ScenarioRunService() + await service.start_run_async(request=_make_request(dataset_names=["only_this"])) + + init_call = scenario_instance.initialize_async.await_args + built_config = init_call.kwargs["dataset_config"] + assert type(built_config) is _MarkerDatasetConfiguration + assert built_config.get_default_dataset_names() == ["only_this"] + assert built_config.max_dataset_size is None + + async def test_start_run_dataset_names_falls_back_when_subclass_constructor_incompatible( + self, mock_all_registries, caplog + ) -> None: + """If the subclass __init__ rejects standard kwargs, fall back to plain ``DatasetConfiguration``.""" + + class _RequiresExtraArgConfiguration(DatasetConfiguration): + def __init__(self, *, required_extra: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._required_extra = required_extra + + scenario_instance = mock_all_registries["scenario_instance"] + # Build the default with the required kwarg so introspection succeeds. + scenario_instance._default_dataset_config = _RequiresExtraArgConfiguration( + required_extra="seeded", dataset_names=["original"] + ) + + service = ScenarioRunService() + with caplog.at_level("WARNING", logger=_svc_mod.logger.name): + await service.start_run_async(request=_make_request(dataset_names=["custom"])) + + init_call = scenario_instance.initialize_async.await_args + built_config = init_call.kwargs["dataset_config"] + + # Fallback is the generic base class, not the subclass + assert type(built_config) is DatasetConfiguration + assert built_config.get_default_dataset_names() == ["custom"] + # Warning was logged so the operator can see the silent degradation + assert any( + "_RequiresExtraArgConfiguration" in record.message + and "Falling back to a generic DatasetConfiguration" in record.message + for record in caplog.records + ) + + async def test_start_run_dataset_names_introspection_failure_raises(self, mock_memory) -> None: + """Passing ``dataset_names`` against a non-no-arg-instantiable scenario fails fast.""" + # Mirrors test_start_run_scenario_not_no_arg_instantiable_raises but for the dataset_names path. + mock_scenario_class = MagicMock( + side_effect=[ + TypeError("missing 1 required positional argument: 'objective_target'"), + ] + ) + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.get_names.return_value = ["my_target"] + + mock_ir = MagicMock() + + service = ScenarioRunService() + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + with pytest.raises(ValueError, match="not instantiable without arguments"): + await service.start_run_async(request=_make_request(dataset_names=["custom"])) + + async def test_start_run_max_dataset_size_with_dataset_names_uses_subclass_with_both( + self, mock_all_registries + ) -> None: + """When both ``dataset_names`` and ``max_dataset_size`` are supplied, both flow into the subclass instance.""" + + class _MarkerDatasetConfiguration(DatasetConfiguration): + pass + + scenario_instance = mock_all_registries["scenario_instance"] + scenario_instance._default_dataset_config = _MarkerDatasetConfiguration( + dataset_names=["original"], max_dataset_size=99 + ) + + service = ScenarioRunService() + await service.start_run_async(request=_make_request(dataset_names=["a", "b"], max_dataset_size=7)) + + built_config = scenario_instance.initialize_async.await_args.kwargs["dataset_config"] + assert type(built_config) is _MarkerDatasetConfiguration + assert built_config.get_default_dataset_names() == ["a", "b"] + assert built_config.max_dataset_size == 7 + async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: """Test that exceeding concurrent run limit raises ValueError.""" service = ScenarioRunService()