Skip to content

Commit 456afd9

Browse files
authored
Params4bit added to bnb classes in set_module_tensor_to_device() (#2315)
1 parent 0d2280d commit 456afd9

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/accelerate/utils/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def set_module_tensor_to_device(
328328
param is not None
329329
and param.device.type != "cuda"
330330
and torch.device(device).type == "cuda"
331-
and param_cls.__name__ in ["Int8Params", "FP4Params"]
331+
and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
332332
):
333333
device_quantization = device
334334
device = "cpu"

0 commit comments

Comments
 (0)