Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \file tvm/node/structural_hash.h
* \brief Structural hash class.
*/
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def structural_equal(lhs, rhs, map_free_vars=False):
See Also
--------
structural_hash
assert_strucural_equal
assert_structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
Expand Down Expand Up @@ -289,7 +289,7 @@ def structural_hash(node, map_free_vars=False):

See Also
--------
structrual_equal
structural_equal
"""
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member

Expand Down
28 changes: 14 additions & 14 deletions tests/python/tir-base/test_tir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def consistent_equal(x, y, map_free_vars=False):

if struct_equal0 != struct_equal1:
raise ValueError(
"Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
"Non-commutative {} vs {}, struct_equal0={}, struct_equal1={}".format(
x, y, struct_equal0, struct_equal1
)
Comment on lines +34 to 36

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, consider using an f-string for this error message. F-strings are generally preferred over .format() in modern Python.

            f"Non-commutative {x} vs {y}, struct_equal0={struct_equal0}, struct_equal1={struct_equal1}"

)
Expand All @@ -40,14 +40,14 @@ def consistent_equal(x, y, map_free_vars=False):
# we can confirm that hash colison doesn't happen for our testcases
if struct_equal0 != (xhash == yhash):
raise ValueError(
"Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
"Inconsistent {} vs {}, struct_equal={}, xhash={}, yhash={}".format(
x, y, struct_equal0, xhash, yhash
)
Comment on lines +43 to 45

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous suggestion, using an f-string here would make the error message more readable and is consistent with modern Python style.

            f"Inconsistent {x} vs {y}, struct_equal={struct_equal0}, xhash={xhash}, yhash={yhash}"

)
return struct_equal0


def get_sequal_mismatch(x, y, map_free_vars=False):
def get_struct_equal_mismatch(x, y, map_free_vars=False):
mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars)
mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars)

Expand Down Expand Up @@ -138,7 +138,7 @@ def test_prim_func_param_count_mismatch():
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1)
expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2)
expected_rhs_path = AccessPath.root().attr("params").array_item(2)
assert lhs_path == expected_lhs_path
Expand All @@ -152,7 +152,7 @@ def test_prim_func_param_dtype_mismatch():
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1)
expected_path = AccessPath.root().attr("params").array_item(1).attr("dtype")
assert lhs_path == expected_path
assert rhs_path == expected_path
Expand All @@ -166,7 +166,7 @@ def test_prim_func_body_mismatch():
# counter example of same equality
func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
lhs_path, rhs_path = get_struct_equal_mismatch(func0, func1)
expected_path = AccessPath.root().attr("body").attr("value").attr("b")
assert lhs_path == expected_path
assert rhs_path == expected_path
Expand Down Expand Up @@ -257,14 +257,14 @@ def test_buffer_map_mismatch():
func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone})
func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1})

lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
lhs_path, rhs_path = get_struct_equal_mismatch(func_0, func_1)
expected_path = (
AccessPath.root().attr("buffer_map").map_item(x).attr("shape").array_item(1).attr("value")
)
assert lhs_path == expected_path
assert rhs_path == expected_path

assert get_sequal_mismatch(func_0, func_0_clone) is None
assert get_struct_equal_mismatch(func_0, func_0_clone) is None


def test_buffer_map_length_mismatch():
Expand All @@ -277,7 +277,7 @@ def test_buffer_map_length_mismatch():
func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1})

lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
lhs_path, rhs_path = get_struct_equal_mismatch(func_0, func_1)

expected_lhs_path = AccessPath.root().attr("buffer_map").map_item_missing(y)
assert lhs_path == expected_lhs_path
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_while_condition_mismatch():
x = tvm.tir.Var("x", "int32")
w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
lhs_path, rhs_path = get_struct_equal_mismatch(w_0, w_1)
expected_path = AccessPath.root().attr("condition")
assert lhs_path == expected_path
assert rhs_path == expected_path
Expand All @@ -324,7 +324,7 @@ def test_while_body_mismatch():
x = tvm.tir.Var("x", "int32")
w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1))
lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
lhs_path, rhs_path = get_struct_equal_mismatch(w_0, w_1)
expected_path = AccessPath.root().attr("body").attr("value")
assert lhs_path == expected_path
assert rhs_path == expected_path
Expand All @@ -348,7 +348,7 @@ def test_seq_mismatch():
tvm.tir.Evaluate(x + 3),
]
)
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1)
expected_path = (
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
)
Expand All @@ -368,7 +368,7 @@ def test_seq_mismatch_different_lengths():
]
)
seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)])
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1)
expected_path = (
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
)
Expand All @@ -387,7 +387,7 @@ def test_seq_length_mismatch():
]
)
seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)])
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
lhs_path, rhs_path = get_struct_equal_mismatch(seq_0, seq_1)
expected_lhs_path = AccessPath.root().attr("seq").array_item(3)
expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3)
assert lhs_path == expected_lhs_path
Expand Down