Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import hashlib
import os
import tempfile
import time
import wave
from mimetypes import guess_type
Expand Down Expand Up @@ -194,19 +195,24 @@ async def save_formatted_audio(

# save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters
if self._is_azure_storage_url(str(file_path)):
local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav")
with wave.open(str(local_temp_path), "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
wav_file.writeframes(data)

async with aiofiles.open(local_temp_path, "rb") as f:
audio_data = await f.read()
with tempfile.NamedTemporaryFile(
suffix=".wav", dir=DB_DATA_PATH, delete=False
) as tmp:
local_temp_path = Path(tmp.name)

try:
with wave.open(str(local_temp_path), "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
wav_file.writeframes(data)
async with aiofiles.open(local_temp_path, "rb") as f:
audio_data = await f.read()
if self._memory.results_storage_io is None:
raise RuntimeError("self._memory.results_storage_io is not initialized")
await self._memory.results_storage_io.write_file(file_path, audio_data)
os.remove(local_temp_path)
finally:
local_temp_path.unlink(missing_ok=True)

# If local, we can just save straight to disk and do not need to delete temp file after
else:
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/models/test_data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import pytest
from PIL import Image
import glob
from pyrit.common.path import DB_DATA_PATH

from pyrit.models import (
AllowedCategories,
Expand Down Expand Up @@ -426,3 +428,29 @@ async def test_get_data_filename_uses_db_data_path_when_results_path_falsy():
result_str = str(result).replace("\\", "/")
assert "/fallback/db_data" in result_str
assert result_str.endswith(".png")


@pytest.mark.asyncio
async def test_save_formatted_audio_cleans_up_temp_file_on_azure_upload_failure(patch_central_database):
"""Regression test: temp file must be deleted even when Azure upload fails."""
serializer = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path")

mock_memory = MagicMock()
mock_storage_io = AsyncMock()
mock_storage_io.write_file.side_effect = RuntimeError("Azure upload failed")
mock_memory.results_storage_io = mock_storage_io

azure_url = "https://account.blob.core.windows.net/container/audio/test.wav"

# Record existing wav files BEFORE test runs
existing_wav_files = set(glob.glob(str(DB_DATA_PATH / "*.wav")))

with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory):
with patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url):
with pytest.raises(RuntimeError, match="Azure upload failed"):
await serializer.save_formatted_audio(data=b"\x00\x01\x02")

# Check no NEW wav files leaked after test
leaked_files = set(glob.glob(str(DB_DATA_PATH / "*.wav"))) - existing_wav_files
assert len(leaked_files) == 0, f"Temp files leaked: {leaked_files}"