diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..a9a373151db9 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -101,13 +101,16 @@ bool ProducerCoversConsumer(const Array& buffer_shape, if (produced_region[i].IsNothing()) { return false; } - arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size}); - arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size}); - PrimExpr produced_min = analyzer->Simplify(produced.min()); - PrimExpr produced_max = analyzer->Simplify(produced.max()); - PrimExpr consumed_min = analyzer->Simplify(consumed.min()); - PrimExpr consumed_max = analyzer->Simplify(consumed.max()); - if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) { + arith::IntSet produced = + arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()), + analyzer->canonical_simplify(produced_region[i].max())); + arith::IntSet consumed = + arith::IntSet::Interval(analyzer->canonical_simplify(consumed_region[i].min()), + analyzer->canonical_simplify(consumed_region[i].max())); + produced = arith::Intersect({produced, buffer_size}); + consumed = arith::Intersect({consumed, buffer_size}); + if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) && + (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) { return false; } } diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e88eacdb453b..8b731404f142 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -763,7 +763,7 @@ def test_non_perfect_tiling_cache(): ) assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( affine_binding=True, - region_cover=False, + region_cover=True, stage_pipeline=True, ) # pylint: enable=protected-access