Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,16 @@ class CacheLocDetector : public StmtVisitor {
static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_sref, CacheStageInfo* info) {
std::vector<StmtSRef> related_blocks;
for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
if (def->kind == DepKind::kRAW) {
related_blocks.push_back(def->dst);
// If consumer is specified, skip detecting the others
if (info->consumer_blocks.size() > 0) {
for (StmtSRef consumer : info->consumer_blocks) {
related_blocks.emplace_back(consumer);
}
} else {
for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
if (def->kind == DepKind::kRAW) {
related_blocks.push_back(def->dst);
}
}
}
if (!related_blocks.empty()) {
Expand Down Expand Up @@ -416,41 +423,37 @@ class CacheLocDetector : public StmtVisitor {

void VisitStmt_(const SeqStmtNode* seq_stmt) final {
bool previous_visited_block = visited_block_;
bool previous_visited_related = visited_related_;
visited_block_ = visited_related_ = false;
visited_block_ = false;

int pos = -1;
for (size_t i = 0; i < seq_stmt->size(); ++i) {
if (loc_pos_ != -1) {
break;
}
VisitStmt(seq_stmt->seq[i]);
// `pos` can be assigned only once when we visited `block_sref`
if (visited_block_ && visited_related_ && pos == -1) {
if (visited_block_ && visited_related_ && loc_pos_ == -1) {
// The offset of insert position from the block
pos = i;
loc_pos_ = i;
return;
} else if (visited_related_) {
// If meet the target consumer, stop searching
visited_block_ = visited_block_ || previous_visited_block;
return;
}
}
visited_block_ = visited_block_ || previous_visited_block;
visited_related_ = visited_related_ || previous_visited_related;
// Only we visited the writing block and any one of the related blocks
// That means that we have found the lowest ancestor
// of the block and any one of the related ones
if (visited_block_ && visited_related_ && loc_pos_ == -1) {
loc_pos_ = pos;
}
}

void VisitStmt_(const BlockNode* block) final {
// Only visit the current scope under buffer writer's parent block
if (block == scope_sref_->stmt) {
// The block vistied is the current parent scope
StmtVisitor::VisitStmt_(block);
// Handling cache_read for input buffer
if (visited_block_ && visited_related_ && !loc_sref_.defined()) {
// Handling cases when insert outside any loop or cache_read for input buffer
if (visited_related_ && !loc_sref_.defined()) {
loc_sref_ = self_->stmt2ref.at(block);
if (loc_pos_ == -1) {
loc_pos_ = 1;
// Handling cache_read for input buffer
if (visited_block_ == false && loc_pos_ == -1) {
loc_pos_ = 0;
}
}
return;
Expand Down Expand Up @@ -980,6 +983,33 @@ class ReIndexRewriter : public StmtExprMutator {
Region region_;
};

void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) {
class NotRegionCoverError : public ScheduleError {
public:
explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {}
IRModule mod() const final { return mod_; }
String FastErrorString() const final {
return "ScheduleError: The scope root's region cover is not complete.";
}
String DetailRenderTemplate() const final {
return R"(The scope {0} 's region cover is not complete.
The region cover property require to hold for every of its child blocks
)";
}
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
IRModule mod_;
Block block_;
};
BlockScope scope = self->GetBlockScope(scope_root);
for (const auto& kv : scope->dst2deps) {
const StmtSRef& consumer_block_sref = kv.first;
if (!self->block_info.at(consumer_block_sref).region_cover) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
throw NotRegionCoverError(self->mod, GetRef<Block>(block));
}
}
}

/******** Implementation ********/

StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
Expand All @@ -1002,7 +1032,9 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Buffer read_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
// Check required region cover for cache_read
CheckRegionCover(self, scope_sref);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);

// Step 2. Create CacheStageInfo
Expand Down Expand Up @@ -1075,7 +1107,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Buffer write_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);

// Step 2. Creating CacheStageInfo
CacheStageInfo info;
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class BlockInfoCollector : private StmtVisitor {
if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region,
&analyzer_)) {
region_cover = false;
self_->block_info.at(consumer_block_sref).region_cover = region_cover;
break;
}
}
Expand Down
63 changes: 59 additions & 4 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,24 @@ def func_with_block_predicate() -> None:
B[ax] = A[ax] + 1.0


@T.prim_func
def inplace_func(data_io: T.Buffer[(64), "int32"]):
data_1d = T.alloc_buffer([64], dtype="int32")
for i0 in T.serial(64):
with T.block("copy_in"):
v0 = T.axis.remap("S", [i0])
data_1d[v0] = data_io[v0]
for i0 in T.serial(1):
with T.block("ext_call"):
T.reads(data_1d[:64])
T.writes(data_1d[:64])
T.evaluate(T.call_extern("call_impl", data_1d.data, dtype=""))
for i0 in T.serial(64):
with T.block("copy_out"):
v0 = T.axis.remap("S", [i0])
data_io[v0] = data_1d[v0]


########## Expected function after cache_read ##########


Expand Down Expand Up @@ -414,15 +432,15 @@ def cache_read_multi_consumer_target() -> None:
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A_global[vi] = A[vi]
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0

for i in T.grid(128):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change this testcase? A common practice is not to change the existing test case if possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, indeed. but if considering a case that R-W-R a same buffer, and we want to cache_read the second R, the 1st R should be ignored. This PR's solution is to specify the 2nd R in consumer_blocks and ignore R if it is not in consumer_blocks. So for this test case, block("B") is ignored. Since the purpose of this case is to cache_read block("C"), the cache block seems to be more reasonable if next to block("C"). Any better solutions is welcomed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jwfromm since he is the author of the test case

with T.block("A"):
vi = T.axis.S(128, i)
A_global[vi] = A[vi]
for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
Expand Down Expand Up @@ -501,6 +519,35 @@ def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None:
C[vi, vj] = B[vi, vj] + T.float32(1)


@T.prim_func
def cache_read_inplace(data_io: T.Buffer[64, "int32"]) -> None:
data_1d = T.alloc_buffer([64], dtype="int32")
data_io_local = T.alloc_buffer([64], dtype="int32", scope="local")
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io[v0])
T.writes(data_io_local[v0])
data_io_local[v0] = data_io[v0]
for i0 in T.serial(64):
with T.block("copy_in"):
v0 = T.axis.spatial(64, i0)
T.reads(data_io_local[v0])
T.writes(data_1d[v0])
data_1d[v0] = data_io_local[v0]
for i0 in T.serial(1):
with T.block("ext_call"):
T.reads(data_1d[0:64])
T.writes(data_1d[0:64])
T.evaluate(T.call_extern("call_impl", data_1d.data, dtype=""))
for i0 in T.serial(64):
with T.block("copy_out"):
v0 = T.axis.spatial(64, i0)
T.reads(data_1d[v0])
T.writes(data_io[v0])
data_io[v0] = data_1d[v0]


########## Expected function after cache_write ##########


Expand Down Expand Up @@ -876,6 +923,14 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name):
sch.cache_read(block_b, 0, "test_scope")


def test_inplace_cache_read():
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
block = sch.get_block("copy_in")
sch.cache_read(block, 0, "local", [block])
tvm.ir.assert_structural_equal(cache_read_inplace, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=inplace_func)


########## Testcases for cache_write ##########


Expand Down