From f3785ff44db1acedaa5cede62320ebffcf122cba Mon Sep 17 00:00:00 2001 From: Danilo Date: Wed, 16 Aug 2023 11:37:37 +0200 Subject: [PATCH 1/3] [MetaSchedule] Fix metaschedule flop estimation for non-integer loop dimensions --- src/tir/analysis/estimate_flops.cc | 20 ++++++++++--- .../test_tir_analysis_estimate_tir_flops.py | 28 +++++++++++++++---- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 8d5aeb8ddbe2..ee2869d99303 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -19,6 +19,8 @@ #include #include +#include "tvm/arith/analyzer.h" + namespace tvm { namespace tir { @@ -84,6 +86,8 @@ struct TResult { class FlopEstimator : private ExprFunctor, private StmtFunctor { + arith::Analyzer ana; + public: TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } @@ -112,6 +116,15 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const GTNode* op) override { return TResult(); } TResult VisitExpr_(const GENode* op) override { return TResult(); } + int64_t GetLoopExtent(const ForNode* node, const arith::Analyzer& ana) { + int64_t bound = ana.const_int_bound(node->extent)->max_value; + if (bound == arith::ConstIntBound::kPosInf) { + return 1; // Analyzer could not determine a valid bound, use 1 instead. + } else { + return bound; + } + } + TResult VisitExpr_(const NotNode* op) override { return VisitExpr(op->a); } TResult VisitExpr_(const AndNode* op) final { TResult result = VisitExpr(op->a); @@ -138,11 +151,10 @@ class FlopEstimator : private ExprFunctor, return result; } TResult VisitStmt_(const ForNode* loop) override { + ana.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + const auto int_imm = GetLoopExtent(loop, ana); TResult result = VisitStmt(loop->body); - const auto* int_imm = loop->extent.as(); - ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " - << loop->extent->GetTypeKey(); - result *= int_imm->value; + result *= int_imm; return result; } diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 1b494a8cd7d0..4db1b6ba137b 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -91,7 +91,7 @@ def flops_override(A: T.Buffer(16, "float32")): A[0] = A[0] + 1 -def test_estimate_flops_forloop_as_experssion(): +def test_estimate_flops_forloop_as_expression(): flops = estimate_tir_flops( IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)}) ) @@ -102,11 +102,6 @@ def test_estimate_flops_forloop_as_experssion(): assert flops == 32 -def test_exception(): - with pytest.raises(tvm.TVMError): - flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression})) - - def test_estimate_flops_with_decl_buffer(): def make_func(use_decl_buffer): buffer_func = T.decl_buffer if use_decl_buffer else T.Buffer @@ -124,5 +119,26 @@ def func(A_data: T.handle("float32")): assert flops_with_decl_buffer == flops_without_decl_buffer +@T.prim_func +def flops_with_nonint_extent(a: T.Buffer(16, "float32")): + for i in range(4 + 4): + a[i] = a[i] + + +def test_flops_with_nonint_extent(): + estimate_tir_flops(IRModule({"main": flops_with_nonint_extent})) + + +@T.prim_func +def flops_with_variable_extent(a: T.Buffer(16, "float32")): + for i in range(4 + 4): + for j in range(i + 8): + a[j] = a[i] + + +def test_flops_with_variable_extent(): + estimate_tir_flops(IRModule({"main": flops_with_variable_extent})) + + if __name__ == "__main__": tvm.testing.main() From d2c2d6c34c0e1f6830992335d6ef46d5965ce56b Mon Sep 17 00:00:00 2001 From: Danilo Date: Thu, 17 Aug 2023 12:01:31 +0200 Subject: [PATCH 2/3] Add assert checks in the test --- .../unittest/test_tir_analysis_estimate_tir_flops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 4db1b6ba137b..b6996b7181fb 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -122,22 +122,22 @@ def func(A_data: T.handle("float32")): @T.prim_func def flops_with_nonint_extent(a: T.Buffer(16, "float32")): for i in range(4 + 4): - a[i] = a[i] + a[i] = 2*a[i] def test_flops_with_nonint_extent(): - estimate_tir_flops(IRModule({"main": flops_with_nonint_extent})) + assert estimate_tir_flops(IRModule({"main": flops_with_nonint_extent}))==8 @T.prim_func def flops_with_variable_extent(a: T.Buffer(16, "float32")): for i in range(4 + 4): for j in range(i + 8): - a[j] = a[i] + a[j] = 2*a[i] def test_flops_with_variable_extent(): - estimate_tir_flops(IRModule({"main": flops_with_variable_extent})) + assert estimate_tir_flops(IRModule({"main": flops_with_variable_extent}))==120 if __name__ == "__main__": From 79a59ce162c73c35692550ec2b354f228a17525c Mon Sep 17 00:00:00 2001 From: Danilo Date: Thu, 17 Aug 2023 12:07:24 +0200 Subject: [PATCH 3/3] Fix formatting --- .../unittest/test_tir_analysis_estimate_tir_flops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index b6996b7181fb..10a6f3b349cd 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -122,22 +122,22 @@ def func(A_data: T.handle("float32")): @T.prim_func def flops_with_nonint_extent(a: T.Buffer(16, "float32")): for i in range(4 + 4): - a[i] = 2*a[i] + a[i] = 2 * a[i] def test_flops_with_nonint_extent(): - assert estimate_tir_flops(IRModule({"main": flops_with_nonint_extent}))==8 + assert estimate_tir_flops(IRModule({"main": flops_with_nonint_extent})) == 8 @T.prim_func def flops_with_variable_extent(a: T.Buffer(16, "float32")): for i in range(4 + 4): for j in range(i + 8): - a[j] = 2*a[i] + a[j] = 2 * a[i] def test_flops_with_variable_extent(): - assert estimate_tir_flops(IRModule({"main": flops_with_variable_extent}))==120 + assert estimate_tir_flops(IRModule({"main": flops_with_variable_extent})) == 120 if __name__ == "__main__":