Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class LinearCongruentialEngine {
* \return The forked seed.
*/
TRandState ForkSeed() {
// In order for reproducibility, we computer the new seed using RNG's random state and a
// In order for reproducibility, we compute the new seed using RNG's random state and a
// different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
return ((*this)() * 32767) % 1999999973;
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class ScheduleNode : public runtime::Object {
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
*/
virtual Schedule Copy() const = 0;
virtual Schedule Copy() = 0;
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,12 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb
new_state->get()->DebugVerify();
}

Schedule ConcreteScheduleNode::Copy() const {
Schedule ConcreteScheduleNode::Copy() {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
n->rand_state_ = ForkSeed();
return Schedule(std::move(n));
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
Schedule Copy() const override;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final;
support::LinearCongruentialEngine::TRandState ForkSeed() final;

Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
return Schedule(std::move(n));
}

Schedule TracedScheduleNode::Copy() const {
Schedule TracedScheduleNode::Copy() {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
n->rand_state_ = ForkSeed();
n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
return Schedule(std::move(n));
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {

public:
Optional<Trace> trace() const final { return trace_; }
Schedule Copy() const final;
Schedule Copy() final;

public:
/******** Schedule: Sampling ********/
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_tir_schedule_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,16 @@ def test_sample_compute_location():
numpy.testing.assert_allclose(expected_rate, cnt / n, atol=0.04)


def test_sample_perfect_tile_after_copy():
sch = tir.Schedule(elementwise, debug_mask="all")
sch_copy = sch.copy()
_, _, i = sch.get_loops(sch.get_block("B"))
sch.sample_perfect_tile(i, n=4)

_, _, i = sch_copy.get_loops(sch_copy.get_block("B"))
# Hangs if ForkSeed is not invoked when copying a schedule
sch_copy.sample_perfect_tile(i, n=4)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))