fix(merge_shmem): allow shared memory reuse for buffers with disjoint lifetimes#1987
Conversation
…point
In merge_shared_memory_allocations.cc, the kill-point reorder scans forward
to move a deep-scope kill to the enclosing scope boundary. Previously it
could overshoot past another buffer's gen (birth) point, creating a false
liveness overlap that blocks memory reuse.
This patch adds an early-stop condition: if the next statement in the
linearized sequence generates a different shared-memory buffer, place the
kill before it. This is safe because T.alloc_shared is always outside
pipelined loop bodies — no new shared buffer is born inside the deep scope
where kills are being reordered from.
Reproduce with examples/flash_attention/example_mha_fwd_bshd.py:
# Before fix — fails (192KB > A100 164KB limit):
kernel = flashattn(8, 32, 4096, 128, False,
block_M=128, block_N=128, num_stages=2, threads=128)
# → "Failed to set the allowed dynamic shared memory size to 196608"
# After fix — succeeds (160KB, Q_shared/O_shared merged):
# Same call now runs successfully.
The Flash Attention kernel allocates four shared buffers:
Q_shared (32KB) — loaded once before the pipeline loop
K_shared (32KB) — loaded each iteration (double-buffered → 64KB)
V_shared (32KB) — loaded each iteration (double-buffered → 64KB)
O_shared (32KB) — written after the pipeline loop
Q_shared and O_shared have disjoint lifetimes (Q is dead after the last
loop iteration; O is born after the loop). Before this fix, the kill
reorder extended Q_shared's kill past O_shared's gen, preventing the merge
pass from reusing Q's memory for O:
Without fix: 32 + 64 + 64 + 32 = 192KB (Q and O separate)
With fix: 32 + 64 + 64 = 160KB (Q and O share offset 0)
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/transform/merge_shared_memory_allocations.cc (1)
1031-1034: Please add a regression test for the stated safety invariant.The behavior now depends on the assumption in Line 1031-1034; a focused IR test that exercises “kill reorder with intervening different-buffer gen” would lock this in and prevent future regressions.
I can draft a minimal TIR regression test case for this scenario if you want.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/merge_shared_memory_allocations.cc` around lines 1031 - 1034, Add a focused regression test that reproduces the “kill reorder with intervening different-buffer gen” scenario relied upon by merge_shared_memory_allocations.cc: create a minimal TIR function that contains a pipelined loop with a T.alloc_shared allocation outside the loop, an inner kill + reorder transformation opportunity, and an intervening different-buffer generation so the pass must not merge or move the shared allocation incorrectly; ensure the test asserts the transformation preserves semantics (no new shared buffer born inside the deep scoped kill) and add it to the TIR transform tests (name it something like test_kill_reorder_with_intervening_buffer) so future changes to MergeSharedMemoryAllocations/MergeSharedMemoryAllocationsPass or the merge_shared_memory_allocations logic will fail if the invariant is violated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/transform/merge_shared_memory_allocations.cc`:
- Around line 1031-1034: Add a focused regression test that reproduces the “kill
reorder with intervening different-buffer gen” scenario relied upon by
merge_shared_memory_allocations.cc: create a minimal TIR function that contains
a pipelined loop with a T.alloc_shared allocation outside the loop, an inner
kill + reorder transformation opportunity, and an intervening different-buffer
generation so the pass must not merge or move the shared allocation incorrectly;
ensure the test asserts the transformation preserves semantics (no new shared
buffer born inside the deep scoped kill) and add it to the TIR transform tests
(name it something like test_kill_reorder_with_intervening_buffer) so future
changes to MergeSharedMemoryAllocations/MergeSharedMemoryAllocationsPass or the
merge_shared_memory_allocations logic will fail if the invariant is violated.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3813bfd5-b486-4c8a-bfbb-08282c78ab7a
📒 Files selected for processing (1)
src/transform/merge_shared_memory_allocations.cc
There was a problem hiding this comment.
Pull request overview
Fixes shared-memory liveness kill-point reordering in merge_shared_memory_allocations.cc so buffers with disjoint lifetimes don’t appear to overlap, enabling additional shared-memory reuse (e.g., Q_shared/O_shared reuse in Flash Attention to reduce SMEM footprint).
Changes:
- Adds an early-stop condition when reordering kill points: if the next statement generates a different shared-memory buffer, place the reordered kill before it.
- Documents the rationale for the early-stop to prevent false liveness overlap that blocks reuse.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Stop if the next statement births a different shared buffer. | ||
| auto next_event_it = event_map_.find(next_it->stmt); | ||
| if (next_event_it != event_map_.end() && | ||
| !next_event_it->second.gen.empty()) { | ||
| bool has_other_gen = false; |
There was a problem hiding this comment.
This change adjusts kill-point reordering to stop before the next statement that generates a different shared-memory buffer, which is a subtle correctness/perf-sensitive part of the shared-memory arena planner. Please add a regression test under testing/python/transform/ that constructs a minimal TIR with two shared buffers whose lifetimes are disjoint across a pipelined loop boundary (similar to the Q_shared/O_shared pattern), and assert that MergeSharedMemoryAllocations assigns the same byte offset (or otherwise demonstrates successful reuse) and does not report overlapping lifetimes.
| auto next_event_it = event_map_.find(next_it->stmt); | ||
| if (next_event_it != event_map_.end() && | ||
| !next_event_it->second.gen.empty()) { | ||
| bool has_other_gen = false; | ||
| for (const VarNode *gen_buf : next_event_it->second.gen) { |
There was a problem hiding this comment.
next_event_it != event_map_.end() is effectively always true here: gen_kill_seq is populated using event_map_[stmt_entry.stmt], which ensures every stmt in gen_kill_seq already has an entry in event_map_. Consider simplifying by using event_map_.at(next_it->stmt) (or equivalent) and dropping the find/end() check to reduce branching and avoid hiding unexpected missing-key bugs.
| auto next_event_it = event_map_.find(next_it->stmt); | |
| if (next_event_it != event_map_.end() && | |
| !next_event_it->second.gen.empty()) { | |
| bool has_other_gen = false; | |
| for (const VarNode *gen_buf : next_event_it->second.gen) { | |
| const auto& next_event = event_map_.at(next_it->stmt); | |
| if (!next_event.gen.empty()) { | |
| bool has_other_gen = false; | |
| for (const VarNode *gen_buf : next_event.gen) { |
Summary
merge_shared_memory_allocations.ccovershooting past another shared buffer's gen pointQ_shared/O_sharedmemory reuse in Flash Attention, saving 32KB shared memoryblock_M=128, block_N=128, num_stages=2now fits in A100's 164KB limit (160KB vs 192KB before)Reproduce
Using
examples/flash_attention/example_mha_fwd_bshd.py, call:Before:
Failed to set the allowed dynamic shared memory size to 196608After: Runs successfully (~12% faster than
num_stages=1thanks to double-buffered pipeline)Root cause
The kill reorder in
LivenessAnalysismoves buffer kills from deep scopes (inside pipeline loops) to their enclosing scope boundary. The forward scan searched for the first statement atgen_levelbut could land past another buffer's gen point:The fix adds an early-stop: when the next statement generates a different shared buffer, place the kill before it.
Test plan
examples/flash_attention/example_mha_fwd_bshd.pywithnum_stages=2, block_M=128, block_N=128: was 196608 error, now succeedsnum_stages=1: unchanged (147 TFlops)Summary by CodeRabbit