Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 14 additions & 31 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1425,16 +1425,15 @@ class InverseAffineIterMapTransformer {
return;
}

// Case 2: If the sum expression has multiple components, match the fuse pattern and then split
// Case 2: If the sum expression has multiple components, check the fuse pattern and then split
// the sum expression for each components.
// For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2
// we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the
// propagated value to get the corresponding components of i1 and i2, which are
// floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
ICHECK(!splits.empty());

for (const IterSplitExpr& split : splits) {
CheckFusePattern(iter_map_expr);
for (size_t i = iter_map_expr->args.size(); i > 0; i--) {
const IterSplitExpr& split = iter_map_expr->args[i - 1];
backprop_.Set(split,
backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent));
}
Expand Down Expand Up @@ -1485,33 +1484,17 @@ class InverseAffineIterMapTransformer {
}
}

Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
IntImm base_scale(nullptr);
size_t base_index = 0;
for (size_t i = 0; i < sum_expr->args.size(); ++i) {
if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
if (!base_scale.defined() || op->value < base_scale->value) {
base_scale = GetRef<IntImm>(op);
base_index = i;
}
}
}
ICHECK(base_scale.defined());
std::vector<IterSplitExpr> iters;
std::vector<bool> visited(sum_expr->args.size(), false);
PrimExpr expected_scale = base_scale;
for (size_t i = 0; i < sum_expr->args.size(); i++) {
size_t j = i == 0 ? base_index : 0;
for (; j < sum_expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale))
break;
}
ICHECK(j != sum_expr->args.size());
visited[j] = true;
iters.push_back(sum_expr->args[j]);
expected_scale *= sum_expr->args[j]->extent;
/*
* \brief Check the fuse pattern of sum_expr. We assume components of sum_expr is sorted in
* descending order of lower_factor.
*/
void CheckFusePattern(const IterSumExpr sum_expr) {
ICHECK(sum_expr->args.size());
PrimExpr expected_scale = sum_expr->args.back()->scale;
for (size_t i = sum_expr->args.size(); i > 0; i--) {
ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale));
expected_scale *= sum_expr->args[i - 1]->extent;
}
return iters;
}

Analyzer* analyzer_;
Expand Down