Skip to content

[REFACTOR][TIR] Phase out LetStmtNode: migrate to flat BindNode#5

Closed
tqchen wants to merge 36 commits into
mainfrom
refactor-letstmtnode-to-bindnode
Closed

[REFACTOR][TIR] Phase out LetStmtNode: migrate to flat BindNode#5
tqchen wants to merge 36 commits into
mainfrom
refactor-letstmtnode-to-bindnode

Conversation

@tqchen

@tqchen tqchen commented Mar 2, 2026

Copy link
Copy Markdown
Owner

Summary

This PR refactors TVM TIR to replace the nested LetStmtNode (with body field) with a flat BindNode that has no body. Variable scope is determined by the enclosing body-carrying statement rather than nesting.

Commits

  1. Introduce BindNode and Bind -- core IR node, functor dispatch, analysis infrastructure
  2. Phase out LetStmtNode body field -- migrate all passes to use flat Bind
  3. Fix passes and tests for flat Bind semantics -- address issues found in testing
  4. Complete LetStmt-to-Bind migration -- fix remaining edge cases
  5. Cleanup: rename LetStmt references to Bind -- naming consistency
  6. Simplify CSE -- remove VisitSeqStmtSlice, use flat Bind semantics
  7. Simplify ConvertSSA -- remove SeqStmt handler for flat Bind
  8. Simplify hoist_expression -- remove SeqStmt handler for flat Bind
  9. Simplify sblock_access_region_detector -- remove SeqStmt handler for flat Bind
  10. Simplify remove_no_op -- remove Bind elimination from SeqStmt
  11. Simplify lower_tvm_builtin -- flatten MakeNdMemAllocWithScope
  12. Remove obsolete roundtrip tests -- opt_gemm_mod_host and let_stmt_value
  13. Disable CanInlineLetStmt for flat Bind -- simplifier always keeps Bind
  14. Fix Bind scope management -- hoist_expression, ir_utils, and test updates
  15. Simplify tir_visitor_with_path -- use ScopeStack for Bind defs
  16. Fix remove_store_undef and inject_ptx_async_copy -- eagerly check undef in indices

Test plan

  • pytest tests/python/tir-transform/ (334 passed)
  • pytest tests/python/tir-base/ (269 passed)
  • pytest tests/python/tvmscript/ (771 passed)
  • pytest tests/python/s_tir/ (1276 passed)
  • pre-commit run --all-files (all passed)

@tqchen tqchen force-pushed the refactor-letstmtnode-to-bindnode branch from 00f1d73 to e35bd58 Compare March 2, 2026 19:04
mshr-h and others added 28 commits March 3, 2026 15:04
Enable opencl target for gpu tests.
Consolidates all Adreno tests under tests/python/relax/backend/adreno
Changes to CLML corresponding to recent changes on json codegen/runtime.
Docker specification for Adreno (ci_gpu + Android SDK, Gradle).
…er (apache#18865)

## Summary

This PR introduces `AllocBufferNode`/`AllocBuffer` as a single TIR
statement that both allocates memory and declares a buffer into scope.
This replaces the previous pattern of `Allocate(var, dtype, shape, cond,
DeclBuffer(buf, body))` with the simpler `AllocBuffer(buf, body)`.

### Main changes

- **New IR node** `AllocBufferNode` with fields `{buffer, annotations,
body}` — same semantics as `DeclBuffer` but also allocates memory
- **TVMScript**: `T.alloc_buffer(shape, dtype, scope)` now emits
`AllocBuffer` directly (statement-level allocation).
`T.sblock_alloc_buffer(...)` for SBlock-level buffer allocation (full
parameter set)
- **All codegen backends** (C, CUDA, Metal, OpenCL, WebGPU, LLVM, NVPTX,
AMDGPU, SPIR-V) updated to handle `AllocBufferNode`
- **All TIR transforms** (storage_rewrite, flatten_buffer,
vectorize_loop, lower_warp_memory, etc.) updated
- **All S-TIR transforms** (compact_buffer_region, merge_shared_memory,
inject_double_buffer, etc.) updated
- **Removed `AllocateNode`** entirely — `AllocBuffer` is now the sole
allocation primitive
- **Removed `AllocDescriptor`** from merge_shared_memory_allocations —
uses `Buffer` objects directly
- **Added `AllocBuffer::ConstantAllocationSize()`** inline helper method

### Design rationale

The old `Allocate + DeclBuffer` pair was a historical artifact:
`AllocateNode` stored raw fields (`buffer_var`, `dtype`, `extents`,
`condition`) separate from the `Buffer` object, requiring pattern
matching (`IsAllocateDeclBufferPattern`) to reconstruct the buffer
association. `AllocBuffer` unifies this into a single node with a proper
`Buffer` reference, simplifying codegen backends and transform passes.

225 files changed, ~3500 insertions/deletions (net near-zero, mostly
mechanical migration).

## Test plan

- [x] All TIR base tests pass
- [x] All TIR transform tests pass
- [x] TVMScript roundtrip tests pass
- [x] S-TIR transform tests pass
- [x] Codegen tests pass
- [x] All-platform minimal tests pass
- [x] C++ functor tests pass
- [x] Pre-commit clean (clang-format, ruff, etc.)
…ctor + analysis infrastructure)

