[REFACTOR][TIR] Tie AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together#19605
Conversation
There was a problem hiding this comment.
Code Review
This pull request reorganizes the order of compilation passes in the TVM pipelines, specifically moving MergeSharedMemoryAllocations before AnnotateDeviceRegions and SplitHostDevice, and shifting LowerDeviceKernelLaunch to run before MakePackedAPI. The reviewer identified a critical scoping and correctness bug introduced by moving MergeSharedMemoryAllocations before SplitHostDevice. For PrimFuncs containing multiple device regions, this change causes shared memory allocations to be merged globally but only allocated in the first device region, leading to undefined variables and compilation failures once the host and device functions are split.
|
@tvm-bot run |
|
Failed to re-run CI in https://github.com/apache/tvm/actions/runs/26423763866 Detailswith response |
6720e81 to
848feba
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the MergeSharedMemoryAllocations pass to use a scope-stack design (KernelScope), which correctly handles multiple sibling thread_extent blocks within a single PrimFunc and prevents undefined buffer references. Additionally, the pass order in the compilation pipelines is adjusted, and corresponding tests are added. A critical issue was identified in the C++ implementation where an iterator is accessed after being erased from a map, which leads to undefined behavior.
… iterator scope.const_free_map.erase(it) invalidates it; the subsequent it->second dereference is undefined behavior. Capture the StorageEntry* into e before the erase and use e afterward. Flagged by Gemini reviewer on PR apache#19605.
A PrimFunc with multiple sibling thread_extent blocks (e.g. coming out of a multi-kernel Relax lowering) violates scoping in the current MergeSharedMemoryAllocations: the merged buffer is allocated only inside the first thread_extent body, but later thread_extents' accesses are rewritten to reference it. SplitHostDevice then emits device functions that read an undefined var. Convert every per-launch field into a KernelScope struct held on a stack. Push a new scope on the outermost thread_extent entry, collect/plan/rewrite/wrap inside that scope, pop on exit. Each kernel launch ends up with its own merged buffer, in scope only for its own subtree, preserving LowerDeviceKernelLaunch's "at most one dyn-shmem allocation per kernel" invariant. Adds a regression test exercising two sibling thread_extent blocks with independent shared-memory allocations.
…KernelLaunch together These three passes are logically a single host/device split step; having intermediaries between them obscures the model and blocks folding them into one pass. This PR moves each intermediary to the position its actual ordering constraint allows, so that AnnotateDeviceRegions, SplitHostDevice, and LowerDeviceKernelLaunch run consecutively in every pipeline. - MergeSharedMemoryAllocations moves before AnnotateDeviceRegions (the only legal position: LowerDeviceKernelLaunch requires at most one dyn-shmem allocation per kernel). - MakePackedAPI moves after LowerDeviceKernelLaunch (Lower's calling_conv flag causes MakePackedAPI to correctly skip device kernels; host body's lowered tvm_call_packed is transparent to MakePackedAPI's subroutine rewriter). - FP8StorageLegalize/BF16StorageLegalize move after MakePackedAPI (their buffer_map.size()==0 ICHECK requires MakePackedAPI to have cleared the map). Prereq for Phase 2: collapsing the three into a single tirx.transform.SplitHostDevice with three commented regions.
Address Gemini perf review on the merge-shmem refactor: the scope-stack refactor lost the original fast-paths that skipped liveness/planning/rewriting when there are 0 or 1 shmem allocations to merge. Restore both: - Per-scope: a thread_extent block with ≤1 shmem alloc skips the per-scope merging machinery. - Function-level: a PrimFunc with ≤1 shmem alloc of the relevant kind short-circuits the entire rewriter invocation. Behavior is unchanged; this is purely performance.
… iterator scope.const_free_map.erase(it) invalidates it; the subsequent it->second dereference is undefined behavior. Capture the StorageEntry* into e before the erase and use e afterward. Flagged by Gemini reviewer on PR apache#19605.
…I host targets
In the pipeline order that places LowerDeviceKernelLaunch before
MakePackedAPI, the host PrimFunc still carries Target("cuda", host="llvm")
when Lower visits it. The same-target shortcut at the call site
compared caller->WithoutHost() against the device kernel's target,
which produced cuda == cuda and silently skipped both the host-call
rewriting and the kernel-attribute assignment. The kernel was then
emitted with the default calling convention, which CUDA codegen lowers
as __device__ __launch_bounds__, rejected by nvcc.
The shortcut is meant for intra-device subroutine calls between
device-resident functions, not for a host caller whose target happens
to share a string with the kernel after WithoutHost(). Track whether
the current caller is a host function (its kTarget has a host attached)
and skip the same-target / same-device-type shortcuts when a host
caller invokes a real device kernel (callee has non-empty
launch_params). Pure intra-device subroutine calls and host-side
extern subroutine calls are unaffected.
This makes Lower order-independent with respect to MakePackedAPI's
host-target rewrite and is a prerequisite for keeping
AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch
consecutive in every pipeline.
…erDeviceKernelLaunch
When SplitHostDevice emits a host-side helper (e.g. an "add_host" with
target "c") for a private subroutine that is called from both host and
device contexts, the host caller still carries its full
"cuda+host=c" target at the time LowerDeviceKernelLaunch runs in the
new pipeline order. The same-target / same-device-type comparisons
used the caller's WithoutHost() target ("cuda") against the callee's
host target ("c"), making them appear cross-device and falling through
to the kernel-launch path. UpdateKernelAttributes then ran
ReturnRemover on the host helper's body, which contains a real
`T.ret(a+b)`, tripping the ICHECK that "device kernel may only
contain T.ret(0)".
The previous robustification only suppressed the same-target shortcut
when the caller is host AND the callee is a real kernel
(launch_params non-empty). It did not address the symmetric case of
a host caller invoking another host helper across host targets.
Capture the caller's host target separately and use it (in place of
the WithoutHost() device target) when the caller is a host function.
Host-to-host calls now correctly compare host targets and route to
the same-target shortcut or call_extern, never to the kernel-launch
ABI. Host-to-kernel calls remain forced through kernel launch.
Pure intra-device subroutine calls (callee_target on the device,
caller without host attached) are unaffected.
Verified with tests/python/codegen/test_target_codegen_cuda.py::
test_device_host_call_same_func[nvcc,nvrtc] (previously failing on
this PR, passing on upstream/main).
beab349 to
e02859a
Compare
Summary
These three passes are logically a single host/device split step;
having intermediaries between them obscures the model and blocks
folding them into one pass. This PR moves each intermediary to the
position its actual ordering constraint allows, so that
AnnotateDeviceRegions,SplitHostDevice, andLowerDeviceKernelLaunchrun consecutively in every pipeline.Rationale
MergeSharedMemoryAllocationsmoves beforeAnnotateDeviceRegions(the only legal position:
LowerDeviceKernelLaunchrequires at mostone dyn-shmem allocation per kernel, so Merge cannot move past Lower).
MakePackedAPImoves afterLowerDeviceKernelLaunch(Lower'skCallingConv = kDeviceKernelLaunchflag causesMakePackedAPItocorrectly skip device kernels; the host body's lowered
tvm_call_packedis transparent toMakePackedAPI's subroutinerewriter).
FP8StorageLegalize/BF16StorageLegalizemove afterMakePackedAPI(theirbuffer_map.size()==0ICHECK requiresMakePackedAPIto have cleared the map).Prereq for Phase 2: collapsing the three consecutive passes into a
single
tirx.transform.SplitHostDevicewith three commented regions.Test plan
test_tir_transform_bf16_legalize.py (13 pass)
test_target_codegen_device.py (6 pass including
test_subroutine_call — verifies Risk Start working on Tensor Infer #2)