Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
from tvm.script import tirx as T


def _has_volatile_alloc_buffer(mod):
has_volatile_alloc = False

def visit(node):
nonlocal has_volatile_alloc
if isinstance(node, tvm.tirx.AllocBuffer) and "tirx.volatile" in node.annotations:
has_volatile_alloc = has_volatile_alloc or node.annotations["tirx.volatile"] is True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using is True to check the value of a TVM annotation can lead to failures. TVM map lookups typically return TVM object wrappers (such as tvm.tir.IntImm or tvm.runtime.Bool) rather than Python's built-in True singleton, so identity comparison (is) will evaluate to False. Using bool(...) is more robust and correctly evaluates the truthiness of the TVM object.

Suggested change
has_volatile_alloc = has_volatile_alloc or node.annotations["tirx.volatile"] is True
has_volatile_alloc = has_volatile_alloc or bool(node.annotations["tirx.volatile"])


tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, visit)
return has_volatile_alloc


def test_basic():
transform = tvm.s_tir.transform.LowerThreadAllreduce()

Expand Down Expand Up @@ -503,7 +515,7 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32
After_script = After.script()
assert "tvm_warp_shuffle_down" in After_script
assert "tvm_storage_sync" in After_script
assert '"tirx.volatile": T.bool(True)' in After_script
assert _has_volatile_alloc_buffer(After)
assert "T.uint32(" not in After_script


Expand Down
Loading