diff --git a/src/s_tir/transform/default_gpu_schedule.cc b/src/s_tir/transform/default_gpu_schedule.cc index d41e2f58433e..b130cbfe45f2 100644 --- a/src/s_tir/transform/default_gpu_schedule.cc +++ b/src/s_tir/transform/default_gpu_schedule.cc @@ -103,6 +103,55 @@ IRModule MarkScheduled(const IRModule& mod) { mod->global_infos); // global_infos } +/*! + * \brief Wrap a PrimFunc body that is a bare \c SBlockRealize (no enclosing + * loops, no iter vars) so the realized block is no longer the function's root + * sref. + * + * Without this, \c ThreadBind below calls \c Schedule::AddUnitLoop(block) on + * a block that is itself the prim_func's root sref, hitting the + * "Cannot add loops on top of the root block" check in + * \c s_tir::AddUnitLoop. The schedule infrastructure additionally requires + * the prim_func body to be an \c SBlockRealize, so we keep that shape and + * push the original block one level deeper, inside a wrapping root block + * that holds a unit serial loop. The synthesised data-parallel iter keeps + * iter_values/iter_vars counts consistent for downstream checks. + */ +tirx::PrimFunc WrapBareSBlockBody(const tirx::PrimFunc& func) { + const auto* realize = func->body.as(); + if (realize == nullptr || !realize->block->iter_vars.empty()) { + return func; + } + // Only wrap when the block is a leaf computation. A well-formed PrimFunc + // produced by the rest of the pipeline has an implicit root SBlockRealize + // whose block body is a For loop (or a nested SBlockRealize) — that case + // already has somewhere to put thread bindings, so leave it alone. + const tirx::Stmt& inner = realize->block->body; + if (inner->IsInstance() || inner->IsInstance()) { + return func; + } + tvm::IntImm zero(tvm::DataType::Int(32), 0); + tvm::IntImm one(tvm::DataType::Int(32), 1); + tirx::Var loop_var("u", tvm::DataType::Int(32)); + tirx::Var iter_var_var("vu", tvm::DataType::Int(32)); + tirx::IterVar new_iter(tvm::Range::FromMinExtent(zero, one), iter_var_var, + tirx::IterVarType::kDataPar); + tirx::SBlock inner_block = realize->block; + inner_block.CopyOnWrite()->iter_vars = ffi::Array{new_iter}; + tirx::SBlockRealize inner_realize(/*iter_values=*/ffi::Array{loop_var}, + /*predicate=*/realize->predicate, inner_block); + tirx::Stmt for_stmt = tirx::For(loop_var, zero, one, tirx::ForKind::kSerial, inner_realize); + tirx::SBlock root_block(/*iter_vars=*/ffi::Array{}, + /*reads=*/ffi::Array{}, + /*writes=*/ffi::Array{}, + /*name_hint=*/"root", /*body=*/for_stmt); + tirx::SBlockRealize root_realize(/*iter_values=*/ffi::Array{}, + /*predicate=*/tvm::Bool(true), root_block); + tirx::PrimFunc result = func; + result.CopyOnWrite()->body = std::move(root_realize); + return result; +} + bool IsScheduledOnGPU(const BaseFunc& func) { // the target from context. tvm::Target target = tvm::Target::Current(); @@ -125,6 +174,27 @@ bool IsScheduledOnGPU(const BaseFunc& func) { Pass DefaultGPUSchedule() { auto pass_func = // [=](IRModule m, PassContext pc) { + // Wrap any GPU-bound PrimFunc whose body is a bare SBlockRealize + // (e.g. a scalar op) so ThreadBind below has a loop to operate on. + ffi::Map wrapped; + bool any_wrapped = false; + for (const auto& [gv, base_func] : m->functions) { + if (const auto* prim_func_node = base_func.as(); + prim_func_node != nullptr && IsScheduledOnGPU(base_func) && + !base_func->HasNonzeroAttr(tirx::attr::kIsScheduled)) { + tirx::PrimFunc func = ffi::GetRef(prim_func_node); + tirx::PrimFunc new_func = WrapBareSBlockBody(func); + if (!new_func.same_as(func)) { + wrapped.Set(gv, new_func); + any_wrapped = true; + continue; + } + } + wrapped.Set(gv, base_func); + } + if (any_wrapped) { + m = IRModule(wrapped, m->source_map, m->attrs, m->global_infos); + } s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, s_tir::ScheduleErrorRenderLevel::kDetail); for (const auto& [gv, func] : m->functions) { diff --git a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py index c562a29e8781..f08dba00d6c2 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py @@ -567,5 +567,39 @@ def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "f tvm.ir.assert_structural_equal(mod, Expected) +def test_scalar_block_no_loops(): + # A PrimFunc whose body is a bare SBlockRealize (e.g. a fully-scalar op) + # used to crash DefaultGPUSchedule with "Cannot add loops on top of the + # root block" because the realized block was the function's root sref. + # pylint: disable=no-self-argument,missing-class-docstring,line-too-long + # fmt: off + @tvm.script.ir_module + class Before: + @T.prim_func + def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")): + with T.sblock("scalar_add"): + c[()] = a[()] + b[()] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")): + T.func_attr({"tirx.is_scheduled": True}) + # with T.sblock("root"): + for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("scalar_add"): + vu = T.axis.spatial(1, 0) + T.reads() + T.writes() + c[()] = a[()] + b[()] + # fmt: on + # pylint: enable=no-self-argument,missing-class-docstring,line-too-long + target = tvm.target.Target("nvidia/geforce-rtx-3070") + with target, tvm.transform.PassContext(opt_level=0): + mod = DefaultGPUSchedule()(Before) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()