33
44import numpy as np
55
6- import onnx
7- import onnx .numpy_helper
8- import onnx .shape_inference
6+ from onnx import TensorProto , TypeProto
7+ from onnx .checker import ValidationError
8+ from onnx .defs import OpSchema , get_all_schemas_with_history , get_schema
9+ from onnx .helper import (
10+ make_graph ,
11+ make_node ,
12+ make_opsetid ,
13+ make_tensor_type_proto ,
14+ make_tensor_value_info ,
15+ )
16+ from onnx .numpy_helper import from_array
17+ from onnx .shape_inference import InferenceError , infer_node_outputs
918
1019ADD_SCHEMA = max (
11- (
12- s
13- for s in onnx .defs .get_all_schemas_with_history ()
14- if s .name == "Add" and s .domain == ""
15- ),
20+ (s for s in get_all_schemas_with_history () if s .name == "Add" and s .domain == "" ),
1621 key = lambda s : s .since_version ,
1722)
1823RESHAPE_SCHEMA = max (
1924 (
2025 s
21- for s in onnx . defs . get_all_schemas_with_history ()
26+ for s in get_all_schemas_with_history ()
2227 if s .name == "Reshape" and s .domain == ""
2328 ),
2429 key = lambda s : s .since_version ,
2530)
2631
27- _tensor = onnx .helper .make_tensor_type_proto
28-
2932
3033def _to_tensor_types (
3134 tensor_types : Dict [str , Tuple [int , Tuple [Union [int , str , None ], ...]]]
32- ) -> Dict [str , onnx .TypeProto ]:
33- return {
34- key : onnx .helper .make_tensor_type_proto (* value )
35- for key , value in tensor_types .items ()
36- }
35+ ) -> Dict [str , TypeProto ]:
36+ return {key : make_tensor_type_proto (* value ) for key , value in tensor_types .items ()}
3737
3838
3939def _run_case (
40- schema : onnx . defs . OpSchema ,
40+ schema : OpSchema ,
4141 input_names : List [str ],
4242 output_names : List [str ],
43- input_types : Dict [str , onnx . TypeProto ],
43+ input_types : Dict [str , TypeProto ],
4444 input_data : Optional [Dict [str , np .ndarray ]] = None ,
45- ) -> Dict [str , onnx . TypeProto ]:
45+ ) -> Dict [str , TypeProto ]:
4646 if input_data is None :
4747 input_data = {}
48- return onnx . shape_inference . infer_node_outputs (
48+ return infer_node_outputs (
4949 schema ,
50- onnx .helper .make_node (
51- schema .name , input_names , output_names , domain = schema .domain
52- ),
50+ make_node (schema .name , input_names , output_names , domain = schema .domain ),
5351 input_types ,
54- {key : onnx . numpy_helper . from_array (arr ) for key , arr in input_data .items ()},
52+ {key : from_array (arr ) for key , arr in input_data .items ()},
5553 )
5654
5755
5856class TestInferenceFunctionCall (unittest .TestCase ):
5957 def test_add_inference (self ) -> None :
6058 cases = [
6159 (
62- {"A" : (onnx . TensorProto .FLOAT , ()), "B" : (onnx . TensorProto .FLOAT , ())},
63- {"C" : (onnx . TensorProto .FLOAT , ())},
60+ {"A" : (TensorProto .FLOAT , ()), "B" : (TensorProto .FLOAT , ())},
61+ {"C" : (TensorProto .FLOAT , ())},
6462 ),
6563 (
6664 {
67- "A" : (onnx . TensorProto .FLOAT , (None , 2 )),
68- "B" : (onnx . TensorProto .FLOAT , (2 ,)),
65+ "A" : (TensorProto .FLOAT , (None , 2 )),
66+ "B" : (TensorProto .FLOAT , (2 ,)),
6967 },
70- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
68+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
7169 ),
7270 (
7371 {
74- "A" : (onnx . TensorProto .FLOAT , (None , 2 )),
75- "B" : (onnx . TensorProto .FLOAT , (1 , 2 )),
72+ "A" : (TensorProto .FLOAT , (None , 2 )),
73+ "B" : (TensorProto .FLOAT , (1 , 2 )),
7674 },
77- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
75+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
7876 ),
7977 (
8078 {
81- "A" : (onnx . TensorProto .DOUBLE , ("n" , "m" )),
82- "B" : (onnx . TensorProto .DOUBLE , (1 , "n" , "m" )),
79+ "A" : (TensorProto .DOUBLE , ("n" , "m" )),
80+ "B" : (TensorProto .DOUBLE , (1 , "n" , "m" )),
8381 },
84- {"C" : (onnx . TensorProto .DOUBLE , (1 , "n" , "m" ))},
82+ {"C" : (TensorProto .DOUBLE , (1 , "n" , "m" ))},
8583 ),
8684 (
8785 {
88- "A" : (onnx . TensorProto .FLOAT , ("x" , 2 )),
89- "B" : (onnx . TensorProto .FLOAT , ("y" , 2 )),
86+ "A" : (TensorProto .FLOAT , ("x" , 2 )),
87+ "B" : (TensorProto .FLOAT , ("y" , 2 )),
9088 },
91- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
89+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
9290 ),
9391 ]
9492 for ins , outs in cases :
9593 assert _run_case (ADD_SCHEMA , ["A" , "B" ], ["C" ], _to_tensor_types (ins )) == _to_tensor_types (outs ) # type: ignore
9694
9795 def test_add_inference_raises_errors (self ) -> None :
98- with self .assertRaises (onnx . checker . ValidationError ):
96+ with self .assertRaises (ValidationError ):
9997 _run_case (
10098 ADD_SCHEMA ,
10199 ["A" ],
102100 ["C" ],
103- _to_tensor_types ({"A" : (onnx . TensorProto .FLOAT , (3 , 4 ))}),
101+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 ))}),
104102 )
105- with self .assertRaises (onnx . checker . ValidationError ):
103+ with self .assertRaises (ValidationError ):
106104 _run_case (
107105 ADD_SCHEMA ,
108106 ["A" , "B" ],
109107 ["C" ],
110- _to_tensor_types (
111- {"A" : (onnx .TensorProto .FLOAT , (3 , 4 )), "B" : (2 , (3 , 4 ))}
112- ),
108+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 )), "B" : (2 , (3 , 4 ))}),
113109 )
114- with self .assertRaises (onnx . shape_inference . InferenceError ):
110+ with self .assertRaises (InferenceError ):
115111 _run_case (
116112 ADD_SCHEMA ,
117113 ["A" , "B" ],
118114 ["C" ],
119115 _to_tensor_types (
120116 {
121- "A" : (onnx . TensorProto .FLOAT , (2 , 4 )),
122- "B" : (onnx . TensorProto .FLOAT , (3 , 4 )),
117+ "A" : (TensorProto .FLOAT , (2 , 4 )),
118+ "B" : (TensorProto .FLOAT , (3 , 4 )),
123119 }
124120 ),
125121 )
@@ -128,7 +124,7 @@ def test_add_inference_raises_errors(self) -> None:
128124 ADD_SCHEMA ,
129125 ["A" , "B" ],
130126 ["C" ],
131- _to_tensor_types ({"A" : (onnx . TensorProto .FLOAT , (3 , 4 ))}),
127+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 ))}),
132128 )
133129
134130 def test_reshape_inference (self ) -> None :
@@ -138,12 +134,63 @@ def test_reshape_inference(self) -> None:
138134 ["y" ],
139135 _to_tensor_types (
140136 {
141- "x" : (onnx . TensorProto .FLOAT , (5 , 4 )),
142- "t" : (onnx . TensorProto .INT64 , (3 ,)),
137+ "x" : (TensorProto .FLOAT , (5 , 4 )),
138+ "t" : (TensorProto .INT64 , (3 ,)),
143139 }
144140 ),
145141 {"t" : np .array ([2 , 2 , 5 ], dtype = np .int64 )},
146- ) == _to_tensor_types ({"y" : (onnx .TensorProto .FLOAT , (2 , 2 , 5 ))})
142+ ) == _to_tensor_types ({"y" : (TensorProto .FLOAT , (2 , 2 , 5 ))})
143+
144+ def test_scan_inference_with_subgraph (self ) -> None :
145+ seq_len = "sequence"
146+ input_size = 2
147+ loop_state_size = 3
148+
149+ input_value_infos = [
150+ make_tensor_value_info ("loop_state_in" , TensorProto .UNDEFINED , None ),
151+ make_tensor_value_info ("input" , TensorProto .UNDEFINED , None ),
152+ make_tensor_value_info ("outer" , TensorProto .UNDEFINED , None ),
153+ ]
154+ output_value_infos = [
155+ make_tensor_value_info ("loop_state_out" , TensorProto .UNDEFINED , None ),
156+ make_tensor_value_info ("output" , TensorProto .FLOAT , (seq_len , input_size )),
157+ ]
158+
159+ subgraph = make_graph (
160+ [
161+ make_node ("Identity" , ["loop_state_in" ], ["loop_state_out" ]),
162+ make_node ("Add" , ["input" , "outer" ], ["output" ]),
163+ ],
164+ "subgraph" ,
165+ input_value_infos ,
166+ output_value_infos ,
167+ )
168+
169+ assert infer_node_outputs (
170+ get_schema ("Scan" , 9 ),
171+ make_node (
172+ "Scan" ,
173+ ["loop_state_orig" , "scan_input" , "scan_outer" ],
174+ ["loop_state_final" , "scan_output" ],
175+ num_scan_inputs = 1 ,
176+ body = subgraph ,
177+ ),
178+ _to_tensor_types (
179+ {
180+ "loop_state_orig" : (TensorProto .FLOAT , (loop_state_size ,)),
181+ "scan_input" : (TensorProto .FLOAT , (seq_len , input_size )),
182+ "scan_outer" : (TensorProto .FLOAT , (input_size ,)),
183+ }
184+ ),
185+ # Same as default value in Scan-9
186+ opset_imports = [make_opsetid ("" , 9 )],
187+ ir_version = 4 ,
188+ ) == _to_tensor_types (
189+ {
190+ "loop_state_final" : (TensorProto .FLOAT , (loop_state_size ,)),
191+ "scan_output" : (TensorProto .FLOAT , (seq_len , input_size )),
192+ }
193+ )
147194
148195
149196if __name__ == "__main__" :
0 commit comments