From f6ce93ea6ae6220b70825aee7e50ed33cd648e25 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 13:09:29 +0900 Subject: [PATCH 1/2] use `_convert_torch_tensor_to_relax` where possible --- python/tvm/relax/frontend/torch/fx_translator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 21a0b2d5642a..68c9c89c09e0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1206,9 +1206,8 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params[module.bias] - dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype)) - running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) - running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + running_mean = self._convert_torch_tensor_to_relax(module.running_mean) + running_var = self._convert_torch_tensor_to_relax(module.running_var) eps = module.eps res_tuple = self.block_builder.emit( @@ -1769,7 +1768,7 @@ def from_fx( dtype = self._convert_data_type(str(param.data.dtype)) if dtype in ("float32", "float16"): if not keep_params_as_input: - self.params[param] = relax.const(param.data.cpu().numpy(), dtype) + self.params[param] = self._convert_torch_tensor_to_relax(param) else: raise ValueError("Unsupported data type for model parameters: %s" % dtype) # Translate the model. From 3557accdfa938afcdec4955aea95bd8f1fbdeaf9 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 13:10:52 +0900 Subject: [PATCH 2/2] add type annotation --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 68c9c89c09e0..6e60c3bb6fc4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,7 +62,7 @@ def _fetch_attr(self, model, target: str): return attr_itr @staticmethod - def _convert_data_type(input_type, env: Optional[Dict] = None): + def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): """converts the PyTorch scalar type input_type to a TVM dtype.""" import torch # type: ignore