From 5ad79dde80b39a11ec7b6af53dd7deaa720e0a33 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 14 Mar 2023 15:38:06 +0300 Subject: [PATCH] [TIR][Hexagon] Enhancement of NarrowDataType pass for binary ops This is enhancement of PR#13327. Motivation: Playing with MetaScheduler for Hexagon target it was found that avg_pool2d has rather poor performance due to lack of vectorized code. IndexDataTypeNormalizer pass converts all indices to int64 format and NarrowDataTypeRewriter should do the opposite (back to int32). In case of fail, we have a lot of int64 arithmetic for average pooling that can not be vectorized. What was done: Added support of binary ops ("div", "max", "min", "+" etc.) in NarrowDataTypeRewriter. In case of different bitwidth of operands in binary opeation it does downcasting instead of upcasting (as it was before). Performance impact: avg_pool2d from quantized InceptionV3 with the shape [1, 8, 35, 35, 32] (NCHW32c layout) tuned with MetaScheduler on Snapdragon 8gen1: shape | Before fix, ms | After fix, ms | speedup | ------------------|----------------|---------------|-------------| avg_pool2d, int32 | 6.67 | 4.41 | +34% | -----------------------------------------------------------------| --- src/tir/transforms/narrow_datatype.cc | 38 ++++++++++++ .../test_tir_transform_narrow_datatype.py | 61 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index e9c57eb78e26..ad8132521d47 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -258,6 +258,44 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { return Parent::VisitExpr_(op); } +#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ + return GetRef(op); \ + } else { \ + if (a.dtype() != b.dtype()) { \ + bool is_enabled = is_enabled_; \ + is_enabled_ = true; \ + PrimExpr lhs = this->VisitExpr(op->a); \ + PrimExpr rhs = this->VisitExpr(op->b); \ + is_enabled_ = is_enabled; \ + return FUNC(lhs, rhs); \ + } else { \ + return FUNC(a, b); \ + } \ + } \ + } + + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) + TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); + +#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH + private: // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 56b63c889335..b3b0c6f59b0f 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -346,5 +346,66 @@ def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32" tvm.ir.assert_structural_equal(after, expected_after) +def test_avg_pool2d(): + @T.prim_func + def before(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")): + for j in T.parallel(T.int64(0), T.int64(280)): + for i in T.serial(T.int64(0), T.int64(35)): + for vi in T.vectorized(T.int64(0), T.int64(32)): + PAVG[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)] = T.cast( + T.Div( + T.cast(PSUM[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)], "int64"), + T.max( + ( + ( + ( + T.min( + T.int64(1), + (T.int64(34) - T.floormod(j, T.int64(35))), + ) + + T.int64(2) + ) + - T.max( + (T.int64(1) - T.floormod(j, T.int64(35))), T.int64(0) + ) + ) + * ( + (T.min(T.int64(1), (T.int64(34) - i)) + T.int64(2)) + - T.max((T.int64(1) - i), T.int64(0)) + ) + ), + T.int64(1), + ), + ), + "int32", + ) + + @T.prim_func + def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")): + for j in T.parallel(T.int32(0), T.int32(280)): + for i in T.serial(T.int32(0), T.int32(35)): + for vi in T.vectorized(T.int32(0), T.int32(32)): + PAVG[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)] = T.Div( + PSUM[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)], + ( + ( + ( + T.min(T.int32(1), (T.int32(34) - T.floormod(j, T.int32(35)))) + + T.int32(2) + ) + - T.max((T.int32(1) - T.floormod(j, T.int32(35))), T.int32(0)) + ) + * ( + (T.min(T.int32(1), (T.int32(34) - i)) + T.int32(2)) + - T.max((T.int32(1) - i), T.int32(0)) + ) + ), + ) + + after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before)) + after = tvm.tir.transform.Simplify()(after) + tvm.ir.assert_structural_equal(after["main"], expected_after) + + if __name__ == "__main__": tvm.testing.main()