From ac0c51c6639bb0c0c00b199f26c05a5826534d1d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Jan 2022 20:30:27 +0900 Subject: [PATCH 01/22] add int8 type in library --- python/tvm/contrib/cutlass/library.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 08cdb323c126..0d58ea724ee8 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -28,36 +28,51 @@ class GeneratorTarget(enum.Enum): class DataType(enum.Enum): f16 = enum_auto() f32 = enum_auto() + s8 = enum_auto() + u8 = enum_auto() + s32 = enum_auto() ShortDataTypeNames = { DataType.f16: "h", DataType.f32: "s", + DataType.s32: 'i', + } DataTypeNames = { DataType.f16: "f16", DataType.f32: "f32", + DataType.s8: "s8", + DataType.u8: "u8", + DataType.s32: "s32", } DataTypeTag = { DataType.f16: "cutlass::half_t", DataType.f32: "float", + DataType.s8: "int8_t", + DataType.s32: "int32_t", + DataType.u8: "uint8_t", } DataTypeSize = { DataType.f16: 16, DataType.f32: 32, + DataType.u8: 8, + DataType.s8: 8, + DataType.s32: 32, } class MathOperation(enum.Enum): multiply_add = enum_auto() - + multiply_add_saturate = enum_auto() MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', } From 0e4e8e086eeeffbe28ea291d1dc38adf3f5ef2ef Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Jan 2022 20:30:44 +0900 Subject: [PATCH 02/22] wip --- python/tvm/contrib/cutlass/gen_tensor_op.py | 69 ++++++++++++++------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6632b159febd..96d86a9a7383 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -103,28 +103,53 @@ def get_tile_descriptions(math_inst): def generate_sm80_tensor_op_16816(out_dtype, op_creator): """Generate GEMM or Conv2D kernels for Ampere.""" assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - "float16": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] + if "float" in out_dtype: + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + elif out_dtype == "int32": + # TODO: add input types + math_instructions = { + "int8": [ + MathInstruction( + [16, 8, 32], + DataType.s8, + DataType.s8, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + ], + "uint8": [ + MathInstruction( + [16, 8, 32], + DataType.u8, + DataType.u8, + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + ], + }["int8"] alignment_constraints = [8, 4, 2] From 2070d2c93c7fccfdb5ddb9751656172fa864f039 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 11:31:44 +0900 Subject: [PATCH 03/22] adding test and plumbing data and weight dtype --- python/tvm/contrib/cutlass/build.py | 8 ++ python/tvm/contrib/cutlass/gen_conv2d.py | 8 ++ python/tvm/contrib/cutlass/gen_tensor_op.py | 133 +++++++++++++------- python/tvm/contrib/cutlass/library.py | 6 +- python/tvm/relay/op/contrib/cutlass.py | 2 +- tests/python/contrib/test_cutlass.py | 50 +++++++- 6 files changed, 153 insertions(+), 54 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index e921302eafce..bd996525e77c 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -178,6 +178,8 @@ def handle_conv2d( strides, dilation, out_dtype, + data_dtype, + weight_dtype, profile_all, use_multiprocessing, ): @@ -195,6 +197,8 @@ def handle_conv2d( strides, dilation, out_dtype, + data_dtype, + weight_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) @@ -258,6 +262,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t new_attrs.update(func.attrs) arg0_shape = new_attrs["arg0_shape"] arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] if "conv2d" in op_type: new_attrs["padding"] = annotator.op_attrs.padding @@ -273,6 +279,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t annotator.op_attrs.strides, annotator.op_attrs.dilation, out_dtype, + arg0_dtype, + arg1_dtype, profile_all, use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 39db9fd01319..90bf7630f9a5 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -178,6 +178,8 @@ def select_op( stride, dilation, out_dtype, + data_dtype, + weight_dtype, profile_all=True, use_multiprocessing=False, ): @@ -208,6 +210,8 @@ def select_op( ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, + data_dtype, + weight_dtype, op_creator=enumerate_conv2d_operators, ) ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) @@ -240,6 +244,8 @@ def profile( stride, dilation, out_dtype, + data_dtype, + weight_dtype, profile_all=True, use_multiprocessing=False, ): @@ -254,6 +260,8 @@ def profile( stride, dilation, out_dtype, + data_dtype, + weight_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 96d86a9a7383..ce865feeef20 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -33,6 +33,9 @@ logger = logging.getLogger("cutlass") +dtype_map = {"int8": DataType.s8, "uint8": DataType.u8} + + def generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ): @@ -100,9 +103,8 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, op_creator): +def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator): """Generate GEMM or Conv2D kernels for Ampere.""" - assert out_dtype in ["float32", "float16"] if "float" in out_dtype: math_instructions = { "float32": [ @@ -126,59 +128,96 @@ def generate_sm80_tensor_op_16816(out_dtype, op_creator): ) ], }[out_dtype] - elif out_dtype == "int32": - # TODO: add input types - math_instructions = { - "int8": [ - MathInstruction( - [16, 8, 32], - DataType.s8, - DataType.s8, - DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_saturate, - ), - ], - "uint8": [ - MathInstruction( - [16, 8, 32], - DataType.u8, - DataType.u8, - DataType.s32, - OpcodeClass.TensorOp, - MathOperation.multiply_add_saturate, - ), - ], - }["int8"] - - alignment_constraints = [8, 4, 2] + alignment_constraints = [8, 4, 2] + block_k_factor = 1 + else: + assert out_dtype == "int32" + math_instructions = [ + MathInstruction( + [16, 8, 32], + dtype_map[arg0_dtype], + dtype_map[arg1_dtype], + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + ] + # TODO: Is this the only possible value? + alignment_constraints = [ + 16, + ] + block_k_factor = 2 def get_tile_descriptions(math_inst): min_cc = 80 max_cc = 1024 max_cc_smem_limited = 80 return [ - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription( + [256, 128, 32 * block_k_factor], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32 * block_k_factor], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32 * block_k_factor], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32 * block_k_factor], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32 * block_k_factor], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32 * block_k_factor], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32 * block_k_factor], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32 * block_k_factor], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32 * block_k_factor], 10, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 128, 64 * block_k_factor], + 3, + [4, 2, 1], + math_inst, + min_cc, + max_cc_smem_limited, + ), + TileDescription( + [128, 256, 64 * block_k_factor], + 3, + [2, 4, 1], + math_inst, + min_cc, + max_cc_smem_limited, + ), + TileDescription( + [256, 64, 64 * block_k_factor], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited + ), + TileDescription( + [64, 256, 64 * block_k_factor], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited + ), + TileDescription( + [128, 128, 64 * block_k_factor], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 64 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription([64, 64, 64 * block_k_factor], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, op_creator) + sm75_kernels = [] # generate_sm75_tensor_op_1688(out_dtype, op_creator) sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 0d58ea724ee8..52b4e121710e 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -36,8 +36,7 @@ class DataType(enum.Enum): ShortDataTypeNames = { DataType.f16: "h", DataType.f32: "s", - DataType.s32: 'i', - + DataType.s32: "i", } @@ -70,9 +69,10 @@ class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() + MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", - MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", } diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 31f0408c0f04..920b45099c68 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -146,7 +146,7 @@ def check_conv2d(call): kernel_layout = conv2d.attrs.kernel_layout data = conv2d.args[0].checked_type weight = conv2d.args[1].checked_type - if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight): + if data_layout != "NHWC" or kernel_layout != "OHWI": return False IC = data.shape[3] OC = weight.shape[0] diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 54738ddd772b..7923f8fa33bf 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -371,6 +371,14 @@ def convert_conv2d_layout(mod, desired_layouts): return seq(mod) +def get_random_ndarray(shape, dtype): + if dtype == "int8": + return np.random.randint(-128, 128, shape).astype(dtype) + elif dtype == "uint8": + return np.random.randint(0, 256, shape).astype(dtype) + return np.random.uniform(-1, 1, shape).astype(dtype) + + def verify_conv2d( expr_nchw, # can be dynamic batch expr_ref, # always static batch @@ -382,6 +390,8 @@ def verify_conv2d( use_cudnn_ref=False, run_benchmark=False, use_fast_math=False, + data_dtype="float16", + weight_dtype="float16", ): if not has_cutlass(): return @@ -392,9 +402,9 @@ def verify_conv2d( typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type out_dtype = typ.dtype - np_data = np.random.uniform(-1, 1, d_shape).astype("float16") - np_weight = np.random.uniform(-1, 1, w_shape).astype("float16") - np_bias = np.random.uniform(-1, 1, (w_shape[0],)).astype(out_dtype) + np_data = get_random_ndarray(d_shape, data_dtype) + np_weight = get_random_ndarray(w_shape, weight_dtype) + np_bias = get_random_ndarray((w_shape[0],), out_dtype) params = {"weight": np_weight, "bias": np_bias} @@ -537,5 +547,39 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) +def get_conv2d_nchw_int8(d_shape, w_shape, padding, activation_dtype="int8"): + data = relay.var("data", shape=d_shape, dtype=activation_dtype) + weight = relay.var("weight", shape=w_shape, dtype="int8") + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + out_dtype="int32", + ) + + +def test_int8(): + d_shape = (16, 16, 32, 32) + w_shape = (32, 16, 3, 3) + padding = (1, 1) + mod_nchw = get_conv2d_nchw_int8(d_shape, w_shape, padding) + + verify_conv2d( + mod_nchw, + mod_nchw, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype="int8", + weight_dtype="int8", + ) + + if __name__ == "__main__": pytest.main([__file__]) From aecbbbf8b7aa97b05a3368691e04303fc28b16ae Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 12:17:20 +0900 Subject: [PATCH 04/22] adding 3xtf32 support and refactor tile description enum --- python/tvm/contrib/cutlass/gen_tensor_op.py | 162 +++++++++----------- python/tvm/contrib/cutlass/library.py | 2 + 2 files changed, 72 insertions(+), 92 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index ce865feeef20..82f3cc11445d 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -33,7 +33,7 @@ logger = logging.getLogger("cutlass") -dtype_map = {"int8": DataType.s8, "uint8": DataType.u8} +dtype_map = {"int8": DataType.s8, "uint8": DataType.u8, "float32": DataType.f32} def generate_tensor_op_common( @@ -105,31 +105,72 @@ def get_tile_descriptions(math_inst): def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator): """Generate GEMM or Conv2D kernels for Ampere.""" - if "float" in out_dtype: - math_instructions = { - "float32": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - "float16": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] + min_cc = 80 + max_cc = 1024 + max_cc_smem_limited = 80 + + def get_default_tile_descriptions(block_k_factor): + return [ + ([256, 128, 32 * block_k_factor], 3, [4, 2, 1], min_cc, max_cc), + ([128, 256, 32 * block_k_factor], 3, [2, 4, 1], min_cc, max_cc), + ([256, 64, 32 * block_k_factor], 4, [4, 1, 1], min_cc, max_cc), + ([64, 256, 32 * block_k_factor], 4, [1, 4, 1], min_cc, max_cc), + ([128, 128, 32 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, 32 * block_k_factor], 4, [2, 2, 1], min_cc, max_cc), + ([128, 128, 32 * block_k_factor], 5, [2, 2, 1], min_cc, max_cc), + ([128, 64, 32 * block_k_factor], 6, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32 * block_k_factor], 6, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32 * block_k_factor], 10, [2, 2, 1], min_cc, max_cc), + ([256, 128, 64 * block_k_factor], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([128, 256, 64 * block_k_factor], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([256, 64, 64 * block_k_factor], 4, [4, 1, 1], min_cc, max_cc_smem_limited), + ([64, 256, 64 * block_k_factor], 4, [1, 4, 1], min_cc, max_cc_smem_limited), + ([128, 128, 64 * block_k_factor], 4, [2, 2, 1], min_cc, max_cc), + ([128, 64, 64 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, 64 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, 64 * block_k_factor], 5, [2, 2, 1], min_cc, max_cc), + ] + + if arg0_dtype == "float16" and arg1_dtype == "float16": + math_instructions = [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + dtype_map[out_dtype], + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ] alignment_constraints = [8, 4, 2] - block_k_factor = 1 + tile_descriptions = get_default_tile_descriptions(1) + elif arg0_dtype == "float32" and arg1_dtype == "float32": + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.f32, + DataType.f32, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_f32, + ), + ] + alignment_constraints = [4, 2, 1] + tile_descriptions = [ + ([128, 128, 16], 4, [4, 2, 1], min_cc, max_cc), + ([128, 128, 16], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 16], 3, [4, 2, 1], min_cc, max_cc), + ([64, 256, 16], 3, [2, 4, 1], min_cc, max_cc), + ([128, 64, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 128, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 64, 16], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, 32], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 32], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([64, 256, 32], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([128, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ] else: assert out_dtype == "int32" math_instructions = [ @@ -146,78 +187,15 @@ def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator) alignment_constraints = [ 16, ] - block_k_factor = 2 + tile_descriptions = get_default_tile_descriptions(2) def get_tile_descriptions(math_inst): - min_cc = 80 - max_cc = 1024 - max_cc_smem_limited = 80 return [ - TileDescription( - [256, 128, 32 * block_k_factor], 3, [4, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 256, 32 * block_k_factor], 3, [2, 4, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [256, 64, 32 * block_k_factor], 4, [4, 1, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 256, 32 * block_k_factor], 4, [1, 4, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 128, 32 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 128, 32 * block_k_factor], 4, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 128, 32 * block_k_factor], 5, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 64, 32 * block_k_factor], 6, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 128, 32 * block_k_factor], 6, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 64, 32 * block_k_factor], 10, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [256, 128, 64 * block_k_factor], - 3, - [4, 2, 1], - math_inst, - min_cc, - max_cc_smem_limited, - ), - TileDescription( - [128, 256, 64 * block_k_factor], - 3, - [2, 4, 1], - math_inst, - min_cc, - max_cc_smem_limited, - ), - TileDescription( - [256, 64, 64 * block_k_factor], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited - ), - TileDescription( - [64, 256, 64 * block_k_factor], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited - ), - TileDescription( - [128, 128, 64 * block_k_factor], 4, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 64, 64 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 128, 64 * block_k_factor], 3, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription([64, 64, 64 * block_k_factor], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) + for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] - sm75_kernels = [] # generate_sm75_tensor_op_1688(out_dtype, op_creator) + sm75_kernels = [] # generate_sm75_tensor_op_1688(out_dtype, op_creator) sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 52b4e121710e..ab02f5b62581 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -68,11 +68,13 @@ class DataType(enum.Enum): class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() + multiply_add_fast_f32 = enum_auto() MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32' } From b7b3aa85dbd39d8b21b9fa754d221058b07abf10 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 12:21:24 +0900 Subject: [PATCH 05/22] add 3xtf32 test --- tests/python/contrib/test_cutlass.py | 46 ++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 7923f8fa33bf..b44879b69c13 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -547,11 +547,15 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) -def get_conv2d_nchw_int8(d_shape, w_shape, padding, activation_dtype="int8"): - data = relay.var("data", shape=d_shape, dtype=activation_dtype) +def test_int8(): + d_shape = (16, 16, 32, 32) + w_shape = (32, 16, 3, 3) + padding = (1, 1) + + data = relay.var("data", shape=d_shape, dtype="int8") weight = relay.var("weight", shape=w_shape, dtype="int8") out_channel = w_shape[0] - return relay.nn.conv2d( + expr = relay.nn.conv2d( data=data, weight=weight, kernel_size=w_shape[2:], @@ -560,24 +564,48 @@ def get_conv2d_nchw_int8(d_shape, w_shape, padding, activation_dtype="int8"): out_dtype="int32", ) + verify_conv2d( + expr, + expr, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype="int8", + weight_dtype="int8", + ) + -def test_int8(): +def test_3xtf32(): d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) padding = (1, 1) - mod_nchw = get_conv2d_nchw_int8(d_shape, w_shape, padding) + + data = relay.var("data", shape=d_shape, dtype="float32") + weight = relay.var("weight", shape=w_shape, dtype="float32") + out_channel = w_shape[0] + expr = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + out_dtype="float32", + ) verify_conv2d( - mod_nchw, - mod_nchw, + expr, + expr, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False, - data_dtype="int8", - weight_dtype="int8", + data_dtype="float32", + weight_dtype="float32", ) From e9b0287591b407018ae652371256aa7fee495a9d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 12:29:08 +0900 Subject: [PATCH 06/22] update gemm generator too --- python/tvm/contrib/cutlass/build.py | 66 +++++++++++++++++++++++--- python/tvm/contrib/cutlass/gen_gemm.py | 31 ++++++++++-- python/tvm/contrib/cutlass/library.py | 2 +- 3 files changed, 88 insertions(+), 11 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index bd996525e77c..1cfe2556da22 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -94,12 +94,24 @@ def visit_call(self, call): def select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + batched, + profile_all, + use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic workloads.""" if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): - out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched) + out = cutlass_profiler.get_default( + op_type, out_dtype, arg0_dtype, arg1_dtype, batched=batched + ) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) else: @@ -109,6 +121,8 @@ def select_gemm_kernel( NN, KK, out_dtype, + arg0_dtype, + arg1_dtype, batched=batched, profile_all=profile_all, use_multiprocessing=use_multiprocessing, @@ -122,7 +136,15 @@ def select_gemm_kernel( def handle_batch_matmul( - cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + profile_all, + use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" MM = arg0_shape[1] @@ -130,7 +152,17 @@ def handle_batch_matmul( NN = arg1_shape[1] name, cutlass_op_def = select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + True, + profile_all, + use_multiprocessing, ) return { @@ -147,7 +179,15 @@ def handle_batch_matmul( def handle_dense( - cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + profile_all, + use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" MM = arg0_shape[0] @@ -155,7 +195,17 @@ def handle_dense( NN = arg1_shape[0] name, cutlass_op_def = select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + False, + profile_all, + use_multiprocessing, ) assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now." @@ -293,6 +343,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape, arg1_shape, out_dtype, + arg0_dtype, + arg1_dtype, profile_all, use_multiprocessing, ) @@ -305,6 +357,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape, arg1_shape, out_dtype, + arg0_dtype, + arg1_dtype, profile_all, use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 9159ed881c74..d848fc79b409 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -156,11 +156,13 @@ def check_align(self, op_name, M, N, K): # When the above issue is resolved, we can remove the alignment check on M below. return all([dim % align == 0 for dim in [M, N, K]]) - def get_default(self, op_type, out_dtype, batched=False): + def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, batched=False): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators) + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, arg0_dtype, arg1_dtype, op_creator=enumerate_gemm_operators + ) default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 @@ -176,7 +178,17 @@ def get_default(self, op_type, out_dtype, batched=False): op.update({"name": name, "opdef": opdef}) return op - def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): + def select_op( + self, + M, + N, + K, + out_dtype, + arg0_dtype, + arg1_dtype, + profile_all=True, + use_multiprocessing=False, + ): """ Profile and select the best kernel from candidate kernels. See the documentation for the profile method below. @@ -187,6 +199,8 @@ def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fa ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, + arg0_dtype, + arg1_dtype, op_creator=enumerate_gemm_operators, ) ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) @@ -212,6 +226,8 @@ def profile( N, K, out_dtype, + arg0_dtype, + arg1_dtype, profile_all=True, use_multiprocessing=False, batched=False, @@ -221,7 +237,14 @@ def profile( If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( - M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + M, + N, + K, + out_dtype, + arg0_dtype, + arg1_dtype, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, ) name, opdef = create_gemm_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index ab02f5b62581..5d986f4d03a7 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -74,7 +74,7 @@ class MathOperation(enum.Enum): MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", - MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32' + MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32", } From ad3faf6b8b60e554af420b26efff31dbd6ce32c7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 17:50:27 +0900 Subject: [PATCH 07/22] int8 test worked --- src/relay/backend/contrib/cutlass/codegen.cc | 6 +++++- tests/python/contrib/test_cutlass.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index dc03eea014ab..0a945793b775 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -43,7 +43,11 @@ namespace contrib { using namespace backend; using Str2StrMap = std::unordered_map; -static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}}; +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, + {"float32", "float"}, + {"int8", "int8_t"}, + {"uint8", "uint8_t"}, + {"int32", "int32_t"}}; constexpr const char* kAnyDim = "Any"; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index b44879b69c13..cefb45a9b7a5 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -392,6 +392,7 @@ def verify_conv2d( use_fast_math=False, data_dtype="float16", weight_dtype="float16", + ref_target="cuda", ): if not has_cutlass(): return @@ -436,7 +437,7 @@ def verify_conv2d( rt_mod_ref, dev = get_ref_rt_mod( convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), params, - target="cuda", + target=ref_target, ) ref_out = get_output(rt_mod_ref, ["data"], [np_data]) @@ -575,6 +576,7 @@ def test_int8(): run_benchmark=False, data_dtype="int8", weight_dtype="int8", + ref_target="llvm", ) @@ -610,4 +612,5 @@ def test_3xtf32(): if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_int8() From 77644087d356fe6b2090daedd19ef004f495554e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 17:55:58 +0900 Subject: [PATCH 08/22] 3xtf32 also works --- tests/python/contrib/test_cutlass.py | 53 +++++++++++++++------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index cefb45a9b7a5..fa4508a6feb9 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -553,31 +553,32 @@ def test_int8(): w_shape = (32, 16, 3, 3) padding = (1, 1) - data = relay.var("data", shape=d_shape, dtype="int8") - weight = relay.var("weight", shape=w_shape, dtype="int8") - out_channel = w_shape[0] - expr = relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=w_shape[2:], - channels=out_channel, - padding=padding, - out_dtype="int32", - ) + for data_dtype in ["int8", "uint8"]: + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype="int8") + out_channel = w_shape[0] + expr = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + out_dtype="int32", + ) - verify_conv2d( - expr, - expr, - d_shape, - w_shape, - sm=80, - atol=1e-5, - rtol=1e-5, - run_benchmark=False, - data_dtype="int8", - weight_dtype="int8", - ref_target="llvm", - ) + verify_conv2d( + expr, + expr, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype=data_dtype, + weight_dtype="int8", + ref_target="llvm", + ) def test_3xtf32(): @@ -608,9 +609,11 @@ def test_3xtf32(): run_benchmark=False, data_dtype="float32", weight_dtype="float32", + ref_target="llvm" ) if __name__ == "__main__": # pytest.main([__file__]) - test_int8() + # test_int8() + test_3xtf32() From 2aaed84a3cc398f4d8a8787fc6a7ff4df606a51d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 18:19:50 +0900 Subject: [PATCH 09/22] int8 and 3xtf32 gemm works --- python/tvm/relay/op/contrib/cutlass.py | 5 +-- tests/python/contrib/test_cutlass.py | 44 +++++++++++++++++--------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 920b45099c68..6ae4ea5624bc 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -119,10 +119,7 @@ def get_root_call(call, root_op_name): def check_gemm(call): """Check if the given dense workload can be offloaded to CUTLASS.""" - dense = get_root_call(call, "nn.dense") - lhs = dense.args[0].checked_type - rhs = dense.args[1].checked_type - return check_dtype(lhs, rhs) + return True def check_batch_matmul(call): diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index fa4508a6feb9..fa4a8b725eed 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -68,14 +68,14 @@ def get_output_vm(vm, names, inputs): return vm.invoke("main", **params).numpy() -def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"): - data = relay.var("data", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=weight_shape, dtype="float16") +def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16"): + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) return relay.nn.dense(data, weight, out_dtype=out_dtype) -def get_dense(M, N, K, out_dtype="float16"): - return get_dense_with_shape((M, K), (N, K), out_dtype) +def get_dense(M, N, K, out_dtype="float16", data_dtype="float16", weight_dtype="float16"): + return get_dense_with_shape((M, K), (N, K), out_dtype, data_dtype, weight_dtype) def get_dense_bias(M, N, K, out_dtype="float16"): @@ -178,6 +178,7 @@ def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16" def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): mod = partition_for_cutlass(mod) + print(mod) mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir ) @@ -210,7 +211,17 @@ def profile_and_build_vm( def verify_dense( - func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + func, + M, + N, + K, + ref_target="cuda", + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype="float16", + weight_dtype="float16", ): if not has_cutlass(): return @@ -218,9 +229,9 @@ def verify_dense( typ = relay.transform.InferType()(mod)["main"].body.checked_type out_dtype = typ.dtype use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) - np_data = np.random.uniform(-1, 1, (M, K)).astype("float16") - np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16") - np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype) + np_data = get_random_ndarray((M, K), data_dtype) + np_weight = get_random_ndarray((N, K), weight_dtype) + np_bias = get_random_ndarray((N,), out_dtype) params = {"weight": np_weight, "bias": np_bias} @@ -292,7 +303,7 @@ def verify_batch_matmul( print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) -M = 1820 +M = 1024 N = 768 K = 768 @@ -302,6 +313,9 @@ def test_dense(): verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) # Test align1 case verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K) + verify_dense(get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8") + # Test 3xtf32 kernels + verify_dense(get_dense(M, N, K, "float32", "float32", "float32"), M, N, K, data_dtype="float32", weight_dtype="float32") def test_dense_bias(): @@ -548,7 +562,7 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) -def test_int8(): +def test_conv2d_int8(): d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) padding = (1, 1) @@ -581,7 +595,7 @@ def test_int8(): ) -def test_3xtf32(): +def test_conv2d_3xtf32(): d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) padding = (1, 1) @@ -609,11 +623,11 @@ def test_3xtf32(): run_benchmark=False, data_dtype="float32", weight_dtype="float32", - ref_target="llvm" + ref_target="llvm", ) if __name__ == "__main__": # pytest.main([__file__]) - # test_int8() - test_3xtf32() + test_dense() + # test_3xtf32() From bc563c90b54193338b0bf6de0292c3ea7e7faa95 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 18:27:21 +0900 Subject: [PATCH 10/22] clean up test --- tests/python/contrib/test_cutlass.py | 89 +++++++++++++++------------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index fa4a8b725eed..d6fddfb3c4d4 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -68,7 +68,9 @@ def get_output_vm(vm, names, inputs): return vm.invoke("main", **params).numpy() -def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16"): +def get_dense_with_shape( + data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16" +): data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) return relay.nn.dense(data, weight, out_dtype=out_dtype) @@ -110,9 +112,11 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"): return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") -def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"): - data = relay.var("data", shape=d_shape, dtype="float16") - weight = relay.var("weight", shape=w_shape, dtype="float16") +def get_conv2d_nchw( + d_shape, w_shape, padding, out_dtype="float16", data_dtype="float16", weight_dtype="float16" +): + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) out_channel = w_shape[0] return relay.nn.conv2d( data=data, @@ -313,9 +317,18 @@ def test_dense(): verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) # Test align1 case verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K) - verify_dense(get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8") + verify_dense( + get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8" + ) # Test 3xtf32 kernels - verify_dense(get_dense(M, N, K, "float32", "float32", "float32"), M, N, K, data_dtype="float32", weight_dtype="float32") + verify_dense( + get_dense(M, N, K, "float32", "float32", "float32"), + M, + N, + K, + data_dtype="float32", + weight_dtype="float32", + ) def test_dense_bias(): @@ -494,6 +507,34 @@ def test_conv2d(): mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ) + for data_dtype, weight_dtype, out_dtype in [ + ("float32", "float32", "float32"), # 3xtf32 + ("int8", "int8", "int32"), + ("uint8", "int8", "int32"), + ]: + expr = get_conv2d_nchw( + d_shape, + w_shape, + padding, + out_dtype=out_dtype, + data_dtype=data_dtype, + weight_dtype=weight_dtype, + ) + + verify_conv2d( + expr, + expr, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype=data_dtype, + weight_dtype=weight_dtype, + ref_target="llvm", + ) + def test_conv2d_fusion(): d_shape = (16, 16, 32, 32) @@ -595,39 +636,5 @@ def test_conv2d_int8(): ) -def test_conv2d_3xtf32(): - d_shape = (16, 16, 32, 32) - w_shape = (32, 16, 3, 3) - padding = (1, 1) - - data = relay.var("data", shape=d_shape, dtype="float32") - weight = relay.var("weight", shape=w_shape, dtype="float32") - out_channel = w_shape[0] - expr = relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=w_shape[2:], - channels=out_channel, - padding=padding, - out_dtype="float32", - ) - - verify_conv2d( - expr, - expr, - d_shape, - w_shape, - sm=80, - atol=1e-5, - rtol=1e-5, - run_benchmark=False, - data_dtype="float32", - weight_dtype="float32", - ref_target="llvm", - ) - - if __name__ == "__main__": - # pytest.main([__file__]) - test_dense() - # test_3xtf32() + pytest.main([__file__]) From 2712172182591dc459eebd0d320476341953c563 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Jan 2022 19:06:46 +0900 Subject: [PATCH 11/22] support int8 in sm75 --- python/tvm/contrib/cutlass/build.py | 2 +- python/tvm/contrib/cutlass/gen_conv2d.py | 7 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 86 ++++++++++++++------- 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 1cfe2556da22..760160045a58 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -235,7 +235,7 @@ def handle_conv2d( ): """Profile and select a kernel for conv2d op workload.""" if any(isinstance(s, tvm.tir.Any) for s in d_shape): - out = cutlass_profiler.get_default(op_type, out_dtype) + out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) else: diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 90bf7630f9a5..f2a04ca7582a 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -153,8 +153,10 @@ def __init__(self, sm, cutlass_path, binary_path): self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.cache = {} - def get_default(self, op_type, out_dtype): - gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype) + def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype): + gemm_profile_result = self.gemm_profiler.get_default( + op_type, out_dtype, arg0_dtype, arg1_dtype + ) tile_description = gemm_profile_result["tile_description"] alignment = gemm_profile_result["alignment"] data_type = gemm_profile_result["data_type"] @@ -214,6 +216,7 @@ def select_op( weight_dtype, op_creator=enumerate_conv2d_operators, ) + ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) if profile_all: diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 82f3cc11445d..3c68b3c82c2d 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -33,7 +33,12 @@ logger = logging.getLogger("cutlass") -dtype_map = {"int8": DataType.s8, "uint8": DataType.u8, "float32": DataType.f32} +dtype_map = { + "int8": DataType.s8, + "uint8": DataType.u8, + "float32": DataType.f32, + "float16": DataType.f16, +} def generate_tensor_op_common( @@ -57,45 +62,65 @@ def generate_tensor_op_common( return ops -def generate_sm75_tensor_op_1688(out_dtype, op_creator): +def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): """Generate GEMM or Conv2D kernels for Turing.""" - assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ + assert out_dtype in ["float32", "float16", "int32"] + min_cc = 75 + max_cc = 1024 + + if arg0_dtype == "float16" and arg1_dtype == "float16": + math_instructions = [ MathInstruction( [16, 8, 8], DataType.f16, DataType.f16, - DataType.f32, + dtype_map[out_dtype], OpcodeClass.TensorOp, MathOperation.multiply_add, ) - ], - "float16": [ + ] + alignment_constraints = [8, 4, 2, 1] + tile_descriptions = [ + ([256, 128, 32], 2, [4, 2, 1], min_cc, max_cc), + ([128, 256, 32], 2, [2, 4, 1], min_cc, max_cc), + ([128, 128, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32], 2, [2, 2, 1], min_cc, max_cc), + ([128, 64, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 128, 64], 2, [1, 2, 2], min_cc, max_cc), + ] + + else: + assert out_dtype == "int32" + math_instructions = [ MathInstruction( - [16, 8, 8], - DataType.f16, - DataType.f16, - DataType.f16, + [8, 8, 16], + dtype_map[arg0_dtype], + dtype_map[arg1_dtype], + DataType.s32, OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] - - alignment_constraints = [8, 4, 2, 1] + MathOperation.multiply_add_saturate, + ), + ] + # TODO: Is this the only possible value? + alignment_constraints = [ + 16, + ] + tile_descriptions = [ + ([256, 128, 64], 2, [4, 2, 1], min_cc, max_cc), + ([128, 256, 64], 2, [2, 4, 1], min_cc, max_cc), + ([128, 128, 64], 2, [2, 2, 1], min_cc, max_cc), + ([64, 256, 64], 2, [1, 4, 1], min_cc, max_cc), + ([256, 64, 64], 2, [4, 1, 1], min_cc, max_cc), + ([64, 128, 64], 2, [2, 2, 1], min_cc, max_cc), + ([128, 64, 64], 2, [2, 2, 1], min_cc, max_cc), + ([64, 64, 64], 2, [2, 2, 1], min_cc, max_cc), + ] def get_tile_descriptions(math_inst): - min_cc = 75 - max_cc = 1024 return [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) + for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] return generate_tensor_op_common( @@ -195,7 +220,12 @@ def get_tile_descriptions(math_inst): for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] - sm75_kernels = [] # generate_sm75_tensor_op_1688(out_dtype, op_creator) + if arg0_dtype != "float32" and arg1_dtype != "float32": + sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator) + else: + # TF32 (float32 + float32 case) is only supported on sm80 + sm75_kernels = [] + sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) From c814b7e6db903e8a97d608089e0ed97d6878e0fb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 05:06:17 +0900 Subject: [PATCH 12/22] refined int8 alignment constraints --- python/tvm/contrib/cutlass/gen_conv2d.py | 7 +++-- python/tvm/contrib/cutlass/gen_gemm.py | 6 ++-- python/tvm/contrib/cutlass/gen_tensor_op.py | 6 ++-- tests/python/contrib/test_cutlass.py | 35 +-------------------- 4 files changed, 10 insertions(+), 44 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index f2a04ca7582a..b8cf0de8daee 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -167,9 +167,10 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype): def check_align(self, op_name, C, K): """Filter out kernels that cannot be supported.""" - aligns = re.findall(r"align[1|2|4|8]", op_name) - assert len(aligns) == 1 - align = int(aligns[0][-1]) + match = re.match(".*_align([1-9]+)", op_name) + assert match is not None and len(match.groups()) == 1 + # The same alignment is used for all axes + align = int(match.groups()[0]) return all([dim % align == 0 for dim in [C, K]]) def select_op( diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index d848fc79b409..e25aca0da0f1 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -147,10 +147,10 @@ def __init__(self, sm, cutlass_path, binary_path): def check_align(self, op_name, M, N, K): """Filter out kernels that cannot be supported.""" - aligns = re.findall(r"align[1|2|4|8]", op_name) - assert len(aligns) == 1 + match = re.match(".*_align([1-9]+)", op_name) + assert match is not None and len(match.groups()) == 1 # The same alignment is used for all axes - align = int(aligns[0][-1]) + align = int(match.groups()[0]) # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. # See https://github.com/NVIDIA/cutlass/issues/362. # When the above issue is resolved, we can remove the alignment check on M below. diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 3c68b3c82c2d..83da56f2f5ec 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -102,9 +102,8 @@ def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): MathOperation.multiply_add_saturate, ), ] - # TODO: Is this the only possible value? alignment_constraints = [ - 16, + 16, 8, 4, 2, 1 ] tile_descriptions = [ ([256, 128, 64], 2, [4, 2, 1], min_cc, max_cc), @@ -208,9 +207,8 @@ def get_default_tile_descriptions(block_k_factor): MathOperation.multiply_add_saturate, ), ] - # TODO: Is this the only possible value? alignment_constraints = [ - 16, + 16, 8, 4 ] tile_descriptions = get_default_tile_descriptions(2) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index d6fddfb3c4d4..9e30eb13e184 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -307,7 +307,7 @@ def verify_batch_matmul( print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) -M = 1024 +M = 1820 N = 768 K = 768 @@ -603,38 +603,5 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) -def test_conv2d_int8(): - d_shape = (16, 16, 32, 32) - w_shape = (32, 16, 3, 3) - padding = (1, 1) - - for data_dtype in ["int8", "uint8"]: - data = relay.var("data", shape=d_shape, dtype=data_dtype) - weight = relay.var("weight", shape=w_shape, dtype="int8") - out_channel = w_shape[0] - expr = relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=w_shape[2:], - channels=out_channel, - padding=padding, - out_dtype="int32", - ) - - verify_conv2d( - expr, - expr, - d_shape, - w_shape, - sm=80, - atol=1e-5, - rtol=1e-5, - run_benchmark=False, - data_dtype=data_dtype, - weight_dtype="int8", - ref_target="llvm", - ) - - if __name__ == "__main__": pytest.main([__file__]) From 4cfa8e7f84fd85cf376d51ac54426fd3938571b6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 05:11:16 +0900 Subject: [PATCH 13/22] black --- python/tvm/contrib/cutlass/gen_tensor_op.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 83da56f2f5ec..f9ff7866d33e 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -102,9 +102,7 @@ def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): MathOperation.multiply_add_saturate, ), ] - alignment_constraints = [ - 16, 8, 4, 2, 1 - ] + alignment_constraints = [16, 8, 4, 2, 1] tile_descriptions = [ ([256, 128, 64], 2, [4, 2, 1], min_cc, max_cc), ([128, 256, 64], 2, [2, 4, 1], min_cc, max_cc), @@ -207,9 +205,7 @@ def get_default_tile_descriptions(block_k_factor): MathOperation.multiply_add_saturate, ), ] - alignment_constraints = [ - 16, 8, 4 - ] + alignment_constraints = [16, 8, 4] tile_descriptions = get_default_tile_descriptions(2) def get_tile_descriptions(math_inst): From f408eabf7d2de1884c120e80db29d293256dec8f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 05:42:20 +0900 Subject: [PATCH 14/22] support 3xtf32 in default kernel --- python/tvm/contrib/cutlass/gen_gemm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index e25aca0da0f1..33a1676a3ef5 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -125,13 +125,14 @@ def enumerate_gemm_operators( # TODO(masahi): A sensible way to pick reasonable default kernels DEFAULT_KERNELS = { 75: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + ("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + ("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, # align1 variants do not seem to be available for sm80 80: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + ("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + ("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + ("float32", "float32"): "cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1", }, } @@ -163,7 +164,7 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, batched=False) ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, arg0_dtype, arg1_dtype, op_creator=enumerate_gemm_operators ) - default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] + default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 op = filtered[0] From 85560f8f5cdb26b19bc7a7f1e5418231289d6e81 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 05:45:36 +0900 Subject: [PATCH 15/22] remove log --- tests/python/contrib/test_cutlass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 9e30eb13e184..86c1a2ebd4fb 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -182,7 +182,6 @@ def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16" def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): mod = partition_for_cutlass(mod) - print(mod) mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir ) From b090d592fcf11c3361d116b941d8621d0e927777 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 06:01:42 +0900 Subject: [PATCH 16/22] refine dtype check --- python/tvm/relay/op/contrib/cutlass.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 6ae4ea5624bc..2cc61923d4b2 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -105,8 +105,11 @@ def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu") def check_dtype(lhs, rhs): """Check if dtypes in the given workload are supported by CUTLASS.""" - # Only fp16 inputs are supported for now. - return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16" + return ( + (lhs.dtype == "float16" and rhs.dtype == "float16") + or (lhs.dtype == "float32" and rhs.dtype == "float32") + or (lhs.dtype in ["int8", "uint8"] and rhs.dtype in ["int8", "uint8"]) + ) def get_root_call(call, root_op_name): @@ -119,7 +122,10 @@ def get_root_call(call, root_op_name): def check_gemm(call): """Check if the given dense workload can be offloaded to CUTLASS.""" - return True + dense = get_root_call(call, "nn.dense") + lhs = dense.args[0].checked_type + rhs = dense.args[1].checked_type + return check_dtype(lhs, rhs) def check_batch_matmul(call): @@ -143,7 +149,7 @@ def check_conv2d(call): kernel_layout = conv2d.attrs.kernel_layout data = conv2d.args[0].checked_type weight = conv2d.args[1].checked_type - if data_layout != "NHWC" or kernel_layout != "OHWI": + if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight): return False IC = data.shape[3] OC = weight.shape[0] From 974636ddac856171ba8577395c592e33f2cc6c78 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 13:24:24 +0900 Subject: [PATCH 17/22] support tf32 --- python/tvm/contrib/cutlass/build.py | 19 +++++++++-- python/tvm/contrib/cutlass/gen_conv2d.py | 12 +++---- python/tvm/contrib/cutlass/gen_gemm.py | 12 +++++-- python/tvm/contrib/cutlass/gen_tensor_op.py | 38 ++++++++++++--------- 4 files changed, 52 insertions(+), 29 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 760160045a58..c919ff283343 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -102,6 +102,7 @@ def select_gemm_kernel( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, batched, profile_all, use_multiprocessing, @@ -110,7 +111,7 @@ def select_gemm_kernel( workloads.""" if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): out = cutlass_profiler.get_default( - op_type, out_dtype, arg0_dtype, arg1_dtype, batched=batched + op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched ) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) @@ -123,6 +124,7 @@ def select_gemm_kernel( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, batched=batched, profile_all=profile_all, use_multiprocessing=use_multiprocessing, @@ -143,6 +145,7 @@ def handle_batch_matmul( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ): @@ -160,6 +163,7 @@ def handle_batch_matmul( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, True, profile_all, use_multiprocessing, @@ -186,6 +190,7 @@ def handle_dense( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ): @@ -203,6 +208,7 @@ def handle_dense( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, False, profile_all, use_multiprocessing, @@ -230,12 +236,13 @@ def handle_conv2d( out_dtype, data_dtype, weight_dtype, + use_3xtf32, profile_all, use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" if any(isinstance(s, tvm.tir.Any) for s in d_shape): - out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype) + out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) else: @@ -249,6 +256,7 @@ def handle_conv2d( out_dtype, data_dtype, weight_dtype, + use_3xtf32, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) @@ -263,7 +271,9 @@ def handle_conv2d( } -def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): +def tune_cutlass_kernels( + mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" +): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -331,6 +341,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) @@ -345,6 +356,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) @@ -359,6 +371,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index b8cf0de8daee..ff5f512428ea 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -153,9 +153,9 @@ def __init__(self, sm, cutlass_path, binary_path): self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.cache = {} - def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype): + def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): gemm_profile_result = self.gemm_profiler.get_default( - op_type, out_dtype, arg0_dtype, arg1_dtype + op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32 ) tile_description = gemm_profile_result["tile_description"] alignment = gemm_profile_result["alignment"] @@ -183,6 +183,7 @@ def select_op( out_dtype, data_dtype, weight_dtype, + use_3xtf32, profile_all=True, use_multiprocessing=False, ): @@ -212,10 +213,7 @@ def select_op( return self.cache[workload] ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, - data_dtype, - weight_dtype, - op_creator=enumerate_conv2d_operators, + out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32 ) ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) @@ -250,6 +248,7 @@ def profile( out_dtype, data_dtype, weight_dtype, + use_3xtf32=True, profile_all=True, use_multiprocessing=False, ): @@ -266,6 +265,7 @@ def profile( out_dtype, data_dtype, weight_dtype, + use_3xtf32, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 33a1676a3ef5..ff118da58468 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -157,12 +157,14 @@ def check_align(self, op_name, M, N, K): # When the above issue is resolved, we can remove the alignment check on M below. return all([dim % align == 0 for dim in [M, N, K]]) - def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, batched=False): + def get_default( + self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False + ): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, op_creator=enumerate_gemm_operators + out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) @@ -187,6 +189,7 @@ def select_op( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all=True, use_multiprocessing=False, ): @@ -202,7 +205,8 @@ def select_op( out_dtype, arg0_dtype, arg1_dtype, - op_creator=enumerate_gemm_operators, + enumerate_gemm_operators, + use_3xtf32=use_3xtf32, ) ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) @@ -229,6 +233,7 @@ def profile( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32=True, profile_all=True, use_multiprocessing=False, batched=False, @@ -244,6 +249,7 @@ def profile( out_dtype, arg0_dtype, arg1_dtype, + use_3xtf32, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index f9ff7866d33e..decd2bf1f861 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -125,7 +125,7 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator): +def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, use_3xtf32=True): """Generate GEMM or Conv2D kernels for Ampere.""" min_cc = 80 max_cc = 1024 @@ -174,25 +174,29 @@ def get_default_tile_descriptions(block_k_factor): DataType.f32, DataType.f32, OpcodeClass.TensorOp, - MathOperation.multiply_add_fast_f32, + MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add, ), ] alignment_constraints = [4, 2, 1] - tile_descriptions = [ - ([128, 128, 16], 4, [4, 2, 1], min_cc, max_cc), - ([128, 128, 16], 3, [4, 2, 1], min_cc, max_cc), - ([256, 64, 16], 3, [4, 2, 1], min_cc, max_cc), - ([64, 256, 16], 3, [2, 4, 1], min_cc, max_cc), - ([128, 64, 16], 4, [2, 2, 1], min_cc, max_cc), - ([64, 128, 16], 4, [2, 2, 1], min_cc, max_cc), - ([64, 64, 16], 3, [2, 2, 1], min_cc, max_cc), - ([128, 128, 32], 3, [4, 2, 1], min_cc, max_cc), - ([256, 64, 32], 3, [4, 2, 1], min_cc, max_cc_smem_limited), - ([64, 256, 32], 3, [2, 4, 1], min_cc, max_cc_smem_limited), - ([128, 64, 32], 3, [2, 2, 1], min_cc, max_cc), - ([64, 128, 32], 3, [2, 2, 1], min_cc, max_cc), - ([64, 64, 32], 3, [2, 2, 1], min_cc, max_cc), - ] + + if use_3xtf32: + tile_descriptions = [ + ([128, 128, 16], 4, [4, 2, 1], min_cc, max_cc), + ([128, 128, 16], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 16], 3, [4, 2, 1], min_cc, max_cc), + ([64, 256, 16], 3, [2, 4, 1], min_cc, max_cc), + ([128, 64, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 128, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 64, 16], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, 32], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 32], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([64, 256, 32], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([128, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ] + else: + tile_descriptions = get_default_tile_descriptions(0.5) else: assert out_dtype == "int32" math_instructions = [ From 5038fc3c7eacf0edf11292fef1a36310d15cfbf3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 13:30:54 +0900 Subject: [PATCH 18/22] leave TODO for alignment modification on int8 kernels --- python/tvm/contrib/cutlass/gen_tensor_op.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index decd2bf1f861..2aa918910835 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -227,6 +227,15 @@ def get_tile_descriptions(math_inst): sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) + + # TODO(masahi): For int8 kernels, The CUTLASS generator modifies the output tensor alignment + # after ops are created. Revisit how important this modification is. + # for op in operations: + # if op.tile_description.threadblock_shape[1] >= 128: + # op.C.alignment = 16 + # else: + # op.C.alignment = 8 + return sm75_kernels + sm80_kernels From 4a4875ccc1ace2b2f25c091b2f96fb6192179622 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 14:53:18 +0900 Subject: [PATCH 19/22] tf32 test working --- python/tvm/contrib/cutlass/gen_tensor_op.py | 36 +++++++++--------- tests/python/contrib/test_cutlass.py | 42 +++++++++++++++++---- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 2aa918910835..42a49748583d 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -133,24 +133,24 @@ def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, def get_default_tile_descriptions(block_k_factor): return [ - ([256, 128, 32 * block_k_factor], 3, [4, 2, 1], min_cc, max_cc), - ([128, 256, 32 * block_k_factor], 3, [2, 4, 1], min_cc, max_cc), - ([256, 64, 32 * block_k_factor], 4, [4, 1, 1], min_cc, max_cc), - ([64, 256, 32 * block_k_factor], 4, [1, 4, 1], min_cc, max_cc), - ([128, 128, 32 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), - ([128, 128, 32 * block_k_factor], 4, [2, 2, 1], min_cc, max_cc), - ([128, 128, 32 * block_k_factor], 5, [2, 2, 1], min_cc, max_cc), - ([128, 64, 32 * block_k_factor], 6, [2, 2, 1], min_cc, max_cc), - ([64, 128, 32 * block_k_factor], 6, [2, 2, 1], min_cc, max_cc), - ([64, 64, 32 * block_k_factor], 10, [2, 2, 1], min_cc, max_cc), - ([256, 128, 64 * block_k_factor], 3, [4, 2, 1], min_cc, max_cc_smem_limited), - ([128, 256, 64 * block_k_factor], 3, [2, 4, 1], min_cc, max_cc_smem_limited), - ([256, 64, 64 * block_k_factor], 4, [4, 1, 1], min_cc, max_cc_smem_limited), - ([64, 256, 64 * block_k_factor], 4, [1, 4, 1], min_cc, max_cc_smem_limited), - ([128, 128, 64 * block_k_factor], 4, [2, 2, 1], min_cc, max_cc), - ([128, 64, 64 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), - ([64, 128, 64 * block_k_factor], 3, [2, 2, 1], min_cc, max_cc), - ([64, 64, 64 * block_k_factor], 5, [2, 2, 1], min_cc, max_cc), + ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), + ([128, 64, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), + ([64, 128, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), + ([64, 64, int(32 * block_k_factor)], 10, [2, 2, 1], min_cc, max_cc), + ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), + ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), + ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), ] if arg0_dtype == "float16" and arg1_dtype == "float16": diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 86c1a2ebd4fb..57f2f39c641b 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -180,10 +180,17 @@ def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16" return bias_add, data -def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): +def profile_and_build( + mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False, use_3xtf32=True +): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( - mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir + mod, + sm, + use_3xtf32=use_3xtf32, + profile_all=False, + use_multiprocessing=False, + tmp_dir=tmp_dir, ) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target="cuda", params=params) @@ -201,9 +208,12 @@ def profile_and_build_vm( lib_path="compile.so", vmcode_path="vmcode.ro", use_fast_math=False, + use_3xtf32=True, ): mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels(mod, sm, tmp_dir=tmp_dir) + mod, num_cutlass_partition = tune_cutlass_kernels( + mod, sm, use_3xtf32=use_3xtf32, tmp_dir=tmp_dir + ) with tvm.transform.PassContext(opt_level=3): vm_exec = relay.vm.compile(mod, target="cuda", params=params) vm_exec = build_cutlass_kernels_vm( @@ -225,6 +235,7 @@ def verify_dense( run_benchmark=False, data_dtype="float16", weight_dtype="float16", + use_3xtf32=True, ): if not has_cutlass(): return @@ -249,7 +260,9 @@ def verify_dense( ) return else: - rt_mod, dev, num_partition = profile_and_build_vm(mod, params, sm) + rt_mod, dev, num_partition = profile_and_build_vm( + mod, params, sm, use_3xtf32=use_3xtf32 + ) rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target) x = tvm.nd.array(np_data, device=dev) @@ -257,7 +270,7 @@ def verify_dense( ref_out = get_output_vm(rt_mod_ref, ["data"], [x]) else: rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target) - rt_mod, dev, num_partition = profile_and_build(mod, params, sm) + rt_mod, dev, num_partition = profile_and_build(mod, params, sm, use_3xtf32=use_3xtf32) x = tvm.nd.array(np_data, device=dev) out = get_output(rt_mod, ["data"], [x]) ref_out = get_output(rt_mod_ref, ["data"], [x]) @@ -316,12 +329,27 @@ def test_dense(): verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) # Test align1 case verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K) + # int8 verify_dense( get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8" ) - # Test 3xtf32 kernels + + dense_fp32 = get_dense(M, N, K, "float32", "float32", "float32") + # tf32 + verify_dense( + dense_fp32, + M, + N, + K, + data_dtype="float32", + weight_dtype="float32", + use_3xtf32=False, + atol=1e-2, + rtol=1e-2, + ) + # 3xtf32 verify_dense( - get_dense(M, N, K, "float32", "float32", "float32"), + dense_fp32, M, N, K, From 59f86908708bed853405639334f0d101f8cc2611 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 15:12:31 +0900 Subject: [PATCH 20/22] fix default kernel for tf32 --- python/tvm/contrib/cutlass/gen_gemm.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index ff118da58468..445acb9305c8 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -132,7 +132,11 @@ def enumerate_gemm_operators( 80: { ("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", ("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", - ("float32", "float32"): "cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1", + # two kernels for tf32 and 3xtf32 + ("float32", "float32"): ( + "cutlass_tensorop_s1688gemm_128x64_32x3_tn_align1", + "cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1", + ), }, } @@ -167,6 +171,12 @@ def get_default( out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] + + if arg0_dtype == "float32": + default_kernel_name = ( + default_kernel_name[0] if not use_3xtf32 else default_kernel_name[1] + ) + filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 op = filtered[0] From 9b7f835d09e3ea9275a10c62fa5ceb6dc31b5c7a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 9 Jan 2022 17:08:16 +0900 Subject: [PATCH 21/22] workaround for compilation failure --- python/tvm/contrib/cutlass/gen_tensor_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 42a49748583d..6bb4f290233e 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -306,6 +306,9 @@ def evaluate(self, op, args): opath = os.path.join(self.binary_prefix, op_name) if not os.path.exists(opath): self._compile(op) + if not os.path.exists(opath): + # Bail out if compilation fails for a whatever reason (e.g. static assert failure) + return float("inf") cmd = [opath] for arg in args: cmd.append(str(arg)) From b42f0433e0d627d4873d7733ae4669667bdd6575 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Jan 2022 15:36:40 +0900 Subject: [PATCH 22/22] lint --- python/tvm/contrib/cutlass/gen_conv2d.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index ff5f512428ea..c09017adfd95 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -154,6 +154,9 @@ def __init__(self, sm, cutlass_path, binary_path): self.cache = {} def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): + """Return the default kernel for the requested architecture. + For now, the default kernel was picked arbitrary. + """ gemm_profile_result = self.gemm_profiler.get_default( op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32 )