diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index efa26a31d02e..fb837ca3d01b 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -618,6 +618,22 @@ class TVM_DLL Analyzer { TransitiveComparisonAnalyzer transitive_comparisons; /*! \brief constructor */ Analyzer(); + /*! + * \brief Mark the value as non-negative value globally in analyzer. + * + * Only call this function if the non-neg condition is global and + * not context-dependent. + * + * This function does best-effort propagations to the sub-analyzers + * + * \note We expose this function because non-negative global values, + * such as symbolic buffer shapes in function arguments are really + * important to ensure the best simplification, and usually they + * can be handled in a simpler way than the generic constraints. + * + * This function may call into the Update function of the sub-analyzers. + */ + void MarkGlobalNonNegValue(const PrimExpr& value); /*! * \brief Notify all the sub-analyzers that var * is created and binded to expr. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 722a2cd00e75..9e5b1414edf4 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -25,6 +25,7 @@ #include #include +#include "const_fold.h" #include "product_normal_form.h" namespace tvm { @@ -63,6 +64,38 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { // skip rewrite simplify } +void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { + // split out the symbolic and non-symbolic part + int64_t cscale = 1; + PrimExpr symbolic = tir::make_const(value.dtype(), 1); + auto fcollect = [&](PrimExpr val) { + if (const auto* intimm = val.as()) { + cscale *= intimm->value; + } else { + symbolic = symbolic * val; + } + }; + UnpackReduction(value, fcollect); + if (cscale <= 0) return; + // override the constant int bound by marking it as non-negative + // NOTE: there might be future opportunities of more bound hint + // this is a simple step and covers all the current needs + // + // We may consider enhance the sub analyzer to directly take + // MarkPositiveVar so their bounds do not overlap + if (const auto* var_ptr = symbolic.as()) { + Var var = GetRef(var_ptr); + // skip non-index type, keep it to be compatible + // with any_dim that do not represent any value + if (!IsIndexType(var.dtype())) return; + bool allow_override = true; + // mark the constant bound is sufficient + // we cannot mark interval set as that will cause relaxation of the var + // during bound proof which is not our intention + this->const_int_bound.Update(var, ConstIntBound(0, ConstIntBound::kPosInf), allow_override); + } +} + void Analyzer::Bind(const Map& variables, bool allow_override) { for (const auto& iter : variables) { this->Bind(iter.first, iter.second, allow_override); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 68ade3bb5400..8ce502523159 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -195,6 +195,31 @@ class ConstIntBoundAnalyzer::Impl return Intersect(a, b); } + /*! + * \brief Process the divisor by making assumption that divide by zero + * won't happen in a valid program. + * + * This is important for us to get a lot of symbolic shape bound right + * now that the shape n >= 0, but in cases + * when mod or divide of n occur, the intention is actually n > 0 + * + * \param divisor The input divsor entry + * \return The processed entry + */ + Entry AssumeNoZeroDivisor(Entry divisor) { + ICHECK(!divisor.is_const(0)) << "Find divide by zero"; + // NOTE: here we make the assumption that + // divide by zero won't happen in a valid program + // this is important for us to get a lot of symbolic shape bound right + // where most conditions know that the shape n >= 0, but in cases + // when mod or divide of n occur, the intention is actually n > 0 + if (divisor.min_value == 0) { + divisor.min_value = 1; + ICHECK_GE(divisor.max_value, 1); + } + return divisor; + } + Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } Entry VisitExpr_(const AddNode* op) final { @@ -223,14 +248,14 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); - Entry b = VisitExpr(op->b); - ICHECK(!b.is_const(0)) << "divide by zero"; + Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); return HandleDivision(a, b, op->dtype, InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { Entry a = VisitExpr(op->a); - Entry b = VisitExpr(op->b); + Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); if (a.min_value >= 0) { @@ -252,8 +277,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); - Entry b = VisitExpr(op->b); - ICHECK(!b.is_const(0)) << "floordiv by zero"; + Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); } @@ -276,7 +300,8 @@ class ConstIntBoundAnalyzer::Impl * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1) */ Entry a = VisitExpr(op->a); - Entry b = VisitExpr(op->b); + Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); if (a.min_value >= 0) { @@ -457,7 +482,6 @@ class ConstIntBoundAnalyzer::Impl // at a negative value and ends at a positive one, narrow it down to // be closer to 0, because BinaryOpBoundary only checks end-points of // the domain ranges. - // If the range of b contains 0, then some infinity will be involved if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) { Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index c201a245e190..1f087d993428 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -30,6 +30,15 @@ namespace arith { using namespace tir; +void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) { + // Mark the all the symbolic buffer shape values in the buffer map as positive value. + for (auto kv : func->buffer_map) { + for (PrimExpr shape : kv.second->shape) { + analyzer_->MarkGlobalNonNegValue(shape); + } + } +} + Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { // record the loop variable as iterators Range dom = Range::FromMinExtent(op->min, op->extent); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index ed62c91df913..f04b40e7ae4e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -62,6 +62,14 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { PrimExpr VisitExpr_(const tir::ReduceNode* op) override; protected: + /*! + * \brief Mark the all the buffer shape values in the buffer map as positive value. + * + * \note call this function before Visit function's body to maximize + * simplification efficiency + */ + void MarkBufferMapShapes(const tir::PrimFunc& func); + /*! \brief internal analyzer field. */ Analyzer* analyzer_; // the following two fields are useful in case we want diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index ffdff45a7d3f..f37c21593f23 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -42,6 +42,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { arith::Analyzer ana; auto pass = BufferFlattener(&ana); auto writer = func.CopyOnWrite(); + pass.MarkBufferMapShapes(func); writer->body = pass.VisitStmt(func->body); // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 130cbe37c167..44d64df63d9f 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: - static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional config_opt = NullOpt) { + static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, + Optional config_opt = NullOpt) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); std::optional touch_pattern = std::nullopt; if (config->propagate_knowns_to_prove_conditional || config->propagate_knowns_to_simplify_expressions) { - touch_pattern = ControlFlowGraph(stmt); + touch_pattern = ControlFlowGraph(func->body); } - std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt); + std::unordered_set used_in_buffer_def = + CollectVarsUsedInBufferDefinition(func->body); StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), std::move(used_in_buffer_def)); - return simplifier(std::move(stmt)); + simplifier.MarkBufferMapShapes(func); + func.CopyOnWrite()->body = simplifier(func->body); + return func; } private: @@ -335,11 +339,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { - -Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) { - return arith::StmtSimplifier::Apply(stmt, analyzer); -} - namespace transform { Pass Simplify() { @@ -347,9 +346,7 @@ Pass Simplify() { arith::Analyzer analyzer; auto cfg = ctx->GetConfig("tir.Simplify"); - auto* n = f.CopyOnWrite(); - n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg); - return f; + return arith::StmtSimplifier::Apply(f, &analyzer, cfg); }; return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index d9ea36206b06..5667c79aaced 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -339,6 +339,23 @@ def test_floormod_negative_divisor(): assert bd.max_value == 6 +def test_divmod_assume_no_zero_divsor(): + # Divmod non negative expression makes assumption that divide by zero won't occur + # this assumption is important to get best result from symbolic shape programs + analyzer = tvm.arith.Analyzer() + flm, fld = tvm.te.floormod, tvm.te.floordiv + a, b = te.var("a"), te.var("b") + analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) + analyzer.update(b, tvm.arith.ConstIntBound(0, tvm.arith.ConstIntBound.POS_INF)) + bd = analyzer.const_int_bound(fld(a, b)) + assert bd.min_value == 0 + assert bd.max_value == 6 + + bd = analyzer.const_int_bound(flm(a, b)) + assert bd.min_value == 0 + assert bd.max_value == 6 + + def test_multiple_condition(): analyzer = tvm.arith.Analyzer() flm, fld = tvm.te.floormod, tvm.te.floordiv diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 1f25405ec9d1..79fd5e143418 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -1733,5 +1733,21 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32): expected = before +class TestBufferShapeConstraint(BaseBeforeAfter): + """If enabled, rewrite boolean expressions into AND of OR""" + + convert_boolean_to_and_of_ors = True + + def before(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32,), "float32") + A[T.min(T.int64(0), n)] = T.float32(0) + + def expected(a: T.handle): + n = T.int64() + A = T.match_buffer(a, (n * 32,), "float32") + A[T.int64(0)] = T.float32(0) + + if __name__ == "__main__": tvm.testing.main()