diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 89b1e9117ff4..fe56bb51eddd 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -115,7 +115,7 @@ class LinearCongruentialEngine { * \return The forked seed. */ TRandState ForkSeed() { - // In order for reproducibility, we computer the new seed using RNG's random state and a + // In order for reproducibility, we compute the new seed using RNG's random state and a // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. return ((*this)() * 32767) % 1999999973; } diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 0273ece0b3b1..1d9bfc9843b5 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -123,7 +123,7 @@ class ScheduleNode : public runtime::Object { * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed */ - virtual Schedule Copy() const = 0; + virtual Schedule Copy() = 0; /*! * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 331ae0209cc0..e261cf2a03de 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -182,11 +182,12 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb new_state->get()->DebugVerify(); } -Schedule ConcreteScheduleNode::Copy() const { +Schedule ConcreteScheduleNode::Copy() { ObjectPtr n = make_object(); n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->rand_state_ = ForkSeed(); return Schedule(std::move(n)); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 32aab1a7b44d..59764e36fe70 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -61,7 +61,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } - Schedule Copy() const override; + Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; support::LinearCongruentialEngine::TRandState ForkSeed() final; diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 8af66f1ede75..417f80dd9337 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -33,11 +33,12 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand return Schedule(std::move(n)); } -Schedule TracedScheduleNode::Copy() const { +Schedule TracedScheduleNode::Copy() { ObjectPtr n = make_object(); n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->rand_state_ = ForkSeed(); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5d355bd70c99..442b50ad0cfc 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -43,7 +43,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: Optional trace() const final { return trace_; } - Schedule Copy() const final; + Schedule Copy() final; public: /******** Schedule: Sampling ********/ diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index cc2b114824a5..d8f9670250ed 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -194,5 +194,16 @@ def test_sample_compute_location(): numpy.testing.assert_allclose(expected_rate, cnt / n, atol=0.04) +def test_sample_perfect_tile_after_copy(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch_copy = sch.copy() + _, _, i = sch.get_loops(sch.get_block("B")) + sch.sample_perfect_tile(i, n=4) + + _, _, i = sch_copy.get_loops(sch_copy.get_block("B")) + # Hangs if ForkSeed is not invoked when copying a schedule + sch_copy.sample_perfect_tile(i, n=4) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))