Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/modules/contrib/BLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ if(IS_DIRECTORY ${USE_MKLDNN})
include_directories(SYSTEM ${USE_MKLDNN}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${MKLDNN_LIBRARY})
list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/dnnl/dnnl.cc)
add_definitions(-DUSE_DNNL=1)
message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY})
endif()
Expand All @@ -84,6 +85,7 @@ elseif(USE_MKLDNN STREQUAL "ON")
add_definitions(-DUSE_DNNL=1)
message(STATUS "Use MKLDNN library " ${MKLDNN_LIBRARY})
list(APPEND RUNTIME_SRCS src/runtime/contrib/cblas/mkldnn.cc)
list(APPEND RUNTIME_SRCS src/runtime/contrib/dnnl/dnnl.cc)
endif()
elseif(USE_MKLDNN STREQUAL "OFF")
# pass
Expand Down
105 changes: 105 additions & 0 deletions python/tvm/contrib/mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""External function interface to BLAS libraries."""
import tvm
from tvm import te
from ..topi.nn.utils import get_pad_tuple


def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
Expand Down Expand Up @@ -50,3 +51,107 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
name="C",
**kwargs,
)


def dnnl_conv2d(
src,
weights,
stride,
padding,
dilation,
groups,
channel_last=False,
out_dtype="float32",
**kwargs,
):
"""Convolution operator in NCHW layout.

Parameters
----------
src : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]

weights : tvm.te.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints

dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]

groups: str
input data layout: NCHW or NHWC

channel_last: bool
chose if input/output data format is in channel_last format(NHWC) or
in plain format(NCHW)

out_dtype: str
output datatype: now only support float32

Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""

assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

if channel_last:
batch, in_height, in_width, _ = src.shape
kernel_h, kernel_w, _, num_filter = weights.shape
else:
batch, _, in_height, in_width = src.shape
num_filter, _, kernel_h, kernel_w = weights.shape

dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w)
)
out_channel = num_filter
out_height = (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
out_width = (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1

if channel_last:
out_shape = (batch, out_height, out_width, out_channel)
else:
out_shape = (batch, out_channel, out_height, out_width)

return te.extern(
out_shape,
[src, weights],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mkldnn.conv2d",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit confusing...so does the use of libraries. We now have USE_MKLDNN (for cblas with matmul/dense) and USE_DNNL (for DNNL/OneDNN with matmal/dense/conv2d). AFAIK, MKL-DNN can be covered by DNNL, so should we deprecate MKL-DNN and use DNNL for both cases (e.g., -libs and BYOC)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that should be a unified name of DNNL, we already have USE_DNNL_CODEGEN and USE_MKLDNN for BYOC and -libs,I suggest to have USE_DNNL_LIBS for -libs and USE_DNNL_CODEGEN for BYOC, and change 'mkldnn' to 'dnnl' in codes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we really need to have 2 flags? It seems fine to enable both libs and BYOC when USE_DNNL is ON.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that would be more concise, 💯
I'll try and commit later.

ins[0],
ins[1],
outs[0],
pad_top,
pad_down,
pad_left,
pad_right,
stride[0],
stride[1],
groups,
channel_last,
),
name="C",
dtype=out_dtype,
**kwargs,
)
23 changes: 18 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.x86",
)
elif "mkldnn" in target.libs:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw_mkldnn),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_mkldnn),
name="conv2d_nchw_mkldnn.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw),
Expand All @@ -133,11 +139,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
assert kernel_layout == "HWIO"
if not is_auto_scheduler_enabled():
logger.warning("conv2d NHWC layout is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
if "mkldnn" in target.libs:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nhwc_mkldnn),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc_mkldnn),
name="conv2d_nhwc_mkldnn.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)

judge_winograd_auto_scheduler = False
if len(kernel.shape) == 4:
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.contrib import mkldnn
from .. import nn
from ..generic import schedule_extern
from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
Expand Down Expand Up @@ -267,6 +269,34 @@ def _callback(op):
return s


@autotvm.register_topi_compute("conv2d_nchw_mkldnn.x86")
def conv2d_nchw_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute conv2d in NCHW format using mkldnn."""
groups = 1
_out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, False, out_dtype)
return _out


@autotvm.register_topi_schedule("conv2d_nchw_mkldnn.x86")
def schedule_conv2d_nchw_mkldnn(_, outs):
"""Create schedule for conv2d_nchw_mkldnn"""
return schedule_extern(outs)


@autotvm.register_topi_compute("conv2d_nhwc_mkldnn.x86")
def conv2d_nhwc_mkldnn(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute conv2d in NHWC format using mkldnn."""
groups = 1
_out = mkldnn.dnnl_conv2d(data, kernel, strides, padding, dilation, groups, True, out_dtype)
return _out


@autotvm.register_topi_schedule("conv2d_nhwc_mkldnn.x86")
def schedule_conv2d_nhwc_mkldnn(_, outs):
"""Create schedule for conv2d_nhwc_mkldnn"""
return schedule_extern(outs)


# FIXME - https://github.com/apache/tvm/issues/4122
# _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO
# layout. Commenting until we have clarity about the nhwc_pack implementation from the author.
Expand Down
82 changes: 72 additions & 10 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {

void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_,
int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_,
primitive_attr attr) {
int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr,
bool channel_last) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
Expand All @@ -97,32 +97,62 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in
memory::dims conv2d_padding0 = {p_Ph0_, p_Pw0_};
memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_};

auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, data);
auto user_weights_memory =
memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
memory({{conv2d_weights_tz}, dt::f32, channel_last ? tag::hwio : tag::oihw}, eng, weights);
if (p_G_ > 1)
user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, channel_last ? tag::ghwio : tag::goihw}, eng, weights);
auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_dst_memory =
memory({{conv2d_dst_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, out);

auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::any);

auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding0, conv2d_padding1);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);

// reorder if src layout not DNNL chosen.
auto conv2d_src_memory = user_src_memory;
if (conv2d_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv2d_src_memory = memory(conv2d_prim_desc.src_desc(), eng);
auto reorder_src = reorder(user_src_memory, conv2d_src_memory);
reorder_src.execute(s, {{DNNL_ARG_FROM, user_src_memory}, {DNNL_ARG_TO, conv2d_src_memory}});
}

// reorder if weights layout not DNNL chosen.
auto conv2d_weights_memory = user_weights_memory;
auto conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng);
if (conv2d_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv2d_weights_memory = memory(conv2d_prim_desc.weights_desc(), eng);
auto reorder_weights = reorder(user_weights_memory, conv2d_weights_memory);
reorder_weights.execute(
s, {{DNNL_ARG_FROM, user_weights_memory}, {DNNL_ARG_TO, conv2d_weights_memory}});
}

auto conv2d_dst_memory = user_dst_memory;
if (conv2d_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng);
}

auto conv = convolution_forward(conv2d_prim_desc);
conv.execute(s, {{DNNL_ARG_SRC, conv2d_src_memory},
{DNNL_ARG_WEIGHTS, conv2d_weights_memory},
{DNNL_ARG_BIAS, conv2d_user_bias_memory},
{DNNL_ARG_DST, conv2d_dst_memory}});

// reorder if dst layout not DNNL chosen.
if (conv2d_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv2d_dst_memory, user_dst_memory)
.execute(s, {{DNNL_ARG_FROM, conv2d_dst_memory}, {DNNL_ARG_TO, user_dst_memory}});
}

s.wait();
read_from_dnnl_memory(out, conv2d_dst_memory);
}

extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_,
Expand All @@ -131,7 +161,8 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr,
false);
}

primitive_attr create_attr_with_relu_post_op() {
Expand All @@ -151,7 +182,7 @@ extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out,
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
create_attr_with_relu_post_op(), false);
}

extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
Expand All @@ -161,7 +192,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float*
int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_,
p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
create_attr_with_relu_post_op(), false);
}

extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) {
Expand Down Expand Up @@ -306,6 +337,37 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_
read_from_dnnl_memory(out, dst_memory);
}

// DNNL Conv2d single OP
TVM_REGISTER_GLOBAL("tvm.contrib.mkldnn.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* input = args[0];
DLTensor* weights = args[1];
DLTensor* output = args[2];
int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7],
p_Sw_ = args[8], p_G_ = args[9];
bool channel_last = args[10];

int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2],
p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2],
p_Kw_ = weights->shape[3];

if (channel_last) {
p_N_ = input->shape[0];
p_H_ = input->shape[1];
p_W_ = input->shape[2];
p_C_ = input->shape[3];
p_O_ = output->shape[3];
p_Kh_ = weights->shape[0];
p_Kw_ = weights->shape[1];
}

std::vector<float> bias(p_O_, 0);
primitive_attr attr;
return dnnl_conv2d_common(static_cast<float*>(input->data), static_cast<float*>(weights->data),
bias.data(), static_cast<float*>(output->data), p_N_, p_C_, p_H_, p_W_,
p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
attr, channel_last);
});

} // namespace contrib
} // namespace runtime
} // namespace tvm
1 change: 1 addition & 0 deletions src/runtime/contrib/dnnl/dnnl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>

#include <vector>

Expand Down
Loading