From 2bff51752d5ae2467ee70e51ad0b59aa7e45d0ae Mon Sep 17 00:00:00 2001 From: thecaptain789 Date: Mon, 9 Feb 2026 23:27:03 +0000 Subject: [PATCH] fix: correct additional typos identified in review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix typos in docstrings and comments identified by Gemini Code Assist during PR #18739 review: - python/tvm/ir/base.py: 'assert_strucural_equal' → 'assert_structural_equal' - python/tvm/ir/base.py: 'structrual_equal' → 'structural_equal' - include/tvm/node/structural_hash.h: fix file path in comment - tests/python/tir-base/test_tir_structural_equal_hash.py: 'sequal' → 'struct_equal' for consistency with variable names --- include/tvm/node/structural_hash.h | 2 +- python/tvm/ir/base.py | 4 +-- .../test_tir_structural_equal_hash.py | 28 +++++++++---------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index ba7cbaf88aa6..b8a00741f73c 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/node/structural_equal.h + * \file tvm/node/structural_hash.h * \brief Structural hash class. */ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 651ab392039c..450a9a2d7914 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -188,7 +188,7 @@ def structural_equal(lhs, rhs, map_free_vars=False): See Also -------- structural_hash - assert_strucural_equal + assert_structural_equal """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) @@ -289,7 +289,7 @@ def structural_hash(node, map_free_vars=False): See Also -------- - structrual_equal + structural_equal """ return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index 296450b9a23c..80505f651bc3 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -31,7 +31,7 @@ def consistent_equal(x, y, map_free_vars=False): if struct_equal0 != struct_equal1: raise ValueError( - "Non-commutative {} vs {}, sequal0={}, sequal1={}".format( + "Non-commutative {} vs {}, struct_equal0={}, struct_equal1={}".format( x, y, struct_equal0, struct_equal1 ) ) @@ -40,14 +40,14 @@ def consistent_equal(x, y, map_free_vars=False): # we can confirm that hash colison doesn't happen for our testcases if struct_equal0 != (xhash == yhash): raise ValueError( - "Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format( + "Inconsistent {} vs {}, struct_equal={}, xhash={}, yhash={}".format( x, y, struct_equal0, xhash, yhash ) ) return struct_equal0 -def get_sequal_mismatch(x, y, map_free_vars=False): +def get_struct_equal_mismatch(x, y, map_free_vars=False): mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars) mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars) @@ -138,7 +138,7 @@ def test_prim_func_param_count_mismatch(): # counter example of same equality func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x)) func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x)) - lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1) expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2) expected_rhs_path = AccessPath.root().attr("params").array_item(2) assert lhs_path == expected_lhs_path @@ -152,7 +152,7 @@ def test_prim_func_param_dtype_mismatch(): # counter example of same equality func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x)) func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x)) - lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1) expected_path = AccessPath.root().attr("params").array_item(1).attr("dtype") assert lhs_path == expected_path assert rhs_path == expected_path @@ -166,7 +166,7 @@ def test_prim_func_body_mismatch(): # counter example of same equality func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0)) func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1)) - lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1) expected_path = AccessPath.root().attr("body").attr("value").attr("b") assert lhs_path == expected_path assert rhs_path == expected_path @@ -257,14 +257,14 @@ def test_buffer_map_mismatch(): func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone}) func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1}) - lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) + lhs_path, rhs_path = get_struct_equal_mismatch(func_0, func_1) expected_path = ( AccessPath.root().attr("buffer_map").map_item(x).attr("shape").array_item(1).attr("value") ) assert lhs_path == expected_path assert rhs_path == expected_path - assert get_sequal_mismatch(func_0, func_0_clone) is None + assert get_struct_equal_mismatch(func_0, func_0_clone) is None def test_buffer_map_length_mismatch(): @@ -277,7 +277,7 @@ def test_buffer_map_length_mismatch(): func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0}) func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1}) - lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) + lhs_path, rhs_path = get_struct_equal_mismatch(func_0, func_1) expected_lhs_path = AccessPath.root().attr("buffer_map").map_item_missing(y) assert lhs_path == expected_lhs_path @@ -314,7 +314,7 @@ def test_while_condition_mismatch(): x = tvm.tir.Var("x", "int32") w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x)) - lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) + lhs_path, rhs_path = get_struct_equal_mismatch(w_0, w_1) expected_path = AccessPath.root().attr("condition") assert lhs_path == expected_path assert rhs_path == expected_path @@ -324,7 +324,7 @@ def test_while_body_mismatch(): x = tvm.tir.Var("x", "int32") w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1)) - lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) + lhs_path, rhs_path = get_struct_equal_mismatch(w_0, w_1) expected_path = AccessPath.root().attr("body").attr("value") assert lhs_path == expected_path assert rhs_path == expected_path @@ -348,7 +348,7 @@ def test_seq_mismatch(): tvm.tir.Evaluate(x + 3), ] ) - lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1) expected_path = ( AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value") ) @@ -368,7 +368,7 @@ def test_seq_mismatch_different_lengths(): ] ) seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)]) - lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1) expected_path = ( AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value") ) @@ -387,7 +387,7 @@ def test_seq_length_mismatch(): ] ) seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)]) - lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1) expected_lhs_path = AccessPath.root().attr("seq").array_item(3) expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3) assert lhs_path == expected_lhs_path