diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 61f9ea07fbec..993b060f18a6 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -654,6 +654,14 @@ class GroupNorm(HybridBlock): Initializer for the beta weight. gamma_initializer: str or `Initializer`, default 'ones' Initializer for the gamma weight. + v2: bool, default False + If True, a correct implementation of the group normalization operator from the original + paper will be used. This is the only version of this operator that will be available in + MXNet 2.0. + If False, an incorrect implementation of group normalization will be used. This setting is + only present for backward compatibility with older MXNet 1.x models. Please see + `#18199 `_ and + `#17139 `_ for more information. Inputs: @@ -690,26 +698,39 @@ class GroupNorm(HybridBlock): """ def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', - prefix=None, params=None): + in_channels=0, v2=False, prefix=None, params=None): super(GroupNorm, self).__init__(prefix=prefix, params=params) self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale} self._num_groups = num_groups self._epsilon = epsilon self._center = center self._scale = scale + self._v2 = v2 + + if not v2: + warnings.warn("You are using an incorrect implementation of GroupNorm. To use a " + "corrected implementation, pass v2=True into GroupNorm's constructor. " + "The current (broken) implementation will be removed in MXNet 2.0. " + "Please see #18199 and #17139 for more details.") + self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(num_groups,), init=gamma_initializer, - allow_deferred_init=True) + shape=(in_channels if v2 else num_groups,), + init=gamma_initializer, allow_deferred_init=True) self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(num_groups,), init=beta_initializer, - allow_deferred_init=True) + shape=(in_channels if v2 else num_groups,), + init=beta_initializer, allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): - norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon) + norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon, + v2=self._v2) return norm_data def __repr__(self): - s = '{name}({content})' + s = '{name}({content}' + in_channels = self.gamma.shape[0] + s += ', in_channels={0}'.format(in_channels) + s += ', v2={0}'.format(self._v2) + s += ')' return s.format(name=self.__class__.__name__, content=', '.join(['='.join([k, v.__repr__()]) for k, v in self._kwargs.items()])) diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h index 69d5a304dc2c..aede3dd2429f 100644 --- a/src/operator/nn/group_norm-inl.h +++ b/src/operator/nn/group_norm-inl.h @@ -54,6 +54,7 @@ struct GroupNormParam : public dmlc::Parameter { int num_groups; float eps; bool output_mean_var; + bool v2; DMLC_DECLARE_PARAMETER(GroupNormParam) { DMLC_DECLARE_FIELD(num_groups).set_default(1) .describe("Total number of groups."); @@ -61,6 +62,11 @@ struct GroupNormParam : public dmlc::Parameter { .describe("An `epsilon` parameter to prevent division by 0."); DMLC_DECLARE_FIELD(output_mean_var).set_default(false) .describe("Output the mean and std calculated along the given axis."); + DMLC_DECLARE_FIELD(v2).set_default(false) + .describe("Set to true to use a corrected version of this operator that will be used in " + "MXNet 2.0. The default value is currently false for backward compatibility. " + "This option will be removed with MXNet 2.0. Please see #18199 and #17139 for more " + "information."); } }; @@ -76,6 +82,7 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; const GroupNormParam& param = nnvm::get(attrs.parsed); const int num_groups = param.num_groups; + const bool v2 = param.v2; if (req[0] == kNullOp) return; CHECK_NE(req[0], kAddTo); @@ -136,16 +143,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, TBlob data_grp = data.reshape(temp_data_shape); const TBlob& mean_grp = mean.reshape(moments_shape); const TBlob& std_grp = std.reshape(moments_shape); - const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape); + const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape); // Calculate data = data - mean BinaryBroadcastCompute(attrs, ctx, {data_grp, mean_grp}, - {kWriteTo}, {output}); + {kWriteTo}, {output_grp}); // Calculate std const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); - MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( s, std_, req[0], workspace, centered_out); @@ -157,11 +164,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, // Calculate data = data / std BinaryBroadcastCompute(attrs, ctx, - {output, std_grp}, - {kWriteTo}, {output}); + {output_grp, std_grp}, + {kWriteTo}, {output_grp}); - mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1); - new_param_shape[1] = num_groups; + const TBlob& output = v2 ? outputs[groupnorm::kOut] : output_grp; + mxnet::TShape new_param_shape(data_shape.ndim() + (v2 ? 0 : 1), 1); + new_param_shape[1] = v2 ? data_shape[1] : num_groups; const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape); const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape); @@ -199,6 +207,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 3U); const GroupNormParam& param = nnvm::get(attrs.parsed); const int num_groups = param.num_groups; + const bool v2 = param.v2; const TBlob& data = inputs[1]; const mxnet::TShape& dshape = data.shape_; @@ -215,8 +224,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); // Reshape gamma to be broadcastable - mxnet::TShape new_param_shape(dshape.ndim() + 1, 1); - new_param_shape[1] = num_groups; + mxnet::TShape new_param_shape(dshape.ndim() + (v2 ? 0 : 1), 1); + new_param_shape[1] = v2 ? dshape[1] : num_groups; const TBlob& gamma = inputs[2].reshape(new_param_shape); @@ -233,7 +242,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, // Prepare the necessary shapes for reduction mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape); - BroadcastReduceShapeCompact(temp_dshape, gamma.shape_, + BroadcastReduceShapeCompact(v2 ? dshape : temp_dshape, gamma.shape_, &red_exclude_src_shape, &red_exclude_dst_shape); int N = red_src_shape.Size() / red_dst_shape.Size(); @@ -308,8 +317,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, if (req[0] != kNullOp) { const TBlob output_ = outputs[0].reshape(data_.shape_); BinaryBroadcastCompute(attrs, ctx, - {ograd, gamma}, - {kWriteTo}, {ograd_mult}); + {v2 ? inputs[0] : ograd, gamma}, + {kWriteTo}, + {v2 + ? ograd_mult.reshape(data.shape_) + : ograd_mult}); BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std_}, {kWriteTo}, {ograd_mult}); diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc index 6b8fe9bbd4c9..4b98fb81fc77 100644 --- a/src/operator/nn/group_norm.cc +++ b/src/operator/nn/group_norm.cc @@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs, return false; } - in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups)); - in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups)); + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(param.v2 ? dshape[1] : num_groups)); + in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(param.v2 ? dshape[1] : num_groups)); out_shape->clear(); out_shape->push_back(dshape); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e22d529eeb41..b67778c74538 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1960,19 +1960,23 @@ def x_hat_helper(x, num_groups, eps): x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape) return x_hat, mean, std - def np_groupnorm(data, gamma, beta, num_groups, eps): - new_param_shape = (1, num_groups, 1, 1, 1) + def np_groupnorm(data, gamma, beta, num_groups, eps, v2): + new_param_shape = (1, dshape[1], 1, 1) if v2 else (1, num_groups, 1, 1, 1) x_hat, mean, std = x_hat_helper(data, num_groups, eps) - out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) - return out.reshape(dshape), mean, std + if v2: + out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) + return out, mean, std + else: + out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) + return out.reshape(dshape), mean, std - def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): + def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps, v2): x_hat, mean, std = x_hat_helper(data, num_groups, eps) new_shape = x_hat.shape dshape = data.shape dtype = data.dtype new_moments_shape = (new_shape[0], num_groups, 1, 1, 1) - new_param_shape = (1, num_groups, 1, 1, 1) + new_param_shape = (1, dshape[1], 1, 1) if v2 else (1, num_groups, 1, 1, 1) acc_type = acc_types[str(dtype)] ograd = ograd.reshape(new_shape) data = data.reshape(new_shape) @@ -1980,9 +1984,16 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): beta = beta.reshape(new_param_shape) mean = mean.reshape(new_moments_shape) std = std.reshape(new_moments_shape) - beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) - gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) - x_hat_grad = ograd * gamma + beta_grad = np.sum(ograd, axis=(0, 3, 4) if v2 else (0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + if v2: + beta_grad = beta_grad.flatten() + gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4) if v2 else (0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + if v2: + gamma_grad = gamma_grad.flatten() + if v2: + x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1) + else: + x_hat_grad = ograd * gamma ograd_mult = x_hat_grad / std red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) data_grad = ograd_mult - red_out @@ -1990,49 +2001,50 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): data_grad = data_grad - x_hat * red_out return data_grad.reshape(dshape), gamma_grad, beta_grad - - batch_size = random.randint(1, 8) - num_groups = random.randint(2, 3) - num_channels = random.randint(2, 3) * num_groups - height = random.randint(1, 5) - width = random.randint(1, 5) - dshape = (batch_size, num_channels, height, width) - param_shape = (num_groups,) - temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) - np_data = np.random.uniform(0.2, 1.0, dshape) - np_gamma = np.random.uniform(-1.0, 1.0, param_shape) - np_beta = np.random.uniform(-1.0, 1.0, param_shape) - data_sym = mx.sym.Variable("data") - gamma_sym = mx.sym.Variable("gamma") - beta_sym = mx.sym.Variable("beta") - for dtype in [np.float16, np.float32, np.float64]: - eps = 1e-2 if dtype == np.float16 else 1e-5 - mx_data = mx.nd.array(np_data, dtype=dtype) - mx_gamma = mx.nd.array(np_gamma, dtype=dtype) - mx_beta = mx.nd.array(np_beta, dtype=dtype) - np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype), - np_gamma.astype(dtype), - np_beta.astype(dtype), - num_groups=num_groups, - eps=eps) - mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, - num_groups=num_groups, eps=eps, output_mean_var=True) - check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], - rtol=1e-2 if dtype == np.float16 else 1e-3, - atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype) - mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, - num_groups=num_groups, eps=eps, output_mean_var=False) - np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype) - np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd, - np_data.astype(dtype), - np_gamma.astype(dtype), - np_beta.astype(dtype), - np_mean, np_std, - num_groups, eps) - check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)], - [np_data_grad, np_gamma_grad, np_beta_grad], + for v2 in [False, True]: + batch_size = random.randint(1, 8) + num_groups = random.randint(2, 3) + num_channels = random.randint(2, 3) * num_groups + height = random.randint(1, 5) + width = random.randint(1, 5) + dshape = (batch_size, num_channels, height, width) + param_shape = (num_channels if v2 else num_groups,) + temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) + np_data = np.random.uniform(0.2, 1.0, dshape) + np_gamma = np.random.uniform(-1.0, 1.0, param_shape) + np_beta = np.random.uniform(-1.0, 1.0, param_shape) + data_sym = mx.sym.Variable("data") + gamma_sym = mx.sym.Variable("gamma") + beta_sym = mx.sym.Variable("beta") + for dtype in [np.float16, np.float32, np.float64]: + eps = 1e-2 if dtype == np.float16 else 1e-5 + mx_data = mx.nd.array(np_data, dtype=dtype) + mx_gamma = mx.nd.array(np_gamma, dtype=dtype) + mx_beta = mx.nd.array(np_beta, dtype=dtype) + np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + num_groups=num_groups, + eps=eps, + v2=v2) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=True, v2=v2) + check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], rtol=1e-2 if dtype == np.float16 else 1e-3, - atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype) + atol=5e-3 if dtype == np.float16 else 1e-4, dtype=dtype) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=False, v2=v2) + np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype) + np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd, + np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + np_mean, np_std, + num_groups, eps, v2) + check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)], + [np_data_grad, np_gamma_grad, np_beta_grad], + rtol=1e-2 if dtype == np.float16 else 1e-3, + atol=5e-2 if dtype == np.float16 else 1e-4, dtype=dtype) @with_seed()