From 4e89183598ff3136686041b90bafff58f5e14b17 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Tue, 6 Jan 2026 17:08:51 +0700 Subject: [PATCH] [Relax][Op] Fixed incorrect output shape of Pool op when ceil_mode = true - Skip the last window as it would start in the bottom padded region --- include/tvm/topi/nn/pooling.h | 14 +++++++++++-- src/relax/op/nn/pooling.cc | 39 +++++++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index b977a54a5920..3caf7bf1f7d2 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -563,8 +563,18 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s PrimExpr numerator = data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i]; - auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1); - out_shape.Set(ii, out_dim); + auto raw_out = indexdiv(numerator, stride[i]) + 1; + if (ceil_mode) { + // In the case of ceil_mode=True, we need to check if the last pooling window is valid. + // If not, we skip the last window as it would start in the bottom padded region, + // we need to minus 1 to get the correct output shape. + auto invalid_last = (raw_out - 1) * stride[i] >= data_shape[ii] + pad_head[i]; + auto out_dim = analyzer.Simplify(if_then_else(invalid_last, raw_out - 1, raw_out)); + out_shape.Set(ii, out_dim); + } else { + auto out_dim = analyzer.Simplify(raw_out); + out_shape.Set(ii, out_dim); + } } ffi::Map attrs; diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 584135520000..1a19872c2794 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -111,7 +111,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { if (attrs->ceil_mode) { numerator_w += attrs->strides[0] - 1; } - out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); + PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1; + if (attrs->ceil_mode) { + PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + attrs->padding[0]; + out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); + } else { + out_NCW_shape[2] = analyzer->Simplify(raw_out_w); + } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); @@ -232,8 +238,17 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { numerator_h += attrs->strides[0] - 1; numerator_w += attrs->strides[1] - 1; } - out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); - out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1; + if (attrs->ceil_mode) { + PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h + attrs->padding[0]; + PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w + attrs->padding[1]; + out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); + out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); + } else { + out_NCHW_shape[2] = analyzer->Simplify(raw_out_h); + out_NCHW_shape[3] = analyzer->Simplify(raw_out_w); + } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); @@ -380,9 +395,21 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { numerator_h += attrs->strides[1] - 1; numerator_w += attrs->strides[2] - 1; } - out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, attrs->strides[0]) + 1); - out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); - out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); + PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1; + if (attrs->ceil_mode) { + PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d + attrs->padding[0]; + PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h + attrs->padding[1]; + PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w + attrs->padding[2]; + out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, raw_out_d - 1, raw_out_d)); + out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); + out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); + } else { + out_NCDHW_shape[2] = analyzer->Simplify(raw_out_d); + out_NCDHW_shape[3] = analyzer->Simplify(raw_out_h); + out_NCDHW_shape[4] = analyzer->Simplify(raw_out_w); + } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice);