From a2ca2d7a1d0f21fa1936c1b60aa7db7b74863211 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 28 May 2024 08:39:49 +0000 Subject: [PATCH 1/4] [SME][TOPI] Add conv2d NHWC SME fp32 schedule This commit adds a scalable `arm_cpu` conv2d NHWC schedule for fp32 which generates SME instructions by using the tensor intrinsics introduced in #16921. Alongside the SME schedule, the logic of the TE schedule `schedule_conv2d_gemm_native()` for both non-scalable and scalable vector implementations has also been translated into the new TIR schedule. This means that the TE compute definition `compute_conv2d_NHWC_hybrid()` is now compatible with both the original TE schedules (e.g. `schedule_conv2d_NHWC_hybrid()`) and the newly introduced TIR schedule `schedule_conv2d_NHWC_hybrid_TIR()`. The corresponding TOPI test has been extended to reflect that. --- python/tvm/relay/op/strategy/arm_cpu.py | 15 ++ python/tvm/topi/arm_cpu/arm_utils.py | 18 +- python/tvm/topi/arm_cpu/conv2d.py | 237 +++++++++++++++++- python/tvm/topi/arm_cpu/conv2d_gemm.py | 12 +- python/tvm/topi/nn/conv2d.py | 6 +- src/arith/scalable_expression.cc | 7 - .../codegen/test_target_codegen_aarch64.py | 69 ++++- .../relay/strategy/arm_cpu/test_conv2d.py | 138 +++++++++- .../strategy/test_select_implementation.py | 8 + tests/python/topi/test_topi_conv2d_nhwc.py | 36 ++- 10 files changed, 513 insertions(+), 33 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 9974d2691d4b..ff3676656a5c 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -254,6 +254,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) # Non-quantized cases if is_aarch64 and data.dtype in ["float32", "float16"]: + if ( + target.features.has_sme + and data.dtype in ["float32"] + and kernel.dtype in ["float32"] + and out_type.dtype in ["float32"] + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME), + lambda: None, + name="conv2d_NHWC_hybrid_SME.arm_cpu", + plevel=12, + ) if target.features.has_sve: # This strategy is currently suboptimal because of LLVM's limited support # for scalable vector alias analysis, which causes redundant loads / stores @@ -801,6 +813,9 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool: if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"): topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) return True + elif has_block(sch, "conv2d_gemm_output"): + topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch) + return True # Fallback to TE schedule for operators we have not written a special TIR schedule for return False diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index f2e01c5aefd6..5c4b3c045661 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -22,7 +22,7 @@ from tvm.tir.expr import PrimExpr -def get_tiling_A(interleave_A, in_dtype): +def get_tiling_A(interleave_A, in_dtype, use_sme=False): """Compute the tiling information for matrix A in C=A*B, which corresponds to the im2col-transformed input matrix. @@ -42,6 +42,8 @@ def get_tiling_A(interleave_A, in_dtype): determines if A is expected to be interleaved in_dtype : str input datatype + use_sme : bool + determines if SME operations on scalable vectors are expected Returns ---------- @@ -65,8 +67,11 @@ def get_tiling_A(interleave_A, in_dtype): # tile size should be 4x16 tile_M = 4 tile_K = 16 + elif use_sme: + tile_M = 2 * 4 * tvm.tir.vscale() + tile_K = 2 * 4 * tvm.tir.vscale() else: - # In non-quantized cases, A is not interleaved. + # In non-SME, non-quantized cases, A is not interleaved. # We are loading 4 rows from A. # Each row will contain 4 elements, along the dimension of reduction tile_M = 4 @@ -75,7 +80,7 @@ def get_tiling_A(interleave_A, in_dtype): return tile_M, tile_K -def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False): +def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, use_sme=False): """Compute the tiling information for matrix B', where B' is the tiled, interleaved (and transposed) version of matrix B in C=A*B. @@ -97,6 +102,8 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False) input datatype use_scalable_vectors : bool determines if operations on scalable vectors are expected + use_sme : bool + determines if SME operations on scalable vectors are expected Returns @@ -131,7 +138,10 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False) # we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements tile_N = 4 tile_K = 16 - # In non-quantized cases, A is not interleaved. + elif use_sme: + tile_N = 2 * 4 * tvm.tir.vscale() + tile_K = 2 * 4 * tvm.tir.vscale() + # In non-SME, non-quantized cases, A is not interleaved. elif use_scalable_vectors: if in_dtype == "float16": # Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 44c4f7f76f69..05548f3b013e 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -21,13 +21,15 @@ import tvm from tvm import te from tvm import autotvm +from tvm.script import tir as T import tvm.contrib.nnpack +from tvm.tir.schedule.analysis import has_block from ..utils import traverse_inline, get_const_tuple from .. import nn from ..nn.utils import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices -from .arm_utils import get_tiling_B_transformed +from .arm_utils import get_tiling_A, get_tiling_B_transformed from .conv2d_spatial_pack import ( conv2d_spatial_pack_nchw, conv2d_spatial_pack_nhwc, @@ -527,13 +529,16 @@ def compute_conv2d_NHWC( out_dtype, interleave_A, use_scalable_vectors=False, + use_sme=False, ): """Compute definition for conv2d NHWC""" N, IH, IW, IC = get_const_tuple(data.shape) KH, KW, _, OC = get_const_tuple(kernel.shape) - tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors) + tile_N, tile_K = get_tiling_B_transformed( + interleave_A, data.dtype, use_scalable_vectors, use_sme + ) - kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors) + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors, use_sme) return compute_conv2d_gemm_without_weight_transform( cfg, data, @@ -546,6 +551,7 @@ def compute_conv2d_NHWC( OC, interleave_A, use_scalable_vectors, + use_sme, ) @@ -655,3 +661,228 @@ def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs): """Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE""" return schedule_conv2d_NHWC(cfg, outs, False) + + +@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME.arm_cpu") +def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Interface for hybrid compute_conv2d_NHWC_hybrid_SME""" + return compute_conv2d_NHWC( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + False, + True, + True, + ) + + +def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): + """ + Perform TIR scheduling for conv2d NHWC. + """ + # Get ordered buffer list + primfunc = sch.mod["main"] + buffer_names = primfunc.params + buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names] + dtype = buffer_list[0].dtype + + # Determine PrimFunc blocks + block_list = [ + "data_pad", + "data_im2col", + "T_reshape", + "A_padded_K", + "A_padded_M", + "weight_flatten", + "C", + "conv2d_gemm_output", + ] + func_blocks = {} + for block in block_list: + func_blocks[block] = sch.get_block(block) if has_block(sch, block) else None + + gemm_block = func_blocks["C"] + b, m, n, k = sch.get_loops(gemm_block) + + # Get tiling information + use_scalable_vectors = sch.get(func_blocks["conv2d_gemm_output"]).annotations[ + "use_scalable_vectors" + ] + use_sme = sch.get(func_blocks["conv2d_gemm_output"]).annotations["use_sme"] + M_padded = sch.get(m).extent + N_padded = sch.get(n).extent + K_padded = sch.get(k).extent + tile_M, tile_K = get_tiling_A(False, dtype, use_sme) + tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme) + tile_M = T.cast(tile_M, M_padded.dtype) + tile_N = T.cast(tile_N, N_padded.dtype) + tile_K = T.cast(tile_K, K_padded.dtype) + + # GeMM + # Compute each tile_M x tile_N tile + # By summing up K outer products + if use_sme: + # pylint: disable=import-outside-toplevel + from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes + from tvm.tir.tensor_intrin.arm_cpu import ( + ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, + ARM_SME_INIT, + get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, + ) + + # Interleave the padded im2col matrix utilizing the matrix tile + interleave_t_A_block = sch.cache_read(gemm_block, 0, "global") + sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m)) + b, m, k = sch.get_loops(interleave_t_A_block) + mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) + ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) + sch.reorder(b, ko, mo, ki, mi) + sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) + + # Split and reorder the loops of the GeMM for tensorization + b, m, n, k = sch.get_loops(gemm_block) + mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) + no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True) + sch.parallel(b) + sch.reorder(b, mo, no, mi, ni, k) + + # Tensorize the GeMM output matrix initialization to zero + init_block = sch.decompose_reduction(gemm_block, mi) + sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT) + + # Tensorize the GeMM update + sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}" + tvm.tir.TensorIntrin.register( + sme_gemm_interleaved_intrin_name, + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded), + override=True, + ) + sch.tensorize(mi, sme_gemm_interleaved_intrin_name) + + # Add pstate annotations + root_block = sch.get_block("root") + sch.annotate( + root_block, SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED + ) + sch.annotate(root_block, SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW) + elif use_scalable_vectors: + mo, mi = sch.split(m, [None, tile_M]) + no, ni = sch.split(n, [None, tile_N], disable_predication=True) + ko, ki = sch.split(k, [None, tile_K]) + b_mo_fused = sch.fuse(b, mo) + sch.parallel(b_mo_fused) + sch.reorder( + b_mo_fused, + no, + ko, + ki, + mi, + ni, + ) + sch.vectorize(ni) + sch.unroll(mi) + + # GeMM - Init + # Initialise an entire GeMM tile at once + sch.decompose_reduction(gemm_block, ko) + else: + mo, mi = sch.split(m, [None, tile_M]) + no, ni = sch.split(n, [None, tile_N]) + ko, ki = sch.split(k, [None, tile_K]) + ni_outer, ni_inner = sch.split(ni, [4, None]) + b_mo_fused = sch.fuse(b, mo) + sch.parallel(b_mo_fused) + sch.reorder( + b_mo_fused, + no, + ko, + ki, + ni_outer, + mi, + ni_inner, + ) + sch.vectorize(ni_inner) + sch.unroll(mi) + sch.unroll(ni_outer) + + # GeMM - Init + # Initialise an entire GeMM tile at once + sch.decompose_reduction(gemm_block, ko) + + # Input padding + if func_blocks["data_pad"]: + input_padding_block = func_blocks["data_pad"] + b, h, w, ic = sch.get_loops(input_padding_block) + b_h_fused = sch.fuse(b, h) + sch.parallel(b_h_fused) + + # Im2col + padding to tile size + # Computed outside GeMM + if func_blocks["data_im2col"]: + im2col_block = func_blocks["data_im2col"] + b1, m1, k1 = sch.get_loops(im2col_block) + b_m_fused_1 = sch.fuse(b1, m1) + if func_blocks["A_padded_K"]: + im2col_pad_K_block = func_blocks["A_padded_K"] + b2, m2, k2 = sch.get_loops(im2col_pad_K_block) + b_m_fused_2 = sch.fuse(b2, m2) + sch.parallel(b_m_fused_2) + sch.compute_at(im2col_block, b_m_fused_2) + _, k1 = sch.get_loops(sch.get_block("data_im2col")) + elif func_blocks["A_padded_M"]: + im2col_pad_M_block = func_blocks["A_padded_M"] + b2, m2, k2 = sch.get_loops(im2col_pad_M_block) + b_m_fused_2 = sch.fuse(b2, m2) + sch.parallel(b_m_fused_1) + sch.parallel(b_m_fused_2) + else: + sch.parallel(b_m_fused_1) + + K = sch.get(k1).extent.value + if K % 16 == 0: + split_factor = 16 + elif K % 8 == 0: + split_factor = 8 + else: + IC = buffer_list[0].shape[3] + split_factor = IC + k_outer, k_inner = sch.split(k1, [None, split_factor]) + sch.vectorize(k_inner) + sch.unroll(k_outer) + + # Reshape + padding to tile size + # Computed inside GeMM + elif func_blocks["T_reshape"]: + reshape_block = func_blocks["T_reshape"] + A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None + A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block + if use_sme: + sch.compute_inline(reshape_block) + elif A_pad_block: + sch.compute_inline(reshape_block) + b, m, k = sch.get_loops(A_pad_block) + _, k_inner = sch.split(k, [None, tile_N]) + sch.vectorize(k_inner) + sch.compute_at(A_pad_block, mi) + else: + sch.compute_at(reshape_block, mi) + + # Weight flattening + if func_blocks["weight_flatten"]: + weight_flatten_block = func_blocks["weight_flatten"] + sch.compute_inline(weight_flatten_block) + + # Conv2d output block + output_block = sch.get_block("conv2d_gemm_output") + n, h, w, c = sch.get_loops(output_block) + n_h_fused = sch.fuse(n, h) + _, inner = sch.split(c, [None, 4]) + sch.vectorize(inner) + sch.parallel(n_h_fused) + + return sch diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 5ff2ccb2c137..0c3908bb7017 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -68,6 +68,7 @@ def compute_conv2d_gemm_without_weight_transform( output_channels, interleave_A, use_scalable_vectors=False, + use_sme=False, ): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" @@ -123,9 +124,12 @@ def compute_conv2d_gemm_without_weight_transform( ) # Select the tiling strategy for A and B - tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype) + tile_M, tile_K_A = arm_utils.get_tiling_A(interleave_A, in_dtype, use_sme) tile_N, tile_K_B = arm_utils.get_tiling_B_transformed( - interleave_A, in_dtype, use_scalable_vectors + interleave_A, + in_dtype, + use_scalable_vectors, + use_sme, ) # Pad to tiles (if necessary) @@ -285,7 +289,7 @@ def compute_conv2d_gemm_without_weight_transform( tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] ) - elif use_scalable_vectors: + elif use_scalable_vectors or use_sme: assert len(B_interleaved_t.shape) == 2 C = te.compute( (batches, M_padded, N_padded), @@ -333,7 +337,7 @@ def compute_conv2d_gemm_without_weight_transform( out_shape, lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype), name="conv2d_gemm_output", - attrs={"use_scalable_vectors": use_scalable_vectors}, + attrs={"use_scalable_vectors": use_scalable_vectors, "use_sme": use_sme}, ) return out diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index e21c0bd4e106..8d61c622504b 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -615,7 +615,7 @@ def conv2d_NCHWc_int8( ) -def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False): +def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=False, use_sme=False): """Weight transformation for winograd Parameters @@ -628,6 +628,8 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa Tile size across K axis of the weight transformation for ConvGemm. (K = KW * KH * IC) use_scalable_vectors : bool determines if operations on scalable vectors are expected + use_sme : bool + determines if SME operations on scalable vectors are expected Returns ------- @@ -652,7 +654,7 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) - if use_scalable_vectors: + if use_sme or use_scalable_vectors: return kernel_flat if kernel.dtype in ["int8", "uint8"]: diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 2df035d6151a..8821ef661c04 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -71,15 +71,8 @@ std::optional ExtractVscaleFactor(const PrimExpr& lanes) { } } -bool IsComparison(const PrimExpr& expr) { - return expr->IsInstance() || expr->IsInstance() || - expr->IsInstance() || expr->IsInstance() || - expr->IsInstance() || expr->IsInstance(); -} - bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr, const std::vector& vscale_values) { - ICHECK(IsComparison(expr)) << "Expected comparison but got: " << expr; bool can_prove_expr = true; for (const unsigned int vscale_value : vscale_values) { PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value); diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index f73d96e7c916..d76a8488e07f 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -729,20 +729,36 @@ def prim_func(a: T.handle, c: T.handle): llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) @pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_conv2d_sve(dtype): +@pytest.mark.parametrize( + "conv2d_impl", + [ + ( + tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, + tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE, + False, + ), + ( + tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, + tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, + True, + ), + ], +) +def test_conv2d_sve(dtype, conv2d_impl): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(dtype): + def check_correct_assembly(dtype, compute, schedule, use_tir_schedule): A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A") W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B") stride = padding = dilation = 1 - - compute = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE - schedule = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE B = compute(A, W, stride, padding, dilation, dtype) - s = schedule([B]) - - f = tvm.build(s, [A, W, B], target) + if use_tir_schedule: + func = te.create_prim_func([A, W, B]) + sch = schedule(tvm.tir.Schedule(func)) + f = tvm.build(sch.mod["main"], target) + else: + s = schedule([B]) + f = tvm.build(s, [A, W, B], target) assembly = f.get_source("asm") loads = re.findall(r"ld1[r]?[q]?[whdb]\t{\s?z", assembly) @@ -756,6 +772,43 @@ def check_correct_assembly(dtype): assert len(compute_ops) > 0 assert len(stores) > 0 + with tvm.target.Target(target): + check_correct_assembly(dtype, *conv2d_impl) + + +@pytest.mark.skipif( + llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SVE" +) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_conv2d_sme(dtype): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v9a,+sme" + + def check_correct_assembly(dtype): + A = te.placeholder((1, 32, 32, 3), dtype=dtype, name="A") + W = te.placeholder((3, 3, 3, 8), dtype=dtype, name="B") + stride = padding = dilation = 1 + + B = tvm.topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME(A, W, stride, padding, dilation, dtype) + func = te.create_prim_func([A, W, B]) + sch = tvm.topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(tvm.tir.Schedule(func)) + f = tvm.build(sch.mod["main"], target) + + assembly = f.get_source("asm") + smstart = re.findall(r"smstart\t(sm|za)", assembly) + loads = re.findall(r"ld1[whdb]\t{\s?za", assembly) + mopa = re.findall( + r"fmopa\tza[0-9].[shdb],( p[0-9]/[zm],)?( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", + assembly, + ) + stores = re.findall(r"st1[whdb]\t{\s?za", assembly) + smstop = re.findall(r"smstop\t(sm|za)", assembly) + + assert len(smstart) > 0 + assert len(loads) > 0 + assert len(mopa) > 0 + assert len(stores) > 0 + assert len(smstop) > 0 + with tvm.target.Target(target): check_correct_assembly(dtype=dtype) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_conv2d.py index 1b9c1a5e2e94..a22d61a029d0 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py @@ -16,8 +16,22 @@ # under the License. """Tests for arm_cpu schedules for regular conv2d.""" +from tests.python.relay.strategy.arm_cpu.scalable_utils import ( + calculate_extra_workspace_size_from_scalable_extents, +) +import tvm +import pytest +import numpy as np from test_generalized_conv2d import GeneralizedConv2dTests +from tvm.target.codegen import llvm_version_major from tvm.testing import fixture, main, parameter, parameters +from tvm import relay +import tvm.topi.testing +from tvm.topi.nn.utils import get_pad_tuple +from tvm.topi.utils import get_const_tuple +from tvm.testing.aot import AOTTestModel, AOTCompiledTestModel, run_and_check, generate_ref_data +from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER +from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy class Conv2dTests(GeneralizedConv2dTests): @@ -107,5 +121,127 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu") +dtype = tvm.testing.parameter("float32") + +batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( + # Pad M, N, K + (1, 1, 1, 1, 1, 1, "SAME", 1), + (1, 1, 3, 15, 1, 1, "SAME", 1), + # Pad M, K + (1, 3, 9, 16, 3, 1, "SAME", 1), + # Pad M, N + (1, 2, 9, 15, 4, 1, "SAME", 1), + # Pad K, N + (1, 7, 4, 15, 3, 1, "SAME", 1), + # Pad M + (1, 2, 9, 16, 4, 1, "SAME", 1), + # Pad K + (1, 7, 4, 16, 3, 1, "SAME", 1), + # Pad N + (1, 2, 4, 15, 4, 1, "SAME", 1), + (1, 2, 4, 20, 1, 1, "SAME", 1), + # Large workloads + (1, 128, 32, 128, 3, 1, "SAME", 1), + (4, 64, 16, 64, 5, 2, "SAME", 1), + (1, 128, 32, 128, 3, 1, "VALID", 1), + (4, 64, 16, 64, 5, 2, "VALID", 1), + (1, 64, 16, 64, 3, 2, (0, 0, 1, 1), 1), + (1, 64, 16, 64, 3, 2, (1, 1, 2, 2), 1), + (1, 64, 16, 64, 5, 2, (3, 3, 2, 2), 1), + (1, 64, 16, 64, 3, 2, (0, 1, 2, 3), 1), + (1, 64, 32, 64, 3, 1, "SAME", 2), + (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2), +) + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): + in_height = in_width = in_size + a_shape = (batch, in_height, in_width, in_channel) + w_shape = (kernel, kernel, in_channel, num_filter) + + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + return a_np, w_np + + +@pytest.mark.skipif( + llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" +) +@tvm.testing.requires_aprofile_aem_fvp +def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): + a_np, w_np = ref_data + dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + + kernel_size = get_const_tuple(w_np.shape[:2]) + out_channels = w_np.shape[3] + + x = relay.var("data", shape=a_np.shape, dtype=dtype) + weight = relay.const(w_np, dtype=dtype) + conv2d = relay.nn.conv2d( + x, + weight, + channels=out_channels, + kernel_size=kernel_size, + strides=stride, + dilation=dilation, + padding=get_pad_tuple(padding, dw_np.shape[:2]), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype=dtype, + ) + + func = relay.Function(relay.analysis.free_vars(conv2d), conv2d) + + ir_mod = tvm.IRModule.from_expr(func) + ir_mod = tvm.relay.transform.InferType()(ir_mod) + + inputs = {"data": a_np} + params = {} + ref_outputs = generate_ref_data(ir_mod, inputs, params) + + target = tvm.target.Target("llvm -mtriple=aarch64-none-elf -mattr=+v9.2a,+sme") + runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) + executor = tvm.relay.backend.Executor( + "aot", + { + "interface-api": "packed", + "unpacked-api": False, + }, + ) + + with tvm.transform.PassContext( + opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config + ), tvm.meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + executor_factory = tvm.relay.build( + ir_mod, + target=target, + executor=executor, + runtime=runtime, + params=params, + ) + generated_func = executor_factory.lowered_ir_mods.items()[0][1][ + "tvmgen_default_fused_nn_conv2d" + ] + extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4) + + test_model = AOTTestModel( + ir_mod, inputs, ref_outputs, params=params, extra_memory_in_bytes=extra_memory_in_bytes + ) + compiled = AOTCompiledTestModel(test_model, executor_factory) + + assembly = ( + compiled.executor_factory.module.imported_modules[0].imported_modules[0].get_source("asm") + ) + assert "fmopa" in assembly + + assert run_and_check( + models=[compiled], + interface_api="packed", + runner=AOT_APROFILE_AEM_RUNNER, + print_output_on_mismatch=True, + ) + + if __name__ == "__main__": - main() + tvm.testing.main() diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 71dd688e2929..01a914e793c1 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -161,6 +161,10 @@ def test_int8_conv2d(target, expected_impl): "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9a", "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", + "conv2d_NHWC_hybrid_SME.arm_cpu", + ), ], ) def test_fp32_conv2d(target, expected_impl): @@ -197,6 +201,10 @@ def test_fp32_conv2d(target, expected_impl): "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a", "conv2d_NHWC_hybrid_without_transform.arm_cpu", ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), ], ) def test_fp16_conv2d(target, expected_impl): diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index b5c9518d3419..028d1a6cbb47 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -51,16 +51,37 @@ "llvm --device arm_cpu --mtriple aarch64-linux-gnu", topi.arm_cpu.conv2d_nhwc_spatial_pack, topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, + False, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16", topi.arm_cpu.compute_conv2d_NHWC_hybrid, topi.arm_cpu.schedule_conv2d_NHWC_hybrid, + False, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve", topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, topi.arm_cpu.schedule_conv2d_NHWC_hybrid_SVE, + False, + ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", + topi.arm_cpu.compute_conv2d_NHWC_hybrid, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, + True, + ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.6a,+sve", + topi.arm_cpu.compute_conv2d_NHWC_hybrid_SVE, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, + True, + ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v9a,+sme", + topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME, + topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, + True, ), ) @@ -68,6 +89,7 @@ batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( # Pad M, N, K + (1, 1, 1, 1, 1, 1, "SAME", 1), (1, 1, 3, 15, 1, 1, "SAME", 1), # Pad M, K (1, 3, 9, 16, 3, 1, "SAME", 1), @@ -139,16 +161,21 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): A = te.placeholder(a_np.shape, name="A", dtype=dtype) W = te.placeholder(w_np.shape, name="W", dtype=dtype) - target, compute, schedule = device + target, compute, schedule, use_tir_schedule = device dev = tvm.device(target, 0) with tvm.target.Target(target) as target: - B = compute(A, W, stride, padding, dilation, dtype) - s = schedule([B]) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) + B = compute(A, W, stride, padding, dilation, dtype) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - func = tvm.build(s, [A, W, B], target) + if use_tir_schedule: + primfunc = te.create_prim_func([A, W, B], index_dtype_override="int64") + sch = schedule(tvm.tir.Schedule(primfunc)) + func = tvm.build(sch.mod["main"], target) + else: + s = schedule([B]) + func = tvm.build(s, [A, W, B], target) # Run only on AArch64 devices # Do not run SVE schedules on non-SVE devices @@ -160,6 +187,7 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): and target.features.has_fp16_simd and not tvm.testing.requires_arm_fp16.run_time_check() ) + or target.features.has_sme ) if build_only: return From fcc9013c53a64d9247da5fa3eee17812d577bbc0 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 28 May 2024 08:39:49 +0000 Subject: [PATCH 2/4] Fix tests --- tests/python/arith/test_arith_simplify.py | 10 ---------- tests/python/topi/test_topi_conv2d_nhwc.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index fd8316d1e007..1a876548af31 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -90,16 +90,6 @@ def test_simplify_vscale_comparison_without_sve_target(capfd): assert warning_msg in capture -def test_simplify_vscale_non_comparison(): - ana = tvm.arith.Analyzer() - vs = tvm.tir.vscale() - - err_msg = r".*Expected comparison but got: T.vscale\(\) \* 4" - with pytest.raises(tvm.TVMError, match=err_msg): - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): - ana.can_prove(vs * 4) - - def test_regression_simplify_inf_recursion(): ana = tvm.arith.Analyzer() cond = tir.Var("cond", "int32") diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 028d1a6cbb47..672e34c87911 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import topi +from tvm.target.codegen import llvm_version_major import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.utils import get_const_tuple @@ -161,10 +162,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): A = te.placeholder(a_np.shape, name="A", dtype=dtype) W = te.placeholder(w_np.shape, name="W", dtype=dtype) - target, compute, schedule, use_tir_schedule = device - dev = tvm.device(target, 0) + target_string, compute, schedule, use_tir_schedule = device + dev = tvm.device(target_string, 0) + target = tvm.target.Target(target_string) + + if (target.features.has_sve and llvm_version_major() < 15) or ( + target.features.has_sme and llvm_version_major() < 16 + ): + return - with tvm.target.Target(target) as target: + with target: a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) B = compute(A, W, stride, padding, dilation, dtype) From 72f8bbd632f81865c8d7d3eda3ca8da9cbd0768b Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 28 May 2024 08:40:29 +0000 Subject: [PATCH 3/4] Address comments --- python/tvm/testing/utils.py | 7 +++++++ python/tvm/topi/arm_cpu/conv2d.py | 3 ++- .../codegen/test_target_codegen_aarch64.py | 2 +- .../relay/strategy/arm_cpu/test_conv2d.py | 18 +++++++++--------- tests/python/topi/test_topi_conv2d_nhwc.py | 14 ++++++++------ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 84b631cf3823..a208459dd88d 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1071,6 +1071,13 @@ def _has_cpu_feat(features): ) +requires_aarch64_sme = Feature( + "arm_sme", + "AArch64 SME", + run_time_check=lambda: _has_cpu_feat("sme"), +) + + requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 05548f3b013e..58c909301ede 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -741,6 +741,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): b, m, k = sch.get_loops(interleave_t_A_block) mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) + sch.parallel(b) sch.reorder(b, ko, mo, ki, mi) sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) @@ -878,7 +879,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): sch.compute_inline(weight_flatten_block) # Conv2d output block - output_block = sch.get_block("conv2d_gemm_output") + output_block = func_blocks["conv2d_gemm_output"] n, h, w, c = sch.get_loops(output_block) n_h_fused = sch.fuse(n, h) _, inner = sch.split(c, [None, 4]) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index d76a8488e07f..9d44ca9e11ea 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -777,7 +777,7 @@ def check_correct_assembly(dtype, compute, schedule, use_tir_schedule): @pytest.mark.skipif( - llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SVE" + llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME" ) @pytest.mark.parametrize("dtype", ["float32"]) def test_conv2d_sme(dtype): diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_conv2d.py index a22d61a029d0..2708094afb08 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py @@ -16,22 +16,21 @@ # under the License. """Tests for arm_cpu schedules for regular conv2d.""" -from tests.python.relay.strategy.arm_cpu.scalable_utils import ( - calculate_extra_workspace_size_from_scalable_extents, -) -import tvm import pytest import numpy as np + +import tvm +import tvm.topi.testing +from tvm import relay from test_generalized_conv2d import GeneralizedConv2dTests -from tvm.target.codegen import llvm_version_major from tvm.testing import fixture, main, parameter, parameters -from tvm import relay -import tvm.topi.testing from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple +from tvm.target.codegen import llvm_version_major from tvm.testing.aot import AOTTestModel, AOTCompiledTestModel, run_and_check, generate_ref_data from tvm.micro.testing.aot_test_utils import AOT_APROFILE_AEM_RUNNER from tvm.relay.op.strategy.arm_cpu import arm_cpu_tir_strategy +from scalable_utils import calculate_extra_workspace_size_from_scalable_extents class Conv2dTests(GeneralizedConv2dTests): @@ -154,8 +153,9 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): ) -@tvm.testing.fixture(cache_return_value=True) +@tvm.testing.fixture() def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): + np.random.seed(0) in_height = in_width = in_size a_shape = (batch, in_height, in_width, in_channel) w_shape = (kernel, kernel, in_channel, num_filter) @@ -212,7 +212,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): with tvm.transform.PassContext( opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config - ), tvm.meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): + ), target, tvm.meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy): executor_factory = tvm.relay.build( ir_mod, target=target, diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 672e34c87911..30401337a068 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -17,6 +17,7 @@ """Example code to do convolution.""" import os import platform +import pytest import numpy as np import tvm from tvm import te @@ -166,10 +167,11 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): dev = tvm.device(target_string, 0) target = tvm.target.Target(target_string) - if (target.features.has_sve and llvm_version_major() < 15) or ( - target.features.has_sme and llvm_version_major() < 16 - ): - return + if target.features.has_sve and llvm_version_major() < 15: + pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SVE.") + + if target.features.has_sme and llvm_version_major() < 16: + pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SME.") with target: a = tvm.nd.array(a_np, dev) @@ -177,7 +179,7 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): B = compute(A, W, stride, padding, dilation, dtype) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) if use_tir_schedule: - primfunc = te.create_prim_func([A, W, B], index_dtype_override="int64") + primfunc = te.create_prim_func([A, W, B]) sch = schedule(tvm.tir.Schedule(primfunc)) func = tvm.build(sch.mod["main"], target) else: @@ -194,7 +196,7 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): and target.features.has_fp16_simd and not tvm.testing.requires_arm_fp16.run_time_check() ) - or target.features.has_sme + or (target.features.has_sme and not tvm.testing.requires_aarch64_sme.run_time_check()) ) if build_only: return From 8e934620fcca171982e043b6354570c59dca1641 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 28 May 2024 09:26:13 +0000 Subject: [PATCH 4/4] Disable fp16 conv2d SME testing --- tests/python/topi/test_topi_conv2d_nhwc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 30401337a068..02f16b59c00b 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -173,6 +173,9 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): if target.features.has_sme and llvm_version_major() < 16: pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SME.") + if target.features.has_sme and dtype == "float16": + pytest.skip(f"Conv2d fp16 targetting SME not implemented.") + with target: a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev)