Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions pyrit/datasets/seed_datasets/remote/_image_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Shared image-fetch-and-cache helper for multimodal seed-dataset loaders.

Multiple loaders under ``pyrit.datasets.seed_datasets.remote`` need to
download an image from a URL (or write bytes already in hand), persist it
under the ``seed-prompt-entries`` cache, and return the local path while
skipping the network call on a cache hit. This module centralizes that
logic so individual loaders only need to construct the appropriate filename.
"""

import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.models import data_serializer_factory

logger = logging.getLogger(__name__)


async def fetch_and_cache_image_async(
*,
filename: str,
image_url: Optional[str] = None,
image_bytes: Optional[bytes] = None,
log_prefix: str = "image-cache",
request_headers: Optional[Mapping[str, str]] = None,
request_timeout: Optional[float] = None,
follow_redirects: bool = False,
) -> str:
"""
Fetch (or accept) image bytes and cache them under ``seed-prompt-entries``.

The cached path is constructed deterministically from the configured
``results_path`` plus the serializer's ``data_sub_directory`` plus
``filename``, normalized through ``pathlib.Path`` so the same on-disk
location is produced on Windows and POSIX. If a file already exists at
that path, the path is returned immediately without performing the network
fetch or rewriting bytes.

Exactly one of ``image_url`` or ``image_bytes`` must be provided. When
``image_bytes`` is supplied, the network is never contacted regardless of
whether ``image_url`` is also provided.

Args:
filename (str): On-disk filename for the cached image, including
extension (e.g. ``"harmbench_<id>.png"``). The caller controls this
so per-loader naming and existing cache files stay intact.
image_url (str | None): URL to fetch the image from. Required when
``image_bytes`` is not provided.
image_bytes (bytes | None): Raw image bytes (e.g. extracted from a PIL
image). When provided, no network request is made.
log_prefix (str): Short tag prepended to warning log messages
(e.g. ``"HarmBench-Multimodal"``) so existing log output stays
recognizable per-loader.
request_headers (Mapping[str, str] | None): Optional HTTP headers to
send with the request (only used when fetching from ``image_url``).
request_timeout (float | None): Optional request timeout in seconds.
follow_redirects (bool): Whether the HTTP client should follow
redirects. Defaults to ``False``.

Returns:
str: Local path to the cached image.

Raises:
ValueError: If neither ``image_url`` nor ``image_bytes`` is provided.
RuntimeError: If the serializer's underlying memory is not properly
configured (``results_path`` or ``results_storage_io`` missing).
Exception: Any error raised by the underlying HTTP fetch or by
``serializer.save_data`` is propagated so callers can catch and
skip individual rows.
"""
if image_bytes is None and not image_url:
raise ValueError("fetch_and_cache_image_async requires either image_url or image_bytes")

extension = Path(filename).suffix.lstrip(".") or None

serializer = data_serializer_factory(
category="seed-prompt-entries",
data_type="image_path",
extension=extension,
)

results_path = serializer._memory.results_path if serializer._memory is not None else None
results_storage_io = serializer._memory.results_storage_io if serializer._memory is not None else None
if not results_path or results_storage_io is None:
raise RuntimeError(
f"[{log_prefix}] Serializer memory is not properly configured: "
"results_path and results_storage_io must be set."
)

sub_directory = serializer.data_sub_directory.lstrip("/\\")
serializer.value = str(Path(results_path) / sub_directory / filename)

try:
if await results_storage_io.path_exists(serializer.value):
return serializer.value
except Exception as e:
logger.warning(f"[{log_prefix}] Failed to check if cached image {filename} exists: {e}")

if image_bytes is None:
# image_url is guaranteed non-empty by the validation above when image_bytes is None.
assert image_url is not None
httpx_kwargs: dict[str, Any] = {"follow_redirects": follow_redirects}
if request_timeout is not None:
httpx_kwargs["timeout"] = request_timeout

headers_dict = dict(request_headers) if request_headers is not None else None
response = await make_request_and_raise_if_error_async(
endpoint_uri=image_url,
method="GET",
headers=headers_dict,
**httpx_kwargs,
)
image_bytes = response.content

await serializer.save_data(data=image_bytes, output_filename=Path(filename).stem)

return str(serializer.value)
29 changes: 9 additions & 20 deletions pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from dataclasses import dataclass
from typing import Literal

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.common.path import DB_DATA_PATH
from pyrit.datasets.seed_datasets.remote._image_cache import (
fetch_and_cache_image_async,
)
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.models import Seed, SeedDataset, SeedObjective, SeedPrompt, data_serializer_factory
from pyrit.models import Seed, SeedDataset, SeedObjective, SeedPrompt

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -346,20 +347,8 @@ async def _fetch_template_async(self, template_name: str) -> str:
f"Invalid template name '{template_name}'. Must be one of: {', '.join(self.TEMPLATE_NAMES)}"
)

filename = f"comic_jailbreak_{template_name}.png"
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")

results_path = serializer._memory.results_path or str(DB_DATA_PATH)
storage_io = serializer._memory.results_storage_io
serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}")
try:
if storage_io and await storage_io.path_exists(serializer.value):
return serializer.value
except Exception as e:
logger.warning(f"[ComicJailbreak] Failed to check cache for template {template_name}: {e}")

image_url = f"{self.TEMPLATE_BASE_URL}{template_name}.png"
response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET")
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))

return str(serializer.value)
return await fetch_and_cache_image_async(
filename=f"comic_jailbreak_{template_name}.png",
image_url=f"{self.TEMPLATE_BASE_URL}{template_name}.png",
log_prefix="ComicJailbreak",
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from enum import Enum
from typing import Literal, Optional

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.datasets.seed_datasets.remote._image_cache import (
fetch_and_cache_image_async,
)
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory
from pyrit.models import SeedDataset, SeedPrompt

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -225,25 +227,8 @@ async def _fetch_and_save_image_async(self, image_url: str, behavior_id: str) ->
Raises:
RuntimeError: If the serializer memory is not properly configured.
"""
filename = f"harmbench_{behavior_id}.png"
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")

