Skip to content

Fix OOMs happening in case of accelerate >= 0.16.0#310

Merged
borzunov merged 1 commit into
mainfrom
from-pretrained-dtype
Apr 25, 2023
Merged

Fix OOMs happening in case of accelerate >= 0.16.0#310
borzunov merged 1 commit into
mainfrom
from-pretrained-dtype

Conversation

@borzunov

@borzunov borzunov commented Apr 25, 2023

Copy link
Copy Markdown
Collaborator

@borzunov borzunov requested a review from mryab April 25, 2023 11:31

@mryab mryab left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a quick investigation!

Comment thread setup.cfg
torch>=1.12
bitsandbytes==0.38.0.post2
accelerate>=0.15.0,<1.0.0
accelerate>=0.16.0,<1.0.0

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_module_tensor_to_device's dtype arg didn't exist before 0.16.0

@borzunov borzunov changed the title Fix OOMs caused by dtype in load_pretrained_block() Fix OOMs happened with accelerate >= 0.16.0 Apr 25, 2023
@borzunov borzunov changed the title Fix OOMs happened with accelerate >= 0.16.0 Fix OOMs happened in case of accelerate >= 0.16.0 Apr 25, 2023
@borzunov borzunov changed the title Fix OOMs happened in case of accelerate >= 0.16.0 Fix OOMs happening in case of accelerate >= 0.16.0 Apr 25, 2023
@borzunov borzunov merged commit 454c193 into main Apr 25, 2023
@borzunov borzunov deleted the from-pretrained-dtype branch April 25, 2023 13:20
fthrvi pushed a commit to fthrvi/nakshatra that referenced this pull request Jun 1, 2026
…op#310)

- After bigscience-workshop#285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants