diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7a3b168fc8fd..b1b01b87f715 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2227,8 +2227,17 @@ def body_fn(*loop_inputs): # Add new scan outputs to tracking combined_scan_outputs = [] for i, scan in enumerate(scan_outputs): - new_scan = _op.expand_dims(new_scan_outputs[i], axis=0) - combined_scan = _op.concatenate([scan, new_scan], axis=0) + rank = len(infer_shape(scan)) - 1 + new_scan = new_scan_outputs[i] + expand_scan = _op.expand_dims(new_scan, axis=0) + # For non scalar outputs we need to broadcast the initial value. + if rank > 0: + new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) + scan_broadcast = _op.concatenate( + [_op.reshape(loop_count, [1]), new_scan_shape], axis=0 + ) + scan = _op.broadcast_to(scan, scan_broadcast) + combined_scan = _op.concatenate([scan, expand_scan], axis=0) combined_scan_outputs.append(combined_scan) # Increment counter. diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 29f903fd4e35..6492b78d6037 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name """x86 declaration and schedules.""" from tvm import te +from tvm.tir import IntImm from ..utils import is_empty_shape @@ -100,18 +101,20 @@ def schedule_concatenate(outs): def vectorize(sch, tensor, vectorize_limit): """Internal vectorization function for concatenate.""" inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1] - inner_length = tensor.shape[len(tensor.shape) - 1].value - if inner_length <= vectorize_limit: - sch[tensor].vectorize(inner_axis) - else: - split_factor = 1 - for i in range(vectorize_limit, 1, -1): - if inner_length % i == 0: - split_factor = i - break - if split_factor > 1: - _, inner_i = sch[tensor].split(inner_axis, split_factor) - sch[tensor].vectorize(inner_i) + # Check that the tensor shape is static. Otherwise skip vectorization. + if isinstance(tensor.shape[len(tensor.shape) - 1], IntImm): + inner_length = tensor.shape[len(tensor.shape) - 1].value + if inner_length <= vectorize_limit: + sch[tensor].vectorize(inner_axis) + else: + split_factor = 1 + for i in range(vectorize_limit, 1, -1): + if inner_length % i == 0: + split_factor = i + break + if split_factor > 1: + _, inner_i = sch[tensor].split(inner_axis, split_factor) + sch[tensor].vectorize(inner_i) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs x = outs[0] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 20937d2060c5..c666604d0e89 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3654,14 +3654,14 @@ def verify_cond_loop(): def verify_count_loop(): - y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) - y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) - scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, []) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, []) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, []) cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) - y = np.array([-2]).astype(np.float32) + y = np.array(-2).astype(np.float32) iter_cast_node = helper.make_node( "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT @@ -3693,11 +3693,11 @@ def verify_count_loop(): inputs=[ onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, []), ], outputs=[ - onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, []), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5]), ], ) loop_model = onnx.helper.make_model(loop_graph) @@ -3708,11 +3708,69 @@ def verify_count_loop(): verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) +def verify_tensor_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + + y = np.random.normal(size=[3, 3, 3, 3]).astype(np.float32) + + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) + + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + + identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) + + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + + loop_body = helper.make_graph( + [identity_node, iter_cast_node, y_add_node, scan_identity_node], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) + + loop_node = helper.make_node( + "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body + ) + + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(np.bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(np.bool) + input_vals = [trip_count, cond, y] + verify_with_ort_with_inputs( + loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True + ) + + def test_loop(): # Test a loop that exits once a condition is met. verify_cond_loop() - # Test a loop that exits after a fixed number of iterations. + # Test a loop that exits after a fixed number of iterations with scalar outputs. verify_count_loop() + # Test a loop that uses an array output. + verify_tensor_loop() def verify_if(cond_array):