Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,30 +465,20 @@ void LinkModules(ObjectPtr<VMExecutable> exec, const ffi::Map<ffi::String, runti
const tvm::ffi::Module& lib, const ffi::Array<ffi::Module>& ext_libs) {
// query if we need const loader for ext_modules
// Wrap all submodules in the initialization wrapper.
std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol;
ffi::Map<ffi::String, ffi::Array<ffi::String>> const_vars_by_symbol;
for (tvm::ffi::Module mod : ext_libs) {
auto pf_sym = mod->GetFunction("get_symbol");
auto pf_var = mod->GetFunction("get_const_vars");
std::vector<std::string> symbol_const_vars;
if (pf_sym.has_value() && pf_var.has_value()) {
ffi::String symbol = (*pf_sym)().cast<ffi::String>();
ffi::Array<ffi::String> variables = (*pf_var)().cast<ffi::Array<ffi::String>>();
for (size_t i = 0; i < variables.size(); i++) {
symbol_const_vars.push_back(variables[i].operator std::string());
}
TVM_FFI_ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U)
<< "Found duplicated symbol: " << symbol;
const_vars_by_symbol[symbol] = symbol_const_vars;
TVM_FFI_ICHECK(!const_vars_by_symbol.count(symbol)) << "Found duplicated symbol: " << symbol;
const_vars_by_symbol.Set(symbol, variables);
}
}
if (!const_vars_by_symbol.empty() || !params.empty()) {
// need runtime const information, run link const loader
std::unordered_map<std::string, runtime::Tensor> const_var_tensor;
for (const auto& [name, param] : params) {
const_var_tensor[name] = param;
}
ffi::Module const_loader_mod =
runtime::ConstLoaderModuleCreate(const_var_tensor, const_vars_by_symbol);
ffi::Module const_loader_mod = runtime::ConstLoaderModuleCreate(params, const_vars_by_symbol);
const_loader_mod->ImportModule(lib);
for (const auto& it : ext_libs) {
const_loader_mod->ImportModule(it);
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/const_loader_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#include <tvm/support/io.h>

#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>

#include "../support/bytes_io.h"

Expand Down
24 changes: 4 additions & 20 deletions src/runtime/const_loader_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/string.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/tensor.h>

#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace runtime {

Expand All @@ -52,26 +49,13 @@ namespace runtime {
* The creator is always available (ConstLoaderModule is a runtime-universal module).
*/
inline ffi::Module ConstLoaderModuleCreate(
const std::unordered_map<std::string, Tensor>& const_var_tensor,
const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) {
const ffi::Map<ffi::String, Tensor>& const_var_tensor,
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& const_vars_by_symbol) {
static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.const_loader");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.const_loader is not registered in runtime. "
<< "Ensure libtvm_runtime is loaded.";
// Convert to FFI-compatible types.
ffi::Map<ffi::String, Tensor> ffi_const_var_tensor;
for (const auto& kv : const_var_tensor) {
ffi_const_var_tensor.Set(kv.first, kv.second);
}
ffi::Map<ffi::String, ffi::Array<ffi::String>> ffi_const_vars_by_symbol;
for (const auto& kv : const_vars_by_symbol) {
ffi::Array<ffi::String> vars;
for (const auto& v : kv.second) {
vars.push_back(ffi::String(v));
}
ffi_const_vars_by_symbol.Set(kv.first, vars);
}
return (*fcreate)(ffi_const_var_tensor, ffi_const_vars_by_symbol).cast<ffi::Module>();
return (*fcreate)(const_var_tensor, const_vars_by_symbol).cast<ffi::Module>();
}

} // namespace runtime
Expand Down
5 changes: 0 additions & 5 deletions src/runtime/metal/metal_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "../metadata.h"

namespace tvm {
Expand Down
6 changes: 2 additions & 4 deletions src/runtime/opencl/opencl_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "../../support/bytes_io.h"
#include "../metadata.h"
Expand Down Expand Up @@ -74,7 +72,7 @@ inline ffi::Module OpenCLModuleCreate(ffi::String data, ffi::String fmt,
*/
inline ffi::Module OpenCLModuleCreate(
const std::unordered_map<std::string, spirv::SPIRVShader>& shaders,
const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap) {
const ffi::String& spirv_text, const ffi::Map<ffi::String, FunctionInfo>& fmap) {
static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.opencl.spirv");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.opencl.spirv is not registered in runtime. "
Expand All @@ -87,7 +85,7 @@ inline ffi::Module OpenCLModuleCreate(
strm.Write(kv.second);
shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
}
return (*fcreate)(shader_bytes, ffi::String(spirv_text), fmap).cast<ffi::Module>();
return (*fcreate)(shader_bytes, spirv_text, fmap).cast<ffi::Module>();
}
} // namespace runtime
} // namespace tvm
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/vulkan/vulkan_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ namespace vulkan {
* and rehydrated on the runtime side.
* Requires libtvm_runtime built with USE_VULKAN=ON to have registered the creator.
*/
inline ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShader> smap,
ffi::Map<ffi::String, FunctionInfo> fmap,
std::string source) {
inline ffi::Module VulkanModuleCreate(const std::unordered_map<std::string, SPIRVShader>& smap,
const ffi::Map<ffi::String, FunctionInfo>& fmap,
const ffi::String& source) {
static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.vulkan");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.vulkan is not registered in runtime. "
Expand All @@ -63,7 +63,7 @@ inline ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShade
strm.Write(kv.second);
shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
}
return (*fcreate)(shader_bytes, fmap, ffi::String(source)).cast<ffi::Module>();
return (*fcreate)(shader_bytes, fmap, source).cast<ffi::Module>();
}

} // namespace vulkan
Expand Down
Loading