Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
35 changes: 28 additions & 7 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,14 @@ class GroupNorm(HybridBlock):
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
v2: bool, default False
If True, a correct implementation of the group normalization operator from the original
paper will be used. This is the only version of this operator that will be available in
MXNet 2.0.
If False, an incorrect implementation of group normalization will be used. This setting is
only present for backward compatibility with older MXNet 1.x models. Please see
`#18199 <https://github.com/apache/incubator-mxnet/pull/18199>`_ and
`#17139 <https://github.com/apache/incubator-mxnet/issues/17139>`_ for more information.


Inputs:
Expand Down Expand Up @@ -690,26 +698,39 @@ class GroupNorm(HybridBlock):
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
in_channels=0, v2=False, prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self._v2 = v2

if not v2:
warnings.warn("You are using an incorrect implementation of GroupNorm. To use a "
"corrected implementation, pass v2=True into GroupNorm's constructor. "
"The current (broken) implementation will be removed in MXNet 2.0. "
"Please see #18199 and #17139 for more details.")

self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
allow_deferred_init=True)
shape=(in_channels if v2 else num_groups,),
init=gamma_initializer, allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
allow_deferred_init=True)
shape=(in_channels if v2 else num_groups,),
init=beta_initializer, allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon,
v2=self._v2)
return norm_data

def __repr__(self):
s = '{name}({content})'
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ', v2={0}'.format(self._v2)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
Expand Down
36 changes: 24 additions & 12 deletions src/operator/nn/group_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ struct GroupNormParam : public dmlc::Parameter<GroupNormParam> {
int num_groups;
float eps;
bool output_mean_var;
bool v2;
DMLC_DECLARE_PARAMETER(GroupNormParam) {
DMLC_DECLARE_FIELD(num_groups).set_default(1)
.describe("Total number of groups.");
DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
.describe("An `epsilon` parameter to prevent division by 0.");
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
.describe("Output the mean and std calculated along the given axis.");
DMLC_DECLARE_FIELD(v2).set_default(false)
.describe("Set to true to use a corrected version of this operator that will be used in "
"MXNet 2.0. The default value is currently false for backward compatibility. "
"This option will be removed with MXNet 2.0. Please see #18199 and #17139 for more "
"information.");
}
};

Expand All @@ -76,6 +82,7 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
using namespace mxnet_op;
const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
const int num_groups = param.num_groups;
const bool v2 = param.v2;
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);

Expand Down Expand Up @@ -136,16 +143,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
TBlob data_grp = data.reshape(temp_data_shape);
const TBlob& mean_grp = mean.reshape(moments_shape);
const TBlob& std_grp = std.reshape(moments_shape);
const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape);

// Calculate data = data - mean
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data_grp, mean_grp},
{kWriteTo}, {output});
{kWriteTo}, {output_grp});

// Calculate std
const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_, req[0], workspace, centered_out);
Expand All @@ -157,11 +164,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,

// Calculate data = data / std
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{output, std_grp},
{kWriteTo}, {output});
{output_grp, std_grp},
{kWriteTo}, {output_grp});

mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
const TBlob& output = v2 ? outputs[groupnorm::kOut] : output_grp;
mxnet::TShape new_param_shape(data_shape.ndim() + (v2 ? 0 : 1), 1);
new_param_shape[1] = v2 ? data_shape[1] : num_groups;

const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
Expand Down Expand Up @@ -199,6 +207,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs.size(), 3U);
const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
const int num_groups = param.num_groups;
const bool v2 = param.v2;

const TBlob& data = inputs[1];
const mxnet::TShape& dshape = data.shape_;
Expand All @@ -215,8 +224,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,

Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma to be broadcastable
mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
mxnet::TShape new_param_shape(dshape.ndim() + (v2 ? 0 : 1), 1);
new_param_shape[1] = v2 ? dshape[1] : num_groups;

const TBlob& gamma = inputs[2].reshape(new_param_shape);

Expand All @@ -233,7 +242,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
// Prepare the necessary shapes for reduction
mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
BroadcastReduceShapeCompact(v2 ? dshape : temp_dshape, gamma.shape_,
&red_exclude_src_shape, &red_exclude_dst_shape);

int N = red_src_shape.Size() / red_dst_shape.Size();
Expand Down Expand Up @@ -308,8 +317,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
if (req[0] != kNullOp) {
const TBlob output_ = outputs[0].reshape(data_.shape_);
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{ograd, gamma},
{kWriteTo}, {ograd_mult});
{v2 ? inputs[0] : ograd, gamma},
{kWriteTo},
{v2
? ograd_mult.reshape(data.shape_)
: ograd_mult});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std_},
{kWriteTo}, {ograd_mult});
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
return false;
}

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(param.v2 ? dshape[1] : num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(param.v2 ? dshape[1] : num_groups));

out_shape->clear();
out_shape->push_back(dshape);
Expand Down
114 changes: 63 additions & 51 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,79 +1960,91 @@ def x_hat_helper(x, num_groups, eps):
x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape)
return x_hat, mean, std

def np_groupnorm(data, gamma, beta, num_groups, eps):
new_param_shape = (1, num_groups, 1, 1, 1)
def np_groupnorm(data, gamma, beta, num_groups, eps, v2):
new_param_shape = (1, dshape[1], 1, 1) if v2 else (1, num_groups, 1, 1, 1)
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out.reshape(dshape), mean, std
if v2:
out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out, mean, std
else:
out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out.reshape(dshape), mean, std

def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps, v2):
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
new_shape = x_hat.shape
dshape = data.shape
dtype = data.dtype
new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
new_param_shape = (1, num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1) if v2 else (1, num_groups, 1, 1, 1)
acc_type = acc_types[str(dtype)]
ograd = ograd.reshape(new_shape)
data = data.reshape(new_shape)
gamma = gamma.reshape(new_param_shape)
beta = beta.reshape(new_param_shape)
mean = mean.reshape(new_moments_shape)
std = std.reshape(new_moments_shape)
beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
x_hat_grad = ograd * gamma
beta_grad = np.sum(ograd, axis=(0, 3, 4) if v2 else (0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
if v2:
beta_grad = beta_grad.flatten()
gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4) if v2 else (0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
if v2:
gamma_grad = gamma_grad.flatten()
if v2:
x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
else:
x_hat_grad = ograd * gamma
ograd_mult = x_hat_grad / std
red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = ograd_mult - red_out
red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = data_grad - x_hat * red_out
return data_grad.reshape(dshape), gamma_grad, beta_grad


batch_size = random.randint(1, 8)
num_groups = random.randint(2, 3)
num_channels = random.randint(2, 3) * num_groups
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
param_shape = (num_groups,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
np_beta = np.random.uniform(-1.0, 1.0, param_shape)
data_sym = mx.sym.Variable("data")
gamma_sym = mx.sym.Variable("gamma")
beta_sym = mx.sym.Variable("beta")
for dtype in [np.float16, np.float32, np.float64]:
eps = 1e-2 if dtype == np.float16 else 1e-5
mx_data = mx.nd.array(np_data, dtype=dtype)
mx_gamma = mx.nd.array(np_gamma, dtype=dtype)
mx_beta = mx.nd.array(np_beta, dtype=dtype)
np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
num_groups=num_groups,
eps=eps)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=True)
check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=False)
np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd,
np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
np_mean, np_std,
num_groups, eps)
check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)],
[np_data_grad, np_gamma_grad, np_beta_grad],
for v2 in [False, True]:
batch_size = random.randint(1, 8)
num_groups = random.randint(2, 3)
num_channels = random.randint(2, 3) * num_groups
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
param_shape = (num_channels if v2 else num_groups,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
np_beta = np.random.uniform(-1.0, 1.0, param_shape)
data_sym = mx.sym.Variable("data")
gamma_sym = mx.sym.Variable("gamma")
beta_sym = mx.sym.Variable("beta")
for dtype in [np.float16, np.float32, np.float64]:
eps = 1e-2 if dtype == np.float16 else 1e-5
mx_data = mx.nd.array(np_data, dtype=dtype)
mx_gamma = mx.nd.array(np_gamma, dtype=dtype)
mx_beta = mx.nd.array(np_beta, dtype=dtype)
np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
num_groups=num_groups,
eps=eps,
v2=v2)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=True, v2=v2)
check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype)
atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype)
mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym,
num_groups=num_groups, eps=eps, output_mean_var=False, v2=v2)
np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd,
np_data.astype(dtype),
np_gamma.astype(dtype),
np_beta.astype(dtype),
np_mean, np_std,
num_groups, eps, v2)
check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)],
[np_data_grad, np_gamma_grad, np_beta_grad],
rtol=1e-2 if dtype == np.float16 else 1e-3,
atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype)


@with_seed()
Expand Down