Skip to content

Add FP8 kernel acceleration for compressed-tensors quantized models#45699

Open
jiqing-feng wants to merge 47 commits into
huggingface:mainfrom
jiqing-feng:fp8
Open

Add FP8 kernel acceleration for compressed-tensors quantized models#45699
jiqing-feng wants to merge 47 commits into
huggingface:mainfrom
jiqing-feng:fp8

Conversation

@jiqing-feng

@jiqing-feng jiqing-feng commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds native FP8 matmul kernel support for compressed-tensors FP8 quantized models in transformers. Previously, compressed-tensors FP8 models were loaded via the compressed-tensors library and dequantized back to FP16/BF16 for inference. With this change, FP8 weights are kept in FP8 format and inference uses hardware-accelerated FP8 matmul kernels (torch._scaled_mm on XPU, fbgemm.f8f8bf16_rowwise on CUDA).

Key changes:

New file: src/transformers/integrations/compressed_tensors_fp8.py

  • CTFP8Linear: FP8 linear layer that stores weights in FP8 and uses row-wise FP8 matmul kernels. Activations are dynamically quantized per-row via quantize_fp8_per_row.
  • Weight converters (CompressedTensorsScaleConvert, CompressedTensorsFp8Dequantize) to handle the checkpoint format conversion (e.g. weight_scaleweight_scale_inv).
  • CTFP8PerRowQuantize: Online quantization support — quantize BF16 weights to FP8 per-row on-the-fly during model loading.

Modified: src/transformers/quantizers/quantizer_compressed_tensors.py

  • CompressedTensorsHfQuantizer now detects FP8 quantization configs (float type, num_bits=8) and automatically routes to the FP8 kernel path when GPU/XPU is available. Falls back to the default compressed-tensors dequantize path on CPU.
  • Added get_weight_conversions() and get_quantize_ops() to support both pre-quantized loading and online quantization.
  • No changes to the non-FP8 code path — existing INT8/INT4 compressed-tensors models are unaffected.

Modified: src/transformers/quantizers/auto.py

  • Minor formatting change (no functional change).

Supported models

Usage

Pre-quantized model (no config needed)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Online quantization

from transformers import AutoModelForCausalLM, CompressedTensorsConfig
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs, QuantizationType, QuantizationStrategy

ct_config = CompressedTensorsConfig(
    config_groups={
        "group_0": QuantizationScheme(
            weights=QuantizationArgs(
                num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.CHANNEL,
            ),
        ),
    },
    run_compressed=True,
)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    quantization_config=ct_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Devices

  • XPU (Intel Data Center Max / Arc): uses torch._scaled_mm
  • CUDA (SM89+): uses fbgemm.f8f8bf16_rowwise
  • CPU: falls back to default compressed-tensors dequantize path

@sywangyi

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng changed the title Fp8 Add FP8 kernel acceleration for compressed-tensors quantized models Apr 29, 2026
@Rocketknight1

Copy link
Copy Markdown
Member

cc @SunMarc

@SunMarc SunMarc left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a comment !

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng marked this pull request as ready for review April 30, 2026 03:11
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Please check it the integration is ok. I'll clean the tests and doc after you approved the integration.

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review the PR? Thanks!

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @Rocketknight1 . It seems that @SunMarc does not have bandwidth to review this PR. Would you please help to review the PR? Thanks!

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/quantizers/auto.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py

@stevhliu stevhliu left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! one more minor change, otherwise docs lgtm :)

Comment thread docs/source/en/quantization/compressed_tensors.md Outdated
jiqing-feng and others added 2 commits June 3, 2026 09:15
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Comment on lines +137 to +140
else:
# CUDA SM80 (A100): no FP8 hardware, dequantize weight to BF16 + normal matmul
w = self.weight.to(input.dtype) * self.weight_scale_inv.to(input.dtype)
output = F.linear(input.view(-1, input.shape[-1]), w, self.bias)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit opinionated but i don't think we should have a dequant path here, for me it kidna goes againt the principle of quantizers. wdyt @SunMarc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with keeping the quantizer pure, but removing the dequant fallback conflicts with checking hardware at runtime (the ZeroGPU requirement).

