Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4fab5ce
[Driver] Single-module lowering flow in driver_api.cc
Lunderberg Mar 24, 2023
b43209a
Clarify behavior of emit_fwd_func_decl_
Lunderberg Jun 6, 2023
6ac4259
Annotate entire body as target region if no subregions found
Lunderberg Jun 6, 2023
0f1d17f
Better detection of host target
Lunderberg Jun 12, 2023
406e973
Allow "ext_dev" to act as host.
Lunderberg Jun 15, 2023
77ade9d
Improved device region annotation
Lunderberg Jun 16, 2023
b507879
Allow call_extern from kDLCPU to kDLExtDev
Lunderberg Jun 21, 2023
5a82fe0
Save .so instead of .o in tutorials
Lunderberg Jun 21, 2023
4b75c77
Mark VTA's CPU-side allocations with disable_lower_builtin
Lunderberg Jul 5, 2023
358d459
Use int8_t* for StringImm literals.
Lunderberg Jul 6, 2023
e9a4c32
Allow call_extern in device-side functions
Lunderberg Jul 7, 2023
85b1d5e
[EthosU] Use ethos-u as both device and host
Lunderberg Jul 12, 2023
60c866a
Merge branch 'single_module_during_build_pr_14985' into HEAD
Lunderberg Aug 8, 2023
775742d
[TIR] End-to-end tests for PrimFunc-to-PrimFunc subroutines
Lunderberg May 15, 2023
210b674
Updated example of PrintFuncPrefix
Lunderberg Jun 16, 2023
3015ab0
Handle missing global symbol in ExtractFuncInfo
Lunderberg Jun 16, 2023
92b74ee
Annotate internal functions with __device__ instead of __global__
Lunderberg Jun 16, 2023
a8815f2
AnnotateDeviceRegions, use host to call exposed device functions
Lunderberg Jun 16, 2023
c4e4fd4
Relax cuda requirement for CallingConv::kDeviceKernelLaunch
Lunderberg Jun 16, 2023
3d6f97e
Enable unit test for internal GPU subroutine
Lunderberg Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ext_dev():
def check_llvm():
if not tvm.testing.device_enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
f = tvm.build(s, [A, B], "ext_dev", "ext_dev")
dev = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ using tvm::transform::Pass;
* \param target The device Target.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
Optional<Target> target = NullOpt);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
primfunc = primfunc.with_attr(
"target", tvm.target.Target(compiler_name, host=compiler_name)
)
return primfunc

def __call__(self, *args, **kwargs):
Expand Down
227 changes: 142 additions & 85 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
return pass_list;
}

IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
auto optimize = tvm::transform::Sequential(pass_list);
mod = optimize(std::move(mod));
return mod;
}

IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
mod = seq(std::move(mod));
return mod;
}

// Convert te schedule to IRModule
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Expand Down Expand Up @@ -339,7 +328,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")

IRModule LowerModule(IRModule mod, bool simple_mode) {
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
tvm::transform::Sequential optimize(pass_list, "tvm.lower");
return optimize(std::move(mod));
}

TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
Expand All @@ -356,10 +346,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));

// Get the pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_primfunc")
Expand All @@ -381,9 +368,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply, bool simple_mode) {
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply);
// Get the legacy TE pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_schedule")
Expand All @@ -400,35 +385,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
simple_mode);
});

/**
* This function takes the input module that contains both the device and host opts.
* Then, it applies transformation on the original module before splitting into separate modules for
* device and host. Then it also applies transformations on the new splitted modules.
*/
std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);

ICHECK(mod_mixed.defined()) << "This module must be defined";
IRModule MergeModules(const Map<Target, IRModule>& inputs) {
if (inputs.size() == 1) {
auto [target, mod] = *inputs.begin();
return tir::transform::BindTarget(target)(mod);
}

mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
// Take the attrs from the first module so the eventual modules have them.
IRModule first_module = (*inputs.begin()).second;
IRModule merged = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
for (auto [target, mod] : inputs) {
mod = tir::transform::BindTarget(target)(mod);
merged->Update(mod);
}

IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
return merged;
}

auto keys = target->GetKeys();
Map<Target, IRModule> SplitModule(const IRModule& module) {
Map<String, IRModule> split;

CheckAndUpdateHostConsistency(&target, &target_host);
for (auto [gvar, base_func] : module->functions) {
auto target_str = base_func->GetAttr<Target>(tvm::attr::kTarget).value()->str();
if (auto it = split.find(target_str); it != split.end()) {
(*it).second->Add(gvar, base_func);
} else {
split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs));
}
}

bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && device_mod->functions.size() == 0) {
DLOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
Map<Target, IRModule> out;
for (auto [str, mod] : split) {
out.Set(Target(str), mod);
}

return {host_mod, device_mod};
return out;
}

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
Expand Down Expand Up @@ -457,52 +449,86 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
// Update target host for all targets
CheckAndUpdateHostConsistency(&inputs, &target_host);

// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

ICHECK(mhost_all.defined()) << "The host module must be defined";

for (const auto& it : inputs) {
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

ICHECK(host_mod.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";

// We don't want library modules going back into host codegen
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
bool overrides_host_target =
target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
mhost_all->Update(host_mod);
auto has_gpu_function = [](const IRModule& mod) -> bool {
for (const auto& [gvar, func] : mod->functions) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
if (target.value()->HasKey("gpu")) {
return true;
}
}
}
return false;
};

IRModule merged = MergeModules(inputs);

bool contains_gpu_function_pre = has_gpu_function(merged);
merged = MixedModulePassManager(merged)(merged);
bool contains_gpu_function_post = has_gpu_function(merged);
if (contains_gpu_function_pre && !contains_gpu_function_post) {
DLOG(WARNING) << "Specified GPU targets, "
<< "but cannot find device code. Did you forget to bind?";
}

Map<Target, IRModule> split = SplitModule(merged);

Map<Target, runtime::Module> built;
for (const auto& [target, mod] : split) {
built.Set(target, codegen::Build(mod, target));
}

auto host_target = [&]() -> Target {
// All targets that contain a kIsEntryFunc=True function
Array<Target> targets_with_entry_func;

// All targets that can run on the CPU and contain at least one
// function without kIsEntryFunc=False.
Array<Target> cpu_targets;
for (const auto& [target, mod] : split) {
bool contains_entry_func = false;
bool may_contain_entry_func = false;
for (const auto& [gvar, func] : mod->functions) {
Optional<Bool> is_entry_func = func->attrs.GetAttr<Bool>(tvm::tir::attr::kIsEntryFunc);
if (is_entry_func.defined() && is_entry_func.value()->value) {
contains_entry_func = true;
} else if (!is_entry_func.defined()) {
may_contain_entry_func = true;
}
}

if (contains_entry_func) {
targets_with_entry_func.push_back(target);
}

if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
if (may_contain_entry_func && target->HasKey("cpu")) {
cpu_targets.push_back(target);
}
}
}

runtime::Module mhost = codegen::Build(mhost_all, target_host);
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
if (targets_with_entry_func.size()) {
ICHECK_EQ(targets_with_entry_func.size(), 1)
<< "Expected at most one function "
<< "annotated with tvm::tir::attr::kIsEntryFunc "
<< "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), "
<< "but found: " << targets_with_entry_func;
return targets_with_entry_func[0];
} else if (cpu_targets.size() == 1) {
return cpu_targets[0];
} else {
LOG(FATAL) << "Could not determine which target is the host. "
<< "No function was annotated with tvm::tir::attr::kIsEntryFunc (\""
<< tvm::tir::attr::kIsEntryFunc << "\"), "
<< "and " << cpu_targets.size() << " targets have the 'cpu' key";
}
}
}();

return mhost;
auto runtime_module = built[host_target];
for (const auto& [target, mod] : built) {
if (!mod.same_as(runtime_module)) {
runtime_module.Import(mod);
}
}
return runtime_module;
}

TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
Expand Down Expand Up @@ -543,18 +569,20 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional<Target> target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

if (target) {
mixed_pass_list.push_back(tir::transform::BindTarget(target.value()));
}

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());

mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
Expand Down Expand Up @@ -600,7 +628,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
// Only applies to the device functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerWarpMemory());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin());

// Apply to both host and device functions
mixed_pass_list.push_back(tir::transform::Simplify());
mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes());
mixed_pass_list.push_back(tir::transform::LowerIntrin());
mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::CombineContextCall());
if (pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value()) {
mixed_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(mixed_pass_list, "tvm.build");
}

TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
Expand All @@ -609,6 +658,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Expand All @@ -634,7 +687,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
return transform::Sequential(host_pass_list, "tir.host_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.host_mod_passes")
Expand All @@ -643,6 +696,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")
});

transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

Array<Pass> device_pass_list;
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
Expand All @@ -658,7 +715,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
device_pass_list.push_back(tir::transform::LowerIntrin());

return transform::Sequential(device_pass_list);
return transform::Sequential(device_pass_list, "tir.device_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.device_mod_passes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module),
host_target_(host_target),
custom_target_(Target("example_target_hook")) {}
custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Expand Down
10 changes: 8 additions & 2 deletions src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,14 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;

if (auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
info.name = global_symbol.value();
} else {
info.name = kv.first->name_hint;
}

fmap[info.name] = info;
}
return fmap;
}
Expand Down
Loading