Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*! \brief Mark the target loop for decompose reduction. This serves as a hint for
* RewriteReductionBlock postprocessor.
*/
constexpr const char* meta_schedule_decompose_point = "meta_schedule.decompose_point";

/*! \brief Mark that a block should be further rewritten using tensorization. */
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

Expand Down
20 changes: 14 additions & 6 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,23 @@ struct ReductionBlockFinder : private StmtVisitor {
/*!
* \brief Find the innermost loop that the `init` of the input block could be decomposed to
* \param block_sref The StmtSRef of the block to be decomposed
* \return The index of the innermost loop where the `init` of the input block could be decomposed,
* or -1 if the `init` does not need to be decomposed.
* \return A pair of the loop index and the boolean flag. The index indicates the innermost loop
* where the `init` of the input block could be decomposed, or -1 if the `init` does not need to
* be decomposed. The boolean flag indicates whether these is a hint of the decompose point in the
* loop annotations that need to be unannotated.
*/
int FindDecomposePoint(const StmtSRef& block_sref) {
std::pair<int, bool> FindDecomposePoint(const StmtSRef& block_sref) {
Array<StmtSRef> loop_srefs = GetLoops(block_sref);
int n = loop_srefs.size();
for (int i = 0; i < n; ++i) {
if (HasAnn(loop_srefs[i], attr::meta_schedule_decompose_point, true)) {
return {i, true};
}
if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) {
return i;
return {i, false};
}
}
return -1;
return {-1, false};
}

} // namespace tir
Expand Down Expand Up @@ -133,14 +138,17 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
for (const auto& kv : results) {
const tir::StmtSRef& block_sref = kv.first;
const String& global_var_name = kv.second;
int decompose_point = tir::FindDecomposePoint(block_sref);
auto [decompose_point, need_unannotate] = tir::FindDecomposePoint(block_sref);
if (decompose_point == -1) {
continue;
}
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);

if (need_unannotate) {
sch->Unannotate(loop_rvs[decompose_point], tir::attr::meta_schedule_decompose_point);
}
// Rewrite auto tensorization related annotations
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
// Remove tensorization annotation as it shouldn't be propagated to the init block.
Expand Down
14 changes: 14 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const;
// Subrule: Add software pipeline
inline std::vector<State> AddSoftwarePipeline(TensorCoreState state) const;
// Subrule: Add annotation of outermost reduction loop
inline std::vector<State> AnnotateOutermostReduction(State state) const;

// Override ApplySubRules to apply tensorization-specific sub-rules
std::vector<State> ApplySubRules(std::vector<State> states) final;
Expand Down Expand Up @@ -225,6 +227,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
return TransformForTensorization(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
states =
SubRule(std::move(states), [&](State state) { return AnnotateOutermostReduction(state); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) {
return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
Expand Down Expand Up @@ -556,6 +560,16 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
return {std::move(state)};
}

inline std::vector<State> MultiLevelTilingTensorCoreNode::AnnotateOutermostReduction(
State state) const {
Schedule& sch = state->sch;
if (r_indices_.size()) {
LoopRV outermost_reduction_loop = state->tiles[r_indices_.front()].front();
sch->Annotate(outermost_reduction_loop, tir::attr::meta_schedule_decompose_point, Integer(1));
}
return {std::move(state)};
}

ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,47 @@ def test_rewrite_softmax():
tvm.ir.assert_structural_equal(sch.mod, Softmax_cross_thread_reduction)


def test_rewrite_unit_reduction_loop():
# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
@T.prim_func
def before(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
for i0, j0 in T.grid(4, 4):
for k0 in T.serial(1, annotations={"meta_schedule.decompose_point": 1}):
for i1, j1, k1 in T.grid(4, 4, 16):
with T.block("C"):
i = T.axis.spatial(16, i0 * 4 + i1)
j = T.axis.spatial(16, j0 * 4 + j1)
k = T.axis.reduce(16, k1)
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[k, j]

@T.prim_func
def expected(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
for i0, j0 in T.grid(4, 4):
for i1_init, j1_init in T.grid(4, 4):
with T.block("C_init"):
i = T.axis.spatial(16, i0 * 4 + i1_init)
j = T.axis.spatial(16, j0 * 4 + j1_init)
C[i, j] = T.float32(0)
for k0, i1, j1, k1 in T.grid(1, 4, 4, 16):
with T.block("C_update"):
i = T.axis.spatial(16, i0 * 4 + i1)
j = T.axis.spatial(16, j0 * 4 + j1)
k = T.axis.reduce(16, k1)
C[i, j] = C[i, j] + A[i, k] * B[k, j]

# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

target = _target()
ctx = _create_context(before, target)
sch = tir.Schedule(before, debug_mask="all")
sch.enter_postproc()
tvm.ir.assert_structural_equal(sch.mod["main"], expected)


if __name__ == "__main__":
test_rewrite_tiled_matmul()
test_rewrite_softmax()
Loading