Add FP8 kernel acceleration for compressed-tensors quantized models#45699
Add FP8 kernel acceleration for compressed-tensors quantized models#45699jiqing-feng wants to merge 47 commits into
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
cc @SunMarc |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc . Please check it the integration is ok. I'll clean the tests and doc after you approved the integration. |
|
Hi @SunMarc . Would you please review the PR? Thanks! |
|
Hi @Rocketknight1 . It seems that @SunMarc does not have bandwidth to review this PR. Would you please help to review the PR? Thanks! |
stevhliu
left a comment
There was a problem hiding this comment.
thanks! one more minor change, otherwise docs lgtm :)
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
| else: | ||
| # CUDA SM80 (A100): no FP8 hardware, dequantize weight to BF16 + normal matmul | ||
| w = self.weight.to(input.dtype) * self.weight_scale_inv.to(input.dtype) | ||
| output = F.linear(input.view(-1, input.shape[-1]), w, self.bias) |
There was a problem hiding this comment.
a bit opinionated but i don't think we should have a dequant path here, for me it kidna goes againt the principle of quantizers. wdyt @SunMarc
There was a problem hiding this comment.
I agree with keeping the quantizer pure, but removing the dequant fallback conflicts with checking hardware at runtime (the ZeroGPU requirement).
If we delay the hardware check to forward() and get assigned an SM80 (A100) GPU at runtime, the model is already initialized with CompressedTensorsFP8Linear. Without the fallback, the forward pass will hard crash.
How should we handle this?
- Keep the fallback (and maybe add a
logger.warning_once). - Remove the fallback and raise a
RuntimeErrorif an unsupported GPU is detected at runtime.
There was a problem hiding this comment.
sorry for the late re-review 😅, overall pretty good, just a couple change requests around naming and a qustion around cpu fallback for @SunMarc
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
LGTM for the integration part ! cc @SunMarc for the quantizer part
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
Nice! Nit on walking / updating existing conversion
|
Hi @ArthurZucker . Would you please review my new commit to check if it fixed your comments? Thanks! |
|
Hi @SunMarc . Would you please let me know what need to be changed before merging? Thanks! |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for your work !
I think it will still be better if we don't do online quantization no ? Like users can just use compressed tensors to do it no ? If they want to use online fp8, they have finegrained-fp8 for that and we can update the support if needed. The thing is that I don't want to introduce a new way to create CT checkpoints + maintenance overhead for reverse ops that needs to match CT implementation and ours. Right now, what you did is to dequantize back to BF16 if someone wants to save the online quantized model.
Also, for the FP8 kernels, we should probably add a new arg like the other quantization method called dequantize. run_compressed don't work anymore as you saw. CT just decompresses the model on the first forward.
| x = input.reshape(-1, input.shape[-1]) | ||
| x_quantized, x_scale = _quantize_fp8_per_row(x) | ||
|
|
||
| weight_scale_float32 = self.weight_scale_inv.to(torch.float32) |
There was a problem hiding this comment.
The weight_scale will be in float32 if you specified it in the modeling so we don't need to do that
| weight_scale_float32 = self.weight_scale_inv.to(torch.float32) |
| scale_b = weight_scale_float32.t() | ||
| if scale_b.shape[-1] == 1 and self.out_features > 1: |
There was a problem hiding this comment.
maybe define a variable called is_per_tensor ?
|
|
||
| if _can_use_fp8_kernel(): | ||
| # XPU or CUDA SM89+: FP8 kernel path (quantize activation + scaled_mm) | ||
| x = input.reshape(-1, input.shape[-1]) |
There was a problem hiding this comment.
maybe we can reshape the input before as we do for both path ?
|
|
||
| module_kwargs = {} if pre_quantized else {"dtype": None} | ||
| if isinstance(module, nn.Linear): | ||
| with torch.device("meta"): |
There was a problem hiding this comment.
we don't need this normally as this method is already under a context manager that does this
| with torch.device("meta"): |
| def _is_fp8_config(quantization_config: CompressedTensorsConfig) -> bool: | ||
| """Check if a CompressedTensorsConfig describes FP8 quantization.""" | ||
| ct_qconfig = quantization_config.quantization_config | ||
| if ct_qconfig is None: | ||
| return False | ||
| for group in ct_qconfig.config_groups.values(): | ||
| weights = group.weights | ||
| if weights is not None and weights.type == "float" and weights.num_bits == 8: | ||
| return True | ||
| return False | ||
|
|
There was a problem hiding this comment.
we can move that to compressed_tensors config class maybe ?
| if self.is_fp8: | ||
| return False |
There was a problem hiding this comment.
depends if we dequantize or not also
| if self.is_fp8: | ||
| return False |
| class CompressedTensorsActivationScaleConvert(ConversionOps): | ||
| """Rename compressed-tensors `input_scale` to `activation_scale`.""" | ||
|
|
||
| def convert(self, input_dict, **kwargs): | ||
| scale = input_dict["input_scale"][0] | ||
| return {"activation_scale": scale.to(torch.float32)} | ||
|
|
||
| @property | ||
| def reverse_op(self): | ||
| return _IdentityOp() |
There was a problem hiding this comment.
we can just keep the same name no ?
| class CompressedTensorsScaleConvert(ConversionOps): | ||
| """Convert compressed-tensors `weight_scale` to `weight_scale_inv`. | ||
|
|
||
| In compressed-tensors, `weight_scale` is the dequantization multiplier: | ||
| bf16_weight = fp8_weight * weight_scale | ||
|
|
||
| In our CompressedTensorsFP8Linear, `weight_scale_inv` has the same semantics (it's | ||
| multiplied with the FP8 weight to get the dequantized value), so no inversion is needed. | ||
| The conversion also reshapes the scale: scalar → (1, 1), 1D (N,) → (N, 1). | ||
| """ | ||
|
|
There was a problem hiding this comment.
same here, we can keep the same name no ?
| class CompressedTensorsFp8Dequantize(ConversionOps): | ||
| """Dequantize compressed-tensors FP8 weights back to BF16. | ||
|
|
||
| Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight, | ||
| producing a BF16 tensor. Prepended to a converter chain for layers that | ||
| cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``): | ||
| it pairs each weight with its sibling scale *by index* and preserves the | ||
| per-expert list structure so the downstream merge / concat ops still see | ||
| one tensor per expert. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Do we really need this ? Like it is not really useful to online quantize a model to finally save it in bf16 no ?
With compressed tensors, it will dequantize the model in any case no if you specify run_compressed=False no for a quantized model.
There was a problem hiding this comment.
also can you explain the moe bit, i didn't fully understand
|
Hi @SunMarc, thanks for the review! Addressed everything; kept the MoE dequant-before-merge on purpose (explained below). Done
Side effect: The MoE dequant (comments 10 & 11) MoE checkpoints store per-expert weights+scales, but transformers merges the experts ( Known limitation: since merged experts land in BF16, this PR doesn't save memory on MoE expert weights (the bulk of params) — only dense models and the attention/router part of MoE benefit. Not fundamental: experts could stay FP8 with (1) a 3D FP8 expert param, (2) a stacked per-expert scale tensor, (3) a grouped scaled-mm in the MoE forward. That's a bigger change, so I'd suggest a follow-up and keeping this dequant path as the correct baseline. Happy to adjust naming or split differently — let me know! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: compressed_tensors_integration |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45699&sha=b73018 |
What does this PR do?
This PR adds native FP8 matmul kernel support for compressed-tensors FP8 quantized models in transformers. Previously, compressed-tensors FP8 models were loaded via the
compressed-tensorslibrary and dequantized back to FP16/BF16 for inference. With this change, FP8 weights are kept in FP8 format and inference uses hardware-accelerated FP8 matmul kernels (torch._scaled_mmon XPU,fbgemm.f8f8bf16_rowwiseon CUDA).Key changes:
New file:
src/transformers/integrations/compressed_tensors_fp8.pyCTFP8Linear: FP8 linear layer that stores weights in FP8 and uses row-wise FP8 matmul kernels. Activations are dynamically quantized per-row viaquantize_fp8_per_row.CompressedTensorsScaleConvert,CompressedTensorsFp8Dequantize) to handle the checkpoint format conversion (e.g.weight_scale→weight_scale_inv).CTFP8PerRowQuantize: Online quantization support — quantize BF16 weights to FP8 per-row on-the-fly during model loading.Modified:
src/transformers/quantizers/quantizer_compressed_tensors.pyCompressedTensorsHfQuantizernow detects FP8 quantization configs (floattype,num_bits=8) and automatically routes to the FP8 kernel path when GPU/XPU is available. Falls back to the default compressed-tensors dequantize path on CPU.get_weight_conversions()andget_quantize_ops()to support both pre-quantized loading and online quantization.Modified:
src/transformers/quantizers/auto.pySupported models
CompressedTensorsConfigwith FP8 quantization scheme.Usage
Pre-quantized model (no config needed)
Online quantization
Devices
torch._scaled_mmfbgemm.f8f8bf16_rowwise@sywangyi