From b04855053832ace9c315846a680a54474142d85c Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Wed, 25 May 2022 10:42:00 +0000 Subject: [PATCH 01/14] enable oneDNN conv op by using -libs=mkldnn --- cmake/modules/contrib/BLAS.cmake | 2 + python/tvm/contrib/mkldnn.py | 91 ++++++++++++++++++++++++++ python/tvm/relay/op/strategy/x86.py | 6 ++ python/tvm/topi/x86/conv2d.py | 15 ++++- src/runtime/contrib/dnnl/dnnl.cc | 18 +++++ src/runtime/contrib/dnnl/dnnl_kernel.h | 1 + tests/python/relay/test_op_level2.py | 41 ++++++++++++ 7 files changed, 173 insertions(+), 1 deletion(-) diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index 06c8755882d5..f31218088a9e 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -72,6 +72,7 @@ if(IS_DIRECTORY ${USE_MKLDNN}) include_directories(SYSTEM ${USE_MKLDNN}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${MKLDNN_LIBRARY}) list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc) + list(APPEND RUNTIME_SRCS src/runtime/contrib/dnnl/dnnl.cc) add_definitions(-DUSE_DNNL=1) message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY}) endif() @@ -84,6 +85,7 @@ elseif(USE_MKLDNN STREQUAL "ON") add_definitions(-DUSE_DNNL=1) message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY}) list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc) + list(APPEND RUNTIME_SRCS src/runtime/contrib/dnnl/dnnl.cc) endif() elseif(USE_MKLDNN STREQUAL "OFF") # pass diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index 8d5f4da0345b..4a0b1e4abcd7 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -17,6 +17,7 @@ """External function interface to BLAS libraries.""" import tvm from tvm import te +from ..topi.nn.utils import get_pad_tuple def matmul(lhs, rhs, transa=False, transb=False, **kwargs): @@ -50,3 +51,93 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): name="C", **kwargs, ) + +def dnnl_conv2d( + Input, + Filter, + stride, + padding, + dilation, + groups, + out_dtype="float32", + **kwargs +): + """Convolution operator in NCHW layout. + + Parameters + ---------- + Input : tvm.te.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + Filter : tvm.te.Tensor + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + groups: str + input data layout: NCHW or NHWC + + out_dtype: str + output datatype: now only support float32 + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, _, in_height, in_width = Input.shape + num_filter, _, kernel_h, kernel_w = Filter.shape + + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = ((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = ((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + + out_shape = (batch, out_channel, out_height, out_width) + + return te.extern( + out_shape, + [Input, Filter], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.mkldnn.conv2d", + ins[0], + ins[1], + outs[0], + pad_top, + pad_down, + pad_left, + pad_right, + stride[0], + stride[1], + groups, + ), + name="C", + dtype=out_dtype, + **kwargs, + ) \ No newline at end of file diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0beb99e4f7db..a785568f0ad9 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -127,6 +127,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_int8), name="conv2d_nchw_int8.x86", ) + elif "mkldnn" in target.libs: + strategy.add_implementation( + wrap_compute_conv2d(topi.x86.conv2d_nchw_mkldnn), + wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_mkldnn), + name="conv2d_nchw_mkldnn.x86", + ) else: strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw), diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index 182454acf3a6..f9e957c25725 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -23,7 +23,8 @@ import tvm from tvm import te from tvm import autotvm -from .. import nn +from tvm.contrib import mkldnn +from .. import nn, generic from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload @@ -266,6 +267,18 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s +@autotvm.register_topi_compute("conv2d_nchw_mkldnn.x86") +def conv2d_nchw_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d in NCHW format using mkldnn.""" + groups=1 + _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, out_dtype) + return _out + +@autotvm.register_topi_schedule("conv2d_nchw_mkldnn.x86") +def schedule_conv2d_nchw_mkldnn(_, outs): + """Create schedule for conv2d_nchw_mkldnn""" + return generic.schedule_extern(outs) + # FIXME - https://github.com/apache/tvm/issues/4122 # _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index d1190df91375..b10a03d13645 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -306,6 +306,24 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ read_from_dnnl_memory(out, dst_memory); } +// DNNL Conv2d single OP +TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* weights = args[1]; + DLTensor* output = args[2]; + int PH_L = args[3], + PW_L = args[4], + PH_R = args[5], + PW_R = args[6], + SH = args[7], + SW = args[8], + G = args[9]; + + dnnl_conv2d(static_cast(input->data), static_cast(weights->data), static_cast(output->data), + input->shape[0], input->shape[1], input->shape[2], input->shape[3], output->shape[1], G, + PH_L, PW_L , PH_R, PW_R , weights->shape[2], weights->shape[3], SH, SW); +}); + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index 522313ae5a64..04e06d9c9e94 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -27,6 +27,7 @@ #include #include +#include #include diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index c644890bbcbe..b5c90f6d239a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1993,6 +1993,47 @@ def test_conv2d_rocm_sdot4(): np.testing.assert_equal(out, ref) +@tvm.testing.requires_x86 +def test_conv2d_mkldnn(): + d_shape = (1, 64, 56, 56) + w_shape = (64, 64, 3, 3) + padding = (1, 1) + strides = (1, 1) + + data = relay.var("data", shape=d_shape, dtype="float32") + weight = relay.var("weight", shape=w_shape, dtype="float32") + out_channel = w_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype="float32", + ) + + mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, d_shape).astype("float32") + weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32") + + target = "llvm -mcpu=skylake-avx512 -libs=mkldnn" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params={"weight": weight_np}) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding) + + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 7164030c7b19351372a4d973cd25f33b58b4c208 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Wed, 1 Jun 2022 00:01:27 +0000 Subject: [PATCH 02/14] add channel last format support and let oneDNN chose blocked format. --- python/tvm/contrib/mkldnn.py | 23 +++++-- python/tvm/relay/op/strategy/x86.py | 17 ++++-- python/tvm/topi/x86/conv2d.py | 18 +++++- src/runtime/contrib/dnnl/dnnl.cc | 89 +++++++++++++++++++++------- tests/python/relay/test_op_level2.py | 45 +++++++++++++- 5 files changed, 159 insertions(+), 33 deletions(-) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index 4a0b1e4abcd7..7dd16e6030b8 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -59,6 +59,7 @@ def dnnl_conv2d( padding, dilation, groups, + channel_last=False, out_dtype="float32", **kwargs ): @@ -85,10 +86,14 @@ def dnnl_conv2d( groups: str input data layout: NCHW or NHWC - + + channel_last: bool + chose if input/output data format is in channel_last format(NHWC) or + in plain format(NCHW) + out_dtype: str output datatype: now only support float32 - + Returns ------- Output : tvm.te.Tensor @@ -107,8 +112,12 @@ def dnnl_conv2d( else: dilation_h, dilation_w = dilation - batch, _, in_height, in_width = Input.shape - num_filter, _, kernel_h, kernel_w = Filter.shape + if channel_last: + batch, in_height, in_width, _ = Input.shape + kernel_h, kernel_w, _, num_filter = Filter.shape + else: + batch, _, in_height, in_width = Input.shape + num_filter, _, kernel_h, kernel_w = Filter.shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -119,7 +128,10 @@ def dnnl_conv2d( out_height = ((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = ((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - out_shape = (batch, out_channel, out_height, out_width) + if channel_last: + out_shape = (batch, out_height, out_width, out_channel) + else: + out_shape = (batch, out_channel, out_height, out_width) return te.extern( out_shape, @@ -136,6 +148,7 @@ def dnnl_conv2d( stride[0], stride[1], groups, + channel_last, ), name="C", dtype=out_dtype, diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a785568f0ad9..f5028d8f285d 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -146,11 +146,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): assert kernel_layout == "HWIO" if not is_auto_scheduler_enabled(): logger.warning("conv2d NHWC layout is not optimized for x86 with autotvm.") - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), - wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), - name="conv2d_nhwc.x86", - ) + if "mkldnn" in target.libs: + strategy.add_implementation( + wrap_compute_conv2d(topi.x86.conv2d_nhwc_mkldnn), + wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc_mkldnn), + name="conv2d_nhwc_mkldnn.x86", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), + name="conv2d_nhwc.x86", + ) judge_winograd_auto_scheduler = False if len(kernel.shape) == 4: diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index f9e957c25725..daccfba53b96 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -24,13 +24,15 @@ from tvm import te from tvm import autotvm from tvm.contrib import mkldnn -from .. import nn, generic +from .. import nn + from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from . import conv2d_avx_1x1, conv2d_avx_common +from .. import generic logger = logging.getLogger("topi") @@ -271,7 +273,7 @@ def _callback(op): def conv2d_nchw_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d in NCHW format using mkldnn.""" groups=1 - _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, out_dtype) + _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, False, out_dtype) return _out @autotvm.register_topi_schedule("conv2d_nchw_mkldnn.x86") @@ -279,6 +281,18 @@ def schedule_conv2d_nchw_mkldnn(_, outs): """Create schedule for conv2d_nchw_mkldnn""" return generic.schedule_extern(outs) +@autotvm.register_topi_compute("conv2d_nhwc_mkldnn.x86") +def conv2d_nhwc_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d in NHWC format using mkldnn.""" + groups=1 + _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, True, out_dtype) + return _out + +@autotvm.register_topi_schedule("conv2d_nhwc_mkldnn.x86") +def schedule_conv2d_nhwc_mkldnn(_, outs): + """Create schedule for conv2d_nhwc_mkldnn""" + return generic.schedule_extern(outs) + # FIXME - https://github.com/apache/tvm/issues/4122 # _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index b10a03d13645..7784d074999c 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -81,7 +81,7 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, - primitive_attr attr) { + primitive_attr attr, bool channel_last) { using tag = memory::format_tag; using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -97,32 +97,59 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory::dims conv2d_padding0 = {p_Ph0_, p_Pw0_}; memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_}; - auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); + auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, data); auto user_weights_memory = - memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights); + memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::hwio : tag::oihw}, eng, weights); + if (p_G_ > 1) user_weights_memory = + memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::ghwio : tag::goihw}, eng, weights); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); + auto user_dst_memory = memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any); - auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); + auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::any); auto conv2d_desc = convolution_forward::desc( prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding0, conv2d_padding1); auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng); + // reorder if src layout not DNNL chosen. auto conv2d_src_memory = user_src_memory; + if (conv2d_prim_desc.src_desc() != user_src_memory.get_desc()) { + conv2d_src_memory = memory(conv2d_prim_desc.src_desc(), eng); + auto reorder_src = reorder(user_src_memory, conv2d_src_memory); + reorder_src.execute(s, {{DNNL_ARG_FROM, user_src_memory}, {DNNL_ARG_TO, conv2d_src_memory}}); + } + + // reorder if weights layout not DNNL chosen. auto conv2d_weights_memory = user_weights_memory; - auto conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng); + if (conv2d_prim_desc.weights_desc() != user_weights_memory.get_desc()) { + conv2d_weights_memory = memory(conv2d_prim_desc.weights_desc(), eng); + auto reorder_weights = reorder(user_weights_memory, conv2d_weights_memory); + reorder_weights.execute( + s, {{DNNL_ARG_FROM, user_weights_memory}, {DNNL_ARG_TO, conv2d_weights_memory}}); + } + + auto conv2d_dst_memory = user_dst_memory; + if (conv2d_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng); + } auto conv = convolution_forward(conv2d_prim_desc); conv.execute(s, {{DNNL_ARG_SRC, conv2d_src_memory}, {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, {DNNL_ARG_BIAS, conv2d_user_bias_memory}, {DNNL_ARG_DST, conv2d_dst_memory}}); + + // reorder if dst layout not DNNL chosen. + if (conv2d_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(conv2d_dst_memory, user_dst_memory).execute( + s, {{DNNL_ARG_FROM, conv2d_dst_memory}, {DNNL_ARG_TO, user_dst_memory}}); + } + s.wait(); - read_from_dnnl_memory(out, conv2d_dst_memory); } extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_, @@ -131,7 +158,7 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i primitive_attr attr; std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr); + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, false); } primitive_attr create_attr_with_relu_post_op() { @@ -151,7 +178,7 @@ extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - create_attr_with_relu_post_op()); + create_attr_with_relu_post_op(), false); } extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out, @@ -161,7 +188,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* int p_Sw_) { return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - create_attr_with_relu_post_op()); + create_attr_with_relu_post_op(), false); } extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { @@ -311,17 +338,39 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRe DLTensor* input = args[0]; DLTensor* weights = args[1]; DLTensor* output = args[2]; - int PH_L = args[3], - PW_L = args[4], - PH_R = args[5], - PW_R = args[6], - SH = args[7], - SW = args[8], - G = args[9]; - - dnnl_conv2d(static_cast(input->data), static_cast(weights->data), static_cast(output->data), - input->shape[0], input->shape[1], input->shape[2], input->shape[3], output->shape[1], G, - PH_L, PW_L , PH_R, PW_R , weights->shape[2], weights->shape[3], SH, SW); + int p_Ph0_ = args[3], + p_Pw0_ = args[4], + p_Ph1_ = args[5], + p_Pw1_ = args[6], + p_Sh_ = args[7], + p_Sw_ = args[8], + p_G_ = args[9]; + bool channel_last = args[10]; + + int p_N_ = input->shape[0], + p_C_ = input->shape[1], + p_H_ = input->shape[2], + p_W_ = input->shape[3], + p_O_ = output->shape[1], + p_Kh_ = weights->shape[2], + p_Kw_ = weights->shape[3]; + + if (channel_last) { + p_N_ = input->shape[0]; + p_H_ = input->shape[1]; + p_W_ = input->shape[2]; + p_C_ = input->shape[3]; + p_O_ = output->shape[3]; + p_Kh_ = weights->shape[0]; + p_Kw_ = weights->shape[1]; + } + + std::vector bias(p_O_, 0); + primitive_attr attr; + return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), bias.data(), + static_cast(output->data), p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, channel_last); + }); } // namespace contrib diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index b5c90f6d239a..ff5a7034f991 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1994,7 +1994,7 @@ def test_conv2d_rocm_sdot4(): np.testing.assert_equal(out, ref) @tvm.testing.requires_x86 -def test_conv2d_mkldnn(): +def test_conv2d_nchw_mkldnn(): d_shape = (1, 64, 56, 56) w_shape = (64, 64, 3, 3) padding = (1, 1) @@ -2034,6 +2034,49 @@ def test_conv2d_mkldnn(): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +@tvm.testing.requires_x86 +def test_conv2d_nhwc_mkldnn(): + d_shape = (1, 56, 56, 64) + w_shape = (3, 3, 64, 64) + padding = (1, 1) + strides = (1, 1) + + data = relay.var("data", shape=d_shape, dtype="float32") + weight = relay.var("weight", shape=w_shape, dtype="float32") + out_channel = w_shape[3] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[:2], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype="float32", + data_layout="NHWC", + kernel_layout="HWIO" + ) + + mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, d_shape).astype("float32") + weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32") + + target = "llvm -mcpu=skylake-avx512 -libs=mkldnn" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params={"weight": weight_np}) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding) + + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From de8159095def5836fb60e450a17081fc8c5e01a8 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Wed, 1 Jun 2022 00:18:07 +0000 Subject: [PATCH 03/14] remove unnecessary changes --- python/tvm/topi/x86/conv2d.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index daccfba53b96..17b0aacabb1d 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -24,15 +24,13 @@ from tvm import te from tvm import autotvm from tvm.contrib import mkldnn -from .. import nn - +from .. import nn, generic from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from . import conv2d_avx_1x1, conv2d_avx_common -from .. import generic logger = logging.getLogger("topi") From def83cb621193d6e65d677dc204d3027e91c3e44 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 17:29:25 +0000 Subject: [PATCH 04/14] reformat 3 files --- python/tvm/contrib/mkldnn.py | 15 ++++++++------- python/tvm/topi/x86/conv2d.py | 8 ++++++-- tests/python/relay/test_op_level2.py | 10 ++++++---- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index 7dd16e6030b8..3fec666cde65 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -52,6 +52,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): **kwargs, ) + def dnnl_conv2d( Input, Filter, @@ -61,7 +62,7 @@ def dnnl_conv2d( groups, channel_last=False, out_dtype="float32", - **kwargs + **kwargs, ): """Convolution operator in NCHW layout. @@ -125,9 +126,9 @@ def dnnl_conv2d( padding, (dilated_kernel_h, dilated_kernel_w) ) out_channel = num_filter - out_height = ((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = ((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - + out_height = (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + out_width = (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + if channel_last: out_shape = (batch, out_height, out_width, out_channel) else: @@ -137,8 +138,8 @@ def dnnl_conv2d( out_shape, [Input, Filter], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.mkldnn.conv2d", - ins[0], + "tvm.contrib.mkldnn.conv2d", + ins[0], ins[1], outs[0], pad_top, @@ -153,4 +154,4 @@ def dnnl_conv2d( name="C", dtype=out_dtype, **kwargs, - ) \ No newline at end of file + ) diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index 17b0aacabb1d..108214707851 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -267,25 +267,29 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + @autotvm.register_topi_compute("conv2d_nchw_mkldnn.x86") def conv2d_nchw_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d in NCHW format using mkldnn.""" - groups=1 + groups = 1 _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, False, out_dtype) return _out + @autotvm.register_topi_schedule("conv2d_nchw_mkldnn.x86") def schedule_conv2d_nchw_mkldnn(_, outs): """Create schedule for conv2d_nchw_mkldnn""" return generic.schedule_extern(outs) + @autotvm.register_topi_compute("conv2d_nhwc_mkldnn.x86") def conv2d_nhwc_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d in NHWC format using mkldnn.""" - groups=1 + groups = 1 _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, True, out_dtype) return _out + @autotvm.register_topi_schedule("conv2d_nhwc_mkldnn.x86") def schedule_conv2d_nhwc_mkldnn(_, outs): """Create schedule for conv2d_nhwc_mkldnn""" diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 367d116bbfe7..3718a6106065 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1995,6 +1995,7 @@ def test_conv2d_rocm_sdot4(): np.testing.assert_equal(out, ref) + @tvm.testing.requires_x86 def test_conv2d_nchw_mkldnn(): d_shape = (1, 64, 56, 56) @@ -2034,7 +2035,8 @@ def test_conv2d_nchw_mkldnn(): ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding) - np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + @tvm.testing.requires_x86 def test_conv2d_nhwc_mkldnn(): @@ -2053,9 +2055,9 @@ def test_conv2d_nhwc_mkldnn(): channels=out_channel, padding=padding, strides=strides, - out_dtype="float32", + out_dtype="float32", data_layout="NHWC", - kernel_layout="HWIO" + kernel_layout="HWIO", ) mod = tvm.IRModule.from_expr(conv2d) @@ -2077,7 +2079,7 @@ def test_conv2d_nhwc_mkldnn(): ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding) - np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) if __name__ == "__main__": From 12fa1e9be0fb3bd0178c301bb306c1d7d85e63c3 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 17:39:14 +0000 Subject: [PATCH 05/14] reformat 1 file --- python/tvm/topi/x86/conv2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index 108214707851..e1710ef1eea5 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -285,7 +285,7 @@ def schedule_conv2d_nchw_mkldnn(_, outs): @autotvm.register_topi_compute("conv2d_nhwc_mkldnn.x86") def conv2d_nhwc_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d in NHWC format using mkldnn.""" - groups = 1 + groups = 1 _out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, True, out_dtype) return _out From d02894b03f8519bb2d74093894a792ebd75c997f Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 17:51:06 +0000 Subject: [PATCH 06/14] change the argument name --- python/tvm/contrib/mkldnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index 3fec666cde65..36f08c67b469 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -54,8 +54,8 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): def dnnl_conv2d( - Input, - Filter, + input, + filter, stride, padding, dilation, From 3ee2a659893f755ff57b6427bfcbc9207cfe37f8 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 17:54:05 +0000 Subject: [PATCH 07/14] change the argument name --- python/tvm/contrib/mkldnn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index 36f08c67b469..ca430933b399 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -114,11 +114,11 @@ def dnnl_conv2d( dilation_h, dilation_w = dilation if channel_last: - batch, in_height, in_width, _ = Input.shape - kernel_h, kernel_w, _, num_filter = Filter.shape + batch, in_height, in_width, _ = input.shape + kernel_h, kernel_w, _, num_filter = filter.shape else: - batch, _, in_height, in_width = Input.shape - num_filter, _, kernel_h, kernel_w = Filter.shape + batch, _, in_height, in_width = input.shape + num_filter, _, kernel_h, kernel_w = filter.shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -136,7 +136,7 @@ def dnnl_conv2d( return te.extern( out_shape, - [Input, Filter], + [input, filter], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.mkldnn.conv2d", ins[0], From 041955835c1e5f33cd9087f8ccd576e512f5294d Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 18:17:37 +0000 Subject: [PATCH 08/14] rename the arguments --- python/tvm/contrib/mkldnn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/contrib/mkldnn.py b/python/tvm/contrib/mkldnn.py index ca430933b399..a60a35f0ad04 100644 --- a/python/tvm/contrib/mkldnn.py +++ b/python/tvm/contrib/mkldnn.py @@ -54,8 +54,8 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): def dnnl_conv2d( - input, - filter, + src, + weights, stride, padding, dilation, @@ -68,10 +68,10 @@ def dnnl_conv2d( Parameters ---------- - Input : tvm.te.Tensor + src : tvm.te.Tensor 4-D with shape [batch, in_channel, in_height, in_width] - Filter : tvm.te.Tensor + weights : tvm.te.Tensor 4-D with shape [num_filter, in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints @@ -114,11 +114,11 @@ def dnnl_conv2d( dilation_h, dilation_w = dilation if channel_last: - batch, in_height, in_width, _ = input.shape - kernel_h, kernel_w, _, num_filter = filter.shape + batch, in_height, in_width, _ = src.shape + kernel_h, kernel_w, _, num_filter = weights.shape else: - batch, _, in_height, in_width = input.shape - num_filter, _, kernel_h, kernel_w = filter.shape + batch, _, in_height, in_width = src.shape + num_filter, _, kernel_h, kernel_w = weights.shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -136,7 +136,7 @@ def dnnl_conv2d( return te.extern( out_shape, - [input, filter], + [src, weights], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.mkldnn.conv2d", ins[0], From cb9468bef63d77c38e521131acd3feea6ac7df6a Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 18:38:10 +0000 Subject: [PATCH 09/14] fix cpp lint issue --- src/runtime/contrib/dnnl/dnnl.cc | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 7784d074999c..872ebcbeaf17 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -97,12 +97,14 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory::dims conv2d_padding0 = {p_Ph0_, p_Pw0_}; memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_}; - auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, data); + auto user_src_memory = + memory({{conv2d_src_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, data); auto user_weights_memory = memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::hwio : tag::oihw}, eng, weights); if (p_G_ > 1) user_weights_memory = memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::ghwio : tag::goihw}, eng, weights); - auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); + auto conv2d_user_bias_memory = + memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); auto user_dst_memory = memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); @@ -158,7 +160,8 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i primitive_attr attr; std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, false); + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, + false); } primitive_attr create_attr_with_relu_post_op() { @@ -333,9 +336,9 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ read_from_dnnl_memory(out, dst_memory); } -// DNNL Conv2d single OP +// DNNL Conv2d single OP TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* input = args[0]; + DLTensor* input = args[0]; DLTensor* weights = args[1]; DLTensor* output = args[2]; int p_Ph0_ = args[3], @@ -367,10 +370,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRe std::vector bias(p_O_, 0); primitive_attr attr; - return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), bias.data(), - static_cast(output->data), p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, channel_last); - + return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), + bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, + p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, + p_Sh_, p_Sw_, attr, channel_last); }); } // namespace contrib From 8895ec6cf61e13c07f193ec0f14a395c143d6e19 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 19:44:51 +0000 Subject: [PATCH 10/14] fix cpp lint issue --- src/runtime/contrib/dnnl/dnnl.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 872ebcbeaf17..527c1279ea65 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -105,7 +105,8 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::ghwio : tag::goihw}, eng, weights); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); - auto user_dst_memory = memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); + auto user_dst_memory = + memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); From 2ecf5f601693c22584b9d58a7ed4f97330f59a16 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 19:57:51 +0000 Subject: [PATCH 11/14] fix cpp lint issue --- src/runtime/contrib/dnnl/dnnl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 527c1279ea65..f2f00bbfea1e 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -105,7 +105,7 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::ghwio : tag::goihw}, eng, weights); auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); - auto user_dst_memory = + auto user_dst_memory = memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); From 570657c7f1dd8655b2f98c9cd34265783ede7b9a Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 20:17:24 +0000 Subject: [PATCH 12/14] clang reformated --- src/runtime/contrib/dnnl/dnnl.cc | 53 +++++++++++++------------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index f2f00bbfea1e..7d3763d411ba 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -80,8 +80,8 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, - int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, - primitive_attr attr, bool channel_last) { + int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr, + bool channel_last) { using tag = memory::format_tag; using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -98,15 +98,15 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_}; auto user_src_memory = - memory({{conv2d_src_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, data); + memory({{conv2d_src_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, data); auto user_weights_memory = - memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::hwio : tag::oihw}, eng, weights); - if (p_G_ > 1) user_weights_memory = - memory({{conv2d_weights_tz}, dt::f32, channel_last? tag::ghwio : tag::goihw}, eng, weights); - auto conv2d_user_bias_memory = - memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); + memory({{conv2d_weights_tz}, dt::f32, channel_last ? tag::hwio : tag::oihw}, eng, weights); + if (p_G_ > 1) + user_weights_memory = memory( + {{conv2d_weights_tz}, dt::f32, channel_last ? tag::ghwio : tag::goihw}, eng, weights); + auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); auto user_dst_memory = - memory({{conv2d_dst_tz}, dt::f32, channel_last? tag::nhwc : tag::nchw}, eng, out); + memory({{conv2d_dst_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, out); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); @@ -129,10 +129,10 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in // reorder if weights layout not DNNL chosen. auto conv2d_weights_memory = user_weights_memory; if (conv2d_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv2d_weights_memory = memory(conv2d_prim_desc.weights_desc(), eng); - auto reorder_weights = reorder(user_weights_memory, conv2d_weights_memory); - reorder_weights.execute( - s, {{DNNL_ARG_FROM, user_weights_memory}, {DNNL_ARG_TO, conv2d_weights_memory}}); + conv2d_weights_memory = memory(conv2d_prim_desc.weights_desc(), eng); + auto reorder_weights = reorder(user_weights_memory, conv2d_weights_memory); + reorder_weights.execute( + s, {{DNNL_ARG_FROM, user_weights_memory}, {DNNL_ARG_TO, conv2d_weights_memory}}); } auto conv2d_dst_memory = user_dst_memory; @@ -148,8 +148,8 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in // reorder if dst layout not DNNL chosen. if (conv2d_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(conv2d_dst_memory, user_dst_memory).execute( - s, {{DNNL_ARG_FROM, conv2d_dst_memory}, {DNNL_ARG_TO, user_dst_memory}}); + reorder(conv2d_dst_memory, user_dst_memory) + .execute(s, {{DNNL_ARG_FROM, conv2d_dst_memory}, {DNNL_ARG_TO, user_dst_memory}}); } s.wait(); @@ -342,21 +342,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRe DLTensor* input = args[0]; DLTensor* weights = args[1]; DLTensor* output = args[2]; - int p_Ph0_ = args[3], - p_Pw0_ = args[4], - p_Ph1_ = args[5], - p_Pw1_ = args[6], - p_Sh_ = args[7], - p_Sw_ = args[8], - p_G_ = args[9]; + int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7], + p_Sw_ = args[8], p_G_ = args[9]; bool channel_last = args[10]; - int p_N_ = input->shape[0], - p_C_ = input->shape[1], - p_H_ = input->shape[2], - p_W_ = input->shape[3], - p_O_ = output->shape[1], - p_Kh_ = weights->shape[2], + int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], + p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], p_Kw_ = weights->shape[3]; if (channel_last) { @@ -372,9 +363,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRe std::vector bias(p_O_, 0); primitive_attr attr; return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), - bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, - p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, - p_Sh_, p_Sw_, attr, channel_last); + bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, p_W_, + p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, + attr, channel_last); }); } // namespace contrib From 16cb147447119543d4c821d2f33bfb2ba3acc4ce Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Sat, 4 Jun 2022 21:37:06 +0000 Subject: [PATCH 13/14] adjust .py import for testing --- python/tvm/topi/x86/conv2d.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/x86/conv2d.py b/python/tvm/topi/x86/conv2d.py index e1710ef1eea5..a28c75b81d3f 100644 --- a/python/tvm/topi/x86/conv2d.py +++ b/python/tvm/topi/x86/conv2d.py @@ -24,7 +24,8 @@ from tvm import te from tvm import autotvm from tvm.contrib import mkldnn -from .. import nn, generic +from .. import nn +from ..generic import schedule_extern from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload @@ -279,7 +280,7 @@ def conv2d_nchw_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype) @autotvm.register_topi_schedule("conv2d_nchw_mkldnn.x86") def schedule_conv2d_nchw_mkldnn(_, outs): """Create schedule for conv2d_nchw_mkldnn""" - return generic.schedule_extern(outs) + return schedule_extern(outs) @autotvm.register_topi_compute("conv2d_nhwc_mkldnn.x86") @@ -293,7 +294,7 @@ def conv2d_nhwc_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype) @autotvm.register_topi_schedule("conv2d_nhwc_mkldnn.x86") def schedule_conv2d_nhwc_mkldnn(_, outs): """Create schedule for conv2d_nhwc_mkldnn""" - return generic.schedule_extern(outs) + return schedule_extern(outs) # FIXME - https://github.com/apache/tvm/issues/4122 From 4a42fa05b51b42d84dffcd1052062ca32b7a69ca Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Tue, 7 Jun 2022 11:11:34 +0000 Subject: [PATCH 14/14] function existence check in test --- tests/python/relay/test_op_level2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 3718a6106065..db1eb16b8ca3 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1998,6 +1998,12 @@ def test_conv2d_rocm_sdot4(): @tvm.testing.requires_x86 def test_conv2d_nchw_mkldnn(): + if not tvm.get_global_func("tvm.contrib.mkldnn.conv2d", allow_missing=True): + print( + "skip because extern mkldnn function is not available, \ + built with MKLDNN=ON" + ) + return d_shape = (1, 64, 56, 56) w_shape = (64, 64, 3, 3) padding = (1, 1) @@ -2040,6 +2046,12 @@ def test_conv2d_nchw_mkldnn(): @tvm.testing.requires_x86 def test_conv2d_nhwc_mkldnn(): + if not tvm.get_global_func("tvm.contrib.mkldnn.conv2d", allow_missing=True): + print( + "skip because extern mkldnn function is not available, \ + built with MKLDNN=ON" + ) + return d_shape = (1, 56, 56, 64) w_shape = (3, 3, 64, 64) padding = (1, 1)