# Return existing path if image already exists for this BehaviorID
results_path = serializer._memory.results_path
results_storage_io = serializer._memory.results_storage_io
if not results_path or results_storage_io is None:
raise RuntimeError(
"[HarmBench-Multimodal] Serializer memory is not properly configured: "
"results_path and results_storage_io must be set."
)
serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}")
try:
if await results_storage_io.path_exists(serializer.value):
return serializer.value
except Exception as e:
logger.warning(f"[HarmBench-Multimodal] Failed to check if image for {behavior_id} exists in cache: {e}")

response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET")
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))

return str(serializer.value)
return await fetch_and_cache_image_async(
filename=f"harmbench_{behavior_id}.png",
image_url=image_url,
log_prefix="HarmBench-Multimodal",
)
31 changes: 9 additions & 22 deletions pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from enum import Enum
from typing import Literal, Optional

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.datasets.seed_datasets.remote._image_cache import (
fetch_and_cache_image_async,
)
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory
from pyrit.models import SeedDataset, SeedPrompt

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -321,23 +323,8 @@ async def _fetch_and_save_image_async(self, image_url: str, example_id: str) ->
Returns:
str: Local path to the saved image.
"""
filename = f"visual_leak_bench_{example_id}.png"
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")

# Return existing path if image already exists
results_path = (serializer._memory.results_path if serializer._memory is not None else None) or ""
serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}")
try:
if (
serializer._memory is not None
and serializer._memory.results_storage_io is not None
and await serializer._memory.results_storage_io.path_exists(serializer.value)
):
return serializer.value
except Exception as e:
logger.warning(f"[VisualLeakBench] Failed to check if image {example_id} exists in cache: {e}")

response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET")
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))

return str(serializer.value)
return await fetch_and_cache_image_async(
filename=f"visual_leak_bench_{example_id}.png",
image_url=image_url,
log_prefix="VisualLeakBench",
)
37 changes: 11 additions & 26 deletions pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from enum import Enum
from typing import Literal, Optional

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.datasets.seed_datasets.remote._image_cache import (
fetch_and_cache_image_async,
)
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory
from pyrit.models import SeedDataset, SeedPrompt

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -252,22 +254,7 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st
Raises:
RuntimeError: If the serializer memory is not properly configured.
"""
filename = f"ml_vlsu_{group_id}.png"
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")

# Return existing path if image already exists
results_path = serializer._memory.results_path
results_storage_io = serializer._memory.results_storage_io
if not results_path or results_storage_io is None:
raise RuntimeError("[ML-VLSU] Serializer memory is not properly configured.")
serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}")
try:
if await results_storage_io.path_exists(serializer.value):
return serializer.value
except Exception as e:
logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}")

# Add browser-like headers for better success rate
# Browser-like headers improve fetch success rate for the ML-VLSU hosting setup.
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)"
Expand All @@ -282,13 +269,11 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st
"Upgrade-Insecure-Requests": "1",
}

response = await make_request_and_raise_if_error_async(
endpoint_uri=image_url,
method="GET",
headers=headers,
timeout=2.0,
return await fetch_and_cache_image_async(
filename=f"ml_vlsu_{group_id}.png",
image_url=image_url,
log_prefix="ML-VLSU",
request_headers=headers,
request_timeout=2.0,
follow_redirects=True,
)
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))

return str(serializer.value)
7 changes: 4 additions & 3 deletions tests/unit/datasets/test_harmbench_multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured():
mock_serializer._memory = mock_memory

with patch(
"pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset.data_serializer_factory",
"pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory",
return_value=mock_serializer,
):
loader = _HarmBenchMultimodalDataset()
Expand All @@ -163,6 +163,7 @@ async def test_fetch_and_save_image_raises_when_memory_not_configured():

async def test_fetch_and_save_image_returns_cached_path():
"""Test that _fetch_and_save_image_async returns cached path when image already exists."""
from pathlib import Path
from unittest.mock import MagicMock

mock_serializer = MagicMock()
Expand All @@ -175,14 +176,14 @@ async def test_fetch_and_save_image_returns_cached_path():
mock_serializer.data_sub_directory = "/images"

with patch(
"pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset.data_serializer_factory",
"pyrit.datasets.seed_datasets.remote._image_cache.data_serializer_factory",
return_value=mock_serializer,
):
loader = _HarmBenchMultimodalDataset()
result = await loader._fetch_and_save_image_async(
behavior_id="test_id", image_url="https://example.com/img.png"
)

expected_path = "/results/images/harmbench_test_id.png"
expected_path = str(Path("/results") / "images" / "harmbench_test_id.png")
assert result == expected_path
assert mock_serializer.value == expected_path
Loading
Loading