diff --git a/pyrit/auth/manual_copilot_authenticator.py b/pyrit/auth/manual_copilot_authenticator.py index 175db871e1..8dca7293df 100644 --- a/pyrit/auth/manual_copilot_authenticator.py +++ b/pyrit/auth/manual_copilot_authenticator.py @@ -59,7 +59,9 @@ def __init__(self, *, access_token: Optional[str] = None) -> None: self._access_token = resolved_token try: - self._claims = jwt.decode(resolved_token, algorithms=["RS256"], options={"verify_signature": False}) + self._claims: dict[str, Any] = jwt.decode( + resolved_token, algorithms=["RS256"], options={"verify_signature": False} + ) except jwt.exceptions.DecodeError as e: raise ValueError(f"Failed to decode access_token as JWT: {e}") @@ -97,7 +99,7 @@ async def get_claims(self) -> dict[str, Any]: Returns: dict[str, Any]: The JWT claims decoded from the access token. """ - return self._claims # type: ignore + return self._claims async def refresh_token_async(self) -> str: """ diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 84fcb93a2e..75228c0f2d 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -4,12 +4,13 @@ import logging import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from pyrit.common.utils import combine_dict from pyrit.executor.attack.component.prepended_conversation_config import ( PrependedConversationConfig, ) +from pyrit.identifiers import TargetIdentifier from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -54,7 +55,7 @@ def get_adversarial_chat_messages( *, adversarial_chat_conversation_id: str, attack_identifier: Dict[str, str], - adversarial_chat_target_identifier: Dict[str, str], + adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]], labels: Optional[Dict[str, str]] = None, ) -> List[Message]: """ diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index ce8813a232..b62e3b5234 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -126,7 +126,7 @@ def _load_tokenizer(model_name: str, token: Optional[str]) -> "PreTrainedTokeniz return cast( PreTrainedTokenizerBase, - AutoTokenizer.from_pretrained(model_name, token=token or None), # type: ignore[no-untyped-call] + AutoTokenizer.from_pretrained(model_name, token=token or None), # type: ignore[no-untyped-call, unused-ignore] ) @classmethod diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 5fe1104a8b..f320248e2c 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -5,11 +5,12 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast from transformers import ( AutoModelForCausalLM, AutoTokenizer, + BatchEncoding, PretrainedConfig, ) @@ -167,7 +168,7 @@ def _load_from_path(self, path: str, **kwargs: Any) -> None: **kwargs: Additional keyword arguments to pass to the model loader. """ logger.info(f"Loading model and tokenizer from path: {path}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] + self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] path, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=self.trust_remote_code, **kwargs) @@ -245,7 +246,7 @@ async def load_model_and_tokenizer(self) -> None: # Load the tokenizer and model from the specified directory logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call] + self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained( @@ -329,8 +330,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: generated_tokens = generated_ids[0][input_length:] # Decode the assistant's response from the generated token IDs - assistant_response = self.tokenizer.decode( - generated_tokens, skip_special_tokens=self.skip_special_tokens + assistant_response = cast( + str, + self.tokenizer.decode(generated_tokens, skip_special_tokens=self.skip_special_tokens), ).strip() if not assistant_response: @@ -369,12 +371,15 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: logger.info("Tokenizer has a chat template. Applying it to the input messages.") # Apply the chat template to format and tokenize the messages - tokenized_chat = self.tokenizer.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=True, - return_tensors=self.tensor_format, - return_dict=True, + tokenized_chat = cast( + BatchEncoding, + self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors=self.tensor_format, + return_dict=True, + ), ).to(self.device) return tokenized_chat else: