Skip to content

Fix dtype of weights in from_pretrained when device_map is set#20602

Merged
sgugger merged 1 commit into
mainfrom
fix_dtype_none
Dec 6, 2022
Merged

Fix dtype of weights in from_pretrained when device_map is set#20602
sgugger merged 1 commit into
mainfrom
fix_dtype_none

Conversation

@sgugger

@sgugger sgugger commented Dec 5, 2022

Copy link
Copy Markdown
Collaborator

What does this PR do?

As reported in #20390, the dtype of the weights after from_pretrained is used for a checkpoint is inconsistent between device_map=None or device_map set:

  • device_map=None (which uses nn.Module.laod_state_dict) will have the dtype of the model stay the same, even if the checkpoints are in a different dtype (so loading a float16 checkpoint in a float32 model gives a float32 model)
  • device_map set (which manually sets the parameters) will change the dtype of the model to the dtype of the checkpoint (so loading a float16 checkpoint in a float32 model gives a float16 model).

This PR addresses this.

@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Dec 5, 2022

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing and making model loading consistent between device_map=auto (or any) and device_map=None !
Just wondering if you need a special safety checker for safetensors (bear in mind that I am not very knowledgable about safetensors - what happens if old_params.dtype == torch.float16 and is_safetensors==True)

@sgugger

sgugger commented Dec 5, 2022

Copy link
Copy Markdown
Collaborator Author

There is no more safetensors at this stage, (is_safetensors means the checkpoint comes from safetensors, but the state dict is a dictionary name to parameter in this case as well).

@LysandreJik LysandreJik left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants