[Bug] Employing double_buffer in tensor core conv2d results in error and precision dropping#10652
[Bug] Employing double_buffer in tensor core conv2d results in error and precision dropping#10652DzAvril wants to merge 3 commits into
Conversation
|
cc @FrozenGene |
|
Hi @FrozenGene, the pipeline of check |
I found this PR #10687 related to my CI check problem. I will try again after this PR merged. |
…annot find allocated buffer for buffer'
|
@FrozenGene All tests are passing. Please review. |
|
There is a bug that double_buffer doesn't work in tensor core conv2d template. The test code is the same as the attachment above. After lowered I found buffer // inject_double_buffer.cc:DoubleBufferDetector
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::double_buffer_scope) {
touched_.insert(op->node.as<VarNode>());
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}As tensor core conv2d template employs tensor intrin, this brings a call node // inject_double_buffer.cc:DoubleBufferDetector
void VisitExpr_(const VarNode* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}As the code shows, void VisitExpr_(const CallNode* op) final {
// do not visit var in tvm_access_ptr
if (op->op.same_as(builtin::tvm_access_ptr())) {
return;
}
StmtExprVisitor::VisitExpr_(op);
}Reference to origin PR: #405 |
|
After fixing the two bugs above, double buffer works in the final Cuda code, but it causes precision dropping. for (i, 0, 100) {
allocate B[float32 * 4]
for (i, 0, 4) {
B[i] = A[((i*4) + i)]
}
for (i, 0, 4) {
A[i] = (B[i] + 1.000000f)
}
}Target allocate B[float32 * 2 * 4]
for (i, 0, 4) {
B[i] = A[i]
}
for (i, 0, 99) {
// prefetch next iteration
for (i, 0, 4) {
B[((((i + 1) % 2)*4) + i)] = A[(((i*4) + i) + 4)]
}
for (i, 0, 4) {
A[i] = (B[(((i % 2)*4) + i)] + 1.000000f)
}
}
for (i, 0, 4) {
A[i] = (B[(i + 4)] + 1.000000f)
}In the target code, the size of B is doubled. In the second for loop, first read data into the last half part of B and then process the first half part of B. So computation can hide the latency of reading global memory. for (k.outer.outer.outer: int32, 0, 2) {
if ((k.outer.outer.outer + 1) < 3) {
attr [im2col_reshape.shared] "double_buffer_write" = 1;
for (ax0.ax1.outer.fused.outer.outer.outer_1: int32, 0, 4) {
attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 2;
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
im2col_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 1), 2)*2560), 8) + ramp(((((ax0.ax1.outer.fused.outer.outer.outer_1*640) + (threadIdx.y*320)) + (floordiv(threadIdx.x, 4)*40)) + (floormod(threadIdx.x, 4)*8)), 1, 8))] = (int8x8*)placeholder_7[ramp((((((((blockIdx.x*12288) + (ax0.outer.outer*6144)) + (ax0.ax1.outer.fused.outer.outer.outer_1*1536)) + (threadIdx.y*768)) + (floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 1)*32)) + (floormod(threadIdx.x, 4)*8)), 1, 8)]
}
}
if ((k.outer.outer.outer + 1) < 3) {
attr [placeholder_reshape.shared] "double_buffer_write" = 1;
attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 2;
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
if (((threadIdx.y*8) + floordiv(threadIdx.x, 4)) < 8) {
if (((threadIdx.y*32) + threadIdx.x) < 32) {
if (threadIdx.y < 1) {
placeholder_reshape.shared[(broadcast((floormod((k.outer.outer.outer + 1), 2)*320), 8) + ramp((((threadIdx.y*320) + (floordiv(threadIdx.x, 4)*40)) + (floormod(threadIdx.x, 4)*8)), 1, 8))] = (int8x8*)placeholder_8[ramp((((((threadIdx.y*768) + (blockIdx.y*768)) + (floordiv(threadIdx.x, 4)*96)) + ((k.outer.outer.outer + 1)*32)) + (floormod(threadIdx.x, 4)*8)), 1, 8)]
}
}
}
}
for (k.outer.inner: int32, 0, 2) {
allocate(im2col_reshape.shared.wmma.matrix_a: Pointer(wmma.matrix_a int8), int8, [64, 16]), storage_scope = wmma.matrix_a {
for (ax0.outer: int32, 0, 2) {
@tir.tvm_load_matrix_sync(im2col_reshape.shared.wmma.matrix_a, 32, 8, 16, ax0.outer, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), im2col_reshape.shared, ((ax0.outer*1280) + (k.outer.inner*16)), 1280, 1, dtype=handle), 40, "row_major", dtype=handle)
}
allocate(placeholder_reshape.shared.wmma.matrix_b: Pointer(wmma.matrix_b int8), int8, [8, 16]), storage_scope = wmma.matrix_b {
@tir.tvm_load_matrix_sync(placeholder_reshape.shared.wmma.matrix_b, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=int8), placeholder_reshape.shared, (k.outer.inner*16), 320, 1, dtype=handle), 40, "col_major", dtype=handle)
for (i.c.outer: int32, 0, 2) {
@tir.tvm_mma_sync(implicit_gemm_conv.wmma.accumulator, i.c.outer, im2col_reshape.shared.wmma.matrix_a, i.c.outer, placeholder_reshape.shared.wmma.matrix_b, 0, implicit_gemm_conv.wmma.accumulator, i.c.outer, dtype=handle)
}
}
}
}
}In the first iterate in the loop
As I guessed in the previous comment, the author didn't expect double buffer as a parameter of a call node. So the solution is processing double buffer in call node. // inject_double_buffer:DoubleBufferInjector
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode* buf = op->args[1].as<VarNode>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
ICHECK(e.stride.defined());
ICHECK(e.switch_read_var.defined());
Array<PrimExpr> args;
// dtype
args.push_back(op->args[0]);
// data
args.push_back(op->args[1]);
// offset
args.push_back(e.switch_read_var * e.stride + op->args[2]);
// extent
args.push_back(op->args[3]);
// rw_mask
args.push_back(op->args[4]);
return Call(op->dtype, op->op, args);
} else {
return GetRef<PrimExpr>(op);
}
} else {
return StmtExprMutator::VisitExpr_(op);
}
} |
|
PR #10066 has a more elegant way to bring the double buffer into the final generated code. This PR will be closed in a few days if no one has interested in it. |
When employing double_buffer in tensor core conv2d template, such as:
It results in such assert:
Attached is my test script.
test_double_buffer.py.txt
As described above, buffer
im2col_reshape.sharedwill be doubled. There are several passes for lowering double_buffer. One of them is storage_flatten. The lowered tir fed to storage_flatten pass has two types of attributes include bufferim2col_reshape.shared.Noticed that at the begin of this pass the pointers of buffer
im2col_reshape.sharedare same. There is a pass buffer_stride in pass storage_flatten which has such process for attribute statement.In branch
op->attr_key == attr::buffer_bind_scope, bufferim2col_reshape.sharedis passed toWithStridesand modified in it. Then bufferim2col_reshape.sharedchanges to[buffer(buffer, 0x55e3524f44d0), buffer(im2col_reshape.shared, 0x55e352500860)], op->attr_key: buffer_bind_scope. Noticed again the pointer toim2col_reshape.sharedis changed afterwards. And because there is none branch for processing attributedouble_buffer_scope, this cause mismatch of pointer to bufferim2col_reshape.shared.So add branch below will solve this issue.