From b75199e80ebed61f914dacde08c4ff8090d33a3f Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 29 Mar 2022 12:09:12 -0700 Subject: [PATCH 1/2] [CUBLAS] Add cuBLAS as a Relay partitioning target (BYOC) This PR adds a partitioning pass for cuBLAS so that supported Relay patterns can be offloaded to cuBLAS. This initial commit only adds offloading support for nn.matmul. Although cuBLAS is already enabled in TVM by using strategy selection in TE, by exposing it explicitly as a Relay partitioning target we can more precisely describe how to execute a model in Relay. This is desirable particularly in the Collage effort to improve multi-backend graph partitioning. --- python/tvm/relay/op/contrib/cublas.py | 158 ++++++++++++++++++++++++++ tests/python/contrib/test_cublas.py | 90 ++++++++++++++- tests/scripts/task_mypy.sh | 3 + 3 files changed, 248 insertions(+), 3 deletions(-) create mode 100644 python/tvm/relay/op/contrib/cublas.py diff --git a/python/tvm/relay/op/contrib/cublas.py b/python/tvm/relay/op/contrib/cublas.py new file mode 100644 index 000000000000..a1c807adfee5 --- /dev/null +++ b/python/tvm/relay/op/contrib/cublas.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""cuBLAS Relay integration.""" +from typing import Callable, List, Tuple, Dict, Optional + +import tvm +import tvm.ir +from tvm import relay +from tvm import te +from tvm.relay import transform +from tvm.contrib import cublas + +from ...dataflow_pattern import is_op, wildcard +from .register import register_pattern_table + + +def partition_for_cublas( + mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None +) -> tvm.IRModule: + """Partition the graph to offload for cuBLAS. + + Parameters + ---------- + mod : tvm.IRModule + The module to partition. + params : Optional[Dict[str, tvm.runtime.NDArray]] + Constant input parameters. + + Returns + ------- + tvm.IRModule + The partitioned module. + """ + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("cublas"), + transform.PartitionGraph(), + transform.InferType(), + ] + ) + return seq(mod) + + +@register_pattern_table("cublas") +def pattern_table() -> List[Tuple[str, relay.Pattern, Callable]]: + """Get the cuBLAS pattern table.""" + + def matmul_pattern() -> relay.Pattern: + """Create pattern for matrix multiply.""" + return is_op("nn.matmul")(wildcard(), wildcard()) + + def check_matmul(matched: relay.Call) -> bool: + """Check if matmul is supported by cuBLAS.""" + # Units not supported + if matched.attrs["units"] != None: + return False + # Input data types can't be mixed + if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype: + return False + in_dtype = matched.args[0].checked_type.dtype + out_dtype = matched.checked_type.dtype + # Only the following data type combinations are supported + if (in_dtype, out_dtype) not in [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ]: + return False + # If inputs are int8, input column strides must be a multiple of 4 + if in_dtype == "int8": + if ( + matched.args[0].checked_type.shape[1] % 4 != 0 + or matched.args[1].checked_type.shape[1] % 4 != 0 + ): + return False + + return True + + return [ + ("cublas.matmul", matmul_pattern(), check_matmul), + ] + + +_LOWER_MAP = {} + + +def lower_composite(comp_name: str) -> Callable: + """Register a lowering function for a given composite function name.""" + + def _register(f: Callable) -> Callable: + _LOWER_MAP[comp_name] = f + return f + + return _register + + +@lower_composite("cublas.matmul") +def lower_matmul( + comp_func: relay.Function, target: tvm.target.Target, global_name: str +) -> tvm.runtime.Module: + """Lower a matmul using cuBLAS.""" + op = comp_func.body + A = te.placeholder( + comp_func.params[0].checked_type.shape, + name="A", + dtype=comp_func.params[0].checked_type.dtype, + ) + B = te.placeholder( + comp_func.params[1].checked_type.shape, + name="B", + dtype=comp_func.params[1].checked_type.dtype, + ) + C = cublas.matmul( + A, + B, + transa=op.attrs["transpose_a"], + transb=op.attrs["transpose_b"], + dtype=comp_func.body.checked_type.dtype, + ) + s = te.create_schedule(C.op) + return tvm.build(s, [A, B, C], target=target, name=global_name) + + +@tvm._ffi.register_func("relay.ext.cublas") +def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: + """Compile cuBLAS Relay functions to a runtime module.""" + assert isinstance(partition, relay.Function) + assert isinstance(partition.body, relay.Call) + assert isinstance(partition.body.op, relay.Function) + + global_name = str(partition.attrs.global_symbol) + target = tvm.target.cuda() + comp_func = partition.body.op + comp_name = comp_func.attrs["Composite"] + assert comp_name in _LOWER_MAP + + return _LOWER_MAP[comp_name](comp_func, target, global_name) diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 210e6877c926..6af13a3f0778 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -14,12 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm from tvm import te +from tvm import relay import numpy as np from tvm.contrib import cublas from tvm.contrib import cublaslt +from tvm.contrib import graph_executor import tvm.testing +from tvm.relay.op.contrib import get_pattern_table +from tvm.relay.op.contrib.cublas import partition_for_cublas def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): @@ -170,7 +176,85 @@ def test_batch_matmul(): verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,m,k,transpose_A,transpose_B", + [ + (64, 128, 32, False, False), + (17, 32, 16, True, False), + (24, 17, 12, False, True), + (96, 4, 17, True, True), + ], +) +@pytest.mark.parametrize( + "in_dtype,out_dtype", + [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ], +) +def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_A, transpose_B): + np.random.seed(42) + pattern_table = get_pattern_table("cublas") + matmul_pattern = None + check_matmul = None + for pattern_name, pattern, check_func in pattern_table: + if pattern_name == "cublas.matmul": + matmul_pattern = pattern + check_matmul = check_func + + assert matmul_pattern is not None + assert check_matmul is not None + A_shape = (k, n) if transpose_A else (n, k) + B_shape = (m, k) if transpose_B else (k, m) + + def _create_mod(): + mod = tvm.IRModule() + A_var = tvm.relay.var("A", tvm.relay.TensorType(A_shape, in_dtype)) + B_var = tvm.relay.var("B", tvm.relay.TensorType(B_shape, in_dtype)) + # Directly use matmul because nn.matmul sometimes defers to nn.dense + C_var = relay.op.nn._make.matmul(A_var, B_var, None, out_dtype, transpose_A, transpose_B) + f = tvm.relay.Function([A_var, B_var], C_var) + mod["main"] = f + mod = relay.transform.InferType()(mod) + return mod + + mod = _create_mod() + if not matmul_pattern.match(mod["main"].body) or not check_matmul(mod["main"].body): + pytest.skip("Unsupported parameters") + + cublas_mod = partition_for_cublas(mod) + # Assert that a new global function has been created for the cuBLAS partition + assert len(cublas_mod.get_global_vars()) == 2 + + A_data = np.random.uniform(0, 32, size=A_shape).astype(in_dtype) + B_data = np.random.uniform(0, 32, size=B_shape).astype(in_dtype) + + # Test against CPU reference + cuda_config = (tvm.target.cuda(), tvm.cuda(), cublas_mod) + cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod) + outputs = [] + for target, dev, test_mod in [cuda_config, cpu_config]: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(test_mod, target=target, target_host=cpu_config[0]) + a = tvm.nd.array(A_data, dev) + b = tvm.nd.array(B_data, dev) + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input("A", a) + module.set_input("B", b) + module.run() + outputs.append(module.get_output(0, tvm.nd.empty((n, m), dtype=out_dtype)).numpy()) + + tvm.testing.assert_allclose( + outputs[0], + outputs[1], + rtol=1e-2, + ) + + if __name__ == "__main__": - test_matmul_add() - test_batch_matmul() - test_matmul_add_igemm() + pytest.main([__file__]) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 2148aeb5e4b4..b7589d1d30e8 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -36,6 +36,9 @@ mypy --check-untyped-defs python/tvm/tir/transform/ echo "Checking MyPy Type defs in the TIR package with unittest" MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py +echo "Checking MyPy Type defs in tvm.relay.op.contrib.cublas" +mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py + #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ From eb4c71f01ae350618b826bb5a3921442515d2850 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 30 Mar 2022 10:49:25 +0000 Subject: [PATCH 2/2] Refactor to remove boilerplate --- python/tvm/relay/op/contrib/cublas.py | 66 +++++++-------- tests/python/contrib/test_cublas.py | 114 +++++++++++++------------- 2 files changed, 90 insertions(+), 90 deletions(-) diff --git a/python/tvm/relay/op/contrib/cublas.py b/python/tvm/relay/op/contrib/cublas.py index a1c807adfee5..09505cdaa8d1 100644 --- a/python/tvm/relay/op/contrib/cublas.py +++ b/python/tvm/relay/op/contrib/cublas.py @@ -60,7 +60,7 @@ def partition_for_cublas( @register_pattern_table("cublas") -def pattern_table() -> List[Tuple[str, relay.Pattern, Callable]]: +def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]: """Get the cuBLAS pattern table.""" def matmul_pattern() -> relay.Pattern: @@ -70,7 +70,7 @@ def matmul_pattern() -> relay.Pattern: def check_matmul(matched: relay.Call) -> bool: """Check if matmul is supported by cuBLAS.""" # Units not supported - if matched.attrs["units"] != None: + if matched.attrs["units"] is not None: return False # Input data types can't be mixed if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype: @@ -102,46 +102,20 @@ def check_matmul(matched: relay.Call) -> bool: ] -_LOWER_MAP = {} +_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor] +_LOWER_MAP: Dict[str, _LowerFunc] = {} -def lower_composite(comp_name: str) -> Callable: +def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]: """Register a lowering function for a given composite function name.""" - def _register(f: Callable) -> Callable: + def _register(f: _LowerFunc) -> _LowerFunc: _LOWER_MAP[comp_name] = f return f return _register -@lower_composite("cublas.matmul") -def lower_matmul( - comp_func: relay.Function, target: tvm.target.Target, global_name: str -) -> tvm.runtime.Module: - """Lower a matmul using cuBLAS.""" - op = comp_func.body - A = te.placeholder( - comp_func.params[0].checked_type.shape, - name="A", - dtype=comp_func.params[0].checked_type.dtype, - ) - B = te.placeholder( - comp_func.params[1].checked_type.shape, - name="B", - dtype=comp_func.params[1].checked_type.dtype, - ) - C = cublas.matmul( - A, - B, - transa=op.attrs["transpose_a"], - transb=op.attrs["transpose_b"], - dtype=comp_func.body.checked_type.dtype, - ) - s = te.create_schedule(C.op) - return tvm.build(s, [A, B, C], target=target, name=global_name) - - @tvm._ffi.register_func("relay.ext.cublas") def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: """Compile cuBLAS Relay functions to a runtime module.""" @@ -154,5 +128,31 @@ def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: comp_func = partition.body.op comp_name = comp_func.attrs["Composite"] assert comp_name in _LOWER_MAP + assert isinstance(comp_func.body, relay.Call) - return _LOWER_MAP[comp_name](comp_func, target, global_name) + op = comp_func.body + inputs = [] + for i, param in enumerate(comp_func.params): + inputs.append( + te.placeholder( + param.checked_type.shape, + name=f"input_{i}", + dtype=param.checked_type.dtype, + ) + ) + + output = _LOWER_MAP[comp_name](op, inputs) + prim_func = te.create_prim_func(inputs + [output]) + return tvm.build(prim_func, target=target, name=global_name) + + +@_lower_composite("cublas.matmul") +def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a matmul using cuBLAS.""" + return cublas.matmul( + inputs[0], + inputs[1], + transa=op.attrs["transpose_a"], + transb=op.attrs["transpose_b"], + dtype=op.checked_type.dtype, + ) diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 6af13a3f0778..64d954e50cfe 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -176,63 +176,21 @@ def test_batch_matmul(): verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") -@tvm.testing.requires_cuda -@pytest.mark.parametrize( - "n,m,k,transpose_A,transpose_B", - [ - (64, 128, 32, False, False), - (17, 32, 16, True, False), - (24, 17, 12, False, True), - (96, 4, 17, True, True), - ], -) -@pytest.mark.parametrize( - "in_dtype,out_dtype", - [ - ("float32", "float32"), - ("float16", "float16"), - ("float16", "float32"), - ("int8", "int32"), - ("float64", "float64"), - ("int8", "float32"), - ], -) -def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_A, transpose_B): +def _verify_cublas_relay(expr): np.random.seed(42) - pattern_table = get_pattern_table("cublas") - matmul_pattern = None - check_matmul = None - for pattern_name, pattern, check_func in pattern_table: - if pattern_name == "cublas.matmul": - matmul_pattern = pattern - check_matmul = check_func - - assert matmul_pattern is not None - assert check_matmul is not None - A_shape = (k, n) if transpose_A else (n, k) - B_shape = (m, k) if transpose_B else (k, m) - - def _create_mod(): - mod = tvm.IRModule() - A_var = tvm.relay.var("A", tvm.relay.TensorType(A_shape, in_dtype)) - B_var = tvm.relay.var("B", tvm.relay.TensorType(B_shape, in_dtype)) - # Directly use matmul because nn.matmul sometimes defers to nn.dense - C_var = relay.op.nn._make.matmul(A_var, B_var, None, out_dtype, transpose_A, transpose_B) - f = tvm.relay.Function([A_var, B_var], C_var) - mod["main"] = f - mod = relay.transform.InferType()(mod) - return mod - - mod = _create_mod() - if not matmul_pattern.match(mod["main"].body) or not check_matmul(mod["main"].body): - pytest.skip("Unsupported parameters") + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + func = mod["main"] cublas_mod = partition_for_cublas(mod) - # Assert that a new global function has been created for the cuBLAS partition assert len(cublas_mod.get_global_vars()) == 2 - A_data = np.random.uniform(0, 32, size=A_shape).astype(in_dtype) - B_data = np.random.uniform(0, 32, size=B_shape).astype(in_dtype) + input_data = [] + for param in func.params: + shape = [int(x) for x in param.checked_type.shape] + input_data.append( + (param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype)) + ) # Test against CPU reference cuda_config = (tvm.target.cuda(), tvm.cuda(), cublas_mod) @@ -241,13 +199,15 @@ def _create_mod(): for target, dev, test_mod in [cuda_config, cpu_config]: with tvm.transform.PassContext(opt_level=3): lib = relay.build(test_mod, target=target, target_host=cpu_config[0]) - a = tvm.nd.array(A_data, dev) - b = tvm.nd.array(B_data, dev) module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input("A", a) - module.set_input("B", b) + for name, data in input_data: + module.set_input(name, tvm.nd.array(data, dev)) + module.run() - outputs.append(module.get_output(0, tvm.nd.empty((n, m), dtype=out_dtype)).numpy()) + out_type = func.body.checked_type + outputs.append( + module.get_output(0, tvm.nd.empty(out_type.shape, dtype=out_type.dtype)).numpy() + ) tvm.testing.assert_allclose( outputs[0], @@ -256,5 +216,45 @@ def _create_mod(): ) +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,m,k,transpose_a,transpose_b", + [ + (64, 128, 32, False, False), + (17, 32, 16, True, False), + (24, 17, 12, False, True), + (96, 4, 17, True, True), + ], +) +@pytest.mark.parametrize( + "in_dtype,out_dtype", + [ + ("float32", "float32"), + ("float16", "float16"), + ("float16", "float32"), + ("int8", "int32"), + ("float64", "float64"), + ("int8", "float32"), + ], +) +def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpose_b): + unsupported_configs = [ + (17, 32, 16, "int8", "float32", True, False), + (96, 4, 17, "int8", "float32", True, True), + (17, 32, 16, "int8", "int32", True, False), + (96, 4, 17, "int8", "int32", True, True), + ] + if (n, m, k, in_dtype, out_dtype, transpose_a, transpose_b) in unsupported_configs: + pytest.skip("Unsupported parameters.") + + a_shape = (k, n) if transpose_a else (n, k) + b_shape = (m, k) if transpose_b else (k, m) + a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype)) + b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype)) + # Directly use matmul because nn.matmul sometimes defers to nn.dense + matmul = relay.op.nn._make.matmul(a, b, None, out_dtype, transpose_a, transpose_b) + _verify_cublas_relay(matmul) + + if __name__ == "__main__": pytest.main([__file__])