diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 2ea641a2cbd4..d54be8a05fdc 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -214,6 +214,32 @@ class OpaqueAccessError : public ScheduleError { Block scope_root_; }; +class ProducerHasNonTrivialPredicateError : public ScheduleError { + public: + explicit ProducerHasNonTrivialPredicateError(IRModule mod, BlockRealize producer, + PrimExpr new_predicate) + : mod_(mod), producer_(producer), new_predicate_(new_predicate) {} + + String FastErrorString() const final { + return "ScheduleError: The producer block has a non-trivial predicate."; + } + + String DetailRenderTemplate() const final { + return "ScheduleError: The producer block {0} has a non-trivial predicate " + + PrettyPrint(producer_->predicate) + + " that cannot be implied " + "by the synthesized predicate " + + PrettyPrint(new_predicate_) + " of the new inlined block."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {producer_}; } + + IRModule mod_; + BlockRealize producer_; + PrimExpr new_predicate_; +}; + /*! * \brief The base class of the inliner, which handles: * 1) Substitute a subtree with the specific block being inlined @@ -533,10 +559,11 @@ class ReverseComputeInliner : public BaseInliner { public: explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block, const BlockRealize& consumer_block_realize, - const StmtSRef& scope_root_sref) + const StmtSRef& scope_root_sref, const IRModule& mod) : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref), producer_block_(producer_block), - consumer_block_(consumer_block_realize->block.get()) { + consumer_block_(consumer_block_realize->block.get()), + mod_(mod) { // Initialize the predicates to ensure consumer block iters are in-bound consumer_iter_in_bound_ = Bool(true); for (const IterVar& iter : consumer_block_realize->block->iter_vars) { @@ -632,8 +659,15 @@ class ReverseComputeInliner : public BaseInliner { Stmt VisitStmt_(const BlockRealizeNode* op) final { BlockRealize new_block_realize = Downcast(StmtMutator::VisitStmt_(op)); if (op->block.get() == producer_block_) { - new_block_realize.CopyOnWrite()->predicate = - BuildInlinedConsumerPredicate(new_block_realize.get()); + auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get()); + + With ctx(&analyzer_, new_predicate); + if (!analyzer_.CanProve(op->predicate)) { + // We do not allow cases where the new predicate for the inlined block cannot + // imply the original predicate in the producer block. + throw ProducerHasNonTrivialPredicateError(mod_, GetRef(op), new_predicate); + } + new_block_realize.CopyOnWrite()->predicate = new_predicate; } return std::move(new_block_realize); } @@ -749,6 +783,8 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr consumer_iter_in_bound_{nullptr}; /*! \brief The arithmetic analyzer */ arith::Analyzer analyzer_; + /*! \brief The target module, only used for error reporting. */ + const IRModule& mod_; }; void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, @@ -814,7 +850,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); // Step 4. Analyze the block body ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs(), - consumer_block_realize, scope_root_sref); + consumer_block_realize, scope_root_sref, self->mod); if (!inliner.BodyPatternAllowInline(consumer_block_realize)) { throw BodyAnalysisError(true, self->mod, consumer_block); } diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 7e361d2c095c..c8e6bf6a0c73 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -3037,8 +3037,8 @@ def test_inline_order(): # reverse-inlined at the very end of ScheduleUsingAnchorTrace, where its producer block # "conv2d_nhwc_reindex_shared" has the predicate # T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) due to anchor-block scheduling - # (see Conv2dInt8_with_predicate_scheduled). Currently, if we try to reverse-inline a block to - # its producer that has a predicate, the predicate disappears after reverse inlining. + # (see Conv2dInt8_with_predicate_scheduled). ReverseComputeInline cannot be applied in + # such cases. def apply_trace(sch: Schedule) -> None: b0 = sch.get_block(name="pad_temp", func_name="main") diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 20eafabc7a22..f9c5e22e97ce 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -626,6 +626,158 @@ def elementwise_producer_not_cover_consumer( D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32") +@T.prim_func +def elementwise_predicate_producer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((127, 128)) + C = T.match_buffer(c, (127, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(i < 127) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(127, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def elementwise_predicate_producer_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (127, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + T.where(i < 127) + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) + + +# fmt: off +@tvm.script.ir_module +class Conv2dInt8_TensorCore_with_predicate: + @T.prim_func + def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[256, "int32"], p5: T.Buffer[256, "int32"], p6: T.Buffer[256, "int32"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "int32"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") + conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") + p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): + for ax0_ax1_fused in T.serial(1024): + with T.block("pad_temp_reindex_shared"): + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + T.writes(pad_temp_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) + pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("p1_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) + v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1[v2, v0, v1, v3]) + T.writes(p1_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) + p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] + for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): + for ax0_0_1, ax1_0_1 in T.grid(1, 1): + with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"}) + for ax0_1_1, ax1_1_1 in T.grid(16, 16): + with T.block("pad_temp_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) + T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): + with T.block("p1_reindex_shared_wmma.matrix_b_o"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) + v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"}) + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("p1_reindex_shared_wmma.matrix_b"): + v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.reduce(1, 0) + v1 = T.axis.reduce(1, 0) + v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) + T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + with T.init(): + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 + for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32") + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) + v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): + with T.block("conv2d_nhwc_reindex_shared"): + T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) + v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) + v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) + T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0]) + T.writes(compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) + compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32") + for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): + with T.block("compute_4"): + i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) + T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13]) + T.writes(compute[i0_13, i1_13, i2_13, i3_13]) + compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(compute_3[i0_13, i1_13, i2_13, i3_13] + T.q_multiply_shift(p9[i0_13, i1_13, i2_13, i3_13], 2101000910, 31, 0, dtype="int32"), 255), 0) +# fmt: on + # pylint: enable=no-member,invalid-name,unused-variable use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) @@ -883,5 +1035,31 @@ def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name sch.reverse_compute_inline(compute) +def test_reverse_compute_inline_producer_predicate_allowed(): + """Test a case where reverse compute inline is allowed even though the producer has a + non-trivial predicate. + """ + + sch = tir.Schedule(elementwise_predicate_producer, debug_mask="all") + sch.reverse_compute_inline(sch.get_block("C")) + tvm.ir.assert_structural_equal(elementwise_predicate_producer_inlined, sch.mod["main"]) + + +def test_reverse_compute_inline_producer_predicate_disallowed(): + """Test reverse compute inline failure when the producer has a non-trivial predicate that cannot be + implied by the synthesized predicate of the new inlined block. + """ + + sch = tir.Schedule(Conv2dInt8_TensorCore_with_predicate, debug_mask="all") + + with pytest.raises(tvm.tir.ScheduleError) as e: + sch.reverse_compute_inline(sch.get_block("compute_4")) + + assert ( + "that cannot be implied by the synthesized predicate True of the new inlined block" + in str(e) + ) + + if __name__ == "__main__": tvm.testing.main()