From abf8dd9578d674d86d79dc1fb0a81841283e0afa Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 19:42:28 +0000 Subject: [PATCH 1/4] [REFACTOR][TIR][ARITH] Phase out ControlFlowGraph and NarrowPredicateExpression All production callers of ControlFlowGraph are gated by config flags (propagate_knowns_to_prove_conditional, propagate_knowns_to_simplify_expressions in SimplifyConfig; use_dataflow_analysis in RemoveNoOpConfig) that no standard lowering pipeline turns on. The dataflow-analysis path is effectively dead code in practice. Remove ControlFlowGraph and its sole dependent NarrowPredicateExpression entirely. Drop the config flags and their associated code paths from StmtSimplifier and NoOpRemover. The remaining simplification logic (transitively_prove_inequalities, convert_boolean_to_and_of_ors, apply_constraints_to_boolean_branches, max_simplification_steps) is unaffected. Delete the tests that exercised the removed paths. --- src/arith/narrow_predicate_expression.cc | 224 --- src/arith/narrow_predicate_expression.h | 57 - src/tirx/analysis/control_flow_graph.cc | 1692 ----------------- src/tirx/analysis/control_flow_graph.h | 667 ------- src/tirx/transform/remove_no_op.cc | 57 +- src/tirx/transform/remove_no_op.h | 14 +- src/tirx/transform/simplify.cc | 57 +- .../test_arith_narrow_predicate_expression.py | 87 - .../test_tir_transform_remove_no_op.py | 289 +-- .../test_tir_transform_simplify.py | 682 ------- 10 files changed, 19 insertions(+), 3807 deletions(-) delete mode 100644 src/arith/narrow_predicate_expression.cc delete mode 100644 src/arith/narrow_predicate_expression.h delete mode 100644 src/tirx/analysis/control_flow_graph.cc delete mode 100644 src/tirx/analysis/control_flow_graph.h delete mode 100644 tests/python/arith/test_arith_narrow_predicate_expression.py diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc deleted file mode 100644 index 697db81f683a..000000000000 --- a/src/arith/narrow_predicate_expression.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file narrow_predicate_expression.cc - * \brief Utility to deduce bound of expression - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace arith { - -using namespace tirx; - -/* \brief Given a true expression that includes free parameter, - * generate a true expression without the free parameters. - * - * This function provides two guarantees: - * - * 1. If the resulting expression evaluates to True, then the original - * expression also evaluates to True. - * - * 2. The resulting expression does not contain any of the free - * parameters. - * - */ -// Utility for generating a known true expression from an expression -// with free parameters, and the range of those parameters. -class ExpressionNarrower : public tirx::ExprMutator { - public: - static PrimExpr Apply(PrimExpr expr, ffi::Map free_parameters) { - TVM_FFI_ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; - ExpressionNarrower mutator(free_parameters); - return mutator(expr); - } - - private: - explicit ExpressionNarrower(ffi::Map free_parameters) - : free_parameters_(free_parameters) {} - - using Parent = tirx::ExprMutator; - using Parent::VisitExpr_; - - enum class Context { - Maximize, - Minimize, - }; - - template - PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { - PrimExpr a = [&]() { - WithContext context(this, a_ctx); - return this->VisitExpr(t->a); - }(); - - PrimExpr b = [&]() { - WithContext context(this, b_ctx); - return this->VisitExpr(t->b); - }(); - - if (contains_unknown_expr_ && t.dtype().is_bool()) { - contains_unknown_expr_ = false; - return Bool(CurrentContext() == Context::Minimize); - } else if (a.same_as(t->a) && b.same_as(t->b)) { - return t; - - } else { - return T(a, b); - } - } - - PrimExpr VisitExpr_(const FloorModNode* op) override { - // FloorMod is non-monotonic, so inserting min/max won't remove - // the free parameters. - contains_unknown_expr_ = true; - return Parent::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const FloorDivNode* op) override { - auto res_a = this->VisitExpr(op->a); - auto res_b = this->VisitExpr(op->b); - if (is_zero(res_b)) { - contains_unknown_expr_ = true; - return IntImm(op->dtype, 0); - } else { - return floordiv(res_a, res_b); - } - } - - PrimExpr VisitExpr_(const GTNode* op) override { - auto current = CurrentContext(); - return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); - } - - PrimExpr VisitExpr_(const GENode* op) override { - auto current = CurrentContext(); - return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); - } - - PrimExpr VisitExpr_(const LTNode* op) override { - auto current = CurrentContext(); - return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); - } - - PrimExpr VisitExpr_(const LENode* op) override { - auto current = CurrentContext(); - return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); - } - - PrimExpr VisitExpr_(const EQNode* op) override { - auto res_a = this->VisitExpr(op->a <= op->b); - auto res_b = this->VisitExpr(op->b <= op->a); - return res_a && res_b; - } - - PrimExpr VisitExpr_(const NENode* op) override { - auto res_a = this->VisitExpr(op->a < op->b); - auto res_b = this->VisitExpr(op->b < op->a); - return res_a || res_b; - } - - PrimExpr VisitExpr_(const SubNode* op) override { - auto current = CurrentContext(); - return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); - } - - PrimExpr VisitExpr_(const NotNode* op) override { - auto current = CurrentContext(); - WithContext context(this, OppositeContext(current)); - return !VisitExpr(op->a); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) override { - contains_unknown_expr_ = true; - return ffi::GetRef(op); - } - - PrimExpr VisitExpr_(const VarNode* op) override { - auto it = free_parameters_.find(ffi::GetRef(op)); - if (it == free_parameters_.end()) { - return Parent::VisitExpr_(op); - } - - Range range = (*it).second; - - switch (CurrentContext()) { - case Context::Minimize: - return range->min; - - case Context::Maximize: - return range->min + range->extent - 1; - } - - return Parent::VisitExpr_(op); - } - - Context CurrentContext() const { - if (context_stack_.size()) { - return context_stack_.back(); - } else { - return Context::Maximize; - } - } - - Context OppositeContext(Context context) const { - switch (context) { - case Context::Minimize: - return Context::Maximize; - - case Context::Maximize: - return Context::Minimize; - - default: - TVM_FFI_THROW(InternalError) << "Unhandled Context, all legal values should be handled"; - } - } - - struct WithContext { - WithContext(ExpressionNarrower* self, Context context) : self(self) { - self->context_stack_.push_back(context); - } - ~WithContext() { self->context_stack_.pop_back(); } - ExpressionNarrower* self; - }; - - std::vector context_stack_; - ffi::Map free_parameters_; - bool contains_unknown_expr_{false}; -}; - -PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters) { - return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.NarrowPredicateExpression", NarrowPredicateExpression); -} - -} // namespace arith -} // namespace tvm diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h deleted file mode 100644 index 8262646caa2d..000000000000 --- a/src/arith/narrow_predicate_expression.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file narrow_predicate_expression.h - * \brief Utility for extracting and interacting with buffer touch points - */ - -#include -#include - -#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ -#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ - -namespace tvm { -namespace arith { - -/* \brief Narrow a true expression to remove free parameters - * - * This function provides two guarantees: - * - * 1. If the resulting expression evaluates to True, then the original - * expression also evaluates to True. - * - * 2. The resulting expression does not contain any of the free - * parameters. - * - * 3. The resulting expression does not contain any BufferLoad - * - * \param expr The expression to be examined. - * - * \param ranges The variables to be removed from the expression - * - * \returns An expression that, if true, implies that the original - * expression is also true. - */ -PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters); - -} // namespace arith -} // namespace tvm -#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ diff --git a/src/tirx/analysis/control_flow_graph.cc b/src/tirx/analysis/control_flow_graph.cc deleted file mode 100644 index 0a8371a1f338..000000000000 --- a/src/tirx/analysis/control_flow_graph.cc +++ /dev/null @@ -1,1692 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file control_flow_graph.cc - * \brief Utility to deduce bound of expression - */ - -#include "control_flow_graph.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "../../arith/conjunctive_normal_form.h" -#include "../../arith/constraint_extract.h" -#include "../../arith/ir_mutator_with_analyzer.h" -#include "../../arith/ir_visitor_with_analyzer.h" -#include "../../arith/narrow_predicate_expression.h" -#include "../../arith/unwrap_vector_expr.h" - -namespace tvm { -namespace tirx { - -using namespace arith; - -namespace { -bool HasBufferLoad(PrimExpr expr) { - struct Visitor : public ExprVisitor { - void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; } - bool found_buffer_load{false}; - }; - - Visitor visitor; - visitor(expr); - return visitor.found_buffer_load; -} - -ffi::Optional SubstituteParamValues(const ffi::Array& param_vars, - const ffi::Array& param_values, - const PrimExpr& expr) { - TVM_FFI_ICHECK_EQ(param_vars.size(), param_values.size()) - << "Expression was defined as having " << param_vars.size() << " parameters, but received " - << param_values.size() << " arguments."; - - ffi::Map var_map; - for (size_t i = 0; i < param_values.size(); i++) { - var_map.Set(param_vars[i], param_values[i]); - } - - return Substitute(expr, var_map); -} -} // namespace - -PrimExpr BufferTouch::BeforeLoopIteration() const { - PrimExpr loop_predicate = Bool(true); - for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { - const Var& loop_var = it->first; - const PrimExpr& loop_expr = it->second; - loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate); - } - return loop_predicate; -} - -PrimExpr BufferTouch::AtLoopIteration() const { - PrimExpr loop_predicate = Bool(true); - for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { - const Var& loop_var = it->first; - const PrimExpr& loop_expr = it->second; - loop_predicate = (loop_var == loop_expr) && loop_predicate; - } - return loop_predicate; -} - -PrimExpr BufferTouch::AfterLoopIteration() const { - PrimExpr loop_predicate = Bool(true); - for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { - const Var& loop_var = it->first; - const PrimExpr& loop_expr = it->second; - loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate); - } - return loop_predicate; -} - -bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const { - if (this->buffer.same_as(other.buffer)) { - With constraint(analyzer, predicate); - - return analyzer->CanProve(other.predicate); - } else { - return false; - } -} - -bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const { - if (this->buffer.same_as(other.buffer)) { - With constraint(analyzer, predicate); - - return analyzer->CanProve(!other.predicate); - } else { - return true; - } -} - -std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) { - auto touch_type = [&]() { - if (tp.touch_type == BufferTouch::AccessType::Read) { - return "read"; - } else if (tp.touch_type == BufferTouch::AccessType::Write) { - return "write"; - } else if (tp.touch_type == BufferTouch::AccessType::Assume) { - return "assume"; - } else { - return "???"; - } - }(); - - os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate - << ", value = " << tp.value << ")"; - return os; -} - -class BufferConstraintApply : public IRMutatorWithAnalyzer { - public: - using Parent = IRMutatorWithAnalyzer; - - BufferConstraintApply(const ffi::Map>& axis_var_lookup, - const std::vector& knowns, Analyzer* analyzer) - : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} - - using Parent::VisitExpr_; - - PrimExpr VisitExpr_(const BufferLoadNode* op) override { - for (const auto& known : knowns_) { - if (!op->buffer.same_as(known.buffer)) { - continue; - } - - ffi::Optional lane_var = std::nullopt; - IntImm num_lanes; - - ffi::Array indices = op->indices.Map([&](const auto& index) { - if (index.dtype().lanes() == 1) { - return index; - } else { - TVM_FFI_ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; - lane_var = Var("lane", index.dtype().element_of()); - num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); - return UnwrapVectorExpr(index, lane_var.value()); - } - }); - - auto axis_vars = axis_var_lookup_.at(op->buffer); - PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value(); - - std::optional> context; - if (lane_var.defined()) { - Var lanes = lane_var.value(); - PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes); - context.emplace(analyzer_, known); - } - - if (analyzer_->CanProve(predicate)) { - return SubstituteParamValues(axis_vars, op->indices, known.value).value(); - } - } - - return ffi::GetRef(op); - } - - private: - const ffi::Map>& axis_var_lookup_; - const std::vector& knowns_; -}; - -/*! \brief Extract the control-flow graph - * - * Walk through a statement, populating the control-flow graph. - */ -class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { - public: - static void Build(ControlFlowGraph* out, const Stmt& stmt) { - ControlFlowGraphBuilder extractor(out); - extractor.AppendControlBlock(); - extractor(stmt); - } - - private: - ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {} - - using Parent = IRVisitorWithAnalyzer; - using Parent::VisitExpr_; - using Parent::VisitStmt_; - - void VisitStmt(const Stmt& stmt) override { - // Update the lookup table to determine which control-flow block - // contains the start of the specified statement. This is used - // later to determine which set of known values should be used to - // simplify a statement. - out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock(); - Stmt prev_stmt = current_stmt_; - current_stmt_ = stmt; - Parent::VisitStmt(stmt); - current_stmt_ = prev_stmt; - } - - void VisitStmt_(const EvaluateNode* op) override { - if (auto* call = op->value.as()) { - if (call->op.same_as(builtin::assume())) { - Assume(call->args[0], true); - return; - } - } - - Parent::VisitStmt_(op); - } - - void Assume(PrimExpr assumption, bool from_assume_statement) { - for (const auto& expr : ExtractConstraints(assumption, false)) { - AssumeConstraintComponent(expr, from_assume_statement); - } - } - - void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) { - PrimExpr additional_predicate = Bool(true); - - std::vector buffer_exprs; - for (const auto& expr : ExtractComponents(assumption)) { - auto side_effect = tirx::SideEffect(expr); - if (side_effect <= tirx::CallEffectKind::kPure) { - // Pulling out portions of the assumption that do not depend - // on a buffer value allows the following two forms to be - // treated identically. - // - // Option 1: if i < 3: T.assume(buf[i] == value) - // Option 2: T.assume(i>=3 or buf[i] == value) - additional_predicate = additional_predicate && logical_not(expr); - } else if (side_effect == tirx::CallEffectKind::kReadState) { - buffer_exprs.push_back(expr); - } else { - TVM_FFI_THROW(InternalError) - << "Assumption must be pure or read-only, but contained expression " << expr - << " with side-effect \'" << side_effect << "\'"; - } - } - - if (buffer_exprs.empty()) { - out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption); - return; - } - - TVM_FFI_ICHECK_EQ(buffer_exprs.size(), 1) - << "T.assume must contain only a single buffer expression"; - - auto* as_equal_node = buffer_exprs[0].as(); - TVM_FFI_ICHECK(as_equal_node || !from_assume_statement) - << "T.assume buffer constraint must be of the form 'buffer[indices] == " - "value', but received " - << assumption; - if (!as_equal_node) { - // This assumption is an inequality on a data-dependent - // conditional. Not an error for this to occur, but also not - // something that is currently supported. - return; - } - - tirx::BufferLoad load; - PrimExpr value; - if (auto opt = as_equal_node->a.as()) { - load = opt.value(); - value = as_equal_node->b; - } else if (auto opt = as_equal_node->b.as()) { - load = opt.value(); - value = as_equal_node->a; - } else if (!from_assume_statement) { - return; - } else { - TVM_FFI_THROW(InternalError) - << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; - } - - auto has_side_effect = tirx::SideEffect(value) > tirx::CallEffectKind::kPure; - TVM_FFI_ICHECK(!has_side_effect || !from_assume_statement) - << "Buffer value in constraint must be pure expression, but was " << value; - if (has_side_effect) { - return; - } - - { - InternalConstraintContext context(this, additional_predicate); - VisitAccess(load, BufferTouch::AccessType::Assume, value); - } - // Appending a control block ensures that all control blocks have - // at most one statement that changes the known buffer contents. - auto prev_block = CurrentControlBlock(); - auto new_block = AppendControlBlock(); - MarkControlFlow(prev_block, new_block); - } - - void VisitExpr_(const LetNode* op) override { - std::optional binding; - if (UsesLoopVar(op->value)) { - binding.emplace(this, op->var, op->value); - } - Parent::VisitExpr_(op); - } - - void VisitStmt_(const BindNode* op) override { - std::optional binding; - if (UsesLoopVar(op->value)) { - binding.emplace(this, op->var, op->value); - } - Parent::VisitStmt_(op); - } - - void VisitExpr_(const BufferLoadNode* op) override { - Parent::VisitExpr_(op); - BufferLoad load = ffi::GetRef(op); - VisitAccess(load, BufferTouch::AccessType::Read, load); - } - - void VisitStmt_(const BufferStoreNode* op) override { - Parent::VisitStmt_(op); - VisitAccess(ffi::GetRef(op), BufferTouch::AccessType::Write, op->value); - // Appending a control block ensures that all control blocks have - // at most one statement that changes the buffer contents. - auto prev_block = CurrentControlBlock(); - auto new_block = AppendControlBlock(); - MarkControlFlow(prev_block, new_block); - } - - void VisitStmt_(const ForNode* op) override { - out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - - auto before_loop = CurrentControlBlock(); - size_t loop_start = -1; - - { - BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent); - loop_start = AppendControlBlock(); - Parent::VisitStmt_(op); - } - - auto loop_end = CurrentControlBlock(); - auto after_loop = AppendControlBlock(); - PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1); - { - auto [forward, backward] = MarkControlFlow(before_loop, loop_start); - backward.post_condition = (op->loop_var == op->min); - forward.var_remap = {{op->loop_var, op->min}}; - } - { - auto [forward, backward] = MarkControlFlow(loop_end, after_loop); - backward.var_remap = {{op->loop_var, max_iterator_value}}; - forward.post_condition = (op->loop_var == max_iterator_value); - } - { - auto [forward, backward] = MarkControlFlow(loop_end, loop_start); - backward.var_remap = {{op->loop_var, op->loop_var - 1}}; - forward.var_remap = {{op->loop_var, op->loop_var + 1}}; - backward.post_condition = (op->loop_var > op->min); - forward.post_condition = (op->loop_var < max_iterator_value); - } - } - - void VisitStmt_(const IfThenElseNode* op) override { - this->VisitExpr(op->condition); - - PrimExpr real_condition = ExtractRealCondition(op->condition); - - auto before_branching = CurrentControlBlock(); - - auto branch_start = AppendControlBlock(); - MarkControlFlow(before_branching, branch_start); - - { - InternalConstraintContext context(this, real_condition); - auto then_start = AppendControlBlock(); - if (context.assume.defined()) { - Assume(context.assume.value(), false); - } - auto [forward, backward] = MarkControlFlow(branch_start, then_start); - backward.post_condition = real_condition; - forward.post_condition = real_condition; - this->VisitStmt(op->then_case); - } - auto then_end = CurrentControlBlock(); - - auto negation = analyzer_.rewrite_simplify(!real_condition); - { - InternalConstraintContext context(this, negation); - auto else_start = AppendControlBlock(); - if (context.assume.defined()) { - Assume(context.assume.value(), false); - } - auto [forward, backward] = MarkControlFlow(branch_start, else_start); - backward.post_condition = negation; - forward.post_condition = negation; - - if (op->else_case.defined()) { - this->VisitStmt(op->else_case.value()); - } - } - - auto else_end = CurrentControlBlock(); - auto after_branching = AppendControlBlock(); - - if (HasBufferLoad(real_condition)) { - // The buffer value may have changed during the body of the - // condition, so we can't provide it as a post-condition. - MarkControlFlow(then_end, after_branching); - MarkControlFlow(else_end, after_branching); - } else { - { - auto [forward, backward] = MarkControlFlow(then_end, after_branching); - backward.post_condition = real_condition; - forward.post_condition = real_condition; - } - { - auto [forward, backward] = MarkControlFlow(else_end, after_branching); - backward.post_condition = negation; - forward.post_condition = negation; - } - } - } - - /*! \brief Internal utility, returns true if the expression depends - * on a loop iterator - */ - bool UsesLoopVar(const PrimExpr& expr) { - return UsesVar(expr, [&](const VarNode* expr_var) { - return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end(); - }); - } - - /*! \brief Record the interaction with the buffer. - * - * \param node The TIR node that accesses the buffer. Should be - * either a BufferLoad or BufferStore node. - * - * \param touch_type The type of buffer access being performed. A - * BufferStore should always use AccessType::Write. A BufferLoad - * may use either AccessType::Read or AccessType::Assume, depending - * on whether the BufferLoad occurs within `builtin::assume`. - * - * \param known_value_expr The value in the buffer following the access. - */ - template - void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) { - auto& current_block = out_->control_flow_.back(); - BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices, - touch_type, known_value_expr); - current_block.touch_points.push_back(buffer_touch); - } - - /*! \brief Return a predicate for having reached the current - * control-flow block - * - * For example, while inside an IfThenElse, will return the - * IfThenElse's condition. - */ - PrimExpr CurrentScopePredicate() const { - PrimExpr predicate = Bool(true); - for (const auto& condition : conditions_) { - predicate = predicate && condition; - } - return predicate; - } - - /* \brief Add a new control block, returning its index */ - size_t AppendControlBlock() { - size_t index = out_->control_flow_.size(); - auto& block = out_->control_flow_.emplace_back(); - block.active_loop_iterators = active_loop_iterators_; - block.let_bindings_using_loop = let_bindings_using_loop_; - block.scope_predicate = CurrentScopePredicate(); - return index; - } - - /* \brief The index of the current control block */ - size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; } - - /* \brief Mark a possible control from one block to another - * - * \param from_block The block from which control leaves - * - * \param to_block The block to which control enters - * - * \param var_remap Variable replacements that should be made in - * known expression while traversing this edge. For example, - * replacing `i` with `i-1` when entering the next loop iteration, - * or replacing `i` with `n-1` when concluding a loop. - */ - std::pair MarkControlFlow( - size_t from_block, size_t to_block) { - TVM_FFI_ICHECK_LE(from_block, out_->control_flow_.size()); - TVM_FFI_ICHECK_LE(to_block, out_->control_flow_.size()); - - auto& forward = out_->control_flow_[from_block].successors.emplace_back( - ControlFlowGraph::ControlFlowEdge{to_block, {}, std::nullopt}); - auto& backward = out_->control_flow_[to_block].predecessors.emplace_back( - ControlFlowGraph::ControlFlowEdge{from_block, {}, std::nullopt}); - return {forward, backward}; - } - - // Internal utility, context manager for entering/leaving a scoped constraint - struct InternalConstraintContext { - InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint) - : self(self), analyzer_context(&self->analyzer_, constraint) { - old_num_constraints = self->conditions_.size(); - - auto side_effect = tirx::SideEffect(constraint); - if (side_effect <= tirx::CallEffectKind::kPure) { - self->conditions_.push_back(constraint); - } else if (side_effect <= tirx::CallEffectKind::kReadState) { - assume = constraint; - } - - new_num_constraints = self->conditions_.size(); - } - ~InternalConstraintContext() { - TVM_FFI_ICHECK_EQ(self->conditions_.size(), new_num_constraints) - << "Internal error: Each condition should only be popped once."; - self->conditions_.erase(self->conditions_.begin() + old_num_constraints, - self->conditions_.end()); - } - - ControlFlowGraphBuilder* self{nullptr}; - With analyzer_context; - size_t old_num_constraints{0}; - size_t new_num_constraints{0}; - ffi::Optional assume{std::nullopt}; - - // Disable default-generated copy/move assignment and constructors - InternalConstraintContext(const InternalConstraintContext&) = delete; - InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; - InternalConstraintContext(InternalConstraintContext&&) = delete; - InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; - }; - - // Internal utility, context manager for tracking a loop - struct BindActiveLoopVar { - BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min, - PrimExpr loop_extent) - : self(self), var(var) { - PrimExpr loop_max = loop_min + (loop_extent - 1); - auto loop_range = Range::FromMinExtent(loop_min, loop_extent); - self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range}); - self->loop_dependent_vars_.insert(var.get()); - } - ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); } - - ControlFlowGraphBuilder* self; - Var var; - - // Disable default-generated copy/move assignment and constructors - BindActiveLoopVar(const BindActiveLoopVar&) = delete; - BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete; - BindActiveLoopVar(BindActiveLoopVar&&) = delete; - BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; - }; - - // Internal utility, context manager for tracking a variable binding. - // Under SSA, each variable is bound exactly once, so the maps grow - // monotonically and cleanup is unnecessary. Omitting cleanup also - // ensures correctness for flat BindNode (which has no body): the - // binding must remain visible to subsequent sibling statements. - struct BindLetVar { - BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) { - self->let_bindings_using_loop_.Set(var, value); - self->loop_dependent_vars_.insert(var.get()); - } - ~BindLetVar() {} - - // Disable default-generated copy/move assignment and constructors - BindLetVar(const BindLetVar&) = delete; - BindLetVar& operator=(const BindLetVar&) = delete; - BindLetVar(BindLetVar&&) = delete; - BindLetVar& operator=(BindLetVar&&) = delete; - }; - - struct LoopEntry { - Var loop_var; - PrimExpr loop_min; - PrimExpr loop_max; - Range loop_range; - }; - - // Track in order to know which Vars to write in terms of the buffer - // indices and substitute out of the predicate. - std::vector active_loop_iterators_; - - // Track all loop iterators, along with values derived from loop iterators. - std::unordered_set loop_dependent_vars_; - - // Any let binding that depends, directly or indirectly, on a loop - // binding. When making a predicate in terms of the buffer indices, - // these need to be substituted out. - // std::unordered_map let_bindings_using_loop_; - ffi::Map let_bindings_using_loop_; - - // Track in order to know what conditions limit the buffer access - std::vector conditions_; - - // Track in order to know what statement initiated the buffer access - Stmt current_stmt_; - - // Output data structure - ControlFlowGraph* out_; -}; - -std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( - const tirx::Buffer& buf, ffi::Array index_variables, ffi::Array indices, - BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { - const auto& current_block = *this; - - Analyzer local_analyzer; - - ffi::Optional lane_var = std::nullopt; - IntImm num_lanes; - - ffi::Array index_expressions = indices.Map([&](const auto& index) { - if (index.dtype().lanes() == 1) { - return index; - } else { - TVM_FFI_ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; - lane_var = Var("lane", index.dtype().element_of()); - num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); - return UnwrapVectorExpr(index, lane_var.value()); - } - }); - - ffi::Array loop_vars; - - ffi::Map loop_ranges; - for (const auto& loop_entry : current_block.active_loop_iterators) { - loop_vars.push_back(loop_entry.loop_var); - loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); - } - - // If the indices contain multiple lanes, treat the lane variable - // as an additional loop iterator to be solved for and substituted - // out. - if (lane_var) { - loop_vars.push_back(lane_var.value()); - loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes)); - } - - IntConstraintsTransform transform = [&]() { - TVM_FFI_ICHECK_EQ(index_variables.size(), index_expressions.size()); - - ffi::Array relations; - - for (size_t i = 0; i < index_expressions.size(); i++) { - PrimExpr expr = index_expressions[i]; - Var var = index_variables[i]; - - expr = Substitute(expr, current_block.let_bindings_using_loop); - relations.push_back(var == expr); - } - - IntConstraints system(loop_vars, loop_ranges, relations); - return arith::SolveLinearEquations(system); - }(); - - ffi::Map loop_var_to_axis_var = transform->src_to_dst; - ffi::Map free_params = transform->dst->ranges; - PrimExpr transform_predicate = - std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), - PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); - - transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); - - auto find_removable_params = [&]() -> ffi::Map { - ffi::Map removable_params; - - // The arith::SolveLinearEquations is more general than the - // utilities in iter_affine_map.h, but can introduce free - // parameters that could later be determined with the known - // constraints. This step removes all such free parameters. - for (const auto& expr : ExtractConstraints(transform_predicate)) { - if (auto* as_equal = expr.as()) { - auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) { - auto* var_ptr = a.as(); - if (!var_ptr) { - return; - } - - Var var = ffi::GetRef(var_ptr); - if (free_params.count(var) == 0) { - return; - } - - bool uses_free_param = UsesVar( - b, [&](const VarNode* v) { return free_params.count(ffi::GetRef(v)) > 0; }); - if (uses_free_param) { - return; - } - removable_params.Set(var, b); - }; - check_expr(as_equal->a, as_equal->b); - check_expr(as_equal->b, as_equal->a); - } - } - - // In addition, the arith::SolveLinearEquation can introduce - // free parameters with an extent of one. Filtering them out here - // avoids needing to track them through later simplifications. - for (const auto [var, range] : free_params) { - if (is_one(range->extent)) { - removable_params.Set(var, range->min); - } - } - - return removable_params; - }; - for (auto removable_params = find_removable_params(); removable_params.size() > 0; - removable_params = find_removable_params()) { - auto update = [&](const PrimExpr& expr) { - return local_analyzer.Simplify(Substitute(expr, removable_params)); - }; - - ffi::Map new_map; - for (const auto [loop_var, expr] : loop_var_to_axis_var) { - static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 - new_map.Set(loop_var, update(expr)); - } - loop_var_to_axis_var = new_map; - - transform_predicate = update(transform_predicate); - - for (const auto [var, expr] : removable_params) { - static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 - free_params.erase(var); - } - } - - // Normalization function, applied to both the predicate and the - // known value. Converts from an expression in terms of loop - // iterators to an expression in terms of buffer indices. - auto normalize_expr = [&](PrimExpr expr) -> PrimExpr { - expr = Substitute(expr, current_block.let_bindings_using_loop); - - if (lane_var) { - expr = UnwrapVectorExpr(expr, lane_var.value()); - } - expr = Substitute(expr, loop_var_to_axis_var); - - return expr; - }; - - // Collect the current loop variables, along with an expression for - // the loop variables in terms of the buffer axis variables. This - // is used during forward/backward propagation to generate predicate - // tracking whether a loop iteration has been reached. - std::vector> loop_var_expressions; - for (const auto& entry : current_block.active_loop_iterators) { - auto expr_it = loop_var_to_axis_var.find(entry.loop_var); - TVM_FFI_ICHECK(expr_it != loop_var_to_axis_var.end()); - loop_var_expressions.push_back({entry.loop_var, (*expr_it).second}); - } - - // The full predicate is composed of the values required to reach - // the scope of the BufferStore or builtin::assume(), any bounds - // implied by solving for the axis variables, and any additional - // statements resulting from unpacking the expression contained in - // builtin::assume(). - PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate); - transform_predicate = normalize_expr(transform_predicate); - - known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr)); - - // Deliberately use an analyzer without scope-based information, - // to avoid simplifying `scope_predicate` to True. - PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate); - - BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions, - touch_type}; - - return {buffer_touch, free_params}; -} - -BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, - const tirx::Buffer& buf, - const ffi::Array& indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const { - TVM_FFI_ICHECK(graph); - auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), - indices, touch_type, known_value_expr); - for (const auto& pair : free_params) { - graph->free_predicate_parameters_.Set(pair.first, pair.second); - } - return buffer_touch; -} - -ControlFlowGraph::ControlFlowGraph(const tirx::Stmt& stmt, int64_t max_simplification_steps, - size_t max_revisits) - : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { - ControlFlowGraphBuilder::Build(this, stmt); - ForwardPropagateKnownValues(); - BackwardPropagateUnusedValues(); -} - -void ControlFlowGraph::RemoveStore(const tirx::BufferStore& store) { - size_t context_index = [&]() { - auto it = control_flow_lookup_.find(store.get()); - TVM_FFI_ICHECK(it != control_flow_lookup_.end()) - << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor"; - return it->second; - }(); - - auto& touch_points = control_flow_[context_index].touch_points; - - touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(), - [](const BufferTouch& touch) { - return touch.touch_type == BufferTouch::AccessType::Write; - }), - touch_points.end()); - ForwardPropagateKnownValues(context_index); - BackwardPropagateUnusedValues(context_index); -} - -std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) { - os << edge.index; - if (edge.var_remap.size()) { - os << " with remap " << edge.var_remap; - } - if (edge.post_condition) { - os << " with postcondition " << edge.post_condition; - } - - return os; -} - -std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) { - os << "Predecessors: ["; - for (size_t i = 0; i < block.predecessors.size(); i++) { - if (i) { - os << ", "; - } - os << block.predecessors[i]; - } - os << "]\n"; - - os << "Active loop iterators: ["; - for (size_t i = 0; i < block.active_loop_iterators.size(); i++) { - if (i) { - os << ", "; - } - os << block.active_loop_iterators[i].loop_var; - } - os << "]\n"; - - os << "Before block knowns: " << block.known_at_block_start << "\n"; - - os << "Before block unused: " << block.unused_at_block_start << "\n"; - - for (size_t i = 0; i < block.touch_points.size(); i++) { - os << "Touch[" << i << "] = " << block.touch_points[i] << "\n"; - } - os << "After block: " << block.known_at_block_end << "\n"; - - os << "After block unused: " << block.unused_at_block_end << "\n"; - - os << "Successors: ["; - for (size_t i = 0; i < block.successors.size(); i++) { - if (i) { - os << ", "; - } - os << block.successors[i]; - } - os << "]"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) { - os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks." - << (pattern.control_flow_.size() ? "\n" : ""); - for (size_t i = 0; i < pattern.control_flow_.size(); i++) { - os << "\t" - << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n"; - } - - return os; -} - -bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const { - // Constraints must apply to the same buffer to be equivalent - if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) { - return false; - } - - ExprDeepEqual deep_equal; - - auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool { - With context(analyzer, a); - return analyzer->CanProve(b); - }; - - // Predicates must be equivalent expressions, or must both be undefined - bool equivalent_predicates = - deep_equal(predicate, other.predicate) || - (implies(predicate, other.predicate) && implies(other.predicate, predicate)); - if (!equivalent_predicates) { - return false; - } - - // The known value must be equal - if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) { - return false; - } - - return true; -} - -std::ostream& operator<<(std::ostream& os, const BufferState& state) { - for (size_t i = 0; i < state.constraints_.size(); i++) { - os << "constraints[" << i << "] = " << state.constraints_[i] - << (i + 1 == state.constraints_.size() ? "" : "\n"); - } - return os; -} - -PrimExpr BufferState::SubstituteKnownBufferValues( - PrimExpr expr, const ffi::Map>& axis_var_lookup, - Analyzer* analyzer) const { - BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); - return mutator(std::move(expr)); -} - -void BufferState::AddCondition(const PrimExpr& condition) { - for (auto& constraint : constraints_) { - constraint.predicate = constraint.predicate && condition; - } -} - -void BufferState::Substitute(const ffi::Map& var_remap, Analyzer* analyzer) { - if (var_remap.size()) { - for (auto& prior : constraints_) { - PrimExpr updated = tvm::tirx::Substitute(prior.predicate, var_remap); - if (!updated.same_as(prior.predicate)) { - prior.predicate = SimplifyAsAndOfOrs(updated, analyzer); - } - } - } -} - -void BufferState::Simplify(Analyzer* analyzer) { - for (auto& constraint : constraints_) { - constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer); - } -} - -void BufferState::Union(const BufferState& b, Analyzer* analyzer) { - for (const auto& b_constraint : b.constraints_) { - bool used = false; - for (auto& a_constraint : constraints_) { - if (a_constraint.buffer.same_as(b_constraint.buffer) && - analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) { - a_constraint.predicate = - SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer); - used = true; - break; - } - } - if (!used) { - constraints_.push_back(b_constraint); - } - } -} - -void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) { - // For a constraint to be in the output, it must be present in both - // inputs. - - std::vector new_constraints; - for (const auto& ai : constraints_) { - for (const auto& bi : b.constraints_) { - if (ai.buffer.same_as(bi.buffer)) { - PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer); - if (!is_zero(predicate)) { - With context(analyzer, predicate); - PrimExpr known_value_a = ai.value; - PrimExpr known_value_b = bi.value; - - bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b); - if (is_consistent) { - new_constraints.push_back({ai.buffer, predicate, known_value_a}); - } - } - } - } - } - - constraints_ = std::move(new_constraints); -} - -class BufferRegionCollector : public ExprVisitor { - public: - struct Region { - PrimExpr region_predicate; - std::unordered_map> known_values; - }; - - static std::vector Collect(const ffi::Map>& axis_var_lookup, - const std::vector& knowns, - const std::vector>& exprs, - Analyzer* analyzer) { - BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); - for (const auto& expr : exprs) { - if (expr) { - collector(expr.value()); - } - } - - return collector.regions_; - } - - private: - using Parent = ExprVisitor; - - BufferRegionCollector(const ffi::Map>& axis_var_lookup, - const std::vector& knowns, Analyzer* analyzer) - : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { - regions_.push_back(Region{Bool(true), {}}); - } - - using Parent::VisitExpr_; - - void VisitExpr_(const BufferLoadNode* op) override { - // Helper struct for the known values of this BufferLoad - struct Known { - PrimExpr predicate; - ffi::Optional value; - }; - - std::vector new_regions; - - PrimExpr unknown_region = Bool(true); - - for (const BufferTouch& constraint : knowns_) { - if (!op->buffer.same_as(constraint.buffer)) { - // This is a different buffer, so continue searching. - continue; - } - - auto axis_vars = axis_var_lookup_.at(op->buffer); - PrimExpr touch_predicate = - SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value(); - touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); - - if (!is_zero(touch_predicate)) { - ffi::Optional known_value = - SubstituteParamValues(axis_vars, op->indices, constraint.value); - new_regions.push_back(Known{touch_predicate, known_value}); - - unknown_region = unknown_region && !touch_predicate; - unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_); - } - } - - if (new_regions.size()) { - Analyzer local_analyzer; - - if (!is_zero(unknown_region)) { - new_regions.insert(new_regions.begin(), Known{unknown_region, std::nullopt}); - } - - std::vector updated_regions; - for (const auto& prev_region : regions_) { - for (const auto& new_region : new_regions) { - PrimExpr intersection = - SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_); - - if (!is_zero(intersection)) { - Region merged{intersection, prev_region.known_values}; - merged.known_values[op] = new_region.value; - updated_regions.push_back(std::move(merged)); - } - } - } - regions_ = updated_regions; - } - } - - Analyzer* analyzer_; - std::vector regions_; - const ffi::Map>& axis_var_lookup_; - const std::vector& knowns_; -}; - -class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { - public: - static PrimExpr Apply( - const std::unordered_map>& known_values, - PrimExpr expr, Analyzer* analyzer) { - BufferRegionValueReplacer mutator(known_values, analyzer); - PrimExpr result = mutator(expr); - // Simplification must occur after the substitution, as known - // values may provide enable simplifications. Also, cannot track - // whether a BufferLoad was - result = analyzer->Simplify(result); - return result; - } - - private: - using Parent = IRMutatorWithAnalyzer; - - BufferRegionValueReplacer( - const std::unordered_map>& known_values, - Analyzer* analyzer) - : Parent(analyzer), known_values_(known_values) {} - - using Parent::VisitExpr_; - - PrimExpr VisitExpr_(const BufferLoadNode* op) override { - auto it = known_values_.find(op); - if (it != known_values_.end() && it->second) { - return it->second.value(); - } else { - return ffi::GetRef(op); - } - } - - const std::unordered_map>& known_values_; -}; - -void BufferState::ApplyTouches(const ffi::Map>& axis_var_lookup, - const std::vector& touch_points, Analyzer* analyzer) { - std::vector new_knowns; - ffi::Map keep_prior_known_at; - - for (auto& touch : touch_points) { - if (touch.touch_type == BufferTouch::AccessType::Read) { - continue; - } - - PrimExpr known_value = touch.value; - - PrimExpr predicate = touch.predicate && touch.AfterLoopIteration(); - auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_, - {predicate, touch.value}, analyzer); - - for (const auto& region : regions) { - PrimExpr updated_predicate = BufferRegionValueReplacer::Apply( - region.known_values, region.region_predicate && predicate, analyzer); - - updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer); - PrimExpr updated_value = - BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer); - - if (!is_zero(updated_predicate)) { - if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) { - keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate); - } else { - keep_prior_known_at.Set(touch.buffer, !updated_predicate); - } - - if (!HasBufferLoad(updated_value)) { - BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value}; - new_knowns.push_back(new_constraint); - } - } - } - } - - if (keep_prior_known_at.size()) { - for (auto& constraint : constraints_) { - if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) { - constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer); - } - } - } - - if (new_knowns.size()) { - std::vector used(new_knowns.size(), false); - - for (auto& constraint : constraints_) { - PrimExpr expand_known_at = Bool(false); - - PrimExpr prev_value = constraint.value; - - for (size_t i = 0; i < new_knowns.size(); i++) { - if (new_knowns[i].buffer.same_as(constraint.buffer)) { - ffi::Optional overwritten_with = new_knowns[i].value; - if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { - expand_known_at = - SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); - used[i] = true; - } - } - } - - if (!is_zero(expand_known_at)) { - constraint.predicate = - SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer); - } - } - - for (size_t i = 0; i < new_knowns.size(); i++) { - if (!used[i]) { - constraints_.push_back(new_knowns[i]); - } - } - } - - constraints_.erase( - std::remove_if(constraints_.begin(), constraints_.end(), - [&](const auto& constraint) { return is_zero(constraint.predicate); }), - constraints_.end()); -} - -void BufferState::BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, - const std::vector& touch_points, - Analyzer* analyzer) { - std::vector new_knowns; - ffi::Map keep_prior_known_at; - - ffi::Map regions_written; - ffi::Map regions_read; - for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { - const auto& touch = *it; - - ffi::Map* to_update{nullptr}; - if (touch.touch_type == BufferTouch::AccessType::Write) { - to_update = ®ions_written; - - } else if (touch.touch_type == BufferTouch::AccessType::Read) { - to_update = ®ions_read; - } else { - continue; - } - - PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false)); - PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration(); - to_update->Set(touch.buffer, prev || new_predicate); - } - - auto update_map = [&](auto& map) { - ffi::Map new_map; - for (auto [buffer, predicate] : map) { - new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); - } - map = std::move(new_map); - }; - update_map(regions_written); - update_map(regions_read); - - // If buffer is already in used, widen the predicate - for (auto& prev_unused : constraints_) { - if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) { - PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value(); - prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer); - regions_written.erase(prev_unused.buffer); - } - } - - // Otherwise, add new "touch" to represent the unused values - for (auto [buffer, predicate] : regions_written) { - constraints_.push_back( - BufferTouch{buffer, predicate, tirx::Call(buffer->dtype, builtin::undef(), {})}); - } - - // If buffer is read out, narrow the predicate - for (auto& prev_unused : constraints_) { - if (auto opt_pred = regions_read.Get(prev_unused.buffer)) { - PrimExpr predicate = opt_pred.value(); - prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer); - } - } - - // Clean-up and remove any empty constraints - constraints_.erase( - std::remove_if(constraints_.begin(), constraints_.end(), - [](const auto& constraint) { return is_zero(constraint.predicate); }), - constraints_.end()); -} - -void BufferState::RemoveFreeParameters(const ffi::Map& free_predicate_parameters, - Analyzer* analyzer) { - for (auto& known : constraints_) { - known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); - known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer); - } -} - -bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const { - if (constraints_.size() != other.constraints_.size()) { - return false; - } - - for (size_t i = 0; i < constraints_.size(); i++) { - if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) { - return false; - } - } - - return true; -} - -ffi::Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { - if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { - return (*it).second; - } else { - return std::nullopt; - } -} - -ffi::Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, - const ffi::Array& indices) { - if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { - return (*it).second; - } - - ffi::Array vars; - for (size_t i = 0; i < indices.size(); i++) { - std::stringstream ss; - ss << buf->name << "_axis_" << i; - vars.push_back(Var(ss.str(), indices[i].dtype().element_of())); - } - - axis_var_lookup_.Set(buf, vars); - return vars; -} - -void ControlFlowGraph::ForwardPropagateKnownValues(std::optional flow_from) { - // Values to visit when searching. Using a std::set to - // preferentially visit nodes near the start of the control flow. - std::set to_visit; - - if (flow_from.has_value()) { - to_visit.insert(flow_from.value()); - } else { - // Initiatize the locations to search from, propagating values - // forward from all locations that have a known value. - for (size_t i = 0; i < control_flow_.size(); i++) { - bool has_known_value = false; - for (const auto& touch : control_flow_[i].touch_points) { - if (!HasBufferLoad(touch.value)) { - has_known_value = true; - break; - } - } - - if (has_known_value) { - to_visit.insert(i); - } - } - } - - // Map from a block's index - std::unordered_map visit_count_lookup; - - Analyzer analyzer; - analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); - analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( - arith::RewriteSimplifier::kTransitivelyProveInequalities | - arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | - arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); - - analyzer.Bind(iterator_ranges_); - analyzer.Bind(free_predicate_parameters_); - - while (to_visit.size()) { - size_t visiting = *to_visit.begin(); - to_visit.erase(visiting); - - size_t num_previous_visits = visit_count_lookup[visiting]++; - - ControlFlowBlock& block = control_flow_[visiting]; - - // Step 1: Collect known values provided from each predecessor - block.known_at_block_start = [&]() -> BufferState { - if (num_previous_visits >= max_revisits_) { - return BufferState(); - } - - // Validate internal constraint. This should be true by - // construction, as ControlFlowGraphBuilder only builds graphs - // that have two or fewer predecessors. - TVM_FFI_CHECK_LE(block.predecessors.size(), 2, InternalError) - << "Each block should have at most two predecessors. " - << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint."; - - std::vector states; - for (const auto& pred : block.predecessors) { - const auto& pred_block = control_flow_[pred.index]; - BufferState state = pred_block.known_at_block_end; - state.Substitute(pred.var_remap, &analyzer); - states.push_back(state); - } - - if (std::all_of(block.predecessors.begin(), block.predecessors.end(), - [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) { - // Predecessors, if any, are unvisited. - return {}; - } else if (block.predecessors.size() == 1) { - // SBlock has only a single predecessor - return states[0]; - } - - const auto& pred_a = block.predecessors[0]; - const auto& pred_b = block.predecessors[1]; - - auto& priors_a = states[0]; - auto& priors_b = states[1]; - - // During the first visit of a block, predecessor blocks may be - // unvisited, even though we preferentially visit earlier blocks - // first. (e.g. During the first visit of the start of a For - // loop, the end of the For loop has not yet been visited.) If - // this is the case, assume the best-case scenario that all - // knowns are consistent, and rely on a later visit to - // resolve/remove any conflicts. - if (visit_count_lookup[pred_a.index] == 0) { - return priors_b; - } else if (visit_count_lookup[pred_b.index] == 0) { - return priors_a; - } - - if (pred_a.post_condition && pred_b.post_condition) { - // The predicate can identify which predecessor block applies - // (e.g. i==0 for the first loop iteration, i>0 for remaining - // loop iterations). Therefore, we can use all buffer - // constraints, conditional on having come from the - // predecessor that provides it. - priors_a.AddCondition(pred_a.post_condition.value()); - priors_b.AddCondition(pred_b.post_condition.value()); - priors_a.Union(priors_b, &analyzer); - return priors_a; - } else { - // We don't know which predecessor applies. Therefore, the - // only buffer constraints that can be used are those that - // appear in both predecessors. - priors_a.Intersection(priors_b, &analyzer); - return priors_a; - } - }(); - - // Step 2: Collect knowns provided as a result of executing this block - auto post_state = [&]() { - if (num_previous_visits >= max_revisits_) { - return BufferState(); - } - auto post_state = block.known_at_block_start; - post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer); - post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); - return post_state; - }(); - - // Step 3: If any changes are made to the post knowns since the - // previous time we visited this block, mark the successor block - // as needing to be visited. - if (num_previous_visits == 0 || - !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) { - block.known_at_block_end = std::move(post_state); - for (const auto& successor : block.successors) { - to_visit.insert(successor.index); - } - } - } -} - -void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_from) { - // Values to visit when searching. Using a std::set to - // preferentially visit nodes near the end of the control flow. - std::set to_visit; - - if (flow_from.has_value()) { - to_visit.insert(flow_from.value()); - } else { - // Initiatize the locations to search from, propagating values - // backward from anywhere that performs a write. - for (size_t i = 0; i < control_flow_.size(); i++) { - const auto& touch_points = control_flow_[i].touch_points; - bool performs_write = std::any_of( - touch_points.begin(), touch_points.end(), - [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); - if (performs_write) { - to_visit.insert(i); - } - } - } - - // Map from a block's index - std::unordered_map visit_count_lookup; - - Analyzer analyzer; - analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); - analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( - arith::RewriteSimplifier::kTransitivelyProveInequalities | - arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | - arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); - - analyzer.Bind(iterator_ranges_); - analyzer.Bind(free_predicate_parameters_); - - while (to_visit.size()) { - size_t visiting = *to_visit.rbegin(); - to_visit.erase(visiting); - - size_t num_previous_visits = visit_count_lookup[visiting]++; - - ControlFlowBlock& block = control_flow_[visiting]; - - // Step 1: Collect known unused indices provided by each successor - block.unused_at_block_end = [&]() -> BufferState { - if (num_previous_visits >= max_revisits_) { - return BufferState(); - } - TVM_FFI_ICHECK_LE(block.successors.size(), 2) - << "Each block should have at most two successors, but block " << visiting - << " breaks this requirement"; - - std::vector states; - for (const auto& successor : block.successors) { - const auto& successor_block = control_flow_[successor.index]; - BufferState state = successor_block.unused_at_block_start; - state.Substitute(successor.var_remap, &analyzer); - states.push_back(state); - } - - if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) { - return visit_count_lookup[successor.index] == 0; - })) { - // Successors, if any, are unvisited. - return {}; - } else if (block.successors.size() == 1) { - // SBlock has only a single successor - return states[0]; - } - - const auto& successor_a = block.successors[0]; - const auto& successor_b = block.successors[1]; - - auto& post_a = states[0]; - auto& post_b = states[1]; - - // During the first visit of a block, successor blocks may be - // unvisited, even though we preferentially visit later blocks - // first. (e.g. During the first visit of the end of a For - // loop, the start of the For loop has not yet been visited.) - // If this is the case, assume the best-case scenario that all - // knowns are consistent, and rely on a later visit to - // resolve/remove any conflicts. - if (visit_count_lookup[successor_a.index] == 0) { - return post_b; - } else if (visit_count_lookup[successor_b.index] == 0) { - return post_a; - } - - if (successor_a.post_condition && successor_b.post_condition) { - // The predicate can identify which successor block applies - // (e.g. i==n-1 for the last loop iteration, i= max_revisits_) { - return BufferState(); - } - auto prior_state = block.unused_at_block_end; - prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer); - prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); - return prior_state; - }(); - - // Step 3: If any changes are made to the post knowns since the - // previous time we visited this block, mark the successor block - // as needing to be visited. - if (num_previous_visits == 0 || - !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) { - block.unused_at_block_start = std::move(unused_at_block_start); - for (const auto& pred : block.predecessors) { - to_visit.insert(pred.index); - } - } - } -} - -bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tirx::BufferStore& store, - const Stmt& context) const { - ffi::Optional> index_variables = GetIndexVariables(store->buffer); - if (!index_variables) { - return false; - } - - auto it = control_flow_lookup_.find(context.get()); - TVM_FFI_ICHECK(it != control_flow_lookup_.end()) - << "Context did not occur within analyzed statement:\n" - << context; - const auto& context_block = control_flow_[it->second]; - - auto [store_touch, free_params] = context_block.MakeBufferTouch( - store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write, - BufferLoad(store->buffer, store->indices)); - - Analyzer local_analyzer; - local_analyzer.Bind(free_predicate_parameters_); - local_analyzer.Bind(iterator_ranges_); - local_analyzer.Bind(free_params); - local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( - arith::RewriteSimplifier::kTransitivelyProveInequalities | - arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | - arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); - - PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration(); - - predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer); - - for (const auto& unused : context_block.unused_at_block_end.constraints_) { - if (store_touch.buffer.same_as(unused.buffer)) { - PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer); - if (is_zero(difference)) { - return true; - } - } - } - return false; -} - -PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tirx::Stmt& context, - Analyzer* analyzer) const { - size_t context_index = [&]() { - auto it = control_flow_lookup_.find(context.get()); - TVM_FFI_ICHECK(it != control_flow_lookup_.end()) - << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor"; - return it->second; - }(); - - const auto& control_flow_block = control_flow_[context_index]; - - PrimExpr constraint = Bool(true); - for (const auto& known : non_buffer_assumptions_) { - constraint = constraint && known; - } - With constraint_context(analyzer, constraint); - With control_flow_scope(analyzer, control_flow_block.scope_predicate); - - expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues( - std::move(expr), axis_var_lookup_, analyzer); - - expr = analyzer->Simplify(std::move(expr)); - return expr; -} - -} // namespace tirx -} // namespace tvm diff --git a/src/tirx/analysis/control_flow_graph.h b/src/tirx/analysis/control_flow_graph.h deleted file mode 100644 index 8f97d06f384e..000000000000 --- a/src/tirx/analysis/control_flow_graph.h +++ /dev/null @@ -1,667 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file control_flow_graph.h - * \brief Utility for extracting and interacting with buffer touch points - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#ifndef TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ -#define TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ - -namespace tvm { -namespace tirx { - -/*! \brief Represents an interaction with a buffer */ -struct BufferTouch { - enum class AccessType { - /*! \brief Buffer access occurs in BufferLoad */ - Read, - - /*! \brief Buffer access occurs in BufferStore */ - Write, - - /*! \brief Buffer access occurs in tirx::builtin::assume() */ - Assume, - }; - - BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value) - : buffer(buffer), - predicate(predicate), - value(value), - loop_var_expressions({}), - touch_type(AccessType::Assume) {} - - BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value, - std::vector> loop_var_expressions, AccessType touch_type) - : buffer(buffer), - predicate(predicate), - value(value), - loop_var_expressions(loop_var_expressions), - touch_type(touch_type) {} - - /*! \brief The buffer being touched */ - Buffer buffer; - - /*! \brief A predicate that is true when this touch applies - * - * May be in terms of axis variables to indicate touches that impact - * only a portion of a buffer. - */ - PrimExpr predicate; - - /*! \brief The value in this buffer after the touch - * - * May be in terms of axis variables to indicate a known - * non-constant value. May be in terms of a BufferLoad to indicate - * an unknown value. - */ - PrimExpr value; - - /*! \brief Active loops during the buffer touch - * - * The vector contains one entry for each loop that contains the - * buffer touch. The `Var` item in each entry is the loop variable - * itself. The `PrimExpr` item is an expression for the loop - * variable in terms of the buffer axis variables in - * `ControlFlowGraph::axis_var_lookup_`. - * - * Used to construct boolean expressions indicating whether the loop - * iteration that performs this touch has been reached. - */ - std::vector> loop_var_expressions; - - /*! \brief How the buffer was interacted with - * - * When used as a constraint (e.g. in BufferState), should use - * Assume. - */ - AccessType touch_type{AccessType::Assume}; - - /*! \brief Generate a boolean expression that is true for indices - * accessed by this touch during this iteration or a previous - * loop iteration. - * - * Used during forward propagation, to track known values that were - * written in the current loop iteration, or in a preceding loop - * iteration. - */ - PrimExpr BeforeLoopIteration() const; - - /*! \brief Generate a boolean expression that is true for indices - * accessed by this touch during this loop iteration. - * - * Used during speculative no-op insertion checks, to specify which - * indices must be later overwritten for a store to have no impact - * on final results. - */ - PrimExpr AtLoopIteration() const; - - /*! \brief Generate a boolean expression that is true for indices - * accessed by this touch during this loop iteration or a - * subsequent loop iteration. - * - * Used during backward propagation, to track indices that are - * overwritten in the current loop iteration or in a later loop - * iteration. - */ - PrimExpr AfterLoopIteration() const; - - /* \brief Checks if this touch affects a subset of indices of another - * - * Returns true if the indices accessed by this touch are a subset - * of predicate is true can be proven to be a subset of the other - * subset. Returns false if it cannot be proven to be a subset of - * ther other subset. - */ - bool IsSubsetOf(const BufferTouch& other, arith::Analyzer* analyzer) const; - - /* \brief Checks if this touch affects distinct indices from another - * - * Returns true if it can be proven that the two predicates cannot - * be simultaneously true. Returns false if it cannot be proven - * that the two predicates are distinct. - */ - bool IsDistinctFrom(const BufferTouch& other, arith::Analyzer* analyzer) const; - - /* \brief Checks if this touch affects distinct indices from another - * - * Returns true if it can be proven that the two predicates cannot - * be simultaneously true. Returns false if it cannot be proven - * that the two predicates are distinct. - */ - bool IsEquivalentTo(const BufferTouch& other, arith::Analyzer* analyzer) const; - - friend std::ostream& operator<<(std::ostream& os, const BufferTouch& expr); -}; - -/*! \brief Represents the known state of buffers at a specific point */ -class BufferState { - public: - /*! Default constructor - * - * Initialize the buffer state with no known information. - */ - BufferState() {} - - /*! \brief Replace BufferLoad instances with known values - * - * \param expr The expression to be updated. - * - * \param axis_var_lookup A map from buffer to the variables - * representing positions along the buffer's axes. - * - * \param analyzer The analyzer to use when validating a - * constraint's predicate. - * - * \returns The modified expression. If no substitutions are made, - * the original expression is returned. - */ - PrimExpr SubstituteKnownBufferValues(PrimExpr expr, - const ffi::Map>& axis_var_lookup, - arith::Analyzer* analyzer) const; - - /*! \brief Apply a condition to all known constraints - * - * For example, when propagating pre-loop constraints into the body - * of a loop, add a condition that the loop iterator is zero. - * - * \param condition The condition to apply - */ - void AddCondition(const PrimExpr& condition); - - /*! \brief Perform a variable substitution for all constraints - * - * For example, when propagating constraints from the end of a loop - * to the beginning, replace `i` with `i-1`. - * - * \param var_remap The variable remapping to apply. - */ - void Substitute(const ffi::Map& var_remap, arith::Analyzer* analyzer); - - /*! \brief Simplify the predicate of all constraints - * - * \param analyzer The analyzer with which to simplify - */ - void Simplify(arith::Analyzer* analyzer); - - /*! \brief Update the known buffer values based on buffer touches - * - * For any Write or Assume touches, update the known values. For - * any Read touches, ignore. Used to determine known values at the - * end of a control flow block, given the known values at the start. - * - * \param axis_var_lookup A map from buffer to the variables - * representing positions along the buffer's axes. - * - * \param touch_points The buffer touch points to apply - * - * \param analyzer The analyzer to use for simplifications - */ - void ApplyTouches(const ffi::Map>& axis_var_lookup, - const std::vector& touch_points, arith::Analyzer* analyzer); - - /*! \brief Update unused buffer locations based on buffer touches - * - * For any Write, mark the written-to indices as unused. (That is, - * immediately prior to assigning `buf[i] = expr`, the value stored - * at `buf[i]` is irrelevant.) For any Read, mark the read-from - * indices as used. This method is used to determine unused buffer - * indices at the start of a control flow block, given the unused - * buffer indices values at the end. - * - * \param axis_var_lookup A map from buffer to the variables - * representing positions along the buffer's axes. - * - * \param touch_points The buffer touch points to apply - * - * \param analyzer The analyzer to use for simplifications - */ - void BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, - const std::vector& touch_points, - arith::Analyzer* analyzer); - - /*! \brief Remove free parameters from the constraints - * - * \param free_predicate_parameters - * - * \param analyzer The analyzer with which to simplify after removal - */ - void RemoveFreeParameters(const ffi::Map& free_predicate_parameters, - arith::Analyzer* analyzer); - - /*! \brief Check if two buffer states are equivalent - * - * \param other - * - * \param analyzer The analyzer used to check equality of PrimExpr - * - * \return True if the two states are provably equivalent, false otherwise. - */ - bool IsEquivalentTo(const BufferState& other, arith::Analyzer* analyzer) const; - - /* \brief Add known values provided by another state - * - * \param other The state with which to merge constraints - * - * \param analyzer The analyzer with which to simplify the result - */ - void Union(const BufferState& other, arith::Analyzer* analyzer); - - /* \brief Remove all known values not consistent with another state - * - * \param other The state with which to merge constraints - * - * \param analyzer The analyzer with which to simplify the result - */ - void Intersection(const BufferState& other, arith::Analyzer* analyzer); - - friend std::ostream& operator<<(std::ostream& os, const BufferState&); - - private: - friend class ControlFlowGraph; - /*! \brief The known constraints */ - std::vector constraints_; -}; - -/*! - * \brief Represents the flow of control through a `tirx::Stmt` - * - * This class contains an internal representation of the possible - * control flow that may occur during execution of a `tirx::Stmt`. It - * consists of a collection of ControlFlowBlock objects, each of which - * represents a subset of operations performed during execution, along - * with edges that represent allowed transitions between - * `ControlFlowBlock`. - * - * In addition, the following restrictions are used. - * - * 1. Each block may have at most two predecessors, and at most two - * successors. - * - * 2. Within each block, values stored in a buffer do not change. - * That is, encountering a `BufferStore` node requires creating a - * new block. - * - * For example, consider the following PrimFunc - * - * \code{.py} - * @T.prim_func - * def func(T.Buffer(16, "float32")): - * for i in T.serial(16): - * if i < 8: - * B[i] = i - * else: - * B[i] = i-8 - * \endcode - * - * The control flow graph would have eight control blocks. - * - * 1. function_entry, from the start of the function through the - * evaluation of the loop's extent. - * - * Predecessors: n/a - * Successors: loop_start - * - * 2. loop_start, after entering the body of the loop, through the - * evaluation of the conditional `i < 8` - * - * Predecessors: function_entry, after_conditional - * Successors: then_clause_start, else_clause_start - * - * 3. then_clause_start, after entering the then_clause of `i < 8`, - * through evaluation of the value `i`. - * - * Predecessors: loop_start - * Successors: then_clause_end - * - * 4. then_clause_end, after storing to `B[i]` prior to exiting the - * then_clause. - * - * Predecessors: then_clause_start - * Successors: after_conditional - * - * 5. else_clause_start, after entering the else_clause of `i < 8`, - * through evaluation of the value `i-8`. - * - * Predecessors: loop_start - * Successors: else_clause_end - * - * 6. else_clause_end, after storing to `B[i]` prior to exiting the - * else_clause. - * - * Predecessors: else_clause_start - * Successors: after_conditional - * - * 7. after_conditional, after the end of the if/then/else, before the - * end of the loop body - * - * Predecessors: then_clause_end, else_clause_end - * Successors: loop_start, after_loop - * - * 8. after_loop, after the loop - * - * Predecessors: after_conditional - * Successors: n/a - * - * - * By identifying `BufferStore` nodes whose value does not depend on - * values stored in input buffers (e.g. initializing `buf[i] = 0.0`), - * or whose values are provided using `builtin::assume()` - * (e.g. `T.assume(buf[i] == 0.0)`), the value stored in a buffer at - * those indices may be known for a given control block. These known - * values can then be propagated forward to successor blocks, to be - * used in context-dependent simplifications. - * - * In addition to the allowed transitions between control-flow - * blocks, each block also tracks the buffer touch points; which - * indices are read from a buffer, which values are written to which - * indices of a buffer, and assumptions are provided using - * `builtin::assume()`; that occur during the control-flow block. - * - * Note: The current implementation only tracks the values of - * buffers that are constrained to a specific value, and does not - * track inequalities that may partially constrain buffer values. - * That is, entering a scoped context with a data-dependent equality - * condition (e.g. `if buf[i] == value`) is tracked, but entering a - * scoped context with a data-dependent inequality condition - * (e.g. `if buf[i] > value`) is not tracked. - */ -class ControlFlowGraph { - public: - /* \brief Extract the touch pattern from a TIR statement - */ - explicit ControlFlowGraph(const Stmt& stmt, int64_t max_simplification_steps = 0, - size_t max_revisits = 5); - - /* \brief Check if a write is overwritten without impacting final results - * - * \param store The store to be examined - * - * \param context The context in which the buffer store occurs, used - * to identify the control-flow block in which the store occurs. In - * most cases, this will be the same object as the `store` itself. - * - * \param analyzer The analyzer to be used for simplifications - * - * \return True if the specified store can be proven to be - * overwritten without contributing to any later statements. - * Returns false otherwise. - */ - bool IsOverwrittenWithoutEffect(const BufferStore& store, const Stmt& context) const; - - /* \brief Simplify the expression, assuming it occurs within the given context - * - * \param expr The expression to be simplified. Does not need to - * have occurred within the statement used to construct this - * BufferTouchPattern. - * - * \param context The statement where this expression occurred, or - * is to be inserted. Must occur within the statement used to - * construct this BufferTouchPattern. - * - * \param analyzer The analyzer to be used for simplifications - * - * \returns The simplified statement - */ - PrimExpr SimplifyInContext(PrimExpr expr, const Stmt& context, arith::Analyzer* analyzer) const; - - /*! \brief Remove the specified BufferStore from the control-flow - * graph - * - * Removing the specified store, which may reflow known values. - * This is necessary when simplifying sequential stores of the same - * value. Otherwise, the first could be removed as a no-op because - * it is overwritten by the second, and the second could be removed - * as a no-op because it is the same value as the first. - * - * \param store The store to remove - */ - void RemoveStore(const tirx::BufferStore& store); - - friend std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern); - - private: - /*! \brief Return index variables representing locations within a - * buffer. - * - * For a given buffer, will always return the same set of variables. - * - * \param buf The buffer being accessed - * - * \param indices The indices at which the buffer is being accessed. - * These are used to set the dtype of the buffer axis variables. - * - * \returns Variables representing a position along the buffer's axis. - */ - ffi::Array GetIndexVariables(const Buffer& buf, const ffi::Array& indices); - - /*! \brief Return index variables representing locations within a - * buffer, if they have been generated before. - * - * For a given buffer, will always return the same set of variables. - * - * \param buf The buffer being accessed - * - * \returns Variables representing a position along the buffer's axis. - */ - ffi::Optional> GetIndexVariables(const Buffer& buf) const; - - /*! \brief Propagate known values from known BufferStore/assume - * subsequent control flow blocks - * - * \param flow_from If specified, re-flow only from that block. - */ - void ForwardPropagateKnownValues(std::optional flow_from = std::nullopt); - - /*! \brief Propagate overwritten/unused indices to preceding control - * flow blocks - * - * \param flow_from If specified, re-flow only from that block. - */ - void BackwardPropagateUnusedValues(std::optional flow_from = std::nullopt); - - struct ControlFlowEdge { - /* \brief The source block of the control flow edge - * - * Lookup index into `control_flow_` - */ - size_t index; - - /*! \brief Variable remaps - * - * e.g. Replacing loop iterator `i` with `i-1` when following an - * edge from the end of a loop to the beginning of the loop. - */ - ffi::Map var_remap; - - /*! \brief Condition that must to true after following this edge - * - * This is applied after variable remapping. For example, `i > - * loop_min` when following the an edge from the end of a loop to - * the beginning of the loop. - */ - ffi::Optional post_condition; - }; - friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge); - - struct ControlFlowBlock { - struct LoopEntry { - Var loop_var; - PrimExpr loop_min; - PrimExpr loop_max; - Range loop_range; - }; - - /*! \brief Loop iterators that are active during this block */ - std::vector active_loop_iterators; - - /*! \brief Loop-dependent Let bindings that may appear within the block */ - ffi::Map let_bindings_using_loop; - - /*! \brief Predicate that must be true to have reached this block */ - PrimExpr scope_predicate{Bool(true)}; - - /*! \brief All known values prior to executing the block */ - BufferState known_at_block_start; - - /*! \brief All known values after executing the block */ - BufferState known_at_block_end; - - /*! \brief Indices whose value at the start of the block is known to be unused */ - BufferState unused_at_block_start; - - /*! \brief Indices whose value at the end of the block is known to be unused */ - BufferState unused_at_block_end; - - /* \brief Buffer touches that occur within the block - * - * All buffer touches within a block can be treated as occurring - * simultaneously. - */ - std::vector touch_points; - - /* \brief The blocks that occur after this block - * - * Lookup index into `control_flow_` - */ - std::vector successors; - - /* \brief The blocks that occur before this block */ - std::vector predecessors; - - /* \brief Construct a BufferTouch instance within this - * ControlFlowBlock - * - * \param graph The mutable ControlFlowGraph that owns the buffer - * touch. Any free parameters used in the BufferTouch's predicate - * will be tracked by the ControlFlowGraph. - * - * \param buf The Buffer being accessed - * - * \param indices The indices at which the buffer is accessed, in - * terms of the loop variables. - * - * \param touch_type The type of touch being generated - * - * \param known_expr_value The value being written to the buffer - * - * \returns The newly generated BufferTouch - */ - BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf, - const ffi::Array& indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; - - /* \brief Construct a BufferTouch instance as if it occurred in - * this ControlFlowBlock - * - * Used when speculative checking if a BufferStore could be - * inserted. - * - * \param buf The Buffer being accessed - * - * \param index_variables The variables representing location - * within a buffer, with one variable for each axis of the buffer. - * - * \param indices The indices at which the buffer is accessed, in - * terms of the loop variables. - * - * \param touch_type The type of touch being generated - * - * \param known_expr_value The value being written to the buffer - * - * \returns The newly generated BufferTouch, and a map specifying - * all free parameters that may occur in the BufferTouch's - * predicate. - */ - std::pair> MakeBufferTouch(const Buffer& buf, - ffi::Array index_variables, - ffi::Array indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; - }; - friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern); - - /* \brief The control flow that occurs within the analyzed statement */ - std::vector control_flow_; - - /* \brief A lookup into control_flow_ - * - * A map to look up the control flow block that contains the - * statement. - */ - std::unordered_map control_flow_lookup_; - - /*! \brief A map from free parameters to their range - * - * A BufferStore/BufferLoad has indices in terms of loop iterators, - * while the internal BufferTouch must have predicate in terms of - * the buffer's axes. While converting to the internal BufferTouch, - * reduction axes show up as free parameters. Tracking the range of - * the free parameters allows them to be removed later, by requiring - * a predicate to be true for all values of the free parameters. - */ - ffi::Map free_predicate_parameters_; - - /*! \brief Ranges of iterators found in the analyzed statement */ - ffi::Map iterator_ranges_; - - /* \brief A map from buffer to the variables representing positions - * along the buffer's axes. - * - * This is stored here, rather than as part of the BufferState or - * BufferTouch, to ensure that all access of a buffer use the same - * variables to represent the buffer's axes, reducing the amount of - * variable substitution required. - */ - ffi::Map> axis_var_lookup_; - - /* \brief Assumptions that do not depend on buffer values - * - * These may be collected as part of the handling of `builtin::assume()`, and do not depend on any - * buffer. Since TIR only allows mutable values as part of buffers, these assumptions may be used - * anywhere the - */ - std::vector non_buffer_assumptions_; - - friend class ControlFlowGraphBuilder; - - /*! \brief The maximum number of revisits while flowing constraints */ - size_t max_revisits_; - - /*! \brief The maximum number of revisits while flowing constraints */ - int64_t max_simplification_steps_; -}; - -} // namespace tirx -} // namespace tvm -#endif // TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index 4bdb5c083c01..aa2280215471 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -32,12 +32,10 @@ #include #include -#include #include #include "../../arith/const_fold.h" #include "../../arith/ir_mutator_with_analyzer.h" -#include "../analysis/control_flow_graph.h" #include "../analysis/var_use_def_analysis.h" #include "ir_utils.h" @@ -45,17 +43,12 @@ namespace tvm { namespace tirx { struct RemoveNoOpConfigNode : public ffi::Object { - bool use_dataflow_analysis; int64_t max_simplification_steps; bool ignore_profiler_call; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("use_dataflow_analysis", &RemoveNoOpConfigNode::use_dataflow_analysis, - "If true, known buffer values are propagated and used " - "to statically prove statements as no-ops.", - refl::DefaultValue(false)) .def_ro("max_simplification_steps", &RemoveNoOpConfigNode::max_simplification_steps, "If non-zero, RewriteSimplifier will throw an error " "after the number of steps specified. " @@ -81,10 +74,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.RemoveNoOp", RemoveNoOpConfig); // Mark the statement of each stage. class NoOpRemover : public arith::IRMutatorWithAnalyzer { public: - static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer, - std::optional touch_pattern, const StmtNode* context, - bool ignore_profiler_call = false) { - NoOpRemover visitor(analyzer, touch_pattern, context, ignore_profiler_call); + static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call = false) { + NoOpRemover visitor(analyzer, ignore_profiler_call); return visitor(std::move(stmt)); } @@ -93,12 +84,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { using Parent::VisitStmt; using Parent::VisitStmt_; - NoOpRemover(arith::Analyzer* analyzer, std::optional touch_pattern, - const StmtNode* context, bool ignore_profiler_call = false) - : Parent(analyzer), - touch_pattern_(touch_pattern), - context_(context), - ignore_profiler_call_(ignore_profiler_call) {} + NoOpRemover(arith::Analyzer* analyzer, bool ignore_profiler_call = false) + : Parent(analyzer), ignore_profiler_call_(ignore_profiler_call) {} Stmt VisitStmt_(const BindNode* op) final { // Simply mutate the value and return. @@ -195,27 +182,11 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { return this->VisitStmt(SeqStmt(statements)); }; - if (touch_pattern_.has_value()) { - // A write that is later overwritten is a no-op. - Stmt context = context_ ? ffi::GetRef(context_) : store; - if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) { - touch_pattern_->RemoveStore(store); - return only_side_effects(); - } - } - // A write whose destination is known to already contain the // values to be written is a no-op. - // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; - if (touch_pattern_.has_value()) { - Stmt context_arg = context_ ? ffi::GetRef(context_) : Stmt(store); - stores_existing_value = - touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_); - } else { - stores_existing_value = analyzer_->Simplify(stores_existing_value); - } + stores_existing_value = analyzer_->Simplify(stores_existing_value); if (is_one(stores_existing_value)) { return only_side_effects(); } @@ -289,30 +260,20 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { } std::unordered_map var_range_map_; - std::optional touch_pattern_; - const StmtNode* context_; bool ignore_profiler_call_{false}; }; -Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional touch_pattern, - const StmtNode* context, bool ignore_profiler_call = false) { - return NoOpRemover::Apply(std::move(stmt), analyzer, std::move(touch_pattern), context, - ignore_profiler_call); +Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call) { + return NoOpRemover::Apply(std::move(stmt), analyzer, ignore_profiler_call); } namespace transform { Pass RemoveNoOp() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - std::optional touch_pattern = std::nullopt; - RemoveNoOpConfig config = ctx->GetConfig("tirx.RemoveNoOp") .value_or(AttrsWithDefaultValues()); - if (config->use_dataflow_analysis) { - touch_pattern.emplace(f->body, config->max_simplification_steps); - } - arith::Analyzer analyzer; analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); @@ -320,8 +281,8 @@ Pass RemoveNoOp() { { auto* write_ptr = f.CopyOnWrite(); - write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer, - std::move(touch_pattern), nullptr, ignore_profiler_call); + write_ptr->body = + NoOpRemover::Apply(std::move(write_ptr->body), &analyzer, ignore_profiler_call); } return f; }; diff --git a/src/tirx/transform/remove_no_op.h b/src/tirx/transform/remove_no_op.h index 8bb4dee1f32e..21d1f917d50b 100644 --- a/src/tirx/transform/remove_no_op.h +++ b/src/tirx/transform/remove_no_op.h @@ -27,10 +27,6 @@ #include #include -#include - -#include "../analysis/control_flow_graph.h" - namespace tvm { namespace tirx { @@ -43,17 +39,9 @@ namespace tirx { * * \param analyzer The analyzer to use while proving no-ops * - * \param control_flow The analyzed control-flow graph, which contains - * the `stmt` to be analyzed. If provided, known buffer values will - * be used to remove no-ops. (e.g. Removing `buf[i] = 0` in cases - * where `buf[i]` is known to already contain zero.) If nullptr, - * known buffer values will not be used. - * * \return The modified statement with no-ops removed */ -Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, - std::optional touch_pattern = std::nullopt, - const StmtNode* context = nullptr, bool ignore_profiler_call = false); +Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call = false); } // namespace tirx } // namespace tvm diff --git a/src/tirx/transform/simplify.cc b/src/tirx/transform/simplify.cc index bf80ad00a455..18e398d095a4 100644 --- a/src/tirx/transform/simplify.cc +++ b/src/tirx/transform/simplify.cc @@ -34,10 +34,7 @@ #include #include -#include - #include "../../arith/ir_mutator_with_analyzer.h" -#include "../../tirx/analysis/control_flow_graph.h" namespace tvm { namespace arith { @@ -46,8 +43,6 @@ using namespace tirx; struct SimplifyConfigNode : public ffi::Object { bool transitively_prove_inequalities; - bool propagate_knowns_to_prove_conditional; - bool propagate_knowns_to_simplify_expressions; bool convert_boolean_to_and_of_ors; bool apply_constraints_to_boolean_branches; @@ -58,17 +53,6 @@ struct SimplifyConfigNode : public ffi::Object { &SimplifyConfigNode::transitively_prove_inequalities, "If true, simplify conditionals with transitive combinations of scoped constraints", refl::DefaultValue(false)) - .def_ro( - "propagate_knowns_to_prove_conditional", - &SimplifyConfigNode::propagate_knowns_to_prove_conditional, - "If true, known buffer values are propagated and used to statically prove conditionals", - refl::DefaultValue(false)) - .def_ro( - "propagate_knowns_to_simplify_expressions", - &SimplifyConfigNode::propagate_knowns_to_simplify_expressions, - "If true, known buffer values are propagated and used to replace BufferLoad wherever " - "possible", - refl::DefaultValue(false)) .def_ro("convert_boolean_to_and_of_ors", &SimplifyConfigNode::convert_boolean_to_and_of_ors, "If true, simplify conditionals into an AND of ORs", refl::DefaultValue(false)) .def_ro("apply_constraints_to_boolean_branches", @@ -117,22 +101,15 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { auto config = config_opt.value_or(MakeDefaultSimplifyConfig()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); - std::optional touch_pattern = std::nullopt; - if (config->propagate_knowns_to_prove_conditional || - config->propagate_knowns_to_simplify_expressions) { - touch_pattern = ControlFlowGraph(func->body); - } - - StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); + StmtSimplifier simplifier(analyzer, config); simplifier.MarkBufferMapShapes(func); func.CopyOnWrite()->body = simplifier(func->body); return func; } private: - explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, - std::optional touch_pattern) - : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} + explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config) + : IRMutatorWithAnalyzer(analyzer), config_(config) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitExpr_; @@ -152,24 +129,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // to prevent inlining LetStmt vars that appear in buffer definitions. Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override { return buffer; } - PrimExpr VisitExpr(const PrimExpr& expr) final { - if (config_->propagate_knowns_to_simplify_expressions) { - return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); - } else { - return analyzer_->Simplify(expr); - } - } + PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } - Stmt VisitStmt(const Stmt& stmt) override { - ffi::Optional cache = this->current_stmt_; - this->current_stmt_ = stmt; - Stmt output = Parent::VisitStmt(stmt); - this->current_stmt_ = std::move(cache); - return output; - } - Stmt VisitStmt_(const ForNode* op) final { analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); With ctx1(analyzer_, op->loop_var >= op->min); @@ -262,17 +225,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { /* \brief Internal utility for checking conditionals * - * Uses more aggressive optimization, such as performing additional - * inlining and tracking known buffer values. + * Substitutes any known Bind values and then simplifies with the analyzer. */ ffi::Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); - if (config_->propagate_knowns_to_prove_conditional) { - TVM_FFI_ICHECK(touch_pattern_.has_value()); - condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_); - } else { - condition = analyzer_->Simplify(condition); - } + condition = analyzer_->Simplify(condition); if (const int64_t* as_int = as_const_int(condition)) { return Bool(*as_int); } else { @@ -281,12 +238,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } SimplifyConfig config_; - std::optional touch_pattern_; // Pure Bind values kept for substitution into assert conditions. // Grows monotonically under SSA — no scope-based cleanup required. ffi::Map non_inlined_bindings_; - ffi::Optional current_stmt_{std::nullopt}; }; } // namespace arith diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py deleted file mode 100644 index ea54d87dab92..000000000000 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401 - -import tvm -import tvm.testing -from tvm import tirx -from tvm.runtime import convert -from tvm.script import tirx as T - -i = tirx.Var("i", "int32") -j = tirx.Var("j", "int32") -n = tirx.Var("n", "int32") -m = tirx.Var("m", "int32") -b = tirx.Var("b", "bool") -buf = tirx.decl_buffer(16, "int32", "buf") - -tir_false = tirx.IntImm("bool", False) -tir_true = tirx.IntImm("bool", True) - -before, expected = tvm.testing.parameters( - # General arithmatic - [tir_true, tir_true], - [tir_false, tir_false], - [b, b], - [i > 5, i > 5], - [i > n, i > 7], - [i < n, i < 0], - [i <= n, i <= 0], - [i >= n, i >= 7], - [n > i, T.int32(0) > i], - [n < i, T.int32(7) < i], - [n <= i, T.int32(7) <= i], - [n >= i, T.int32(0) >= i], - [i == n, tirx.all(i <= 0, T.int32(7) <= i)], - [n == i, tirx.all(T.int32(7) <= i, i <= 0)], - [i != n, tirx.any(i < 0, T.int32(7) < i)], - [n != i, tirx.any(T.int32(7) < i, i < 0)], - [i // 4 > n, i // 4 > 7], - [n < i // 4, T.int32(7) < i // 4], - [(i + n) // 4 > 0, tirx.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tirx.all(tirx.Add(i, 7) // 4 <= 0, T.int32(0) <= tirx.Add(i, 0) // 4)], - [i + n < 10, i + 7 < 10], - [i - n < 10, tirx.Sub(i, 0) < 10], - [tirx.Not(i < n), tirx.Not(i < 7)], - # Use of FloorMod should make the narrowing strategy bail out, as - # it is non-monotonic. - [i % 8 == n, tir_false], - # Ensure that dividing by a free parameter doesn't generate a - # divide-by-zero to be triggered later. - [i // n == 0, tir_false], - ### Buffer handling - [buf.vload(0) > 0, tir_false], - [buf.vload(0) > i, tir_false], - [buf.vload(i) > 0, tir_false], - [tirx.And(buf.vload(i) > 0, i <= 0), tirx.And(tir_false, i <= 0)], - [tirx.Or(buf.vload(i) > 0, i <= n), tirx.Or(tir_false, i <= 0)], - [tirx.Or(tirx.Not(buf.vload(i) > 0), i <= n), tirx.Or(tir_false, i <= 0)], -) - - -def test_narrow_expression(before, expected): - ranges = {n: tvm.ir.Range(0, 8)} - after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges) - - if expected is None: - assert after is None - else: - tvm.ir.assert_structural_equal(after, expected) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tirx-transform/test_tir_transform_remove_no_op.py b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py index 35137ac4cf50..a4e732a77a01 100644 --- a/tests/python/tirx-transform/test_tir_transform_remove_no_op.py +++ b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py @@ -86,11 +86,10 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None: assert isinstance(ret, tvm.tirx.Evaluate) -def _apply_remove_no_op(mod, use_dataflow_analysis=False, max_simplification_steps=0): +def _apply_remove_no_op(mod, max_simplification_steps=0): """Helper function to apply RemoveNoOp transform with config.""" config = { "tirx.RemoveNoOp": { - "use_dataflow_analysis": use_dataflow_analysis, "max_simplification_steps": max_simplification_steps, } } @@ -242,30 +241,8 @@ def expected(A: T.Buffer(16, "int32")): tvm.ir.assert_structural_equal(mod["main"], expected) -def test_remove_unused_write(): - """For two sequential writes, the first is a no-op""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 100 - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - def test_suppress_removal_of_unused_write(): - """Dataflow analysis requires the config to opt-in - - Like test_remove_unused_write, but dataflow analysis isn't enabled. - """ + """Sequential writes to the same location are not removed without dataflow analysis.""" @T.prim_func(private=True, s_tir=True) def before(A: T.Buffer(16, "int32")): @@ -274,30 +251,10 @@ def before(A: T.Buffer(16, "int32")): A[i] = 42 mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=False) + mod = _apply_remove_no_op(mod) tvm.ir.assert_structural_equal(mod["main"], before) -def test_keep_side_effects_of_unused_write(): - """For two sequential writes, the first value may have side effects""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = T.call_extern("extern_func", dtype="int32") - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.call_extern("extern_func", dtype="int32")) - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - def test_keep_first_write_when_used(): """For two sequential writes, keep the first if it is used""" @@ -312,56 +269,6 @@ def before(A: T.Buffer(16, "int32")): tvm.ir.assert_structural_equal(mod["main"], before) -def test_remove_overwritten_loop(): - """Remove repeated writes to the same region - - If two loops write to the same region, the first is a no-op. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 100 - - for i in T.serial(16): - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - -def test_remove_overwritten_subloop(): - """Remove repeated writes to the same region - - If the first loop writes to a subset of the region, the first loop - is a no-op. Similar to test_remove_overwritten_loop, but the first - loop's extents are a subset of the second loop. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(4, 12): - A[i] = 100 - - for i in T.serial(16): - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - def test_keep_partially_overwritten_loop(): """Keep partially overwritten regions @@ -383,148 +290,6 @@ def before(A: T.Buffer(16, "int32")): tvm.ir.assert_structural_equal(mod["main"], before) -def test_remove_overwritten_predicated_loop_with_identical_condition(): - """Remove repeated writes to the same predicated region. - - Similar to test_keep_partially_overwritten_loop, except the first loop - has the same predicate as the second, and can therefore be - removed. - - In the past, this test has had performance regressions in which - the runtime increased from a few seconds to nearly ten minutes. - The "max_simplification_steps" parameter is set at twice the - current number of steps required, in order to prevent similar - performance regression. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i < 12: - A[i] = 100 - - for i in T.serial(16): - if i < 12: - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i < 12: - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True, max_simplification_steps=200000) - tvm.ir.assert_structural_equal(mod["main"], expected) - - -def test_remove_overwritten_predicated_loop_with_provable_condition(): - """Remove repeated writes to the same predicated region. - - Similar to - test_remove_overwritten_predicated_loop_with_identical_condition, except - the first loop's predicate is not a precise match for the second - loop's predicate. So long as the regions written in the first - loop are a subset of those written in the second loop, they can be - removed. - - In the past, this test has had performance regressions in which - the runtime increased from a few seconds to nearly ten minutes. - The "max_simplification_steps" parameter is set at twice the - current number of steps required, in order to prevent similar - performance regression. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i < 10: - A[i] = 100 - - for i in T.serial(16): - if i // 4 < 3: - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i // 4 < 3: - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True, max_simplification_steps=200000) - tvm.ir.assert_structural_equal(mod["main"], expected) - - -def test_remove_separated_overwrites(): - """Remove repeated writes to the same predicated region. - - Similar to test_remove_overwritten_loop, but with an - independent loop between the first and second write of the buffer. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = 100 - - for i in T.serial(16): - B[i] = 0 - - for i in T.serial(16): - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - B[i] = 0 - - for i in T.serial(16): - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - -@pytest.mark.xfail(reason="Not implemented yet") -def test_remove_separated_overwrite_of_predicated_loop(): - """Remove repeated writes to the same predicated region. - - Similar to test_remove_separated_overwrites, but the independent loop - between the first and second writes to a different subset - of the same buffer. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i < 12: - A[i] = 100 - - for i in T.serial(16): - if i > 12: - A[i] = 15 - - for i in T.serial(16): - if i < 12: - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i > 12: - A[i] = 15 - - for i in T.serial(16): - if i < 12: - A[i] = 42 - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - def test_remove_read_write(): """Writing a value to the same location as was just read is a no-op.""" @@ -607,54 +372,6 @@ def expected(A: T.Buffer(16, "int32")): tvm.ir.assert_structural_equal(mod["main"], expected) -def test_remove_writing_of_known_value(): - """Writing a value that already exists at that index is a no-op""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - A[4] = 4 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - -def test_keep_one_of_duplicate_loops(): - """Must not reason based on a touch point after removing it. - - If the first loop is removed because it is overwritten by the - second loop, and the second loop is removed because it writes the - same value as the first loop, the overall transformation is no - longer valid. In this case, only one of the two should be - removed. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - for i in T.serial(16): - A[i] = i - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - mod = tvm.IRModule.from_expr(before) - mod = _apply_remove_no_op(mod, use_dataflow_analysis=True) - tvm.ir.assert_structural_equal(mod["main"], expected) - - @pytest.mark.xfail(reason="Dead alloc removal not yet implemented for flat AllocBuffer") def test_remove_empty_temporary(): """An allocation with a no-op body is a no-op.""" diff --git a/tests/python/tirx-transform/test_tir_transform_simplify.py b/tests/python/tirx-transform/test_tir_transform_simplify.py index 8340900fd815..dd3d9a01f87e 100644 --- a/tests/python/tirx-transform/test_tir_transform_simplify.py +++ b/tests/python/tirx-transform/test_tir_transform_simplify.py @@ -101,8 +101,6 @@ def _apply_simplify( transitively_prove_inequalities=False, convert_boolean_to_and_of_ors=False, apply_constraints_to_boolean_branches=False, - propagate_knowns_to_prove_conditional=False, - propagate_knowns_to_simplify_expressions=False, ): """Helper to apply simplify transform with config options.""" config = { @@ -110,8 +108,6 @@ def _apply_simplify( "transitively_prove_inequalities": transitively_prove_inequalities, "convert_boolean_to_and_of_ors": convert_boolean_to_and_of_ors, "apply_constraints_to_boolean_branches": apply_constraints_to_boolean_branches, - "propagate_knowns_to_prove_conditional": propagate_knowns_to_prove_conditional, - "propagate_knowns_to_simplify_expressions": propagate_knowns_to_simplify_expressions, } } mod = tvm.IRModule.from_expr(func) @@ -1149,684 +1145,6 @@ def expected_func(A: T.Buffer(1, "bool")): tvm.ir.assert_structural_equal(after, expected_func) -def test_altered_buffer_contents_with_propagation(): - """Propagation of data-dependent conditionals. - - A literal constraint must not be propagated if the values - referenced may change. TIR requires single assignment of - variables, so Var objects may be assumed constant, but BufferLoad - may not. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer((1,), "int32"), n: T.int32): - if A[0] == n: - A[0] = A[0] + 1 - # If the simplifier incorrectly uses the invalidated - # A[0]==n condition required to reach this point, then it - # will incorrectly simplify to the then-case. If the - # simplifier correctly determines that A[0] now contains - # n+1, then it will correctly simplify to the else-case. - if A[0] == n: - A[0] = 5 - else: - A[0] = 10 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer((1,), "int32"), n: T.int32): - if A[0] == n: - A[0] = A[0] + 1 - A[0] = 10 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_possibly_altered_buffer_contents(): - """No simplification of data-dependent conditionals. - - Like test_altered_buffer_contents_with_propagation, but the `m==0` conditional - prevents the value of `A[0]` from being known at the point of the - inner conditional, either as `A[0] == n` from the outer - conditional or as `A[0] == n+1` from the write statement. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer((1,), "int32"), n: T.int32, m: T.int32): - if A[0] == n: - if m == 0: - A[0] = A[0] + 1 - - if A[0] == n: - A[0] = 5 - else: - A[0] = 10 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_input_assumption(): - """A T.assume annotation may be used to simplify""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32"), n: T.int32): - T.evaluate(T.assume(n == 0)) - if n == 0: - A[0] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32"), n: T.int32): - T.evaluate(T.assume(n == 0)) - A[0] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_no_simplify_from_scoped_input_assumption(): - """A T.assume inside a scope may not apply outside that scope""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32"), n: T.int32, m: T.int32): - if m == 0: - T.evaluate(T.assume(n == 0)) - - if n == 0: - A[0] = 42 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_conditional_using_buffer_value(): - """Simplify a conditional using the known value in the buffer""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - A[0] = 0 - - if A[0] == 0: - A[0] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32")): - A[0] = 0 - A[0] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_keep_expression_simplify_using_buffer_value(): - """Do not simplify expressions in general using known values in the buffer - - For now, because this is equivalent to inlining, preventing this - usage from occurring. Known buffer values may be used to prove - conditionals, but should not be used for other simplifications. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32"), B: T.Buffer(1, "int32")): - A[0] = 0 - B[0] = A[0] - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_conditional_in_loop_using_buffer_value(): - """Simplify a conditional using the known value in the buffer - - Like test_simplify_conditional_using_buffer_value, but the value used - to simplify is set in a previous loop. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - for j in T.serial(16): - if A[j] == j: - B[j] = 42 - else: - B[j] = 100 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - A[i] = i - - for j in T.serial(16): - B[j] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_buffer_assumption(): - """A T.assume may apply to a buffer's contents""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - T.evaluate(T.assume(A[0] == 0)) - - if A[0] == 0: - A[0] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32")): - T.evaluate(T.assume(A[0] == 0)) - A[0] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_buffer_assumption_in_loop(): - """An assumption about buffer contents may apply to a range""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(A[i] == i)) - - for i in T.serial(16): - if A[i] < 100: - A[i] = 0 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(A[i] == i)) - - for i in T.serial(16): - A[i] = 0 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_partially_known_buffer_conditional(): - """An assumption about buffer contents may apply to only part of a buffer""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if 14 <= i: - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - if 14 <= i: - if A[i] == 0: - A[i] = 42 - - else: - if A[i] == 0: - A[i] = 100 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if 14 <= i: - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - if 14 <= i: - A[i] = 42 - - else: - if A[i] == 0: - A[i] = 100 - - after = _apply_simplify( - before, - propagate_knowns_to_prove_conditional=True, - apply_constraints_to_boolean_branches=True, - ) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_partially_known_buffer_expression(): - """An assumption about buffer contents may apply to only part of a buffer - - Like test_simplify_using_partially_known_buffer_conditional, but the - conditional is expressed as part of T.assume, instead of in the - control flow. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(i < 14 or A[i] == 0)) - - for i in T.serial(16): - if 14 <= i: - if A[i] == 0: - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(i < 14 or A[i] == 0)) - - for i in T.serial(16): - if 14 <= i: - A[i] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_no_simplification_if_predicate_not_met(): - """Assumptions about buffer contents must apply to all cases to be used - - Like test_simplify_using_partial_buffer_assumption_in_loop, but the - predicate in the second loop does not match the predicate in the - first loop. Therefore, the `T.assume` refers to a different set - of indices. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if 14 <= i: - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - if i < 14: - if A[i] == 0: - A[i] = 42 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_no_simplify_using_invalidated_scoped_constraint(): - """A write may not be used for proofs outside its conditional""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - if i == 0: - A[i] = 0 - - if A[i] == 0: - A[i] = 42 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_no_simplify_using_overwritten_value(): - """A write that may have been overwritten may not be treated as known - - The appearance of "A[i] = 5" must prevent the earlier constraint - from being used for simplification. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - if i == 0: - A[i] = 5 - - if A[i] == 0: - A[i] = 42 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_no_simplify_using_loop_dependent_buffer_value(): - """Do not simplify assuming reads are invariant - - If a buffer's value changes across loop iterations, the buffer's - value before the loop should not be used to simplify conditionals - within the loop. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32"), B: T.Buffer(1, "int32")): - B[0] = 0 - for i in T.serial(16): - if B[0] < 10: - B[0] = A[i] * 2 + B[0] - else: - B[0] = A[i] + B[0] - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_prior_to_overwritten_value(): - """A known value may be used until it is overwritten - - Like test_no_simplify_using_overwritten_value, but the use of the - known `A[i]` value occurs before it is overwritten. - - Like test_no_simplify_using_loop_dependent_buffer_value, but the loop - iterations are all independent. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - if A[i] == 0: - A[i] = 17 - - if i == 0: - A[i] = 5 - - if A[i] == 0: - A[i] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32")): - for i in T.serial(16): - T.evaluate(T.assume(A[i] == 0)) - - for i in T.serial(16): - A[i] = 17 - - if i == 0: - A[i] = 5 - - if A[i] == 0: - A[i] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_element_wise_using_pre_loop_buffer_value(): - """Allow data-Do not simplify assuming reads are invariant - - If an element-wise loop reads and overwrites a buffer value, the - pre-loop buffer value may be used to simplify conditions that - occur prior to the write. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - B[i] = 0 - - for i in T.serial(16): - if B[i] < 10: - B[i] = A[i] * 2 + B[i] - else: - B[i] = A[i] + B[i] - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in T.serial(16): - B[i] = 0 - - for i in T.serial(16): - B[i] = A[i] * 2 + B[i] - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_non_conditional(): - """Propagate a known value to later expressions.""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - A[0] = 0 - A[0] = A[0] + 1 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32")): - A[0] = 0 - A[0] = 1 - - after = _apply_simplify(before, propagate_knowns_to_simplify_expressions=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_suppress_simplify_non_conditional(): - """Propagate a known value to later expressions. - - Like test_simplify_non_conditional, but with data-propagation turned off. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - A[0] = 0 - A[0] = A[0] + 1 - - expected = before - - after = _apply_simplify(before, propagate_knowns_to_simplify_expressions=False) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_transitive_known_buffer_value(): - """Propagate known buffer values - - If a known value of a buffer depends on another known value, it - can be tracked backwards through both. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - T.evaluate(T.assume(A[0] == 0)) - - A[0] = A[0] + 1 - A[0] = A[0] + 1 - A[0] = A[0] + 1 - - if A[0] == 3: - A[0] = 42 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32")): - T.evaluate(T.assume(A[0] == 0)) - - A[0] = A[0] + 1 - A[0] = A[0] + 1 - A[0] = A[0] + 1 - - A[0] = 42 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_ramp_index_broadcast_value(): - """Simplifications involving buffer loads with ramp indices""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(4, "int32")): - A[T.ramp(0, 1, 4)] = T.broadcast(0, 4) - - if A[0] == 0: - A[0] = 42 - - if A[1] == 0: - A[1] = 60 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(4, "int32")): - A[T.ramp(0, 1, 4)] = T.broadcast(0, 4) - - A[0] = 42 - A[1] = 60 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_ramp_index_ramp_value(): - """Simplifications involving buffer loads with ramp indices""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(4, "int32")): - A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4) - - if A[0] == 11: - A[0] = 42 - - if A[1] == 12: - A[1] = 60 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(4, "int32")): - A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4) - - A[0] = 42 - A[1] = 60 - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_partially_proven_buffer_value_gather(): - """Propagate known buffer values in part of buffer. - - Even if a constraint can't be solved for all values in an - assignment, it may be provable in part of a buffer. Here, the - known 0 values in the padding of A produces known 0 values in the - padding of B. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "int32")): - # A has non-zero values only in the range 3 <= i < 17 - for i in T.serial(24): - T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) - - # After convoluting with F, B has non-zero values only in the - # range 3 <= i < 19. - for i in T.serial(24): - B[i] = 0 - for f in T.serial(3): - if 0 <= i - f: - B[i] = B[i] + A[i - f] * F[f] - - # Which means that this loop is unnecessary. It would be - # removed entirely in tirx.transform.RemoveNoOp, but here we - # want to test that the simplification works as intended. - for i in T.serial(24): - if i < 3 or 19 <= i: - if B[i] != 0: - B[i] = 0 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "int32")): - for i in T.serial(24): - T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) - - for i in T.serial(24): - B[i] = 0 - for f in T.serial(3): - if 0 <= i - f: - B[i] = B[i] + A[i - f] * F[f] - - for i in T.serial(24): - if i < 3 or 19 <= i: - T.evaluate(0) - - after = _apply_simplify( - before, transitively_prove_inequalities=True, propagate_knowns_to_prove_conditional=True - ) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_using_partially_proven_buffer_value_scatter(): - """Propagate known buffer values in part of buffer. - - Like test_simplify_using_partially_proven_buffer_value_gather, but the - compute loop is over the input buffer A, rather than the output - buffer B. - """ - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "int32")): - # A has non-zero values only in the range 3 <= i < 17 - for i in T.serial(24): - T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) - - for i in T.serial(24): - B[i] = 0 - - # After convoluting with F, B has non-zero values only in the - # range 3 <= i < 19. - for i in T.serial(24): - for f in T.serial(3): - if i + f >= 0 and i + f < 24: - B[i + f] = B[i + f] + A[i] * F[f] - - # Which means that this loop is unnecessary. It actually gets - # removed in tirx.transform.RemoveNoOp, but here we want to - # test that the simplification works as intended. - for i in T.serial(24): - if i < 3 or 19 <= i: - if B[i] != 0: - B[i] = 0 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "int32")): - for i in T.serial(24): - T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) - - for i in T.serial(24): - B[i] = 0 - - for i in T.serial(24): - for f in T.serial(3): - if i + f < 24: - B[i + f] = B[i + f] + A[i] * F[f] - - for i in T.serial(24): - if i < 3 or 19 <= i: - T.evaluate(0) - - after = _apply_simplify(before, propagate_knowns_to_prove_conditional=True) - tvm.ir.assert_structural_equal(after, expected) - - -def test_simplify_buffer_store(): - """Simplification using prior known""" - - @T.prim_func(private=True, s_tir=True) - def before(A: T.Buffer(1, "int32")): - A[0] = 5 - A[0] = A[0] + 7 - - @T.prim_func(private=True, s_tir=True) - def expected(A: T.Buffer(1, "int32")): - A[0] = 5 - A[0] = 12 - - after = _apply_simplify(before, propagate_knowns_to_simplify_expressions=True) - tvm.ir.assert_structural_equal(after, expected) - - def test_simplify_trivial_let_buffer_var(): """A Bind used in a buffer definition should be retained""" From e9ff4f5874f1af52bc0550e449bfd6442db5b150 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 19:50:40 +0000 Subject: [PATCH 2/4] [REFACTOR][TIR] Rename Simplify to StmtSimplify and split file Rename clarifies that this is statement-level simplification (distinct from expression-level arithmetic simplifiers in src/arith/). The file names simplify.{h,cc} become stmt_simplify.{h,cc}, and all associated C++ types, free functions, pass functions, and FFI keys are renamed accordingly: tirx::Simplify -> tirx::StmtSimplify, tirx::transform::Simplify -> tirx::transform::StmtSimplify, SimplifyConfig -> StmtSimplifyConfig. Python wrappers and all call sites are updated to match. --- include/tvm/tirx/transform.h | 4 +- python/tvm/s_tir/backend/adreno/pipeline.py | 4 +- python/tvm/s_tir/pipeline.py | 4 +- python/tvm/testing/utils.py | 2 +- python/tvm/tirx/compilation_pipeline.py | 16 +++---- .../trn/compose_op/unary_reduce.py | 2 +- python/tvm/tirx/transform/transform.py | 12 ++--- .../analysis/calculate_allocated_memory.cc | 2 +- .../feature_extractor/per_store_feature.cc | 4 +- .../disallow_async_strided_mem_copy.cc | 2 +- .../meta_schedule/postproc/verify_gpu_code.cc | 4 +- .../schedule/primitive/blockize_tensorize.cc | 4 +- src/s_tir/transform/hoist_expression.cc | 6 +-- .../{simplify.cc => stmt_simplify.cc} | 48 ++++++++++--------- .../transform/{simplify.h => stmt_simplify.h} | 16 +++---- ...tproc_rewrite_parallel_vectorize_unroll.py | 4 +- ...t_s_tir_transform_compact_buffer_region.py | 2 +- ..._tir_transform_convert_blocks_to_opaque.py | 2 +- .../test_s_tir_transform_hoist_if.py | 2 +- ...st_s_tir_transform_inject_double_buffer.py | 6 +-- ..._tir_transform_inject_software_pipeline.py | 2 +- .../test_s_tir_transform_loop_partition.py | 22 ++++----- ...test_s_tir_transform_lower_match_buffer.py | 2 +- ...test_s_tir_transform_lower_opaque_block.py | 2 +- ...tir_transform_renormalize_split_pattern.py | 4 +- ...st_s_tir_transform_unify_thread_binding.py | 2 +- tests/python/te/test_te_create_primfunc.py | 2 +- .../python/tirx-base/test_tir_constructor.py | 2 +- .../test_tir_transform_flatten_buffer.py | 2 +- .../test_tir_transform_lower_intrin.py | 2 +- .../test_tir_transform_narrow_datatype.py | 2 +- .../test_tir_transform_simplify.py | 14 +++--- .../test_tir_transform_unroll_loop.py | 2 +- .../tile_primitive/trn/test_binary_trn.py | 2 +- .../tile_primitive/trn/test_compose_op_trn.py | 6 +-- .../tile_primitive/trn/test_copy_trn.py | 12 ++--- .../tile_primitive/trn/test_gemm_trn.py | 8 ++-- .../tile_primitive/trn/test_reduction_trn.py | 2 +- .../tile_primitive/trn/test_select_trn.py | 8 ++-- .../tile_primitive/trn/test_unary_trn.py | 2 +- 40 files changed, 124 insertions(+), 122 deletions(-) rename src/tirx/transform/{simplify.cc => stmt_simplify.cc} (84%) rename src/tirx/transform/{simplify.h => stmt_simplify.h} (70%) diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h index 35d9779e79eb..186ebf3f5227 100644 --- a/include/tvm/tirx/transform.h +++ b/include/tvm/tirx/transform.h @@ -94,11 +94,11 @@ TVM_DLL Pass UnrollLoop(); TVM_DLL Pass RemoveNoOp(); /*! - * \brief Run arithmetic simplifications on the statements and expressions. + * \brief Run statement-level arithmetic simplifications on the TIR PrimFunc. * * \return The pass. */ -TVM_DLL Pass Simplify(); +TVM_DLL Pass StmtSimplify(); /*! * \brief Convert an IRModule to be SSA form. diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index df6decb9949b..85359b1d35aa 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -44,7 +44,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerAutoCopy(), s_tir.transform.UnifyThreadBinding(), s_tir.transform.LowerMatchBuffer(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), s_tir.transform.InjectPermutedLayout(), s_tir.transform.AnnotateIrregularLoop(), s_tir.transform.InjectSoftwarePipeline(), @@ -68,7 +68,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.HoistIfThenElse(), tirx.transform.UnrollLoop(), s_tir.transform.RenormalizeSplitPattern(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.RemoveNoOp(), s_tir.transform.RewriteUnsafeSelect(), ] diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index 9cb3995a8255..9f863ca4674f 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -45,7 +45,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerAutoCopy(), s_tir.transform.UnifyThreadBinding(), s_tir.transform.LowerMatchBuffer(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), s_tir.transform.InjectPermutedLayout(), s_tir.transform.AnnotateIrregularLoop(), s_tir.transform.InjectSoftwarePipeline(), @@ -68,7 +68,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.HoistIfThenElse(), tirx.transform.UnrollLoop(), s_tir.transform.RenormalizeSplitPattern(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.RemoveNoOp(), s_tir.transform.RewriteUnsafeSelect(), ] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 3b78278de120..bdbf69396a1e 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -2017,7 +2017,7 @@ class object that inherits from `Exception`. .. code-block:: python class TestRemoveIf(tvm.testing.CompareBeforeAfter): - transform = tvm.tirx.transform.Simplify() + transform = tvm.tirx.transform.StmtSimplify() def before(A: T.Buffer(1, "int32")): if True: diff --git a/python/tvm/tirx/compilation_pipeline.py b/python/tvm/tirx/compilation_pipeline.py index 570f12da081b..30facc2663c6 100644 --- a/python/tvm/tirx/compilation_pipeline.py +++ b/python/tvm/tirx/compilation_pipeline.py @@ -33,13 +33,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I passes = [ tirx.transform.LowerInitBlock(), tvm.s_tir.transform.UnifyThreadBinding(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.FlattenBuffer(), tirx.transform.BF16ComputeLegalize(), tirx.transform.NarrowDataType(32), tirx.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), tirx.transform.UnrollLoop(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), ] if not bool(config.get("tir.disable_cse_tir", False)): passes.append(tirx.transform.CommonSubexprElim()) @@ -73,14 +73,14 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I passes = [ tirx.transform.LowerTIRx(), tvm.s_tir.transform.UnifyThreadBinding(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.LowerTIRxOpaque(), tirx.transform.FlattenBuffer(), tirx.transform.BF16ComputeLegalize(), tirx.transform.NarrowDataType(32), tirx.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), tirx.transform.UnrollLoop(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), ] if not bool(config.get("tir.disable_cse_tir", False)): passes.append(tirx.transform.CommonSubexprElim()) @@ -115,11 +115,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tirx.transform.trn.TrnNaiveAllocator(), tirx.transform.LowerTIRx(), tvm.s_tir.transform.DecorateDeviceScope(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.LowerTIRxOpaque(), tvm.s_tir.transform.LoopPartition(), tvm.s_tir.transform.HoistIfThenElse(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.RemoveNoOp(), tirx.transform.AnnotateEntryFunc(), tirx.transform.AnnotateDeviceRegions(), @@ -146,7 +146,7 @@ def finalize_device_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" device_pass_list = [ tirx.transform.LowerWarpMemory(), - tirx.transform.Simplify(), + tirx.transform.StmtSimplify(), tirx.transform.LowerCustomDatatypes(), tirx.transform.LowerIntrin(), ] @@ -161,7 +161,7 @@ def finalize_device_passes_tirx(): # pylint: disable=unused-argument def finalize_device_passes_trn(): # pylint: disable=unused-argument """The default finalization passes for TRN backend.""" - device_pass_list = [tirx.transform.Simplify()] + device_pass_list = [tirx.transform.StmtSimplify()] return tvm.ir.transform.Sequential(device_pass_list) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py index 1677f4df1410..1fc801403842 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py @@ -118,7 +118,7 @@ def impl(): import tvm mod = tvm.IRModule({"main": impl}) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) return mod["main"] else: # fmt: off diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index fbf07b5f4897..2c01863d32f3 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -210,20 +210,20 @@ def CommonSubexprElim(): return _ffi_api.CommonSubexprElim() # type: ignore -@_ffi.register_object("tirx.transform.SimplifyConfig") -class SimplifyConfig(_ffi.Object): - """Config for simplify pass""" +@_ffi.register_object("tirx.transform.StmtSimplifyConfig") +class StmtSimplifyConfig(_ffi.Object): + """Config for stmt simplify pass""" -def Simplify(): - """Run arithmetic simplifications on the statements and expressions. +def StmtSimplify(): + """Run statement-level arithmetic simplifications on the TIR PrimFunc. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.Simplify() # type: ignore + return _ffi_api.StmtSimplify() # type: ignore def ConvertSSA(): diff --git a/src/s_tir/analysis/calculate_allocated_memory.cc b/src/s_tir/analysis/calculate_allocated_memory.cc index 5c67b8aaeb03..7b54cb4fe491 100644 --- a/src/s_tir/analysis/calculate_allocated_memory.cc +++ b/src/s_tir/analysis/calculate_allocated_memory.cc @@ -179,7 +179,7 @@ ffi::Array GetVTCMCompactionPasses() { pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); pass_list.push_back(tirx::transform::FlattenBuffer()); - pass_list.push_back(tirx::transform::Simplify()); + pass_list.push_back(tirx::transform::StmtSimplify()); pass_list.push_back(tirx::transform::VectorizeLoop(true)); pass_list.push_back(tirx::transform::StorageRewrite()); return pass_list; diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index cba12d62ba1f..ad5a9e3b2fc9 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -319,11 +319,11 @@ tvm::transform::Sequential PassListForPerStoreFeature() { s_tir::transform::PlanAndUpdateBufferAllocationLocation(), s_tir::transform::ConvertBlocksToOpaque(), s_tir::transform::CompactBufferAllocation(), - tirx::transform::Simplify(), + tirx::transform::StmtSimplify(), s_tir::transform::LowerAutoCopy(), s_tir::transform::UnifyThreadBinding(), s_tir::transform::LowerMatchBuffer(), - tirx::transform::Simplify(), + tirx::transform::StmtSimplify(), }); } diff --git a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index bf39aa54180c..6e1f195e75b3 100644 --- a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -152,7 +152,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tirx::transform::FlattenBuffer()); pass_list.push_back(tirx::transform::BF16ComputeLegalize()); pass_list.push_back(tirx::transform::NarrowDataType(32)); - pass_list.push_back(tirx::transform::Simplify()); + pass_list.push_back(tirx::transform::StmtSimplify()); pass_list.push_back(s_tir::transform::InjectVirtualThread()); pass_list.push_back(s_tir::transform::InjectDoubleBuffer()); pass_list.push_back(tirx::transform::VectorizeLoop(true)); diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc index 99e0dbee6d81..0f55fcb70c66 100644 --- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc @@ -166,7 +166,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(s_tir::transform::LiftThreadBinding()); pass_list.push_back(s_tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(s_tir::transform::CompactBufferAllocation()); - pass_list.push_back(tirx::transform::Simplify()); + pass_list.push_back(tirx::transform::StmtSimplify()); pass_list.push_back(s_tir::transform::LowerAutoCopy()); pass_list.push_back(s_tir::transform::UnifyThreadBinding()); pass_list.push_back(s_tir::transform::LowerMatchBuffer()); @@ -175,7 +175,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tirx::transform::FlattenBuffer()); pass_list.push_back(tirx::transform::BF16ComputeLegalize()); pass_list.push_back(tirx::transform::NarrowDataType(32)); - pass_list.push_back(tirx::transform::Simplify()); + pass_list.push_back(tirx::transform::StmtSimplify()); // Phase 2 pass_list.push_back(tirx::transform::VectorizeLoop(true)); pass_list.push_back(s_tir::transform::InjectVirtualThread()); diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index da4deb01bc87..a2f915b0bb86 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -23,7 +23,7 @@ #include #include "../../../tirx/ir/data_type_rewriter.h" -#include "../../../tirx/transform/simplify.h" +#include "../../../tirx/transform/stmt_simplify.h" #include "../ir_comparator.h" #include "../utils.h" @@ -768,7 +768,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } arith::Analyzer analyzer; - PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer); + PrimFunc intrin_desc = StmtSimplify(intrin->desc, &analyzer); PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index dbe389e84a63..448643bdb429 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -578,7 +578,7 @@ Pass HoistExpression() { return tvm::transform::Sequential( { insertion_pass, - tirx::transform::Simplify(), + tirx::transform::StmtSimplify(), tirx::transform::RemoveNoOp(), }, "s_tir.HoistExpression"); @@ -616,7 +616,7 @@ static Pass HoistIfThenElseImpl() { return tvm::transform::Sequential( { insertion_pass, - tirx::transform::Simplify(), + tirx::transform::StmtSimplify(), tirx::transform::RemoveNoOp(), }, "s_tir.HoistIfThenElse"); @@ -634,7 +634,7 @@ static Pass HoistIfThenElseBasicImpl() { return tvm::transform::Sequential( { insertion_pass, - tirx::transform::Simplify(), + tirx::transform::StmtSimplify(), tirx::transform::RemoveNoOp(), }, "s_tir.HoistIfThenElseBasic"); diff --git a/src/tirx/transform/simplify.cc b/src/tirx/transform/stmt_simplify.cc similarity index 84% rename from src/tirx/transform/simplify.cc rename to src/tirx/transform/stmt_simplify.cc index 18e398d095a4..f271c65c8a75 100644 --- a/src/tirx/transform/simplify.cc +++ b/src/tirx/transform/stmt_simplify.cc @@ -18,11 +18,11 @@ */ /*! - * \file simplify.cc + * \file stmt_simplify.cc * \brief Statement simplifier based on analyzer */ -#include "../../tirx/transform/simplify.h" +#include "../../tirx/transform/stmt_simplify.h" #include #include @@ -41,27 +41,28 @@ namespace arith { using namespace tirx; -struct SimplifyConfigNode : public ffi::Object { +struct StmtSimplifyConfigNode : public ffi::Object { bool transitively_prove_inequalities; bool convert_boolean_to_and_of_ors; bool apply_constraints_to_boolean_branches; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() + refl::ObjectDef() .def_ro("transitively_prove_inequalities", - &SimplifyConfigNode::transitively_prove_inequalities, + &StmtSimplifyConfigNode::transitively_prove_inequalities, "If true, simplify conditionals with transitive combinations of scoped constraints", refl::DefaultValue(false)) - .def_ro("convert_boolean_to_and_of_ors", &SimplifyConfigNode::convert_boolean_to_and_of_ors, + .def_ro("convert_boolean_to_and_of_ors", + &StmtSimplifyConfigNode::convert_boolean_to_and_of_ors, "If true, simplify conditionals into an AND of ORs", refl::DefaultValue(false)) .def_ro("apply_constraints_to_boolean_branches", - &SimplifyConfigNode::apply_constraints_to_boolean_branches, + &StmtSimplifyConfigNode::apply_constraints_to_boolean_branches, "If true, simplify each branch of AND/OR under a constraints provided by the other " "branch", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.SimplifyConfig", SimplifyConfigNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.StmtSimplifyConfig", StmtSimplifyConfigNode, ffi::Object); RewriteSimplifier::Extension GetEnabledExtensions() const { @@ -81,24 +82,25 @@ struct SimplifyConfigNode : public ffi::Object { } }; -class SimplifyConfig : public ffi::ObjectRef { +class StmtSimplifyConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, ffi::ObjectRef, SimplifyConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StmtSimplifyConfig, ffi::ObjectRef, + StmtSimplifyConfigNode); }; -static SimplifyConfig MakeDefaultSimplifyConfig() { - return AttrsWithDefaultValues(); +static StmtSimplifyConfig MakeDefaultStmtSimplifyConfig() { + return AttrsWithDefaultValues(); } -TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { StmtSimplifyConfigNode::RegisterReflection(); } -TVM_REGISTER_PASS_CONFIG_OPTION("tirx.Simplify", SimplifyConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.StmtSimplify", StmtSimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, - ffi::Optional config_opt = std::nullopt) { - auto config = config_opt.value_or(MakeDefaultSimplifyConfig()); + ffi::Optional config_opt = std::nullopt) { + auto config = config_opt.value_or(MakeDefaultStmtSimplifyConfig()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); StmtSimplifier simplifier(analyzer, config); @@ -108,7 +110,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } private: - explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config) + explicit StmtSimplifier(Analyzer* analyzer, StmtSimplifyConfig config) : IRMutatorWithAnalyzer(analyzer), config_(config) {} using Parent = IRMutatorWithAnalyzer; @@ -237,7 +239,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } } - SimplifyConfig config_; + StmtSimplifyConfig config_; // Pure Bind values kept for substitution into assert conditions. // Grows monotonically under SSA — no scope-based cleanup required. @@ -248,25 +250,25 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { namespace tirx { -PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) { +PrimFunc StmtSimplify(PrimFunc func, arith::Analyzer* analyzer) { return arith::StmtSimplifier::Apply(std::move(func), analyzer); } namespace transform { -Pass Simplify() { +Pass StmtSimplify() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { arith::Analyzer analyzer; - auto cfg = ctx->GetConfig("tirx.Simplify"); + auto cfg = ctx->GetConfig("tirx.StmtSimplify"); return arith::StmtSimplifier::Apply(f, &analyzer, cfg); }; - return CreatePrimFuncPass(pass_func, 0, "tirx.Simplify", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.StmtSimplify", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tirx.transform.Simplify", Simplify); + refl::GlobalDef().def("tirx.transform.StmtSimplify", StmtSimplify); } } // namespace transform diff --git a/src/tirx/transform/simplify.h b/src/tirx/transform/stmt_simplify.h similarity index 70% rename from src/tirx/transform/simplify.h rename to src/tirx/transform/stmt_simplify.h index c59797fcff95..2e5e9b48cabb 100644 --- a/src/tirx/transform/simplify.h +++ b/src/tirx/transform/stmt_simplify.h @@ -18,11 +18,11 @@ */ /*! - * \file simplify.h - * \brief Helper functions to construct and compose IR nodes. + * \file stmt_simplify.h + * \brief Statement-level simplification of TIR PrimFuncs. */ -#ifndef TVM_TIR_TRANSFORM_SIMPLIFY_H_ -#define TVM_TIR_TRANSFORM_SIMPLIFY_H_ +#ifndef TVM_TIR_TRANSFORM_STMT_SIMPLIFY_H_ +#define TVM_TIR_TRANSFORM_STMT_SIMPLIFY_H_ #include #include @@ -30,12 +30,12 @@ namespace tvm { namespace tirx { -/* \brief Simplifies the prim func +/* \brief Simplify statements in the prim func * - * Applies the same behavior as the tirx.transform.Simplify pass. + * Applies the same behavior as the tirx.transform.StmtSimplify pass. */ -PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer); +PrimFunc StmtSimplify(PrimFunc func, arith::Analyzer* analyzer); } // namespace tirx } // namespace tvm -#endif // TVM_TIR_TRANSFORM_SIMPLIFY_H_ +#endif // TVM_TIR_TRANSFORM_STMT_SIMPLIFY_H_ diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index b376e1d99bcf..63a6343b714b 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -207,7 +207,7 @@ def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): postproc = RewriteParallelVectorizeUnroll() sch = Schedule(Move_PUV) assert postproc.apply(sch) - mod = tvm.tirx.transform.Simplify()(sch.mod) + mod = tvm.tirx.transform.StmtSimplify()(sch.mod) tvm.ir.assert_structural_equal(mod["main"], Move_PUV0) @@ -283,7 +283,7 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo postproc = RewriteParallelVectorizeUnroll() sch = Schedule(layer_norm) assert postproc.apply(sch) - mod = tvm.tirx.transform.Simplify()(sch.mod) + mod = tvm.tirx.transform.StmtSimplify()(sch.mod) assert_structural_equal_ignore_global_symbol(mod["main"], expected) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py index 81d69cb43983..e398876f35d6 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py @@ -37,7 +37,7 @@ def test_compact(self): before = tvm.IRModule.from_expr(self.before.with_attr("global_symbol", "main")) expected = tvm.IRModule.from_expr(self.expected.with_attr("global_symbol", "main")) simplify = tvm.transform.Sequential( - [tirx.transform.Simplify(), tirx.transform.RemoveNoOp()] + [tirx.transform.StmtSimplify(), tirx.transform.RemoveNoOp()] ) after = simplify(s_tir.transform.CompactBufferAllocation(is_strict=is_strict)(before)) expected = simplify(expected) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py index 5177d87d0a26..60c628b1e2cb 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py @@ -28,7 +28,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py index d30a9d81164d..66fb3d9a5d8f 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py @@ -467,7 +467,7 @@ def main(data: T.Buffer((1,), "float32"), l: T.int32, m: T.int32, n: T.int32): ] + T.float32(1.3) mod = Module - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) mod = tvm.tirx.transform.RemoveNoOp()(mod) stmt = mod["main"].body diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py index bbe937fe5d87..b24f151a4e1e 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py @@ -44,7 +44,7 @@ def db(A: T.handle("float32"), C: T.handle("float32")): mod = Module opt = tvm.transform.Sequential( - [tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tirx.transform.Simplify()] + [tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tirx.transform.StmtSimplify()] ) with tvm.transform.PassContext(config={"s_tir.InjectDoubleBuffer": {"split_loop": 2}}): @@ -78,7 +78,7 @@ def test_double_buffer_transform(): transform = tvm.ir.transform.Sequential( [ tvm.s_tir.transform.InjectDoubleBuffer(), - tvm.tirx.transform.Simplify(), + tvm.tirx.transform.StmtSimplify(), ] ) @@ -118,7 +118,7 @@ def test_double_buffer_with_decl_buffer(): transform = tvm.ir.transform.Sequential( [ tvm.s_tir.transform.InjectDoubleBuffer(), - tvm.tirx.transform.Simplify(), + tvm.tirx.transform.StmtSimplify(), ] ) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py index 36c54a2d89f9..338ab63b21af 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py @@ -41,7 +41,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.InjectSoftwarePipeline()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py index 19663e3d2c5b..aa111bed1dca 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py @@ -43,7 +43,7 @@ def func(n: T.int64, m: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -65,7 +65,7 @@ def func(n: T.int64, m: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -79,7 +79,7 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tirx.Select))) @@ -93,7 +93,7 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tirx.Select))) @@ -109,7 +109,7 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert isinstance(stmt.body.body, tvm.tirx.IfThenElse) @@ -139,7 +139,7 @@ def func(m: T.int64, data: T.handle("float32"), out: T.handle("float32")): with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -160,7 +160,7 @@ def func(A: T.Buffer((n * m,), "float16"), B: T.Buffer((n * m,), "float16")): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -181,7 +181,7 @@ def func(): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -202,7 +202,7 @@ def func(): with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tirx.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.StmtSimplify()(mod)["main"].body assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) @@ -225,7 +225,7 @@ def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True): if do_flatten: mod = tvm.tirx.transform.FlattenBuffer()(mod) mod = tvm.s_tir.transform.LoopPartition()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) mod = tvm.tirx.transform.RemoveNoOp()(mod) return mod @@ -329,7 +329,7 @@ def partitioned_main( ) mod = tvm.tirx.transform.UnrollLoop()(mod) mod = tvm.tirx.transform.RemoveNoOp()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py index 514497032932..42f0c98d0568 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py @@ -26,7 +26,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LowerMatchBuffer()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py index 660c1e1d1caf..62ad915a575d 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py @@ -24,7 +24,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py index 68d82da7c053..37a567a52e67 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py @@ -123,7 +123,7 @@ def main(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 51 def test_renormalize_split_pattern(): after = tvm.s_tir.transform.RenormalizeSplitPattern()(Before) tvm.ir.assert_structural_equal(after, After) - after = tvm.tirx.transform.Simplify()(after) + after = tvm.tirx.transform.StmtSimplify()(after) tvm.ir.assert_structural_equal(after, After_simplified) @@ -166,7 +166,7 @@ def test_analyze_inside_integer_conditional(integer_condition): """ # Similar issue would occur in most transformations that subclass - # IRMutatorWithAnalyzer. tirx.transform.Simplify() is an + # IRMutatorWithAnalyzer. tirx.transform.StmtSimplify() is an # exception, as it rewrites the integer conditionals first. These # tests are written using RenormalizeSplitPattern as it is the # first case identified. diff --git a/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py index 2ddd6f3bbdc9..c5e421774698 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py @@ -28,7 +28,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.UnifyThreadBinding()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index fc29e82442c6..6aa5689ad10d 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -50,7 +50,7 @@ def test_unique_name_reduction_block(): def _check_workload(te_workload, tir_workload, index_dtype_override=None, do_simplify=False): func = te.create_prim_func(te_workload(), index_dtype_override) if do_simplify: - simplify = tirx.transform.Simplify() + simplify = tirx.transform.StmtSimplify() func = simplify(tvm.IRModule.from_expr(func))["main"] tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"] tvm.ir.assert_structural_equal(func, tir_workload) diff --git a/tests/python/tirx-base/test_tir_constructor.py b/tests/python/tirx-base/test_tir_constructor.py index 00cd63fa8590..d084fe2b2590 100644 --- a/tests/python/tirx-base/test_tir_constructor.py +++ b/tests/python/tirx-base/test_tir_constructor.py @@ -173,7 +173,7 @@ def test_expr_constructor(): [cond0, inner_if, tvm.tirx.IntImm("int32", 0)], annotations={"keep": True}, ) - simplified = tvm.tirx.transform.Simplify()( + simplified = tvm.tirx.transform.StmtSimplify()( tvm.IRModule({"main": tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(outer_if))}) )["main"].body.value assert bool(simplified.annotations["keep"]) diff --git a/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py index 909070498706..9b1c171be472 100644 --- a/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py @@ -24,7 +24,7 @@ def _transform(): return tvm.transform.Sequential( [ tvm.tirx.transform.FlattenBuffer(), - tvm.tirx.transform.Simplify(), + tvm.tirx.transform.StmtSimplify(), ] ) diff --git a/tests/python/tirx-transform/test_tir_transform_lower_intrin.py b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py index 75e801dfd3e5..30ead37c841b 100644 --- a/tests/python/tirx-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py @@ -29,7 +29,7 @@ def lower_intrin(params, stmt): tvm.tirx.PrimFunc(params, stmt).with_attr("target", tvm.target.Target("llvm")) ) mod = tvm.transform.Sequential( - [tvm.tirx.transform.Simplify(), tvm.tirx.transform.LowerIntrin()] + [tvm.tirx.transform.StmtSimplify(), tvm.tirx.transform.LowerIntrin()] )(mod) func = mod["main"] stmt = func.body diff --git a/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py b/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py index 51cc29bbd1f5..cbd5b103389b 100644 --- a/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py @@ -297,7 +297,7 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), after = tvm.tirx.transform.NarrowDataType(32)( tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) ) - after = tvm.tirx.transform.Simplify()(after) + after = tvm.tirx.transform.StmtSimplify()(after) tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main")) diff --git a/tests/python/tirx-transform/test_tir_transform_simplify.py b/tests/python/tirx-transform/test_tir_transform_simplify.py index dd3d9a01f87e..c2121ebeca44 100644 --- a/tests/python/tirx-transform/test_tir_transform_simplify.py +++ b/tests/python/tirx-transform/test_tir_transform_simplify.py @@ -32,7 +32,7 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[i] = C_ptr[i] mod = tvm.IRModule.from_expr(func) - body = tvm.tirx.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.StmtSimplify()(mod)["main"].body # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tirx.DeclBuffer): body = body.body @@ -58,7 +58,7 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[tx] = C_ptr[tx + ty] mod = tvm.IRModule.from_expr(func) - body = tvm.tirx.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.StmtSimplify()(mod)["main"].body # Navigate through DeclBuffer nodes to reach the inner body while isinstance(body, tvm.tirx.DeclBuffer): body = body.body @@ -86,7 +86,7 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[tx] = C_ptr[tx * 32 + ty] mod = tvm.IRModule.from_expr(func) - body = tvm.tirx.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.StmtSimplify()(mod)["main"].body # With flat semantics, skip DeclBuffer/AllocBuffer siblings to find the For if isinstance(body, tvm.tirx.SeqStmt): for_stmts = [s for s in body.seq if isinstance(s, tvm.tirx.For)] @@ -104,7 +104,7 @@ def _apply_simplify( ): """Helper to apply simplify transform with config options.""" config = { - "tirx.Simplify": { + "tirx.StmtSimplify": { "transitively_prove_inequalities": transitively_prove_inequalities, "convert_boolean_to_and_of_ors": convert_boolean_to_and_of_ors, "apply_constraints_to_boolean_branches": apply_constraints_to_boolean_branches, @@ -112,7 +112,7 @@ def _apply_simplify( } mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(config=config): - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) return mod["main"] @@ -1255,7 +1255,7 @@ def main(a: T.handle): A = T.match_buffer(a, (n * 32,), "float32") A[T.int64(0)] = T.float32(0) - after = tvm.tirx.transform.Simplify()(Before) + after = tvm.tirx.transform.StmtSimplify()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) @@ -1276,7 +1276,7 @@ def main(a: T.handle): A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32") A[T.int64(1)] = T.float32(0) - after = tvm.tirx.transform.Simplify()(Before) + after = tvm.tirx.transform.StmtSimplify()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) diff --git a/tests/python/tirx-transform/test_tir_transform_unroll_loop.py b/tests/python/tirx-transform/test_tir_transform_unroll_loop.py index 4ece36a97b70..a6e3bf40e3cf 100644 --- a/tests/python/tirx-transform/test_tir_transform_unroll_loop.py +++ b/tests/python/tirx-transform/test_tir_transform_unroll_loop.py @@ -149,7 +149,7 @@ def main(B: T.Buffer((64,), "float32")): } ): after = tvm.tirx.transform.UnrollLoop()(Before) - after = tvm.tirx.transform.Simplify()(after) + after = tvm.tirx.transform.StmtSimplify()(after) tvm.ir.assert_structural_equal(after, Expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py index 1b9fd015728e..a6e3da9bf482 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py @@ -352,7 +352,7 @@ def expected(): with target: mod = tvm.IRModule({"main": binary}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py index d014516cc214..8c2ec52c4583 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py @@ -575,7 +575,7 @@ def expected(): with target: mod = tvm.IRModule({"main": binary_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -623,7 +623,7 @@ def expected(): mod = tvm.IRModule({"main": unary_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -663,7 +663,7 @@ def expected(): with target: mod = tvm.IRModule({"main": binary_chain}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py index 7dba16555afa..6c831252bd61 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py @@ -233,7 +233,7 @@ def expected(): mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -282,7 +282,7 @@ def expected(): mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -696,7 +696,7 @@ def expected(A_ptr: Tx.handle): with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -734,7 +734,7 @@ def expected(A_ptr: Tx.handle): with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -785,7 +785,7 @@ def expected(): mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -861,7 +861,7 @@ def expected(A_ptr: Tx.handle): mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py index d806024b17b1..16ac5cbd8e08 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py @@ -291,7 +291,7 @@ def expected(): mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -421,7 +421,7 @@ def expected(): with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -548,7 +548,7 @@ def expected(): mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -593,7 +593,7 @@ def expected(): with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py index fa892d43f57f..c5b9506a7824 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py @@ -240,7 +240,7 @@ def expected(): mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py index ca0cb266a58d..f2f3a901643a 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py @@ -73,7 +73,7 @@ def expected(): with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -109,7 +109,7 @@ def expected(): with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -143,7 +143,7 @@ def expected(): with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -180,7 +180,7 @@ def expected(): with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py index efd91a388388..3e72c8bb28bb 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py @@ -286,7 +286,7 @@ def expected(): with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) - mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) From 31cfa142dfd6724c0a630413d462ca3dbd974c32 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 19:57:19 +0000 Subject: [PATCH 3/4] [REFACTOR][TIR] Address Gemini review on StmtSimplify rename Fix grammatical typo in apply_constraints_to_boolean_branches docstring and update the remove_no_op test docstring to reflect that dataflow-analysis is no longer a supported option (rather than just "not used here"). --- src/tirx/transform/stmt_simplify.cc | 2 +- .../python/tirx-transform/test_tir_transform_remove_no_op.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tirx/transform/stmt_simplify.cc b/src/tirx/transform/stmt_simplify.cc index f271c65c8a75..5443702b8f9a 100644 --- a/src/tirx/transform/stmt_simplify.cc +++ b/src/tirx/transform/stmt_simplify.cc @@ -58,7 +58,7 @@ struct StmtSimplifyConfigNode : public ffi::Object { "If true, simplify conditionals into an AND of ORs", refl::DefaultValue(false)) .def_ro("apply_constraints_to_boolean_branches", &StmtSimplifyConfigNode::apply_constraints_to_boolean_branches, - "If true, simplify each branch of AND/OR under a constraints provided by the other " + "If true, simplify each branch of AND/OR under constraints provided by the other " "branch", refl::DefaultValue(false)); } diff --git a/tests/python/tirx-transform/test_tir_transform_remove_no_op.py b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py index a4e732a77a01..08ff4728d8fb 100644 --- a/tests/python/tirx-transform/test_tir_transform_remove_no_op.py +++ b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py @@ -242,7 +242,10 @@ def expected(A: T.Buffer(16, "int32")): def test_suppress_removal_of_unused_write(): - """Sequential writes to the same location are not removed without dataflow analysis.""" + """Sequential writes to the same location are not removed. + + Dataflow analysis is no longer supported. + """ @T.prim_func(private=True, s_tir=True) def before(A: T.Buffer(16, "int32")): From 05801364ab182a691e493b6117d51a05785184e0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 22:51:50 +0000 Subject: [PATCH 4/4] =?UTF-8?q?[REFACTOR][TIR]=20Fix=20missed=20Simplify?= =?UTF-8?q?=E2=86=92StmtSimplify=20rename=20in=20s=5Ftir=20finalize=5Fdevi?= =?UTF-8?q?ce=5Fpasses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bare `tir.transform.X()` form in s_tir/pipeline.py finalize_device_passes() resolved at runtime to tvm.tirx.transform, but the post-rename module no longer exposes Simplify (it's StmtSimplify). Without this fix, calling finalize_device_passes() raises AttributeError. The CI suite did not exercise this code path, so the regression slipped through PR #19604. Switch all four lines to use the explicit `tirx.transform` form (already imported) and rename Simplify → StmtSimplify. --- python/tvm/s_tir/pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index 9f863ca4674f..33a16b381fea 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -137,10 +137,10 @@ def finalize_host_passes(): # pylint: disable=unused-argument def finalize_device_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" device_pass_list = [ - tir.transform.LowerWarpMemory(), - tir.transform.Simplify(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), + tirx.transform.LowerWarpMemory(), + tirx.transform.StmtSimplify(), + tirx.transform.LowerCustomDatatypes(), + tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(device_pass_list)