Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,47 @@ std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters()
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);

/******** Misc ********/

/*!
* \brief Checks if a block could be successfully computed inline into its consumer
* \param self The schedule state
* \param block_sref The block to be checked
* \return A boolean indicating whether the block could be successfully computed inline
*/
bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if a block could be successfully computed inline into its producer
* \param self The schedule state
* \param block_sref The block to be checked
* \return A boolean indicating whether the block could be successfully computed inline
*/
bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if a producer block could be successfully computed at the specific loop.
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \return A boolean indicating whether the block could be successfully compute at the specific loop
*/
bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);

/*!
* \brief Checks if a consumer block could be successfully computed at the specific loop.
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \return A boolean indicating whether the block could be successfully reverse compute at the
* specific loop
*/
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);

} // namespace tir
} // namespace tvm

Expand Down
46 changes: 39 additions & 7 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ void CalculateProvidedRequiredRegions(

template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand All @@ -463,11 +464,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
arith::Analyzer analyzer;
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
&analyzer);
analyzer);
// Check condition 4): `block` is not an output block
if (is_compute_at) {
CheckNotOutputBlock(self, block_sref, scope_root_sref);
Expand Down Expand Up @@ -501,29 +501,61 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
/*provided_regions=*/std::move(provided_regions),
/*required_regions=*/std::move(required_regions),
/*analyzer=*/&analyzer);
/*analyzer=*/analyzer);
// Step 6. Create the new scope according to the iteration domain
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
/*preserve_unit_loops=*/preserve_unit_loops);
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));

// Step 7. Do the actual replacement
if (check_only) {
return;
}
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
// Step 8. Update the cached flags
BlockInfo& block_info = self->block_info[block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/reconstructor.new_block_realize_,
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
/*analyzer=*/&analyzer);
/*analyzer=*/analyzer);
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops);
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops);
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}

bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

/******** InstructionKind Registration ********/
Expand Down
66 changes: 59 additions & 7 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError {
bool is_read_;
Block block_;

static Buffer GetSingleRead(const ScheduleState& self, const Block& block) {
if (block->reads.size() != 1) {
static Buffer GetSingleRead(const ScheduleState& self, const Block& block,
const StmtSRef& scope_root_sref) {
const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>&
buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers;
const BufferNode* read_buffer = nullptr;
for (const BufferRegion& read_region : block->reads) {
const BufferNode* buffer = read_region->buffer.get();
if (buffer == read_buffer) {
continue;
}
if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) {
if (read_buffer != nullptr) {
throw NotSingleReadWriteBuffer(self->mod, true, block);
}
read_buffer = buffer;
}
}
if (read_buffer == nullptr) {
throw NotSingleReadWriteBuffer(self->mod, true, block);
}
return block->reads[0]->buffer;
return GetRef<Buffer>(read_buffer);
}

static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) {
Expand Down Expand Up @@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError {
* \brief The base class of the inliner, which handles:
* 1) Substitute a subtree with the specific block being inlined
* 2) Update the block signature to reflect the changes of read/write/allocated buffers
* 3) Maintain a list of index variables and their substition of the buffer being inlined
* 3) Maintain a list of index variables and their substitution of the buffer being inlined
*/
class BaseInliner : public StmtExprMutator {
protected:
Expand Down Expand Up @@ -526,7 +542,8 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr producer_rhs_{nullptr};
};

void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
bool check_only = false) {
const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref);
Block producer_block = GetRef<Block>(_producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
Expand All @@ -535,6 +552,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
// Step 2. Check completeness
CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
// Step 3. Analyze the block body
ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref);
Expand All @@ -550,17 +568,35 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 6. Do the real mutation on the AST and the sref tree in the schedule state
if (check_only) {
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}

void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
ComputeInlineImpl(self, producer_block_sref);
}

bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) {
try {
ComputeInlineImpl(self, producer_block_sref, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref,
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
Buffer inlined_buffer =
NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
Expand All @@ -579,9 +615,25 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 7. Do the real mutation on the AST and the sref tree in the schedule state
if (check_only) {
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}

bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
try {
ReverseComputeInlineImpl(self, block_sref, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
ReverseComputeInlineImpl(self, consumer_block_sref);
}

/******** InstructionKind Registration ********/

struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> {
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,28 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
B[vi] = A_cache[vi] * 2.0 + 1.0


@T.prim_func
def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
B = T.match_buffer(var_B, [512, 512], dtype="float32")
compute = T.match_buffer(var_compute, [512, 512], dtype="float32")
C = T.alloc_buffer([512, 512], dtype="float32")
for i0, i1, i2 in T.grid(512, 512, 512):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads([C[i, j], A[i, k], B[k, j]])
T.writes([C[i, j]])
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[k, j]
for i0, i1 in T.grid(512, 512):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads([C[i0_1, i1_1]])
T.writes([compute[i0_1, i1_1]])
compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -458,6 +480,13 @@ def test_buffer_matched():
sch.compute_inline(block_b)


def test_output_block():
sch = tir.Schedule(matmul_relu, debug_mask="all")
block = sch.get_block("compute")
with pytest.raises(tvm.tir.ScheduleError):
sch.compute_inline(block)


def test_compute_inline_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")
Expand Down