Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
private:
/*! \brief Iteration range for loop_vars */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra iteration range hint for free vars */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The buffers that the current block reads */
std::vector<Buffer> read_buffers_;
/*! \brief The buffers that the current block writes */
Expand Down Expand Up @@ -96,6 +98,9 @@ class BlockReadWriteDetector : public StmtExprVisitor {
/*! \brief Helper function to update a opaque access. */
void UpdateOpaque(const Var& buffer_var);

/*! \brief Helper function to relax the buffer indices */
arith::IntSet RelaxAccessIndex(const PrimExpr& index);

void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
Expand Down Expand Up @@ -140,10 +145,22 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
ExprVisitor::VisitExpr_(op);
}

arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
if (!hint_map_.empty()) {
// take non-relaxed var bound hints into considerations
// eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
// then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
relaxed = arith::Intersect({relaxed, hint_bound});
}
return relaxed;
}

void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
relaxed_region.push_back(RelaxAccessIndex(index));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
Expand All @@ -160,12 +177,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
Expand All @@ -175,12 +192,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
Expand All @@ -196,7 +213,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
relaxed_region.push_back(RelaxAccessIndex(index));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
Expand Down
10 changes: 6 additions & 4 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
Expand All @@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
Expand Down Expand Up @@ -282,6 +282,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra map from free vars to their iter range hints. */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
Expand Down
62 changes: 49 additions & 13 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,18 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
Map<Var, Range> ranges;
for (const Var& v : vars) {
auto it = dom_map_->find(v.get());
if (it != dom_map_->end()) {
const auto& int_set = it->second;
ranges.Set(v, Range::FromMinExtent(int_set.min(),
analyzer.Simplify(int_set.max() - int_set.min() + 1)));
arith::IntSet dom;
auto relax_it = relax_map_->find(v.get());
if (relax_it != relax_map_->end()) {
dom = relax_it->second;
} else {
auto hint_it = hint_map_->find(v.get());
if (hint_it != hint_map_->end()) {
dom = hint_it->second;
}
}
if (dom.defined()) {
ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
}
}
// solve constraints
Expand All @@ -314,24 +321,53 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
}

ConditionalBoundsContext::ConditionalBoundsContext(
const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
bool is_true_branch)
: condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {}
const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
: condition_(condition),
relax_map_(relax_map),
hint_map_(hint_map),
is_true_branch_(is_true_branch) {}

void ConditionalBoundsContext::EnterWithScope() {
for (const auto& p : GetVarBoundsFromCondition()) {
const auto* var = p.first.get();
auto it = dom_map_->find(var);
if (it != dom_map_->end()) {
origin_map_.emplace(var, it->second);
it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)});
arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// this is a bound for relaxed var
origin_map_.emplace(var, relax_it->second);
relax_it->second = arith::Intersect({relax_it->second, new_dom});
} else {
// this is a bound for free var
auto hint_it = hint_map_->find(var);
if (hint_it != hint_map_->end()) {
origin_map_.emplace(var, hint_it->second);
hint_it->second = arith::Intersect({hint_it->second, new_dom});
} else {
origin_map_.emplace(var, arith::IntSet::Nothing());
hint_map_->insert(hint_it, {var, new_dom});
}
}
}
}

void ConditionalBoundsContext::ExitWithScope() {
for (const auto& p : origin_map_) {
(*dom_map_)[p.first] = p.second;
const auto* var = p.first;
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// recover bound for relaxed var
relax_it->second = p.second;
} else {
// recover bound for free var
auto hint_it = hint_map_->find(var);
ICHECK(hint_it != hint_map_->end());
if (p.second.IsNothing()) {
hint_map_->erase(hint_it);
} else {
hint_it->second = p.second;
}
}
}
}

Expand Down
18 changes: 11 additions & 7 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,23 @@ Bool IsFromLegacyTESchedule(PrimFunc f);
*\brief Context helper to update domain map within conditional scope.
*
* Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
*[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where
*dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into
*scope where dom_map[i] is [9, 20]
* [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
*into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
*&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
*/
class ConditionalBoundsContext {
private:
friend class With<ConditionalBoundsContext>;
/*!
* \brief Construct a condition bounds context.
* \param condition The condition holds on true branch.
* \param dom_map The global domain map to be updated.
* \param relax_map The domain map for relaxed vars to update.
* \param hint_map The domain map for free vars to update.
* \param is_true_branch Whether step into the branch where condition bounds holds.
*/
ConditionalBoundsContext(const PrimExpr& condition,
std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
bool is_true_branch);
void EnterWithScope();
void ExitWithScope();
Expand All @@ -255,8 +257,10 @@ class ConditionalBoundsContext {

/*! \brief the condition holds on true branch. */
const PrimExpr& condition_;
/*! \brief global domain map to updated */
std::unordered_map<const VarNode*, arith::IntSet>* dom_map_;
/*! \brief domain map for relaxed vars to update */
std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
/*! \brief domain map for free vars to update */
std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
/*! \brief whether is on true branch */
bool is_true_branch_;
/*! \brief used to record and restore original var bounds */
Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ def access_in_branch_func() -> None:
B[i] = A[i - 1]


@T.prim_func
def access_of_padding_pattern() -> None:
X = T.alloc_buffer([28, 28])
X_pad = T.alloc_buffer([32, 32])
Y = T.alloc_buffer([28, 28])
for i, j in T.grid(32, 32):
with T.block("padding"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(
[
X[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
T.writes([X_pad[vi, vj]])
X_pad[vi, vj] = T.if_then_else(
2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
)
with T.block("padding_reverse"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
T.writes(
[
Y[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
Y[vi - 2, vj - 2] = X_pad[vi, vj]


def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
Expand Down Expand Up @@ -220,10 +255,41 @@ def test_access_in_branch_func():
tvm.ir.assert_structural_equal(ret0[1], ret1[1])


def test_access_of_padding_pattern():
s = tvm.tir.schedule.Schedule(access_of_padding_pattern)
alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

def do_compare_buffer_region(region, expect):
assert region.buffer == expect.buffer
analyzer = tvm.arith.Analyzer()
for k, rng in enumerate(region.region):
tvm.ir.assert_structural_equal(
analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
)
tvm.ir.assert_structural_equal(
analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
)

def do_check_block(block_name):
block = s.get_sref(s.get_block(block_name)).stmt
expect_reads = block.reads
expect_writes = block.writes
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
for i, read in enumerate(ret[0]):
do_compare_buffer_region(read, expect_reads[i])
for i, write in enumerate(ret[1]):
do_compare_buffer_region(write, expect_writes[i])

do_check_block("padding")
do_check_block("padding_reverse")


if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
test_access_of_padding_pattern()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.CompactBufferAllocation()(mod)
mod = tvm.tir.transform.Simplify()(mod)
transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
tvm.ir.assert_structural_equal(mod["main"], transformed)


Expand Down