Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,24 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu
TVM_FFI_ICHECK(!result.defined());
if (ffi::Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
result = sch.value();
} else {
TVM_FFI_THROW(ValueError) << "Cannot postprocess the trace:\n" << trace;
throw;
}
};
support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, f_proc_measured);
return results;
TVM_PY_LOG(INFO, self->ctx_->logger) << "Pick-Best-From-Database summary:\n"
<< pp.SummarizeFailures();
if (pp.TraceFailCount() > 0) {
TVM_PY_LOG(WARNING, self->ctx_->logger)
<< "PickBestFromDatabase skipped " << pp.TraceFailCount()
<< " candidate(s) due to trace replay failures";
}
std::vector<Schedule> filtered;
filtered.reserve(actual_num);
for (const Schedule& sch : results) {
if (sch.defined()) {
filtered.push_back(sch);
}
}
return filtered;
}

std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
Expand Down Expand Up @@ -538,6 +549,11 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
fail_count += !found_new;
TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n"
<< pp.SummarizeFailures();
if (pp.TraceFailCount() > 0) {
TVM_PY_LOG(WARNING, self->ctx_->logger)
<< "SampleInitPopulation encountered " << pp.TraceFailCount()
<< " trace replay failure(s); invalid candidates were skipped";
}
}
return out_schs;
}
Expand Down
35 changes: 27 additions & 8 deletions src/s_tir/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,24 @@ struct ThreadedTraceApply {
*/
ffi::Optional<s_tir::Schedule> Apply(const IRModule& mod, const s_tir::Trace& trace,
TRandState* rand_state) {
s_tir::Schedule sch =
s_tir::Schedule::Traced(mod,
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone);

trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
s_tir::Schedule sch{nullptr};
try {
sch = s_tir::Schedule::Traced(mod,
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/
s_tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
} catch (const s_tir::ScheduleError& e) {
TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with ScheduleError: " << e.what();
this->trace_fail_counter_++;
return std::nullopt;
} catch (const std::exception& e) {
TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with exception: " << e.what();
this->trace_fail_counter_++;
return std::nullopt;
}

for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
Expand All @@ -364,6 +374,10 @@ struct ThreadedTraceApply {
/*! \brief Returns a string summarizing the failures on each postprocessor */
std::string SummarizeFailures() const {
std::ostringstream os;
os << "Trace replay failures: " << this->trace_fail_counter_.load() << " failure(s)";
if (n_ > 0) {
os << "\n";
}
for (int i = 0; i < n_; ++i) {
const Item& item = items_[i];
os << "Postproc #" << i << " [" << item.postproc //
Expand All @@ -375,6 +389,9 @@ struct ThreadedTraceApply {
return os.str();
}

/*! \brief Returns the number of trace replay failures. */
int TraceFailCount() const { return this->trace_fail_counter_.load(); }

private:
/*! \brief A helper data structure that stores the fail count for each postprocessor. */
struct Item {
Expand All @@ -386,6 +403,8 @@ struct ThreadedTraceApply {

/*! \brief The number of total postprocessors. */
int n_;
/*! \brief The thread-safe trace replay failure counter. */
std::atomic<int> trace_fail_counter_{0};
/*! \brief The pointer to the list of postprocessor items. */
Item* items_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore
C[vi, vj] = 0.0 # type: ignore
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class OtherBlock:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (32, 32), "float32")
B = T.match_buffer(b, (32, 32), "float32")
C = T.match_buffer(c, (32, 32), "float32")
for i, j, k in T.grid(32, 32, 32):
with T.sblock("other"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0 # type: ignore
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument

Expand Down Expand Up @@ -308,6 +324,63 @@ def __str__(self) -> str:
assert candidates is None


def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() # pylint: disable = invalid-name
# Construct an incompatible measured trace: it references block name "other",
# which doesn't exist in Matmul. Replaying this trace should fail and be skipped.
wrong_sch = Schedule(OtherBlock)
wrong_sch.get_sblock("other")
wrong_trace = wrong_sch.trace

database = ms.database.MemoryDatabase()
workload = database.commit_workload(Matmul)
database.commit_tuning_record(
ms.database.TuningRecord(
trace=wrong_trace,
workload=workload,
run_secs=[0.1],
target=tvm.target.Target("llvm"),
args_info=ms.arg_info.ArgInfo.from_prim_func(func=Matmul["main"]),
)
)

context = ms.TuneContext(
mod=Matmul,
space_generator=ms.space_generator.ScheduleFn(
sch_fn=_schedule_matmul,
sch_rules=[],
postprocs=[],
mutator_probs={
DummyMutator(): 1.0,
},
),
search_strategy=ms.search_strategy.EvolutionarySearch(
population_size=5,
init_measured_ratio=1.0,
init_min_unmeasured=1,
genetic_num_iters=1,
genetic_mutate_prob=0.5,
genetic_max_fail_count=4,
eps_greedy=0.9,
),
target=tvm.target.Target("llvm"),
num_threads=1,
)
strategy = context.search_strategy
strategy.pre_tuning(
max_trials=4,
num_trials_per_iter=2,
design_spaces=context.space_generator.generate_design_space(context.mod),
database=database,
cost_model=ms.cost_model.RandomModel(),
)

candidates = strategy.generate_measure_candidates()
strategy.post_tuning()

# Regression assertion: invalid measured trace should be skipped, not crash
assert candidates is not None


def test_search_strategy_abstract_class_instantiation():
"""Test that directly instantiating abstract SearchStrategy raises TypeError instead of segfault."""
from tvm.s_tir.meta_schedule import SearchStrategy, TuneContext
Expand Down
Loading