From 12609f4f6218b0d253e079cb104a9f47842a51d4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 18 Jul 2024 19:43:03 +0900 Subject: [PATCH] use `packaging.version.parse` instead of `distutils.version.LooseVersion` --- python/tvm/contrib/msc/core/utils/info.py | 6 +++--- python/tvm/relay/frontend/pytorch_utils.py | 4 ++-- python/tvm/relay/op/contrib/ethosn.py | 6 +++--- python/tvm/relay/testing/tflite.py | 4 ++-- .../test_arm_compute_lib/test_network.py | 4 ++-- .../frontend/tensorflow/test_forward.py | 9 ++++----- tests/python/frontend/tflite/test_forward.py | 19 +++++++++---------- 7 files changed, 25 insertions(+), 27 deletions(-) diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 4fea45f8fab2..58b08112797a 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -17,7 +17,7 @@ """tvm.contrib.msc.core.utils.info""" from typing import List, Tuple, Dict, Any, Union -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import tvm @@ -409,8 +409,8 @@ def get_version(framework: str) -> List[int]: raw_version = "1.0.0" except: # pylint: disable=bare-except raw_version = "1.0.0" - raw_version = raw_version or "1.0.0" - return LooseVersion(raw_version).version + version = parse(raw_version or "1.0.0") + return [version.major, version.minor, version.micro] def compare_version(given_version: List[int], target_version: List[int]) -> int: diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 7de1248bda77..8686be4b1ea9 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -36,7 +36,7 @@ def is_version_greater_than(ver): than the one given as an argument. """ import torch - from distutils.version import LooseVersion + from packaging.version import parse torch_ver = torch.__version__ # PT version numbers can include +cu[cuda version code] @@ -44,7 +44,7 @@ def is_version_greater_than(ver): if "+cu" in torch_ver: torch_ver = torch_ver.split("+cu")[0] - return LooseVersion(torch_ver) > ver + return parse(torch_ver) > parse(ver) def getattr_attr_name(node): diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 81534d48a216..c1e87ad5d90b 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Arm(R) Ethos(TM)-N NPU supported operators.""" from enum import Enum -from distutils.version import LooseVersion +from packaging.version import parse import tvm.ir from tvm.relay import transform @@ -118,7 +118,7 @@ def partition_for_ethosn(mod, params=None, **opts): """ api_version = ethosn_api_version() supported_api_versions = ["3.2.0"] - if all(api_version != LooseVersion(exp_ver) for exp_ver in supported_api_versions): + if all(parse(api_version) != parse(exp_ver) for exp_ver in supported_api_versions): raise ValueError( f"Driver stack version {api_version} is unsupported. " f"Please use version in {supported_api_versions}." @@ -433,7 +433,7 @@ def split(expr): """Check if a split is supported by Ethos-N.""" if not ethosn_available(): return False - if ethosn_api_version() == LooseVersion("3.0.1"): + if parse(ethosn_api_version()) == parse("3.0.1"): return False if not _ethosn.split(expr): return False diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index df9c0bcadf62..29f6bc62cad2 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Common utilities for creating TFLite models""" -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import pytest import tflite.Model # pylint: disable=wrong-import-position @@ -134,7 +134,7 @@ def generate_reference_data(self): assert self.serial_model is not None, "TFLite model was not created." output_tolerance = None - if tf.__version__ < LooseVersion("2.5.0"): + if parse(tf.__version__) < parse("2.5.0"): output_tolerance = 1 interpreter = tf.lite.Interpreter(model_content=self.serial_model) else: diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 3cf81e971f77..8c6302abf842 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -16,7 +16,7 @@ # under the License. """Arm Compute Library network tests.""" -from distutils.version import LooseVersion +from packaging.version import parse import numpy as np import pytest @@ -137,7 +137,7 @@ def get_model(): mod, params = _get_keras_model(mobilenet, inputs) return mod, params, inputs - if keras.__version__ < LooseVersion("2.9"): + if parse(keras.__version__) < parse("2.9"): # This can be removed after we migrate to TF/Keras >= 2.9 expected_tvm_ops = 56 expected_acl_partitions = 31 diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index db270ccb2e9f..354ed38a62ce 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -21,7 +21,6 @@ This article is a test script to test tensorflow operator with Relay. """ from __future__ import print_function -from distutils.version import LooseVersion import threading import platform @@ -1755,7 +1754,7 @@ def _test_concat_v2(shape1, shape2, dim): def test_forward_concat_v2(): - if tf.__version__ < LooseVersion("1.4.1"): + if package_version.parse(tf.__version__) < package_version.parse("1.4.1"): return _test_concat_v2([2, 3], [2, 3], 0) @@ -3128,7 +3127,7 @@ def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype) def test_forward_clip_by_value(): """test ClipByValue op""" - if tf.__version__ < LooseVersion("1.9"): + if package_version.parse(tf.__version__) < package_version.parse("1.9"): _test_forward_clip_by_value((4,), 0.1, 5.0, "float32") _test_forward_clip_by_value((4, 4), 1, 5, "int32") @@ -4482,7 +4481,7 @@ def _test_forward_zeros_like(in_shape, dtype): def test_forward_zeros_like(): - if tf.__version__ < LooseVersion("1.2"): + if package_version.parse(tf.__version__) < package_version.parse("1.2"): _test_forward_zeros_like((2, 3), "int32") _test_forward_zeros_like((2, 3, 5), "int8") _test_forward_zeros_like((2, 3, 5, 7), "uint16") @@ -5566,7 +5565,7 @@ def test_forward_spop(): # This test is expected to fail in TF version >= 2.6 # as the generated graph will be considered frozen, hence # not passing the criteria for the test below. - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.__version__) < package_version.parse("2.6.1"): _test_spop_resource_variables() # Placeholder test cases diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 75a2a37c636a..cb0b17ea3fcf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -22,7 +22,6 @@ """ from __future__ import print_function from functools import partial -from distutils.version import LooseVersion import platform import os import tempfile @@ -1054,7 +1053,7 @@ def representative_data_gen(): input_node = subgraph.Tensors(model_input).Name().decode("utf-8") tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): input_node = data_in.name.replace(":0", "") else: input_node = "serving_default_" + data_in.name + ":0" @@ -1775,7 +1774,7 @@ def representative_data_gen(): tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): input_node = data_in.name.replace(":0", "") else: input_node = "serving_default_" + data_in.name + ":0" @@ -2219,9 +2218,9 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8): tflite_output = run_tflite_graph(tflite_model_quant, data) # TFLite 2.6.x upgrade support - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.__version__) < package_version.parse("2.6.1"): in_node = ["serving_default_input_int8"] - elif tf.__version__ < LooseVersion("2.9"): + elif package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ( ["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"] ) @@ -2245,7 +2244,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): """One iteration of rsqrt""" # tensorflow version upgrade support - if tf.__version__ < LooseVersion("2.6.1") or not quantized: + if package_version.parse(tf.__version__) < package_version.parse("2.6.1") or not quantized: return _test_unary_elemwise( math_ops.rsqrt, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype ) @@ -2254,7 +2253,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ["tfl.quantize"] else: in_node = "serving_default_input" @@ -2338,7 +2337,7 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8): tf.math.cos, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = ["tfl.quantize"] else: in_node = "serving_default_input" @@ -3396,7 +3395,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = data_in.name.split(":")[0] else: in_node = "serving_default_" + data_in.name + ":0" @@ -3426,7 +3425,7 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - if tf.__version__ < LooseVersion("2.9"): + if package_version.parse(tf.__version__) < package_version.parse("2.9"): in_node = data_in.name.split(":")[0] else: in_node = "serving_default_" + data_in.name + ":0"