Skip to content

Fix atomic_load access_ptr lowering for dynamic indices#2157

Merged
LeiWang1999 merged 11 commits into
tile-ai:mainfrom
VitalyAnkh:fix/issue-2123-access-ptr-helper
May 25, 2026
Merged

Fix atomic_load access_ptr lowering for dynamic indices#2157
LeiWang1999 merged 11 commits into
tile-ai:mainfrom
VitalyAnkh:fix/issue-2123-access-ptr-helper

Conversation

@VitalyAnkh

@VitalyAnkh VitalyAnkh commented May 6, 2026

Copy link
Copy Markdown
Contributor

Summary

Validation

./format.sh passed pre-commit hooks; the script reported unstaged changes before commit.

./format.sh
cmake -S . -B build
cmake --build build -j$(nproc)
PYTHONPATH=$(pwd) python -m pytest testing/python/language/test_tilelang_language_reduce.py -q
PYTHONPATH=$(pwd) python -m pytest testing/python/language/test_tilelang_language_atomic.py -q -k "test_atomic_add or test_tile_atomic_add or test_tma_atomic_add"
PYTHONPATH=$(pwd) python -m pytest testing/python/language/test_tilelang_language_reduce_maxmin_nan.py -q
PYTHONPATH=$(pwd) python -m pytest testing/python/issue/test_tilelang_issue_tma_no_ws.py::test_sparse_ws_regular_metadata_copy_stays_in_producer testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py::test_ws_keeps_full_producer_extent_for_lowered_simt_copy -q

Summary by CodeRabbit

  • Bug Fixes
    • Improved atomic load lowering and access pointer transformations for enhanced correctness in memory access handling
  • Tests
    • Added comprehensive test coverage for atomic load operations across different execution contexts

VitalyAnkh added 2 commits May 7, 2026 01:53
Add the issue 2123 regression test and keep the access_ptr base BufferLoad intact through safe-memory legalization and lowering.

Closes tile-ai#2123
Move shared tl.access_ptr BufferLoad base visitation into a common transform helper so LowerAccessPtr and safe-memory legalization stay aligned.

Extend the issue 2123 regression with a direct LowerAccessPtr pass check while keeping the LowerAndLegalize pipeline coverage.
@coderabbitai

coderabbitai Bot commented May 6, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a reusable helper to visit/rebuild a BufferLoad's indices and predicate for tl.access_ptr, integrates it into lowering and safe-memory legalization, updates visitors to recurse uniformly, and adds a regression test validating access_ptrtvm_access_ptr lowering for an atomic_load scenario.

Changes

Access Pointer Lowering and Legalization Fix

Layer / File(s) Summary
Core Helper
src/transform/common/access_ptr_utils.h
New templated VisitAccessPtrBase that validates a PrimExpr as a tir::BufferLoad, visits each index and optional predicate with a supplied visitor, and returns either the original or a rebuilt BufferLoad.
Legalization Integration
src/transform/legalize_safe_memory_access.cc
Include common/access_ptr_utils.h; add SafeMemorysRewriter(arith::Analyzer*) constructor; add VisitAccessPtrBase helper and override VisitExpr_(const CallNode *op) to special-case tl.access_ptr by validating 3 args, transforming the base via the helper, visiting other args, and reconstructing the call; non-access_ptr calls delegated to base visitor.
Lowering Integration
src/transform/lower_access_ptr.cc
Include common/access_ptr_utils.h; refactor AccessPtrLowerer::VisitExpr_ to early-return for non-tl.access_ptr, ensure 3 args, obtain/transform base BufferLoad via detail::VisitAccessPtrBase (using an inline visitor), visit extent and rw_mask with VisitExpr, and add a private VisitAccessPtrBase mirroring the utility logic.
Tests
testing/python/issue/test_tilelang_issue_2123.py
New repro issue_2123_atomic_load_repro() and two tests asserting tl.access_ptr is lowered to tir.tvm_access_ptr: direct lowering via LowerAccessPtr and pipeline lowering via LowerAndLegalize (CUDA target). Includes helpers to detect op calls and assert expected lowering.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • tile-ai/tilelang#1292: Modifies SafeMemorysRewriter and related legalization wiring; overlaps with constructor/visitor changes.
  • tile-ai/tilelang#1827: Related work on access_ptr/tvm_access_ptr lowering and VisitAccessPtrBase-style helpers.
  • tile-ai/tilelang#1050: Prior edits to legalize_safe_memory_access.cc and load/store rewriting that overlap with this PR's changes.

Suggested reviewers

  • LeiWang1999

"I hopped through indices, nibbling bugs away,
I rebuilt the loads where loose predicates stray,
Atomic loads now clear,
Access paths draw near,
Kernels bound and happy — hip-hip-hooray!" 🐰✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Fix atomic_load access_ptr lowering for dynamic indices' directly and specifically describes the main change: fixing tl.access_ptr lowering to handle dynamic indices in atomic_load contexts.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented May 6, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/transform/common/access_ptr_utils.h (1)

7-9: ⚡ Quick win

Avoid using namespace directives in headers.

using namespace tir; at namespace scope in a header injects the entire tir namespace into tvm::tl for every translation unit that includes this file. This can cause name-lookup surprises and ODR clashes as the codebase grows. Prefer fully-qualifying names (tir::BufferLoad, tir::PrimExpr, etc.) within the header.

♻️ Suggested change
 namespace tvm {
 namespace tl {
-using namespace tir;
 
 namespace detail {
 
 template <typename VisitExprFn>
-BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) {
-  const auto *base_load_node = expr.as<BufferLoadNode>();
+tir::BufferLoad VisitAccessPtrBase(const tir::PrimExpr &expr, VisitExprFn &&visit_expr) {
+  const auto *base_load_node = expr.as<tir::BufferLoadNode>();
   ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got "
                          << expr;
-  BufferLoad base_load = tvm::ffi::GetRef<BufferLoad>(base_load_node);
+  tir::BufferLoad base_load = tvm::ffi::GetRef<tir::BufferLoad>(base_load_node);
 
-  Array<PrimExpr> indices;
+  tvm::Array<tir::PrimExpr> indices;
   bool changed = false;
-  for (const PrimExpr &index : base_load->indices) {
+  for (const tir::PrimExpr &index : base_load->indices) {
     ...
   }
 
-  Optional<PrimExpr> predicate = base_load->predicate;
+  tvm::Optional<tir::PrimExpr> predicate = base_load->predicate;
   ...
-  return BufferLoad(base_load->buffer, indices, predicate, base_load->span);
+  return tir::BufferLoad(base_load->buffer, indices, predicate, base_load->span);
 }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/transform/common/access_ptr_utils.h` around lines 7 - 9, Remove the
header-level "using namespace tir;" from src/transform/common/access_ptr_utils.h
and update all references that rely on it to fully-qualified names (e.g.,
tir::BufferLoad, tir::PrimExpr, tir::Var, etc.) inside the tvm::tl namespace and
in functions/classes such as any occurrences in AccessPtr utilities; this keeps
the header from polluting translation units and prevents lookup/ODR issues while
preserving the existing API and symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@src/transform/common/access_ptr_utils.h`:
- Around line 7-9: Remove the header-level "using namespace tir;" from
src/transform/common/access_ptr_utils.h and update all references that rely on
it to fully-qualified names (e.g., tir::BufferLoad, tir::PrimExpr, tir::Var,
etc.) inside the tvm::tl namespace and in functions/classes such as any
occurrences in AccessPtr utilities; this keeps the header from polluting
translation units and prevents lookup/ODR issues while preserving the existing
API and symbols.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 13264419-1cb3-4c81-92f3-d99c3796f3ca

📥 Commits

Reviewing files that changed from the base of the PR and between 7204f78 and 14e9308.

📒 Files selected for processing (1)
  • src/transform/common/access_ptr_utils.h

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/transform/common/access_ptr_utils.h (1)

8-9: 💤 Low value

Remove unnecessary include <tvm/ffi/reflection/registry.h>.

This header is only used for FFI registration macros, not for tvm::ffi::GetRef. The codebase shows GetRef works without it (e.g., loop_vectorization_utils.h uses GetRef extensively without this header). The required declarations come from <tvm/tir/expr.h> or its transitive includes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/transform/common/access_ptr_utils.h` around lines 8 - 9, Remove the
unnecessary include of <tvm/ffi/reflection/registry.h> from the include block —
it is only for FFI registration macros and not required for tvm::ffi::GetRef or
the code in this file; delete that include line in access_ptr_utils.h and keep
the existing <tvm/tir/expr.h> (and other necessary includes) so compilation uses
the transitive declarations instead.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@src/transform/common/access_ptr_utils.h`:
- Around line 8-9: Remove the unnecessary include of
<tvm/ffi/reflection/registry.h> from the include block — it is only for FFI
registration macros and not required for tvm::ffi::GetRef or the code in this
file; delete that include line in access_ptr_utils.h and keep the existing
<tvm/tir/expr.h> (and other necessary includes) so compilation uses the
transitive declarations instead.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dbd08278-cfbd-468d-b7d2-82a67e67b051

📥 Commits

Reviewing files that changed from the base of the PR and between 31a7fff and 0b4fdcc.

📒 Files selected for processing (3)
  • src/transform/common/access_ptr_utils.h
  • src/transform/legalize_safe_memory_access.cc
  • src/transform/lower_access_ptr.cc

@LeiWang1999 LeiWang1999 self-requested a review May 20, 2026 09:49
@LeiWang1999

Copy link
Copy Markdown
Member

Thanks for tracking this down and putting together a fix. However, I do not think accepting a non-BufferLoad as the first argument of tl.access_ptr is the fundamental fix here. Semantically, the first argument of tl.access_ptr is the address-producing buffer access, so it really should remain a BufferLoad. If LegalizeSafeMemoryAccess rewrites that BufferLoad into an if_then_else, the IR starts to mean "take the address of either status[look] or safe_value", which is not a valid pointer semantics.

In other words, we should avoid producing IR like this:

tl.access_ptr(
    tirx.if_then_else(
        in_bounds(look),
        BufferLoad(status, [look]),
        safe_value,
    ),
    1,
    read_mask,
)

and instead preserve the pointer shape and put the safety guard around the operation that consumes the pointer, for example:

tirx.if_then_else(
    in_bounds(look),
    tl.atomic_load_elem_op(
        tl.access_ptr(
            BufferLoad(status, [look]),
            1,
            read_mask,
        ),
        ...
    ),
    safe_value,
)

I will make some follow-up improvements in LegalizeSafeMemoryAccess so it treats tl.access_ptr specially instead of recursively rewriting its base BufferLoad as a value load.

…s.cc

Updated the return structure of AccessPtrInfo to improve readability and maintainability. The logic for handling the tvm_access_ptr operation has been streamlined, ensuring clearer checks and consistent formatting of returned values.
@LeiWang1999

Copy link
Copy Markdown
Member

@regression-perf

@github-actions

Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/26383116746

Results

File Original Latency Current Latency Speedup
example_topk 38.4832 61.9263 0.621436
example_mha_bwd_bhsd 0.0294688 0.0298276 0.987969
example_tilelang_gemm_fp8 0.236327 0.238852 0.989425
example_linear_attn_fwd 0.0284143 0.0286875 0.990476
example_mha_bwd_bshd 0.0290475 0.0293207 0.990682
example_warp_specialize_gemm_copy_1_gemm_0 0.0194005 0.0195475 0.992482
topk_selector 0.0412668 0.041573 0.992635
example_mha_fwd_bhsd 0.00904037 0.00909887 0.993571
example_dequant_gemm_bf16_mxfp4_hopper 0.355616 0.357812 0.993864
example_gqa_fwd_bshd 0.0511719 0.0514453 0.994686
example_gqa_bwd 0.0327486 0.0329131 0.995004
example_gemv 0.201261 0.202245 0.995134
example_gemm 0.0170397 0.0171204 0.995286
block_sparse_attn_tilelang 0.00669874 0.00672606 0.995937
fp8_lighting_indexer 0.0226888 0.0227789 0.996045
example_mha_sink_bwd_bhsd_sliding_window 0.0380939 0.0382276 0.996502
example_gqa_decode 0.0410856 0.0412222 0.996686
example_convolution 0.922269 0.925256 0.996772
example_group_per_split_token_cast_to_fp8 0.00761487 0.00763476 0.997394
example_fusedmoe_tilelang 0.0953438 0.0955928 0.997396
example_mhc_pre 0.144111 0.144468 0.997531
example_mha_sink_fwd_bhsd_sliding_window 0.0126635 0.0126943 0.997574
example_warp_specialize_gemm_barrierpipe_stage2 0.0295769 0.0296192 0.998572
sparse_mla_fwd 0.0825764 0.0826105 0.999587
example_tilelang_sparse_gqa_decode_varlen_indice 0.0117931 0.0117956 0.999788
example_dequant_gemm_w4a8 3.83158 3.83207 0.999871
example_warp_specialize_gemm_copy_0_gemm_1 0.0269352 0.0269361 0.999969
example_tilelang_block_sparse_attn 0.00724506 0.00724522 0.999978
example_gqa_sink_bwd_bhsd_sliding_window 0.018109 0.0181086 1.00002
example_blocksparse_gemm 0.0137177 0.0137171 1.00004
example_mha_fwd_bshd 0.0190445 0.0190417 1.00015
example_tilelang_sparse_gqa_decode_varlen_mask 0.0127861 0.0127838 1.00018
example_vertical_slash_sparse_attn 0.167321 0.167268 1.00032
example_per_token_cast_to_fp8 0.0065059 0.00650262 1.0005
example_dequant_gemv_fp16xint4 0.0269908 0.0269733 1.00065
example_convolution_autotune 0.723778 0.723245 1.00074
example_tilelang_nsa_decode 0.00551607 0.00551163 1.0008
example_tilelang_nsa_fwd 0.00562289 0.00561754 1.00095
example_elementwise_add 0.113091 0.112982 1.00096
example_dynamic 0.499459 0.49888 1.00116
example_mha_sink_fwd_bhsd 0.0127363 0.0127186 1.00138
example_gemm_intrinsics 0.0254448 0.0254095 1.00139
example_tilelang_gemm_splitk 0.768636 0.76692 1.00224
example_gemm_autotune 0.0162357 0.0161965 1.00242
example_mhc_post 0.10674 0.106394 1.00324
example_mha_inference 0.0628477 0.0626338 1.00342
sparse_mla_fwd_pipelined 0.059363 0.0591577 1.00347
example_dequant_gemm_bf16_fp4_hopper 0.398265 0.396877 1.0035
sparse_mla_bwd 0.229959 0.228967 1.00433
example_mla_decode 0.315918 0.314207 1.00545
example_tilelang_gemm_fp8_2xAcc 0.0913236 0.0908246 1.00549
example_warp_specialize_gemm_softpipe_stage2 0.0195565 0.0194366 1.00617
example_gqa_sink_bwd_bhsd 0.0301934 0.0299761 1.00725
example_mha_fwd_varlen 0.032942 0.0326363 1.00937
example_dequant_gemm_fp4_hopper 0.717325 0.709321 1.01128
example_linear_attn_bwd 0.118981 0.117587 1.01186
example_tilelang_gemm_splitk_vectorize_atomicadd 0.793654 0.78344 1.01304
example_gqa_bwd_tma_reduce_varlen 0.0334657 0.0330253 1.01333
example_mha_sink_bwd_bhsd 0.0526599 0.0518506 1.01561

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit b1083d7 into tile-ai:main May 25, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] T.atomic_load fails in LowerAccessPtr when the address uses a block-derived loop variable

2 participants