Introduces `BindNode`/`Bind`, a new TIR statement node that binds a variable
to a value with flat (no-body) scope semantics, as the first step of the
LetStmt-to-Bind refactor. Unlike `LetStmtNode`, `BindNode` has no body field;
the bound variable is visible in subsequent siblings of the enclosing SeqStmt.

PR 1 — Core IR node + functor infrastructure:
- `include/tvm/tir/stmt.h`: Define BindNode (var, value, no body) and Bind ref class
- `src/tir/ir/stmt.cc`: Implement Bind constructor, RegisterReflection, GlobalDef
- `include/tvm/tir/stmt_functor.h`: Add VisitStmt_(BindNode*) to StmtFunctor vtable,
  StmtVisitor, StmtMutator
- `src/tir/ir/stmt_functor.cc`: Implement StmtVisitor and StmtMutator for BindNode
- `src/tir/ir/py_functor.cc`: Add BindNode dispatch entries for Python functors
- `src/tir/ir/tir_visitor_with_path.{h,cc}`: Add BindNode visitor (visits value only)

PR 2 — Base visitors/mutators + arithmetic + analysis:
- `src/arith/ir_mutator_with_analyzer.{h,cc}`: BindNode handler binds in analyzer
- `src/arith/ir_visitor_with_analyzer.{h,cc}`: BindNode handler visits value + binds
- `src/tir/ir/data_type_rewriter.{h,cc}`: BindNode support in DataTypeLegalizer and
  IndexDataTypeRewriter
- `src/tir/analysis/var_use_def_analysis.{h,cc}`: BindNode registers HandleDef
- `src/tir/analysis/verify_ssa.cc`: BindNode calls MarkDef
- `src/tir/analysis/verify_memory.cc`: BindNode books defs_ map
- `src/tir/analysis/control_flow_graph.cc`: BindNode checks UsesLoopVar

LetStmtNode/LetStmt are kept intact (deprecated aliases come in PR 11).

Tests: TIR base (269 passed, 2 skipped), all-platform-minimal (75 passed, 77 skipped)
This commit completes the migration from tree-nested LetStmtNode
(with body field) to flat BindNode (no body) across the entire TVM
codebase. BindNode binds a variable visible to subsequent siblings
in the enclosing SeqStmt scope, replacing the old nested scoping.

Key changes:
- LetStmtNode is now a `using` alias for BindNode
- All C++ LetStmt(var,val,body) constructions -> SeqStmt({Bind(var,val), body})
- All VisitStmt_(LetStmtNode*) handlers -> VisitStmt_(BindNode*)
  with op->body access removed (parent SeqStmt handles traversal)
- TIRVisitorWithPath::SeqStmt handler tracks Bind-defined vars
  for well-formed verification
- CSE pass: new SeqStmtNode handler + VisitSeqStmtSlice to process
  flat Bind sequences (mirrors old nested LetStmt CSE behavior)
- Python: added Bind class, LetStmt = Bind alias
- Updated ~65 C++ files and test files
Update TIR passes and tests to work correctly with the flat BindNode
model (no body field) where variable scoping is managed via SeqStmt
siblings instead of nested tree structure.

Pass fixes:
- RemoveNoOp: Add SeqStmt handler for back-to-front unused Bind scan
- ConvertSSA: Add SeqStmt handler to maintain ScopedRedefine across siblings
- StorageRewrite: Push/pop scope entry in BindNode handler
- HoistExpression: Merge Bind lifecycle management into SeqStmt handler;
  only set reached_sequential_node for truly sequential (non-Bind) stmts
- SBlockAccessRegionDetector: Defer let_bindings_ erasure to SeqStmt end
- TVMScript printer: Add AsDocBodySeqSlice for scoped T.LetStmt form
  when printing already-defined-var Binds
- TVMScript parser: Support doc.Attribute in _duplicate_lhs_check

Test updates:
- verify_well_formed: Adjust for flat scope semantics
- convert_ssa: Update for flattened SeqStmt behavior
- tvmscript printer/annotation/syntax_sugar: Update access paths
- loop_partition: Fix pre-existing test with incorrect expected output
- Update comments in var.h and Python functor.py to reference BindNode
  instead of LetStmtNode
- Apply clang-format fixes to files modified by the BindNode migration
- Remove unused Bind import in functor.py (LetStmt alias is used instead)
- Remove extra blank lines left over from migration in analysis/rewriter files
…FI_UNREACHABLE

- Replace all 8 `__builtin_unreachable()` calls with `TVM_FFI_UNREACHABLE()`:
  src/s_tir/transform/inject_virtual_thread.cc,
  src/script/printer/relax/distributed.cc,
  src/script/printer/tir/stmt.cc,
  src/target/source/codegen_c.cc,
  src/tir/transform/vectorize_loop.cc,
  src/tir/transform/lower_tvm_builtin.cc,
  include/tvm/script/printer/ir_docsifier_functor.h,
  include/tvm/tir/stmt.h

- Rename `kLetStmt` enum value → `kBind` in HoistedLetBindings (C++ and Python)
  (src/s_tir/transform/hoist_expression.cc,
   python/tvm/tir/transform/transform.py)

- Rename `LetStmt()` → `Bind()` in script/ir_builder/tir:
  - C++ function in ir.h and ir.cc; keep `LetStmt` as a deprecated inline alias
  - Register `"script.ir_builder.tir.Bind"` as primary; keep `LetStmt` as alias
  - Python ir.py: add `Bind()` as primary function; `LetStmt()` delegates to it

- Update stale `LetStmt` mentions in comments and docstrings to `Bind`:
  src/s_tir/schedule/analysis/reducer.cc,
  src/s_tir/schedule/primitive/reduction.cc,
  src/s_tir/transform/hoist_expression.cc,
  src/tir/ir/specialize.cc,
  src/tir/transform/common_subexpr_elim.cc,
  src/tir/transform/tvm_ffi_binder.h,
  src/tir/transform/ir_utils.cc,
  src/te/operation/create_primfunc.cc,
  include/tvm/tir/stmt.h,
  python/tvm/tir/stmt.py,
  python/tvm/tir/functor.py

- Clean up `src/script/printer/tir/utils.h`: remove `AsDocBodySeqSlice`
  helper that used `TIR(d, "LetStmt")` scoped form; inline loop directly
  in `AsDocBody` (Bind is flat-assignment, no scoped form needed)
… semantics

Replace the recursive VisitSeqStmtSlice helper with an iterative
SeqStmt handler that processes children directly: Bind nodes augment
the context and trigger cross-sibling CSE on remaining siblings,
while non-Bind nodes are processed individually.
…Bind

Remove the custom SeqStmt handler that maintained ScopedRedefine
entries for Bind nodes. Instead, the simplified BindNode handler
adds persistent remappings via function_scope_var_remap_ directly,
which don't need scoped cleanup. The default StmtMutator processes
SeqStmt children sequentially, so remappings from Bind nodes are
automatically visible to subsequent siblings.
… flat Bind

Simplify the SeqStmt handler to only perform sequential detection
(counting non-Bind statements) and delegate visitation to the parent.
Remove the Bind-var lifecycle management (tracking and erasing from
let_var_to_loop_vars/let_var_to_let_vars maps at sequence boundaries).
Bind vars now persist in the tracking maps for the duration of the
HoistInfoCollector instance.
…t handler for flat Bind

Remove the custom SeqStmt handler that tracked and erased Bind-defined
let_bindings_ at sequence boundaries. The BindNode handler now just adds
to let_bindings_ and relies on the BlockReadWriteDetector instance scope
for cleanup. The default StmtVisitor processes SeqStmt children
sequentially, so bindings are visible to subsequent siblings.
…eqStmt

Remove the custom SeqStmt handler and dead-Bind-variable backward scan
from remove_no_op. The VisitStmt_(BindNode*) handler now simply mutates
the value and returns. Unused Bind elimination can be added back later
via a separate two-pass approach.
…hScope

Remove the custom SeqStmt handler that captured remaining siblings as
body for nd_mem_alloc_with_scope processing. MakeNdMemAllocWithScope
now rewrites the Bind value inline (lowering to tvm_call_packed) and
adds a null check, without body capture.
…roundtrip tests

