diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 35fd2b7a78d7..f4b47084017b 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -110,6 +110,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" strategy = _op.OpStrategy() data, kernel = inputs + data_shape = data.shape + kernel_shape = kernel.shape dilation_h, dilation_w = attrs.get_int_tuple("dilation") stride_h, stride_w = attrs.get_int_tuple("strides") padding = attrs.get_int_tuple("padding") @@ -258,6 +260,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): target.features.has_sme and kernel.dtype == data.dtype and out_type.dtype == "float32" + and data_shape[0] == 1 + # The schedule uses tensorization which does not work when the + # reduction axis of the gemm has unit iters. See + # https://github.com/apache/tvm/issues/16566 + and (data_shape[3] * kernel_shape[0] * kernel_shape[1]) > 1 ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME), diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 3a3430af514f..a6f3538846e7 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -176,7 +176,51 @@ def _create_ptrue_mask(dtype): return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype)) -def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): +def _create_active_lane_mask(tensor, relative_offsets, vertical_limit): + """ + Get the active lane mask intrinsic call for predicated accesses. + + Parameters + ---------- + tensor : tvm.tir.Buffer + The tensor the buffer access will be performed on. + relative_offsets : Tuple[PrimExpr, PrimExpr] + The vertical and horizontal offsets into the accumulator tile. + vertical_limit : PrimExpr + An absolute offset specifying the limit at which rows should be stored. + + Returns + ------- + PrimExpr + The active lane mask intrinsic. + """ + vertical_offset, horizontal_offset = relative_offsets + stride = tensor.strides[0] + + # The base is the offset of the first value we wish to store + base = T.int32(tensor.offset_of([vertical_offset, horizontal_offset])[0]) + + # The limit is the maximum offset in the current row of 'base' that we wish to allow values + # to be stored. Calculating this limit is a bit tricky since we can only request offsets of + # elements in the tensorized tile of the output tensor. One way to calculate this is to find + # the offset of the first value in the row of the output tensor that 'base' is in and add + # 'stride' to it. + limit = ( + base + - T.int32(horizontal_offset) + - T.int32((tensor.offset_of([0, 0])[0] % stride)) + + T.int32(stride) + ) + limit = T.Min(limit, T.Cast("int32", vertical_limit) * stride) + + return T.get_active_lane_mask( + "uint1xvscalex4", + T.Cast("int32", base), + T.Cast("int32", limit), + ) + + +def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows): """ Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length) using the Scalable Matrix Extension (SME). @@ -247,9 +291,6 @@ def impl(): strides=[T.int32(), 1], ) - # Disable predication - ptrue = _create_ptrue_mask("float32") - with T.block("root"): T.reads(A[0:SVF2, 0:SVF2]) T.writes(A_t[0:SVF2, 0:SVF2]) @@ -263,19 +304,22 @@ def impl(): input_ptr = A.access_ptr("r", offset=offset) sub_tile = T.int32(sub_tile_idx) + predicate = _create_active_lane_mask( + A, (row_offset + slice_idx, col_offset), cols + ) T.evaluate( T.call_llvm_intrin( "void", "llvm.aarch64.sme.ld1w.horiz", T.uint32(4), - ptrue, + predicate, input_ptr, sub_tile, slice_idx, ) ) - # Store columns to the ouptut matrix + # Store columns to the output matrix with T.serial(0, SVF) as slice_idx: for sub_tile_idx in range(0, sub_tile_count): col_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0 @@ -284,12 +328,15 @@ def impl(): output_ptr = A_t.access_ptr("w", offset=offset) sub_tile = T.int32(sub_tile_idx) + predicate = _create_active_lane_mask( + A_t, (row_offset + slice_idx, col_offset), rows + ) T.evaluate( T.call_llvm_intrin( "void", "llvm.aarch64.sme.st1w.vert", T.uint32(4), - ptrue, + predicate, output_ptr, sub_tile, slice_idx, @@ -445,7 +492,24 @@ def impl(): return desc, impl() -def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype): +def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols, extent_rows): + if in_dtype == "float32" and out_dtype == "float32": + sme_transpose_interleave_intrin_name = ( + ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + f"_{extent_cols}_{extent_rows}" + ) + tir.TensorIntrin.register( + sme_transpose_interleave_intrin_name, + *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols, extent_rows), + override=True, + ) + return sme_transpose_interleave_intrin_name + elif in_dtype == "float16" and out_dtype == "float32": + return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE + else: + raise ValueError("Input/output data type combination not supported.") + + +def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype): """ Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length using outer product operations from the Scalable Matrix Extension (SME). @@ -579,15 +643,39 @@ def impl(): k_row = k * rows_per_iter in_dtype_svf = tir.get_vscale_expr(in_dtype) - a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]) - b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]) - + # Ideally we'd rely on predicating the loads and use the same predicate + # for the outer product operation. However, support for predicated + # buffers is not currently supported by multiple lowering passes such as + # "LowerMatchBuffer", therefore the predicate is passed directly to the + # outer product operation for now. if in_dtype == "float32": - a_high = T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]) - b_high = T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]) + a_low = ( + T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]), + _create_active_lane_mask(A, (k_row, 0), K), + ) + b_low = ( + T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]), + _create_active_lane_mask(B, (k_row, 0), K), + ) + a_high = ( + T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]), + _create_active_lane_mask(A, (k_row, in_dtype_svf), K), + ) + b_high = ( + T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]), + _create_active_lane_mask(B, (k_row, in_dtype_svf), K), + ) else: - a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]) - b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]) + a_low = (T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue) + b_low = (T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue) + a_high = ( + T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]), + ptrue, + ) + b_high = ( + T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]), + ptrue, + ) input_combinations = [ (a_low, b_low), @@ -606,10 +694,10 @@ def impl(): fmopa_intrin, T.uint32(5), sub_tile, - ptrue, - ptrue, - input_1, - input_2, + input_1[1], + input_2[1], + input_1[0], + input_2[0], ) ) @@ -626,7 +714,9 @@ def impl(): "void", "llvm.aarch64.sme.st1w.horiz", T.uint32(4), - _create_ptrue_mask("float32"), + _create_active_lane_mask( + C, (vert_offset + slice_idx, horiz_offset), M + ), output_ptr, T.int32(sub_tile_idx), T.int32(slice_idx), @@ -691,10 +781,6 @@ def impl(c: T.handle) -> None: # in versions of LLVM >= 15. Installations with older versions of LLVM will # not be able to use them. if llvm_version_major() >= 15: - TensorIntrin.register( - ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, - *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(), - ) TensorIntrin.register( ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE, *get_sme_transpose_interleave_block2_2svl_fp16_intrin(), diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index a6c951c07830..b7327d5b52e8 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -24,7 +24,6 @@ from tvm.script import tir as T import tvm.contrib.nnpack from tvm.tir.schedule.analysis import has_block -from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name from ..utils import traverse_inline, get_const_tuple from .. import nn @@ -773,10 +772,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, - ) - - transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name( - in_dtype, out_dtype + get_transpose_interleave_intrin_name, ) # Interleave the padded im2col matrix utilizing the matrix tile @@ -787,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) sch.parallel(b) sch.reorder(b, ko, mo, ki, mi) - sch.tensorize(ki, transpose_interleave_intrin_name) + sch.tensorize( + ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded) + ) # Interleave the padded weights matrix utilizing the matrix tile if in_dtype == "float16": @@ -797,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True) sch.reorder(ko, no, ki, ni) - sch.tensorize(ki, transpose_interleave_intrin_name) + sch.tensorize( + ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded) + ) # Split and reorder the loops of the GeMM for tensorization b, m, n, k = sch.get_loops(gemm_block) @@ -816,11 +816,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): # Tensorize the GeMM update sme_gemm_interleaved_intrin_name = ( - ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}" + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{M_padded}_{K_padded}_{in_dtype}" ) tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded, K_padded, in_dtype), override=True, ) sch.tensorize(mi, sme_gemm_interleaved_intrin_name) @@ -922,16 +922,18 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): reshape_block = func_blocks["T_reshape"] A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block - if use_sme: - sch.compute_inline(reshape_block) - elif A_pad_block: - sch.compute_inline(reshape_block) - b, m, k = sch.get_loops(A_pad_block) - _, k_inner = sch.split(k, [None, tile_N]) - sch.vectorize(k_inner) - sch.compute_at(A_pad_block, mi) - else: - sch.compute_at(reshape_block, mi) + use_explicit_predication = use_sme and in_dtype == "float32" + if not use_explicit_predication: + if use_sme: + sch.compute_inline(reshape_block) + elif A_pad_block: + sch.compute_inline(reshape_block) + b, m, k = sch.get_loops(A_pad_block) + _, k_inner = sch.split(k, [None, tile_N]) + sch.vectorize(k_inner) + sch.compute_at(A_pad_block, mi) + else: + sch.compute_at(reshape_block, mi) # Weight flattening if func_blocks["weight_flatten"]: diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index e637aa91e5b4..bf6a9c75516f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -133,23 +133,25 @@ def compute_conv2d_gemm_without_weight_transform( ) # Pad to tiles (if necessary) - pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A) - pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B) + use_explicit_predication = use_sme and in_dtype == "float32" + if not use_explicit_predication: + pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A) + pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B) - M_padded = M + pad_M - K_padded = K + pad_K - N_padded = N + pad_N + M_padded = M + pad_M + K_padded = K + pad_K + N_padded = N + pad_N - pad_before = (0, 0, 0) - pad_after = (0, pad_M, pad_K) + pad_before = (0, 0, 0) + pad_after = (0, pad_M, pad_K) - if pad_K != 0: - A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K") - elif pad_M != 0: - A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M") + if pad_K != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K") + elif pad_M != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M") idxm = tvm.tir.indexmod - k = te.reduce_axis((0, K_padded), "k") + k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k") # Determine matrix multiplication compute definition target = Target.current(allow_none=False) @@ -300,7 +302,18 @@ def compute_conv2d_gemm_without_weight_transform( name="C", ) zero = tvm.tir.const(0) - elif use_scalable_vectors or use_sme: + elif use_explicit_predication: + assert len(B_interleaved_t.shape) == 2 + C = te.compute( + (batches, M, N), + lambda b, x, y: te.sum( + A[b, x, k].astype(in_dtype) * B_interleaved_t[k, y].astype(in_dtype), + axis=k, + ), + name="C", + ) + zero = tvm.tir.const(0) + elif use_scalable_vectors: assert len(B_interleaved_t.shape) == 2 C = te.compute( (batches, M_padded, N_padded), diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 23b8734a0ba4..63f6289f0eb7 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -53,19 +53,16 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, tra tile_k *= 2 tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype) - M_padded, pad_M = pad_dim_to_multiple(M, tile_m) - _, pad_K = pad_dim_to_multiple(K, tile_k) - N_padded, pad_N = pad_dim_to_multiple(N, tile_n) - - m_pad_after = (pad_M, pad_K) - n_pad_after = (pad_K, pad_N) - if transpose_b: - n_pad_after = (pad_N, pad_K) - - if pad_M != 0: - data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after) - if pad_N != 0: - data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after) + if data_a.dtype == "float16": + _, pad_M = pad_dim_to_multiple(M, tile_m) + _, pad_K = pad_dim_to_multiple(K, tile_k) + _, pad_N = pad_dim_to_multiple(N, tile_n) + m_pad_after = (pad_M, pad_K) + n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N) + if pad_M != 0: + data_a = nn.pad(data_a, pad_before=(0, 0), pad_after=m_pad_after) + if pad_N != 0: + data_b = nn.pad(data_b, pad_before=(0, 0), pad_after=n_pad_after) if out_dtype is None: out_dtype = data_a.dtype @@ -87,28 +84,12 @@ def compute(*indices): (False, False): "T_matmul_NN", }[(transpose_a, transpose_b)] - C = te.compute( - (M_padded, N_padded), + return te.compute( + (M, N), compute, name=compute_name, attrs={"schedule_type": "sme"}, ) - return te.compute((M, N), lambda m, n: C[m, n]) - - -def _get_transpose_interleave_intrin_name(in_dtype, out_dtype): - # pylint: disable=import-outside-toplevel - from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, - ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE, - ) - - if in_dtype == "float32" and out_dtype == "float32": - return ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE - elif in_dtype == "float16" and out_dtype == "float32": - return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE - else: - raise ValueError("Input/output data type combination not supported.") def tir_schedule_matmul_sme(sch): @@ -120,6 +101,7 @@ def tir_schedule_matmul_sme(sch): ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, + get_transpose_interleave_intrin_name, ) main_func = sch.mod["main"] @@ -157,9 +139,9 @@ def tir_schedule_matmul_sme(sch): outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) outer_k, inner_k = sch.split(k, factors=(None, tile_k), disable_predication=True) sch.reorder(outer_k, outer_m, inner_k, inner_m) - - transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(in_dtype, out_dtype) - sch.tensorize(inner_k, transpose_interleave_intrin_name) + sch.tensorize( + inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_m, extent_k) + ) # Interleave the weights utilizing the matrix tile if transpose_b: @@ -169,7 +151,9 @@ def tir_schedule_matmul_sme(sch): outer_k, inner_k = sch.split(k, factors=(None, tile_k), disable_predication=True) outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) sch.reorder(outer_k, outer_n, inner_k, inner_n) - sch.tensorize(inner_k, transpose_interleave_intrin_name) + sch.tensorize( + inner_k, get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_k, extent_n) + ) # Split and reorder the loops of the GeMM for tensorization tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype) @@ -185,11 +169,11 @@ def tir_schedule_matmul_sme(sch): # Tensorize the GeMM update sme_gemm_interleaved_intrin_name = ( - ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_k}_{in_dtype}" + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{extent_m}_{extent_k}_{in_dtype}" ) tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_k, in_dtype), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(extent_m, extent_k, in_dtype), override=True, ) sch.tensorize(inner_m, sme_gemm_interleaved_intrin_name) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 9b0408b949a0..f596549a10d0 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -530,12 +530,14 @@ def check_correct_assembly(dtype): ) stores = re.findall(r"st1[whdb]\t{\s?za", assembly) smstop = re.findall(r"smstop\t(sm|za)", assembly) + whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly) assert len(smstart) > 0 assert len(loads) > 0 assert len(mopa) > 0 assert len(stores) > 0 assert len(smstop) > 0 + assert len(whilelo) > 0 check_correct_assembly(dtype=dtype) @@ -819,12 +821,14 @@ def check_correct_assembly(dtype): ) stores = re.findall(r"st1[whdb]\t{\s?za", assembly) smstop = re.findall(r"smstop\t(sm|za)", assembly) + whilelo = re.findall(r"whilelo\tp[0-9].[shdb]", assembly) assert len(smstart) > 0 assert len(loads) > 0 assert len(mopa) > 0 assert len(stores) > 0 assert len(smstop) > 0 + assert len(whilelo) > 0 with tvm.target.Target(target): check_correct_assembly(dtype=dtype) diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index d46db1b28b37..e7009ed179f5 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -168,10 +168,16 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): target = tvm.target.Target(target_string) if target.features.has_sve and llvm_version_major() < 15: - pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SVE.") + pytest.skip(f"LLVM {llvm_version_major()} does not support targeting SVE.") if target.features.has_sme and llvm_version_major() < 16: - pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SME.") + pytest.skip(f"LLVM {llvm_version_major()} does not support targeting SME.") + + if target.features.has_sme and a_np.shape[0] > 1: + pytest.skip(f"Conv2d with batches > 1 targeting SME not implemented.") + + if target.features.has_sme and (a_np.shape[3] * w_np.shape[0] * w_np.shape[1]) <= 1: + pytest.skip(f"Conv2d with unit reduction dimension targeting SME not supported.") # SME schedule always outputs float32 results, regardless of input dtype. # Otherwise, output dtype is the same as input dtype.