diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 9898fe382d..2e9fc46ee8 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -4,7 +4,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from pyrit.memory import CentralMemory from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -12,8 +12,8 @@ PromptConverterConfiguration, ) from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer +from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.prompt_target.common.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -213,8 +213,8 @@ def set_system_prompt( async def update_conversation_state_async( self, *, + target: PromptTarget, conversation_id: str, - target: Optional[Union[PromptTarget, PromptChatTarget]] = None, prepended_conversation: List[Message], request_converters: Optional[List[PromptConverterConfiguration]] = None, response_converters: Optional[List[PromptConverterConfiguration]] = None, @@ -237,9 +237,9 @@ async def update_conversation_state_async( and extracts per-session counters such as the current turn index. Args: + target (PromptTarget): The target for which the conversation is being prepared. + Used to validate that prepended_conversation is compatible with the target type. conversation_id (str): Unique identifier for the conversation to update or create. - target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if - applicable). prepended_conversation (List[Message]): List of messages to prepend to the conversation history. request_converters (Optional[List[PromptConverterConfiguration]]): @@ -254,12 +254,21 @@ async def update_conversation_state_async( messages, including turn count and last user message. Raises: - ValueError: If `conversation_id` is empty or if the last message in a multi-turn - context is a user message (which should not be prepended). + ValueError: If `conversation_id` is empty, if the last message in a multi-turn + context is a user message (which should not be prepended), or if + prepended_conversation is provided with a non-PromptChatTarget target. """ if not conversation_id: raise ValueError("conversation_id cannot be empty") + # Validate prepended_conversation compatibility with target type + # Non-chat targets do not read conversation history from memory + if prepended_conversation and not isinstance(target, PromptChatTarget): + raise ValueError( + "prepended_conversation requires target to be a PromptChatTarget. " + "Non-chat targets do not support explicit conversation history management." + ) + # Initialize conversation state state = ConversationState() logger.debug(f"Preparing conversation with ID: {conversation_id}") @@ -304,7 +313,6 @@ async def update_conversation_state_async( request=request, conversation_id=conversation_id, conversation_state=state, - target=target, max_turns=max_turns, ) @@ -364,7 +372,6 @@ async def _process_prepended_message_async( request: Message, conversation_id: str, conversation_state: ConversationState, - target: Optional[Union[PromptTarget, PromptChatTarget]] = None, max_turns: Optional[int] = None, ) -> None: """ @@ -376,39 +383,27 @@ async def _process_prepended_message_async( request (Message): The request containing pieces to process. conversation_id (str): The ID of the conversation to update. conversation_state (ConversationState): The current state of the conversation. - target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if - applicable). max_turns (Optional[int]): Maximum allowed turns for the conversation. - - Raises: - ValueError: If the request is invalid or if a system prompt is provided but target doesn't support it. """ # Validate the request before processing if not request or not request.message_pieces: return # Set the conversation ID and attack ID for each piece in the request - save_to_memory = True for piece in request.message_pieces: piece.conversation_id = conversation_id piece.attack_identifier = self._attack_identifier piece.id = uuid.uuid4() - # Process the piece based on its role + # Process the piece based on its role (validates turn count for multi-turn) self._process_piece( piece=piece, conversation_state=conversation_state, max_turns=max_turns, - target=target, ) - if ConversationManager._should_exclude_piece_from_memory(piece=piece, max_turns=max_turns): - # it is excluded, so we don't want to save it to memory - save_to_memory = False - - # Add the request to memory if it was not a system piece - if save_to_memory: - self._memory.add_message_to_memory(request=request) + # Add the request to memory + self._memory.add_message_to_memory(request=request) def _process_piece( self, @@ -416,47 +411,28 @@ def _process_piece( piece: MessagePiece, conversation_state: ConversationState, max_turns: Optional[int] = None, - target: Optional[Union[PromptTarget, PromptChatTarget]] = None, ) -> None: """ Process a message piece based on its role and update conversation state. + For multi-turn conversations, this validates that the turn count doesn't exceed + max_turns. Only assistant messages count as turns. + Args: piece (MessagePiece): The piece to process. conversation_state (ConversationState): The current state of the conversation. max_turns (Optional[int]): Maximum allowed turns (for validation). - target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on. Raises: ValueError: If max_turns would be exceeded by this piece. - ValueError: If a system prompt is provided but target doesn't support it. """ - # Check if multiturn is_multi_turn = max_turns is not None - # Handle system prompts (both single-turn and multi-turn) - if piece.role == "system": - if target is None: - raise ValueError("Target must be provided to handle system prompts") - - if not isinstance(target, PromptChatTarget): - raise ValueError("Target must be a PromptChatTarget to set system prompts") - - # Set system prompt and exclude from memory - self.set_system_prompt( - target=target, - conversation_id=piece.conversation_id, - system_prompt=piece.converted_value, - labels=piece.labels, - ) - - # Handle assistant messages (count turns for multi-turn only) - elif piece.role == "assistant" and is_multi_turn: - # Update turn count + # Only assistant messages count as turns + if piece.role == "assistant" and is_multi_turn: conversation_state.turn_count += 1 - # Validate against max_turns - if max_turns and conversation_state.turn_count > max_turns: + if conversation_state.turn_count > max_turns: raise ValueError( f"The number of turns in the prepended conversation ({conversation_state.turn_count-1}) is equal to" + f" or exceeds the maximum number of turns ({max_turns}), which means the" @@ -464,12 +440,6 @@ def _process_piece( + " the prepended conversation or increase the maximum number of turns and try again." ) - @staticmethod - def _should_exclude_piece_from_memory(*, piece: MessagePiece, max_turns: Optional[int] = None) -> bool: - # System pieces should always be excluded from memory because set_system_prompt function - # is called on the target, which internally adds them to memory - return piece.role == "system" - async def _populate_conversation_state_async( self, *, diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 967406e30a..4759b7fb49 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -279,8 +279,8 @@ async def initialize_with_prepended_conversation_async( # Add to objective target conversation (handles system prompts and memory) await conversation_manager.update_conversation_state_async( - conversation_id=self.objective_target_conversation_id, target=self._objective_target, + conversation_id=self.objective_target_conversation_id, prepended_conversation=prepended_conversation, ) diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index a9fb1fb6fa..73e32018e8 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -266,22 +266,24 @@ class TestConversationStateUpdate: @pytest.mark.asyncio async def test_update_conversation_state_with_empty_conversation_id_raises_error( - self, attack_identifier: dict[str, str] + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock ): manager = ConversationManager(attack_identifier=attack_identifier) with pytest.raises(ValueError, match="conversation_id cannot be empty"): - await manager.update_conversation_state_async(conversation_id="", prepended_conversation=[]) + await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id="", prepended_conversation=[] + ) @pytest.mark.asyncio async def test_update_conversation_state_with_empty_history_returns_default_state( - self, attack_identifier: dict[str, str] + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=[] + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=[] ) assert isinstance(state, ConversationState) @@ -291,14 +293,14 @@ async def test_update_conversation_state_with_empty_history_returns_default_stat @pytest.mark.asyncio async def test_update_conversation_state_single_turn_mode( - self, attack_identifier: dict[str, str], sample_conversation: list[Message] + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_conversation: list[Message] ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) # Single-turn mode (no max_turns) state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=sample_conversation + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=sample_conversation ) # Verify all messages were added to memory @@ -311,7 +313,7 @@ async def test_update_conversation_state_single_turn_mode( @pytest.mark.asyncio async def test_update_conversation_state_multi_turn_mode_excludes_last_user_message( - self, attack_identifier: dict[str, str], sample_user_piece: MessagePiece + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_user_piece: MessagePiece ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -320,7 +322,7 @@ async def test_update_conversation_state_multi_turn_mode_excludes_last_user_mess conversation = [Message(message_pieces=[sample_user_piece])] state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 ) # Last user message should be excluded from memory in multi-turn mode @@ -334,6 +336,7 @@ async def test_update_conversation_state_multi_turn_mode_excludes_last_user_mess async def test_update_conversation_state_with_role_specific_converters( self, attack_identifier: dict[str, str], + mock_chat_target: MagicMock, mock_prompt_normalizer: MagicMock, sample_conversation: list[Message], ): @@ -345,6 +348,7 @@ async def test_update_conversation_state_with_role_specific_converters( response_converter_config = [PromptConverterConfiguration(converters=[])] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=sample_conversation, request_converters=request_converter_config, @@ -389,9 +393,9 @@ async def test_update_conversation_state_system_messages_no_converters( response_converter_config = [PromptConverterConfiguration(converters=[])] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_chat_target, request_converters=request_converter_config, response_converters=response_converter_config, max_turns=5, # Multi-turn mode to trigger system prompt handling @@ -404,6 +408,7 @@ async def test_update_conversation_state_system_messages_no_converters( async def test_update_conversation_state_processes_system_prompts_multi_turn( self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_system_piece: MessagePiece ): + """Test that system messages are added to memory like any other message.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -411,24 +416,22 @@ async def test_update_conversation_state_processes_system_prompts_multi_turn( conversation = [Message(message_pieces=[sample_system_piece])] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_chat_target, - max_turns=5, # Multi-turn mode to trigger system prompt handling + max_turns=5, # Multi-turn mode ) - # System prompt should be set on target - mock_chat_target.set_system_prompt.assert_called_once() - - # System messages should not be added to memory in multi-turn mode + # System messages SHOULD be added to memory stored_conversation = manager.get_conversation(conversation_id) - assert len(stored_conversation) == 0 + assert len(stored_conversation) == 1 + assert stored_conversation[0].get_piece().role == "system" @pytest.mark.asyncio async def test_update_conversation_state_processes_system_prompts_single_turn( self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_system_piece: MessagePiece ): - """Test that system messages in single-turn mode are NOT added to memory""" + """Test that system messages in single-turn mode are added to memory.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -436,14 +439,16 @@ async def test_update_conversation_state_processes_system_prompts_single_turn( conversation = [Message(message_pieces=[sample_system_piece])] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_chat_target, # No max_turns = single-turn mode ) - # System prompt should be set on target - mock_chat_target.set_system_prompt.assert_called_once() + # System messages SHOULD be added to memory + stored_conversation = manager.get_conversation(conversation_id) + assert len(stored_conversation) == 1 + assert stored_conversation[0].get_piece().role == "system" @pytest.mark.asyncio async def test_update_conversation_state_single_turn_behavior_matches_legacy( @@ -454,7 +459,7 @@ async def test_update_conversation_state_single_turn_behavior_matches_legacy( sample_assistant_piece: MessagePiece, sample_system_piece: MessagePiece, ): - """Test that single-turn behavior correctly excludes system messages from memory""" + """Test that all message types are correctly saved to memory.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -466,21 +471,21 @@ async def test_update_conversation_state_single_turn_behavior_matches_legacy( ] # Store original IDs to verify they get updated - # Since we are mocking the target, the system piece won't be stored, so we only check user and assistant original_user_id = sample_user_piece.id original_assistant_id = sample_assistant_piece.id await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_chat_target, # No max_turns = single-turn mode ) + # All messages including system should be stored stored_conversation = manager.get_conversation(conversation_id) - assert len(stored_conversation) == 2 + assert len(stored_conversation) == 3 - # Verify that user and assistant pieces have the correct conversation_id and attack_identifier + # Verify that all pieces have the correct conversation_id and attack_identifier for stored_response in stored_conversation: for piece in stored_response.message_pieces: assert piece.conversation_id == conversation_id @@ -488,49 +493,71 @@ async def test_update_conversation_state_single_turn_behavior_matches_legacy( # Verify that IDs were regenerated assert piece.id != original_user_id assert piece.id != original_assistant_id - # System piece should not be in memory, since we mocked the target - # Verify roles are preserved and in order (excluding system) - assert stored_conversation[0].get_piece().role == "user" - assert stored_conversation[1].get_piece().role == "assistant" - - # System prompt should still be set on target even in single-turn mode - mock_chat_target.set_system_prompt.assert_called_once() + # Verify roles are preserved and in order + assert stored_conversation[0].get_piece().role == "system" + assert stored_conversation[1].get_piece().role == "user" + assert stored_conversation[2].get_piece().role == "assistant" @pytest.mark.asyncio - async def test_update_conversation_state_system_prompt_without_target_raises_error( - self, attack_identifier: dict[str, str], sample_system_piece: MessagePiece + async def test_update_conversation_state_system_message_without_target_succeeds( + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_system_piece: MessagePiece ): - """Test that providing system prompts without a target raises an error""" + """Test that system messages work fine with a chat target - they're saved to memory.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) # Create conversation with system message conversation = [Message(message_pieces=[sample_system_piece])] - with pytest.raises(ValueError, match="Target must be provided to handle system prompts"): - await manager.update_conversation_state_async( - conversation_id=conversation_id, - prepended_conversation=conversation, - # No target provided - ) + # Should succeed with a chat target + await manager.update_conversation_state_async( + target=mock_chat_target, + conversation_id=conversation_id, + prepended_conversation=conversation, + ) + + # System message should be in memory + stored_conversation = manager.get_conversation(conversation_id) + assert len(stored_conversation) == 1 + assert stored_conversation[0].get_piece().role == "system" @pytest.mark.asyncio - async def test_update_conversation_state_system_prompt_with_non_chat_target_raises_error( + async def test_update_conversation_state_system_message_with_non_chat_target_succeeds( self, attack_identifier: dict[str, str], mock_prompt_target: MagicMock, sample_system_piece: MessagePiece ): - """Test that providing system prompts with non-chat target raises an error""" + """Test that empty prepended_conversation works fine with non-chat target.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) - # Create conversation with system message - conversation = [Message(message_pieces=[sample_system_piece])] + # Should succeed with empty prepended_conversation + await manager.update_conversation_state_async( + target=mock_prompt_target, + conversation_id=conversation_id, + prepended_conversation=[], + ) + + # Empty prepended_conversation means no messages stored + stored_conversation = manager.get_conversation(conversation_id) + assert len(stored_conversation) == 0 + + @pytest.mark.asyncio + async def test_update_conversation_state_with_non_chat_target_and_prepended_conversation_raises_error( + self, attack_identifier: dict[str, str], mock_prompt_target: MagicMock, sample_user_piece: MessagePiece + ): + """Test that prepended_conversation with non-chat target raises ValueError.""" + manager = ConversationManager(attack_identifier=attack_identifier) + conversation_id = str(uuid.uuid4()) - with pytest.raises(ValueError, match="Target must be a PromptChatTarget to set system prompts"): + # Create conversation with messages + conversation = [Message(message_pieces=[sample_user_piece])] + + # Should raise ValueError because non-chat targets don't support conversation history + with pytest.raises(ValueError, match="prepended_conversation requires target to be a PromptChatTarget"): await manager.update_conversation_state_async( + target=mock_prompt_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_prompt_target, # Non-chat target ) @pytest.mark.asyncio @@ -542,7 +569,7 @@ async def test_update_conversation_state_mixed_conversation_multi_turn( sample_assistant_piece: MessagePiece, sample_system_piece: MessagePiece, ): - """Test that in multi-turn mode, system prompts are excluded but other messages are added""" + """Test that in multi-turn mode, all messages including system are saved to memory.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -554,26 +581,24 @@ async def test_update_conversation_state_mixed_conversation_multi_turn( ] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, - target=mock_chat_target, max_turns=5, # Multi-turn mode ) - # System prompt should be set on target - mock_chat_target.set_system_prompt.assert_called_once() - - # Only user and assistant messages should be in memory - # Since the target is mocked, the system piece won't be stored + # All messages including system should be in memory stored_conversation = manager.get_conversation(conversation_id) - assert len(stored_conversation) == 2 - assert stored_conversation[0].get_piece().role == "user" - assert stored_conversation[1].get_piece().role == "assistant" + assert len(stored_conversation) == 3 + assert stored_conversation[0].get_piece().role == "system" + assert stored_conversation[1].get_piece().role == "user" + assert stored_conversation[2].get_piece().role == "assistant" @pytest.mark.asyncio async def test_update_conversation_state_preserves_original_values_like_legacy( self, attack_identifier: dict[str, str], + mock_chat_target: MagicMock, sample_user_piece: MessagePiece, ): """Test that original values and other piece properties are preserved like the legacy function""" @@ -589,6 +614,7 @@ async def test_update_conversation_state_preserves_original_values_like_legacy( conversation = [Message(message_pieces=[sample_user_piece])] await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, # Single-turn mode @@ -610,6 +636,7 @@ async def test_update_conversation_state_preserves_original_values_like_legacy( async def test_update_conversation_state_counts_turns_correctly( self, attack_identifier: dict[str, str], + mock_chat_target: MagicMock, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece, ): @@ -625,7 +652,7 @@ async def test_update_conversation_state_counts_turns_correctly( ] state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 ) assert state.turn_count == 2 @@ -634,6 +661,7 @@ async def test_update_conversation_state_counts_turns_correctly( async def test_update_conversation_state_exceeds_max_turns_raises_error( self, attack_identifier: dict[str, str], + mock_chat_target: MagicMock, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece, ): @@ -650,6 +678,7 @@ async def test_update_conversation_state_exceeds_max_turns_raises_error( with pytest.raises(ValueError, match="exceeds the maximum number of turns"): await manager.update_conversation_state_async( + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, max_turns=1, # Only allow 1 turn @@ -659,6 +688,7 @@ async def test_update_conversation_state_exceeds_max_turns_raises_error( async def test_update_conversation_state_extracts_assistant_scores( self, attack_identifier: dict[str, str], + mock_chat_target: MagicMock, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece, sample_score: Score, @@ -690,7 +720,7 @@ async def test_update_conversation_state_extracts_assistant_scores( ] state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 ) # Should extract scores for last assistant message @@ -699,7 +729,7 @@ async def test_update_conversation_state_extracts_assistant_scores( @pytest.mark.asyncio async def test_update_conversation_state_no_scores_for_assistant_message( - self, attack_identifier: dict[str, str], sample_assistant_piece: MessagePiece + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_assistant_piece: MessagePiece ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -710,7 +740,7 @@ async def test_update_conversation_state_no_scores_for_assistant_message( ] state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 ) # Should not set last_user_message when no scores found @@ -719,7 +749,11 @@ async def test_update_conversation_state_no_scores_for_assistant_message( @pytest.mark.asyncio async def test_update_conversation_state_assistant_without_preceding_user_raises_error( - self, attack_identifier: dict[str, str], sample_assistant_piece: MessagePiece, sample_score: Score + self, + attack_identifier: dict[str, str], + mock_chat_target: MagicMock, + sample_assistant_piece: MessagePiece, + sample_score: Score, ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -741,37 +775,13 @@ async def test_update_conversation_state_assistant_without_preceding_user_raises with pytest.raises(ValueError, match="There must be a user message preceding"): await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation, max_turns=5 + target=mock_chat_target, + conversation_id=conversation_id, + prepended_conversation=conversation, + max_turns=5, ) -class TestPrivateMethods: - """Tests for private helper methods""" - - def test_should_exclude_piece_from_memory_single_turn_mode(self, sample_system_piece: MessagePiece): - # System pieces should be excluded in both single-turn and multi-turn modes - # because set_system_prompt() is called on the target, which internally adds them to memory - assert ConversationManager._should_exclude_piece_from_memory(piece=sample_system_piece, max_turns=None) - - def test_should_exclude_piece_from_memory_multi_turn_system_piece(self, sample_system_piece: MessagePiece): - # System pieces should be excluded in both single-turn and multi-turn modes - assert ConversationManager._should_exclude_piece_from_memory(piece=sample_system_piece, max_turns=5) - - def test_should_exclude_piece_from_memory_single_turn_non_system_piece( - self, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece - ): - # In single-turn mode, non-system pieces should not be excluded - assert not ConversationManager._should_exclude_piece_from_memory(piece=sample_user_piece, max_turns=None) - assert not ConversationManager._should_exclude_piece_from_memory(piece=sample_assistant_piece, max_turns=None) - - def test_should_exclude_piece_from_memory_multi_turn_non_system_piece( - self, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece - ): - # In multi-turn mode, non-system pieces should not be excluded - assert not ConversationManager._should_exclude_piece_from_memory(piece=sample_user_piece, max_turns=5) - assert not ConversationManager._should_exclude_piece_from_memory(piece=sample_assistant_piece, max_turns=5) - - @pytest.mark.usefixtures("patch_central_database") class TestEdgeCasesAndErrorHandling: """Tests for edge cases and error handling""" @@ -783,7 +793,9 @@ async def test_update_conversation_state_with_empty_message_pieces(self, attack_ Message(message_pieces=[]) @pytest.mark.asyncio - async def test_update_conversation_state_with_none_request(self, attack_identifier: dict[str, str]): + async def test_update_conversation_state_with_none_request( + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock + ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -792,14 +804,16 @@ async def test_update_conversation_state_with_none_request(self, attack_identifi # Should handle gracefully state = await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation # type: ignore + target=mock_chat_target, + conversation_id=conversation_id, + prepended_conversation=conversation, # type: ignore ) assert state.turn_count == 0 @pytest.mark.asyncio async def test_update_conversation_state_preserves_piece_metadata( - self, attack_identifier: dict[str, str], sample_user_piece: MessagePiece + self, attack_identifier: dict[str, str], mock_chat_target: MagicMock, sample_user_piece: MessagePiece ): manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -811,7 +825,7 @@ async def test_update_conversation_state_preserves_piece_metadata( conversation = [Message(message_pieces=[sample_user_piece])] await manager.update_conversation_state_async( - conversation_id=conversation_id, prepended_conversation=conversation + target=mock_chat_target, conversation_id=conversation_id, prepended_conversation=conversation ) # Verify piece was processed with metadata intact