diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 083a60fc3631..ab2a7f84dfc3 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a6" +version = "0.1.0a7" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index a223da90cb7e..064473e134c4 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -24,41 +24,6 @@ except ImportError: torch = None -def load_torch_get_current_cuda_stream(): - """Create a faster get_current_cuda_stream for torch through cpp extension. - """ - source = """ - #include - - int64_t get_current_cuda_stream(int device_id) { - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); - // fast invariant, default stream is always 0 - if (stream.id() == 0) return 0; - // convert to cudaStream_t - return reinterpret_cast(static_cast(stream)); - } - """ - def fallback_get_current_cuda_stream(device_id): - """Fallback with python api""" - return torch.cuda.current_stream(device_id).cuda_stream - try: - from torch.utils import cpp_extension - result = cpp_extension.load_inline( - name="get_current_cuda_stream", - cpp_sources=[source], - cuda_sources=[], - extra_cflags=["-O3"], - extra_include_paths=cpp_extension.include_paths("cuda"), - functions=["get_current_cuda_stream"], - ) - return result.get_current_cuda_stream - except Exception: - return fallback_get_current_cuda_stream - - -torch_get_current_cuda_stream = None - - cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" cdef TVMFFIByteArray bytes @@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: ctx_dev_type[0] = temp_dltensor.device.device_type ctx_dev_id[0] = temp_dltensor.device.device_id - if torch_get_current_cuda_stream is None: - torch_get_current_cuda_stream = load_torch_get_current_cuda_stream() - temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id) + # This is an API that dynamo and other uses to get the raw stream from torch + temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"):