Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions include/tvm/topi/nn/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TOPI_NN_LAYER_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/reduction.h>
#include <tvm/topi/tags.h>

#include <string>
Expand Down Expand Up @@ -59,17 +60,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor&
TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
<< "layer_norm: only support float32 and float16 for now";
bool is_float16 = data_type == DataType::Float(16);
// sum x and x^2
// Two-pass algorithm for improved numerical stability:
// pass1: mean = E[x]
// pass2: var = E[(x - mean)^2]
auto ndim = data->shape.size();
TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false);
auto func = MakeTupleSumReducer();

auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
&data](const ffi::Array<Var>& indices) {
auto make_eval_range = [&real_axis, &reduce_axes,
ndim](const ffi::Array<Var>& non_reduce_indices) {
ffi::Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;
Expand All @@ -80,34 +82,51 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor&
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
eval_range.push_back(indices[arg_counter]);
eval_range.push_back(non_reduce_indices[arg_counter]);
arg_counter++;
}
}
auto square = [is_float16](const PrimExpr& x) {
if (is_float16) {
return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
}
return x * x;
};
if (is_float16) {
return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))},
reduce_axes, nullptr);
} else {
return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr);
}
return eval_range;
};

auto temp_x_x2 =
tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);
Tensor temp_sum = te::compute(
target_shape,
[is_float16, &data, &reduce_axes, &make_eval_range](const ffi::Array<Var>& indices) {
auto eval_range = make_eval_range(indices);
PrimExpr x = data(eval_range);
if (is_float16) {
x = Cast(DataType::Float(32), x);
}
return sum(x, reduce_axes);
},
data->op->name + "_sum", kCommReduce);

auto temp_x = temp_x_x2[0];
auto temp_x2 = temp_x_x2[1];

auto reduce_extent = make_const(data->dtype, 1);
DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype;
PrimExpr reduce_extent = make_const(reduce_dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
Tensor temp_mean = te::compute(
target_shape,
[&temp_sum, &reduce_extent](const ffi::Array<Var>& indices) {
return temp_sum(indices) / reduce_extent;
},
data->op->name + "_mean", kInjective);

Tensor temp_var_sum = te::compute(
target_shape,
[is_float16, &data, &reduce_axes, &make_eval_range,
&temp_mean](const ffi::Array<Var>& indices) {
auto eval_range = make_eval_range(indices);
PrimExpr x = data(eval_range);
if (is_float16) {
x = Cast(DataType::Float(32), x);
}
PrimExpr diff = x - temp_mean(indices);
return sum(diff * diff, reduce_axes);
},
data->op->name + "_var_sum", kCommReduce);

auto layer_norm_func = [&](const ffi::Array<Var>& indices) {
ffi::Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
Expand All @@ -117,9 +136,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor&
non_reduce_indices.push_back(indices[i]);
}
}
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon));
auto mean = temp_mean(non_reduce_indices);
auto var = temp_var_sum(non_reduce_indices) / reduce_extent;
auto layer_norm = (data(indices) - mean) * rsqrt(var + make_const(var->dtype, epsilon));
if (is_float16) {
layer_norm = Cast(DataType::Float(16), layer_norm);
}
Expand All @@ -129,7 +148,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor&
}
return layer_norm;
};
return tvm::te::compute(data->shape, layer_norm_func, name, tag);
return te::compute(data->shape, layer_norm_func, name, tag);
}

} // namespace nn
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,6 +2309,44 @@ def test_layer_norm_with_nd_gamma_beta():
check_correctness(model)


def test_layer_norm_numerical_stability():
"""Numerical stability test for https://github.com/apache/tvm/issues/19592."""
layer_norm_node = helper.make_node(
"LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=-1, epsilon=1e-5
)
graph = helper.make_graph(
[layer_norm_node],
"layer_norm_numerical_stability",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 4]),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, [4]),
helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]),
],
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]),
],
)
model = helper.make_model(graph, producer_name="layer_norm_numerical_stability")

input_array = np.array([[80000.0, 80001.0, 80002.0, 80003.0]], dtype=np.float32)
scale_array = np.ones(4, dtype=np.float32)
bias_array = np.zeros(4, dtype=np.float32)
inputs = {"input": input_array, "scale": scale_array, "bias": bias_array}

# ONNXRuntime also returns NaN for Large-value, small-variance inputs, so we here
# compare against a two-pass reference instead of ORT.
mean = input_array.mean(axis=-1, keepdims=True)
var = ((input_array - mean) ** 2).mean(axis=-1, keepdims=True)
expected = ((input_array - mean) / np.sqrt(var + 1e-5) * scale_array + bias_array).astype(
np.float32
)

tvm_output = run_in_tvm(model, inputs=inputs, ir_version=9, opset=17)

assert np.isfinite(tvm_output.numpy()).all()
tvm.testing.assert_allclose(tvm_output.numpy(), expected)


def test_rms_norm():
# Basic test: default axis=-1
rms_norm_node = helper.make_node("RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05)
Expand Down
Loading
Loading