diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 826711538c68..c768ea19af7d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { if (data_sinfo->IsUnknownNdim()) { return data_sinfo; } - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() && + !data_sinfo->dtype.is_bfloat()) { ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " << data_sinfo->dtype); diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 7adfc8428355..ec4551872fc8 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. import pytest + import tvm import tvm.testing -from tvm import relax, tir -from tvm import TVMError +from tvm import TVMError, relax, tir from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -143,6 +143,7 @@ def test_softmax_log_softmax_infer_struct_info(): x3 = relax.Var("x", R.Tensor((2, 3))) x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) + x6 = relax.Var("x", R.Tensor((2, 3), "bfloat16")) _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32")) _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2, 3), "float32", vdev0)) @@ -164,6 +165,10 @@ def test_softmax_log_softmax_infer_struct_info(): bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") ) _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.nn.softmax(x6), relax.TensorStructInfo((2, 3), dtype="bfloat16")) + _check_inference( + bb, relax.op.nn.log_softmax(x6), relax.TensorStructInfo((2, 3), dtype="bfloat16") + ) def test_softmax_log_softmax_infer_struct_info_shape_symbolic():