[REFACTOR][TIR] Phase out LetStmtNode: migrate to flat BindNode#5
Closed
tqchen wants to merge 36 commits into
Closed
[REFACTOR][TIR] Phase out LetStmtNode: migrate to flat BindNode#5tqchen wants to merge 36 commits into
tqchen wants to merge 36 commits into
Conversation
00f1d73 to
e35bd58
Compare
Enable opencl target for gpu tests. Consolidates all Adreno tests under tests/python/relax/backend/adreno Changes to CLML corresponding to recent changes on json codegen/runtime. Docker specification for Adreno (ci_gpu + Android SDK, Gradle).
…er (apache#18865) ## Summary This PR introduces `AllocBufferNode`/`AllocBuffer` as a single TIR statement that both allocates memory and declares a buffer into scope. This replaces the previous pattern of `Allocate(var, dtype, shape, cond, DeclBuffer(buf, body))` with the simpler `AllocBuffer(buf, body)`. ### Main changes - **New IR node** `AllocBufferNode` with fields `{buffer, annotations, body}` — same semantics as `DeclBuffer` but also allocates memory - **TVMScript**: `T.alloc_buffer(shape, dtype, scope)` now emits `AllocBuffer` directly (statement-level allocation). `T.sblock_alloc_buffer(...)` for SBlock-level buffer allocation (full parameter set) - **All codegen backends** (C, CUDA, Metal, OpenCL, WebGPU, LLVM, NVPTX, AMDGPU, SPIR-V) updated to handle `AllocBufferNode` - **All TIR transforms** (storage_rewrite, flatten_buffer, vectorize_loop, lower_warp_memory, etc.) updated - **All S-TIR transforms** (compact_buffer_region, merge_shared_memory, inject_double_buffer, etc.) updated - **Removed `AllocateNode`** entirely — `AllocBuffer` is now the sole allocation primitive - **Removed `AllocDescriptor`** from merge_shared_memory_allocations — uses `Buffer` objects directly - **Added `AllocBuffer::ConstantAllocationSize()`** inline helper method ### Design rationale The old `Allocate + DeclBuffer` pair was a historical artifact: `AllocateNode` stored raw fields (`buffer_var`, `dtype`, `extents`, `condition`) separate from the `Buffer` object, requiring pattern matching (`IsAllocateDeclBufferPattern`) to reconstruct the buffer association. `AllocBuffer` unifies this into a single node with a proper `Buffer` reference, simplifying codegen backends and transform passes. 225 files changed, ~3500 insertions/deletions (net near-zero, mostly mechanical migration). ## Test plan - [x] All TIR base tests pass - [x] All TIR transform tests pass - [x] TVMScript roundtrip tests pass - [x] S-TIR transform tests pass - [x] Codegen tests pass - [x] All-platform minimal tests pass - [x] C++ functor tests pass - [x] Pre-commit clean (clang-format, ruff, etc.)
…ctor + analysis infrastructure)
Introduces `BindNode`/`Bind`, a new TIR statement node that binds a variable
to a value with flat (no-body) scope semantics, as the first step of the
LetStmt-to-Bind refactor. Unlike `LetStmtNode`, `BindNode` has no body field;
the bound variable is visible in subsequent siblings of the enclosing SeqStmt.
PR 1 — Core IR node + functor infrastructure:
- `include/tvm/tir/stmt.h`: Define BindNode (var, value, no body) and Bind ref class
- `src/tir/ir/stmt.cc`: Implement Bind constructor, RegisterReflection, GlobalDef
- `include/tvm/tir/stmt_functor.h`: Add VisitStmt_(BindNode*) to StmtFunctor vtable,
StmtVisitor, StmtMutator
- `src/tir/ir/stmt_functor.cc`: Implement StmtVisitor and StmtMutator for BindNode
- `src/tir/ir/py_functor.cc`: Add BindNode dispatch entries for Python functors
- `src/tir/ir/tir_visitor_with_path.{h,cc}`: Add BindNode visitor (visits value only)
PR 2 — Base visitors/mutators + arithmetic + analysis:
- `src/arith/ir_mutator_with_analyzer.{h,cc}`: BindNode handler binds in analyzer
- `src/arith/ir_visitor_with_analyzer.{h,cc}`: BindNode handler visits value + binds
- `src/tir/ir/data_type_rewriter.{h,cc}`: BindNode support in DataTypeLegalizer and
IndexDataTypeRewriter
- `src/tir/analysis/var_use_def_analysis.{h,cc}`: BindNode registers HandleDef
- `src/tir/analysis/verify_ssa.cc`: BindNode calls MarkDef
- `src/tir/analysis/verify_memory.cc`: BindNode books defs_ map
- `src/tir/analysis/control_flow_graph.cc`: BindNode checks UsesLoopVar
LetStmtNode/LetStmt are kept intact (deprecated aliases come in PR 11).
Tests: TIR base (269 passed, 2 skipped), all-platform-minimal (75 passed, 77 skipped)
This commit completes the migration from tree-nested LetStmtNode
(with body field) to flat BindNode (no body) across the entire TVM
codebase. BindNode binds a variable visible to subsequent siblings
in the enclosing SeqStmt scope, replacing the old nested scoping.
Key changes:
- LetStmtNode is now a `using` alias for BindNode
- All C++ LetStmt(var,val,body) constructions -> SeqStmt({Bind(var,val), body})
- All VisitStmt_(LetStmtNode*) handlers -> VisitStmt_(BindNode*)
with op->body access removed (parent SeqStmt handles traversal)
- TIRVisitorWithPath::SeqStmt handler tracks Bind-defined vars
for well-formed verification
- CSE pass: new SeqStmtNode handler + VisitSeqStmtSlice to process
flat Bind sequences (mirrors old nested LetStmt CSE behavior)
- Python: added Bind class, LetStmt = Bind alias
- Updated ~65 C++ files and test files
Update TIR passes and tests to work correctly with the flat BindNode model (no body field) where variable scoping is managed via SeqStmt siblings instead of nested tree structure. Pass fixes: - RemoveNoOp: Add SeqStmt handler for back-to-front unused Bind scan - ConvertSSA: Add SeqStmt handler to maintain ScopedRedefine across siblings - StorageRewrite: Push/pop scope entry in BindNode handler - HoistExpression: Merge Bind lifecycle management into SeqStmt handler; only set reached_sequential_node for truly sequential (non-Bind) stmts - SBlockAccessRegionDetector: Defer let_bindings_ erasure to SeqStmt end - TVMScript printer: Add AsDocBodySeqSlice for scoped T.LetStmt form when printing already-defined-var Binds - TVMScript parser: Support doc.Attribute in _duplicate_lhs_check Test updates: - verify_well_formed: Adjust for flat scope semantics - convert_ssa: Update for flattened SeqStmt behavior - tvmscript printer/annotation/syntax_sugar: Update access paths - loop_partition: Fix pre-existing test with incorrect expected output
- Update comments in var.h and Python functor.py to reference BindNode instead of LetStmtNode - Apply clang-format fixes to files modified by the BindNode migration - Remove unused Bind import in functor.py (LetStmt alias is used instead) - Remove extra blank lines left over from migration in analysis/rewriter files
…FI_UNREACHABLE - Replace all 8 `__builtin_unreachable()` calls with `TVM_FFI_UNREACHABLE()`: src/s_tir/transform/inject_virtual_thread.cc, src/script/printer/relax/distributed.cc, src/script/printer/tir/stmt.cc, src/target/source/codegen_c.cc, src/tir/transform/vectorize_loop.cc, src/tir/transform/lower_tvm_builtin.cc, include/tvm/script/printer/ir_docsifier_functor.h, include/tvm/tir/stmt.h - Rename `kLetStmt` enum value → `kBind` in HoistedLetBindings (C++ and Python) (src/s_tir/transform/hoist_expression.cc, python/tvm/tir/transform/transform.py) - Rename `LetStmt()` → `Bind()` in script/ir_builder/tir: - C++ function in ir.h and ir.cc; keep `LetStmt` as a deprecated inline alias - Register `"script.ir_builder.tir.Bind"` as primary; keep `LetStmt` as alias - Python ir.py: add `Bind()` as primary function; `LetStmt()` delegates to it - Update stale `LetStmt` mentions in comments and docstrings to `Bind`: src/s_tir/schedule/analysis/reducer.cc, src/s_tir/schedule/primitive/reduction.cc, src/s_tir/transform/hoist_expression.cc, src/tir/ir/specialize.cc, src/tir/transform/common_subexpr_elim.cc, src/tir/transform/tvm_ffi_binder.h, src/tir/transform/ir_utils.cc, src/te/operation/create_primfunc.cc, include/tvm/tir/stmt.h, python/tvm/tir/stmt.py, python/tvm/tir/functor.py - Clean up `src/script/printer/tir/utils.h`: remove `AsDocBodySeqSlice` helper that used `TIR(d, "LetStmt")` scoped form; inline loop directly in `AsDocBody` (Bind is flat-assignment, no scoped form needed)
… semantics Replace the recursive VisitSeqStmtSlice helper with an iterative SeqStmt handler that processes children directly: Bind nodes augment the context and trigger cross-sibling CSE on remaining siblings, while non-Bind nodes are processed individually.
…Bind Remove the custom SeqStmt handler that maintained ScopedRedefine entries for Bind nodes. Instead, the simplified BindNode handler adds persistent remappings via function_scope_var_remap_ directly, which don't need scoped cleanup. The default StmtMutator processes SeqStmt children sequentially, so remappings from Bind nodes are automatically visible to subsequent siblings.
… flat Bind Simplify the SeqStmt handler to only perform sequential detection (counting non-Bind statements) and delegate visitation to the parent. Remove the Bind-var lifecycle management (tracking and erasing from let_var_to_loop_vars/let_var_to_let_vars maps at sequence boundaries). Bind vars now persist in the tracking maps for the duration of the HoistInfoCollector instance.
…t handler for flat Bind Remove the custom SeqStmt handler that tracked and erased Bind-defined let_bindings_ at sequence boundaries. The BindNode handler now just adds to let_bindings_ and relies on the BlockReadWriteDetector instance scope for cleanup. The default StmtVisitor processes SeqStmt children sequentially, so bindings are visible to subsequent siblings.
…eqStmt Remove the custom SeqStmt handler and dead-Bind-variable backward scan from remove_no_op. The VisitStmt_(BindNode*) handler now simply mutates the value and returns. Unused Bind elimination can be added back later via a separate two-pass approach.
…hScope Remove the custom SeqStmt handler that captured remaining siblings as body for nd_mem_alloc_with_scope processing. MakeNdMemAllocWithScope now rewrites the Bind value inline (lowering to tvm_call_packed) and adds a null check, without body capture.
…roundtrip tests Remove opt_gemm_mod_host and let_stmt_value test functions from test_tvmscript_roundtrip.py. Both use non-SSA re-binds (with T.LetStmt var= pattern) that cannot roundtrip with flat Bind semantics.
With flat Bind there is no body to inspect for usage patterns, so Bind inlining (removing the Bind and substituting its value) is disabled. The analyzer still records variable bindings for constraint proving, but the Bind statement is always kept. Remove the CollectVarsUsedInBufferDefinition utility and used_in_buffer_def_ tracking which were only needed for the inlining codepath. Update tests to reflect that Binds are no longer eliminated.
…ls, and tests Update hoist_expression to manage Bind lifecycle in SeqStmt, fix IRConvertSSA to handle Bind redefinitions across SeqStmt siblings, and update test expectations for flat Bind semantics.
…defs Use ScopeStack to manage Bind variable definitions. Body-carrying statements (For, IfThenElse, Allocate, DeclBuffer, AttrStmt, While, SBlock) push a new scope; BindNode pushes its WithDef into the current scope. When the scope exits all Bind defs are cleaned up automatically, removing the need for custom SeqStmt handling.
…flat Bind - remove_store_undef: eagerly check buffer indices for undef in the locator phase (flat Bind means the undef Bind is a sibling, not an ancestor, so post-validation alone cannot catch it). Also remove Bind nodes whose value contains undef alongside the removed stores. - inject_ptx_async_copy test: update expected CUDA to reflect that analyzer->Bind substitutes the variable with its value.
…nagement Restore the original SeqStmt handler logic that tracks Bind vars defined in a sequence and erases them from let_var_to_loop_vars/let_var_to_let_vars maps when the sequence ends. Keep the refactor simple per user feedback.
The migration from LetStmt to Bind is complete. Remove all backward- compatibility aliases and deprecated wrappers: - Remove `using LetStmtNode = BindNode` and `using LetStmt = Bind` from include/tvm/tir/stmt.h - Remove `LetStmt()` wrapper and `LegacyLetStmt()` from C++ and Python script ir_builder - Remove `tir.LetStmt` FFI factory from stmt.cc - Remove `LetStmt = Bind` alias from python/tvm/tir/stmt.py - Rename `visit_let_stmt_` to `visit_bind_` in Python functor metadata and method names, matching the C++ `f_visit_bind` field - Rename `f_visit_let_stmt` parameters in py_functor.cc to `f_visit_bind` - Update all test files: T.LetStmt -> T.Bind, comments, function names
…anagement Replace the ScopedRedefine RAII struct and custom SeqStmt handler with ScopeStack<ScopeLevel> for cleaner scope management: - Body-carrying statements (For, IfThenElse, AttrStmt, DeclBuffer, While, Allocate, SBlock) push a new scope via scope_.WithNewScope() - Bind pushes var remaps to the current scope level, persisting across SeqStmt siblings - Scope exit automatically pops all remaps via ScopeLevel destructor - Remove the custom VisitStmt_(SeqStmtNode*) -- default sequential iteration works because Bind remaps persist in the enclosing scope - Add IfThenElse handler with separate scopes per branch to prevent remap leakage between then/else cases
…ement Replace manual save/restore of context_ in the Common Subexpression Elimination pass with ScopeStack-based automatic scope management. Key changes: - Add ScopeStack<ContextScopeLevel> where each scope level records the context size on entry and truncates it back on exit via destructor - ForNode, LetNode: WithNewScope replaces manual context save/restore - New scope-boundary overrides for IfThenElse, AttrStmt, Allocate, DeclBuffer, While to prevent context leaks across scope boundaries - SeqStmtNode: remove manual context save/restore (enclosing scope handles cleanup), retain wrap-remaining-siblings pattern for cross-sibling CSE after Bind nodes - BindNode: entries persist across SeqStmt siblings, cleaned up automatically when enclosing body-carrying statement's scope exits
The LetStmt-to-Bind migration dropped the free_nd call that was previously wrapped after the LetStmt body, causing a memory leak for nd allocations (Hexagon VTCM, Adreno textures). With flat Bind semantics, the free is pushed to a pending_nd_frees_ vector and appended at the end of the enclosing SeqStmt by a new VisitStmt_(SeqStmtNode*) override.
The LetStmt-to-Bind refactor accidentally duplicated the LOG(WARNING) call in IRDocsifierFunctor::operator(). Remove the extra one.
Rename test fixtures and functions that still use "letstmt" to "bind"
to match the LetStmt-to-Bind refactor:
- argmax_split_letstmt_{fewer,more}_than_init -> argmax_split_bind_*
- test_letstmt_bufferload_without_type_annotation -> test_bind_*
- test_letstmt_bind_with_constant -> test_bind_with_constant
Rename LetFrameNode/LetFrame to BindFrameNode/BindFrame across C++ headers, implementation, and Python bindings to align with the LetStmt-to-Bind refactor. Updates FFI type key from "script.ir_builder.tir.LetFrame" to "script.ir_builder.tir.BindFrame".
…base Resolve remaining AllocateNode references that should be AllocBufferNode after rebasing onto the AllocBuffer commit. Also add TVM_FFI_UNREACHABLE after throw in blockize_tensorize.
95f473e to
6d07e9d
Compare
The pending_nd_frees_ approach hoisted free_nd calls to the nearest SeqStmt boundary, which could incorrectly escape conditional branches. Use ScopeStack instead: register free_nd in the current scope when Bind allocates via nd_mem_alloc_with_scope, and emit frees on scope exit. This matches the old LetStmt body semantics structurally.
The old handler wrapped remaining siblings after each individual Bind node and re-ran VisitStmt, causing O(n^2) complexity for sequences with many consecutive Bind nodes. The new hybrid approach batches consecutive trivial Binds (constant or variable values) and defers the cross-sibling CSE until the batch ends, reducing the common case to O(n). Non-trivial Binds (whose values may contain eligible computations) still use the per-Bind wrap pattern to preserve full CSE effectiveness.
…ph and layout_transformation BindLetVar and BindVariableDefinition RAII guards erased map entries on destruction, but flat BindNode has no body -- the guard is destroyed when the handler returns, making the binding invisible to subsequent sibling statements. Under SSA each variable is bound exactly once, so the maps grow monotonically and cleanup is unnecessary. Remove the cleanup from both destructors to fix the bug. Also add a comment explaining the dead cse_v1 variable in test_s_tir_transform_inject_ptx_async_copy: CSE extracts (i < 12) before inject_ptx_async_copy replaces IfThenElse guards with new cast(int32, ...) expressions for predicated copies, leaving the CSE variable unused.
Bug 1 (inject_virtual_thread.cc): When a Bind in a SeqStmt touches vt_var, the VT loop must wrap the Bind together with all remaining siblings (which may reference the bound variable). Previously, the Bind handler wrapped only itself, breaking semantics. Rewrite the SeqStmt handler to pre-check Bind children and group them with remaining siblings before wrapping with InjectVTLoop. Bug 2 (lower_tvm_builtin.cc): MakeNdMemAllocWithScope was returning without re-visiting via StmtExprMutator::VisitStmt, leaving tvm_call_packed builtins in both the Bind value and the free_stmt unlowered. Re-wrap with VisitStmt and visit free_stmt before pushing to pending_frees. Bug 3 (frame.cc): BindFrameNode::ExitWithScope used SeqStmt constructor (which does not flatten) instead of SeqStmt::Flatten, creating nested SeqStmts. Also, when stmts is empty, emit just the Bind without wrapping in a SeqStmt with a spurious Evaluate(0).
Bind is now a direct statement like Evaluate -- it emits a Bind stmt to the current frame and returns the Var, with no context manager or RAII scope needed. Changes: - C++ ir_builder: Bind() creates var, calls AddToParent(tir::Bind(...)), returns var instead of BindFrame - Remove BindFrameNode/BindFrame classes from frame.h and frame.cc - Python ir_builder: Bind() returns Var instead of BindFrame - Parser: bind_assign_value and visit_ann_assign simplified to call T.Bind() directly without frame lifecycle management - Parser: visit_expr_stmt skips standalone Var results (from T.Bind()) instead of wrapping them in T.evaluate() - Remove BindFrame Python class from frame.py - Update all tests from `with T.Bind(...) as v:` to `v = T.Bind(...)`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR refactors TVM TIR to replace the nested LetStmtNode (with body field) with a flat BindNode that has no body. Variable scope is determined by the enclosing body-carrying statement rather than nesting.
Commits
Test plan