We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Params4bit
1 parent 0d2280d commit 456afd9Copy full SHA for 456afd9
1 file changed
src/accelerate/utils/modeling.py
@@ -328,7 +328,7 @@ def set_module_tensor_to_device(
328
param is not None
329
and param.device.type != "cuda"
330
and torch.device(device).type == "cuda"
331
- and param_cls.__name__ in ["Int8Params", "FP4Params"]
+ and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
332
):
333
device_quantization = device
334
device = "cpu"
0 commit comments