Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSharedMemoryReuse, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
Expand Down
2 changes: 2 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";
static constexpr const char *kDisableSharedMemoryReuse =
"tl.disable_shared_memory_reuse";
static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
static constexpr const char *kEnableFastMath = "tl.enable_fast_math";
static constexpr const char *kPtxasRegisterUsageLevel =
Expand Down
102 changes: 91 additions & 11 deletions src/transform/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include <tvm/tirx/function.h>
#include <tvm/tirx/stmt.h>

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -460,18 +461,91 @@ class SharedMemoryRewriter : public StmtExprMutator {
* \param stmt the statement
*/
void PlanReuse(const Stmt &stmt, bool is_dynamic = true,
bool enable_aggressive_merge = false, bool verbose = false) {
bool enable_aggressive_merge = false, bool verbose = false,
bool disable_reuse = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic,
enable_aggressive_merge, verbose);
finder(stmt);
shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt);
// First compute liveness over the flattened schedule, then feed it into the
// arena packer.
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
if (disable_reuse) {
this->PlanSequentialLayout();
} else {
// First compute liveness over the flattened schedule, then feed it into
// the arena packer.
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
}
}

private:
/*!
* \brief Lay out all shared memory buffers sequentially without any reuse.
* Each buffer gets its own dedicated region in the merged allocation.
*/
void PlanSequentialLayout() {
buffer_byte_offsets_.clear();

if (shmem_allocs_.empty()) {
merged_alloc_size_ = make_const(DataType::Int(64), 0);
return;
}

// Sort allocations deterministically by name.
std::vector<const VarNode *> sorted_vars;
sorted_vars.reserve(shmem_allocs_.size());
for (const auto &kv : shmem_allocs_) {
sorted_vars.push_back(kv.first);
}
std::sort(sorted_vars.begin(), sorted_vars.end(),
[](const VarNode *a, const VarNode *b) {
return a->name_hint < b->name_hint;
});

DataType offset_dtype = DataType::Int(32);
PrimExpr cursor = make_const(offset_dtype, 0);
PrimExpr total_size = make_const(offset_dtype, 0);

for (const VarNode *var : sorted_vars) {
const AllocBufferNode *alloc = shmem_allocs_.at(var);
int64_t bytes_per_elem = static_cast<int64_t>(
alloc->buffer->dtype.bytes() * alloc->buffer->dtype.lanes());

DataType size_dtype = DataType::Int(32);
if (!alloc->buffer->shape.empty()) {
size_dtype = alloc->buffer->shape[0].dtype();
}
if (!size_dtype.is_int() && !size_dtype.is_uint()) {
size_dtype = DataType::Int(32);
}

PrimExpr size_expr = make_const(size_dtype, bytes_per_elem);
for (const PrimExpr &extent : alloc->buffer->shape) {
PrimExpr e = extent;
if (e.dtype() != size_dtype) {
e = cast(size_dtype, e);
}
size_expr = size_expr * e;
}

int alignment = align_bytes_;
auto align_it = shmem_alignment_map_.find(var);
if (align_it != shmem_alignment_map_.end()) {
alignment = std::max(alignment, align_it->second);
}

cursor = AlignPrimExpr(cursor, alignment);
if (size_expr.dtype() != offset_dtype) {
size_expr = cast(offset_dtype, size_expr);
}
buffer_byte_offsets_[var] = cursor;
PrimExpr buf_end = cursor + size_expr;
total_size = max(total_size, buf_end);
cursor = buf_end;
}

merged_alloc_size_ = AlignPrimExpr(total_size, align_bytes_);
}

Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tirx::attr::thread_extent && !allocated_) {
// Allocate one dynamic shared memory allocation at the beginning of
Expand Down Expand Up @@ -1495,19 +1569,24 @@ class SharedMemoryRewriter : public StmtExprMutator {

Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
bool enable_aggressive_merge,
int align_bytes = 16, bool verbose = false) {
int align_bytes = 16, bool verbose = false,
bool disable_reuse = false) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose,
align_bytes);
rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
rewriter.PlanReuse(stmt, true,
disable_reuse ? false : enable_aggressive_merge, false,
disable_reuse);
stmt = rewriter(std::move(stmt));
}
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
verbose, align_bytes);
rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
rewriter.PlanReuse(stmt, false,
disable_reuse ? false : enable_aggressive_merge, false,
disable_reuse);
stmt = rewriter(std::move(stmt));
}
return stmt;
Expand All @@ -1518,8 +1597,9 @@ using namespace tirx::transform;
namespace transform {

Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
int align_bytes = 16) {
auto pass_func = [enable_aggressive_merge, align_bytes](
int align_bytes = 16,
bool disable_reuse = false) {
auto pass_func = [enable_aggressive_merge, align_bytes, disable_reuse](
PrimFunc f, const IRModule &m, PassContext ctx) {
bool merge_static_smem =
ctx->GetConfig<Bool>("tirx.merge_static_smem", Bool(false)).value();
Expand All @@ -1529,7 +1609,7 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
auto *n = f.CopyOnWrite();
n->body = tl::MergeSharedMemoryAllocations(
std::move(n->body), merge_static_smem, enable_aggressive_merge,
align_bytes, debug_merge_shared_memory_allocations);
align_bytes, debug_merge_shared_memory_allocations, disable_reuse);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Tests for TL_DISABLE_SHARED_MEMORY_REUSE pass config.

When `TL_DISABLE_SHARED_MEMORY_REUSE` is True, shared memory allocations are still
merged into a single buffer, but each buffer gets its own dedicated region without
lifetime-based reuse (i.e., no two buffers share the same offset even if their
lifetimes don't overlap).
"""

import re

import torch
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import PassConfigKey

N = 1024


def _make_data_integrity_kernel():

@tilelang.jit(
pass_configs={
PassConfigKey.TL_DISABLE_SHARED_MEMORY_REUSE: True,
},
)
def kernel(A, B, C, D):
A: T.Tensor[[N], T.float16]
B: T.Tensor[[N], T.float16]
C: T.Tensor[[N], T.float16]
D: T.Tensor[[N], T.float16]

with T.Kernel(1, threads=128):
a_shared = T.alloc_shared([N], T.float16)
b_shared = T.alloc_shared([N], T.float16)
c_shared = T.alloc_shared([N], T.float16)
d_frag = T.alloc_fragment([N], T.float16)

T.copy(A, a_shared)
T.copy(B, b_shared)
T.copy(C, c_shared)

for i in T.Parallel(N):
d_frag[i] = a_shared[i] + b_shared[i] + c_shared[i]

T.copy(d_frag, D)

return kernel


def _make_no_overlap_kernel(disable_reuse: bool):

@tilelang.jit(
pass_configs={
PassConfigKey.TL_DISABLE_SHARED_MEMORY_REUSE: disable_reuse,
},
)
def kernel(A, B, A_out, B_out):
A: T.Tensor[[N], T.float16]
B: T.Tensor[[N], T.float16]
A_out: T.Tensor[[N], T.float16]
B_out: T.Tensor[[N], T.float16]

with T.Kernel(1, threads=128):
a_shared = T.alloc_shared([N], T.float16)
T.copy(A, a_shared)
T.copy(a_shared, A_out)

b_shared = T.alloc_shared([N], T.float16)
T.copy(B, b_shared)
T.copy(b_shared, B_out)

return kernel


@tilelang.testing.requires_cuda
def test_disable_reuse_data_integrity():
"""Allocate multiple shared buffers, copy data in, compute sum, copy out.

Verifies data integrity when shared memory reuse is disabled.
"""
kernel = _make_data_integrity_kernel()

a = torch.randn(N, device="cuda", dtype=torch.float16)
b = torch.randn(N, device="cuda", dtype=torch.float16)
c = torch.randn(N, device="cuda", dtype=torch.float16)
d = torch.empty(N, device="cuda", dtype=torch.float16)

kernel(a, b, c, d)
ref = a + b + c
torch.testing.assert_close(d, ref, rtol=1e-3, atol=1e-3)


@tilelang.testing.requires_cuda
def test_disable_reuse_no_overlap():
"""Two sequential buffers must NOT share the same offset when reuse is disabled.

a_shared is used (copy in, copy out), then b_shared is used (copy in, copy out).
Their lifetimes don't overlap, so with reuse enabled they share offset 0.
With reuse disabled, b_shared must get a different (non-zero) offset.
"""
kernel_no_reuse = _make_no_overlap_kernel(disable_reuse=True)
kernel_reuse = _make_no_overlap_kernel(disable_reuse=False)

a = torch.randn(N, device="cuda", dtype=torch.float16)
b = torch.randn(N, device="cuda", dtype=torch.float16)
a_out = torch.empty(N, device="cuda", dtype=torch.float16)
b_out = torch.empty(N, device="cuda", dtype=torch.float16)

# Correctness: both should produce correct results
kernel_no_reuse(a, b, a_out, b_out)
torch.testing.assert_close(a_out, a, rtol=0, atol=0)
torch.testing.assert_close(b_out, b, rtol=0, atol=0)

a_out2 = torch.empty(N, device="cuda", dtype=torch.float16)
b_out2 = torch.empty(N, device="cuda", dtype=torch.float16)
kernel_reuse(a, b, a_out2, b_out2)
torch.testing.assert_close(a_out2, a, rtol=0, atol=0)
torch.testing.assert_close(b_out2, b, rtol=0, atol=0)

# Verify shared memory layout difference:
# With reuse: both buffers use offset 0 in buf_dyn_shmem.
# Without reuse: second buffer gets a separate region (non-zero offset).
src_no_reuse = kernel_no_reuse.get_kernel_source()
src_reuse = kernel_reuse.get_kernel_source()

def extract_smem_element_offsets(src: str) -> list[int]:
"""Extract element offsets from buf_dyn_shmem[OFFSET] patterns in generated code."""
# Matches patterns like: buf_dyn_shmem)[1024]) used in tma_store calls
pattern = r"buf_dyn_shmem\)\[(\d+)\]"
return sorted(set(int(m) for m in re.findall(pattern, src)))

offsets_no_reuse = extract_smem_element_offsets(src_no_reuse)
offsets_reuse = extract_smem_element_offsets(src_reuse)

# With reuse disabled: must have at least 2 distinct offsets (buffers not merged)
assert len(offsets_no_reuse) >= 2, f"Expected >=2 distinct smem offsets with reuse disabled, got {offsets_no_reuse}"

# With reuse enabled: should have only 1 offset (buffers share the same region)
assert len(offsets_reuse) == 1, f"Expected 1 smem offset with reuse enabled, got {offsets_reuse}"


if __name__ == "__main__":
test_disable_reuse_data_integrity()
test_disable_reuse_no_overlap()
print("All tests passed!")
9 changes: 8 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def should_enable_prelower_semantic_check(pass_ctx: PassContext | None = None) -
return enabled


def should_disable_shared_memory_reuse(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx.config.get(tilelang.PassConfigKey.TL_DISABLE_SHARED_MEMORY_REUSE, False))


def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
Expand Down Expand Up @@ -272,7 +278,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
disable_reuse = should_disable_shared_memory_reuse(pass_ctx=pass_ctx)
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge, disable_reuse=disable_reuse)(mod)
# InjectFenceProxy is a no-op on targets that lack the TMA / async-proxy
# programming model; the pass itself checks the PrimFunc's target.
mod = tilelang.transform.InjectFenceProxy()(mod)
Expand Down
4 changes: 2 additions & 2 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ def FlattenBuffer():
return _ffi_api.FlattenBuffer() # type: ignore


def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_bytes: int = 16):
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_bytes: int = 16, disable_reuse: bool = False):
"""MergeSharedMemoryAllocations

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes, disable_reuse) # type: ignore


def MarkCudaSyncCalls(have_pdl: bool = False):
Expand Down
6 changes: 6 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ class PassConfigKey(str, Enum):
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
"""Enable aggressive merge of shared memory allocations. Default: False"""

TL_DISABLE_SHARED_MEMORY_REUSE = "tl.disable_shared_memory_reuse"
"""Disable shared memory reuse planning in MergeSharedMemoryAllocations.
When enabled, shared memory allocations are still merged into a single
allocation but each buffer gets its own dedicated region without lifetime-based
reuse. Default: False"""

TL_DISABLE_SHUFFLE_ELECT = "tl.disable_shuffle_elect"
"""Disable shuffle election optimization. Default: False"""

Expand Down
Loading