From fd67595831c7b8741f30577bc91488bcce34a76a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 23 Dec 2021 14:31:06 +0900 Subject: [PATCH 1/5] Refactor cutlass kernel generation and selection --- python/tvm/contrib/cutlass/build.py | 78 +++---- python/tvm/contrib/cutlass/gen_conv2d.py | 194 ++++++++++------- python/tvm/contrib/cutlass/gen_gemm.py | 230 +++++++++++--------- python/tvm/contrib/cutlass/gen_tensor_op.py | 18 ++ python/tvm/relay/op/contrib/cutlass.py | 23 +- 5 files changed, 302 insertions(+), 241 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 3bc3b5defaf2..e921302eafce 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -94,15 +94,17 @@ def visit_call(self, call): def select_gemm_kernel( - cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing + cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic workloads.""" if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): - out = cutlass_profiler.get_default(out_dtype, batched=batched) - logger.info("Picked the default kernel %s", out["name"]) + out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched) + name, cutlass_op_def = out["name"], out["opdef"] + logger.info("Picked the default kernel %s", name) else: - out = cutlass_profiler.profile( + name, cutlass_op_def, _ = cutlass_profiler.profile( + op_type, MM, NN, KK, @@ -112,10 +114,11 @@ def select_gemm_kernel( use_multiprocessing=use_multiprocessing, ) if profile_all: - logger.info("The best kernel is %s", out["name"]) + logger.info("The best kernel is %s", name) else: - logger.info("Picked the first kernel found %s", out["name"]) - return out + logger.info("Picked the first kernel found %s", name) + + return name, cutlass_op_def def handle_batch_matmul( @@ -126,24 +129,17 @@ def handle_batch_matmul( KK = arg0_shape[2] NN = arg1_shape[1] - out = select_gemm_kernel( - cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing + name, cutlass_op_def = select_gemm_kernel( + cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing ) - if op_type == "cutlass.batch_matmul": - cutlass_op_def = out["opdef"] - else: - raise ValueError("%s pattern is not implemented." % op_type) - - assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now." - return { "batch": arg0_shape[0], "batch_stride_A": arg0_shape[1] * arg0_shape[2], "batch_stride_B": arg1_shape[1] * arg1_shape[2], "batch_stride_C": arg0_shape[1] * arg1_shape[1], "cutlass_op_def": cutlass_op_def, - "cutlass_op_name": out["name"], + "cutlass_op_name": name, "lda": "K", "ldb": "K", "ldc": "N", @@ -158,26 +154,15 @@ def handle_dense( KK = arg0_shape[1] NN = arg1_shape[0] - out = select_gemm_kernel( - cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing + name, cutlass_op_def = select_gemm_kernel( + cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing ) - if op_type == "cutlass.dense": - cutlass_op_def = out["opdef"] - elif op_type == "cutlass.dense_bias": - cutlass_op_def = out["opdef_bias"] - elif op_type == "cutlass.dense_bias_relu": - cutlass_op_def = out["opdef_bias_relu"] - elif "cutlass.dense_bias_gelu" in op_type: - cutlass_op_def = out["opdef_bias_gelu"] - else: - raise ValueError("%s pattern is not implemented." % op_type) - - assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now." + assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now." return { "cutlass_op_def": cutlass_op_def, - "cutlass_op_name": out["name"], + "cutlass_op_name": name, "lda": "K", "ldb": "K", "ldc": "N", @@ -198,10 +183,12 @@ def handle_conv2d( ): """Profile and select a kernel for conv2d op workload.""" if any(isinstance(s, tvm.tir.Any) for s in d_shape): - out = cutlass_profiler.get_default(out_dtype) - logger.info("Picked the default kernel %s", out["name"]) + out = cutlass_profiler.get_default(op_type, out_dtype) + name, cutlass_op_def = out["name"], out["opdef"] + logger.info("Picked the default kernel %s", name) else: - out = cutlass_profiler.profile( + name, cutlass_op_def, _ = cutlass_profiler.profile( + op_type, d_shape, w_shape, padding, @@ -212,28 +199,13 @@ def handle_conv2d( use_multiprocessing=use_multiprocessing, ) if profile_all: - logger.info("The best kernel is %s", out["name"]) + logger.info("The best kernel is %s", name) else: - logger.info("Picked the first kernel found %s", out["name"]) - - if op_type == "cutlass.conv2d": - cutlass_op_def = out["opdef"] - elif op_type == "cutlass.conv2d_bias": - cutlass_op_def = out["opdef_bias"] - elif op_type == "cutlass.conv2d_bias_relu": - cutlass_op_def = out["opdef_bias_relu"] - elif op_type == "cutlass.conv2d_bias_sigmoid": - cutlass_op_def = out["opdef_bias_sigmoid"] - elif op_type == "cutlass.conv2d_bias_silu": - cutlass_op_def = out["opdef_bias_silu"] - elif op_type == "cutlass.conv2d_bias_hardswish": - cutlass_op_def = out["opdef_bias_hardswish"] - else: - raise ValueError("%s pattern is not implemented." % op_type) + logger.info("Picked the first kernel found %s", name) return { "cutlass_op_def": cutlass_op_def, - "cutlass_op_name": out["name"], + "cutlass_op_name": name, } diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 43317f9054bb..9516d9be82dd 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -20,10 +20,7 @@ from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter -from .gen_tensor_op import ( - ProfilerEngine, - GENERATOR_FUNC_TABLE, -) +from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP from .library import ( EpilogueFunctor, SwizzlingFunctor, @@ -35,7 +32,39 @@ ) -def create_conv2d_operator( +def create_conv2d_operator_with_epilogue( + op_type, tile_description, data_type, alignment, swizzling_functor +): + """TODO""" + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + + element_a, element_b, element_c, element_epilogue = data_type + + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + + op = Conv2dOperation( + ConvKind.Fprop, + IteratorAlgorithm.Optimized, + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + epilogue, + swizzling_functor, + ) + + name = op.procedural_name() + opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling) + + return name, opdef + + +def enumerate_conv2d_operators( tile_descriptions, data_type, alignment_constraints, @@ -48,77 +77,38 @@ def create_conv2d_operator( profiler_emitter = Conv2dProfilerEmitter() element_a, element_b, element_c, element_epilogue = data_type - iterator_algorithms = [IteratorAlgorithm.Optimized] - layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) for tile in tile_descriptions: for alignment in alignment_constraints: - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - swizzling_functor_ = swizzling_functor - - for iterator_algorithm in iterator_algorithms: - op_entry = {} - - op = Conv2dOperation( - ConvKind.Fprop, - iterator_algorithm, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - StrideSupport.Strided, - EpilogueFunctor.LinearCombination, - swizzling_functor_, - ) - - op_entry["opdef"] = kernel_emitter.emit(op) - op_entry["op"] = op - op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name()) - op_entry["name"] = op.procedural_name() - - # fused ops - for epilogue, opdef, no_bias_scaling in zip( - [ - EpilogueFunctor.LinearCombinationBias, - EpilogueFunctor.LinearCombinationRelu, - EpilogueFunctor.LinearCombinationSigmoid, - EpilogueFunctor.LinearCombinationSilu, - EpilogueFunctor.LinearCombinationHardSwish, - ], - [ - "opdef_bias", - "opdef_bias_relu", - "opdef_bias_sigmoid", - "opdef_bias_silu", - "opdef_bias_hardswish", - ], - [True, True, False, False, False], - ): - op = Conv2dOperation( - ConvKind.Fprop, - iterator_algorithm, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - StrideSupport.Strided, - epilogue, - swizzling_functor_, - ) - - op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling) - - ret.append(op_entry) + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + + op = Conv2dOperation( + ConvKind.Fprop, + IteratorAlgorithm.Optimized, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + EpilogueFunctor.LinearCombination, + swizzling_functor, + ) + + ret.append( + { + "src": profiler_emitter.emit(kernel_emitter.emit(op), op.procedural_name()), + "name": op.procedural_name(), + "tile_description": tile, + "alignment": alignment, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + } + ) return ret @@ -133,12 +123,15 @@ def __init__(self, sm, cutlass_path, binary_path): self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.cache = {} - def get_default(self, out_dtype): - gemm_profile_result = self.gemm_profiler.get_default(out_dtype) + def get_default(self, op_type, out_dtype): + gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype) tile_description = gemm_profile_result["tile_description"] alignment = gemm_profile_result["alignment"] data_type = gemm_profile_result["data_type"] - return create_conv2d_operator([tile_description], data_type, [alignment])[0] + name, opdef = create_conv2d_operator_with_epilogue( + op_type, tile_description, data_type, alignment, SwizzlingFunctor.Identity4 + ) + return {"name": name, "opdef": opdef} def check_align(self, op_name, C, K): """Filter out kernels that cannot be supported.""" @@ -147,7 +140,7 @@ def check_align(self, op_name, C, K): align = int(aligns[0][-1]) return all([dim % align == 0 for dim in [C, K]]) - def profile( + def select_op( self, d_shape, w_shape, @@ -158,10 +151,7 @@ def profile( profile_all=True, use_multiprocessing=False, ): - """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. - If use_multiprocessing is True, compile all profiler executables in parallel. - """ + """TODO""" N, H, W, IC = d_shape OC, R, S, _ = w_shape workload = ( @@ -183,7 +173,10 @@ def profile( if workload in self.cache: return self.cache[workload] - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator) + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, + op_creator=enumerate_conv2d_operators, + ) ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) if profile_all: @@ -201,6 +194,39 @@ def profile( self.cache[workload] = op return op - output = min(ops, key=lambda i: i["runtime"]) - self.cache[workload] = output - return output + op = min(ops, key=lambda i: i["runtime"]) + self.cache[workload] = op + return op + + def profile( + self, + op_type, + d_shape, + w_shape, + padding, + stride, + dilation, + out_dtype, + profile_all=True, + use_multiprocessing=False, + ): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + op = self.select_op( + d_shape, + w_shape, + padding, + stride, + dilation, + out_dtype, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, + ) + + name, opdef = create_conv2d_operator_with_epilogue( + op_type, op["tile_description"], op["data_type"], op["alignment"], op["swizzle_functor"] + ) + + return name, opdef, op["runtime"] diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 7048c32fe1da..2cb89cfbfe48 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,14 +16,10 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" -from functools import partial import re from .gemm_operation import GemmOperation, EmitGemmInstance from .gemm_profiler import GemmProfilerEmitter -from .gen_tensor_op import ( - ProfilerEngine, - GENERATOR_FUNC_TABLE, -) +from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP from .library import ( EpilogueFunctor, SwizzlingFunctor, @@ -33,12 +29,47 @@ ) -def create_gemm_operator( +def create_gemm_operator_with_epilogue( + op_type, + tile_description, + data_type, + alignment, + swizzling_functor, + batched=False, +): + """TODO""" + element_a, element_b, element_c, element_epilogue = data_type + + A = TensorDescription(element_a, LayoutType.RowMajor, alignment) + B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + C = TensorDescription(element_c, LayoutType.RowMajor, alignment) + + if batched: + swizzling_functor = SwizzlingFunctor.Batched + + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue, + swizzling_functor, + ) + + return op.procedural_name(), EmitGemmInstance().emit( + op, no_beta_scaling=no_beta_scaling, batched=batched + ) + + +def enumerate_gemm_operators( tile_descriptions, data_type, alignment_constraints, swizzling_functor=SwizzlingFunctor.Identity8, - batched=False, ): """Exhaustively instantiate all kernels from a given configuration.""" ret = [] @@ -47,86 +78,44 @@ def create_gemm_operator( element_a, element_b, element_c, element_epilogue = data_type - if batched: - swizzling_functor = SwizzlingFunctor.Batched + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + A = TensorDescription(element_a, LayoutType.RowMajor, alignment) + B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + C = TensorDescription(element_c, LayoutType.RowMajor, alignment) + + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombination, + swizzling_functor, + ) + + src = profiler_emitter.emit( + op.procedural_name(), + kernel_emitter.emit(op, batched=False), + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + + ret.append( + { + "src": src, + "op": op, + "name": op.procedural_name(), + "tile_description": tile_description, + "alignment": alignment, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + } + ) - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - ] - - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - alignment_c = min(8, alignment) - - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) - - op_entry = {} - op = GemmOperation( - tile_description.minimum_compute_capability, - tile_description, - A, - B, - C, - element_epilogue, - EpilogueFunctor.LinearCombination, - swizzling_functor, - ) - op_bias = GemmOperation( - tile_description.minimum_compute_capability, - tile_description, - A, - B, - C, - element_epilogue, - EpilogueFunctor.LinearCombinationBias, - swizzling_functor, - ) - op_bias_relu = GemmOperation( - tile_description.minimum_compute_capability, - tile_description, - A, - B, - C, - element_epilogue, - EpilogueFunctor.LinearCombinationRelu, - swizzling_functor, - ) - op_bias_gelu = GemmOperation( - tile_description.minimum_compute_capability, - tile_description, - A, - B, - C, - element_epilogue, - EpilogueFunctor.LinearCombinationGelu, - swizzling_functor, - ) - - op_entry["op"] = op - op_entry["name"] = op.procedural_name() - op_entry["opdef"] = kernel_emitter.emit(op, batched=batched) - op_entry["opdef_bias"] = kernel_emitter.emit( - op_bias, no_beta_scaling=True, batched=batched - ) - op_entry["opdef_bias_relu"] = kernel_emitter.emit( - op_bias_relu, no_beta_scaling=True, batched=batched - ) - op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched) - op_entry["src"] = profiler_emitter.emit( - op.procedural_name(), - kernel_emitter.emit(op, batched=False), - DataTypeTag[element_a], - DataTypeTag[element_b], - DataTypeTag[element_c], - op.leading_dim(), - ) - op_entry["tile_description"] = tile_description - op_entry["alignment"] = alignment - op_entry["data_type"] = data_type - ret.append(op_entry) return ret @@ -164,30 +153,35 @@ def check_align(self, op_name, M, N, K): # When the above issue is resolved, we can remove the alignment check on M below. return all([dim % align == 0 for dim in [M, N, K]]) - def get_default(self, out_dtype, batched=False): + def get_default(self, op_type, out_dtype, batched=False): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ - ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, op_creator=partial(create_gemm_operator, batched=batched) - ) + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators) default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 - return filtered[0] + op = filtered[0] + name, opdef = create_gemm_operator_with_epilogue( + op_type, + op["tile_description"], + op["data_type"], + op["alignment"], + op["swizzle_functor"], + batched=batched, + ) + op.update({"name": name, "opdef": opdef}) + return op - def profile( - self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False - ): - """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. - If use_multiprocessing is True, compile all profiler executables in parallel. - """ + def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): + """TODO""" if (M, N, K) in self.cache: - return self.cache[(M, N, K)] + op = self.cache[(M, N, K)] + return op ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, op_creator=partial(create_gemm_operator, batched=batched) + out_dtype, + op_creator=enumerate_gemm_operators, ) ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) @@ -201,6 +195,36 @@ def profile( self.cache[(M, N, K)] = op return op - output = min(ops, key=lambda i: i["runtime"]) - self.cache[(M, N, K)] = output - return output + op = min(ops, key=lambda i: i["runtime"]) + self.cache[(M, N, K)] = op + return op + + def profile( + self, + op_type, + M, + N, + K, + out_dtype, + profile_all=True, + use_multiprocessing=False, + batched=False, + ): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + op = self.select_op( + M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + ) + + name, opdef = create_gemm_operator_with_epilogue( + op_type, + op["tile_description"], + op["data_type"], + op["alignment"], + op["swizzle_functor"], + batched=batched, + ) + + return name, opdef, op["runtime"] diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 9ccde37bfe91..4b2486b0cb55 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -27,6 +27,7 @@ OpcodeClass, MathOperation, TileDescription, + EpilogueFunctor, ) logger = logging.getLogger("cutlass") @@ -165,6 +166,23 @@ def get_tile_descriptions(math_inst): } +# (Epilogue functor name, no_beta_scaling) +EPILOGUE_MAP = { + "cutlass.dense": (EpilogueFunctor.LinearCombination, True), + "cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, True), + "cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False), + "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False), + "cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid, False), + "cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.conv2d": (EpilogueFunctor.LinearCombination, True), +} + + class ProfilerEngine: """Compile and run a given profiler executable.""" diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index eb36dc2d7c9f..96691ae47e9d 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """Patterns supported CUTLASS.""" +from functools import partial +from tvm import relay from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name @@ -95,6 +97,8 @@ def check_dtype(lhs, rhs): def get_root_call(call, root_op_name): + if not isinstance(call, relay.Call): + return None if str(call.op) == root_op_name: return call return get_root_call(call.args[0], root_op_name) @@ -151,13 +155,27 @@ def partition_for_cutlass(mod, params=None): make_gemm_pattern(True, "gelu", out_dtype="float32"), check_gemm, ) - cutlass_patterns = [ + + dense_patterns = [ dense_bias_gelu_fp16_pat, dense_bias_gelu_fp32_pat, dense_bias_relu_pat, dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), + ] + + conv2d_patterns = [ + ( + "cutlass.conv2d_bias_hardswish", + make_conv2d_pattern(with_bias=True, with_act="hardswish"), + check_conv2d, + ), + ( + "cutlass.conv2d_bias_silu", + make_conv2d_pattern(with_bias=True, with_act="silu"), + check_conv2d, + ), ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"), @@ -182,6 +200,8 @@ def partition_for_cutlass(mod, params=None): ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] + cutlass_patterns = dense_patterns + conv2d_patterns + if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential( @@ -198,6 +218,7 @@ def partition_for_cutlass(mod, params=None): seq = Sequential( [ transform.InferType(), + transform.SimplifyExpr(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), From 87b36dbbb11adb582ffb628fc6ad62668dcdee7e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 23 Dec 2021 14:59:40 +0900 Subject: [PATCH 2/5] fill in TODO doc --- python/tvm/contrib/cutlass/gen_conv2d.py | 10 ++++++++-- python/tvm/contrib/cutlass/gen_gemm.py | 10 ++++++++-- python/tvm/relay/op/contrib/cutlass.py | 1 - 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 9516d9be82dd..4e4a7b2458e2 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -35,7 +35,10 @@ def create_conv2d_operator_with_epilogue( op_type, tile_description, data_type, alignment, swizzling_functor ): - """TODO""" + """ + Instantiate a cutlass kernel from the given configuration, + along with the epilouge functor + """ epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] element_a, element_b, element_c, element_epilogue = data_type @@ -151,7 +154,10 @@ def select_op( profile_all=True, use_multiprocessing=False, ): - """TODO""" + """ + Profile and select the best kernel from candidate kernels. + See the documentation for the profile method below. + """ N, H, W, IC = d_shape OC, R, S, _ = w_shape workload = ( diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 2cb89cfbfe48..9159ed881c74 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -37,7 +37,10 @@ def create_gemm_operator_with_epilogue( swizzling_functor, batched=False, ): - """TODO""" + """ + Instantiate a cutlass kernel from the given configuration, + along with the epilouge functor + """ element_a, element_b, element_c, element_epilogue = data_type A = TensorDescription(element_a, LayoutType.RowMajor, alignment) @@ -174,7 +177,10 @@ def get_default(self, op_type, out_dtype, batched=False): return op def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): - """TODO""" + """ + Profile and select the best kernel from candidate kernels. + See the documentation for the profile method below. + """ if (M, N, K) in self.cache: op = self.cache[(M, N, K)] return op diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 96691ae47e9d..92b8f9e979ef 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Patterns supported CUTLASS.""" -from functools import partial from tvm import relay from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform From d3b681d95977b6fc0965a0a3ec8af3f866bd9e91 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 23 Dec 2021 15:47:07 +0900 Subject: [PATCH 3/5] fix no_beta_scaling values --- python/tvm/contrib/cutlass/conv2d_operation.py | 2 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 35308928cdab..1c7f9a31b955 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -186,7 +186,7 @@ def __init__(self): >::Kernel; """ - def emit(self, operation, no_beta_scaling=True): + def emit(self, operation, no_beta_scaling=False): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 4b2486b0cb55..6632b159febd 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -168,18 +168,18 @@ def get_tile_descriptions(math_inst): # (Epilogue functor name, no_beta_scaling) EPILOGUE_MAP = { - "cutlass.dense": (EpilogueFunctor.LinearCombination, True), + "cutlass.dense": (EpilogueFunctor.LinearCombination, False), "cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True), "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False), - "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, True), + "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False), "cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False), "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False), "cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid, False), "cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), "cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True), - "cutlass.conv2d": (EpilogueFunctor.LinearCombination, True), + "cutlass.conv2d": (EpilogueFunctor.LinearCombination, False), } From ce9d52fd629d6119abdd471b00ff6a79223d6752 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 23 Dec 2021 15:56:32 +0900 Subject: [PATCH 4/5] Remove SimplifyExpr pass from the pipeline (makes DETR result nan) --- python/tvm/relay/op/contrib/cutlass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 92b8f9e979ef..9750f3c300a4 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -217,7 +217,6 @@ def partition_for_cutlass(mod, params=None): seq = Sequential( [ transform.InferType(), - transform.SimplifyExpr(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), From 6bb1c3bb36b7ec098e97accb7b5d518dc575e484 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 23 Dec 2021 16:10:23 +0900 Subject: [PATCH 5/5] fix bad merge --- python/tvm/relay/op/contrib/cutlass.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 9750f3c300a4..cbbc45a5d1c0 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -165,16 +165,6 @@ def partition_for_cutlass(mod, params=None): ] conv2d_patterns = [ - ( - "cutlass.conv2d_bias_hardswish", - make_conv2d_pattern(with_bias=True, with_act="hardswish"), - check_conv2d, - ), - ( - "cutlass.conv2d_bias_silu", - make_conv2d_pattern(with_bias=True, with_act="silu"), - check_conv2d, - ), ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"),