Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class DeviceAPIManager {
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
static DeviceAPIManager* inst = new DeviceAPIManager();
return inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;

static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst = std::make_shared<CPUDeviceAPI>();
static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CPUDeviceAPI();
return inst;
}
};
Expand All @@ -99,7 +101,7 @@ void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
}

TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
DeviceAPI* ptr = CPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
10 changes: 6 additions & 4 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<CUDADeviceAPI>& Global() {
static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>();
static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CUDADeviceAPI();
return inst;
}

Expand All @@ -230,12 +232,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
8 changes: 5 additions & 3 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class HexagonDeviceAPI : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final;
void FreeWorkspace(TVMContext ctx, void* ptr) final;

static const std::shared_ptr<HexagonDeviceAPI>& Global() {
static std::shared_ptr<HexagonDeviceAPI> inst = std::make_shared<HexagonDeviceAPI>();
static HexagonDeviceAPI* Global() {
// NOTE: explicitly use new to avoid destruction of global state
// Global state will be recycled by OS as the process exits.
static HexagonDeviceAPI* inst = new HexagonDeviceAPI();
return inst;
}
};
Expand Down Expand Up @@ -121,7 +123,7 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
}

TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
DeviceAPI* ptr = HexagonDeviceAPI::Global();
*rv = ptr;
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class MetalWorkspace final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global();
static MetalWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
namespace runtime {
namespace metal {

const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
static std::shared_ptr<MetalWorkspace> inst = std::make_shared<MetalWorkspace>();
MetalWorkspace* MetalWorkspace::Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}

Expand Down Expand Up @@ -273,7 +275,7 @@ int GetWarpSize(id<MTLDevice> dev) {
MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MetalWorkspace::Global().get();
DeviceAPI* ptr = MetalWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
}
// get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
CHECK_LT(device_id, w->devices.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -168,7 +168,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = metal::MetalWorkspace::Global().get();
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/micro/micro_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class MicroDeviceAPI final : public DeviceAPI {
* \brief obtain a global singleton of MicroDeviceAPI
* \return global shared pointer to MicroDeviceAPI
*/
static const std::shared_ptr<MicroDeviceAPI>& Global() {
static std::shared_ptr<MicroDeviceAPI> inst = std::make_shared<MicroDeviceAPI>();
static MicroDeviceAPI* Global() {
static MicroDeviceAPI* inst = new MicroDeviceAPI();
return inst;
}

Expand All @@ -155,7 +155,7 @@ class MicroDeviceAPI final : public DeviceAPI {

// register device that can be obtained from Python frontend
TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MicroDeviceAPI::Global().get();
DeviceAPI* ptr = MicroDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/aocl/aocl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AOCLWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for AOCL */
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/aocl/aocl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& AOCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<AOCLWorkspace>();
OpenCLWorkspace* AOCLWorkspace::Global() {
static OpenCLWorkspace* inst = new AOCLWorkspace();
return inst;
}

Expand All @@ -49,7 +49,7 @@ typedef dmlc::ThreadLocalStore<AOCLThreadEntry> AOCLThreadStore;
AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = AOCLWorkspace::Global().get();
DeviceAPI* ptr = AOCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 2 additions & 4 deletions src/runtime/opencl/aocl/aocl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ class AOCLModuleNode : public OpenCLModuleNode {
explicit AOCLModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
cl::OpenCLWorkspace* GetGlobalWorkspace() final;
};

const std::shared_ptr<cl::OpenCLWorkspace>& AOCLModuleNode::GetGlobalWorkspace() {
return cl::AOCLWorkspace::Global();
}
cl::OpenCLWorkspace* AOCLModuleNode::GetGlobalWorkspace() { return cl::AOCLWorkspace::Global(); }

Module AOCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
Expand Down
9 changes: 4 additions & 5 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class OpenCLWorkspace : public DeviceAPI {
virtual OpenCLThreadEntry* GetThreadEntry();

// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand All @@ -265,8 +265,7 @@ class OpenCLThreadEntry {
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
OpenCLThreadEntry(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
: pool(device_type, device) {
OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device) : pool(device_type, device) {
context.device_id = 0;
context.device_type = device_type;
}
Expand Down Expand Up @@ -298,7 +297,7 @@ class OpenCLModuleNode : public ModuleNode {
/*!
* \brief Get the global workspace
*/
virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace();
virtual cl::OpenCLWorkspace* GetGlobalWorkspace();

const char* type_key() const final { return workspace_->type_key.c_str(); }

Expand All @@ -315,7 +314,7 @@ class OpenCLModuleNode : public ModuleNode {
private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
std::shared_ptr<cl::OpenCLWorkspace> workspace_;
cl::OpenCLWorkspace* workspace_;
// the binary data
std::string data_;
// The format
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
OpenCLWorkspace* OpenCLWorkspace::Global() {
static OpenCLWorkspace* inst = new OpenCLWorkspace();
return inst;
}

Expand Down Expand Up @@ -276,7 +276,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
}

TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global().get();
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OpenCLWrappedFunc {
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
w_ = m->GetGlobalWorkspace().get();
w_ = m->GetGlobalWorkspace();
m_ = m;
sptr_ = sptr;
entry_ = entry;
Expand Down Expand Up @@ -110,7 +110,7 @@ OpenCLModuleNode::~OpenCLModuleNode() {
}
}

const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace() {
cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() {
return cl::OpenCLWorkspace::Global();
}

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/sdaccel/sdaccel_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SDAccelWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for SDAccel*/
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/sdaccel/sdaccel_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<SDAccelWorkspace>();
OpenCLWorkspace* SDAccelWorkspace::Global() {
static OpenCLWorkspace* inst = new SDAccelWorkspace();
return inst;
}

Expand All @@ -47,7 +47,7 @@ typedef dmlc::ThreadLocalStore<SDAccelThreadEntry> SDAccelThreadStore;
SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = SDAccelWorkspace::Global().get();
DeviceAPI* ptr = SDAccelWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/opencl/sdaccel/sdaccel_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class SDAccelModuleNode : public OpenCLModuleNode {
explicit SDAccelModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
cl::OpenCLWorkspace* GetGlobalWorkspace() final;
};

const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() {
cl::OpenCLWorkspace* SDAccelModuleNode::GetGlobalWorkspace() {
return cl::SDAccelWorkspace::Global();
}

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<ROCMDeviceAPI>& Global() {
static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
static ROCMDeviceAPI* Global() {
static ROCMDeviceAPI* inst = new ROCMDeviceAPI();
return inst;
}

Expand All @@ -197,7 +197,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
DeviceAPI* ptr = ROCMDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ class VulkanDeviceAPI final : public DeviceAPI {
VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
}

static const std::shared_ptr<VulkanDeviceAPI>& Global() {
static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
static VulkanDeviceAPI* Global() {
static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
return inst;
}

Expand Down Expand Up @@ -1159,7 +1159,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModul
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);

TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
DeviceAPI* ptr = VulkanDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/workspace_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class WorkspacePool::Pool {
std::vector<Entry> allocated_;
};

WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device)
: device_type_(device_type), device_(device) {}

WorkspacePool::~WorkspacePool() {
Expand All @@ -143,7 +143,7 @@ WorkspacePool::~WorkspacePool() {
TVMContext ctx;
ctx.device_type = device_type_;
ctx.device_id = static_cast<int>(i);
array_[i]->Release(ctx, device_.get());
array_[i]->Release(ctx, device_);
delete array_[i];
}
}
Expand All @@ -156,7 +156,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) {
if (array_[ctx.device_id] == nullptr) {
array_[ctx.device_id] = new Pool();
}
return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
return array_[ctx.device_id]->Alloc(ctx, device_, size);
}

void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
Expand Down
Loading