Skip to content

Commit 1c27844

Browse files
prepare cuda 12.8
1 parent d65cc41 commit 1c27844

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

.github/workflows/build.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
matrix:
1616
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
17-
cuda-version: ['11.4', '11.8', '12.1', '12.4', '12.6', '']
17+
cuda-version: ['11.4', '11.8', '12.1', '12.4', '12.6', '12.8', '']
1818
steps:
1919
- uses: actions/checkout@master
2020
- name: Install CUDA
@@ -91,7 +91,7 @@ jobs:
9191
strategy:
9292
matrix:
9393
python-version: ['3.12'] # this version is only used for upload.
94-
cuda-version: ['114', '118', '121', '124', '126', '']
94+
cuda-version: ['114', '118', '121', '124', '126', '128', '']
9595

9696
steps:
9797
- uses: actions/checkout@master
@@ -112,7 +112,7 @@ jobs:
112112
PYTHON_VERSION: ${{ matrix.python-version }}
113113
DOCKER_IMAGE: scrin/manylinux2014-cuda:cu${{ matrix.cuda-version }}-devel-1.0.0
114114
PLAT: ${{ matrix.cuda-version > '123' && 'manylinux_2_28_x86_64' || 'manylinux2014_x86_64' }}
115-
if: (github.event_name == 'push' && (startsWith(github.ref, 'refs/tags')) && (env.CUDA_VERSION != '') ) || env.CUDA_VERSION == '126'
115+
if: (github.event_name == 'push' && (startsWith(github.ref, 'refs/tags')) && (env.CUDA_VERSION != '') ) || env.CUDA_VERSION == '128'
116116
run: |
117117
# clone nvidia cuda cccl to third_party/
118118
if [ $CUDA_VERSION -lt "120" ]; then

cumm/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,22 @@ def _get_cuda_arch_flags(is_gemm: bool = False) -> Tuple[List[str], List[Tuple[i
132132
_arch_list = "3.7;5.0;5.2;6.0;6.1;7.0;7.5+PTX"
133133
elif (major, minor) < (11, 8):
134134
_arch_list = "5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
135-
elif (major, minor) < (12, 0):
135+
elif (major, minor) < (12, 8):
136136
_arch_list = "6.0;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
137137
else:
138138
# remove sm < 70 prebuilt gemm kernels in CUDA 12.
139139
# these gemm kernels will be compiled via nvrtc.
140-
_arch_list = "6.0;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
140+
_arch_list = "7.5;8.0;8.6;8.9;9.0;10.0;12.0+PTX"
141141
else:
142142
# flag for non-gemm kernels, they are usually simple and small.
143143
if (major, minor) < (11, 0):
144144
_arch_list = "3.5;3.7;5.0;5.2;6.0;6.1;7.0;7.5+PTX"
145145
elif (major, minor) < (11, 8):
146146
_arch_list = "3.5;3.7;5.0;5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
147-
elif (major, minor) < (12, 0):
147+
elif (major, minor) < (12, 8):
148148
_arch_list = "5.0;5.2;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
149149
else:
150-
_arch_list = "5.0;5.2;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
150+
_arch_list = "7.5;8.0;8.6;8.9;9.0;10.0;12.0+PTX"
151151
_all_arch = "5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
152152
for named_arch, archval in named_arches.items():
153153
_all_arch = _all_arch.replace(named_arch, archval)

0 commit comments

Comments
 (0)