diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 13dc02fde4a1..bf29556768c5 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -465,30 +465,20 @@ void LinkModules(ObjectPtr exec, const ffi::Map& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. - std::unordered_map> const_vars_by_symbol; + ffi::Map> const_vars_by_symbol; for (tvm::ffi::Module mod : ext_libs) { auto pf_sym = mod->GetFunction("get_symbol"); auto pf_var = mod->GetFunction("get_const_vars"); - std::vector symbol_const_vars; if (pf_sym.has_value() && pf_var.has_value()) { ffi::String symbol = (*pf_sym)().cast(); ffi::Array variables = (*pf_var)().cast>(); - for (size_t i = 0; i < variables.size(); i++) { - symbol_const_vars.push_back(variables[i].operator std::string()); - } - TVM_FFI_ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) - << "Found duplicated symbol: " << symbol; - const_vars_by_symbol[symbol] = symbol_const_vars; + TVM_FFI_ICHECK(!const_vars_by_symbol.count(symbol)) << "Found duplicated symbol: " << symbol; + const_vars_by_symbol.Set(symbol, variables); } } if (!const_vars_by_symbol.empty() || !params.empty()) { // need runtime const information, run link const loader - std::unordered_map const_var_tensor; - for (const auto& [name, param] : params) { - const_var_tensor[name] = param; - } - ffi::Module const_loader_mod = - runtime::ConstLoaderModuleCreate(const_var_tensor, const_vars_by_symbol); + ffi::Module const_loader_mod = runtime::ConstLoaderModuleCreate(params, const_vars_by_symbol); const_loader_mod->ImportModule(lib); for (const auto& it : ext_libs) { const_loader_mod->ImportModule(it); diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 006c1f1e1acd..aaaeb9737e2d 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -39,6 +39,9 @@ #include #include +#include +#include +#include #include "../support/bytes_io.h" diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index c97232016d8a..6722785cc950 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -29,13 +29,10 @@ #include #include #include +#include #include #include -#include -#include -#include - namespace tvm { namespace runtime { @@ -52,26 +49,13 @@ namespace runtime { * The creator is always available (ConstLoaderModule is a runtime-universal module). */ inline ffi::Module ConstLoaderModuleCreate( - const std::unordered_map& const_var_tensor, - const std::unordered_map>& const_vars_by_symbol) { + const ffi::Map& const_var_tensor, + const ffi::Map>& const_vars_by_symbol) { static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.const_loader"); TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) << "ffi.Module.create.const_loader is not registered in runtime. " << "Ensure libtvm_runtime is loaded."; - // Convert to FFI-compatible types. - ffi::Map ffi_const_var_tensor; - for (const auto& kv : const_var_tensor) { - ffi_const_var_tensor.Set(kv.first, kv.second); - } - ffi::Map> ffi_const_vars_by_symbol; - for (const auto& kv : const_vars_by_symbol) { - ffi::Array vars; - for (const auto& v : kv.second) { - vars.push_back(ffi::String(v)); - } - ffi_const_vars_by_symbol.Set(kv.first, vars); - } - return (*fcreate)(ffi_const_var_tensor, ffi_const_vars_by_symbol).cast(); + return (*fcreate)(const_var_tensor, const_vars_by_symbol).cast(); } } // namespace runtime diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index fe9454f674d1..3f4b3965adc5 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -28,11 +28,6 @@ #include #include -#include -#include -#include -#include - #include "../metadata.h" namespace tvm { diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 6697badd4885..9d16ea9231b2 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -28,10 +28,8 @@ #include #include -#include #include #include -#include #include "../../support/bytes_io.h" #include "../metadata.h" @@ -74,7 +72,7 @@ inline ffi::Module OpenCLModuleCreate(ffi::String data, ffi::String fmt, */ inline ffi::Module OpenCLModuleCreate( const std::unordered_map& shaders, - const std::string& spirv_text, ffi::Map fmap) { + const ffi::String& spirv_text, const ffi::Map& fmap) { static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.opencl.spirv"); TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) << "ffi.Module.create.opencl.spirv is not registered in runtime. " @@ -87,7 +85,7 @@ inline ffi::Module OpenCLModuleCreate( strm.Write(kv.second); shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf))); } - return (*fcreate)(shader_bytes, ffi::String(spirv_text), fmap).cast(); + return (*fcreate)(shader_bytes, spirv_text, fmap).cast(); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index 87df473753d4..d8fdda4d9251 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -48,9 +48,9 @@ namespace vulkan { * and rehydrated on the runtime side. * Requires libtvm_runtime built with USE_VULKAN=ON to have registered the creator. */ -inline ffi::Module VulkanModuleCreate(std::unordered_map smap, - ffi::Map fmap, - std::string source) { +inline ffi::Module VulkanModuleCreate(const std::unordered_map& smap, + const ffi::Map& fmap, + const ffi::String& source) { static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.vulkan"); TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) << "ffi.Module.create.vulkan is not registered in runtime. " @@ -63,7 +63,7 @@ inline ffi::Module VulkanModuleCreate(std::unordered_map(); + return (*fcreate)(shader_bytes, fmap, source).cast(); } } // namespace vulkan