diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index e9c57eb78e26..ad8132521d47 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -258,6 +258,44 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { return Parent::VisitExpr_(op); } +#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ + return GetRef(op); \ + } else { \ + if (a.dtype() != b.dtype()) { \ + bool is_enabled = is_enabled_; \ + is_enabled_ = true; \ + PrimExpr lhs = this->VisitExpr(op->a); \ + PrimExpr rhs = this->VisitExpr(op->b); \ + is_enabled_ = is_enabled; \ + return FUNC(lhs, rhs); \ + } else { \ + return FUNC(a, b); \ + } \ + } \ + } + + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); + +#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH + private: // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 56b63c889335..b3b0c6f59b0f 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -346,5 +346,66 @@ def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32" tvm.ir.assert_structural_equal(after, expected_after) +def test_avg_pool2d(): + @T.prim_func + def before(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")): + for j in T.parallel(T.int64(0), T.int64(280)): + for i in T.serial(T.int64(0), T.int64(35)): + for vi in T.vectorized(T.int64(0), T.int64(32)): + PAVG[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)] = T.cast( + T.Div( + T.cast(PSUM[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)], "int64"), + T.max( + ( + ( + ( + T.min( + T.int64(1), + (T.int64(34) - T.floormod(j, T.int64(35))), + ) + + T.int64(2) + ) + - T.max( + (T.int64(1) - T.floormod(j, T.int64(35))), T.int64(0) + ) + ) + * ( + (T.min(T.int64(1), (T.int64(34) - i)) + T.int64(2)) + - T.max((T.int64(1) - i), T.int64(0)) + ) + ), + T.int64(1), + ), + ), + "int32", + ) + + @T.prim_func + def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")): + for j in T.parallel(T.int32(0), T.int32(280)): + for i in T.serial(T.int32(0), T.int32(35)): + for vi in T.vectorized(T.int32(0), T.int32(32)): + PAVG[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)] = T.Div( + PSUM[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)], + ( + ( + ( + T.min(T.int32(1), (T.int32(34) - T.floormod(j, T.int32(35)))) + + T.int32(2) + ) + - T.max((T.int32(1) - T.floormod(j, T.int32(35))), T.int32(0)) + ) + * ( + (T.min(T.int32(1), (T.int32(34) - i)) + T.int32(2)) + - T.max((T.int32(1) - i), T.int32(0)) + ) + ), + ) + + after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before)) + after = tvm.tir.transform.Simplify()(after) + tvm.ir.assert_structural_equal(after["main"], expected_after) + + if __name__ == "__main__": tvm.testing.main()