Skip to content

fix(merge_shmem): allow shared memory reuse for buffers with disjoint lifetimes#1987

Merged
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
reoLantern:fix/merge-shmem-kill-reorder-gen-boundary
Mar 29, 2026
Merged

fix(merge_shmem): allow shared memory reuse for buffers with disjoint lifetimes#1987
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
reoLantern:fix/merge-shmem-kill-reorder-gen-boundary

Conversation

@reoLantern

@reoLantern reoLantern commented Mar 28, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Fix kill-point reorder in merge_shared_memory_allocations.cc overshooting past another shared buffer's gen point
  • Enables Q_shared / O_shared memory reuse in Flash Attention, saving 32KB shared memory
  • FA with block_M=128, block_N=128, num_stages=2 now fits in A100's 164KB limit (160KB vs 192KB before)

Reproduce

Using examples/flash_attention/example_mha_fwd_bshd.py, call:

kernel = flashattn(8, 32, 4096, 128, False,
                   block_M=128, block_N=128, num_stages=2, threads=128)

Before: Failed to set the allowed dynamic shared memory size to 196608
After: Runs successfully (~12% faster than num_stages=1 thanks to double-buffered pipeline)

Root cause

The kill reorder in LivenessAnalysis moves buffer kills from deep scopes (inside pipeline loops) to their enclosing scope boundary. The forward scan searched for the first statement at gen_level but could land past another buffer's gen point:

Q_shared gen (before loop) → Q_shared read (inside loop) → loop ends →
O_shared gen (after loop) → Q_shared kill (reordered here — too late!)

The fix adds an early-stop: when the next statement generates a different shared buffer, place the kill before it.

Test plan

  • examples/flash_attention/example_mha_fwd_bshd.py with num_stages=2, block_M=128, block_N=128: was 196608 error, now succeeds
  • Same example with num_stages=1: unchanged (147 TFlops)
  • GEMM examples: no regression

Summary by CodeRabbit

  • Bug Fixes
    • Fixed memory buffer reordering logic to correctly handle scenarios with multiple shared memory allocations, preventing potential conflicts during buffer lifecycle management.

…point

In merge_shared_memory_allocations.cc, the kill-point reorder scans forward
to move a deep-scope kill to the enclosing scope boundary.  Previously it
could overshoot past another buffer's gen (birth) point, creating a false
liveness overlap that blocks memory reuse.

This patch adds an early-stop condition: if the next statement in the
linearized sequence generates a different shared-memory buffer, place the
kill before it.  This is safe because T.alloc_shared is always outside
pipelined loop bodies — no new shared buffer is born inside the deep scope
where kills are being reordered from.

Reproduce with examples/flash_attention/example_mha_fwd_bshd.py:

  # Before fix — fails (192KB > A100 164KB limit):
  kernel = flashattn(8, 32, 4096, 128, False,
                     block_M=128, block_N=128, num_stages=2, threads=128)
  # → "Failed to set the allowed dynamic shared memory size to 196608"

  # After fix — succeeds (160KB, Q_shared/O_shared merged):
  # Same call now runs successfully.

The Flash Attention kernel allocates four shared buffers:
  Q_shared (32KB) — loaded once before the pipeline loop
  K_shared (32KB) — loaded each iteration (double-buffered → 64KB)
  V_shared (32KB) — loaded each iteration (double-buffered → 64KB)
  O_shared (32KB) — written after the pipeline loop

Q_shared and O_shared have disjoint lifetimes (Q is dead after the last
loop iteration; O is born after the loop).  Before this fix, the kill
reorder extended Q_shared's kill past O_shared's gen, preventing the merge
pass from reusing Q's memory for O:
  Without fix: 32 + 64 + 64 + 32 = 192KB (Q and O separate)
  With fix:    32 + 64 + 64      = 160KB (Q and O share offset 0)
Copilot AI review requested due to automatic review settings March 28, 2026 00:47
@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Mar 28, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

The SharedMemoryRewriter class's liveness-kill reordering logic now includes an additional stopping condition when computing the reassignment point for a buffer's kill statement. The loop terminates early if the next statement is a gen site for a different shared-memory buffer, preventing kill statements from being moved past other buffers' generation sites.

Changes

Cohort / File(s) Summary
Shared Memory Kill Reordering
src/transform/merge_shared_memory_allocations.cc
Modified last_stmt_at_level computation in SharedMemoryRewriter to check if the next statement is a gen site for a different shared-memory buffer; if so, stops the loop and sets the reassignment point accordingly.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Through shared buffers we hop with care,
Reordering kills with proper flair,
When gen sites whisper of brothers near,
We pause and yield—a boundary clear!
No crossing past another's claim,
Just orderly swaps within the game.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix(merge_shmem): allow shared memory reuse for buffers with disjoint lifetimes' directly and accurately summarizes the main change: enabling shared memory buffer reuse by fixing a kill-point reorder bug that was preventing buffers with non-overlapping lifetimes from reusing the same memory space.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/transform/merge_shared_memory_allocations.cc (1)