Remove opt_gemm_mod_host and let_stmt_value test functions from
test_tvmscript_roundtrip.py. Both use non-SSA re-binds (with T.LetStmt
var= pattern) that cannot roundtrip with flat Bind semantics.
With flat Bind there is no body to inspect for usage patterns, so
Bind inlining (removing the Bind and substituting its value) is
disabled. The analyzer still records variable bindings for constraint
proving, but the Bind statement is always kept. Remove the
CollectVarsUsedInBufferDefinition utility and used_in_buffer_def_
tracking which were only needed for the inlining codepath. Update
tests to reflect that Binds are no longer eliminated.
…ls, and tests

Update hoist_expression to manage Bind lifecycle in SeqStmt, fix
IRConvertSSA to handle Bind redefinitions across SeqStmt siblings,
and update test expectations for flat Bind semantics.
…defs

Use ScopeStack to manage Bind variable definitions. Body-carrying
statements (For, IfThenElse, Allocate, DeclBuffer, AttrStmt, While,
SBlock) push a new scope; BindNode pushes its WithDef into the
current scope. When the scope exits all Bind defs are cleaned up
automatically, removing the need for custom SeqStmt handling.
…flat Bind

- remove_store_undef: eagerly check buffer indices for undef in the
  locator phase (flat Bind means the undef Bind is a sibling, not an
  ancestor, so post-validation alone cannot catch it). Also remove
  Bind nodes whose value contains undef alongside the removed stores.
- inject_ptx_async_copy test: update expected CUDA to reflect that
  analyzer->Bind substitutes the variable with its value.
…nagement

Restore the original SeqStmt handler logic that tracks Bind vars defined
in a sequence and erases them from let_var_to_loop_vars/let_var_to_let_vars
maps when the sequence ends. Keep the refactor simple per user feedback.
The migration from LetStmt to Bind is complete. Remove all backward-
compatibility aliases and deprecated wrappers:

- Remove `using LetStmtNode = BindNode` and `using LetStmt = Bind`
  from include/tvm/tir/stmt.h
- Remove `LetStmt()` wrapper and `LegacyLetStmt()` from C++ and Python
  script ir_builder
- Remove `tir.LetStmt` FFI factory from stmt.cc
- Remove `LetStmt = Bind` alias from python/tvm/tir/stmt.py
- Rename `visit_let_stmt_` to `visit_bind_` in Python functor metadata
  and method names, matching the C++ `f_visit_bind` field
- Rename `f_visit_let_stmt` parameters in py_functor.cc to
  `f_visit_bind`
- Update all test files: T.LetStmt -> T.Bind, comments, function names
…anagement

Replace the ScopedRedefine RAII struct and custom SeqStmt handler with
ScopeStack<ScopeLevel> for cleaner scope management:

- Body-carrying statements (For, IfThenElse, AttrStmt, DeclBuffer,
  While, Allocate, SBlock) push a new scope via scope_.WithNewScope()
- Bind pushes var remaps to the current scope level, persisting across
  SeqStmt siblings
- Scope exit automatically pops all remaps via ScopeLevel destructor
- Remove the custom VisitStmt_(SeqStmtNode*) -- default sequential
  iteration works because Bind remaps persist in the enclosing scope
- Add IfThenElse handler with separate scopes per branch to prevent
  remap leakage between then/else cases
…ement

Replace manual save/restore of context_ in the Common Subexpression
Elimination pass with ScopeStack-based automatic scope management.

Key changes:
- Add ScopeStack<ContextScopeLevel> where each scope level records the
  context size on entry and truncates it back on exit via destructor
- ForNode, LetNode: WithNewScope replaces manual context save/restore
- New scope-boundary overrides for IfThenElse, AttrStmt, Allocate,
  DeclBuffer, While to prevent context leaks across scope boundaries
- SeqStmtNode: remove manual context save/restore (enclosing scope
  handles cleanup), retain wrap-remaining-siblings pattern for
  cross-sibling CSE after Bind nodes
- BindNode: entries persist across SeqStmt siblings, cleaned up
  automatically when enclosing body-carrying statement's scope exits
The LetStmt-to-Bind migration dropped the free_nd call that was
previously wrapped after the LetStmt body, causing a memory leak
for nd allocations (Hexagon VTCM, Adreno textures).