If we delay the hardware check to forward() and get assigned an SM80 (A100) GPU at runtime, the model is already initialized with CompressedTensorsFP8Linear. Without the fallback, the forward pass will hard crash.

How should we handle this?

  1. Keep the fallback (and maybe add a logger.warning_once).
  2. Remove the fallback and raise a RuntimeError if an unsupported GPU is detected at runtime.

Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment thread docs/source/en/quantization/compressed_tensors.md Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the late re-review 😅, overall pretty good, just a couple change requests around naming and a qustion around cpu fallback for @SunMarc

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the integration part ! cc @SunMarc for the quantizer part

Comment thread src/transformers/quantizers/quantizer_compressed_tensors.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Nit on walking / updating existing conversion

Comment thread src/transformers/integrations/compressed_tensors_fp8.py Outdated
Comment thread src/transformers/integrations/compressed_tensors_fp8.py
Comment thread src/transformers/quantizers/quantizer_compressed_tensors.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @ArthurZucker . Would you please review my new commit to check if it fixed your comments? Thanks!

@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please let me know what need to be changed before merging? Thanks!

@SunMarc SunMarc left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work !

I think it will still be better if we don't do online quantization no ? Like users can just use compressed tensors to do it no ? If they want to use online fp8, they have finegrained-fp8 for that and we can update the support if needed. The thing is that I don't want to introduce a new way to create CT checkpoints + maintenance overhead for reverse ops that needs to match CT implementation and ours. Right now, what you did is to dequantize back to BF16 if someone wants to save the online quantized model.

Also, for the FP8 kernels, we should probably add a new arg like the other quantization method called dequantize. run_compressed don't work anymore as you saw. CT just decompresses the model on the first forward.

x = input.reshape(-1, input.shape[-1])
x_quantized, x_scale = _quantize_fp8_per_row(x)

weight_scale_float32 = self.weight_scale_inv.to(torch.float32)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight_scale will be in float32 if you specified it in the modeling so we don't need to do that

Suggested change
weight_scale_float32 = self.weight_scale_inv.to(torch.float32)

Comment on lines +123 to +124
scale_b = weight_scale_float32.t()
if scale_b.shape[-1] == 1 and self.out_features > 1:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe define a variable called is_per_tensor ?


if _can_use_fp8_kernel():
# XPU or CUDA SM89+: FP8 kernel path (quantize activation + scaled_mm)
x = input.reshape(-1, input.shape[-1])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can reshape the input before as we do for both path ?


module_kwargs = {} if pre_quantized else {"dtype": None}
if isinstance(module, nn.Linear):
with torch.device("meta"):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this normally as this method is already under a context manager that does this

Suggested change
with torch.device("meta"):

Comment on lines +28 to +38
def _is_fp8_config(quantization_config: CompressedTensorsConfig) -> bool:
"""Check if a CompressedTensorsConfig describes FP8 quantization."""
ct_qconfig = quantization_config.quantization_config
if ct_qconfig is None:
return False
for group in ct_qconfig.config_groups.values():
weights = group.weights
if weights is not None and weights.type == "float" and weights.num_bits == 8:
return True
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move that to compressed_tensors config class maybe ?

Comment on lines +171 to +172
if self.is_fp8:
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends if we dequantize or not also

Comment on lines +181 to +182
if self.is_fp8:
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +218 to +227
class CompressedTensorsActivationScaleConvert(ConversionOps):
"""Rename compressed-tensors `input_scale` to `activation_scale`."""

def convert(self, input_dict, **kwargs):
scale = input_dict["input_scale"][0]
return {"activation_scale": scale.to(torch.float32)}

@property
def reverse_op(self):
return _IdentityOp()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just keep the same name no ?

