@@ -132,22 +132,22 @@ def _get_cuda_arch_flags(is_gemm: bool = False) -> Tuple[List[str], List[Tuple[i
132132 _arch_list = "3.7;5.0;5.2;6.0;6.1;7.0;7.5+PTX"
133133 elif (major , minor ) < (11 , 8 ):
134134 _arch_list = "5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
135- elif (major , minor ) < (12 , 0 ):
135+ elif (major , minor ) < (12 , 8 ):
136136 _arch_list = "6.0;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
137137 else :
138138 # remove sm < 70 prebuilt gemm kernels in CUDA 12.
139139 # these gemm kernels will be compiled via nvrtc.
140- _arch_list = "6.0;7.0;7. 5;8.0;8.6;8.9;9.0+PTX"
140+ _arch_list = "7. 5;8.0;8.6;8.9;9.0;10.0;12 .0+PTX"
141141 else :
142142 # flag for non-gemm kernels, they are usually simple and small.
143143 if (major , minor ) < (11 , 0 ):
144144 _arch_list = "3.5;3.7;5.0;5.2;6.0;6.1;7.0;7.5+PTX"
145145 elif (major , minor ) < (11 , 8 ):
146146 _arch_list = "3.5;3.7;5.0;5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
147- elif (major , minor ) < (12 , 0 ):
147+ elif (major , minor ) < (12 , 8 ):
148148 _arch_list = "5.0;5.2;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX"
149149 else :
150- _arch_list = "5.0;5.2;6.0;6.1;7.0;7. 5;8.0;8.6;8.9;9.0+PTX"
150+ _arch_list = "7. 5;8.0;8.6;8.9;9.0;10.0;12 .0+PTX"
151151 _all_arch = "5.2;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
152152 for named_arch , archval in named_arches .items ():
153153 _all_arch = _all_arch .replace (named_arch , archval )
0 commit comments