From e2ef3baeb7d35e809eb33ab2eacb87c9606218d5 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Sun, 11 Sep 2022 04:43:49 +0000 Subject: [PATCH] [TIR][Schedule] Relax cache read/write's restriction and fix unexpected behavior. --- .../schedule/primitive/cache_read_write.cc | 76 +++++++++++++------ src/tir/schedule/state.cc | 1 + .../test_tir_schedule_cache_read_write.py | 63 ++++++++++++++- 3 files changed, 114 insertions(+), 26 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a221733eb394..c76e6abaebb5 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -382,9 +382,16 @@ class CacheLocDetector : public StmtVisitor { static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { std::vector related_blocks; - for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { - if (def->kind == DepKind::kRAW) { - related_blocks.push_back(def->dst); + // If consumer is specified, skip detecting the others + if (info->consumer_blocks.size() > 0) { + for (StmtSRef consumer : info->consumer_blocks) { + related_blocks.emplace_back(consumer); + } + } else { + for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + if (def->kind == DepKind::kRAW) { + related_blocks.push_back(def->dst); + } } } if (!related_blocks.empty()) { @@ -416,29 +423,24 @@ class CacheLocDetector : public StmtVisitor { void VisitStmt_(const SeqStmtNode* seq_stmt) final { bool previous_visited_block = visited_block_; - bool previous_visited_related = visited_related_; - visited_block_ = visited_related_ = false; + visited_block_ = false; - int pos = -1; for (size_t i = 0; i < seq_stmt->size(); ++i) { if (loc_pos_ != -1) { break; } VisitStmt(seq_stmt->seq[i]); // `pos` can be assigned only once when we visited `block_sref` - if (visited_block_ && visited_related_ && pos == -1) { + if (visited_block_ && visited_related_ && loc_pos_ == -1) { // The offset of insert position from the block - pos = i; + loc_pos_ = i; + return; + } else if (visited_related_) { + // If meet the target consumer, stop searching + visited_block_ = visited_block_ || previous_visited_block; + return; } } - visited_block_ = visited_block_ || previous_visited_block; - visited_related_ = visited_related_ || previous_visited_related; - // Only we visited the writing block and any one of the related blocks - // That means that we have found the lowest ancestor - // of the block and any one of the related ones - if (visited_block_ && visited_related_ && loc_pos_ == -1) { - loc_pos_ = pos; - } } void VisitStmt_(const BlockNode* block) final { @@ -446,11 +448,12 @@ class CacheLocDetector : public StmtVisitor { if (block == scope_sref_->stmt) { // The block vistied is the current parent scope StmtVisitor::VisitStmt_(block); - // Handling cache_read for input buffer - if (visited_block_ && visited_related_ && !loc_sref_.defined()) { + // Handling cases when insert outside any loop or cache_read for input buffer + if (visited_related_ && !loc_sref_.defined()) { loc_sref_ = self_->stmt2ref.at(block); - if (loc_pos_ == -1) { - loc_pos_ = 1; + // Handling cache_read for input buffer + if (visited_block_ == false && loc_pos_ == -1) { + loc_pos_ = 0; } } return; @@ -980,6 +983,33 @@ class ReIndexRewriter : public StmtExprMutator { Region region_; }; +void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) { + class NotRegionCoverError : public ScheduleError { + public: + explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The scope root's region cover is not complete."; + } + String DetailRenderTemplate() const final { + return R"(The scope {0} 's region cover is not complete. +The region cover property require to hold for every of its child blocks +)"; + } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + BlockScope scope = self->GetBlockScope(scope_root); + for (const auto& kv : scope->dst2deps) { + const StmtSRef& consumer_block_sref = kv.first; + if (!self->block_info.at(consumer_block_sref).region_cover) { + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); + throw NotRegionCoverError(self->mod, GetRef(block)); + } + } +} + /******** Implementation ********/ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, @@ -1002,7 +1032,9 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer read_buffer = GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + // Check required region cover for cache_read + CheckRegionCover(self, scope_sref); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo @@ -1075,7 +1107,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer write_buffer = GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Step 2. Creating CacheStageInfo CacheStageInfo info; diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 6d4a42236f57..27056124d9e1 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -346,6 +346,7 @@ class BlockInfoCollector : private StmtVisitor { if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, &analyzer_)) { region_cover = false; + self_->block_info.at(consumer_block_sref).region_cover = region_cover; break; } } diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index cf4836e5361e..334fb988d775 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -223,6 +223,24 @@ def func_with_block_predicate() -> None: B[ax] = A[ax] + 1.0 +@T.prim_func +def inplace_func(data_io: T.Buffer[(64), "int32"]): + data_1d = T.alloc_buffer([64], dtype="int32") + for i0 in T.serial(64): + with T.block("copy_in"): + v0 = T.axis.remap("S", [i0]) + data_1d[v0] = data_io[v0] + for i0 in T.serial(1): + with T.block("ext_call"): + T.reads(data_1d[:64]) + T.writes(data_1d[:64]) + T.evaluate(T.call_extern("call_impl", data_1d.data, dtype="")) + for i0 in T.serial(64): + with T.block("copy_out"): + v0 = T.axis.remap("S", [i0]) + data_io[v0] = data_1d[v0] + + ########## Expected function after cache_read ########## @@ -414,15 +432,15 @@ def cache_read_multi_consumer_target() -> None: with T.block("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = 1.0 - for j in T.grid(16): - with T.block("A"): - vi = T.axis.S(128, i * 16 + j) - A_global[vi] = A[vi] for j in T.grid(16): with T.block("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 + for i in T.grid(128): + with T.block("A"): + vi = T.axis.S(128, i) + A_global[vi] = A[vi] for i in T.grid(128): with T.block("C"): vi = T.axis.S(128, i) @@ -501,6 +519,35 @@ def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None: C[vi, vj] = B[vi, vj] + T.float32(1) +@T.prim_func +def cache_read_inplace(data_io: T.Buffer[64, "int32"]) -> None: + data_1d = T.alloc_buffer([64], dtype="int32") + data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") + for ax0 in T.serial(64): + with T.block("data_io_local"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io[v0]) + T.writes(data_io_local[v0]) + data_io_local[v0] = data_io[v0] + for i0 in T.serial(64): + with T.block("copy_in"): + v0 = T.axis.spatial(64, i0) + T.reads(data_io_local[v0]) + T.writes(data_1d[v0]) + data_1d[v0] = data_io_local[v0] + for i0 in T.serial(1): + with T.block("ext_call"): + T.reads(data_1d[0:64]) + T.writes(data_1d[0:64]) + T.evaluate(T.call_extern("call_impl", data_1d.data, dtype="")) + for i0 in T.serial(64): + with T.block("copy_out"): + v0 = T.axis.spatial(64, i0) + T.reads(data_1d[v0]) + T.writes(data_io[v0]) + data_io[v0] = data_1d[v0] + + ########## Expected function after cache_write ########## @@ -876,6 +923,14 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name): sch.cache_read(block_b, 0, "test_scope") +def test_inplace_cache_read(): + sch = tvm.tir.Schedule(inplace_func, debug_mask="all") + block = sch.get_block("copy_in") + sch.cache_read(block, 0, "local", [block]) + tvm.ir.assert_structural_equal(cache_read_inplace, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=inplace_func) + + ########## Testcases for cache_write ##########