From 3e2c503523ab5b7ca1a35ae3f9c7f200a86aff12 Mon Sep 17 00:00:00 2001 From: Amod Wani Date: Mon, 16 Dec 2024 17:39:49 +0530 Subject: [PATCH 1/4] mulyi gpu example --- .../examples/simple_multi_gpu_example.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 cuda_core/examples/simple_multi_gpu_example.py diff --git a/cuda_core/examples/simple_multi_gpu_example.py b/cuda_core/examples/simple_multi_gpu_example.py new file mode 100644 index 00000000000..7a4f0c67e8f --- /dev/null +++ b/cuda_core/examples/simple_multi_gpu_example.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# coding: utf-8 + +from cuda.core.experimental import Device +from cuda import cuda, cudart, nvrtc +from cuda.core.experimental import LaunchConfig, launch +from cuda.core.experimental import Program +from cuda.core.experimental._memory import Buffer + +import cupy as cp + +dtype = cp.float32 +size = 50000 + +# Set GPU0 +dev0 = Device(0) +dev0.set_current() +stream0 = dev0.create_stream() + +# allocate memory to GPU0 +a = cp.random.random(size, dtype=dtype) +b = cp.random.random(size, dtype=dtype) +c = cp.empty_like(a) + +# Set GPU1 +dev1 = Device(1) +dev1.set_current() +stream1 = dev1.create_stream() + +# allocate memory to GPU1 +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +z = cp.empty_like(a) + +# compute c = a + b +code_add = """ +template +__global__ void vector_add(const T* A, + const T* B, + T* C, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i +__global__ void vector_sub(const T* A, + const T* B, + T* C, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i",)) + +# run in single precision +ker_add = mod_add.get_kernel("vector_add") + +prog_sub = Program(code_sub, code_type="c++") +mod_sub = prog_sub.compile( + "cubin", + options=("-std=c++17", "-arch=sm_" + "".join(f"{i}" for i in dev1.compute_capability),), + name_expressions=("vector_sub",)) + +# run in single precision +ker_sub = mod_sub.get_kernel("vector_sub") + + +# Synchronize devices to ensure that memory has been created +dev0.sync() +dev1.sync() + +block = 256 +grid0 = (size + block - 1) // block +grid1 = (size + block - 1) // block + +config0 = LaunchConfig(grid=grid0, block=block, stream=stream0) +config1 = LaunchConfig(grid=grid1, block=block, stream=stream1) + +# First we update device 0 data with host data +dev0.set_current() +launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) +stream0.sync() +# Validate result +assert cp.allclose(c, a + b) + +dev1.set_current() +launch(ker_sub, config1, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size)) +stream1.sync() +assert cp.allclose(z, x -y) + +print('done') From 0d60da3bed53cd95a33f7433c205bbfec97644a2 Mon Sep 17 00:00:00 2001 From: Amod Wani Date: Tue, 17 Dec 2024 14:06:28 +0530 Subject: [PATCH 2/4] Cleanup and linted --- cuda_core/docs/source/conf.py | 4 +- .../examples/simple_multi_gpu_example.py | 69 ++++++++++--------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/cuda_core/docs/source/conf.py b/cuda_core/docs/source/conf.py index 4b3e17aea3d..ca59d588b92 100644 --- a/cuda_core/docs/source/conf.py +++ b/cuda_core/docs/source/conf.py @@ -94,9 +94,11 @@ section_titles = ["Returns"] + + def autodoc_process_docstring(app, what, name, obj, options, lines): if name.startswith("cuda.core.experimental.system"): - # patch the docstring (in lines) *in-place*. Should docstrings include section titles other than "Returns", + # patch the docstring (in lines) *in-place*. Should docstrings include section titles other than "Returns", # this will need to be modified to handle them. attr = name.split(".")[-1] from cuda.core.experimental._system import System diff --git a/cuda_core/examples/simple_multi_gpu_example.py b/cuda_core/examples/simple_multi_gpu_example.py index 7a4f0c67e8f..1b778bae0f8 100644 --- a/cuda_core/examples/simple_multi_gpu_example.py +++ b/cuda_core/examples/simple_multi_gpu_example.py @@ -1,14 +1,11 @@ -#!/usr/bin/env python -# coding: utf-8 - -from cuda.core.experimental import Device -from cuda import cuda, cudart, nvrtc -from cuda.core.experimental import LaunchConfig, launch -from cuda.core.experimental import Program -from cuda.core.experimental._memory import Buffer +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE import cupy as cp +from cuda.core.experimental import Device, LaunchConfig, Program, launch + dtype = cp.float32 size = 50000 @@ -17,7 +14,7 @@ dev0.set_current() stream0 = dev0.create_stream() -# allocate memory to GPU0 +# Allocate memory to GPU0 a = cp.random.random(size, dtype=dtype) b = cp.random.random(size, dtype=dtype) c = cp.empty_like(a) @@ -27,17 +24,17 @@ dev1.set_current() stream1 = dev1.create_stream() -# allocate memory to GPU1 +# Allocate memory to GPU1 x = cp.random.random(size, dtype=dtype) y = cp.random.random(size, dtype=dtype) z = cp.empty_like(a) # compute c = a + b code_add = """ -template -__global__ void vector_add(const T* A, - const T* B, - T* C, +extern "C" +__global__ void vector_add(const float* A, + const float* B, + float* C, size_t N) { const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; for (size_t i=tid; i -__global__ void vector_sub(const T* A, - const T* B, - T* C, +extern "C" +__global__ void vector_sub(const *float A, + const float* B, + float* C, size_t N) { const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; for (size_t i=tid; i",)) + options=( + "-std=c++17", + "-arch=sm_" + arch0, + ), +) # run in single precision -ker_add = mod_add.get_kernel("vector_add") +ker_add = mod_add.get_kernel("vector_add") +arch1 = "".join(f"{i}" for i in dev1.compute_capability) prog_sub = Program(code_sub, code_type="c++") mod_sub = prog_sub.compile( "cubin", - options=("-std=c++17", "-arch=sm_" + "".join(f"{i}" for i in dev1.compute_capability),), - name_expressions=("vector_sub",)) + options=( + "-std=c++17", + "-arch=sm_" + arch1, + ), +) # run in single precision -ker_sub = mod_sub.get_kernel("vector_sub") - +ker_sub = mod_sub.get_kernel("vector_sub") # Synchronize devices to ensure that memory has been created dev0.sync() dev1.sync() block = 256 -grid0 = (size + block - 1) // block -grid1 = (size + block - 1) // block +grid = (size + block - 1) // block -config0 = LaunchConfig(grid=grid0, block=block, stream=stream0) -config1 = LaunchConfig(grid=grid1, block=block, stream=stream1) +config0 = LaunchConfig(grid=grid, block=block, stream=stream0) +config1 = LaunchConfig(grid=grid, block=block, stream=stream1) -# First we update device 0 data with host data +# Launch GPU0 and Synchronize the stream dev0.set_current() launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) stream0.sync() + # Validate result assert cp.allclose(c, a + b) +# Launch GPU1 and Synchronize the stream dev1.set_current() launch(ker_sub, config1, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size)) stream1.sync() -assert cp.allclose(z, x -y) +assert cp.allclose(z, x - y) -print('done') +print("done") From 401491aec558353110e43bd7ac40fbe0b5d2b793 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 2 Jan 2025 17:35:41 +0000 Subject: [PATCH 3/4] avoid unnecessary sync --- .../examples/simple_multi_gpu_example.py | 116 +++++++++++------- 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/cuda_core/examples/simple_multi_gpu_example.py b/cuda_core/examples/simple_multi_gpu_example.py index 1b778bae0f8..89dd6cbea63 100644 --- a/cuda_core/examples/simple_multi_gpu_example.py +++ b/cuda_core/examples/simple_multi_gpu_example.py @@ -2,34 +2,25 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import sys + import cupy as cp -from cuda.core.experimental import Device, LaunchConfig, Program, launch +from cuda.core.experimental import Device, LaunchConfig, Program, launch, system + +if system.num_devices < 2: + print("this example requires at least 2 GPUs", file=sys.stderr) + sys.exit(0) dtype = cp.float32 size = 50000 -# Set GPU0 +# Set GPU 0 dev0 = Device(0) dev0.set_current() stream0 = dev0.create_stream() -# Allocate memory to GPU0 -a = cp.random.random(size, dtype=dtype) -b = cp.random.random(size, dtype=dtype) -c = cp.empty_like(a) - -# Set GPU1 -dev1 = Device(1) -dev1.set_current() -stream1 = dev1.create_stream() - -# Allocate memory to GPU1 -x = cp.random.random(size, dtype=dtype) -y = cp.random.random(size, dtype=dtype) -z = cp.empty_like(a) - -# compute c = a + b +# Compile a kernel targeting GPU 0 to compute c = a + b code_add = """ extern "C" __global__ void vector_add(const float* A, @@ -42,11 +33,26 @@ } } """ +arch0 = "".join(f"{i}" for i in dev0.compute_capability) +prog_add = Program(code_add, code_type="c++") +mod_add = prog_add.compile( + "cubin", + options=( + "-std=c++17", + "-arch=sm_" + arch0, + ), +) +ker_add = mod_add.get_kernel("vector_add") -# compute c = a - b +# Set GPU 1 +dev1 = Device(1) +dev1.set_current() +stream1 = dev1.create_stream() + +# Compile a kernel targeting GPU 1 to compute c = a - b code_sub = """ extern "C" -__global__ void vector_sub(const *float A, +__global__ void vector_sub(const float* A, const float* B, float* C, size_t N) { @@ -56,20 +62,6 @@ } } """ - -arch0 = "".join(f"{i}" for i in dev0.compute_capability) -prog_add = Program(code_add, code_type="c++") -mod_add = prog_add.compile( - "cubin", - options=( - "-std=c++17", - "-arch=sm_" + arch0, - ), -) - -# run in single precision -ker_add = mod_add.get_kernel("vector_add") - arch1 = "".join(f"{i}" for i in dev1.compute_capability) prog_sub = Program(code_sub, code_type="c++") mod_sub = prog_sub.compile( @@ -79,31 +71,63 @@ "-arch=sm_" + arch1, ), ) - -# run in single precision ker_sub = mod_sub.get_kernel("vector_sub") -# Synchronize devices to ensure that memory has been created -dev0.sync() -dev1.sync() +# This adaptor ensures that any foreign stream (ex: from CuPy) that have not +# yet supported the __cuda_stream__ protocol can still be recognized by +# cuda.core. +class StreamAdaptor: + def __init__(self, obj): + self.obj = obj + + @property + def __cuda_stream__(self): + # Note: CuPy streams have a .ptr attribute + return (0, self.obj.ptr) + + +# Create launch configs for each kernel that will be executed on the respective +# CUDA streams. block = 256 grid = (size + block - 1) // block - config0 = LaunchConfig(grid=grid, block=block, stream=stream0) config1 = LaunchConfig(grid=grid, block=block, stream=stream1) -# Launch GPU0 and Synchronize the stream +# Allocate memory on GPU 0 +# Note: This runs on CuPy's current stream for GPU 0 dev0.set_current() -launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) -stream0.sync() +a = cp.random.random(size, dtype=dtype) +b = cp.random.random(size, dtype=dtype) +c = cp.empty_like(a) +cp_stream0 = StreamAdaptor(cp.cuda.get_current_stream()) -# Validate result -assert cp.allclose(c, a + b) +# Establish a stream order to ensure that memory has been initialized before +# accessed by the kernel. +stream0.wait(cp_stream0) -# Launch GPU1 and Synchronize the stream +# Launch the add kernel on GPU 0 / stream 0 +launch(ker_add, config0, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) + +# Allocate memory on GPU 1 +# Note: This runs on CuPy's current stream for GPU 1. dev1.set_current() +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +z = cp.empty_like(a) +cp_stream1 = StreamAdaptor(cp.cuda.get_current_stream()) + +# Establish a stream order +stream1.wait(cp_stream1) + +# Launch the subtract kernel on GPU 1 / stream 1 launch(ker_sub, config1, x.data.ptr, y.data.ptr, z.data.ptr, cp.uint64(size)) + +# Synchronize both GPUs are validate the results +dev0.set_current() +stream0.sync() +assert cp.allclose(c, a + b) +dev1.set_current() stream1.sync() assert cp.allclose(z, x - y) From fde06d591082236a045daf70d09c0c0edf6dea01 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 2 Jan 2025 18:07:37 -0500 Subject: [PATCH 4/4] fix TLS delete condition --- cuda_core/tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 30d80f6f820..dc50585ab84 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -41,8 +41,9 @@ def _device_unset_current(): # no active context, do nothing return handle_return(driver.cuCtxPopCurrent()) - with _device._tls_lock: - del _device._tls.devices + if hasattr(_device._tls, "devices"): + with _device._tls_lock: + del _device._tls.devices @pytest.fixture(scope="function")