diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a84c35e62234..67d93b066972 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1197,6 +1197,7 @@ def from_exported_program( keep_params_as_input: bool = False, unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, + run_ep_decomposition: bool = False, ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program @@ -1216,6 +1217,12 @@ def from_exported_program( A boolean flag indicating whether to bind the return tuple as a relax var. If the flag is true and the return value is a tuple, it will not bind it to a var. + run_ep_decomposition : bool + A boolean flag indicating whether to run PyTorch's decomposition on the + exported program before translation. When True, high-level operators will + be decomposed into their constituent parts. Defaults to False for backward + compatibility. + Returns ------- output : tvm.IRModule @@ -1255,8 +1262,9 @@ def forward(self, input): # Use the importer to import the ExportedProgram to Relax. mod: tvm.IRModule = from_exported_program(exported_program) """ - # decompose into Core ATen operators - exported_program.run_decompositions() + # Conditionally decompose into Core ATen operators + if run_ep_decomposition: + exported_program = exported_program.run_decompositions() return ExportedProgramImporter().from_exported_program( exported_program,