From 28674c6fde134860827e6be1c667f38ce8daa7c0 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 23 Aug 2025 08:27:34 -0400 Subject: [PATCH] [ROCm] Minor fixes for latest refactor This PR fixes a few ROCm and hipBLAS build issues after recent refactors. --- src/runtime/contrib/hipblas/hipblas_json_runtime.cc | 4 +++- src/runtime/contrib/hipblas/hipblas_utils.cc | 5 +++-- src/runtime/contrib/hipblas/hipblas_utils.h | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index ab8545561be4..08866fc1088a 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -22,6 +22,7 @@ * \brief A simple JSON runtime for HIPBLAS. */ +#include #include #include #include @@ -30,6 +31,7 @@ #include #include +#include "../../rocm/rocm_common.h" #include "../json/json_node.h" #include "../json/json_runtime.h" #include "hipblas_utils.h" @@ -86,7 +88,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { if (device_id == -1) { ROCM_CALL(hipGetDevice(&device_id)); } - auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(Device(kDLROCM, device_id)); + auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id}); hipStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 454ab7a3707e..1b61cbd38219 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -23,6 +23,7 @@ #include "hipblas_utils.h" #include +#include #include #include "../../rocm/rocm_common.h" @@ -41,7 +42,7 @@ HipBlasThreadEntry::~HipBlasThreadEntry() { typedef dmlc::ThreadLocalStore HipBlasThreadStore; -HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(Device curr_device) { +HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) { HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); TVMFFIStreamHandle stream = TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); @@ -72,7 +73,7 @@ HipBlasLtThreadEntry::~HipBlasLtThreadEntry() { typedef dmlc::ThreadLocalStore HipBlasLtThreadStore; -HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(Device curr_device) { +HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) { return HipBlasLtThreadStore::Get(); } diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h b/src/runtime/contrib/hipblas/hipblas_utils.h index 66d7afafbd64..d07e825c21c8 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.h +++ b/src/runtime/contrib/hipblas/hipblas_utils.h @@ -68,7 +68,7 @@ struct HipBlasThreadEntry { HipBlasThreadEntry(); ~HipBlasThreadEntry(); hipblasHandle_t handle{nullptr}; - static HipBlasThreadEntry* ThreadLocal(); + static HipBlasThreadEntry* ThreadLocal(DLDevice curr_device); }; // HipBlasThreadEntry struct HipBlasLtThreadEntry { @@ -82,7 +82,7 @@ struct HipBlasLtThreadEntry { // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace. static constexpr const size_t workspace_size = 33554432; - static HipBlasLtThreadEntry* ThreadLocal(); + static HipBlasLtThreadEntry* ThreadLocal(DLDevice curr_device); }; // HipBlasLtThreadEntry inline hipDataType GetHipDataType(DLDataType type) {