With flat Bind semantics, the free is pushed to a pending_nd_frees_
vector and appended at the end of the enclosing SeqStmt by a new
VisitStmt_(SeqStmtNode*) override.
The LetStmt-to-Bind refactor accidentally duplicated the LOG(WARNING)
call in IRDocsifierFunctor::operator(). Remove the extra one.
Rename test fixtures and functions that still use "letstmt" to "bind"
to match the LetStmt-to-Bind refactor:
- argmax_split_letstmt_{fewer,more}_than_init -> argmax_split_bind_*
- test_letstmt_bufferload_without_type_annotation -> test_bind_*
- test_letstmt_bind_with_constant -> test_bind_with_constant
Rename LetFrameNode/LetFrame to BindFrameNode/BindFrame across C++
headers, implementation, and Python bindings to align with the
LetStmt-to-Bind refactor. Updates FFI type key from
"script.ir_builder.tir.LetFrame" to "script.ir_builder.tir.BindFrame".
…base

Resolve remaining AllocateNode references that should be
AllocBufferNode after rebasing onto the AllocBuffer commit.
Also add TVM_FFI_UNREACHABLE after throw in blockize_tensorize.
@tqchen tqchen force-pushed the refactor-letstmtnode-to-bindnode branch from 95f473e to 6d07e9d Compare March 4, 2026 19:48
tqchen added 8 commits March 4, 2026 19:51
The pending_nd_frees_ approach hoisted free_nd calls to the nearest
SeqStmt boundary, which could incorrectly escape conditional branches.
Use ScopeStack instead: register free_nd in the current scope when
Bind allocates via nd_mem_alloc_with_scope, and emit frees on scope
exit. This matches the old LetStmt body semantics structurally.
The old handler wrapped remaining siblings after each individual Bind
node and re-ran VisitStmt, causing O(n^2) complexity for sequences with
many consecutive Bind nodes.

The new hybrid approach batches consecutive trivial Binds (constant or
variable values) and defers the cross-sibling CSE until the batch ends,
reducing the common case to O(n). Non-trivial Binds (whose values may
contain eligible computations) still use the per-Bind wrap pattern to
preserve full CSE effectiveness.
…ph and layout_transformation

BindLetVar and BindVariableDefinition RAII guards erased map entries on
destruction, but flat BindNode has no body -- the guard is destroyed when
the handler returns, making the binding invisible to subsequent sibling
statements.  Under SSA each variable is bound exactly once, so the maps
grow monotonically and cleanup is unnecessary.  Remove the cleanup from
both destructors to fix the bug.

Also add a comment explaining the dead cse_v1 variable in
test_s_tir_transform_inject_ptx_async_copy: CSE extracts (i < 12) before
inject_ptx_async_copy replaces IfThenElse guards with new cast(int32, ...)
expressions for predicated copies, leaving the CSE variable unused.
Bug 1 (inject_virtual_thread.cc): When a Bind in a SeqStmt touches
vt_var, the VT loop must wrap the Bind together with all remaining
siblings (which may reference the bound variable).  Previously, the
Bind handler wrapped only itself, breaking semantics.  Rewrite the
SeqStmt handler to pre-check Bind children and group them with
remaining siblings before wrapping with InjectVTLoop.

Bug 2 (lower_tvm_builtin.cc): MakeNdMemAllocWithScope was returning
without re-visiting via StmtExprMutator::VisitStmt, leaving
tvm_call_packed builtins in both the Bind value and the free_stmt
unlowered.  Re-wrap with VisitStmt and visit free_stmt before pushing
to pending_frees.

Bug 3 (frame.cc): BindFrameNode::ExitWithScope used SeqStmt
constructor (which does not flatten) instead of SeqStmt::Flatten,
creating nested SeqStmts.  Also, when stmts is empty, emit just the
Bind without wrapping in a SeqStmt with a spurious Evaluate(0).
Bind is now a direct statement like Evaluate -- it emits a Bind stmt
to the current frame and returns the Var, with no context manager or
RAII scope needed.

Changes:
- C++ ir_builder: Bind() creates var, calls AddToParent(tir::Bind(...)),
  returns var instead of BindFrame
- Remove BindFrameNode/BindFrame classes from frame.h and frame.cc
- Python ir_builder: Bind() returns Var instead of BindFrame
- Parser: bind_assign_value and visit_ann_assign simplified to call
  T.Bind() directly without frame lifecycle management
- Parser: visit_expr_stmt skips standalone Var results (from T.Bind())
  instead of wrapping them in T.evaluate()
- Remove BindFrame Python class from frame.py
- Update all tests from `with T.Bind(...) as v:` to `v = T.Bind(...)`
@tqchen tqchen closed this Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants