diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 9fd7a7428588..cfa4143b66c3 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -252,7 +252,7 @@ def _vmlink( runtime=_autodetect_system_lib_req(target, system_lib), ) for ext_mod in ext_libs: - if ext_mod.type_key == "cuda": + if ext_mod.is_device_module: tir_ext_libs.append(ext_mod) else: relax_ext_libs.append(ext_mod) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 2c3eff700009..ca151293bbbd 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -274,6 +274,10 @@ def is_runnable(self): """ return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 + @property + def is_device_module(self): + return self.type_key in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] + @property def is_dso_exportable(self): """Returns true if module is 'DSO exportable', ie can be included in result of