From 6cd89ac481ebb1eb2b1ce9692987d942c3b0cab2 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Thu, 3 Aug 2023 14:30:09 +0000 Subject: [PATCH] [Bugfix][TOPI] Fix a bug in arm_cpu int8 conv2d i8mm schedule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved` was failing compilation with the `+i8mm` extension enabled whenever the output height and output width were both equal to 1, such that OH x OW = 1. Padding was being removed during the `tir.BufferShapeLegalize` pass, causing an error in the `tir.BufferBindUnwrapper` pass. Some of the removed padding was necessary for tensorize (using the `gemm_acc_2x2_int8_int8_int32` intrinsic), which expects 2x2 output tiles. However, because of the optimisations mentioned above, the output tensor `C_interleaved` was reduced to having 1x2 tiles instead. e.g. for A = [1x1x1x8], W = [1x1x8x24], C = [1x1x1x24]: - Before fix: `C_interleaved = T.Buffer((1, 1, 2, 1, 6, 1, 2), "int32”)` - After fix: `C_interleaved = T.Buffer((1, 1, 2, 1, 6, 2, 2), "int32”)` To make sure the required padding is left untouched, while the rest of it is still removed, a dummy reference to the needed axis is declared. Finally, the leftover padding is still disregarded when computing the final output tensor `C`. --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 45 +++++++++++++++---- .../topi/python/test_topi_conv2d_int8.py | 11 +++-- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 04748a4d81fb..ea9026688eec 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -211,18 +211,45 @@ def compute_conv2d_gemm_without_weight_transform( ), name="C_interleaved", ) + # Ensure the padding needed for tensorize does not get removed during tir passes + # by adding a dummy reference to the specific padded area of the result + zero = ( + tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_rows_A, + N_transformed - 1, + idxm(M, tile_rows_A) // 2, + tile_rows_B // 2 - 1, + 1, + 1, + ] + - tvm.tir.const(1, C_interleaved.dtype) + * C_interleaved[ + batches - 1, + M // tile_rows_A, + N_transformed - 1, + idxm(M, tile_rows_A) // 2, + tile_rows_B // 2 - 1, + 1, + 1, + ] + ) # Unpack the result C = te.compute( (batches, M, N), - lambda b, x, y: C_interleaved[ - b, - x // tile_rows_A, - y // tile_rows_B, - idxm(x, tile_rows_A) // 2, - idxm(y, tile_rows_B) // 2, - idxm(idxm(x, tile_rows_A), 2), - idxm(idxm(y, tile_rows_B), 2), - ].astype(out_dtype), + lambda b, x, y: ( + C_interleaved[ + b, + x // tile_rows_A, + y // tile_rows_B, + idxm(x, tile_rows_A) // 2, + idxm(y, tile_rows_B) // 2, + idxm(idxm(x, tile_rows_A), 2), + idxm(idxm(y, tile_rows_B), 2), + ] + + zero + ).astype(out_dtype), name="C", ) else: diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index e05dba3dfee4..fd101fb79768 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -57,12 +57,11 @@ topi.arm_cpu.compute_conv2d_NHWC_quantized_native, topi.arm_cpu.schedule_conv2d_NHWC_quantized_native, ), - # TODO(giuseros) We need LLVM-11 in order to compile with +i8mm extension - # ( - # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", - # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, - # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, - # ), + ( + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, + ), ]