Comment on lines +183 to +193
class CompressedTensorsScaleConvert(ConversionOps):
"""Convert compressed-tensors `weight_scale` to `weight_scale_inv`.

In compressed-tensors, `weight_scale` is the dequantization multiplier:
bf16_weight = fp8_weight * weight_scale

In our CompressedTensorsFP8Linear, `weight_scale_inv` has the same semantics (it's
multiplied with the FP8 weight to get the dequantized value), so no inversion is needed.
The conversion also reshapes the scale: scalar → (1, 1), 1D (N,) → (N, 1).
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we can keep the same name no ?

Comment on lines +230 to +240
class CompressedTensorsFp8Dequantize(ConversionOps):
"""Dequantize compressed-tensors FP8 weights back to BF16.

Folds the per-channel / per-tensor ``weight_scale`` into the FP8 weight,
producing a BF16 tensor. Prepended to a converter chain for layers that
cannot stay in FP8 (e.g. merged MoE experts, which are not ``nn.Linear``):
it pairs each weight with its sibling scale *by index* and preserves the
per-expert list structure so the downstream merge / concat ops still see
one tensor per expert.
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this ? Like it is not really useful to online quantize a model to finally save it in bf16 no ?

With compressed tensors, it will dequantize the model in any case no if you specify run_compressed=False no for a quantized model.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you explain the moe bit, i didn't fully understand

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng

Copy link
Copy Markdown
Contributor Author

Hi @SunMarc, thanks for the review! Addressed everything; kept the MoE dequant-before-merge on purpose (explained below).

Done

  • No more online quant — removed CompressedTensorsFP8PerRowQuantize; get_quantize_ops returns None and param_needs_quantization is False for FP8. Online FP8 → use finegrained-fp8.
  • dequantize arg — added dequantize: bool = False to CompressedTensorsConfig (like the other methods), wired through to replace_with_compressed_tensors_fp8_linear. run_compressed no longer drives this.
  • (1) dropped the redundant .to(float32) (weight_scale is already f32).
  • (2) added an is_per_tensor variable.
  • (3) input is reshaped to 2D once, shared by both paths.
  • (4) removed with torch.device("meta").
  • (5) moved _is_fp8_configCompressedTensorsConfig.is_fp8.
  • (6) / (7) is_trainable / is_qat_trainable now return dequantize for FP8 (trainable only on the dequantized BF16 path).
  • (8) kept input_scale — the converter was dead code, removed it.
  • (9) kept weight_scale (no more weight_scale_inv); the converter only reshapes.

Side effect: CompressedTensorsFp8Dequantize.reverse_op is now _IdentityOp() (we never re-quantize on save).

The MoE dequant (comments 10 & 11)

MoE checkpoints store per-expert weights+scales, but transformers merges the experts (stack/cat) into a single 3D packed nn.Parameter — not an nn.Linear, so it can't hold a weight_scale, and the per-expert scales differ so they can't survive the merge. The only correct option is to dequantize each expert (weight * weight_scale) to BF16 before the merge: update_weight_conversions prepends CompressedTensorsFp8Dequantize to the merging converters, pairs weight+scale by expert index, drops the scales. Without it, MoE FP8 loading crashes. (Attention/router linears in the same model still stay FP8.) The dequantize=True flag reuses this same op to also fold plain linears to BF16 — that's the answer to comment 10.

Known limitation: since merged experts land in BF16, this PR doesn't save memory on MoE expert weights (the bulk of params) — only dense models and the attention/router part of MoE benefit. Not fundamental: experts could stay FP8 with (1) a 3D FP8 expert param, (2) a stacked per-expert scale tensor, (3) a grouped scaled-mm in the MoE forward. That's a bigger change, so I'd suggest a follow-up and keeping this dequant path as the correct baseline.

Happy to adjust naming or split differently — let me know!

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: compressed_tensors_integration

@github-actions

Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45699&sha=b73018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants