diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 8f912c59ea16..fec214fa1fc7 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -340,7 +340,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block } auto consumers = sch->GetConsumers(block_rv); for (const auto& consumer : consumers) { - sch->ComputeInline(consumer); + auto sref = sch->GetSRef(consumer); + if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) + sch->ComputeInline(consumer); } } // Construct a mapping from tir loops back to LoopRVs diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index df8607e55127..1fd2ab84749e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( ) +def test_padded_conv(): + # fmt: off + @T.prim_func + def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator") + PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared") + weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared") + PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a") + weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(10): + for ax0_ax1_fused in range(28672): + with T.block("PadInput_reindex_pad_shared"): + v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16) + v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16) + T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3]) + T.writes(PadInput_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0)) + for ax0_ax1_fused in range(512): + with T.block("weight_reindex_pad_shared"): + v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) + T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1]) + T.writes(weight_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0)) + for ax2_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(14, 1): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0) + v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0) + T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0) + T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1): + with T.block("conv2d_nhwc_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2) + T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(14): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2_o = T.axis.spatial(14, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512) + v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2 = T.axis.spatial(14, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [7, 1, 8, 7, 2]), + ("SamplePerfectTile", [2, 1, 1, 2, 1]), + ("SamplePerfectTile", [10, 1, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + 1, + 224, + 224, + 3, + 64, + 7, + 2, + 3, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_conv2d_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": tvm.testing.main()