diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e554648c41ad..33a22b34fcc0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1634,6 +1634,12 @@ def _any(self, node: fx.Node) -> relax.Var: dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + # max doesn't support boolean tensors directly, so we compute it in int8 and cast back + if x.struct_info.dtype == "bool": + x = relax.op.astype(x, "int8") + ret = relax.op.max(x, dim, keepdims=keepdim) + return self.block_builder.emit(relax.op.astype(ret, "bool")) + # For boolean tensors, any is equivalent to max (checking if any element is True) return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim)) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 662df5e76a62..7397b3f21aef 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1693,8 +1693,10 @@ def main( with R.dataflow(): lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, R.shape([10, 10, 1])) lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, test_elements) - lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1], keepdims=False) - gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,) + lv2: R.Tensor((10, 10, 8), dtype="int8") = R.astype(lv1, dtype="int8") + lv3: R.Tensor((10, 10), dtype="int8") = R.max(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((10, 10), dtype="bool") = R.astype(lv3, dtype="bool") + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,) R.output(gv) return gv @@ -4118,71 +4120,22 @@ def main( v: R.Tensor((32, 8, 128, 64), dtype="float32"), ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply( - q, R.const(0.35355338454246521, "float32") + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + q, axes=[0, 2, 1, 3] ) - lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = R.permute_dims( - k, axes=[0, 1, 3, 2] + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + k, axes=[0, 2, 1, 3] ) - lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply( - lv1, R.const(0.35355338454246521, "float32") + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + v, axes=[0, 2, 1, 3] ) - lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to( - lv, R.shape([32, 8, 128, 64]) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, scale=None, causal_mask=None, window_size=None ) - lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape( - lv3, R.shape([256, 128, 64]) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] ) - lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = R.broadcast_to( - lv2, R.shape([32, 8, 64, 128]) - ) - lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape( - lv5, R.shape([256, 64, 128]) - ) - lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul( - lv4, lv6, out_dtype="float32" - ) - lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape( - lv7, R.shape([32, 8, 128, 128]) - ) - lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.nn.softmax(lv8, axis=-1) - lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal( - lv8, R.const(float("-inf"), "float32") - ) - lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.logical_not(lv10) - lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max( - lv11, axis=[-1], keepdims=True - ) - lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.logical_not(lv12) - lv14: R.Tensor((32, 8, 128, 128), dtype="float32") = R.full_like( - lv9, R.const(0, "int32"), dtype="void" - ) - lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = R.where(lv13, lv14, lv9) - lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = R.broadcast_to( - lv15, R.shape([32, 8, 128, 128]) - ) - lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape( - lv16, R.shape([256, 128, 128]) - ) - lv18: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to( - v, R.shape([32, 8, 128, 64]) - ) - lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape( - lv18, R.shape([256, 128, 64]) - ) - lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul( - lv17, lv19, out_dtype="float32" - ) - lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape( - lv20, R.shape([32, 8, 128, 64]) - ) - lv22: R.Tensor((128, 32, 8, 64), dtype="float32") = R.permute_dims( - lv21, axes=[2, 0, 1, 3] - ) - lv23: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( - lv22, axes=[1, 2, 0, 3] - ) - gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv23,) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) R.output(gv) return gv @@ -4200,72 +4153,22 @@ def main( mask: R.Tensor((32, 8, 128, 128), dtype="float32"), ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply( - q, R.const(0.35355338454246521, "float32") - ) - lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = R.permute_dims( - k, axes=[0, 1, 3, 2] - ) - lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply( - lv1, R.const(0.35355338454246521, "float32") - ) - lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to( - lv, R.shape([32, 8, 128, 64]) - ) - lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape( - lv3, R.shape([256, 128, 64]) - ) - lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = R.broadcast_to( - lv2, R.shape([32, 8, 64, 128]) - ) - lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape( - lv5, R.shape([256, 64, 128]) - ) - lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul( - lv4, lv6, out_dtype="float32" - ) - lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape( - lv7, R.shape([32, 8, 128, 128]) + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + q, axes=[0, 2, 1, 3] ) - lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8, mask) - lv10: R.Tensor((32, 8, 128, 128), dtype="float32") = R.nn.softmax(lv9, axis=-1) - lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal( - lv9, R.const(float("-inf"), "float32") + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + k, axes=[0, 2, 1, 3] ) - lv12: R.Tensor((32, 8, 128, 128), dtype="bool") = R.logical_not(lv11) - lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max( - lv12, axis=[-1], keepdims=True + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + v, axes=[0, 2, 1, 3] ) - lv14: R.Tensor((32, 8, 128, 1), dtype="bool") = R.logical_not(lv13) - lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = R.full_like( - lv10, R.const(0, "int32"), dtype="void" + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention_bias( + lv, lv1, lv2, mask, scale=None, causal_mask=None, window_size=None ) - lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = R.where(lv14, lv15, lv10) - lv17: R.Tensor((32, 8, 128, 128), dtype="float32") = R.broadcast_to( - lv16, R.shape([32, 8, 128, 128]) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] ) - lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape( - lv17, R.shape([256, 128, 128]) - ) - lv19: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to( - v, R.shape([32, 8, 128, 64]) - ) - lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape( - lv19, R.shape([256, 128, 64]) - ) - lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul( - lv18, lv20, out_dtype="float32" - ) - lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape( - lv21, R.shape([32, 8, 128, 64]) - ) - lv23: R.Tensor((128, 32, 8, 64), dtype="float32") = R.permute_dims( - lv22, axes=[2, 0, 1, 3] - ) - lv24: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( - lv23, axes=[1, 2, 0, 3] - ) - gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv24,) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) R.output(gv) return gv @@ -4278,7 +4181,7 @@ def main( ), {}, Expected1, - run_ep_decomposition=True, + run_ep_decomposition=False, ) verify_model( @@ -4291,7 +4194,7 @@ def main( ), {}, Expected2, - run_ep_decomposition=True, + run_ep_decomposition=False, ) # Test 2D input (seq_len, head_dim) - bug fix for #18441 @@ -7307,6 +7210,29 @@ def main( verify_model(Take(), example_args, {}, Expected) +def test_any(): + class AnyAten(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.any(x, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), dtype="bool"), + ) -> R.Tuple(R.Tensor((2,), dtype="bool")): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="int8") = relax.op.astype(x, dtype="int8") + lv2: R.Tensor((2,), dtype="int8") = relax.op.max(lv, axis=1, keepdims=False) + lv3: R.Tensor((2,), dtype="bool") = relax.op.astype(lv2, dtype="bool") + gv: R.Tuple(R.Tensor((2,), dtype="bool")) = (lv3,) + R.output(gv) + return gv + + example_args = (torch.tensor([[0, 0, 0], [0, 1, 0]], dtype=torch.bool),) + verify_model(AnyAten(), example_args, {}, Expected) + + def test_std(): class Std(Module): def forward(self, x):