diff --git a/src/s_tir/transform/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc index fe90f38cec67..74e843a6e5d0 100644 --- a/src/s_tir/transform/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -246,6 +246,17 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return access_ptr_call; } + // Device intrinsics are registered under both a flat name (the builtin Op) + // and a canonical dotted name (emitted by TVMScript and the tensor + // intrinsics), so compare against both. + static bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) { + if (call->op.same_as(compat_op)) { + return true; + } + const auto* op_node = call->op.as(); + return op_node != nullptr && op_node->name == canonical_name; + } + PrimExpr VisitExpr_(const CallNode* op) final { // Rewrite from/to shared or shared.dyn to/from local auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); @@ -254,12 +265,12 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return call; } - if (!call->op.same_as(builtin::ptx_ldmatrix()) && !call->op.same_as(builtin::mma_store())) { - return call; - } - - if (call->op.same_as(builtin::ptx_ldmatrix())) { - // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset) + // Only the legacy intrinsic forms fold the shared memory access into a + // tvm_access_ptr + offset, which must be rewritten here. The non-legacy + // forms address shared memory through BufferLoad (e.g. via address_of), + // which is already handled by the BufferLoad visitor above. + if (IsOp(call, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { + // form: T.ptx.ldmatrix_legacy(..., smem_ptr, smem_offset) // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) auto access_ptr = call->args[5]; PrimExpr smem_offset = call->args[6]; @@ -268,7 +279,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { new_call->args.Set(5, new_access_ptr); new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); return call; - } else if (call->op.same_as(builtin::mma_store())) { + } else if (IsOp(call, builtin::mma_store_legacy(), "tirx.mma_store_legacy")) { // TODO(yixin): mma_store is not fully tested yet // because we will directly store result to Buffer instead of calling mma_store now auto access_ptr = call->args[2]; @@ -276,9 +287,8 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { auto new_call = call.CopyOnWrite(); new_call->args.Set(2, new_access_ptr); return call; - } else { - TVM_FFI_THROW(InternalError) << "Invalid call node: " << call; } + return call; } static constexpr size_t VECTORIZE_FACTOR = 8; diff --git a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py index 891ba3f20869..875fe1818247 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py @@ -575,14 +575,14 @@ def test_scalar_block_no_loops(): # fmt: off @tvm.script.ir_module class Before: - @T.prim_func + @T.prim_func(s_tir=True) def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")): with T.sblock("scalar_add"): c[()] = a[()] + b[()] @tvm.script.ir_module class Expected: - @T.prim_func + @T.prim_func(s_tir=True) def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")): T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py index 62ad915a575d..441074128e6d 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py @@ -56,7 +56,11 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new = T.decl_buffer(shape=[1, 16], dtype="float32") + B_new = T.alloc_buffer( + [1, 16], + "float32", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -98,7 +102,12 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local") + B = T.alloc_buffer( + [1, 16], + "float32", + scope="local", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -133,7 +142,11 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = T.decl_buffer(shape=[m], dtype="float32") + B = T.alloc_buffer( + [m], + "float32", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -206,8 +219,16 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B = T.decl_buffer(shape=(32,), dtype="float32") - C = T.decl_buffer(shape=(32,), dtype="float32") + B = T.alloc_buffer( + (32,), + "float32", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) + C = T.alloc_buffer( + (32,), + "float32", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 @@ -242,7 +263,12 @@ def transformed_strided_buffer_func( ) -> None: # body for i0 in T.serial(4): - B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1]) + B = T.alloc_buffer( + [4, 16], + "float32", + strides=[17, 1], + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, + ) for i1, j in T.grid(4, 16): B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) for i1, j in T.grid(4, 16): @@ -275,10 +301,11 @@ def transformed_symbolic_strided_buffer_func(a: T.handle): n = T.int32() A = T.match_buffer(a, (1, n, 10240)) for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): - A_pad_shared_dyn = T.decl_buffer( + A_pad_shared_dyn = T.alloc_buffer( (1, T.min((n + 63) // 64 * 64, 96), 64), strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn", + annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64}, ) for ax0, ax1 in T.grid(96, 64): if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64: