Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();

/*!
* \brief Lower cross-thread reduction from thread
* bindings to intrinsic function calls.
* \return The pass.
*/
TVM_DLL Pass LowerCrossThreadReduction();

/*!
* \brief Lower block init stmt into IfThenElse stmts
* \return The pass.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None):
return _ffi_api.HoistIfThenElse() # type: ignore


def LowerCrossThreadReduction():
"""Lower cross-thread reduction from thread bindings to
intrinsic function calls.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerCrossThreadReduction() # type: ignore


def LowerInitBlock():
"""Lower block init stmt into IfThenElse statements.

Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
Expand Down
50 changes: 49 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -323,14 +328,57 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/******** Reduction Block Related ********/

/*!
* \brief Convert the `init` and `body` of the input block to BufferStores
* \param self The schedule state
* \param block The block to be analyzed
* \return The BufferStores of the `init` and `body` of the input block
* \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
* buffer
*/
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block);

/*!
* \brief Check whether the input array of IterVars only contains data-parallel and reduction block
* iters
* \param iters The input array of IterVars to be checked
* \return A boolean indicating whether the input array of IterVars only contains data-parallel and
* reduction block iters
*/
bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);

/*!
* \brief Check whether the block's reduction block iters are not used to index the block's output
* buffers
* \param block The block to be checked
* \return A boolean indicating whether the block's reduction block iters are not used to index the
* block's output buffer
*/
bool ReductionIterNotIndexOutputBuffer(const Block& block);

/*!
* \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative
* reducer, and extract the combiner lhs and combiner rhs
* \param self The schedule state
* \param identity The reduction identity to be analyzed
* \param combiner The reduction combiner to be analyzed
* \return The corresponding CommReducer, the combiner lhs and the combiner rhs
* \throw ScheduleError If no corresponding commutative reducer can be matched
*/
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

/*!
* \brief Get the list of the registered reducer-getter functions
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();

/*!
* \brief Given the input identity and the combiner BufferStore of a reduction, extract the
Expand Down
Loading