Skip to content

Update mhc_pre hip kernel support hc_head#3044

Merged
valarLip merged 1 commit into
mainfrom
jun/hc_head
May 6, 2026
Merged

Update mhc_pre hip kernel support hc_head#3044
valarLip merged 1 commit into
mainfrom
jun/hc_head

Conversation

@junhaha666

Copy link
Copy Markdown
Contributor

test hc_head cmd: python3 op_tests/test_mhc.py --hc_head

@junhaha666 junhaha666 requested review from a team and Copilot May 6, 2026 06:11
@github-actions

github-actions Bot commented May 6, 2026

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3044 --add-label <label>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an hc_head / “pre-only” execution mode for the MHC pre-kernel path (skipping post/comb mix + Sinkhorn), and updates the corresponding Python test harness to exercise it via a new CLI flag.

Changes:

  • Extend mhc_pre to support sinkhorn_repeat == 0 as an hc_head-only mode (allowing fn.size(0) == hc_mult).
  • Update the HIP kernel to skip post/comb writes when sinkhorn_repeat == 0, and tweak GEMM-SQRSUM dispatch heuristics for small hc_mult3.
  • Add --hc_head to op_tests/test_mhc.py to run the pre-only configuration.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
op_tests/test_mhc.py Adds --hc_head flag and adjusts reference/test logic to support pre-only validation.
csrc/kernels/mhc_kernels.cu Updates kernel dispatch and gates post/comb computation on sinkhorn_repeat > 0.
aiter/ops/mhc.py Adds defaults and allows hc_head mode via sinkhorn_repeat == 0 + relaxed hc_mult3 assertion.
Comments suppressed due to low confidence (2)

csrc/kernels/mhc_kernels.cu:253

  • In MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH, the condition else if (tile_k == 128 || hc_mult3 <= 16) can bypass the tile_k validation: if a caller passes an unsupported tile_k while hc_mult3 <= 16, it will still dispatch the 128-kernel instead of throwing. Please keep the tile_k check strict (64 vs 128) and handle the hc_mult3 <= 16 special-casing inside the valid branches so invalid tile_k still errors out deterministically.
#define MHC_PRE_GEMM_SQRSUM_KERNEL_DISPATCH(tile_k) \
    if (tile_k == 64) { \
        if (cu_num * 2 > m_blocks * split_k || hc_mult3 <= 16) { \
            MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 64); \
        } else { \
            MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 64); \
        } \
    } else if (tile_k == 128 || hc_mult3 <= 16) { \
        if (cu_num > m_blocks * split_k) { \
            MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 16, 128); \
        } else { \
            MHC_PRE_GEMM_SQRSUM_KERNEL_IMPL(256, 32, 128); \
        } \
    } else { \
        TORCH_CHECK(false, "tile_k must be 64 or 128"); \
    }

op_tests/test_mhc.py:398

  • mhc_pre_ref() is annotated to return tuple[torch.Tensor, torch.Tensor, torch.Tensor], but in the test_hc_head path it returns None for post_mix and res_mix. Please update the type annotation to reflect the actual behavior (e.g., use Optional[torch.Tensor] for the first two elements) to avoid misleading type hints and downstream tooling issues.
# copy from tilelang/examples/deepseek_mhc/example_mhc_pre.py
def mhc_pre_ref(
    residual: torch.Tensor,
    fn: torch.Tensor,
    hc_scale: torch.Tensor,
    hc_base: torch.Tensor,
    rms_eps: float,
    hc_pre_eps: float,
    hc_sinkhorn_eps: float,
    hc_post_mult_value: float,
    sinkhorn_repeat: int,
    test_hc_head: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    hc_mult = residual.shape[-2]

    residual_flat = residual.flatten(-2, -1).float()
    sqrsum = residual_flat.square().sum(-1)
    out = residual_flat @ fn.T
    mixes = out * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt()

    if not test_hc_head:
        hc_scale = torch.cat(
            [
                hc_scale[0].expand(hc_mult),
                hc_scale[1].expand(hc_mult),
                hc_scale[2].expand(hc_mult * hc_mult),
            ],
        )
        mixes = mixes * hc_scale + hc_base

        pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps
        post_mix = (
            mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value
        ).unsqueeze(-1)
        res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult)

        def sinkhorn_normalize_ref(
            x: torch.Tensor, repeat: int, eps: float
        ) -> torch.Tensor:
            x = x.softmax(-1) + eps
            x = x / (x.sum(-2, keepdim=True) + eps)
            for _ in range(repeat - 1):
                x = x / (x.sum(-1, keepdim=True) + eps)
                x = x / (x.sum(-2, keepdim=True) + eps)
            return x

        res_mix = sinkhorn_normalize_ref(
            res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps
        )
    else:
        hc_scale = hc_scale[0].expand(hc_mult)
        mixes = mixes * hc_scale + hc_base
        pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps
        post_mix = None
        res_mix = None

    layer_input = (residual * pre_mix).sum(-2).bfloat16()

    return post_mix, res_mix, layer_input

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/mhc.py
Comment on lines +47 to 60
rms_eps: float = 1e-6,
hc_pre_eps: float = 1e-6,
hc_sinkhorn_eps: float = 1e-6,
hc_post_mult_value: float = 1.0,
sinkhorn_repeat: int = 20, # if 0, only do pre for hc_head
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = residual.size(0)
hc_mult = residual.size(1)
hidden_size = residual.size(2)
hc_mult3 = fn.size(0)
assert hc_mult3 == hc_mult * 2 + hc_mult * hc_mult
assert hc_mult3 == hc_mult * 2 + hc_mult * hc_mult or (
hc_mult3 == hc_mult and sinkhorn_repeat == 0
)
hc_hidden_size = hc_mult * hidden_size
Comment thread aiter/ops/mhc.py
Comment on lines +51 to +59
sinkhorn_repeat: int = 20, # if 0, only do pre for hc_head
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = residual.size(0)
hc_mult = residual.size(1)
hidden_size = residual.size(2)
hc_mult3 = fn.size(0)
assert hc_mult3 == hc_mult * 2 + hc_mult * hc_mult
assert hc_mult3 == hc_mult * 2 + hc_mult * hc_mult or (
hc_mult3 == hc_mult and sinkhorn_repeat == 0
)
@valarLip valarLip merged commit ee62170 into main May 6, 2026
32 checks passed
@valarLip valarLip deleted the jun/hc_head branch May 6, 2026 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants