diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 04748a4d81fb..ea9026688eec 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -211,18 +211,45 @@ def compute_conv2d_gemm_without_weight_transform( ), name="C_interleaved", ) + # Ensure the padding needed for tensorize does not get removed during tir passes + # by adding a dummy reference to the specific padded area of the result + zero = ( + tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_rows_A, + N_transformed - 1, + idxm(M, tile_rows_A) // 2, + tile_rows_B // 2 - 1, + 1, + 1, + ] + - tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_rows_A, + N_transformed - 1, + idxm(M, tile_rows_A) // 2, + tile_rows_B // 2 - 1, + 1, + 1, + ] + ) # Unpack the result C = te.compute( (batches, M, N), - lambda b, x, y: C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A) // 2, - idxm(y, tile_rows_B) // 2, - idxm(idxm(x, tile_rows_A), 2), - idxm(idxm(y, tile_rows_B), 2), - ].astype(out_dtype), + lambda b, x, y: ( + C_interleaved[ + b, + x // tile_rows_A, + y // tile_rows_B, + idxm(x, tile_rows_A) // 2, + idxm(y, tile_rows_B) // 2, + idxm(idxm(x, tile_rows_A), 2), + idxm(idxm(y, tile_rows_B), 2), + ] + + zero + ).astype(out_dtype), name="C", ) else: diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index e05dba3dfee4..fd101fb79768 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -57,12 +57,11 @@ topi.arm_cpu.compute_conv2d_NHWC_quantized_native, topi.arm_cpu.schedule_conv2d_NHWC_quantized_native, ), - # TODO(giuseros) We need LLVM-11 in order to compile with +i8mm extension - # ( - # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", - # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, - # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, - # ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, + ), ]