diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py new file mode 100644 index 0000000000..b9d62d019c --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared image-fetch-and-cache helper for multimodal seed-dataset loaders. + +Multiple loaders under ``pyrit.datasets.seed_datasets.remote`` need to +download an image from a URL (or write bytes already in hand), persist it +under the ``seed-prompt-entries`` cache, and return the local path while +skipping the network call on a cache hit. This module centralizes that +logic so individual loaders only need to construct the appropriate filename. +""" + +import logging +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Optional + +from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.models import data_serializer_factory + +logger = logging.getLogger(__name__) + + +async def fetch_and_cache_image_async( + *, + filename: str, + image_url: Optional[str] = None, + image_bytes: Optional[bytes] = None, + log_prefix: str = "image-cache", + request_headers: Optional[Mapping[str, str]] = None, + request_timeout: Optional[float] = None, + follow_redirects: bool = False, +) -> str: + """ + Fetch (or accept) image bytes and cache them under ``seed-prompt-entries``. + + The cached path is constructed deterministically from the configured + ``results_path`` plus the serializer's ``data_sub_directory`` plus + ``filename``, normalized through ``pathlib.Path`` so the same on-disk + location is produced on Windows and POSIX. If a file already exists at + that path, the path is returned immediately without performing the network + fetch or rewriting bytes. + + Exactly one of ``image_url`` or ``image_bytes`` must be provided. When + ``image_bytes`` is supplied, the network is never contacted regardless of + whether ``image_url`` is also provided. + + Args: + filename (str): On-disk filename for the cached image, including + extension (e.g. ``"harmbench_.png"``). The caller controls this + so per-loader naming and existing cache files stay intact. + image_url (str | None): URL to fetch the image from. Required when + ``image_bytes`` is not provided. + image_bytes (bytes | None): Raw image bytes (e.g. extracted from a PIL + image). When provided, no network request is made. + log_prefix (str): Short tag prepended to warning log messages + (e.g. ``"HarmBench-Multimodal"``) so existing log output stays + recognizable per-loader. + request_headers (Mapping[str, str] | None): Optional HTTP headers to + send with the request (only used when fetching from ``image_url``). + request_timeout (float | None): Optional request timeout in seconds. + follow_redirects (bool): Whether the HTTP client should follow + redirects. Defaults to ``False``. + + Returns: + str: Local path to the cached image. + + Raises: + ValueError: If neither ``image_url`` nor ``image_bytes`` is provided. + RuntimeError: If the serializer's underlying memory is not properly + configured (``results_path`` or ``results_storage_io`` missing). + Exception: Any error raised by the underlying HTTP fetch or by + ``serializer.save_data`` is propagated so callers can catch and + skip individual rows. + """ + if image_bytes is None and not image_url: + raise ValueError("fetch_and_cache_image_async requires either image_url or image_bytes") + + extension = Path(filename).suffix.lstrip(".") or None + + serializer = data_serializer_factory( + category="seed-prompt-entries", + data_type="image_path", + extension=extension, + ) + + results_path = serializer._memory.results_path if serializer._memory is not None else None + results_storage_io = serializer._memory.results_storage_io if serializer._memory is not None else None + if not results_path or results_storage_io is None: + raise RuntimeError( + f"[{log_prefix}] Serializer memory is not properly configured: " + "results_path and results_storage_io must be set." + ) + + sub_directory = serializer.data_sub_directory.lstrip("/\\") + serializer.value = str(Path(results_path) / sub_directory / filename) + + try: + if await results_storage_io.path_exists(serializer.value): + return serializer.value + except Exception as e: + logger.warning(f"[{log_prefix}] Failed to check if cached image {filename} exists: {e}") + + if image_bytes is None: + # image_url is guaranteed non-empty by the validation above when image_bytes is None. + assert image_url is not None + httpx_kwargs: dict[str, Any] = {"follow_redirects": follow_redirects} + if request_timeout is not None: + httpx_kwargs["timeout"] = request_timeout + + headers_dict = dict(request_headers) if request_headers is not None else None + response = await make_request_and_raise_if_error_async( + endpoint_uri=image_url, + method="GET", + headers=headers_dict, + **httpx_kwargs, + ) + image_bytes = response.content + + await serializer.save_data(data=image_bytes, output_filename=Path(filename).stem) + + return str(serializer.value) diff --git a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py index 59a57b5cf4..b4d43bca04 100644 --- a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py @@ -6,12 +6,13 @@ from dataclasses import dataclass from typing import Literal -from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.common.path import DB_DATA_PATH +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import Seed, SeedDataset, SeedObjective, SeedPrompt, data_serializer_factory +from pyrit.models import Seed, SeedDataset, SeedObjective, SeedPrompt logger = logging.getLogger(__name__) @@ -346,20 +347,8 @@ async def _fetch_template_async(self, template_name: str) -> str: f"Invalid template name '{template_name}'. Must be one of: {', '.join(self.TEMPLATE_NAMES)}" ) - filename = f"comic_jailbreak_{template_name}.png" - serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") - - results_path = serializer._memory.results_path or str(DB_DATA_PATH) - storage_io = serializer._memory.results_storage_io - serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") - try: - if storage_io and await storage_io.path_exists(serializer.value): - return serializer.value - except Exception as e: - logger.warning(f"[ComicJailbreak] Failed to check cache for template {template_name}: {e}") - - image_url = f"{self.TEMPLATE_BASE_URL}{template_name}.png" - response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET") - await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) - - return str(serializer.value) + return await fetch_and_cache_image_async( + filename=f"comic_jailbreak_{template_name}.png", + image_url=f"{self.TEMPLATE_BASE_URL}{template_name}.png", + log_prefix="ComicJailbreak", + ) diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index 6e063c6242..514c5a1cb8 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -6,11 +6,13 @@ from enum import Enum from typing import Literal, Optional -from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -225,25 +227,8 @@ async def _fetch_and_save_image_async(self, image_url: str, behavior_id: str) -> Raises: RuntimeError: If the serializer memory is not properly configured. """ - filename = f"harmbench_{behavior_id}.png" - serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") - - # Return existing path if image already exists for this BehaviorID - results_path = serializer._memory.results_path - results_storage_io = serializer._memory.results_storage_io - if not results_path or results_storage_io is None: - raise RuntimeError( - "[HarmBench-Multimodal] Serializer memory is not properly configured: " - "results_path and results_storage_io must be set." - ) - serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") - try: - if await results_storage_io.path_exists(serializer.value): - return serializer.value - except Exception as e: - logger.warning(f"[HarmBench-Multimodal] Failed to check if image for {behavior_id} exists in cache: {e}") - - response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET") - await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) - - return str(serializer.value) + return await fetch_and_cache_image_async( + filename=f"harmbench_{behavior_id}.png", + image_url=image_url, + log_prefix="HarmBench-Multimodal", + ) diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index 0028554bec..2edd921ebf 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -6,11 +6,13 @@ from enum import Enum from typing import Literal, Optional -from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -321,23 +323,8 @@ async def _fetch_and_save_image_async(self, image_url: str, example_id: str) -> Returns: str: Local path to the saved image. """ - filename = f"visual_leak_bench_{example_id}.png" - serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") - - # Return existing path if image already exists - results_path = (serializer._memory.results_path if serializer._memory is not None else None) or "" - serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") - try: - if ( - serializer._memory is not None - and serializer._memory.results_storage_io is not None - and await serializer._memory.results_storage_io.path_exists(serializer.value) - ): - return serializer.value - except Exception as e: - logger.warning(f"[VisualLeakBench] Failed to check if image {example_id} exists in cache: {e}") - - response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET") - await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) - - return str(serializer.value) + return await fetch_and_cache_image_async( + filename=f"visual_leak_bench_{example_id}.png", + image_url=image_url, + log_prefix="VisualLeakBench", + ) diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 3be8bd64be..74ccf9434f 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -6,11 +6,13 @@ from enum import Enum from typing import Literal, Optional -from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -252,22 +254,7 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st Raises: RuntimeError: If the serializer memory is not properly configured. """ - filename = f"ml_vlsu_{group_id}.png" - serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") - - # Return existing path if image already exists - results_path = serializer._memory.results_path - results_storage_io = serializer._memory.results_storage_io - if not results_path or results_storage_io is None: - raise RuntimeError("[ML-VLSU] Serializer memory is not properly configured.") - serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") - try: - if await results_storage_io.path_exists(serializer.value): - return serializer.value - except Exception as e: - logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}") - - # Add browser-like headers for better success rate + # Browser-like headers improve fetch success rate for the ML-VLSU hosting setup. headers = { "User-Agent": ( "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)" @@ -282,13 +269,11 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st "Upgrade-Insecure-Requests": "1", } - response = await make_request_and_raise_if_error_async( - endpoint_uri=image_url, - method="GET", - headers=headers, - timeout=2.0, + return await fetch_and_cache_image_async( + filename=f"ml_vlsu_{group_id}.png", + image_url=image_url, + log_prefix="ML-VLSU", + request_headers=headers, + request_timeout=2.0, follow_redirects=True, ) - await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) - - return str(serializer.value) diff --git a/tests/unit/datasets/test_harmbench_multimodal_dataset.py b/tests/unit/datasets/test_harmbench_multimodal_dataset.py index b0ad4af8c9..c8e7cdd4bd 100644 --- a/tests/unit/datasets/test_harmbench_multimodal_dataset.py +++ b/tests/unit/datasets/test_harmbench_multimodal_dataset.py @@ -153,7 +153,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured(): mock_serializer._memory = mock_memory with patch( - "pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset.data_serializer_factory", + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", return_value=mock_serializer, ): loader = _HarmBenchMultimodalDataset() @@ -163,6 +163,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured(): async def test_fetch_and_save_image_returns_cached_path(): """Test that _fetch_and_save_image_async returns cached path when image already exists.""" + from pathlib import Path from unittest.mock import MagicMock mock_serializer = MagicMock() @@ -175,7 +176,7 @@ async def test_fetch_and_save_image_returns_cached_path(): mock_serializer.data_sub_directory = "/images" with patch( - "pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset.data_serializer_factory", + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", return_value=mock_serializer, ): loader = _HarmBenchMultimodalDataset() @@ -183,6 +184,6 @@ async def test_fetch_and_save_image_returns_cached_path(): behavior_id="test_id", image_url="https://example.com/img.png" ) - expected_path = "/results/images/harmbench_test_id.png" + expected_path = str(Path("/results") / "images" / "harmbench_test_id.png") assert result == expected_path assert mock_serializer.value == expected_path diff --git a/tests/unit/datasets/test_image_cache.py b/tests/unit/datasets/test_image_cache.py new file mode 100644 index 0000000000..d7936bd391 --- /dev/null +++ b/tests/unit/datasets/test_image_cache.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) + + +def _make_mock_serializer(*, exists: bool = False) -> MagicMock: + """Build a MagicMock serializer with memory configured.""" + mock_serializer = MagicMock() + mock_memory = MagicMock() + mock_memory.results_path = "/results" + mock_storage_io = AsyncMock() + mock_storage_io.path_exists = AsyncMock(return_value=exists) + mock_memory.results_storage_io = mock_storage_io + mock_serializer._memory = mock_memory + mock_serializer.data_sub_directory = "/seed-prompt-entries/images" + mock_serializer.save_data = AsyncMock() + return mock_serializer + + +async def test_returns_cached_path_when_file_exists_and_skips_network(): + mock_serializer = _make_mock_serializer(exists=True) + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(), + ) as mock_request, + ): + result = await fetch_and_cache_image_async( + filename="test_image.png", + image_url="https://example.com/image.png", + log_prefix="TestLoader", + ) + + expected_path = str(Path("/results") / "seed-prompt-entries" / "images" / "test_image.png") + assert result == expected_path + assert mock_serializer.value == expected_path + mock_request.assert_not_called() + mock_serializer.save_data.assert_not_called() + + +async def test_downloads_when_cache_miss_and_writes_bytes(): + mock_serializer = _make_mock_serializer(exists=False) + + mock_response = MagicMock() + mock_response.content = b"fake-image-bytes" + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(return_value=mock_response), + ) as mock_request, + ): + await fetch_and_cache_image_async( + filename="test_image.png", + image_url="https://example.com/image.png", + log_prefix="TestLoader", + ) + + mock_request.assert_called_once() + assert mock_request.call_args.kwargs["endpoint_uri"] == "https://example.com/image.png" + assert mock_request.call_args.kwargs["method"] == "GET" + + mock_serializer.save_data.assert_called_once() + save_kwargs = mock_serializer.save_data.call_args.kwargs + assert save_kwargs["data"] == b"fake-image-bytes" + assert save_kwargs["output_filename"] == "test_image" + + +async def test_image_bytes_path_skips_network_call(): + mock_serializer = _make_mock_serializer(exists=False) + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(), + ) as mock_request, + ): + await fetch_and_cache_image_async( + filename="bytes_image.png", + image_bytes=b"raw-pil-bytes", + log_prefix="TestLoader", + ) + + mock_request.assert_not_called() + mock_serializer.save_data.assert_called_once() + assert mock_serializer.save_data.call_args.kwargs["data"] == b"raw-pil-bytes" + assert mock_serializer.save_data.call_args.kwargs["output_filename"] == "bytes_image" + + +async def test_raises_value_error_when_neither_url_nor_bytes_provided(): + with pytest.raises(ValueError, match="either image_url or image_bytes"): + await fetch_and_cache_image_async(filename="test.png") + + +async def test_raises_runtime_error_when_memory_not_configured(): + mock_serializer = MagicMock() + mock_memory = MagicMock() + mock_memory.results_path = None + mock_memory.results_storage_io = None + mock_serializer._memory = mock_memory + + with patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ): + with pytest.raises(RuntimeError, match="Serializer memory is not properly configured"): + await fetch_and_cache_image_async( + filename="test.png", + image_url="https://example.com/img.png", + ) + + +async def test_propagates_http_failures(): + mock_serializer = _make_mock_serializer(exists=False) + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(side_effect=Exception("download failed")), + ), + ): + with pytest.raises(Exception, match="download failed"): + await fetch_and_cache_image_async( + filename="test.png", + image_url="https://example.com/img.png", + ) + + mock_serializer.save_data.assert_not_called() + + +async def test_passes_custom_headers_timeout_and_redirects_to_http_client(): + mock_serializer = _make_mock_serializer(exists=False) + mock_response = MagicMock() + mock_response.content = b"bytes" + + custom_headers = {"User-Agent": "test-agent", "Accept": "image/*"} + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(return_value=mock_response), + ) as mock_request, + ): + await fetch_and_cache_image_async( + filename="custom.png", + image_url="https://example.com/img.png", + request_headers=custom_headers, + request_timeout=5.0, + follow_redirects=True, + ) + + kwargs = mock_request.call_args.kwargs + assert kwargs["headers"] == custom_headers + assert kwargs["timeout"] == 5.0 + assert kwargs["follow_redirects"] is True + + +async def test_path_exists_failure_is_logged_and_treated_as_cache_miss(): + mock_serializer = _make_mock_serializer(exists=False) + mock_serializer._memory.results_storage_io.path_exists = AsyncMock(side_effect=Exception("storage IO unavailable")) + + mock_response = MagicMock() + mock_response.content = b"bytes" + + with ( + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", + return_value=mock_serializer, + ), + patch( + "pyrit.datasets.seed_datasets.remote._image_cache.make_request_and_raise_if_error_async", + new=AsyncMock(return_value=mock_response), + ) as mock_request, + ): + await fetch_and_cache_image_async( + filename="failing_cache.png", + image_url="https://example.com/img.png", + ) + + # Treated as cache miss: fetch happens and save runs. + mock_request.assert_called_once() + mock_serializer.save_data.assert_called_once() diff --git a/tests/unit/datasets/test_visual_leak_bench_dataset.py b/tests/unit/datasets/test_visual_leak_bench_dataset.py index eb984762bc..a425e6b3d9 100644 --- a/tests/unit/datasets/test_visual_leak_bench_dataset.py +++ b/tests/unit/datasets/test_visual_leak_bench_dataset.py @@ -351,6 +351,7 @@ def test_get_query_prompt_pii(self): async def test_fetch_and_save_image_returns_cached_path(): """Test that _fetch_and_save_image_async returns cached path when image already exists.""" + from pathlib import Path from unittest.mock import AsyncMock, MagicMock mock_serializer = MagicMock() @@ -363,7 +364,7 @@ async def test_fetch_and_save_image_returns_cached_path(): mock_serializer.data_sub_directory = "/images" with patch( - "pyrit.datasets.seed_datasets.remote.visual_leak_bench_dataset.data_serializer_factory", + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", return_value=mock_serializer, ): loader = _VisualLeakBenchDataset() @@ -371,6 +372,6 @@ async def test_fetch_and_save_image_returns_cached_path(): image_url="https://example.com/img.png", example_id="test_001" ) - expected_path = "/results/images/visual_leak_bench_test_001.png" + expected_path = str(Path("/results") / "images" / "visual_leak_bench_test_001.png") assert result == expected_path assert mock_serializer.value == expected_path diff --git a/tests/unit/datasets/test_vlsu_multimodal_dataset.py b/tests/unit/datasets/test_vlsu_multimodal_dataset.py index 9ffa84f810..6fda88e706 100644 --- a/tests/unit/datasets/test_vlsu_multimodal_dataset.py +++ b/tests/unit/datasets/test_vlsu_multimodal_dataset.py @@ -382,7 +382,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured(): mock_serializer._memory = mock_memory with patch( - "pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset.data_serializer_factory", + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", return_value=mock_serializer, ): loader = _VLSUMultimodalDataset() @@ -392,6 +392,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured(): async def test_fetch_and_save_image_returns_cached_path(): """Test that _fetch_and_save_image_async returns cached path when image already exists.""" + from pathlib import Path from unittest.mock import AsyncMock, MagicMock mock_serializer = MagicMock() @@ -404,7 +405,7 @@ async def test_fetch_and_save_image_returns_cached_path(): mock_serializer.data_sub_directory = "/images" with patch( - "pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset.data_serializer_factory", + "pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory", return_value=mock_serializer, ): loader = _VLSUMultimodalDataset() @@ -412,6 +413,6 @@ async def test_fetch_and_save_image_returns_cached_path(): group_id="test_group", image_url="https://example.com/img.png" ) - expected_path = "/results/images/ml_vlsu_test_group.png" + expected_path = str(Path("/results") / "images" / "ml_vlsu_test_group.png") assert result == expected_path assert mock_serializer.value == expected_path