Skip to content

[Enhancement]Support mixed-sign ramp indices in LegalizeNegativeIndex#2222

Closed
TerminusAkivili wants to merge 1 commit into
tile-ai:mainfrom
TerminusAkivili:codex/bugfix-mixed-negative-ramp
Closed

[Enhancement]Support mixed-sign ramp indices in LegalizeNegativeIndex#2222
TerminusAkivili wants to merge 1 commit into
tile-ai:mainfrom
TerminusAkivili:codex/bugfix-mixed-negative-ramp

Conversation

@TerminusAkivili

@TerminusAkivili TerminusAkivili commented May 19, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Extend LegalizeNegativeIndex to handle constant Ramp indices whose lanes mix negative and non-negative values.
  • Preserve existing negative-index semantics by rewriting only the lanes proven negative and keeping non-negative lanes unchanged.
  • Widen affected block read/write regions conservatively when the original region starts below zero.

Motivation

TileLang already legalizes Python-style negative indices in scalar and all-negative vector cases, for example:

A[-1]
A[T.Ramp(-4, 1, 4)]

This PR extends the same behavior to the mixed-sign constant ramp case:

A[T.Ramp(-2, 1, 4)]  # lanes: [-2, -1, 0, 1]

The existing pass classifies the whole ramp as unknown-sign, because the vector expression is neither fully negative nor fully non-negative. As a result, the negative lanes are left unlegalized even though each lane is constant and can be handled independently.

Kernel Reproducer

Kernels often split a wrapped circular-buffer window into two contiguous accesses, which is usually the faster lowering shape. This example is not meant to prescribe a preferred kernel style; it is a compact reproducer for the existing negative-index semantics when a vectorized load crosses a circular-buffer boundary.

Complete reproducer:

import torch
import tilelang
from tilelang import tvm as tvm
import tilelang.language as T


CACHE_LEN = 1024
WINDOW = 4
WRAP_LEFT = 2


@T.prim_func
def sliding_window_kv_cache_unwrap(
    K_cache: T.Tensor((CACHE_LEN,), T.float32),
    K_window: T.Tensor((WINDOW,), T.float32),
):
    """Gather a KV window that crosses the circular-cache boundary.

    Logical window:
      K_cache[-2], K_cache[-1], K_cache[0], K_cache[1]

    Expected physical slots:
      K_cache[1022], K_cache[1023], K_cache[0], K_cache[1]
    """
    with T.Kernel(1, threads=1) as _:
        logical_window = T.Ramp(-WRAP_LEFT, 1, WINDOW)
        K_window[T.Ramp(0, 1, WINDOW)] = K_cache[logical_window]


def show_lowered_ir():
    mod = tvm.IRModule.from_expr(
        sliding_window_kv_cache_unwrap.with_attr("global_symbol", "main")
    )
    lowered = tilelang.transform.LegalizeNegativeIndex()(mod)["main"]
    print(lowered.script())


def run_kernel():
    backing = torch.empty(CACHE_LEN + WRAP_LEFT, device="cuda", dtype=torch.float32)
    backing[0] = -2000.0
    backing[1] = -1000.0
    backing[WRAP_LEFT:] = torch.arange(CACHE_LEN, device="cuda", dtype=torch.float32)

    k_cache_view = backing[WRAP_LEFT:]
    k_window = torch.empty(WINDOW, device="cuda", dtype=torch.float32)
    expected = torch.tensor([1022.0, 1023.0, 0.0, 1.0], device="cuda")

    kernel = tilelang.compile(
        sliding_window_kv_cache_unwrap,
        target="cuda",
        execution_backend="nvrtc",
    )
    kernel(k_cache_view, k_window)
    torch.cuda.synchronize()

    print("actual  :", k_window.detach().cpu().tolist())
    print("expected:", expected.detach().cpu().tolist())
    assert torch.allclose(k_window, expected)


if __name__ == "__main__":
    show_lowered_ir()
    if torch.cuda.is_available():
        run_kernel()

Before this patch, the transform leaves the mixed-sign access as an unlegalized negative range/ramp:

T.reads(K_cache[-2:2])
K_window[0:4] = K_cache[logical_window]

After this patch, the same transform produces the lane-wise legalized indices:

T.reads(K_cache[0:1024])
K_window[0:4] = K_cache[T.Shuffle([1022, 1023, 0, 1], [0, 1, 2, 3])]

This keeps the transform behavior consistent across scalar negative indices, all-negative vector ramps, and constant mixed-sign vector ramps.

Implementation

When the sign state is unknown, the rewriter now tries a narrow lane-wise rewrite for constant Ramp indices:

  • If a lane is provably negative, rewrite it as buffer_extent + lane.
  • If a lane is provably non-negative, keep it unchanged.
  • If any lane remains unknown, leave the original expression unchanged.

The rewritten vector is represented as T.Shuffle, which preserves the per-lane indexing semantics without changing truly dynamic unknown-sign cases.

Summary by CodeRabbit

  • New Features

    • Support for legalization of mixed-sign vector indices in buffer load/store operations.
  • Bug Fixes

    • Improved handling of negative lanes in vectorized buffer access and block-level region metadata.
  • Tests

    • Added tests covering mixed-sign vector index scenarios for loads and stores (including kernel contexts).

Review Change Stack

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented May 19, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8d0bd46f-f4e0-40e3-87a9-4b0827c41459

📥 Commits

