diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index e885195b3d42..96cb92850d5a 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1425,16 +1425,15 @@ class InverseAffineIterMapTransformer { return; } - // Case 2: If the sum expression has multiple components, match the fuse pattern and then split + // Case 2: If the sum expression has multiple components, check the fuse pattern and then split // the sum expression for each components. // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2 // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the // propagated value to get the corresponding components of i1 and i2, which are // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively. - Array splits = MatchFusePattern(iter_map_expr); - ICHECK(!splits.empty()); - - for (const IterSplitExpr& split : splits) { + CheckFusePattern(iter_map_expr); + for (size_t i = iter_map_expr->args.size(); i > 0; i--) { + const IterSplitExpr& split = iter_map_expr->args[i - 1]; backprop_.Set(split, backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); } @@ -1485,33 +1484,17 @@ class InverseAffineIterMapTransformer { } } - Array MatchFusePattern(const IterSumExpr sum_expr) { - IntImm base_scale(nullptr); - size_t base_index = 0; - for (size_t i = 0; i < sum_expr->args.size(); ++i) { - if (const auto* op = sum_expr->args[i]->scale.as()) { - if (!base_scale.defined() || op->value < base_scale->value) { - base_scale = GetRef(op); - base_index = i; - } - } - } - ICHECK(base_scale.defined()); - std::vector iters; - std::vector visited(sum_expr->args.size(), false); - PrimExpr expected_scale = base_scale; - for (size_t i = 0; i < sum_expr->args.size(); i++) { - size_t j = i == 0 ? base_index : 0; - for (; j < sum_expr->args.size(); ++j) { - if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale)) - break; - } - ICHECK(j != sum_expr->args.size()); - visited[j] = true; - iters.push_back(sum_expr->args[j]); - expected_scale *= sum_expr->args[j]->extent; + /* + * \brief Check the fuse pattern of sum_expr. We assume components of sum_expr is sorted in + * descending order of lower_factor. + */ + void CheckFusePattern(const IterSumExpr sum_expr) { + ICHECK(sum_expr->args.size()); + PrimExpr expected_scale = sum_expr->args.back()->scale; + for (size_t i = sum_expr->args.size(); i > 0; i--) { + ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale)); + expected_scale *= sum_expr->args[i - 1]->extent; } - return iters; } Analyzer* analyzer_;