diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 873a5fd1b2d2..d74bbce23f65 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -25,6 +25,7 @@ #define TVM_TOPI_NN_LAYER_NORM_H_ #include +#include #include #include @@ -59,17 +60,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) << "layer_norm: only support float32 and float16 for now"; bool is_float16 = data_type == DataType::Float(16); - // sum x and x^2 + // Two-pass algorithm for improved numerical stability: + // pass1: mean = E[x] + // pass2: var = E[(x - mean)^2] auto ndim = data->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false); - auto func = MakeTupleSumReducer(); - auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const ffi::Array& indices) { + auto make_eval_range = [&real_axis, &reduce_axes, + ndim](const ffi::Array& non_reduce_indices) { ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -80,34 +82,51 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& eval_range.push_back(reduce_axes[red_counter]); red_counter++; } else { - eval_range.push_back(indices[arg_counter]); + eval_range.push_back(non_reduce_indices[arg_counter]); arg_counter++; } } - auto square = [is_float16](const PrimExpr& x) { - if (is_float16) { - return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); - } - return x * x; - }; - if (is_float16) { - return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, - reduce_axes, nullptr); - } else { - return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); - } + return eval_range; }; - auto temp_x_x2 = - tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); + Tensor temp_sum = te::compute( + target_shape, + [is_float16, &data, &reduce_axes, &make_eval_range](const ffi::Array& indices) { + auto eval_range = make_eval_range(indices); + PrimExpr x = data(eval_range); + if (is_float16) { + x = Cast(DataType::Float(32), x); + } + return sum(x, reduce_axes); + }, + data->op->name + "_sum", kCommReduce); - auto temp_x = temp_x_x2[0]; - auto temp_x2 = temp_x_x2[1]; - - auto reduce_extent = make_const(data->dtype, 1); + DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype; + PrimExpr reduce_extent = make_const(reduce_dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; } + Tensor temp_mean = te::compute( + target_shape, + [&temp_sum, &reduce_extent](const ffi::Array& indices) { + return temp_sum(indices) / reduce_extent; + }, + data->op->name + "_mean", kInjective); + + Tensor temp_var_sum = te::compute( + target_shape, + [is_float16, &data, &reduce_axes, &make_eval_range, + &temp_mean](const ffi::Array& indices) { + auto eval_range = make_eval_range(indices); + PrimExpr x = data(eval_range); + if (is_float16) { + x = Cast(DataType::Float(32), x); + } + PrimExpr diff = x - temp_mean(indices); + return sum(diff * diff, reduce_axes); + }, + data->op->name + "_var_sum", kCommReduce); + auto layer_norm_func = [&](const ffi::Array& indices) { ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { @@ -117,9 +136,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& non_reduce_indices.push_back(indices[i]); } } - auto mean = temp_x(non_reduce_indices) / reduce_extent; - auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + auto mean = temp_mean(non_reduce_indices); + auto var = temp_var_sum(non_reduce_indices) / reduce_extent; + auto layer_norm = (data(indices) - mean) * rsqrt(var + make_const(var->dtype, epsilon)); if (is_float16) { layer_norm = Cast(DataType::Float(16), layer_norm); } @@ -129,7 +148,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& } return layer_norm; }; - return tvm::te::compute(data->shape, layer_norm_func, name, tag); + return te::compute(data->shape, layer_norm_func, name, tag); } } // namespace nn diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 427881243663..7ee10993a4e9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2309,6 +2309,44 @@ def test_layer_norm_with_nd_gamma_beta(): check_correctness(model) +def test_layer_norm_numerical_stability(): + """Numerical stability test for https://github.com/apache/tvm/issues/19592.""" + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=-1, epsilon=1e-5 + ) + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_numerical_stability", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 4]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [4]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]), + ], + ) + model = helper.make_model(graph, producer_name="layer_norm_numerical_stability") + + input_array = np.array([[80000.0, 80001.0, 80002.0, 80003.0]], dtype=np.float32) + scale_array = np.ones(4, dtype=np.float32) + bias_array = np.zeros(4, dtype=np.float32) + inputs = {"input": input_array, "scale": scale_array, "bias": bias_array} + + # ONNXRuntime also returns NaN for Large-value, small-variance inputs, so we here + # compare against a two-pass reference instead of ORT. + mean = input_array.mean(axis=-1, keepdims=True) + var = ((input_array - mean) ** 2).mean(axis=-1, keepdims=True) + expected = ((input_array - mean) / np.sqrt(var + 1e-5) * scale_array + bias_array).astype( + np.float32 + ) + + tvm_output = run_in_tvm(model, inputs=inputs, ir_version=9, opset=17) + + assert np.isfinite(tvm_output.numpy()).all() + tvm.testing.assert_allclose(tvm_output.numpy(), expected) + + def test_rms_norm(): # Basic test: default axis=-1 rms_norm_node = helper.make_node("RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 6badc7fc3324..4a708b5da1f4 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2734,28 +2734,40 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32" return gv @T.prim_func(private=True, s_tir=True) - def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): + def layer_norm(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), gamma: T.Buffer((T.int64(4), T.int64(5)), "float32"), beta: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): T.func_attr({"tirx.noalias": True}) - rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): - with T.sblock("rxplaceholder_red_temp"): - ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, k2, k3]) - T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + # with T.sblock("root"): + x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.sblock("x_sum"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(x[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(x_sum[v_ax0, v_ax1]) + with T.init(): + x_sum[v_ax0, v_ax1] = T.float32(0.0) + x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.sblock("x_mean"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x_sum[v_ax0, v_ax1]) + T.writes(x_mean[v_ax0, v_ax1]) + x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] / T.float32(20.0) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.sblock("x_var_sum"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1]) + T.writes(x_var_sum[v_ax0, v_ax1]) with T.init(): - rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) - rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) - v_rxplaceholder_red_temp_v0: T.let[T.float32] = rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] - v_rxplaceholder_red_temp_v1: T.let[T.float32] = rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] * rxplaceholder[ax0, ax1, k2, k3] - rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 - rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 - for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + x_var_sum[v_ax0, v_ax1] = T.float32(0.0) + x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] + (x[v_ax0, v_ax1, v_k2, v_k3] - x_mean[v_ax0, v_ax1]) * (x[v_ax0, v_ax1, v_k2, v_k3] - x_mean[v_ax0, v_ax1]) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.sblock("T_layer_norm"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3]) - T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) - T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] / T.float32(20) - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20) * (rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3] + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0, v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, v_ax1, v_ax2, v_ax3] - x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1] / T.float32(20.0) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3] # fmt: on mod = LegalizeOps()(LayerNorm) tvm.ir.assert_structural_equal(mod, Expected) @@ -2780,26 +2792,36 @@ class LayerNorm_1D_Expected: def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")): T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): - x_red_temp_v0 = T.sblock_alloc_buffer(()) - x_red_temp_v1 = T.sblock_alloc_buffer(()) + x_sum = T.sblock_alloc_buffer(()) + x_mean = T.sblock_alloc_buffer(()) + x_var_sum = T.sblock_alloc_buffer(()) for k0 in range(T.int64(3)): - with T.sblock("x_red_temp"): + with T.sblock("x_sum"): v_k0 = T.axis.reduce(T.int64(3), k0) T.reads(x[v_k0]) - T.writes(x_red_temp_v0[()], x_red_temp_v1[()]) + T.writes(x_sum[()]) + with T.init(): + x_sum[()] = T.float32(0.0) + x_sum[()] = x_sum[()] + x[v_k0] + with T.sblock("x_mean"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(x_sum[()]) + T.writes(x_mean[()]) + x_mean[()] = x_sum[()] / T.float32(3.0) + for k0 in range(T.int64(3)): + with T.sblock("x_var_sum"): + v_k0 = T.axis.reduce(T.int64(3), k0) + T.reads(x[v_k0], x_mean[()]) + T.writes(x_var_sum[()]) with T.init(): - x_red_temp_v0[()] = T.float32(0.0) - x_red_temp_v1[()] = T.float32(0.0) - v_x_red_temp_v0: T.let[T.float32] = x_red_temp_v0[()] + x[v_k0] - v_x_red_temp_v1: T.let[T.float32] = x_red_temp_v1[()] + x[v_k0] * x[v_k0] - x_red_temp_v0[()] = v_x_red_temp_v0 - x_red_temp_v1[()] = v_x_red_temp_v1 + x_var_sum[()] = T.float32(0.0) + x_var_sum[()] = x_var_sum[()] + (x[v_k0] - x_mean[()]) * (x[v_k0] - x_mean[()]) for ax0 in range(T.int64(3)): with T.sblock("T_layer_norm"): v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) + T.reads(x[v_ax0], x_mean[()], x_var_sum[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) T.writes(T_layer_norm[v_ax0]) - T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] / T.float32(3)) * T.rsqrt(x_red_temp_v1[()] / T.float32(3) - x_red_temp_v0[()] / T.float32(3) * (x_red_temp_v0[()] / T.float32(3)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] + T_layer_norm[v_ax0] = (x[v_ax0] - x_mean[()]) * T.rsqrt(x_var_sum[()] / T.float32(3.0) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] @R.function def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): @@ -2827,47 +2849,45 @@ def main(x: R.Tensor((2, 3, 4, 5), "float16"), gamma: R.Tensor((4, 5), "float16" @I.ir_module(s_tir=True) class Expected: @T.prim_func(private=True, s_tir=True) - def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): + def layer_norm( + x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), + gamma: T.Buffer((T.int64(4), T.int64(5)), "float16"), + beta: T.Buffer((T.int64(4), T.int64(5)), "float16"), + T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), + ): T.func_attr({"tirx.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), T.int64(5)), "float16") - rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), T.int64(5)), "float16") - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") - with T.sblock("root"): - T.reads() - T.writes() - rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) - rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for k2 in range(T.int64(4)): - for k3 in range(T.int64(5)): - with T.sblock("rxplaceholder_red_temp"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_k2 = T.axis.reduce(T.int64(4), k2) - v_k3 = T.axis.reduce(T.int64(5), k3) - T.reads(rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) - T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_rxplaceholder_red_temp_v0: T.let[T.float32] = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) - v_rxplaceholder_red_temp_v1: T.let[T.float32] = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 - rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(4)): - for ax3 in range(T.int64(5)): - with T.sblock("T_layer_norm"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(4), ax2) - v_ax3 = T.axis.spatial(T.int64(5), ax3) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3], rxplaceholder_2[v_ax2, v_ax3]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) - T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3]) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) * T.float16(5))) + T.float32(1.0000000000000001e-05))) * rxplaceholder_1[v_ax2, v_ax3] + rxplaceholder_2[v_ax2, v_ax3] + # with T.sblock("root"): + x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.sblock("x_sum"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(x[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(x_sum[v_ax0, v_ax1]) + with T.init(): + x_sum[v_ax0, v_ax1] = T.float32(0.0) + x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] + T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.sblock("x_mean"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x_sum[v_ax0, v_ax1]) + T.writes(x_mean[v_ax0, v_ax1]) + x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] / T.float32(20.0) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.sblock("x_var_sum"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1]) + T.writes(x_var_sum[v_ax0, v_ax1]) + with T.init(): + x_var_sum[v_ax0, v_ax1] = T.float32(0.0) + x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] + (T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1]) * (T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1]) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.sblock("T_layer_norm"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0, v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", x[v_ax0, v_ax1, v_ax2, v_ax3]) - x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1] / T.float32(20.0) + T.float32(1.0000000000000001e-05))) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3] @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma: R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"): @@ -2901,35 +2921,45 @@ def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "f return gv @T.prim_func(private=True, s_tir=True) - def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): + def layer_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_T_layer_norm: T.handle): T.func_attr({"tirx.noalias": True}) - f = T.int64() - n = T.int64() - s = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], dtype="float32") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], dtype="float32") - rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], dtype="float32") - T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f], dtype="float32") - rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([n], dtype="float32") - rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([n], dtype="float32") - for i0, i1, i2 in T.grid(n, s, f): - with T.sblock("rxplaceholder_red_temp"): - ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, k1, k2]) - T.writes(rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0]) + n, s, f = T.int64(), T.int64(), T.int64() + x = T.match_buffer(var_x, (n, s, f)) + gamma = T.match_buffer(var_gamma, (s, f)) + beta = T.match_buffer(var_beta, (s, f)) + T_layer_norm = T.match_buffer(var_T_layer_norm, (n, s, f)) + # with T.sblock("root"): + x_sum = T.sblock_alloc_buffer((n,)) + x_mean = T.sblock_alloc_buffer((n,)) + x_var_sum = T.sblock_alloc_buffer((n,)) + for ax0, k1, k2 in T.grid(n, s, f): + with T.sblock("x_sum"): + v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2]) + T.reads(x[v_ax0, v_k1, v_k2]) + T.writes(x_sum[v_ax0]) with T.init(): - rxplaceholder_red_temp_v0[ax0] = T.float32(0) - rxplaceholder_red_temp_v1[ax0] = T.float32(0) - v_rxplaceholder_red_temp_v0: T.let[T.float32] = rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2] - v_rxplaceholder_red_temp_v1: T.let[T.float32] = rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] * rxplaceholder[ax0, k1, k2] - rxplaceholder_red_temp_v0[ax0] = v_rxplaceholder_red_temp_v0 - rxplaceholder_red_temp_v1[ax0] = v_rxplaceholder_red_temp_v1 - for i0, i1, i2 in T.grid(n, s, f): + x_sum[v_ax0] = T.float32(0.0) + x_sum[v_ax0] = x_sum[v_ax0] + x[v_ax0, v_k1, v_k2] + for ax0 in range(n): + with T.sblock("x_mean"): + v_ax0 = T.axis.spatial(n, ax0) + T.reads(x_sum[v_ax0]) + T.writes(x_mean[v_ax0]) + x_mean[v_ax0] = x_sum[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + for ax0, k1, k2 in T.grid(n, s, f): + with T.sblock("x_var_sum"): + v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2]) + T.reads(x[v_ax0, v_k1, v_k2], x_mean[v_ax0]) + T.writes(x_var_sum[v_ax0]) + with T.init(): + x_var_sum[v_ax0] = T.float32(0.0) + x_var_sum[v_ax0] = x_var_sum[v_ax0] + (x[v_ax0, v_k1, v_k2] - x_mean[v_ax0]) * (x[v_ax0, v_k1, v_k2] - x_mean[v_ax0]) + for ax0, ax1, ax2 in T.grid(n, s, f): with T.sblock("T_layer_norm"): - ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2]) - T.writes(T_layer_norm[ax0, ax1, ax2]) - T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1, ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) * (rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] + rxplaceholder_2[ax1, ax2] + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(x[v_ax0, v_ax1, v_ax2], x_mean[v_ax0], x_var_sum[v_ax0], gamma[v_ax1, v_ax2], beta[v_ax1, v_ax2]) + T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) + T_layer_norm[v_ax0, v_ax1, v_ax2] = (x[v_ax0, v_ax1, v_ax2] - x_mean[v_ax0]) * T.rsqrt(x_var_sum[v_ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax1, v_ax2] + beta[v_ax1, v_ax2] # fmt: on mod = LegalizeOps()(LayerNorm) tvm.ir.assert_structural_equal(mod, Expected)