Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/s_tir/transform/inject_permuted_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,17 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
return access_ptr_call;
}

// Device intrinsics are registered under both a flat name (the builtin Op)
// and a canonical dotted name (emitted by TVMScript and the tensor
// intrinsics), so compare against both.
static bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) {
if (call->op.same_as(compat_op)) {
return true;
}
const auto* op_node = call->op.as<OpNode>();
return op_node != nullptr && op_node->name == canonical_name;
}

PrimExpr VisitExpr_(const CallNode* op) final {
// Rewrite from/to shared or shared.dyn to/from local
auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
Comment thread
tlopex marked this conversation as resolved.
Expand All @@ -254,12 +265,12 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
return call;
}

if (!call->op.same_as(builtin::ptx_ldmatrix()) && !call->op.same_as(builtin::mma_store())) {
return call;
}

if (call->op.same_as(builtin::ptx_ldmatrix())) {
// form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
// Only the legacy intrinsic forms fold the shared memory access into a
// tvm_access_ptr + offset, which must be rewritten here. The non-legacy
// forms address shared memory through BufferLoad (e.g. via address_of),
// which is already handled by the BufferLoad visitor above.
if (IsOp(call, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) {
Comment thread
tlopex marked this conversation as resolved.
// form: T.ptx.ldmatrix_legacy(..., smem_ptr, smem_offset)
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
auto access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
Expand All @@ -268,17 +279,16 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
return call;
} else if (call->op.same_as(builtin::mma_store())) {
} else if (IsOp(call, builtin::mma_store_legacy(), "tirx.mma_store_legacy")) {
Comment thread
tlopex marked this conversation as resolved.
// TODO(yixin): mma_store is not fully tested yet
// because we will directly store result to Buffer instead of calling mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
return call;
} else {
TVM_FFI_THROW(InternalError) << "Invalid call node: " << call;
}
return call;
}

static constexpr size_t VECTORIZE_FACTOR = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,14 @@ def test_scalar_block_no_loops():
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
@T.prim_func(s_tir=True)
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")):
with T.sblock("scalar_add"):
c[()] = a[()] + b[()]

@tvm.script.ir_module
class Expected:
@T.prim_func
@T.prim_func(s_tir=True)
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")):
T.func_attr({"tirx.is_scheduled": True})
# with T.sblock("root"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
for i in T.serial(0, 16):
B_new = T.decl_buffer(shape=[1, 16], dtype="float32")
B_new = T.alloc_buffer(
[1, 16],
"float32",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
for j in T.serial(0, 16):
B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
Expand Down Expand Up @@ -98,7 +102,12 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local")
B = T.alloc_buffer(
[1, 16],
"float32",
scope="local",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
for j in range(0, 16):
B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
for j in range(0, 16):
Expand Down Expand Up @@ -133,7 +142,11 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32)
C = T.match_buffer(c, (n, m), "float32")

for i in range(0, n):
B = T.decl_buffer(shape=[m], dtype="float32")
B = T.alloc_buffer(
[m],
"float32",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
for j in range(0, m):
B[j] = A[i, j] + 1.0
for j in range(0, m):
Expand Down Expand Up @@ -206,8 +219,16 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None:
D = T.match_buffer(d, (32), "float32")

for i in range(0, 32):
B = T.decl_buffer(shape=(32,), dtype="float32")
C = T.decl_buffer(shape=(32,), dtype="float32")
B = T.alloc_buffer(
(32,),
"float32",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
C = T.alloc_buffer(
(32,),
"float32",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
B[i] = A[i] + 1.0
C[i] = A[i] + B[i]
D[i] = C[i] * 2.0
Expand Down Expand Up @@ -242,7 +263,12 @@ def transformed_strided_buffer_func(
) -> None:
# body
for i0 in T.serial(4):
B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1])
B = T.alloc_buffer(
[4, 16],
"float32",
strides=[17, 1],
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
for i1, j in T.grid(4, 16):
B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
for i1, j in T.grid(4, 16):
Expand Down Expand Up @@ -275,10 +301,11 @@ def transformed_symbolic_strided_buffer_func(a: T.handle):
n = T.int32()
A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
A_pad_shared_dyn = T.decl_buffer(
A_pad_shared_dyn = T.alloc_buffer(
(1, T.min((n + 63) // 64 * 64, 96), 64),
strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn",
annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
)
for ax0, ax1 in T.grid(96, 64):
if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:
Expand Down
Loading