Skip to content

[Enhancement]Support mixed-sign ramp indices in LegalizeNegativeIndex#2225

Merged
LeiWang1999 merged 2 commits into
tile-ai:mainfrom
TerminusAkivili:enhance/legalize-mixed-sign-ramp
May 19, 2026
Merged

[Enhancement]Support mixed-sign ramp indices in LegalizeNegativeIndex#2225
LeiWang1999 merged 2 commits into
tile-ai:mainfrom
TerminusAkivili:enhance/legalize-mixed-sign-ramp

Conversation

@TerminusAkivili

@TerminusAkivili TerminusAkivili commented May 19, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Extend LegalizeNegativeIndex to handle constant Ramp indices with mixed negative and non-negative lanes.
  • Rewrite only lanes proven negative, preserving non-negative lanes.
  • Conservatively widen affected block read/write regions when the original region starts below zero.

Motivation

TileLang already legalizes Python-style negative indices in scalar and all-negative vector cases:

A[-1]
A[T.Ramp(-4, 1, 4)]

This PR extends the same behavior to constant mixed-sign ramps:

A[T.Ramp(-2, 1, 4)]  # lanes: [-2, -1, 0, 1]

The current pass classifies the whole ramp as unknown-sign and leaves it unchanged, even though each lane is constant and can be legalized independently.

Kernel

This minimal circular-buffer example exercises a vectorized load whose constant ramp crosses the zero boundary.

import torch
import tilelang
from tilelang import tvm as tvm
import tilelang.language as T


CACHE_LEN = 1024
WINDOW = 4
WRAP_LEFT = 2


@T.prim_func
def sliding_window_kv_cache_unwrap(
    K_cache: T.Tensor((CACHE_LEN,), T.float32),
    K_window: T.Tensor((WINDOW,), T.float32),
):
    """Gather a KV window that crosses the circular-cache boundary.

    Logical window:
      K_cache[-2], K_cache[-1], K_cache[0], K_cache[1]

    Expected physical slots:
      K_cache[1022], K_cache[1023], K_cache[0], K_cache[1]
    """
    with T.Kernel(1, threads=1) as _:
        logical_window = T.Ramp(-WRAP_LEFT, 1, WINDOW)
        K_window[T.Ramp(0, 1, WINDOW)] = K_cache[logical_window]


def show_lowered_ir():
    mod = tvm.IRModule.from_expr(
        sliding_window_kv_cache_unwrap.with_attr("global_symbol", "main")
    )
    lowered = tilelang.transform.LegalizeNegativeIndex()(mod)["main"]
    print(lowered.script())


def run_kernel():
    backing = torch.empty(CACHE_LEN + WRAP_LEFT, device="cuda", dtype=torch.float32)
    backing[0] = -2000.0
    backing[1] = -1000.0
    backing[WRAP_LEFT:] = torch.arange(CACHE_LEN, device="cuda", dtype=torch.float32)

    k_cache_view = backing[WRAP_LEFT:]
    k_window = torch.empty(WINDOW, device="cuda", dtype=torch.float32)
    expected = torch.tensor([1022.0, 1023.0, 0.0, 1.0], device="cuda")

    kernel = tilelang.compile(
        sliding_window_kv_cache_unwrap,
        target="cuda",
        execution_backend="nvrtc",
    )
    kernel(k_cache_view, k_window)
    torch.cuda.synchronize()

    print("actual  :", k_window.detach().cpu().tolist())
    print("expected:", expected.detach().cpu().tolist())
    assert torch.allclose(k_window, expected)


if __name__ == "__main__":
    show_lowered_ir()
    if torch.cuda.is_available():
        run_kernel()

Before this patch, the mixed-sign access remains unlegalized:

T.reads(K_cache[-2:2])
K_window[0:4] = K_cache[logical_window]

After this patch, the transform rewrites the lanes independently:

T.reads(K_cache[0:1024])
K_window[0:4] = K_cache[T.Shuffle([1022, 1023, 0, 1], [0, 1, 2, 3])]

Implementation

For unknown-sign constant ramps, the rewriter now checks each lane:

  • negative lane: rewrite to buffer_extent + lane
  • non-negative lane: keep unchanged
  • unknown lane: leave the original expression unchanged

The rewritten vector is emitted as T.Shuffle, preserving per-lane semantics while leaving truly dynamic unknown-sign cases untouched. Existing scalar-negative and all-negative vector paths are unchanged.

Summary by CodeRabbit

  • Bug Fixes

    • Improved legalization of mixed-sign vector indices for loads and stores, ensuring correct handling in vectorized operations and preserving integer-width of buffer extents.
    • Block-level buffer regions with provably negative axis minima are conservatively expanded to full regions to prevent incorrect accesses.
  • Tests

    • Added tests covering mixed-sign vectorized load/store legalization, including cases inside kernels and int64 extents.

Review Change Stack

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented May 19, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e91d403e-b22a-47c7-bd91-c16fdf5ba4f9

📥 Commits

Reviewing files that changed from the base of the PR and between 50f210b and 3593f67.

📒 Files selected for processing (2)
  • src/transform/legalize_negative_index.cc
  • testing/python/transform/test_tilelang_transform_legalize_negative_index.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/transform/legalize_negative_index.cc

📝 Walkthrough

Walkthrough

Rewriter now legalizes mixed-sign vector Ramp indices by per-lane expansion and Shuffle-based reconstruction, updates Block buffer regions to FullRegion when axis minima are provably negative, and adds tests covering vectorized load/store cases (including int64 extent preservation and Kernel context).

Changes

Mixed-sign ramp index legalization

Layer / File(s) Summary
Ramp rewrite & region helper
src/transform/legalize_negative_index.cc
Adds TryRewriteMixedRamp to expand a Ramp into per-lane expressions, adjust proven-negative lanes with buffer_extent, and reconstruct the vector via Shuffle; adds UpdateRegion to widen a BufferRegion to FullRegion when any axis min is provably negative.
Index update integration
src/transform/legalize_negative_index.cc
Extends UpdateIdx so IndexSignState::kUnknown attempts TryRewriteMixedRamp and replaces the index with the rewritten vector when successful; kNegative still uses buffer_shape + index.
Block visitor region mutation
src/transform/legalize_negative_index.cc
Overrides VisitStmt_(const BlockNode*) to mutate Block reads and writes in-place by applying UpdateRegion, widening regions with provably negative minima to full regions.
Vectorized load/store tests
testing/python/transform/test_tilelang_transform_legalize_negative_index.py
Adds test_buffer_load_vector_index_mixed_sign_ramp, test_buffer_load_vector_index_mixed_sign_ramp_preserves_int64_extent, test_buffer_load_vector_index_mixed_sign_ramp_in_kernel, and test_buffer_store_vector_index_mixed_sign_ramp validating Shuffle-based legalization and int64 extent preservation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • kurisu6912
  • LeiWang1999

Poem

A rabbit peeks at ramps that swing both ways,
Splits lanes, shuffles, counts the negative days.
Each lane wears buffer_extent like a hat,
Blocks grow wide regions—no lane falls flat.
Hooray for shuffled indices—hip-hop, hooray! 🐰✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.36% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main enhancement: adding support for mixed-sign ramp indices in the LegalizeNegativeIndex transform, which directly matches the primary change in the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999

Copy link
Copy Markdown
Member

Thanks!

@LeiWang1999 LeiWang1999 merged commit a7d9380 into tile-ai:main May 19, 2026
4 of 6 checks passed
@TerminusAkivili TerminusAkivili deleted the enhance/legalize-mixed-sign-ramp branch May 31, 2026 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants