From f22966202bb96a5345d3eb4ea6a8c7202295c0b2 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 15 Jan 2022 22:42:36 +0800 Subject: [PATCH 1/3] canonical simplify the intset before region cover proof --- src/tir/schedule/state.cc | 17 ++++++++++------- .../test_tir_schedule_state_cached_flags.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..ba5928713182 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -101,13 +101,16 @@ bool ProducerCoversConsumer(const Array& buffer_shape, if (produced_region[i].IsNothing()) { return false; } - arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size}); - arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size}); - PrimExpr produced_min = analyzer->Simplify(produced.min()); - PrimExpr produced_max = analyzer->Simplify(produced.max()); - PrimExpr consumed_min = analyzer->Simplify(consumed.min()); - PrimExpr consumed_max = analyzer->Simplify(consumed.max()); - if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) { + arith::IntSet produced = + arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()), + analyzer->canonical_simplify(produced_region[i].max())); + arith::IntSet consumed = + arith::IntSet::Interval(analyzer->canonical_simplify(consumed_region[i].min()), + analyzer->canonical_simplify(consumed_region[i].max())); + produced = arith::Intersect({produced, buffer_size}); + consumed = arith::Intersect({consumed, buffer_size}); + if (!analyzer->CanProve((produced.min() <= consumed.min()) && + (consumed.max() <= produced.max()))) { return false; } } diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e88eacdb453b..8b731404f142 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -763,7 +763,7 @@ def test_non_perfect_tiling_cache(): ) assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( affine_binding=True, - region_cover=False, + region_cover=True, stage_pipeline=True, ) # pylint: enable=protected-access From f076822b908eac3429bf0d8a273ee0d066facbea Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 16 Jan 2022 23:16:06 +0800 Subject: [PATCH 2/3] add more rewrite rule for intimms to fix the issue after rebase --- src/arith/rewrite_simplify.cc | 20 +++++++++++++++++++ .../unittest/test_arith_rewrite_simplify.py | 20 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..b988a646be5c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -153,6 +153,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); + TVM_TRY_REWRITE(min(c1, x - c2) + c3, min(c1 + c3, x + (c3 - c2))); + TVM_TRY_REWRITE(min(c1, x + c2) + c3, min(c1 + c3, x + (c3 + c2))); + TVM_TRY_REWRITE(min(x - c2, c1) + c3, min(x + (c3 - c2), c1 + c3)); + TVM_TRY_REWRITE(min(x + c2, c1) + c3, min(x + (c3 + c2), c1 + c3)); + + TVM_TRY_REWRITE(max(c1, x - c2) + c3, max(c1 + c3, x + (c3 - c2))); + TVM_TRY_REWRITE(max(c1, x + c2) + c3, max(c1 + c3, x + (c3 + c2))); + TVM_TRY_REWRITE(max(x - c2, c1) + c3, max(x + (c3 - c2), c1 + c3)); + TVM_TRY_REWRITE(max(x + c2, c1) + c3, max(x + (c3 + c2), c1 + c3)); + TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), @@ -296,6 +306,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); + TVM_TRY_REWRITE(min(c1, x - c2) - c3, min(c1 - c3, x - (c2 + c3))); + TVM_TRY_REWRITE(min(c1, x + c2) - c3, min(c1 - c3, x + (c2 - c3))); + TVM_TRY_REWRITE(min(x - c2, c1) - c3, min(x - (c2 + c3), c1 - c3)); + TVM_TRY_REWRITE(min(x + c2, c1) - c3, min(x + (c2 - c3), c1 - c3)); + + TVM_TRY_REWRITE(max(c1, x - c2) - c3, max(c1 - c3, x - (c2 + c3))); + TVM_TRY_REWRITE(max(c1, x + c2) - c3, max(c1 - c3, x + (c2 - c3))); + TVM_TRY_REWRITE(max(x - c2, c1) - c3, max(x - (c2 + c3), c1 - c3)); + TVM_TRY_REWRITE(max(x + c2, c1) - c3, max(x + (c2 - c3), c1 - c3)); + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 6ca2a2a5fcb0..326dc4c15194 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -239,6 +239,16 @@ def test_add_index_simplify(): ck.verify(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y)) ck.verify(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1)) + ck.verify(tvm.te.min(1, x + 3) + 1, tvm.te.min(2, x + 4)) + ck.verify(tvm.te.min(1, x - 3) + 1, tvm.te.min(2, x + -2)) + ck.verify(tvm.te.min(x + 3, 1) + 1, tvm.te.min(x + 4, 2)) + ck.verify(tvm.te.min(x - 3, 1) + 1, tvm.te.min(x + -2, 2)) + + ck.verify(tvm.te.max(1, x + 3) + 1, tvm.te.max(2, x + 4)) + ck.verify(tvm.te.max(1, x - 3) + 1, tvm.te.max(2, x + -2)) + ck.verify(tvm.te.max(x + 3, 1) + 1, tvm.te.max(x + 4, 2)) + ck.verify(tvm.te.max(x - 3, 1) + 1, tvm.te.max(x + -2, 2)) + ck.verify(tvm.te.max(0, 1 - x * 4) + x * 4, tvm.te.max(x * 4, 1)) ck.verify(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)) @@ -326,6 +336,16 @@ def test_sub_index_simplify(): ck.verify(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y)) ck.verify(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y)) + ck.verify(tvm.te.min(1, x + 3) - 1, tvm.te.min(0, x + 2)) + ck.verify(tvm.te.min(1, x - 3) - 1, tvm.te.min(0, x + -4)) + ck.verify(tvm.te.min(x + 3, 1) - 1, tvm.te.min(x + 2, 0)) + ck.verify(tvm.te.min(x - 3, 1) - 1, tvm.te.min(x + -4, 0)) + + ck.verify(tvm.te.max(1, x + 3) - 1, tvm.te.max(0, x + 2)) + ck.verify(tvm.te.max(1, x - 3) - 1, tvm.te.max(0, x + -4)) + ck.verify(tvm.te.max(x + 3, 1) - 1, tvm.te.max(x + 2, 0)) + ck.verify(tvm.te.max(x - 3, 1) - 1, tvm.te.max(x + -4, 0)) + ck.verify(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y)) From 2fca571f9e98f346e8cd054f97b258e42c14df67 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 23 Jan 2022 15:17:10 +0800 Subject: [PATCH 3/3] remove rewrite rules, there are existing rule can work after canonical simplify --- src/arith/rewrite_simplify.cc | 20 ------------------- src/tir/schedule/state.cc | 4 ++-- .../unittest/test_arith_rewrite_simplify.py | 20 ------------------- 3 files changed, 2 insertions(+), 42 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index b988a646be5c..4a99e10211b7 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -153,16 +153,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); - TVM_TRY_REWRITE(min(c1, x - c2) + c3, min(c1 + c3, x + (c3 - c2))); - TVM_TRY_REWRITE(min(c1, x + c2) + c3, min(c1 + c3, x + (c3 + c2))); - TVM_TRY_REWRITE(min(x - c2, c1) + c3, min(x + (c3 - c2), c1 + c3)); - TVM_TRY_REWRITE(min(x + c2, c1) + c3, min(x + (c3 + c2), c1 + c3)); - - TVM_TRY_REWRITE(max(c1, x - c2) + c3, max(c1 + c3, x + (c3 - c2))); - TVM_TRY_REWRITE(max(c1, x + c2) + c3, max(c1 + c3, x + (c3 + c2))); - TVM_TRY_REWRITE(max(x - c2, c1) + c3, max(x + (c3 - c2), c1 + c3)); - TVM_TRY_REWRITE(max(x + c2, c1) + c3, max(x + (c3 + c2), c1 + c3)); - TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), @@ -306,16 +296,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); - TVM_TRY_REWRITE(min(c1, x - c2) - c3, min(c1 - c3, x - (c2 + c3))); - TVM_TRY_REWRITE(min(c1, x + c2) - c3, min(c1 - c3, x + (c2 - c3))); - TVM_TRY_REWRITE(min(x - c2, c1) - c3, min(x - (c2 + c3), c1 - c3)); - TVM_TRY_REWRITE(min(x + c2, c1) - c3, min(x + (c2 - c3), c1 - c3)); - - TVM_TRY_REWRITE(max(c1, x - c2) - c3, max(c1 - c3, x - (c2 + c3))); - TVM_TRY_REWRITE(max(c1, x + c2) - c3, max(c1 - c3, x + (c2 - c3))); - TVM_TRY_REWRITE(max(x - c2, c1) - c3, max(x - (c2 + c3), c1 - c3)); - TVM_TRY_REWRITE(max(x + c2, c1) - c3, max(x + (c2 - c3), c1 - c3)); - TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ba5928713182..a9a373151db9 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -109,8 +109,8 @@ bool ProducerCoversConsumer(const Array& buffer_shape, analyzer->canonical_simplify(consumed_region[i].max())); produced = arith::Intersect({produced, buffer_size}); consumed = arith::Intersect({consumed, buffer_size}); - if (!analyzer->CanProve((produced.min() <= consumed.min()) && - (consumed.max() <= produced.max()))) { + if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) && + (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) { return false; } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 326dc4c15194..6ca2a2a5fcb0 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -239,16 +239,6 @@ def test_add_index_simplify(): ck.verify(tvm.te.min(x, y + 2) + (-2), tvm.te.min(x + (-2), y)) ck.verify(tvm.te.min(x + 2, y + 3) + (-2), tvm.te.min(x, y + 1)) - ck.verify(tvm.te.min(1, x + 3) + 1, tvm.te.min(2, x + 4)) - ck.verify(tvm.te.min(1, x - 3) + 1, tvm.te.min(2, x + -2)) - ck.verify(tvm.te.min(x + 3, 1) + 1, tvm.te.min(x + 4, 2)) - ck.verify(tvm.te.min(x - 3, 1) + 1, tvm.te.min(x + -2, 2)) - - ck.verify(tvm.te.max(1, x + 3) + 1, tvm.te.max(2, x + 4)) - ck.verify(tvm.te.max(1, x - 3) + 1, tvm.te.max(2, x + -2)) - ck.verify(tvm.te.max(x + 3, 1) + 1, tvm.te.max(x + 4, 2)) - ck.verify(tvm.te.max(x - 3, 1) + 1, tvm.te.max(x + -2, 2)) - ck.verify(tvm.te.max(0, 1 - x * 4) + x * 4, tvm.te.max(x * 4, 1)) ck.verify(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)) @@ -336,16 +326,6 @@ def test_sub_index_simplify(): ck.verify(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y)) ck.verify(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y)) - ck.verify(tvm.te.min(1, x + 3) - 1, tvm.te.min(0, x + 2)) - ck.verify(tvm.te.min(1, x - 3) - 1, tvm.te.min(0, x + -4)) - ck.verify(tvm.te.min(x + 3, 1) - 1, tvm.te.min(x + 2, 0)) - ck.verify(tvm.te.min(x - 3, 1) - 1, tvm.te.min(x + -4, 0)) - - ck.verify(tvm.te.max(1, x + 3) - 1, tvm.te.max(0, x + 2)) - ck.verify(tvm.te.max(1, x - 3) - 1, tvm.te.max(0, x + -4)) - ck.verify(tvm.te.max(x + 3, 1) - 1, tvm.te.max(x + 2, 0)) - ck.verify(tvm.te.max(x - 3, 1) - 1, tvm.te.max(x + -4, 0)) - ck.verify(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y))