Skip to content
30 changes: 16 additions & 14 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import base64
import hashlib
import tempfile
import time
import wave
from mimetypes import guess_type
Expand Down Expand Up @@ -211,23 +212,24 @@ async def save_formatted_audio_async(

# save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters
if self._is_azure_storage_url(str(file_path)):
local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav")
await asyncio.to_thread(
_write_wav_sync,
str(local_temp_path),
num_channels=num_channels,
sample_width=sample_width,
sample_rate=sample_rate,
data=data,
)

async with aiofiles.open(local_temp_path, "rb") as f:
audio_data = await f.read()
with tempfile.NamedTemporaryFile(suffix=".wav", dir=DB_DATA_PATH, delete=False) as tmp:
local_temp_path = Path(tmp.name)
try:
await asyncio.to_thread(
_write_wav_sync,
str(local_temp_path),
num_channels=num_channels,
sample_width=sample_width,
sample_rate=sample_rate,
data=data,
)
async with aiofiles.open(local_temp_path, "rb") as f:
audio_data = await f.read()
if self._memory.results_storage_io is None:
raise RuntimeError("self._memory.results_storage_io is not initialized")
await self._memory.results_storage_io.write_file_async(file_path, audio_data)
local_temp_path.unlink()

finally:
local_temp_path.unlink(missing_ok=True)
Comment thread
romanlutz marked this conversation as resolved.
# If local, we can just save straight to disk and do not need to delete temp file after
else:
await asyncio.to_thread(
Expand Down
30 changes: 28 additions & 2 deletions tests/unit/models/test_data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ async def _capture_write(file_path, data):
pcm = b"\xaa\xbb\xcc\xdd\xee\xff\x00\x11"
with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory):
with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url):
# Redirect DB_DATA_PATH so the temp_audio.wav write lands in tmp_path
# Redirect so the temp_audio.wav write lands in tmp_path
with patch.object(common_path, "DB_DATA_PATH", str(tmp_path)):
from pyrit.models import data_type_serializer as dts_module

Expand Down Expand Up @@ -616,7 +616,33 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path):
await serializer.save_formatted_audio_async(data=b"\x00\x01\x02\x03")

# The local temp file written via wave.open should have been unlinked after upload.
assert not (tmp_path / "temp_audio.wav").exists()
assert list(tmp_path.glob("*.wav")) == []
mock_storage_io.write_file_async.assert_awaited_once()
assert mock_storage_io.write_file_async.call_args[0][0] == azure_url
assert serializer.value == azure_url


@pytest.mark.asyncio
Comment thread
romanlutz marked this conversation as resolved.
async def test_save_formatted_audio_async_cleans_up_temp_file_on_azure_upload_failure(tmp_path):
"""Regression test: temp file must be deleted even when Azure upload fails."""
serializer = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path")

mock_memory = MagicMock()
mock_storage_io = AsyncMock()
mock_storage_io.write_file_async.side_effect = RuntimeError("Azure upload failed")
mock_memory.results_storage_io = mock_storage_io

azure_url = "https://account.blob.core.windows.net/container/audio/test.wav"

# Record existing wav files BEFORE test runs
existing_wav_files = set(tmp_path.glob("*.wav"))

with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory):
with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url):
with patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path):
with pytest.raises(RuntimeError, match="Azure upload failed"):
await serializer.save_formatted_audio_async(data=b"\x00\x01\x02")

# Check no NEW wav files leaked after test
leaked_files = set(tmp_path.glob("*.wav")) - existing_wav_files
assert leaked_files == set(), f"Temp files leaked: {leaked_files}"
Loading