Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dc5e2ef
[TIR] Support asynchronous stages in software pipeline transform
masahi Jun 14, 2022
1054638
Support interleaved async producers separated by a consumer
masahi Jun 16, 2022
fcb75a5
clean up
masahi Jun 17, 2022
ab78c35
adding doc
masahi Jun 17, 2022
b2ade84
adding doc
masahi Jun 17, 2022
769632b
simplifying
masahi Jun 17, 2022
67f81a7
make wait count computation a two pass process
masahi Jun 20, 2022
8c01129
commit_stage -> commit_queue, wait_stage -> wait_queue
masahi Jun 24, 2022
9d0f7d6
make async_commit_queue special scope stmt
masahi Jun 27, 2022
a5a4bfc
codegen async_commit_queue in cuda
masahi Jun 27, 2022
6e0b442
clean up
masahi Jun 27, 2022
75f8a38
clean up
masahi Jun 27, 2022
8f04f70
Move block predicate outside of commit_queue
masahi Jun 27, 2022
bc4f073
updating test
masahi Jun 28, 2022
c80bbd9
test updated
masahi Jun 28, 2022
7e50d2f
changed async_wait to an annotation
masahi Jul 20, 2022
b4289a3
update doc
masahi Jul 21, 2022
be51062
update meaning of software_pipeline_async_stages
masahi Jul 21, 2022
d446581
update test
masahi Jul 21, 2022
8228587
fixing codegen
masahi Jul 21, 2022
07dd0b2
more fix
masahi Jul 21, 2022
dca56c6
remove one of tests that have async and sync ops in the same stage
masahi Jul 21, 2022
8a0ff51
format
masahi Jul 21, 2022
468566f
lint and other fix
masahi Jul 25, 2022
bf13acf
Define attr::software_pipeline_async_stages
masahi Jul 26, 2022
787f608
populate wait count in a separate function
masahi Jul 26, 2022
44bbb12
fold variabel consumed into AsyncStateLocal
masahi Jul 26, 2022
d4ae91a
introduce CompletePipelineLoopStatements function for further refactor
masahi Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,27 @@ constexpr const char* device_scope = "device_scope";
*/
constexpr const char* async_scope = "async_scope";

/*!
* \brief Annotations for invoking and synchronizing asynchronous operations.

* Synchronization is done in terms of "queue": It is an abstract entity associated
* with each asynchronous unit, and it tracks invocations and completions of asynchronous
* operations in the FIFO order.
*
* Similarly to PTX instructions commit_group and wait_group, these annotations express
* synchronization by "counting":
*
* async_commit_queue(i): Group one or more invocations of async operations in the given scope,
* and "commit" (or push) them to the queue i. A group of operations committed together is
* awaited as one chunk. Groups committed to the same queue complete in the FIFO order.
*
* async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at
* the queue i. N does not have to be a constant, but some backends may require a constant count.
*/
constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";

/*!
* \brief Mark that the shape of TensorCore fragment
*/
Expand Down Expand Up @@ -1483,6 +1504,12 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage";
/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";

/*! \brief List stages in the software pipeline that should run asynchronously
* \note All statements in the provided stages are assumed to have asynchronous
* semantics (e.g. CUDA async global to shared memory copy).
*/
constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";

/*! \brief Mark the buffers which is const access and can be transformed layout. */
constexpr const char* layout_free_buffers = "layout_free_buffers";

Expand Down
18 changes: 18 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,24 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
return;
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
this->VisitStmt(inner->body);
return;
}
CodeGenC::VisitStmt_(op);
}
Expand Down
Loading