diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 578efca5cc..a5de83dbba 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -7,6 +7,7 @@ import base64 import hashlib import os +import tempfile import time import wave from mimetypes import guess_type @@ -194,19 +195,24 @@ async def save_formatted_audio( # save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters if self._is_azure_storage_url(str(file_path)): - local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav") - with wave.open(str(local_temp_path), "wb") as wav_file: - wav_file.setnchannels(num_channels) - wav_file.setsampwidth(sample_width) - wav_file.setframerate(sample_rate) - wav_file.writeframes(data) - - async with aiofiles.open(local_temp_path, "rb") as f: - audio_data = await f.read() + with tempfile.NamedTemporaryFile( + suffix=".wav", dir=DB_DATA_PATH, delete=False + ) as tmp: + local_temp_path = Path(tmp.name) + + try: + with wave.open(str(local_temp_path), "wb") as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(data) + async with aiofiles.open(local_temp_path, "rb") as f: + audio_data = await f.read() if self._memory.results_storage_io is None: raise RuntimeError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.write_file(file_path, audio_data) - os.remove(local_temp_path) + finally: + local_temp_path.unlink(missing_ok=True) # If local, we can just save straight to disk and do not need to delete temp file after else: diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/models/test_data_type_serializer.py index d710afd830..e59dc173e3 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/models/test_data_type_serializer.py @@ -10,6 +10,8 @@ import pytest from PIL import Image +import glob +from pyrit.common.path import DB_DATA_PATH from pyrit.models import ( AllowedCategories, @@ -426,3 +428,29 @@ async def test_get_data_filename_uses_db_data_path_when_results_path_falsy(): result_str = str(result).replace("\\", "/") assert "/fallback/db_data" in result_str assert result_str.endswith(".png") + + +@pytest.mark.asyncio +async def test_save_formatted_audio_cleans_up_temp_file_on_azure_upload_failure(patch_central_database): + """Regression test: temp file must be deleted even when Azure upload fails.""" + serializer = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path") + + mock_memory = MagicMock() + mock_storage_io = AsyncMock() + mock_storage_io.write_file.side_effect = RuntimeError("Azure upload failed") + mock_memory.results_storage_io = mock_storage_io + + azure_url = "https://account.blob.core.windows.net/container/audio/test.wav" + + # Record existing wav files BEFORE test runs + existing_wav_files = set(glob.glob(str(DB_DATA_PATH / "*.wav"))) + + with patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory): + with patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url): + with pytest.raises(RuntimeError, match="Azure upload failed"): + await serializer.save_formatted_audio(data=b"\x00\x01\x02") + + # Check no NEW wav files leaked after test + leaked_files = set(glob.glob(str(DB_DATA_PATH / "*.wav"))) - existing_wav_files + assert len(leaked_files) == 0, f"Temp files leaked: {leaked_files}" +