1031-1034: Please add a regression test for the stated safety invariant.

The behavior now depends on the assumption in Line 1031-1034; a focused IR test that exercises “kill reorder with intervening different-buffer gen” would lock this in and prevent future regressions.

I can draft a minimal TIR regression test case for this scenario if you want.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/merge_shared_memory_allocations.cc` around lines 1031 - 1034,
Add a focused regression test that reproduces the “kill reorder with intervening
different-buffer gen” scenario relied upon by
merge_shared_memory_allocations.cc: create a minimal TIR function that contains
a pipelined loop with a T.alloc_shared allocation outside the loop, an inner
kill + reorder transformation opportunity, and an intervening different-buffer
generation so the pass must not merge or move the shared allocation incorrectly;
ensure the test asserts the transformation preserves semantics (no new shared
buffer born inside the deep scoped kill) and add it to the TIR transform tests
(name it something like test_kill_reorder_with_intervening_buffer) so future
changes to MergeSharedMemoryAllocations/MergeSharedMemoryAllocationsPass or the
merge_shared_memory_allocations logic will fail if the invariant is violated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/transform/merge_shared_memory_allocations.cc`:
- Around line 1031-1034: Add a focused regression test that reproduces the “kill
reorder with intervening different-buffer gen” scenario relied upon by
merge_shared_memory_allocations.cc: create a minimal TIR function that contains
a pipelined loop with a T.alloc_shared allocation outside the loop, an inner
kill + reorder transformation opportunity, and an intervening different-buffer
generation so the pass must not merge or move the shared allocation incorrectly;
ensure the test asserts the transformation preserves semantics (no new shared
buffer born inside the deep scoped kill) and add it to the TIR transform tests
(name it something like test_kill_reorder_with_intervening_buffer) so future
changes to MergeSharedMemoryAllocations/MergeSharedMemoryAllocationsPass or the
merge_shared_memory_allocations logic will fail if the invariant is violated.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3813bfd5-b486-4c8a-bfbb-08282c78ab7a

📥 Commits

Reviewing files that changed from the base of the PR and between bdf436d and 92c6957.

📒 Files selected for processing (1)
  • src/transform/merge_shared_memory_allocations.cc

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes shared-memory liveness kill-point reordering in merge_shared_memory_allocations.cc so buffers with disjoint lifetimes don’t appear to overlap, enabling additional shared-memory reuse (e.g., Q_shared/O_shared reuse in Flash Attention to reduce SMEM footprint).

Changes:

  • Adds an early-stop condition when reordering kill points: if the next statement generates a different shared-memory buffer, place the reordered kill before it.
  • Documents the rationale for the early-stop to prevent false liveness overlap that blocks reuse.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1043 to +1047
// Stop if the next statement births a different shared buffer.
auto next_event_it = event_map_.find(next_it->stmt);
if (next_event_it != event_map_.end() &&
!next_event_it->second.gen.empty()) {
bool has_other_gen = false;

Copilot AI Mar 28, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adjusts kill-point reordering to stop before the next statement that generates a different shared-memory buffer, which is a subtle correctness/perf-sensitive part of the shared-memory arena planner. Please add a regression test under testing/python/transform/ that constructs a minimal TIR with two shared buffers whose lifetimes are disjoint across a pipelined loop boundary (similar to the Q_shared/O_shared pattern), and assert that MergeSharedMemoryAllocations assigns the same byte offset (or otherwise demonstrates successful reuse) and does not report overlapping lifetimes.

Copilot uses AI. Check for mistakes.
Comment on lines +1044 to +1048
auto next_event_it = event_map_.find(next_it->stmt);
if (next_event_it != event_map_.end() &&
!next_event_it->second.gen.empty()) {
bool has_other_gen = false;
for (const VarNode *gen_buf : next_event_it->second.gen) {

Copilot AI Mar 28, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next_event_it != event_map_.end() is effectively always true here: gen_kill_seq is populated using event_map_[stmt_entry.stmt], which ensures every stmt in gen_kill_seq already has an entry in event_map_. Consider simplifying by using event_map_.at(next_it->stmt) (or equivalent) and dropping the find/end() check to reduce branching and avoid hiding unexpected missing-key bugs.

Suggested change
auto next_event_it = event_map_.find(next_it->stmt);
if (next_event_it != event_map_.end() &&
!next_event_it->second.gen.empty()) {
bool has_other_gen = false;
for (const VarNode *gen_buf : next_event_it->second.gen) {
const auto& next_event = event_map_.at(next_it->stmt);
if (!next_event.gen.empty()) {
bool has_other_gen = false;
for (const VarNode *gen_buf : next_event.gen) {

Copilot uses AI. Check for mistakes.
@LeiWang1999 LeiWang1999 merged commit 3a956e0 into tile-ai:main Mar 29, 2026
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants