diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2060fb7920ed..5dd4103e8202 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1448,6 +1448,27 @@ constexpr const char* device_scope = "device_scope"; */ constexpr const char* async_scope = "async_scope"; +/*! + * \brief Annotations for invoking and synchronizing asynchronous operations. + + * Synchronization is done in terms of "queue": It is an abstract entity associated + * with each asynchronous unit, and it tracks invocations and completions of asynchronous + * operations in the FIFO order. + * + * Similarly to PTX instructions commit_group and wait_group, these annotations express + * synchronization by "counting": + * + * async_commit_queue(i): Group one or more invocations of async operations in the given scope, + * and "commit" (or push) them to the queue i. A group of operations committed together is + * awaited as one chunk. Groups committed to the same queue complete in the FIFO order. + * + * async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at + * the queue i. N does not have to be a constant, but some backends may require a constant count. +*/ +constexpr const char* async_commit_queue_scope = "async_commit_queue_scope"; +constexpr const char* async_wait_queue_scope = "async_wait_queue_scope"; +constexpr const char* async_wait_inflight_count = "async_wait_inflight_count"; + /*! * \brief Mark that the shape of TensorCore fragment */ @@ -1483,6 +1504,12 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage"; /*! \brief Mark the order of a statement in the software pipeline */ constexpr const char* software_pipeline_order = "software_pipeline_order"; +/*! \brief List stages in the software pipeline that should run asynchronously + * \note All statements in the provided stages are assumed to have asynchronous + * semantics (e.g. CUDA async global to shared memory copy). + */ +constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages"; + /*! \brief Mark the buffers which is const access and can be transformed layout. */ constexpr const char* layout_free_buffers = "layout_free_buffers"; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 616e75f2e776..3ea6f8d9edbd 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -917,6 +917,24 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { const VarNode* buffer = op->node.as(); const StringImmNode* layout_str = op->value.as(); fragment_layouts[buffer] = layout_str->value; + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode* queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + this->VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + this->VisitExpr(commit_group, this->stream); + return; + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + this->VisitExpr(wait_group, this->stream); + auto inner = op->body.as(); + ICHECK(inner); + this->VisitStmt(inner->body); + return; } CodeGenC::VisitStmt_(op); } diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b4a597fe97d8..227935bf72dd 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -25,6 +25,8 @@ #include #include +#include + #include "../../support/utils.h" #include "../schedule/utils.h" #include "./ir_utils.h" @@ -60,13 +62,14 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) return block; } -/*! Structure that represents the stage and order of the software pipeline component. */ -struct PipelineStageOrder { +/*! Structure that represents the provided annotation per block or loop. */ +struct PipelineAnnotation { int stage; int order; + bool async; }; -using PipelineInfo = std::unordered_map; +using PipelineInfo = std::unordered_map; struct BufferAccessInfo { int def = -1; // the defining stage of the buffer @@ -99,6 +102,8 @@ class PipelineOpaqueAccessRewriter { static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); static const auto& mma_sync = builtin::tvm_mma_sync(); static const auto& access_ptr = builtin::tvm_access_ptr(); + static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix(); + static const auto& ptx_mma = builtin::ptx_mma(); if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); auto it = buffer_remap_.find(buffer); @@ -122,24 +127,11 @@ class PipelineOpaqueAccessRewriter { } return Call(call->dtype, call->op, new_args, call->span); } else if (call->op.same_as(access_ptr)) { - const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - Array new_args = call->args; - const Buffer& new_buffer = (*it).second; - const PrimExpr& old_index = call->args[2]; - PrimExpr offset; - if (new_buffer->strides.empty()) { - offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), buffer->shape); - } else { - offset = new_buffer->strides[0]; - } - PrimExpr new_index = - old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; - new_args.Set(2, new_index); - return Call(call->dtype, call->op, new_args, call->span); - } + return RewriteBufferAccess(call, {1}); + } else if (call->op.same_as(ptx_mma)) { + return RewriteBufferAccess(call, {6, 8, 10}); + } else if (call->op.same_as(ptx_ldmatrix)) { + return RewriteBufferAccess(call, {3}); } return call; } @@ -166,6 +158,32 @@ class PipelineOpaqueAccessRewriter { return new_buffer_offset; } + PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { + auto product = [](const Array& input) { + return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), input); + }; + Array new_args = call->args; + for (int i : arg_indices) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[i])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[i + 1]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = product(buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = + old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; + new_args.Set(i + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } + const Map& buffer_data_to_buffer_; const Map& buffer_remap_; const For& pipeline_loop_; @@ -494,6 +512,267 @@ class PipelineRewriter : public StmtExprMutator { return Buffer(new_buffer); } + // Per-stage states that need to be tracked across pipeline prologue, body, and epilogue. + struct AsyncStateGlobal { + // Buffers that this stage asynchronously writes. + std::unordered_set dst_buffers; + // An imaginary index that the latest async operation associated with this stage has written + // into. Only valid if all associated predicates are true, so that we can count the number of + // async invocations exactly. When it is valid, it is the "sum of extents of loops that have + // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This + // is only needed to compute wait count for epilogue without async producers. + Optional producer_head{PrimExpr(-1)}; + + bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } + }; + + // Per-stage states that are local to each of pipeline prologue, body, and epilogue. + struct AsyncStateLocal { + struct { + // The index into a list of blocks, where async_wait_queue should be attached at the + // beginning. + int insert_before; + // in_flight_count would be a more precise name, but the implementation uses wait_count for + // brevity. + PrimExpr wait_count{nullptr}; + + bool valid() const { return wait_count.defined(); } + } pending_wait; + + // Destination buffers of async operations that have been encountered so far in the loop + // + // for (size_t i = 0; i < new_blocks.size(); ++i) { + // ... + // } + // + // This is for tracking which async operations have been issued at the "current" iteration, up + // until a point where we encounter a consumer of async result buffers. This is used to decide + // if the producer_head of each buffer points to a copy written in the current or previous + // iteration. + std::unordered_set seen; + + // A symbolic expression representing the index the latest async operation associated with this + // stage has written into, at the "current" iteration. + Optional producer_head; + // The predicate of BlockRealize containing the async operation of this stage. + Optional predicate; + // Indices into a list of blocks, where async_commit_queue scope should be attached. + // If multiple async producers are interleaved with their consumer in between, we need separate + // async_commit_queue for each producer. Thus, we need multiple sets of indices. + std::vector> commit_groups; + + // This is set to true when we reach a stage that consumes this async stage. + bool consumed{false}; + }; + + /*! Structure holding intermediate information for pipeline loop rewriting. */ + struct RewrittenBlockInfo { + int stage; + PrimExpr predicate; + Block block; + PrimExpr access_index; + bool is_async; + }; + + // Determine where to insert async_wait and the corresponding wait count. + void PopulateWaitCounts(const std::vector& new_blocks, + arith::Analyzer* ana_normalized, + const std::unordered_map& buffer_to_commit_group, + std::map* async_states_local) { + for (size_t i = 0; i < new_blocks.size(); ++i) { + if (new_blocks[i].is_async) { + // Record the fact that we have encountered these write buffers. + for (auto write_region : new_blocks[i].block->writes) { + (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get()); + } + } + + int producer_stage_idx = -1; + for (auto read_region : new_blocks[i].block->reads) { + for (auto kv : async_states) { + if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) { + // Found an earlier stage where read_region->buffer was asynchronously written + ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) + << "A dependency on multiple async stages is not supported"; + producer_stage_idx = kv.first; + } + } + } + + if (producer_stage_idx == -1) continue; + + // The following logic has become complicated to handle case like this: + // + // for i in range(13): + // # Stage 0 + // async_commit_queue(0): + // async_scope: + // A_shared[(i + 3) % 4] = A[...] + // + // + // # Stage 1 + // async_wait_queue(0, 5): + // compute(A_shared[i], B_shared[i]) + // + // # Stage 0 + // async_commit_queue(0) + // async_scope: + // B_shared[(i + 3) % 4] = B[...] + // + // + // Here, multiple async producers in the same stage are interleaved with their consumer in + // between. Since each buffer is associated with different commit groups, the wait_count + // before the consumer should be bigger than the simpler case: + // + // for i in range(13): + // # Stage 0 + // async_commit_queue(0): + // async_scope: + // A_shared[(i + 3) % 4] = A[...] + // B_shared[(i + 3) % 4] = B[...] + // + // # Stage 1 + // async_wait_queue(0, 3): + // compute(A_shared[i], B_shared[i]) + // + // The correct wait_count can be determined by considering each commit group separately, and + // summing "per-commit" wait_counts. + // + // From A_shared's perspective, it allows for (i + 3) - i async commit groups to be in + // flight while from B_shared's perspective, the producer head at compute points to the copy + // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The + // sum of the two wait_counts gives 5. + + auto& dep_local_state = (*async_states_local)[producer_stage_idx]; + const auto num_commit_group = dep_local_state.commit_groups.size(); + std::vector> producer_head_per_commit; + + if (num_commit_group == 0) { + // Epilogue, no async producer. Since "local" producer_head is not available, use + // "global" producer_head. + ICHECK(!dep_local_state.producer_head); + producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head); + } else { + ICHECK(dep_local_state.producer_head); + std::vector need_wait_count(num_commit_group, true); + + for (auto read_region : new_blocks[i].block->reads) { + if (!async_states[producer_stage_idx].writes(read_region->buffer)) continue; + auto commit_group_id = buffer_to_commit_group.at(read_region->buffer.get()); + if (!need_wait_count[commit_group_id]) continue; + + if (!dep_local_state.seen.count(read_region->buffer.get())) { + // Multiple async producers interleaved: The most recent async write is from the + // previous iteration. This is the B_shared case above. + producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1); + } else { + // Normal case + producer_head_per_commit.push_back(dep_local_state.producer_head.value()); + } + + need_wait_count[commit_group_id] = false; + } + } + + auto wait_count = [=, &ana_normalized]() { + auto sum = PrimExpr(0); + for (auto producer_head : producer_head_per_commit) { + if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { + // Here, new_blocks[i].access_index corresponds to "consumer_head". + // The difference of producer_head and consumer_head is precisely the number of + // async commit groups that can still be in flight after this wait. + sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index); + } else { + // The precise count cannot be determined, give up. + return PrimExpr(0); + } + } + return sum; + }(); + + auto& pending_wait = dep_local_state.pending_wait; + + if (!pending_wait.valid()) { + pending_wait = {static_cast(i), wait_count}; + } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { + // Coalesce multiple wait_queue if the later one allows fewer in-flight ops. + pending_wait = {pending_wait.insert_before, wait_count}; + } + } + } + + // Given pipelined blocks and async-related information, generate final loop statements with async + // scopes (if any). + Array CompletePipelineLoopStatements( + const std::vector& blocks, + const std::map& async_states_local, + arith::Analyzer* ana_normalized) const { + std::vector new_blocks = blocks; + std::vector commit_group_indices(new_blocks.size(), -1); + for (const auto& kv : async_states_local) { + const int stage_id = kv.first; + const AsyncStateLocal& state = kv.second; + + if (!state.commit_groups.empty()) { + for (size_t i = 0; i < state.commit_groups.size(); ++i) { + for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { + ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); + commit_group_indices[state.commit_groups[i][0] + j] = stage_id; + } + } + } + + if (state.pending_wait.valid()) { + auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { + auto& block = new_blocks[i].block; + BlockNode* n = block.CopyOnWrite(); + auto zero = make_zero(DataType::Int(32)); + n->body = + AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, + AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body)); + }; + + if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) { + // If the async operation that this wait_queue is waiting on is predicated, and we cannot + // prove that the predicate is always true, the precise wait count is only valid + // at iterations where the predicate is true; + auto wait_count = Call(DataType::Int(32), builtin::if_then_else(), + {state.predicate.value(), state.pending_wait.wait_count, 0}); + attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count); + } else { + attach_wait_scope(state.pending_wait.insert_before, stage_id, + state.pending_wait.wait_count); + } + } + } + + Array stmts; + + for (size_t i = 0; i < new_blocks.size();) { + if (commit_group_indices[i] == -1) { + // A synchrnous block, not part of any commit group + stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); + ++i; + } else { + Array group_bodies; + auto stage_id = commit_group_indices[i]; + auto predicate = new_blocks[i].predicate; + for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { + ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) + << "Predicates in the same stage are expected to be identical"; + group_bodies.push_back(new_blocks[i].block->body); + } + auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, stage_id, body); + auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); + stmts.push_back(BlockRealize({}, predicate, new_block)); + } + } + + return stmts; + } + /*! * \brief Emit the pipeline loop in the given range. * \param start The start of the range @@ -502,7 +781,6 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { - Array stmts; PrimExpr new_loop_var; PrimExpr extent = end - start; @@ -519,6 +797,19 @@ class PipelineRewriter : public StmtExprMutator { analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); } + // In contrast to analyzer_ which is bound to [start, end), this one is bound to + // the "normalized" range, [pipeline_loop_->min, extent). + arith::Analyzer ana_normalized; + if (!is_unit_loop) { + ana_normalized.Bind(Downcast(new_loop_var), Range(pipeline_loop_->min, extent)); + } + + std::vector new_blocks; + + // Async related + std::map async_states_local; + std::unordered_map buffer_to_commit_group; + for (const Block& block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; PrimExpr skewed_loop_var = new_loop_var - stage; @@ -530,20 +821,78 @@ class PipelineRewriter : public StmtExprMutator { Block new_block = Downcast(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1, fragment_info_)(block)); - Map subst_map; - if (is_unit_loop) { - subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); - } else { - // normalize loop range - PrimExpr delta = start - pipeline_loop_->min; - subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta); + + PrimExpr delta = start - pipeline_loop_->min; + // This variable corresponds to + // - "producer_head" if this stage is an async producer + // - "consumer_head" if this stage reads from asynchronously written buffers. + PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; + + // Adjust the block predicate and the body according to the final loop bound + // [pipeline_loop_->min, extent). + if (!is_unit_loop) { Var loop_iter = Downcast(new_loop_var); - inbound = Substitute(inbound, Map{{loop_iter, loop_iter + delta}}); + inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); + } + + new_block = Downcast( + Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); + + if (pipeline_info_[block].async) { + auto& local_state = async_states_local[stage]; + + int commit_group_id = -1; + if (local_state.commit_groups.empty() || local_state.consumed) { + // consumed == true means there is already a consumer stage waiting for an + // eariler async operation of this stage. In such cases, we make multiple commit_queue + // for this stage. + commit_group_id = local_state.commit_groups.size(); + local_state.commit_groups.push_back({new_blocks.size()}); + } else { + // This is the case when one commit_queue groups multiple async blocks. + // with commit_queue(stage): + // async_scope: + // A_shared[...] = ... + // async_scope: + // B_shared[...] = ... + + commit_group_id = local_state.commit_groups.size() - 1; + local_state.commit_groups.back().push_back(new_blocks.size()); + } + + for (auto write_region : new_block->writes) { + async_states[stage].dst_buffers.insert(write_region->buffer.get()); + buffer_to_commit_group[write_region->buffer.get()] = commit_group_id; + } + + local_state.producer_head = normalized_access_index; + + if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) { + local_state.predicate = inbound; + } else if (local_state.predicate) { + local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound); + } + + BlockNode* n = new_block.CopyOnWrite(); + n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); + } + + new_blocks.push_back( + {stage, inbound, new_block, normalized_access_index, pipeline_info_[block].async}); + + for (auto read_region : new_block->reads) { + for (auto kv : async_states) { + int producer_stage_id = kv.first; + if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) { + async_states_local[producer_stage_id].consumed = true; + } + } } - new_block = Downcast(Substitute(new_block, subst_map)); - stmts.push_back(BlockRealize({}, inbound, new_block)); } + PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local); + auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized); + Stmt new_loop{nullptr}; if (stmts.empty()) { @@ -559,6 +908,24 @@ class PipelineRewriter : public StmtExprMutator { new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); } + + // Update producer heads in the global async states. + for (const auto& kv : async_states_local) { + const int stage_id = kv.first; + const AsyncStateLocal& state = kv.second; + + if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && + async_states[stage_id].producer_head) { + // Advance the "global" producer head if it is still valid and we know exactly how much we + // can increment + async_states[stage_id].producer_head = + async_states[stage_id].producer_head.value() + extent; + } else { + // Otherwise, invalidate the global producer head + async_states[stage_id].producer_head = NullOpt; + } + } + return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); } @@ -572,6 +939,7 @@ class PipelineRewriter : public StmtExprMutator { int max_stage_ = -1; Map buffer_remap_; Array ordered_stmts_; + std::map async_states; }; /*! @@ -727,11 +1095,23 @@ class PipelineInjector : private StmtExprMutator { Downcast>(op->annotations.at(attr::software_pipeline_order)); CHECK_EQ(pipeline_stages.size(), original_order.size()); CHECK_EQ(pipeline_orders.size(), original_order.size()); + + std::unordered_set pipeline_async_stages; + if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { + for (auto s : Downcast>(annot)) { + pipeline_async_stages.insert(s->value); + } + } + for (size_t i = 0; i < pipeline_stages.size(); i++) { - PipelineStageOrder stage_order{/*stage=*/static_cast(pipeline_stages[i]->value), - /*order=*/static_cast(pipeline_orders[i]->value)}; + int stage = static_cast(pipeline_stages[i]->value); + bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end(); + PipelineAnnotation stage_order{stage, + /*order=*/static_cast(pipeline_orders[i]->value), + is_async}; pipeline_info.emplace(original_order[i], stage_order); } + ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 700c9931bba0..66b04bd67892 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -441,5 +441,12 @@ void ConditionalBoundsContext::ExitWithScope() { } } +std::pair GetAsyncWaitAttributes(const AttrStmtNode* op) { + ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope); + auto inner = op->body.as(); + ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); + return std::make_pair(op->value, inner->value); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 2234cc22bcfa..d89ee3619699 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace tvm { @@ -306,6 +307,10 @@ struct FragmentInfo { */ std::unordered_map GetTensorCoreFragmentInfo(const Stmt& stmt); +// Return the queue id and the in-flight count associated with the given +// attr::async_wait_queue_scope annotation. +std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index c8c77b8badf5..ce0d9b87c433 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -21,6 +21,7 @@ * \file remove_no_op.cc * \brief Remove no op from the stmt */ +#include #include #include #include @@ -30,6 +31,8 @@ #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -44,7 +47,20 @@ class NoOpRemover : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { return MakeEvaluate(0); + } else if (op->attr_key == attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto wait_cnt = wait_attrs.second; + arith::Analyzer ana; + if (ana.CanProve(wait_cnt < 0)) { + // A negative wait count can arise if it depends on a loop variable. + // For example, a wait count 1 - i can be negative after loop unrolling. + // We assume that such wait is a nop. + auto inner = op->body.as(); + ICHECK(inner); + return StmtMutator::VisitStmt(inner->body); + } } + Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index ce3f8fd3e3ac..954f4f7cc47d 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -230,6 +230,48 @@ class ThreadSyncPlanner : public StorageAccessVisitor { StorageScope sync_scope_; }; +// There are cases where necessary syncthreads is not inserted by ThreadSyncInserter. +// For example, syncthreads is needed after async_wait_queue in the second loop below, +// but since ThreadSyncInserter is not aware of the asynchronous semantics, it cannot tell +// that the syncthreads is needed there. +// +// // Pipeline prologue +// for i in range(125): +// async_commit_queue(0): +// async_scope: +// shared[(i + 3) % 4] = ... +// ... +// +// // Pipeline Epilogue +// for i in range(3): +// async_wait_queue(0, 2 - i): +// local[...] = shared[(i + 125) % 4] + +// This class adds syncthreads after all async_wait_queue. That includes syncthreads that +// can be inserted by ThreadSyncInserter as well, but ThreadSyncInserter will not insert +// duplicate syncthreads if it finds an existing one at the synchronization point. +class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { + public: + explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(sync_scope) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::async_wait_queue_scope) { + auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string())})); + auto inner = op->body.as(); + ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); + auto zero = make_zero(DataType::Int(32)); + auto new_body = SeqStmt({sync, inner->body}); + return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value, + AttrStmt(zero, tir::attr::async_wait_inflight_count, inner->value, new_body)); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + StorageScope sync_scope_; +}; + class ThreadSyncInserter : public StmtExprMutator { public: ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) @@ -384,6 +426,9 @@ class ThreadSyncInserter : public StmtExprMutator { Stmt ThreadSync(Stmt stmt, std::string storage_scope) { StorageScope sync_scope = StorageScope::Create(storage_scope); + if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") { + stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); + } ThreadSyncPlanner planner(sync_scope); planner(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 2f08249ed76f..edaeb7c9b639 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -92,26 +92,32 @@ def transformed_trivial_pipeline( C[tx, 0] = B[0, tx, 0] + T.float32(1) -@T.prim_func -def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): - for tx in T.thread_binding(0, 16, thread="threadIdx.x"): - for i in T.serial( - 0, - 16, - annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}, - ): - with T.block(): - T.reads(A[tx, i]) - T.writes(C[tx, i]) - B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - with T.block(): +def gen_simple_compute(num_stages): + @T.prim_func + def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, num_stages], + "software_pipeline_order": [0, 1], + }, + ): + with T.block("compute"): T.reads(A[tx, i]) - T.writes(B[tx, 0]) - B[tx, 0] = A[tx, i] * T.float32(2) - with T.block(): - T.reads(B[tx, 0]) T.writes(C[tx, i]) - C[tx, i] = B[tx, 0] + T.float32(1) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + return simple_compute @T.prim_func @@ -156,7 +162,7 @@ def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "software_pipeline_order": [0, 1, 2], }, ): - with T.block(): + with T.block("compute"): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") @@ -991,7 +997,7 @@ def simple_compute_missing_annotation( def test_simple_compute(): - _check(simple_compute, transformed_simple_compute) + _check(gen_simple_compute(1), transformed_simple_compute) def test_trivial_pipeline(): @@ -1034,15 +1040,322 @@ def test_error_missing_annotation(): _check_error(simple_compute_missing_annotation) -@tvm.testing.requires_cuda -def test_three_stage_gemm(): - N = K = M = 4096 - i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] +def test_simple_compute_async(): + mod = tvm.IRModule.from_expr(gen_simple_compute(1)) + sch = tvm.tir.Schedule(mod) - def is_ampere_or_newer(): - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - return major >= 8 + _, loop = sch.get_loops(sch.get_block("compute")) + sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) + mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + + @T.prim_func + def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16]) + T.writes(C[tx, 0:16]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads(A[tx, 1:16], B[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[tx, 0:15]) + for i in T.serial(15): + with T.block(): + T.where(i + 1 < 16) + T.reads(A[tx, i + 1]) + T.writes(B[(i + 1) % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.where(i + 1 - 1 < 16) + T.reads(B[(i - 1 + 1) % 2, tx, 0]) + T.writes(C[tx, i - 1 + 1]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 1): + C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[15 % 2, tx, 0]) + T.writes(C[tx, 15]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 0): + C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1) + + tvm.ir.assert_structural_equal(mod["main"], ref, True) + + mod = tvm.IRModule.from_expr(gen_simple_compute(3)) + sch = tvm.tir.Schedule(mod) + + _, loop = sch.get_loops(sch.get_block("compute")) + sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) + mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + + @T.prim_func + def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16]) + T.writes(C[tx, 0:16]) + B = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0:3]) + T.writes(B[0:3, tx, 0]) + for i in T.unroll(3): + with T.block(): + T.where(i < 16) + T.reads(A[tx, i]) + T.writes(B[i % 4, tx, 0]) + T.attr(0, "async_commit_queue_scope", 0) + T.attr(0, "async_scope", 1) + B[i % 4, tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(A[tx, 3:16], B[0:4, tx, 0]) + T.writes(B[0:4, tx, 0], C[tx, 0:13]) + for i in T.serial(13): + with T.block(): + T.where(i + 3 < 16) + T.reads(A[tx, i + 3]) + T.writes(B[(i + 3) % 4, tx, 0]) + T.attr(0, "async_commit_queue_scope", 0) + T.attr(0, "async_scope", 1) + B[(i + 3) % 4, tx, 0] = A[tx, i + 3] * T.float32(2) + with T.block(): + T.where(i + 3 - 3 < 16) + T.reads(B[0:4, tx, 0]) + T.writes(C[tx, i - 3 + 3]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 3): + C[tx, i - 3 + 3] = B[(i - 3 + 3) % 4, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[0:4, tx, 0]) + T.writes(C[tx, 13:16]) + for i in T.unroll(3): + with T.block(): + T.where(i + 16 - 3 < 16) + T.reads(B[0:4, tx, 0]) + T.writes(C[tx, i - 3 + 16]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 2 - i): + C[tx, i - 3 + 16] = B[(i - 3 + 16) % 4, tx, 0] + T.float32(1) + + tvm.ir.assert_structural_equal(mod["main"], ref, True) + + +def test_async_producer_interleaving(): + @T.prim_func + def simple_compute( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in range(16): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = A[tx, i] + with T.block(): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = B[tx, i] + with T.block(): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(simple_compute) + sch = tvm.tir.Schedule(mod) + + _, loop = sch.get_loops(sch.get_block("compute")) + sch.annotate(loop, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(loop, ann_key="software_pipeline_order", ann_val=[0, 2, 1]) + sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0]) + mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + + @T.prim_func + def ref( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16], B[tx, 0:16]) + T.writes(C[tx, 0:16]) + A_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([4, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0:3], B[tx, 0:3]) + T.writes(A_shared[0:3, tx, 0], B_shared[0:3, tx, 0]) + for i in T.unroll(3): + with T.block(): + T.where(i < 16) + T.reads(A[tx, i], B[tx, i]) + T.writes(A_shared[i % 4, tx, 0], B_shared[i % 4, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + A_shared[i % 4, tx, 0] = A[tx, i] + with T.attr(0, "async_scope", 1): + B_shared[i % 4, tx, 0] = B[tx, i] + with T.block(): + T.reads(A[tx, 3:16], A_shared[0:4, tx, 0], B_shared[0:4, tx, 0], B[tx, 3:16]) + T.writes(A_shared[0:4, tx, 0], C[tx, 0:13], B_shared[0:4, tx, 0]) + for i in T.serial(13): + with T.block(): + T.where(i + 3 < 16) + T.reads(A[tx, i + 3]) + T.writes(A_shared[(i + 3) % 4, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + A_shared[(i + 3) % 4, tx, 0] = A[tx, i + 3] + with T.block(): + T.where(i + 3 - 3 < 16) + T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) + T.writes(C[tx, i - 3 + 3]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 5): + C[tx, i - 3 + 3] = ( + A_shared[(i - 3 + 3) % 4, tx, 0] + + B_shared[(i - 3 + 3) % 4, tx, 0] + ) + with T.block(): + T.where(i + 3 < 16) + T.reads(B[tx, i + 3]) + T.writes(B_shared[(i + 3) % 4, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + B_shared[(i + 3) % 4, tx, 0] = B[tx, i + 3] + with T.block(): + T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) + T.writes(C[tx, 13:16]) + for i in T.unroll(3): + with T.block(): + T.where(i + 16 - 3 < 16) + T.reads(A_shared[0:4, tx, 0], B_shared[0:4, tx, 0]) + T.writes(C[tx, i - 3 + 16]) + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 2 - i): + C[tx, i - 3 + 16] = ( + A_shared[(i - 3 + 16) % 4, tx, 0] + + B_shared[(i - 3 + 16) % 4, tx, 0] + ) + + tvm.ir.assert_structural_equal(mod["main"], ref, True) + + +def test_three_stage_compute_two_stage_async(): + mod = tvm.IRModule.from_expr(three_stage_compute) + sch = tvm.tir.Schedule(mod) + + _, loop = sch.get_loops(sch.get_block("compute")) + sch.annotate(loop, ann_key="software_pipeline_async_stages", ann_val=[0, 1]) + + mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod) + + @T.prim_func + def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16]) + T.writes(D[tx, 0:16]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0:2], B[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) + for i in T.unroll(2): + with T.block(): + T.where(i < 16) + T.reads(A[tx, i]) + T.writes(B[i % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + B[i % 2, tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.where(1 <= i and i - 1 < 16) + T.reads(B[(i + 1) % 2, tx, 0]) + T.writes(C[(i + 1) % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 1): + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 1): + with T.attr(0, "async_scope", 1): + C[(i - 1) % 2, tx, 0] = B[ + (i - 1) % 2, tx, 0 + ] + T.float32(2) + with T.block(): + T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) + for i in T.serial(14): + with T.block(): + T.where(i + 2 < 16) + T.reads(A[tx, i + 2]) + T.writes(B[i % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 0): + with T.attr(0, "async_scope", 1): + B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) + with T.block(): + T.where(i + 2 - 1 < 16) + T.reads(B[(i + 1) % 2, tx, 0]) + T.writes(C[(i + 1) % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 1): + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 1): + with T.attr(0, "async_scope", 1): + C[(i - 1 + 2) % 2, tx, 0] = B[ + (i - 1 + 2) % 2, tx, 0 + ] + T.float32(2) + with T.block(): + T.where(i + 2 - 2 < 16) + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i - 2 + 2]) + with T.attr(0, "async_wait_queue_scope", 1): + with T.attr(0, "async_wait_inflight_count", 1): + D[tx, i - 2 + 2] = C[(i - 2 + 2) % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(C[0:2, tx, 0], D[tx, 14:16]) + for i in T.unroll(2): + with T.block(): + T.where(i + 16 - 1 < 16) + T.reads(B[(i + 1) % 2, tx, 0]) + T.writes(C[(i + 1) % 2, tx, 0]) + with T.attr(0, "async_commit_queue_scope", 1): + with T.attr(0, "async_wait_queue_scope", 0): + with T.attr(0, "async_wait_inflight_count", 0 - i): + with T.attr(0, "async_scope", 1): + C[(i - 1 + 16) % 2, tx, 0] = B[ + (i - 1 + 16) % 2, tx, 0 + ] + T.float32(2) + with T.block(): + T.where(i + 16 - 2 < 16) + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i - 2 + 16]) + with T.attr(0, "async_wait_queue_scope", 1): + with T.attr( + 0, + "async_wait_inflight_count", + T.if_then_else(i + 16 - 1 < 16, 1, 0, dtype="int32"), + ): + D[tx, i - 2 + 16] = C[(i - 2 + 16) % 2, tx, 0] + T.float32(1) + + tvm.ir.assert_structural_equal(mod["main"], ref, True) + + +N = K = M = 4096 + + +def get_mma_schedule(): + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [16, 2, 4, 1, 2], [128, 2, 1] def index_map(i, j): return ( @@ -1055,7 +1368,7 @@ def index_map(i, j): te_workload.matmul(N, M, K, in_dtype="float16", out_dtype="float32") ) - sch = mma_schedule( + return mma_schedule( workload, 16, "float16", @@ -1074,13 +1387,11 @@ def index_map(i, j): "shared.dyn", ) - k0 = sch.get_loops(sch.get_block("C_o_update"))[3] - - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if is_ampere_or_newer(): - f = tvm.build(sch.mod["main"], target="cuda") +def build_and_run(sch): + if tvm.testing.is_ampere_or_newer(): + with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) a_np = np.random.uniform(size=(N, K)).astype("float16") @@ -1093,5 +1404,93 @@ def index_map(i, j): tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +@tvm.testing.requires_cuda +def test_async_pipelined_mma_gemm_simple(): + sch = get_mma_schedule() + + k0 = sch.get_loops(sch.get_block("C_o_update"))[3] + + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) + + seq = tvm.transform.Sequential( + [ + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.tir.transform.ConvertBlocksToOpaque(), + tvm.tir.transform.UnifyThreadBinding(), + tvm.tir.transform.LowerMatchBuffer(), + tvm.tir.transform.InjectSoftwarePipeline(), + ] + ) + mod = seq(sch.mod) + + pipeline = mod["main"].body.block.body.body.body.body.body.block.body[1].block.body + prologue, body, epilogue = pipeline + + commit_queue_scope = prologue.block.body.body.block.body + assert len(commit_queue_scope.body) == 2 + assert commit_queue_scope.value == 0 + + commit_queue_scope = body.block.body.body[0].block.body + assert len(commit_queue_scope.body) == 2 + assert commit_queue_scope.value == 0 + + assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count" + assert body.block.body.body[1].block.body.body.value == 3 + + assert epilogue.block.body.body.block.body.body.attr_key == "async_wait_inflight_count" + assert str(epilogue.block.body.body.block.body.body.value) == "(2 - i2_0_0: int32)" + + build_and_run(sch) + + +@tvm.testing.requires_cuda +def test_async_nested_pipeline_mma_gemm_ideal_annotation(): + sch = get_mma_schedule() + + k0 = sch.get_loops(sch.get_block("C_o_update"))[3] + k1 = sch.get_loops(sch.get_block("C_o_update"))[4] + + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 2, 3, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 3, 2, 4]) + sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) + + sch.annotate(k1, ann_key="software_pipeline_stage", ann_val=[0, 0, 1]) + sch.annotate(k1, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + + seq = tvm.transform.Sequential( + [ + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + tvm.tir.transform.ConvertBlocksToOpaque(), + tvm.tir.transform.UnifyThreadBinding(), + tvm.tir.transform.LowerMatchBuffer(), + tvm.tir.transform.InjectSoftwarePipeline(), + ] + ) + mod = seq(sch.mod) + + pipeline = mod["main"].body.block.body.body.body.body.body.block.body[1].block.body + prologue, body, epilogue = pipeline + + commit_queue_scope = prologue.block.body.body[0].block.body + assert len(commit_queue_scope.body) == 2 + assert commit_queue_scope.value == 0 + + assert prologue.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count" + assert prologue.block.body.body[1].block.body.body.value == 2 + + commit_queue_scope = body.block.body.body[0].block.body + assert len(commit_queue_scope.body) == 2 + assert commit_queue_scope.value == 0 + + assert body.block.body.body[1].block.body.body.attr_key == "async_wait_inflight_count" + assert body.block.body.body[1].block.body.body.value == 2 + + assert str(epilogue.block.body.body[0].block.body.body.value) == "(1 - i2_0_0: int32)" + + build_and_run(sch) + + if __name__ == "__main__": tvm.testing.main()