From d8f8193c3e05926eb40a3c7c9da789c38002e172 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 24 Apr 2026 19:41:17 +0800 Subject: [PATCH 1/3] [S-TIR][MetaSchedule] Make evolutionary search resilient to trace replay failures --- .../search_strategy/evolutionary_search.cc | 24 ++++++++++--- src/s_tir/meta_schedule/utils.h | 35 ++++++++++++++----- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc index fabe50dd60f9..ec85ffeee510 100644 --- a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc @@ -498,13 +498,24 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu TVM_FFI_ICHECK(!result.defined()); if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); - } else { - TVM_FFI_THROW(ValueError) << "Cannot postprocess the trace:\n" << trace; - throw; } }; support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, f_proc_measured); - return results; + TVM_PY_LOG(INFO, self->ctx_->logger) << "Pick-Best-From-Database summary:\n" + << pp.SummarizeFailures(); + if (pp.TraceFailCount() > 0) { + TVM_PY_LOG(WARNING, self->ctx_->logger) + << "PickBestFromDatabase skipped " << pp.TraceFailCount() + << " candidate(s) due to trace replay failures"; + } + std::vector filtered; + filtered.reserve(actual_num); + for (const Schedule& sch : results) { + if (sch.defined()) { + filtered.push_back(sch); + } + } + return filtered; } std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { @@ -538,6 +549,11 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu fail_count += !found_new; TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + if (pp.TraceFailCount() > 0) { + TVM_PY_LOG(WARNING, self->ctx_->logger) + << "SampleInitPopulation encountered " << pp.TraceFailCount() + << " trace replay failure(s); invalid candidates were skipped"; + } } return out_schs; } diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 847adc2591da..524a99dfddfd 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -330,14 +330,24 @@ struct ThreadedTraceApply { */ ffi::Optional Apply(const IRModule& mod, const s_tir::Trace& trace, TRandState* rand_state) { - s_tir::Schedule sch = - s_tir::Schedule::Traced(mod, - /*rand_state=*/ForkSeed(rand_state), - /*debug_mode=*/0, - /*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone); - - trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - sch->EnterPostproc(); + s_tir::Schedule sch{nullptr}; + try { + sch = s_tir::Schedule::Traced(mod, + /*rand_state=*/ForkSeed(rand_state), + /*debug_mode=*/0, + /*error_render_level=*/ + s_tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + sch->EnterPostproc(); + } catch (const s_tir::ScheduleError& e) { + DLOG(WARNING) << "Trace replay failed with ScheduleError: " << e.what(); + this->trace_fail_counter_++; + return std::nullopt; + } catch (const std::exception& e) { + DLOG(WARNING) << "Trace replay failed with exception: " << e.what(); + this->trace_fail_counter_++; + return std::nullopt; + } for (int i = 0; i < n_; ++i) { Item& item = items_[i]; @@ -364,6 +374,10 @@ struct ThreadedTraceApply { /*! \brief Returns a string summarizing the failures on each postprocessor */ std::string SummarizeFailures() const { std::ostringstream os; + os << "Trace replay failures: " << this->trace_fail_counter_.load() << " failure(s)"; + if (n_ > 0) { + os << "\n"; + } for (int i = 0; i < n_; ++i) { const Item& item = items_[i]; os << "Postproc #" << i << " [" << item.postproc // @@ -375,6 +389,9 @@ struct ThreadedTraceApply { return os.str(); } + /*! \brief Returns the number of trace replay failures. */ + int TraceFailCount() const { return this->trace_fail_counter_.load(); } + private: /*! \brief A helper data structure that stores the fail count for each postprocessor. */ struct Item { @@ -386,6 +403,8 @@ struct ThreadedTraceApply { /*! \brief The number of total postprocessors. */ int n_; + /*! \brief The thread-safe trace replay failure counter. */ + std::atomic trace_fail_counter_{0}; /*! \brief The pointer to the list of postprocessor items. */ Item* items_; }; From 9f54fe08e74b7169519615f436c99d2e31d884ae Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 24 Apr 2026 21:12:05 +0800 Subject: [PATCH 2/3] [S-TIR][MetaSchedule] Add test case --- .../test_meta_schedule_search_strategy.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py index 5df88ba7d511..370eff27c77b 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py @@ -49,6 +49,22 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +@tvm.script.ir_module +class OtherBlock: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (32, 32), "float32") + B = T.match_buffer(b, (32, 32), "float32") + C = T.match_buffer(c, (32, 32), "float32") + for i, j, k in T.grid(32, 32, 32): + with T.sblock("other"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + # fmt: on # pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -308,6 +324,63 @@ def __str__(self) -> str: assert candidates is None +def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() # pylint: disable = invalid-name + # Construct an incompatible measured trace: it references block name "other", + # which doesn't exist in Matmul. Replaying this trace should fail and be skipped. + wrong_sch = Schedule(OtherBlock) + wrong_sch.get_sblock("other") + wrong_trace = wrong_sch.trace + + database = ms.database.MemoryDatabase() + workload = database.commit_workload(Matmul) + database.commit_tuning_record( + ms.database.TuningRecord( + trace=wrong_trace, + workload=workload, + run_secs=[0.1], + target=tvm.target.Target("llvm"), + args_info=ms.arg_info.ArgInfo.from_prim_func(func=Matmul["main"]), + ) + ) + + context = ms.TuneContext( + mod=Matmul, + space_generator=ms.space_generator.ScheduleFn( + sch_fn=_schedule_matmul, + sch_rules=[], + postprocs=[], + mutator_probs={ + DummyMutator(): 1.0, + }, + ), + search_strategy=ms.search_strategy.EvolutionarySearch( + population_size=5, + init_measured_ratio=1.0, + init_min_unmeasured=1, + genetic_num_iters=1, + genetic_mutate_prob=0.5, + genetic_max_fail_count=4, + eps_greedy=0.9, + ), + target=tvm.target.Target("llvm"), + num_threads=1, + ) + strategy = context.search_strategy + strategy.pre_tuning( + max_trials=4, + num_trials_per_iter=2, + design_spaces=context.space_generator.generate_design_space(context.mod), + database=database, + cost_model=ms.cost_model.RandomModel(), + ) + + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + + # Regression assertion: invalid measured trace should be skipped, not crash + assert candidates is not None + + def test_search_strategy_abstract_class_instantiation(): """Test that directly instantiating abstract SearchStrategy raises TypeError instead of segfault.""" from tvm.s_tir.meta_schedule import SearchStrategy, TuneContext From fadd2991550c548e59f72a12af1c788863980892 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 25 Apr 2026 15:30:45 +0800 Subject: [PATCH 3/3] [S-TIR][MetaSchedule] Update the log --- src/s_tir/meta_schedule/utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 524a99dfddfd..a7804361ebca 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -340,11 +340,11 @@ struct ThreadedTraceApply { trace->ApplyToSchedule(sch, /*remove_postproc=*/true); sch->EnterPostproc(); } catch (const s_tir::ScheduleError& e) { - DLOG(WARNING) << "Trace replay failed with ScheduleError: " << e.what(); + TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with ScheduleError: " << e.what(); this->trace_fail_counter_++; return std::nullopt; } catch (const std::exception& e) { - DLOG(WARNING) << "Trace replay failed with exception: " << e.what(); + TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with exception: " << e.what(); this->trace_fail_counter_++; return std::nullopt; }