Reviewing files that changed from the base of the PR and between 83eaef1 and 50f210b.

📒 Files selected for processing (2)
  • src/transform/legalize_negative_index.cc
  • testing/python/transform/test_tilelang_transform_legalize_negative_index.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • testing/python/transform/test_tilelang_transform_legalize_negative_index.py
  • src/transform/legalize_negative_index.cc

📝 Walkthrough

Walkthrough

The PR extends the LegalizeNegativeIndex pass to handle mixed-sign vector ramp indices and block-level buffer region metadata. New helpers detect and rewrite ramp lanes that are provably negative through shuffle reconstruction, while a block visitor applies region-level legalization to buffer metadata.

Changes

Mixed-Sign Ramp Index Legalization

Layer / File(s) Summary
Mixed-sign ramp legalization helpers
src/transform/legalize_negative_index.cc
TryRewriteMixedRamp attempts per-lane legalization of mixed-sign Ramp indices by comparing against buffer extent bounds and reconstructing via Shuffle; UpdateRegion scans buffer region ranges and replaces provably-negative regions with full regions.
UpdateIdx integration with mixed-ramp rewriting
src/transform/legalize_negative_index.cc
UpdateIdx is extended to try TryRewriteMixedRamp for kUnknown index axes, while kNegative indices continue using the existing buffer_shape[i] + index rewrite.
Block-level buffer region metadata mutation
src/transform/legalize_negative_index.cc
New VisitStmt_ override for BlockNode mutates block-level reads and writes buffer regions using UpdateRegion, enabling region-level negative-index legalization beyond individual load/store expressions.
Test coverage for mixed-sign ramp indices
testing/python/transform/test_tilelang_transform_legalize_negative_index.py
Three new tests validate mixed-sign ramp legalization: two load tests (standalone and within T.Kernel block) and one store test, all verifying index remapping via shuffled positive-index vectors.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • tile-ai/tilelang#1354: Both PRs modify LegalizeNegativeIndex's ramp/broadcast index sign handling, aligning transformation reasoning for when vector ramp lanes are provably negative/non-negative.
  • tile-ai/tilelang#1207: Both PRs improve legality rewriting for vector Ramp indices in NegativeIndexRewriter through enhanced sign analysis.
  • tile-ai/tilelang#1339: Both PRs modify core NegativeIndexRewriter/UpdateIdx logic to rewrite negative indices for buffer accesses; this PR further adds mixed-sign ramp and block region rewriting.

Suggested reviewers

  • kurisu6912
  • LeiWang1999

Poem

🐰 I hopped through ramps of mixed-sign lanes,

Shuffled lanes to heal the negative pains.
Blocks now widen where minima fall,
So indices stay lawful, one and all.
Hooray for vectors that answer the call!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.84% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main enhancement: adding support for mixed-sign ramp indices in the LegalizeNegativeIndex transform, which is the core change across both the implementation and test files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@TerminusAkivili TerminusAkivili marked this pull request as draft May 19, 2026 05:40

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/transform/legalize_negative_index.cc (1)

173-173: 💤 Low value

Consider adding a null check for non-constant lanes.

as_const_int returns nullptr if ramp->lanes is not a constant integer. While the existing analyzer code (line 68) uses the same pattern, dereferencing without a check is technically UB if a symbolic lanes value ever appears.

Suggested defensive check
-    int lanes = *as_const_int(ramp->lanes);
+    const int64_t* lanes_ptr = as_const_int(ramp->lanes);
+    if (lanes_ptr == nullptr)
+      return PrimExpr();
+    int lanes = static_cast<int>(*lanes_ptr);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/transform/legalize_negative_index.cc` at line 173, The code dereferences
the result of as_const_int(ramp->lanes) without checking for nullptr (int lanes
= *as_const_int(ramp->lanes)); add a defensive null check around the
as_const_int call (referencing ramp->lanes and as_const_int) and handle the
non-constant case by skipping the legalization path or returning/continuing
appropriately (e.g., treat it as non-constant and leave the node unchanged or
bail out from the current transformation); ensure you avoid UB by only
dereferencing when the pointer is non-null and preserve existing analyzer
behavior used elsewhere in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@src/transform/legalize_negative_index.cc`:
- Line 173: The code dereferences the result of as_const_int(ramp->lanes)
without checking for nullptr (int lanes = *as_const_int(ramp->lanes)); add a
defensive null check around the as_const_int call (referencing ramp->lanes and
as_const_int) and handle the non-constant case by skipping the legalization path
or returning/continuing appropriately (e.g., treat it as non-constant and leave
the node unchanged or bail out from the current transformation); ensure you
avoid UB by only dereferencing when the pointer is non-null and preserve
existing analyzer behavior used elsewhere in the file.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c8110274-8ddc-4c65-a03c-0f541296c2fa

📥 Commits

Reviewing files that changed from the base of the PR and between f11954c and 83eaef1.

📒 Files selected for processing (2)
  • src/transform/legalize_negative_index.cc
  • testing/python/transform/test_tilelang_transform_legalize_negative_index.py

@TerminusAkivili TerminusAkivili force-pushed the codex/bugfix-mixed-negative-ramp branch from 83eaef1 to 50f210b Compare May 19, 2026 08:17
@TerminusAkivili TerminusAkivili marked this pull request as ready for review May 19, 2026 08:27
@TerminusAkivili TerminusAkivili deleted the codex/bugfix-mixed-negative-ramp branch May 19, 2026 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant