From f664ac0de6aa612f4f00363b6c534db9b778d595 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 6 Dec 2022 14:14:58 -0800 Subject: [PATCH 1/2] Add DecomposePoint annotation to hint RewriteReductionBlock --- include/tvm/tir/stmt.h | 5 +++ .../postproc/rewrite_reduction_block.cc | 20 ++++++--- ...hedule_postproc_rewrite_reduction_block.py | 42 +++++++++++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5beea44cdb1a..c813591010dd 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1597,6 +1597,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl /*! \brief Mark auto-unroll setting on the block. */ constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; +/*! \brief Mark the target loop for decompose reduction. This serves as a hint for + * RewriteReductionBlock postprocessor. + */ +constexpr const char* meta_schedule_decompose_point = "meta_schedule.decompose_point"; + /*! \brief Mark that a block should be further rewritten using tensorization. */ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 05a7640f047c..ce3bd925485e 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -86,18 +86,23 @@ struct ReductionBlockFinder : private StmtVisitor { /*! * \brief Find the innermost loop that the `init` of the input block could be decomposed to * \param block_sref The StmtSRef of the block to be decomposed - * \return The index of the innermost loop where the `init` of the input block could be decomposed, - * or -1 if the `init` does not need to be decomposed. + * \return A pair of the loop index and the boolean flag. The index indicates the innermost loop + * where the `init` of the input block could be decomposed, or -1 if the `init` does not need to + * be decomposed. The boolean flag indicates whether these is a hint of the decompose point in the + * loop annotations that need to be unannotated. */ -int FindDecomposePoint(const StmtSRef& block_sref) { +std::pair FindDecomposePoint(const StmtSRef& block_sref) { Array loop_srefs = GetLoops(block_sref); int n = loop_srefs.size(); for (int i = 0; i < n; ++i) { + if (HasAnn(loop_srefs[i], attr::meta_schedule_decompose_point, true)) { + return {i, true}; + } if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { - return i; + return {i, false}; } } - return -1; + return {-1, false}; } } // namespace tir @@ -133,7 +138,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { for (const auto& kv : results) { const tir::StmtSRef& block_sref = kv.first; const String& global_var_name = kv.second; - int decompose_point = tir::FindDecomposePoint(block_sref); + auto [decompose_point, need_unannotate] = tir::FindDecomposePoint(block_sref); if (decompose_point == -1) { continue; } @@ -141,6 +146,9 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { Array loop_rvs = sch->GetLoops(block_rv); tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + if (need_unannotate) { + sch->Unannotate(loop_rvs[decompose_point], tir::attr::meta_schedule_decompose_point); + } // Rewrite auto tensorization related annotations if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py index 7e499424058d..87961436d510 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -218,6 +218,48 @@ def test_rewrite_softmax(): tvm.ir.assert_structural_equal(sch.mod, Softmax_cross_thread_reduction) +def test_rewrite_unit_reduction_loop(): + # fmt: off + # pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + @T.prim_func + def before(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i0, j0 in T.grid(4, 4): + for k0 in T.serial(1, annotations={"meta_schedule.decompose_point": 1}): + for i1, j1, k1 in T.grid(4, 4, 16): + with T.block("C"): + i = T.axis.spatial(16, i0 * 4 + i1) + j = T.axis.spatial(16, j0 * 4 + j1) + k = T.axis.reduce(16, k1) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @T.prim_func + def expected(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i0, j0 in T.grid(4, 4): + for i1_init, j1_init in T.grid(4, 4): + with T.block("C_init"): + i = T.axis.spatial(16, i0 * 4 + i1_init) + j = T.axis.spatial(16, j0 * 4 + j1_init) + C[i, j] = T.float32(0) + for k0, i1, j1, k1 in T.grid(1, 4, 4, 16): + with T.block("C_update"): + i = T.axis.spatial(16, i0 * 4 + i1) + j = T.axis.spatial(16, j0 * 4 + j1) + k = T.axis.reduce(16, k1) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + # fmt: on + + target = _target() + ctx = _create_context(before, target) + sch = tir.Schedule(before, debug_mask="all") + sch.enter_postproc() + print(tvm.ir.base.get_first_structural_mismatch(sch.mod, expected)) + tvm.ir.assert_structural_equal(sch.mod["main"], expected) + + if __name__ == "__main__": test_rewrite_tiled_matmul() test_rewrite_softmax() From 2c195669a4d6a07781b4fd42688e394ddf236397 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 6 Dec 2022 14:52:36 -0800 Subject: [PATCH 2/2] Add DecomposePoint hint in MultiLevelTiling --- .../multi_level_tiling_tensor_core.cc | 14 ++ ...hedule_postproc_rewrite_reduction_block.py | 1 - ...test_meta_schedule_schedule_rule_mlt_tc.py | 151 +++++++++--------- 3 files changed, 90 insertions(+), 76 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index d5cca52d41f9..6f656bc062a8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -130,6 +130,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { inline std::vector AddWriteReuseTensorCore(TensorCoreState state) const; // Subrule: Add software pipeline inline std::vector AddSoftwarePipeline(TensorCoreState state) const; + // Subrule: Add annotation of outermost reduction loop + inline std::vector AnnotateOutermostReduction(State state) const; // Override ApplySubRules to apply tensorization-specific sub-rules std::vector ApplySubRules(std::vector states) final; @@ -225,6 +227,8 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state)); }); states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = + SubRule(std::move(states), [&](State state) { return AnnotateOutermostReduction(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuseTensorCore(Downcast(state)); @@ -556,6 +560,16 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat return {std::move(state)}; } +inline std::vector MultiLevelTilingTensorCoreNode::AnnotateOutermostReduction( + State state) const { + Schedule& sch = state->sch; + if (r_indices_.size()) { + LoopRV outermost_reduction_loop = state->tiles[r_indices_.front()].front(); + sch->Annotate(outermost_reduction_loop, tir::attr::meta_schedule_decompose_point, Integer(1)); + } + return {std::move(state)}; +} + ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( Array> intrin_groups, String structure, Optional> tile_binds, Optional max_innermost_factor, Optional> vector_load_lens, diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py index 87961436d510..78dfa9b7606a 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -256,7 +256,6 @@ def expected(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], ctx = _create_context(before, target) sch = tir.Schedule(before, debug_mask="all") sch.enter_postproc() - print(tvm.ir.base.get_first_structural_mismatch(sch.mod, expected)) tvm.ir.assert_structural_equal(sch.mod["main"], expected) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index acc626b904a1..c681b477336b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -83,7 +83,7 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): + for ax2_0_0 in T.serial(1, annotations={"meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) @@ -225,7 +225,7 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): + for ax2_0_0 in T.serial(2, annotations={"meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(2048): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) @@ -376,7 +376,7 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): + for ax2_0_0 in T.serial(1, annotations={"meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(4608): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) @@ -539,7 +539,7 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1], "meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(1024): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) @@ -686,7 +686,7 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): + for ax2_0_0 in T.serial(2, annotations={"meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(8192): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 64) @@ -841,7 +841,7 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): + for ax2_0_0 in T.serial(1, annotations={"meta_schedule.decompose_point": 1}): for ax0_ax1_fused in T.serial(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) @@ -979,78 +979,79 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): - for ax0_ax1_fused in T.serial(1024): - with T.block("PadInput_reindex_shared"): - v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) - v1 = T.axis.spatial(64, ax0_ax1_fused % 64) - T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) - T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] - for ax0_ax1_ax2_ax3_fused in T.serial(2048): - with T.block("weight_reindex_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32) - v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) - T.reads(weight[v0, v1, v2, v3]) - T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) - weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): - for ax0_0_1, ax1_0_1 in T.grid(1, 4): - with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) - v1_o = T.axis.spatial(4, ax1_0_1) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) - for ax0_1_1, ax1_1_1 in T.grid(16, 16): - with T.block("PadInput_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) - T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1): - with T.block("weight_reindex_shared_wmma.matrix_b_o"): + for ax0_0 in T.serial(1, annotations={"meta_schedule.decompose_point":1}): + for ax1_0, ax4_0_0 in T.grid(1, 1): + for ax0_ax1_fused in T.serial(1024): + with T.block("PadInput_reindex_shared"): + v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(64, ax0_ax1_fused % 64) + T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) + T.writes(PadInput_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2_o = T.axis.spatial(4, ax2_0) - v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) - for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("weight_reindex_shared_wmma.matrix_b"): - v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] - for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): - with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, 0) - v1 = T.axis.reduce(1, 0) - v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) - v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) - v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) - with T.init(): + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32) + v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) + weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): + for ax0_0_1, ax1_0_1 in T.grid(1, 4): + with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1_o = T.axis.spatial(4, ax1_0_1) + T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1_1, ax1_1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1): + with T.block("weight_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(4, ax2_0) + v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) for ax2_1, ax3_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_init"): - v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) - T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) - for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): - with T.block("conv2d_nhwc"): - v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") + with T.block("weight_reindex_shared_wmma.matrix_b"): + v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) + v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) + v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) + for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused)