From b8777b7d1b54b0586f243db74c75ea139307e89a Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 11:39:29 +0000 Subject: [PATCH 1/6] [REFACTOR][S-TIR] Merge shared memory per thread_extent scope A PrimFunc with multiple sibling thread_extent blocks (e.g. coming out of a multi-kernel Relax lowering) violates scoping in the current MergeSharedMemoryAllocations: the merged buffer is allocated only inside the first thread_extent body, but later thread_extents' accesses are rewritten to reference it. SplitHostDevice then emits device functions that read an undefined var. Convert every per-launch field into a KernelScope struct held on a stack. Push a new scope on the outermost thread_extent entry, collect/plan/rewrite/wrap inside that scope, pop on exit. Each kernel launch ends up with its own merged buffer, in scope only for its own subtree, preserving LowerDeviceKernelLaunch's "at most one dyn-shmem allocation per kernel" invariant. Adds a regression test exercising two sibling thread_extent blocks with independent shared-memory allocations. --- .../merge_shared_memory_allocations.cc | 427 ++++++++++-------- ...merge_dynamic_shared_memory_allocations.py | 95 +++- 2 files changed, 334 insertions(+), 188 deletions(-) diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index c680eb38ac71..3c5b7cd0782d 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -77,24 +77,26 @@ static int64_t ConstantAllocationSize(const ffi::Array& extents) { } /*! - * \brief collect the mapping from the buffer var to its Buffer + * \brief collect the mapping from the buffer var to its Buffer within a subtree */ class AllocateCollector : public StmtExprVisitor { public: + explicit AllocateCollector(bool is_dynamic) : is_dynamic_(is_dynamic) {} + void VisitStmt_(const AllocBufferNode* op) final { - if (IsDynamicSharedMemory(op->buffer->data) || IsStaticSharedMemory(op->buffer->data)) { - if (IsDynamicSharedMemory(op->buffer->data)) { - dyn_shmem_allocs_[op->buffer->data.get()] = op->buffer; - } else { - static_shmem_allocs_[op->buffer->data.get()] = op->buffer; - } + if (is_dynamic_ && IsDynamicSharedMemory(op->buffer->data)) { + shmem_allocs_[op->buffer->data.get()] = op->buffer; + } else if (!is_dynamic_ && IsStaticSharedMemory(op->buffer->data)) { + shmem_allocs_[op->buffer->data.get()] = op->buffer; } StmtExprVisitor::VisitStmt_(op); } - // The dynamic mapping from the original buffer var to its Buffer - std::unordered_map dyn_shmem_allocs_; - // The static mapping from the original buffer var to its Buffer - std::unordered_map static_shmem_allocs_; + + // The mapping from the original buffer var to its Buffer + std::unordered_map shmem_allocs_; + + private: + bool is_dynamic_; }; // Find a linear pattern of storage access @@ -274,89 +276,122 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { /*! * \brief merge the buffers whose live range has no intersection and rewrite the body + * + * Uses a scope-stack design: each thread_extent block (kernel launch) gets its + * own KernelScope that owns the merged buffer var and all per-launch bookkeeping. + * This correctly handles PrimFuncs with multiple sibling thread_extent blocks. */ class SharedMemoryRewriter : public StmtExprMutator { public: - explicit SharedMemoryRewriter(const std::unordered_map& shmem_allocs, - bool is_dynamic = true) - : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} { - if (!is_dynamic) { - merged_buf_var_ = Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared")); - } - } + explicit SharedMemoryRewriter(bool is_dynamic = true) : is_dynamic_{is_dynamic} {} + + private: + using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; + + struct StorageEntry { + // The constant size of the buffer in bits, only used if it is constant + uint64_t const_nbits{0}; + // Allocs that shares this entry. + // The inner vector means a "layer" + // For example, it we need to allocate C in the memory of A and B: + // | A: 4096 bytes | B: 4096 bytes | + // | C: 8192 bytes | + // Then the allocs = {{A, B}, {C}} + std::vector> allocs; + }; + + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + /*! + * \brief Per-kernel-launch scope holding all state for one thread_extent block. + */ + struct KernelScope { + // The merged buffer var for THIS kernel launch. + Var merged_buf_var; + // Total byte size of THIS kernel's merged buffer. + PrimExpr merged_alloc_size{0}; + // Allocations from THIS kernel's subtree. + std::unordered_map shmem_allocs; + // Per-buffer byte offset into merged_buf_var. + std::unordered_map buffer_byte_offsets; + // Buffer-object remap: original Buffer -> merged-data-var Buffer. + std::unordered_map buffer_remap; + // Has any original alloc in this scope been marked volatile? + bool has_volatile_alloc{false}; + // Liveness data (event_map, alloc_map, const_free_map, sym_free_list) — all per-scope. + std::unordered_map event_map; + std::multimap const_free_map; + std::list sym_free_list; + std::unordered_map alloc_map; + }; /*! - * \brief plan the memory reuse for all the buffer allocated in the statement - * \param stmt the statement + * \brief Create a fresh merged buffer Var for a new kernel scope. + * Same name string is fine — Var identity is by pointer, not name. */ - void PlanReuse(const Stmt& stmt, bool is_dynamic = true) { - SharedMemLinearAccessPatternFinder finder(is_dynamic); - finder(stmt); - this->LivenessAnalysis(finder.linear_seq_); - this->PlanMemory(finder.linear_seq_); + Var MakeMergedBufferVar() { + if (is_dynamic_) { + return Var("buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")); + } else { + return Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared")); + } } - private: Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tirx::attr::thread_extent && !allocated_) { - // Allocate one dynamic shared memory allocation at the beginning of thread scope - int max_layer_num = 0; - std::vector all_entry; - for (const auto& e : const_free_map_) { - all_entry.push_back(e.second); - } - for (const StorageEntry* e : sym_free_list_) { - all_entry.push_back(e); - } - for (const StorageEntry* e : all_entry) { - max_layer_num = std::max(max_layer_num, static_cast(e->allocs.size())); - } - // calculate align for each layer of each storage entry. - std::vector align(max_layer_num, 0); - for (const StorageEntry* e : all_entry) { - for (int i = 0; i < static_cast(e->allocs.size()); i++) { - for (const VarNode* buffer : e->allocs[i]) { - const Buffer& buf = shmem_allocs_.at(buffer); - align[i] = std::max(align[i], buf->dtype.bytes()); - } - } - } - // calculate offset for each buffer based on the align of each layer - for (const StorageEntry* e : all_entry) { - PrimExpr max_inner_offset = 0; - for (int i = 0; i < static_cast(e->allocs.size()); i++) { - PrimExpr inner_offset = 0; - for (const VarNode* buffer : e->allocs[i]) { - const Buffer& buf = shmem_allocs_.at(buffer); - ffi::Array alloc_shape = GetBufferAllocationShape(buf); - int align_bytes = std::max(align[i], buf->dtype.bytes()); - if (buf->data_alignment > 0) { - TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0) - << "The alignment of the buffer is not a multiple of the data type size."; - align_bytes = buf->data_alignment; - } - PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes(); - inner_offset += - indexmod(align_bytes - indexmod(merged_alloc_size_ + inner_offset, align_bytes), - align_bytes); - buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; - inner_offset += buffer_bytes; - } - max_inner_offset = max(max_inner_offset, inner_offset); - } - merged_alloc_size_ += max_inner_offset; - } + if (op->attr_key == tirx::attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + + // 1. Push a fresh scope. + scope_stack_.emplace_back(); + KernelScope& scope = scope_stack_.back(); + scope.merged_buf_var = MakeMergedBufferVar(); + + // 2. Collect shmem allocs that belong to THIS subtree. + AllocateCollector collector(is_dynamic_); + collector(op->body); + scope.shmem_allocs = std::move(collector.shmem_allocs_); + + // 3. Liveness + reuse plan over this subtree only. + // Run the finder on the full AttrStmt (not just op->body) so that + // VisitNewScope creates the proper scope pair entry for the thread_extent. + SharedMemLinearAccessPatternFinder finder(is_dynamic_); + finder(ffi::GetRef(op)); + this->LivenessAnalysis(finder.linear_seq_, scope); + this->PlanMemory(finder.linear_seq_, scope); - allocated_ = true; - Buffer merged_buf(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, {}, PrimExpr(), - merged_buf_var_->name_hint, 0, 0, BufferType::kDefault); + // 4. Compute byte offsets / merged_alloc_size. + this->ComputeOffsets(scope); + + // 5. Recursively mutate the body — reads scope_stack_.back() for all rewrites. Stmt visited_body = StmtExprMutator::VisitStmt(op->body); + + in_thread_env_ = false; + + // 6. If this scope has no shmem allocs, skip the wrapper. + if (scope.shmem_allocs.empty()) { + scope_stack_.pop_back(); + return AttrStmt(op->node, op->attr_key, op->value, visited_body, op->span); + } + + // 7. Wrap with the merged-buffer AllocBuffer. + Buffer merged_buf(scope.merged_buf_var, DataType::UInt(8), {scope.merged_alloc_size}, {}, + PrimExpr(), scope.merged_buf_var->name_hint, 0, 0, BufferType::kDefault); ffi::Map annotations; - if (has_volatile_alloc_) { + if (scope.has_volatile_alloc) { annotations.Set(tirx::attr::kVolatile, true); } Stmt alloc_stmt = AllocBuffer(merged_buf, annotations); Stmt new_body = SeqStmt::Flatten(alloc_stmt, visited_body); + + // 8. Pop the scope. + scope_stack_.pop_back(); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } return StmtMutator::VisitStmt_(op); @@ -364,10 +399,17 @@ class SharedMemoryRewriter : public StmtExprMutator { Stmt VisitStmt_(const AllocBufferNode* op) final { if (IsAppropriateSharedMemory(op->buffer->data)) { - if (op->annotations.count(tirx::attr::kVolatile)) { - has_volatile_alloc_ = true; + if (!scope_stack_.empty()) { + KernelScope& scope = scope_stack_.back(); + if (scope.shmem_allocs.count(op->buffer->data.get())) { + if (op->annotations.count(tirx::attr::kVolatile)) { + scope.has_volatile_alloc = true; + } + return Evaluate(0); + } } - return Evaluate(0); + // Outside any thread_extent scope — leave as-is. + return StmtExprMutator::VisitStmt_(op); } return StmtExprMutator::VisitStmt_(op); } @@ -392,7 +434,8 @@ class SharedMemoryRewriter : public StmtExprMutator { template Node VisitBufferAccess(Node node) { - if (IsAppropriateSharedMemory(node->buffer->data)) { + if (IsAppropriateSharedMemory(node->buffer->data) && !scope_stack_.empty() && + scope_stack_.back().shmem_allocs.count(node->buffer->data.get())) { TVM_FFI_ICHECK_EQ(node->indices.size(), 1) << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " @@ -409,9 +452,13 @@ class SharedMemoryRewriter : public StmtExprMutator { } Buffer GetUpdatedBuffer(Buffer buffer) { + if (scope_stack_.empty()) return buffer; + KernelScope& scope = scope_stack_.back(); + if (!scope.shmem_allocs.count(buffer->data.get())) return buffer; + auto key = buffer.get(); - auto it = buffer_remap_.find(key); - if (it != buffer_remap_.end()) { + auto it = scope.buffer_remap.find(key); + if (it != scope.buffer_remap.end()) { return it->second; } @@ -422,10 +469,10 @@ class SharedMemoryRewriter : public StmtExprMutator { << "and is to be run after " << "FlattenBuffer"; auto writer = buffer.CopyOnWrite(); - writer->data = merged_buf_var_; + writer->data = scope.merged_buf_var; } - buffer_remap_[key] = buffer; + scope.buffer_remap[key] = buffer; return buffer; } @@ -434,7 +481,8 @@ class SharedMemoryRewriter : public StmtExprMutator { TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); Var buffer = Downcast(op->args[1]); - if (!IsAppropriateSharedMemory(buffer)) { + if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() || + !scope_stack_.back().shmem_allocs.count(buffer.get())) { return StmtExprMutator::VisitExpr_(op); } PrimExpr extra_offset = GetBufferOffset(buffer, dtype); @@ -442,7 +490,8 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); return Call(op->dtype, op->op, - {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); + {op->args[0], scope_stack_.back().merged_buf_var, extra_offset + offset, extent, + op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async())) { TVM_FFI_ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); Var buffer = Downcast(op->args[0]); @@ -451,7 +500,8 @@ class SharedMemoryRewriter : public StmtExprMutator { const auto* prim_type = ptr_type->element_type.as(); TVM_FFI_ICHECK(prim_type) << "The buffer should be a pointer to a primitive type."; DataType dtype = DataType(prim_type->dtype); - if (!IsAppropriateSharedMemory(buffer)) { + if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() || + !scope_stack_.back().shmem_allocs.count(buffer.get())) { return StmtExprMutator::VisitExpr_(op); } PrimExpr extra_offset = GetBufferOffset(buffer, dtype); @@ -461,21 +511,25 @@ class SharedMemoryRewriter : public StmtExprMutator { // the correct offset of merged shared buffer. int index_factor = dtype.bytes(); if (op->args.size() == 5) - return Call(dtype, op->op, - {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4]}); + return Call( + dtype, op->op, + {scope_stack_.back().merged_buf_var, mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4]}); else - return Call(dtype, op->op, - {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4], op->args[5]}); + return Call( + dtype, op->op, + {scope_stack_.back().merged_buf_var, mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4], op->args[5]}); } else { return StmtExprMutator::VisitExpr_(op); } } PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { - auto it = buffer_byte_offsets_.find(buffer_var.get()); - TVM_FFI_ICHECK(it != buffer_byte_offsets_.end()); + TVM_FFI_ICHECK(!scope_stack_.empty()); + KernelScope& scope = scope_stack_.back(); + auto it = scope.buffer_byte_offsets.find(buffer_var.get()); + TVM_FFI_ICHECK(it != scope.buffer_byte_offsets.end()); return indexdiv(it->second, dtype.bytes()); } @@ -484,32 +538,12 @@ class SharedMemoryRewriter : public StmtExprMutator { return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); } - using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; - struct StorageEntry { - // The constant size of the buffer in bits, only used if it is constant - uint64_t const_nbits{0}; - // Allocs that shares this entry. - // The inner vector means a "layer" - // For example, it we need to allocate C in the memory of A and B: - // | A: 4096 bytes | B: 4096 bytes | - // | C: 8192 bytes | - // Then the allocs = {{A, B}, {C}} - std::vector> allocs; - }; - - // Event entry in liveness analysis - struct EventEntry { - // variables we generate - std::vector gen; - // variables we kill - std::vector kill; - }; - /*! * \brief Liveness analysis to find gen and kill point of each variable. * \param seq the linear pattern of storage access + * \param scope the kernel scope to write results into */ - void LivenessAnalysis(const std::vector& seq) { + void LivenessAnalysis(const std::vector& seq, KernelScope& scope) { // find kill point, do a reverse linear scan. std::unordered_set touched; for (size_t i = seq.size(); i != 0; --i) { @@ -517,7 +551,7 @@ class SharedMemoryRewriter : public StmtExprMutator { for (const VarNode* buffer : s.touched) { if (!touched.count(buffer)) { touched.insert(buffer); - event_map_[s.stmt].kill.push_back(buffer); + scope.event_map[s.stmt].kill.push_back(buffer); } } } @@ -530,7 +564,7 @@ class SharedMemoryRewriter : public StmtExprMutator { for (const VarNode* buffer : s.touched) { if (!touched.count(buffer)) { touched.insert(buffer); - event_map_[s.stmt].gen.push_back(buffer); + scope.event_map[s.stmt].gen.push_back(buffer); } } } @@ -539,12 +573,13 @@ class SharedMemoryRewriter : public StmtExprMutator { /*! * \brief Memory plan algorithm * \param seq the linear pattern of storage access + * \param scope the kernel scope to write results into */ - void PlanMemory(const std::vector& seq) { + void PlanMemory(const std::vector& seq, KernelScope& scope) { std::unordered_set inplace_flag; for (size_t i = 0; i < seq.size(); ++i) { - auto it = event_map_.find(seq[i].stmt); + auto it = scope.event_map.find(seq[i].stmt); // scope_pair_offset <= 0 means it is either // - leaf stmt(offset = 0) // - end of scope(offset < 0) @@ -553,30 +588,84 @@ class SharedMemoryRewriter : public StmtExprMutator { return seq[i].scope_pair_offset == 0 && std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end(); }; - if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { - if (!is_leaf_alloc(var)) this->Free(var); + if (!is_leaf_alloc(var)) this->Free(var, scope); } } // scope_pair_offset >= 0 means it is either // - leaf stmt(offset = 0) // - beginning of scope(offset < 0) // In both cases, we need to handle the gen event correctly - if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + if (it != scope.event_map.end() && seq[i].scope_pair_offset >= 0) { for (const VarNode* var : it->second.gen) { - TVM_FFI_ICHECK(shmem_allocs_.count(var)); - const Buffer& buf = shmem_allocs_.at(var); - StorageEntry* dst_entry = FindAlloc(buf); - alloc_map_[var] = dst_entry; + TVM_FFI_ICHECK(scope.shmem_allocs.count(var)); + const Buffer& buf = scope.shmem_allocs.at(var); + StorageEntry* dst_entry = FindAlloc(buf, scope); + scope.alloc_map[var] = dst_entry; } } - if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { - if (is_leaf_alloc(var)) this->Free(var); + if (is_leaf_alloc(var)) this->Free(var, scope); + } + } + } + } + + /*! + * \brief Compute byte offsets for all entries in the scope after PlanMemory. + * \param scope the kernel scope whose offset map to fill + */ + void ComputeOffsets(KernelScope& scope) { + int max_layer_num = 0; + std::vector all_entry; + for (const auto& e : scope.const_free_map) { + all_entry.push_back(e.second); + } + for (const StorageEntry* e : scope.sym_free_list) { + all_entry.push_back(e); + } + for (const StorageEntry* e : all_entry) { + max_layer_num = std::max(max_layer_num, static_cast(e->allocs.size())); + } + // calculate align for each layer of each storage entry. + std::vector align(max_layer_num, 0); + for (const StorageEntry* e : all_entry) { + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + for (const VarNode* buffer : e->allocs[i]) { + const Buffer& buf = scope.shmem_allocs.at(buffer); + align[i] = std::max(align[i], buf->dtype.bytes()); } } } + // calculate offset for each buffer based on the align of each layer + for (const StorageEntry* e : all_entry) { + PrimExpr max_inner_offset = 0; + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + PrimExpr inner_offset = 0; + for (const VarNode* buffer : e->allocs[i]) { + const Buffer& buf = scope.shmem_allocs.at(buffer); + ffi::Array alloc_shape = GetBufferAllocationShape(buf); + int align_bytes = std::max(align[i], buf->dtype.bytes()); + if (buf->data_alignment > 0) { + TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0) + << "The alignment of the buffer is not a multiple of the data type size."; + align_bytes = buf->data_alignment; + } + PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes(); + inner_offset += + indexmod(align_bytes - indexmod(scope.merged_alloc_size + inner_offset, align_bytes), + align_bytes); + scope.buffer_byte_offsets[buffer] = scope.merged_alloc_size + inner_offset; + inner_offset += buffer_bytes; + } + max_inner_offset = max(max_inner_offset, inner_offset); + } + scope.merged_alloc_size = scope.merged_alloc_size + max_inner_offset; + } } + /*! * \brief Allocate new storage entry. * \param buf the buffer object @@ -590,12 +679,14 @@ class SharedMemoryRewriter : public StmtExprMutator { entry->const_nbits = const_nbits; return entry; } + /*! * \brief find the storage entry in the free list for the buffer * \param buf the buffer object + * \param scope the kernel scope whose free lists to search * \return the storage entry */ - StorageEntry* FindAlloc(const Buffer& buf) { + StorageEntry* FindAlloc(const Buffer& buf, KernelScope& scope) { // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; @@ -611,16 +702,16 @@ class SharedMemoryRewriter : public StmtExprMutator { if (const_nbits != 0) { // constant allocation. - auto begin = const_free_map_.lower_bound(0); - auto mid = const_free_map_.lower_bound(const_nbits); - auto end = const_free_map_.upper_bound(const_nbits * match_range); + auto begin = scope.const_free_map.lower_bound(0); + auto mid = scope.const_free_map.lower_bound(const_nbits); + auto end = scope.const_free_map.upper_bound(const_nbits * match_range); // Start looking at the buffer that is bigger than the required size first. // If we find one, directly allocate the buffer in its location and remove its entry in the // free list for (auto it = mid; it != end; ++it) { StorageEntry* e = it->second; e->const_nbits = std::max(const_nbits, e->const_nbits); - const_free_map_.erase(it); + scope.const_free_map.erase(it); it->second->allocs.push_back({buf->data.get()}); return e; } @@ -654,16 +745,16 @@ class SharedMemoryRewriter : public StmtExprMutator { e->const_nbits = std::max(const_nbits, mem_ct); e->allocs = reuse_allocs; for (auto it : delete_it) { - const_free_map_.erase(it); + scope.const_free_map.erase(it); } return e; } } else { // if its symbolic allocation, just arbitrarily choose one entry to fit in because we don't // know its actual size - for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { + for (auto it = scope.sym_free_list.begin(); it != scope.sym_free_list.end(); ++it) { StorageEntry* e = *it; - sym_free_list_.erase(it); + scope.sym_free_list.erase(it); return e; } } @@ -673,10 +764,11 @@ class SharedMemoryRewriter : public StmtExprMutator { /*! * \brief add the storage entry to the buffer var into the free list. * \param var the buffer var + * \param scope the kernel scope whose free lists to update */ - void Free(const VarNode* var) { - auto it = alloc_map_.find(var); - TVM_FFI_ICHECK(it != alloc_map_.end()); + void Free(const VarNode* var, KernelScope& scope) { + auto it = scope.alloc_map.find(var); + TVM_FFI_ICHECK(it != scope.alloc_map.end()); StorageEntry* e = it->second; TVM_FFI_ICHECK_NE(e->allocs.size(), 0U); @@ -685,51 +777,28 @@ class SharedMemoryRewriter : public StmtExprMutator { // normal free. if (e->const_nbits != 0) { - const_free_map_.insert({e->const_nbits, e}); + scope.const_free_map.insert({e->const_nbits, e}); } else { - sym_free_list_.push_back(e); + scope.sym_free_list.push_back(e); } } + // Whether enable dynamic analysis. bool is_dynamic_{true}; - // The var for the merged buffer - Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; - // The mapping from the original buffer var to its Buffer - std::unordered_map shmem_allocs_; - // The size of the merged buffer - PrimExpr merged_alloc_size_{0}; - // The mapping from the original buffer var to its offset in the merged buffer - std::unordered_map buffer_byte_offsets_; - // The mapping from the original buffer objects to their location in the merged buffer. - std::unordered_map buffer_remap_; - // The flag indicating whether the merged buffer has been allocated - bool allocated_{false}; - // Whether any original shared memory allocation had the volatile annotation - bool has_volatile_alloc_{false}; - // Locations of free ops. - std::unordered_map event_map_; - // constant size free map. - std::multimap const_free_map_; - // symbolic free list, for non constant items. - std::list sym_free_list_; - // The allocation assign map - std::unordered_map alloc_map_; - /*! \brief allocator of all the StorageEntry*/ + // Whether already inside a thread_extent (outermost only). + bool in_thread_env_{false}; + // Stack of per-kernel-launch scopes. Pushed on thread_extent entry, popped on exit. + std::vector scope_stack_; + /*! \brief allocator of all the StorageEntry (shared across all scopes) */ support::Arena arena_; }; Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) { - AllocateCollector collector; - collector(stmt); - if (collector.dyn_shmem_allocs_.size() > 1) { - SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); - rewriter.PlanReuse(stmt); - stmt = rewriter(std::move(stmt)); - } - if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { - SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false); - rewriter.PlanReuse(stmt, false); - stmt = rewriter(std::move(stmt)); + SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true); + stmt = dyn_rewriter(std::move(stmt)); + if (merge_static_smem) { + SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false); + stmt = static_rewriter(std::move(stmt)); } return stmt; } diff --git a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py index ca7d1de7c488..b09c1fd796b1 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -254,23 +254,100 @@ def test_async_copy(): class Before: @T.prim_func(s_tir=True) def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + threadIdx_x = T.launch_thread("threadIdx.x", 128) A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") - threadIdx_x = T.launch_thread("threadIdx.x", 128) T.ptx.cp_async("float32", A_sh.data, threadIdx_x, A.data, threadIdx_x, 512) T.ptx.cp_async("float32", B_sh.data, threadIdx_x, B.data, threadIdx_x, 512) After = transform(Before) - # The pass merges shared.dyn allocations but DeclBuffer nodes from the original - # allocations remain with remapped data vars. The output can't be precisely - # represented in TVMScript due to same-name var constraints, so we verify - # key properties instead of exact structural equality. + # The pass merges shared.dyn allocations. A_sh and B_sh are accessed + # sequentially inside the thread_extent with non-overlapping lifetimes, + # so the liveness analysis allows reuse — both fit in 512 bytes + # (= 128 elements * 4 bytes). script = After["main"].script() - # Verify merged allocation (1024 bytes = 128*4 + 128*4) - assert '"uint8"' in script and '"shared.dyn"' in script and "(1024,)" in script - # Verify cp_async uses correct byte offsets + # Verify merged allocation (512 bytes - A_sh and B_sh can be reused) + assert '"uint8"' in script and '"shared.dyn"' in script and "(512,)" in script + # Verify cp_async uses the merged buffer + assert "buf_dyn_shmem" in script assert "threadIdx_x * 4" in script - assert "(128 + threadIdx_x) * 4" in script + + +def test_multi_thread_extent_blocks(): + """Each thread_extent block must get its own merged buffer. + + Reproduces the scoping bug from PR #19605: a single PrimFunc + with two sibling thread_extent regions, each containing its + own shared.dyn allocations. The merged buffer must be allocated + inside each kernel body — not just the first. + """ + transform = tvm.s_tir.transform.MergeSharedMemoryAllocations() + + @I.ir_module(check_well_formed=False) + class Before: + @T.prim_func(s_tir=True, check_well_formed=False) + def main( + X: T.Buffer((128,), "float32"), + Y: T.Buffer((128,), "float32"), + ): + X_flat = T.decl_buffer(128, data=X.data) + Y_flat = T.decl_buffer(128, data=Y.data) + + # First kernel launch + tx0 = T.env_thread("threadIdx.x") + with T.attr(tx0, "thread_extent", 128): + A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") + B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") + A_sh[tx0] = X_flat[tx0] + B_sh[tx0] = A_sh[tx0] + X_flat[tx0] = B_sh[tx0] + + # Second kernel launch — must NOT see kernel #0's merged buffer. + tx1 = T.env_thread("threadIdx.x") + with T.attr(tx1, "thread_extent", 128): + C_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") + D_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn") + C_sh[tx1] = Y_flat[tx1] + D_sh[tx1] = C_sh[tx1] + Y_flat[tx1] = D_sh[tx1] + + After = transform(Before) + script = After["main"].script() + + # Two merged allocations — one per thread_extent body. + # Each of the four original 128-float32 buffers (A_sh, B_sh, C_sh, D_sh) + # gets merged within its own kernel scope. + assert script.count("shared.dyn") >= 2, ( + "Expected at least two shared.dyn allocations (one per kernel)" + ) + assert script.count("alloc_buffer") >= 2, ( + "Expected at least two alloc_buffer nodes (one merged buf per kernel)" + ) + + # Both thread_extent blocks must contain their own merged buffer — + # they must NOT share the same buf_dyn_shmem variable. + # Structurally verify that the first kernel's body accesses are + # not rewritten to the second kernel's buf_dyn_shmem (and vice versa). + first_block = script.split("with T.attr(tx1")[0] + second_block = script.split("with T.attr(tx1")[1] if "tx1" in script else "" + assert "buf_dyn_shmem" in first_block, "Kernel 1 must have a merged buffer" + if second_block: + assert "buf_dyn_shmem" in second_block, "Kernel 2 must have a merged buffer" + + # End-to-end: post-merge IR must remain well-formed through + # the host/device split — this is the exact ordering from + # PR #19605 that triggers the scoping bug. + target = tvm.target.Target("llvm") + mod_with_target = tvm.IRModule({"main": After["main"].with_attr({"target": target})}) + split = tvm.transform.Sequential( + [ + tvm.tirx.transform.AnnotateDeviceRegions(), + tvm.tirx.transform.SplitHostDevice(), + ] + ) + # If kernel #1 referenced an undefined buf_dyn_shmem, this + # would raise during well-formedness checking inside SplitHostDevice. + split(mod_with_target) if __name__ == "__main__": From dcd4f17acb2ac80db14072af6505dc6c6fb3d376 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 19:59:19 +0000 Subject: [PATCH 2/6] [REFACTOR][TIR] Tie AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together These three passes are logically a single host/device split step; having intermediaries between them obscures the model and blocks folding them into one pass. This PR moves each intermediary to the position its actual ordering constraint allows, so that AnnotateDeviceRegions, SplitHostDevice, and LowerDeviceKernelLaunch run consecutively in every pipeline. - MergeSharedMemoryAllocations moves before AnnotateDeviceRegions (the only legal position: LowerDeviceKernelLaunch requires at most one dyn-shmem allocation per kernel). - MakePackedAPI moves after LowerDeviceKernelLaunch (Lower's calling_conv flag causes MakePackedAPI to correctly skip device kernels; host body's lowered tvm_call_packed is transparent to MakePackedAPI's subroutine rewriter). - FP8StorageLegalize/BF16StorageLegalize move after MakePackedAPI (their buffer_map.size()==0 ICHECK requires MakePackedAPI to have cleared the map). Prereq for Phase 2: collapsing the three into a single tirx.transform.SplitHostDevice with three commented regions. --- python/tvm/s_tir/backend/adreno/pipeline.py | 5 ++--- python/tvm/s_tir/pipeline.py | 5 ++--- python/tvm/tirx/compilation_pipeline.py | 6 +++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index 85359b1d35aa..618970b37e66 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ + s_tir.transform.MergeSharedMemoryAllocations(), tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - # MergeSharedMemoryAllocations must follow SplitHostDevice. - s_tir.transform.MergeSharedMemoryAllocations(), + tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), - tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index 33a16b381fea..a127e43a0ebd 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ + s_tir.transform.MergeSharedMemoryAllocations(), tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - # MergeSharedMemoryAllocations must follow SplitHostDevice. - s_tir.transform.MergeSharedMemoryAllocations(), + tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), - tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) diff --git a/python/tvm/tirx/compilation_pipeline.py b/python/tvm/tirx/compilation_pipeline.py index 30facc2663c6..f964f50668be 100644 --- a/python/tvm/tirx/compilation_pipeline.py +++ b/python/tvm/tirx/compilation_pipeline.py @@ -50,10 +50,10 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tirx.transform.AnnotateEntryFunc(), tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), + tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), - tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) @@ -91,10 +91,10 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tirx.transform.AnnotateEntryFunc(), tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), + tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), - tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) @@ -124,8 +124,8 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tirx.transform.AnnotateEntryFunc(), tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.MakePackedAPI(), tirx.transform.LowerDeviceKernelLaunch(), + tirx.transform.MakePackedAPI(), ] return tvm.ir.transform.Sequential(passes)(mod) From c788198098017a83291422b92dc1f49b8c1c014a Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 12:01:48 +0000 Subject: [PATCH 3/6] [PERF] Restore fast-path early-exits in MergeSharedMemoryAllocations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Gemini perf review on the merge-shmem refactor: the scope-stack refactor lost the original fast-paths that skipped liveness/planning/rewriting when there are 0 or 1 shmem allocations to merge. Restore both: - Per-scope: a thread_extent block with ≤1 shmem alloc skips the per-scope merging machinery. - Function-level: a PrimFunc with ≤1 shmem alloc of the relevant kind short-circuits the entire rewriter invocation. Behavior is unchanged; this is purely performance. --- .../merge_shared_memory_allocations.cc | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index 3c5b7cd0782d..d0307225267a 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -357,6 +357,15 @@ class SharedMemoryRewriter : public StmtExprMutator { collector(op->body); scope.shmem_allocs = std::move(collector.shmem_allocs_); + // Per-scope early bail-out: if this thread_extent block has ≤1 shmem + // allocation, there is nothing to merge. Skip liveness analysis, + // memory planning, and rewriting entirely. + if (scope.shmem_allocs.size() <= 1) { + scope_stack_.pop_back(); + in_thread_env_ = false; + return StmtExprMutator::VisitStmt_(op); + } + // 3. Liveness + reuse plan over this subtree only. // Run the finder on the full AttrStmt (not just op->body) so that // VisitNewScope creates the proper scope pair entry for the thread_extent. @@ -794,11 +803,24 @@ class SharedMemoryRewriter : public StmtExprMutator { }; Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) { - SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true); - stmt = dyn_rewriter(std::move(stmt)); + // Function-level early-out: skip the rewriter entirely if the PrimFunc + // has ≤1 dynamic shared-memory allocation (nothing to merge). + { + AllocateCollector dyn_probe(/*is_dynamic=*/true); + dyn_probe(stmt); + if (dyn_probe.shmem_allocs_.size() > 1) { + SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true); + stmt = dyn_rewriter(std::move(stmt)); + } + } if (merge_static_smem) { - SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false); - stmt = static_rewriter(std::move(stmt)); + // Similarly skip the static rewriter if there is ≤1 static shmem alloc. + AllocateCollector static_probe(/*is_dynamic=*/false); + static_probe(stmt); + if (static_probe.shmem_allocs_.size() > 1) { + SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false); + stmt = static_rewriter(std::move(stmt)); + } } return stmt; } From 7a0f7e371ef2cfa95f3dbfabb05624302dec713e Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:09:55 +0000 Subject: [PATCH 4/6] [FIX][S-TIR] Use captured StorageEntry pointer after erasing free-map iterator scope.const_free_map.erase(it) invalidates it; the subsequent it->second dereference is undefined behavior. Capture the StorageEntry* into e before the erase and use e afterward. Flagged by Gemini reviewer on PR #19605. --- src/s_tir/transform/merge_shared_memory_allocations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index d0307225267a..d1417943c327 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -721,7 +721,7 @@ class SharedMemoryRewriter : public StmtExprMutator { StorageEntry* e = it->second; e->const_nbits = std::max(const_nbits, e->const_nbits); scope.const_free_map.erase(it); - it->second->allocs.push_back({buf->data.get()}); + e->allocs.push_back({buf->data.get()}); return e; } // Then start looking at smaller buffers. From 6d5120429c8e930e2272d7fc5547fc29f53d31a3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 15:10:07 +0000 Subject: [PATCH 5/6] [FIX][TIR] Robustify LowerDeviceKernelLaunch against pre-MakePackedAPI host targets In the pipeline order that places LowerDeviceKernelLaunch before MakePackedAPI, the host PrimFunc still carries Target("cuda", host="llvm") when Lower visits it. The same-target shortcut at the call site compared caller->WithoutHost() against the device kernel's target, which produced cuda == cuda and silently skipped both the host-call rewriting and the kernel-attribute assignment. The kernel was then emitted with the default calling convention, which CUDA codegen lowers as __device__ __launch_bounds__, rejected by nvcc. The shortcut is meant for intra-device subroutine calls between device-resident functions, not for a host caller whose target happens to share a string with the kernel after WithoutHost(). Track whether the current caller is a host function (its kTarget has a host attached) and skip the same-target / same-device-type shortcuts when a host caller invokes a real device kernel (callee has non-empty launch_params). Pure intra-device subroutine calls and host-side extern subroutine calls are unaffected. This makes Lower order-independent with respect to MakePackedAPI's host-target rewrite and is a prerequisite for keeping AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch consecutive in every pipeline. --- .../transform/lower_device_kernel_launch.cc | 68 ++++++++++++++----- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/src/tirx/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc index 9b38c4d629dd..15a10a5ee28b 100644 --- a/src/tirx/transform/lower_device_kernel_launch.cc +++ b/src/tirx/transform/lower_device_kernel_launch.cc @@ -213,6 +213,13 @@ class DeviceKernelMutator : public StmtExprMutator { auto it = device_info_map_.find(gvar.get()); TVM_FFI_ICHECK(it != device_info_map_.end()); current_target_ = it->second.target; + // Track whether the caller is a host function (i.e. its target + // still has a host attached). The same-target shortcut at the + // call site is only safe when caller and callee are both device- + // resident; a host caller must take the kernel-launch path even + // if Target::WithoutHost() makes the strings match. + current_caller_is_host_ = + func->GetAttr(tvm::attr::kTarget).value()->GetHost().defined(); auto body = VisitStmt(func->body); if (!body.same_as(func->body)) { @@ -220,6 +227,7 @@ class DeviceKernelMutator : public StmtExprMutator { } current_target_ = std::nullopt; + current_caller_is_host_ = false; return func; } @@ -275,26 +283,45 @@ class DeviceKernelMutator : public StmtExprMutator { auto caller_target = current_target_.value(); auto callee_target = dev_info.target; - bool same_target = caller_target->str() == callee_target->str(); - if (same_target) { - // Calls within the same target may be handled at codegen time - // as internal subroutine calls. - return node; - } + // A callee with non-empty launch_params has thread_extent + // bindings in its body, i.e. it is a real device kernel that + // must be invoked via a kernel-launch ABI. In that case the + // same-target / same-device-type shortcuts below are unsafe + // when the caller is a host function: Target::WithoutHost() + // on a host caller can make the strings (or device types) + // match the kernel's, but the call still crosses the + // host/device boundary and must be lowered to a kernel + // launch. Otherwise codegen emits the kernel with default + // calling convention (__device__ instead of __global__) and + // the host body is left with a raw GlobalVar call that no + // subsequent pass will rewrite (MakePackedAPI's subroutine + // rewriter skips functions that already have a non-default + // calling convention or no global symbol). + bool callee_is_kernel = dev_info.launch_params.size() > 0; + bool force_kernel_launch = callee_is_kernel && current_caller_is_host_; + + if (!force_kernel_launch) { + bool same_target = caller_target->str() == callee_target->str(); + if (same_target) { + // Calls within the same target may be handled at codegen time + // as internal subroutine calls. + return node; + } - bool same_device_type = - caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); - if (same_device_type) { - // Calls to another target using the same device (e.g. LLVM - // calling a custom TIRToRuntime target) do not require a kernel - // launch, but need to be replaced with call_extern. - extern_function_call_.insert(gvar); - ffi::Array args; - args.push_back(StringImm(gvar->name_hint)); - for (const auto& arg : node->args) { - args.push_back(arg); + bool same_device_type = + caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); + if (same_device_type) { + // Calls to another target using the same device (e.g. LLVM + // calling a custom TIRToRuntime target) do not require a kernel + // launch, but need to be replaced with call_extern. + extern_function_call_.insert(gvar); + ffi::Array args; + args.push_back(StringImm(gvar->name_hint)); + for (const auto& arg : node->args) { + args.push_back(arg); + } + return Call(node->dtype, builtin::call_extern(), args); } - return Call(node->dtype, builtin::call_extern(), args); } TVM_FFI_ICHECK(dev_info.launch_params.defined()) @@ -336,6 +363,11 @@ class DeviceKernelMutator : public StmtExprMutator { } ffi::Optional current_target_; + // True iff the caller currently being rewritten is a host function + // (its kTarget had a host attached). Used to suppress the + // same-target shortcut when a host calls a device kernel that + // happens to share a target string after WithoutHost(). + bool current_caller_is_host_{false}; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; std::unordered_set extern_function_call_; From e02859ab425d9ebfebe3707f0982124fe372b837 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 20:11:44 +0000 Subject: [PATCH 6/6] [FIX][TIR] Route host-to-host helper calls through call_extern in LowerDeviceKernelLaunch When SplitHostDevice emits a host-side helper (e.g. an "add_host" with target "c") for a private subroutine that is called from both host and device contexts, the host caller still carries its full "cuda+host=c" target at the time LowerDeviceKernelLaunch runs in the new pipeline order. The same-target / same-device-type comparisons used the caller's WithoutHost() target ("cuda") against the callee's host target ("c"), making them appear cross-device and falling through to the kernel-launch path. UpdateKernelAttributes then ran ReturnRemover on the host helper's body, which contains a real `T.ret(a+b)`, tripping the ICHECK that "device kernel may only contain T.ret(0)". The previous robustification only suppressed the same-target shortcut when the caller is host AND the callee is a real kernel (launch_params non-empty). It did not address the symmetric case of a host caller invoking another host helper across host targets. Capture the caller's host target separately and use it (in place of the WithoutHost() device target) when the caller is a host function. Host-to-host calls now correctly compare host targets and route to the same-target shortcut or call_extern, never to the kernel-launch ABI. Host-to-kernel calls remain forced through kernel launch. Pure intra-device subroutine calls (callee_target on the device, caller without host attached) are unaffected. Verified with tests/python/codegen/test_target_codegen_cuda.py:: test_device_host_call_same_func[nvcc,nvrtc] (previously failing on this PR, passing on upstream/main). --- .../transform/lower_device_kernel_launch.cc | 73 ++++++++++++------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/src/tirx/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc index 15a10a5ee28b..af30af6bfb37 100644 --- a/src/tirx/transform/lower_device_kernel_launch.cc +++ b/src/tirx/transform/lower_device_kernel_launch.cc @@ -214,12 +214,20 @@ class DeviceKernelMutator : public StmtExprMutator { TVM_FFI_ICHECK(it != device_info_map_.end()); current_target_ = it->second.target; // Track whether the caller is a host function (i.e. its target - // still has a host attached). The same-target shortcut at the - // call site is only safe when caller and callee are both device- - // resident; a host caller must take the kernel-launch path even - // if Target::WithoutHost() makes the strings match. - current_caller_is_host_ = - func->GetAttr(tvm::attr::kTarget).value()->GetHost().defined(); + // still has a host attached) and capture its host target. The + // same-target shortcut at the call site is only safe when caller + // and callee are both device-resident; a host caller must take + // the kernel-launch path even if Target::WithoutHost() makes the + // strings match. Conversely, a host caller invoking another host + // helper (e.g. a same-target subroutine that SplitHostDevice + // emitted on the host side) should compare against the host + // target, not the device target stripped by WithoutHost(). + auto full_target = func->GetAttr(tvm::attr::kTarget).value(); + if (full_target->GetHost().defined()) { + current_caller_host_target_ = full_target->GetHost().value(); + } else { + current_caller_host_target_ = std::nullopt; + } auto body = VisitStmt(func->body); if (!body.same_as(func->body)) { @@ -227,7 +235,7 @@ class DeviceKernelMutator : public StmtExprMutator { } current_target_ = std::nullopt; - current_caller_is_host_ = false; + current_caller_host_target_ = std::nullopt; return func; } @@ -280,25 +288,36 @@ class DeviceKernelMutator : public StmtExprMutator { << gvar->name_hint << " did not appear within the IRModule"; const KernelInfo& dev_info = it->second; - auto caller_target = current_target_.value(); auto callee_target = dev_info.target; // A callee with non-empty launch_params has thread_extent // bindings in its body, i.e. it is a real device kernel that - // must be invoked via a kernel-launch ABI. In that case the - // same-target / same-device-type shortcuts below are unsafe - // when the caller is a host function: Target::WithoutHost() - // on a host caller can make the strings (or device types) - // match the kernel's, but the call still crosses the - // host/device boundary and must be lowered to a kernel - // launch. Otherwise codegen emits the kernel with default - // calling convention (__device__ instead of __global__) and - // the host body is left with a raw GlobalVar call that no - // subsequent pass will rewrite (MakePackedAPI's subroutine - // rewriter skips functions that already have a non-default - // calling convention or no global symbol). + // must be invoked via a kernel-launch ABI. Conversely a callee + // with empty launch_params is a plain subroutine (host helper + // or intra-device helper) and is never invoked via kernel launch. bool callee_is_kernel = dev_info.launch_params.size() > 0; - bool force_kernel_launch = callee_is_kernel && current_caller_is_host_; + bool caller_is_host = current_caller_host_target_.has_value(); + + // For host callers, comparisons against the callee target must + // use the caller's *host* target, not the device target stripped + // by WithoutHost(). This handles two cases that the device-side + // comparison gets wrong: + // 1. A host caller invoking a real device kernel whose + // WithoutHost() target happens to match (e.g. kernel target + // "cuda" matches "cuda+host=c" after stripping host). Must + // go through kernel launch, not the same-target shortcut. + // 2. A host caller invoking another host helper with a + // different host target (e.g. SplitHostDevice emits an + // "add_host" with target "c" while the host body still + // carries "cuda+host=c"). Must go through call_extern (or + // same-target subroutine), not kernel launch. + auto caller_target = + caller_is_host ? current_caller_host_target_.value() : current_target_.value(); + + // A host caller invoking a real device kernel must always go + // through the kernel-launch ABI, regardless of any same-target / + // same-device-type coincidence. + bool force_kernel_launch = callee_is_kernel && caller_is_host; if (!force_kernel_launch) { bool same_target = caller_target->str() == callee_target->str(); @@ -363,11 +382,13 @@ class DeviceKernelMutator : public StmtExprMutator { } ffi::Optional current_target_; - // True iff the caller currently being rewritten is a host function - // (its kTarget had a host attached). Used to suppress the - // same-target shortcut when a host calls a device kernel that - // happens to share a target string after WithoutHost(). - bool current_caller_is_host_{false}; + // The host target of the caller currently being rewritten, if the + // caller is a host function (its kTarget has a host attached). + // Used both to detect that the caller is a host function and to + // compare against the callee target on the host side, so that + // host-to-host subroutine calls are not misrouted through the + // device kernel-launch ABI. + ffi::Optional current_caller_host_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; std::unordered_set extern_function_call_;