Skip to content

Add TL_DISABLE_SHARED_MEMORY_REUSE pass config#2228

Merged
LeiWang1999 merged 4 commits into
tile-ai:mainfrom
kurisu6912:add-disable-merge-smem-config
May 21, 2026
Merged

Add TL_DISABLE_SHARED_MEMORY_REUSE pass config#2228
LeiWang1999 merged 4 commits into
tile-ai:mainfrom
kurisu6912:add-disable-merge-smem-config

Conversation

@kurisu6912

@kurisu6912 kurisu6912 commented May 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Add TL_DISABLE_SHARED_MEMORY_REUSE pass config to disable lifetime-based shared memory reuse in MergeSharedMemoryAllocations. Allocations are still merged into a single buffer but each gets its own dedicated region without sharing.

Summary by CodeRabbit

  • New Features
    • Added a tl.disable_shared_memory_reuse pass configuration option to toggle shared-memory reuse. When enabled, buffers are still combined but placed into dedicated, deterministically ordered sequential regions (no lifetime-based reuse). The shared-memory merge transform now honors this option.
  • Tests
    • Added CUDA-focused tests that validate correctness with reuse enabled and disabled, ensuring outputs remain correct when reuse is turned off.

Review Change Stack

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented May 20, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a pass config key and wiring to disable shared-memory lifetime-based reuse; implements a sequential-layout fallback in the merge transform, updates the C++ and Python pass APIs, integrates the option into the engine pipeline, and adds CUDA tests validating correctness.

Changes

Disable shared-memory reuse feature

Layer / File(s) Summary
Config key declaration and FFI registration
tilelang/transform/pass_config.py, src/op/builtin.h, src/op/builtin.cc
PassConfigKey.TL_DISABLE_SHARED_MEMORY_REUSE enum member is defined and registered as a boolean pass configuration option at the TVM runtime level via the C++ header constant and pass-config registration.
Reuse planning: disable_reuse mode with sequential layout
src/transform/merge_shared_memory_allocations.cc
SharedMemoryRewriter::PlanReuse now accepts a disable_reuse flag. When true, it uses a new PlanSequentialLayout method to deterministically lay out shared buffers back-to-back with alignment, bypassing liveness analysis. When false, the existing reuse path (liveness + arena packing) is used. Includes the added tvm/tirx/stmt.h include.
C++ pass API surface: statement-level and pass factory
src/transform/merge_shared_memory_allocations.cc
The statement-level tl::MergeSharedMemoryAllocations and the FFI pass factory MergeSharedMemoryAllocations now accept disable_reuse; the flag is forwarded to PlanReuse and forces aggressive merge off when reuse is disabled.
Python transform wrapper
tilelang/transform/__init__.py
MergeSharedMemoryAllocations wrapper now accepts disable_reuse and forwards all three parameters (enable_aggressive_merge, align_bytes, disable_reuse) to the underlying FFI pass.
Engine helper and pipeline wiring
tilelang/engine/phase.py
New should_disable_shared_memory_reuse(pass_ctx) helper reads the config key from the active pass context; OptimizeForTarget computes and passes the disable_reuse flag into the transform.
Tests
testing/python/transform/test_tilelang_transform_disable_memory_reuse.py
Adds CUDA-gated tests that verify numerical correctness with TL_DISABLE_SHARED_MEMORY_REUSE enabled and disabled and exercise sequential/shared allocation patterns via JIT kernels.

Possibly related PRs

  • tile-ai/tilelang#1570: Modifies deterministic offset ordering in src/transform/merge_shared_memory_allocations.cc; related to offset computation logic that the new sequential layout relies on.
  • tile-ai/tilelang#1987: Fixes kill-reordering in shared-memory reuse planning; complementary to the disable_reuse codepath that provides an alternative to lifetime-based reuse.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 I bound through bytes and aligned each cue,

Placing buffers tidy, one, two, two.
No overlapping hops, just steady parade,
Config flips the plan that the rewriter made.
Hooray — sequential memory, neat and true!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.13% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately summarizes the main change: adding a new TL_DISABLE_SHARED_MEMORY_REUSE pass configuration flag. It is concise, specific, and clearly reflects the primary objective of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999

Copy link
Copy Markdown
Member

surprised the test can pass, because dynamic shared memory can only be allocated once. That means allocations for shared.dyn buffers must be merged. Would you mind double-checking the allocations?

@kurisu6912 kurisu6912 changed the title Add TL_DISABLE_MERGE_SHARED_MEMORY_ALLOCATIONS pass config Add TL_DISABLE_SHARED_MEMORY_REUSE pass config May 20, 2026
…in C++

Instead of skipping MergeSharedMemoryAllocations entirely, add a disable_reuse
flag that still merges allocations into one but lays out each buffer sequentially
without lifetime-based sharing.
@kurisu6912 kurisu6912 force-pushed the add-disable-merge-smem-config branch from 56424ec to f8c4721 Compare May 21, 2026 02:14

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@testing/python/transform/test_tilelang_transform_disable_memory_reuse.py`:
- Around line 105-118: The test only checks numeric equality but must assert a
structural difference in the lowered/kernel source between disable_reuse=True
and False; call make_kernel(disable_reuse=True) and
make_kernel(disable_reuse=False) (already present as kernel_no_reuse and
kernel_with_reuse), extract their lowered kernel source or allocation metadata
(e.g., the kernel source string or allocation list exposed by your lowering
API), and assert a concrete structural change such as different shared-memory
allocation layout/offsets or total shared-memory bytes (for example, no
overlapping offsets or larger total shared-memory size when reuse is disabled).
Implement an assertion comparing the allocation entries (names/offsets/size) or
a substring in the lowered source that indicates reuse was disabled versus
enabled to ensure the pass changed the allocation strategy.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ed1d1772-b0b9-4644-be7b-cf21dba76269

📥 Commits

Reviewing files that changed from the base of the PR and between 059fe6b and 0ee4417.

📒 Files selected for processing (1)
  • testing/python/transform/test_tilelang_transform_disable_memory_reuse.py

Comment thread testing/python/transform/test_tilelang_transform_disable_memory_reuse.py Outdated
@kurisu6912 kurisu6912 force-pushed the add-disable-merge-smem-config branch from 0ee4417 to 36e660a Compare May 21, 2026 06:19
@kurisu6912 kurisu6912 force-pushed the add-disable-merge-smem-config branch from 36e660a to 93861c5 Compare May 21, 2026 06:21
@LeiWang1999 LeiWang1999 merged commit b271a69 into tile-ai:main May 21, 2026
4 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants