From 2029121b55fb1658fe655091b0a325e2f04e548f Mon Sep 17 00:00:00 2001 From: mitiskuma Date: Fri, 6 Mar 2026 01:56:32 +0100 Subject: [PATCH 1/4] [Metal] Batch compute dispatches into single command buffer, add staging pool for CPU->GPU copies Benchmark results (Metal, M4 Max, MLC-LLM serve, temperature=0): 256 decode tokens: Qwen2.5-0.5B-Instruct-q4f16_1: 238 t/s -> 466 t/s (1.95x) Qwen2.5-1.5B-Instruct-q4f16_1: 177 t/s -> 239 t/s (1.35x) Qwen2.5-3B-Instruct-q4f16_1: 114 t/s -> 139 t/s (1.21x) Llama-3.1-8B-Instruct-q4f16_1: 76 t/s -> 89 t/s (1.18x) 1024 decode tokens: Qwen2.5-0.5B-Instruct-q4f16_1: 239 t/s -> 398 t/s (1.67x) Qwen2.5-1.5B-Instruct-q4f16_1: 137 t/s -> 190 t/s (1.38x) Qwen2.5-3B-Instruct-q4f16_1: 92 t/s -> 115 t/s (1.25x) Llama-3.1-8B-Instruct-q4f16_1: 70 t/s -> 80 t/s (1.14x) Baseline and optimized use the same MLC-LLM, same compiled models, only the TVM Metal runtime differs. Servers run sequentially (not parallel) to avoid GPU contention. Each run preceded by 2 warmup requests. The speedup is larger on smaller models because they are dispatch-bound (262 dispatches/token for 0.5B vs 394 for 8B). Larger models spend more time in actual compute, so the per-dispatch overhead is a smaller fraction. At 1024 tokens the 0.5B speedup drops from 1.95x to 1.67x because KV cache growth increases per-token compute, shifting the bottleneck toward memory bandwidth. What changed: 1. Batched compute dispatch. Kernel dispatches are accumulated in a single MTLCommandBuffer via a shared MTLComputeCommandEncoder. Previously each dispatch created its own command buffer and committed immediately. The pending encoder is flushed on GPU->CPU readback, buffer deallocation, or stream sync. 2. Inline blit encoders for copies. CPU->GPU and GPU->GPU copies now use blit encoders on the same pending command buffer instead of creating a separate command buffer per copy. Metal guarantees sequential ordering of encoders within a command buffer, so no explicit sync is needed between compute and copy operations. 3. Staging buffer pool for CPU->GPU copies. Each inlined CPU->GPU copy needs its own staging buffer because the GPU reads them asynchronously from the deferred command buffer. A per-device StagingBufferPool hands out shared-mode buffers and recycles them after flush/sync. 4. Conditional sync in FreeDataSpace. Instead of always calling StreamSync, we check HasPendingWork() first. When the GPU->CPU readback path has already flushed and waited, FreeDataSpace can skip the redundant sync. --- src/runtime/metal/metal_common.h | 205 +++++++++++++++++++++++++- src/runtime/metal/metal_device_api.mm | 99 +++++++++---- src/runtime/metal/metal_module.mm | 19 +-- 3 files changed, 274 insertions(+), 49 deletions(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 8d72fac97a8a..66fdb6d7c3a4 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -103,13 +103,34 @@ class AutoReleasePoolWrapper { }; /*! - * \brief Structure for error handling in queues + * \brief Metal command stream with batched dispatch support. + * + * Compute dispatches are batched into a single command buffer via + * GetPendingComputeEncoder(). Blit operations (copies) are interleaved + * on the same command buffer via GetBlitEncoderOnPendingBuffer(). + * The command buffer is committed when FlushCommandBuffer() is called. + * + * Must call FlushCommandBuffer() before: + * - GPU→CPU readback (need data in CPU memory) + * - Buffer deallocation (FreeDataSpace, setPurgeableState:Empty on + * a buffer referenced by an uncommitted CB crashes Metal) + * - Stream sync (StreamSync / Synchronize) */ class Stream { public: explicit Stream(id device) { queue_ = [device newCommandQueue]; } - ~Stream() { [queue_ release]; } - id GetCommandBuffer(std::string label = "", bool attach_error_callback = true) { + ~Stream() { + FlushCommandBuffer(); + [queue_ release]; + } + + /*! + * \brief Get a standalone command buffer (for GPU→CPU readback only). + * + * Used when we need a separate command buffer that we can commit + * and waitUntilCompleted on independently. + */ + id GetCommandBuffer(std::string label = "") { id cb = [queue_ commandBuffer]; if (!label.empty()) { cb.label = [NSString stringWithUTF8String:label.c_str()]; @@ -123,6 +144,99 @@ class Stream { return cb; } + /*! + * \brief Get the pending compute command encoder, creating one if needed. + * + * Multiple compute dispatches are batched into a single command buffer + * and encoder. Blit operations (copies) can be interleaved on the same + * command buffer via GetBlitEncoderOnPendingBuffer(). The entire command + * buffer is committed when FlushCommandBuffer() is called. + * + * Must flush before: + * - GPU→CPU readback (need data on CPU immediately) + * - Buffer deallocation (FreeDataSpace) + * - Stream sync (StreamSync) + */ + id GetPendingComputeEncoder(const std::string& kernel_name = "") { + if (pending_compute_encoder_ == nil) { + id cb = GetOrCreatePendingCommandBuffer(); + pending_compute_encoder_ = [[cb computeCommandEncoder] retain]; + } + if (!kernel_name.empty()) { + last_dispatched_kernel_ = kernel_name; + } + profile.dispatches++; + return pending_compute_encoder_; + } + + /*! + * \brief Get a blit encoder on the pending command buffer. + * + * Pauses the active compute encoder (if any), creates a blit encoder + * on the same command buffer. Caller must call [encoder endEncoding] + * when done. The next GetPendingComputeEncoder() call will create a + * new compute encoder on the same command buffer. + * + * Metal guarantees sequential ordering of encoders within a command + * buffer, so blits encoded here execute after prior compute dispatches + * and before subsequent ones. + */ + id GetBlitEncoderOnPendingBuffer() { + PauseComputeEncoder(); + id cb = GetOrCreatePendingCommandBuffer(); + profile.blits++; + return [cb blitCommandEncoder]; + } + + /*! + * \brief Flush: end active encoder, commit the command buffer. + * + * Safe to call when nothing is pending (no-op). + */ + void FlushCommandBuffer() { + PauseComputeEncoder(); + if (pending_command_buffer_ != nil) { + [pending_command_buffer_ commit]; + [pending_command_buffer_ release]; + pending_command_buffer_ = nil; + profile.flushes++; + } + } + + /*! + * \brief Flush pending work, then wait for all submitted work to complete. + */ + void Synchronize() { + FlushCommandBuffer(); + id cb = [queue_ commandBuffer]; + [cb addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) { + TVM_FFI_ICHECK(buffer.error != nil); + this->SetError(buffer.error.localizedDescription.UTF8String); + } + }]; + [cb commit]; + [cb waitUntilCompleted]; + profile.syncs++; + } + + bool HasPendingWork() const { return pending_command_buffer_ != nil; } + + /*! \brief Profiling counters for diagnosing dispatch/copy/sync overhead. */ + struct ProfileCounters { + size_t dispatches = 0; + size_t flushes = 0; + size_t syncs = 0; + size_t blits = 0; + size_t gpu_to_cpu = 0; + size_t cpu_to_gpu = 0; + size_t gpu_to_gpu = 0; + size_t free_syncs = 0; // FreeDataSpace calls that triggered a sync + + void Reset() { *this = ProfileCounters(); } + }; + ProfileCounters profile; + void SetError(std::string error_description) { error_happened_ = true; error_description_ = std::move(error_description); @@ -133,8 +247,42 @@ class Stream { const std::string& ErrorDescription() const { return error_description_; } private: + /*! \brief Get or create the pending command buffer (shared by compute and blit). */ + id GetOrCreatePendingCommandBuffer() { + if (pending_command_buffer_ == nil) { + pending_command_buffer_ = [[queue_ commandBuffer] retain]; + pending_command_buffer_.label = @"TVMBatched"; + [pending_command_buffer_ addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) { + TVM_FFI_ICHECK(buffer.error != nil); + std::string msg = buffer.error.localizedDescription.UTF8String; + if (!this->last_dispatched_kernel_.empty()) { + msg = "GPUError after kernel " + this->last_dispatched_kernel_ + ": " + msg; + } + this->SetError(msg); + } + }]; + } + return pending_command_buffer_; + } + + /*! \brief End the active compute encoder without committing the command buffer. */ + void PauseComputeEncoder() { + if (pending_compute_encoder_ != nil) { + [pending_compute_encoder_ endEncoding]; + [pending_compute_encoder_ release]; + pending_compute_encoder_ = nil; + } + } + // Queue id queue_; + // Pending command buffer (shared by compute and blit encoders) + id pending_command_buffer_ = nil; + // Active compute encoder on the pending command buffer (nil when paused/blit) + id pending_compute_encoder_ = nil; + // Last dispatched kernel name (for error diagnostics) + std::string last_dispatched_kernel_; // Check if error happened in one previous run bool error_happened_{false}; // error description @@ -201,8 +349,50 @@ class MetalThreadEntry { Device device; /*! \brief The current stream */ std::vector stream; - /*! \brief The shared buffer used for copy. */ + /*! \brief The shared buffer used for GPU→CPU readback. */ std::vector> temp_buffer_; + /*! + * \brief Pool of staging buffers for CPU→GPU copies that are inlined + * into the pending command buffer. Each inlined copy needs its own + * staging buffer because the GPU reads them asynchronously. + * Buffers are recycled after FlushCommandBuffer()/Synchronize(). + */ + struct StagingBufferPool { + struct Entry { + id buffer = nil; + size_t size = 0; + }; + std::vector pool; + size_t next_index = 0; // round-robin within current batch + + id GetOrCreate(id dev, size_t nbytes) { + if (next_index < pool.size() && pool[next_index].size >= nbytes) { + return pool[next_index++].buffer; + } + // Need a new or bigger buffer at this index + if (next_index < pool.size() && pool[next_index].buffer != nil) { + [pool[next_index].buffer release]; + } + if (next_index >= pool.size()) { + pool.push_back({nil, 0}); + } + pool[next_index].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared]; + pool[next_index].size = nbytes; + return pool[next_index++].buffer; + } + + // Called after flush/sync, all staging buffers are safe to reuse + void ResetIndex() { next_index = 0; } + + ~StagingBufferPool() { + for (auto& e : pool) { + if (e.buffer != nil) { + [e.buffer release]; + } + } + } + }; + std::vector staging_pools_; // per device /*! \brief workspace pool */ WorkspacePool pool; // constructor @@ -210,13 +400,14 @@ class MetalThreadEntry { device.device_id = 0; device.device_type = static_cast(kDLMetal); MetalWorkspace* global_ws = MetalWorkspace::Global(); - // by default, set the stream to nullptr, which indicate - // that we are using default stream this->stream.resize(global_ws->devices.size(), nullptr); + this->staging_pools_.resize(global_ws->devices.size()); } ~MetalThreadEntry(); - // Get temp buffer with at least size under dev. + // Get temp buffer with at least size under dev (for GPU→CPU readback). id GetTempBuffer(Device dev, size_t size); + // Get a staging buffer for inlined CPU→GPU copy (from pool). + id GetOrCreateStagingBuffer(Device dev, size_t size); // get the global workspace static MetalThreadEntry* ThreadLocal(); }; diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 5ff9c2dfcd9b..9d0dc5a180a9 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -200,14 +200,18 @@ int GetWarpSize(id dev) { void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) { AUTORELEASEPOOL { - // need to make sure buffer is not in use in command buffer - // before set the purgeable state to empty - // otherwise can cause issues sometimes - this->StreamSync(dev, nullptr); - // MTLBuffer PurgeableState should be set to empty before manual - // release in order to prevent memory leak + Stream* s = CastStreamOrGetDefault(nullptr, dev.device_id); + if (s->HasPendingWork()) { + s->profile.free_syncs++; + // Buffer may be referenced by pending compute/blit encoders. + // Must fully sync, setPurgeableState:Empty on a buffer in an + // uncommitted or incomplete CB crashes Metal. + this->StreamSync(dev, nullptr); + } + // No pending work, safe to release immediately. + // Either nothing was dispatched since last sync, or the GPU→CPU + // readback path already flushed+waited. [(id)ptr setPurgeableState:MTLPurgeableStateEmpty]; - // release the ptr. CFRelease(ptr); }; } @@ -229,25 +233,30 @@ int GetWarpSize(id dev) { if (s->HasErrorHappened()) { LOG(FATAL) << "GPUError: " << s->ErrorDescription(); } - id cb = s->GetCommandBuffer(/*label=*/"TVMCopyDataFromTo"); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) { + s->profile.gpu_to_gpu++; + // GPU→GPU: inline blit into the pending command buffer. + // No flush needed, Metal guarantees encoder ordering within a CB. TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Metal disallow cross device copy."; - id encoder = [cb blitCommandEncoder]; + id encoder = s->GetBlitEncoderOnPendingBuffer(); [encoder copyFromBuffer:(id)(from) sourceOffset:from_offset toBuffer:(id)(to)destinationOffset:to_offset size:size]; [encoder endEncoding]; - [cb commit]; + } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) { - // copy to a local buffer before get into global buffer. + s->profile.gpu_to_cpu++; + // GPU→CPU: must flush and wait, we need data in CPU memory. + s->FlushCommandBuffer(); id from_buf = (id)(from); if (from_buf.storageMode != MTLStorageModeShared) { id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_from, size); + id cb = s->GetCommandBuffer("TVMCopyGPUtoCPU"); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:from_buf sourceOffset:from_offset @@ -262,24 +271,31 @@ int GetWarpSize(id dev) { memcpy(static_cast(to) + to_offset, static_cast([from_buf contents]) + from_offset, size); } + } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) { + s->profile.cpu_to_gpu++; + // CPU→GPU: inline blit into the pending command buffer. + // We use a staging buffer from the pool (not the single temp_buffer_) + // so multiple CPU→GPU copies can be inlined before a flush. id to_buf = (id)(to); if (to_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(dev_to, size); - memcpy([temp contents], static_cast(from) + from_offset, size); - id encoder = [cb blitCommandEncoder]; - [encoder copyFromBuffer:temp + MetalThreadEntry* t = MetalThreadEntry::ThreadLocal(); + id staging = t->GetOrCreateStagingBuffer(dev_to, size); + memcpy([staging contents], static_cast(from) + from_offset, size); + id encoder = s->GetBlitEncoderOnPendingBuffer(); + [encoder copyFromBuffer:staging sourceOffset:0 toBuffer:to_buf destinationOffset:to_offset size:size]; [encoder endEncoding]; - [cb commit]; - [cb waitUntilCompleted]; + // No flush, no wait. Metal executes encoders in order within the CB. + // The staging buffer stays alive until flush, when the pool resets. } else { memcpy(static_cast([to_buf contents]) + to_offset, static_cast(from) + from_offset, size); } + } else { LOG(FATAL) << "Expect copy from/to Metal or between Metal" << ", from=" << from_dev_type << ", to=" << to_dev_type; @@ -302,10 +318,9 @@ int GetWarpSize(id dev) { void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { Stream* s = CastStreamOrGetDefault(stream, dev.device_id); - // commit an empty command buffer and wait until it completes. - id cb = s->GetCommandBuffer(/*label=*/"TVMStreamSync"); - [cb commit]; - [cb waitUntilCompleted]; + s->Synchronize(); + // After sync, all staging buffers are safe to reuse. + MetalThreadEntry::ThreadLocal()->staging_pools_[dev.device_id].ResetIndex(); if (s->HasErrorHappened()) { LOG(FATAL) << "GPUError: " << s->ErrorDescription(); } @@ -336,10 +351,17 @@ int GetWarpSize(id dev) { if (temp_buffer_[dev.device_id] == nil || temp_buffer_[dev.device_id].length < size) { id mtl_dev = MetalWorkspace::Global()->GetDevice(dev); if (temp_buffer_[dev.device_id] != nil) { - // need to make sure buffer is not in use in command buffer - // before set the purgeable state to empty - // otherwise can cause issues sometimes - MetalWorkspace::Global()->StreamSync(dev, nullptr); + // The caller (GPU→CPU path in CopyDataFromTo) already called + // FlushCommandBuffer() before calling us, so all pending work + // using this buffer has been committed. We just need to wait + // for completion before releasing. + auto* ws = MetalWorkspace::Global(); + Stream* s = ws->CastStreamOrGetDefault(nullptr, dev.device_id); + if (s->HasPendingWork()) { + // Only sync if there's actually pending work (shouldn't happen + // since caller flushed, but be safe). + ws->StreamSync(dev, nullptr); + } [temp_buffer_[dev.device_id] setPurgeableState:MTLPurgeableStateEmpty]; [temp_buffer_[dev.device_id] release]; } @@ -348,6 +370,11 @@ int GetWarpSize(id dev) { return temp_buffer_[dev.device_id]; } +id MetalThreadEntry::GetOrCreateStagingBuffer(Device dev, size_t size) { + id mtl_dev = MetalWorkspace::Global()->GetDevice(dev); + return staging_pools_[dev.device_id].GetOrCreate(mtl_dev, size); +} + MetalThreadEntry* MetalThreadEntry::ThreadLocal() { static thread_local MetalThreadEntry inst; return &inst; @@ -362,7 +389,27 @@ int GetWarpSize(id dev) { *rv = static_cast(ptr); }) .def("metal.ResetGlobalState", - []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); + []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }) + .def("metal.GetProfileCounters", + [](int device_id) { + auto* ws = MetalWorkspace::Global(); + Stream* s = ws->CastStreamOrGetDefault(nullptr, device_id); + const auto& p = s->profile; + ffi::Map result; + result.Set("dispatches", static_cast(p.dispatches)); + result.Set("flushes", static_cast(p.flushes)); + result.Set("syncs", static_cast(p.syncs)); + result.Set("blits", static_cast(p.blits)); + result.Set("gpu_to_cpu", static_cast(p.gpu_to_cpu)); + result.Set("cpu_to_gpu", static_cast(p.cpu_to_gpu)); + result.Set("gpu_to_gpu", static_cast(p.gpu_to_gpu)); + result.Set("free_syncs", static_cast(p.free_syncs)); + return result; + }) + .def("metal.ResetProfileCounters", [](int device_id) { + auto* ws = MetalWorkspace::Global(); + ws->CastStreamOrGetDefault(nullptr, device_id)->profile.Reset(); + }); } class MetalTimerNode : public TimerNode { diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 0066b651fc7d..6837404ad3be 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -213,10 +213,9 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; TVM_FFI_ICHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); - // attach error message directly in this functio - id cb = stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_, - /*attach_error_callback=*/false); - id encoder = [cb computeCommandEncoder]; + // Reuse the pending compute encoder to batch dispatches. + // The encoder is flushed on sync, copy, or buffer deallocation. + id encoder = stream->GetPendingComputeEncoder(func_name_); [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { void* buf = args[static_cast(i)].cast(); @@ -231,18 +230,6 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock]; - [encoder endEncoding]; - // attach error message with function name - [cb addCompletedHandler:^(id buffer) { - if (buffer.status == MTLCommandBufferStatusError) { - TVM_FFI_ICHECK(buffer.error != nil); - std::ostringstream os; - os << "GPUError happens after running " << func_name_ << ": " - << buffer.error.localizedDescription.UTF8String; - stream->SetError(os.str()); - } - }]; - [cb commit]; }; } From a3eb5d19343ca76c837ee3add0407ee4a9f77eee Mon Sep 17 00:00:00 2001 From: mitiskuma Date: Fri, 6 Mar 2026 02:07:24 +0100 Subject: [PATCH 2/4] [Metal] Add nil check for staging buffer allocation, fix comment --- src/runtime/metal/metal_common.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 66fdb6d7c3a4..ca3d5b50c24a 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -363,7 +363,7 @@ class MetalThreadEntry { size_t size = 0; }; std::vector pool; - size_t next_index = 0; // round-robin within current batch + size_t next_index = 0; // sequential within current batch, reset on sync id GetOrCreate(id dev, size_t nbytes) { if (next_index < pool.size() && pool[next_index].size >= nbytes) { @@ -377,6 +377,8 @@ class MetalThreadEntry { pool.push_back({nil, 0}); } pool[next_index].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared]; + TVM_FFI_ICHECK(pool[next_index].buffer != nil) + << "Failed to allocate staging buffer of size " << nbytes; pool[next_index].size = nbytes; return pool[next_index++].buffer; } From b18aab8c5d4b3cec75c0602bb5df4bad4bca22d5 Mon Sep 17 00:00:00 2001 From: mitiskuma Date: Fri, 6 Mar 2026 02:36:31 +0100 Subject: [PATCH 3/4] fixes/renamings, Size query for sync threshold --- src/runtime/metal/metal_common.h | 52 ++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index ca3d5b50c24a..b20d924ed6f1 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -119,6 +119,9 @@ class AutoReleasePoolWrapper { class Stream { public: explicit Stream(id device) { queue_ = [device newCommandQueue]; } + // Stream is only destroyed during MetalWorkspace teardown (process exit + // or ReinitializeDefaultStreams), so no GPU work is in flight. We flush + // to commit any pending CB but do not wait for completion. ~Stream() { FlushCommandBuffer(); [queue_ release]; @@ -182,7 +185,7 @@ class Stream { * and before subsequent ones. */ id GetBlitEncoderOnPendingBuffer() { - PauseComputeEncoder(); + EndPendingComputeEncoder(); id cb = GetOrCreatePendingCommandBuffer(); profile.blits++; return [cb blitCommandEncoder]; @@ -194,7 +197,7 @@ class Stream { * Safe to call when nothing is pending (no-op). */ void FlushCommandBuffer() { - PauseComputeEncoder(); + EndPendingComputeEncoder(); if (pending_command_buffer_ != nil) { [pending_command_buffer_ commit]; [pending_command_buffer_ release]; @@ -267,7 +270,7 @@ class Stream { } /*! \brief End the active compute encoder without committing the command buffer. */ - void PauseComputeEncoder() { + void EndPendingComputeEncoder() { if (pending_compute_encoder_ != nil) { [pending_compute_encoder_ endEncoding]; [pending_compute_encoder_ release]; @@ -358,41 +361,46 @@ class MetalThreadEntry { * Buffers are recycled after FlushCommandBuffer()/Synchronize(). */ struct StagingBufferPool { - struct Entry { - id buffer = nil; - size_t size = 0; - }; - std::vector pool; - size_t next_index = 0; // sequential within current batch, reset on sync - + public: id GetOrCreate(id dev, size_t nbytes) { - if (next_index < pool.size() && pool[next_index].size >= nbytes) { - return pool[next_index++].buffer; + if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) { + return pool_[next_index_++].buffer; } // Need a new or bigger buffer at this index - if (next_index < pool.size() && pool[next_index].buffer != nil) { - [pool[next_index].buffer release]; + if (next_index_ < pool_.size() && pool_[next_index_].buffer != nil) { + [pool_[next_index_].buffer release]; } - if (next_index >= pool.size()) { - pool.push_back({nil, 0}); + if (next_index_ >= pool_.size()) { + pool_.push_back({nil, 0}); } - pool[next_index].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared]; - TVM_FFI_ICHECK(pool[next_index].buffer != nil) + pool_[next_index_].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared]; + TVM_FFI_ICHECK(pool_[next_index_].buffer != nil) << "Failed to allocate staging buffer of size " << nbytes; - pool[next_index].size = nbytes; - return pool[next_index++].buffer; + pool_[next_index_].size = nbytes; + return pool_[next_index_++].buffer; } // Called after flush/sync, all staging buffers are safe to reuse - void ResetIndex() { next_index = 0; } + void ResetIndex() { next_index_ = 0; } + + // Number of staging buffers used in the current batch + size_t Size() const { return next_index_; } ~StagingBufferPool() { - for (auto& e : pool) { + for (auto& e : pool_) { if (e.buffer != nil) { [e.buffer release]; } } } + + private: + struct Entry { + id buffer = nil; + size_t size = 0; + }; + std::vector pool_; + size_t next_index_ = 0; // sequential within current batch, reset on sync }; std::vector staging_pools_; // per device /*! \brief workspace pool */ From 7335ab28a4b1cdb972933852d1ee169dcd04e206 Mon Sep 17 00:00:00 2001 From: mitiskuma Date: Fri, 6 Mar 2026 14:11:42 +0100 Subject: [PATCH 4/4] guard on maximum staging buffer --- src/runtime/metal/metal_common.h | 14 ++++++++++++++ src/runtime/metal/metal_device_api.mm | 11 +++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index b20d924ed6f1..cc538f84dce0 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -362,6 +362,13 @@ class MetalThreadEntry { */ struct StagingBufferPool { public: + /*! \brief Maximum staging buffers before requiring a flush. + * Prevents unbounded pool growth in workloads with many CPU→GPU copies + * between syncs. When this limit is reached, the caller must flush the + * stream (to make all pending staging buffers safe to reuse) before + * requesting more buffers. */ + static constexpr size_t kMaxStagingBuffers = 64; + id GetOrCreate(id dev, size_t nbytes) { if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) { return pool_[next_index_++].buffer; @@ -386,6 +393,9 @@ class MetalThreadEntry { // Number of staging buffers used in the current batch size_t Size() const { return next_index_; } + // True when the pool has reached its limit and needs a flush before more allocations + bool NeedsFlush() const { return next_index_ >= kMaxStagingBuffers; } + ~StagingBufferPool() { for (auto& e : pool_) { if (e.buffer != nil) { @@ -418,6 +428,10 @@ class MetalThreadEntry { id GetTempBuffer(Device dev, size_t size); // Get a staging buffer for inlined CPU→GPU copy (from pool). id GetOrCreateStagingBuffer(Device dev, size_t size); + // Check if the staging pool has reached its limit and needs a flush. + bool StagingPoolNeedsFlush(Device dev); + // Reset the staging pool index after a flush. + void ResetStagingPool(Device dev); // get the global workspace static MetalThreadEntry* ThreadLocal(); }; diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 9d0dc5a180a9..f240f589c103 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -280,6 +280,11 @@ int GetWarpSize(id dev) { id to_buf = (id)(to); if (to_buf.storageMode != MTLStorageModeShared) { MetalThreadEntry* t = MetalThreadEntry::ThreadLocal(); + // If the staging pool is full, flush pending work so buffers can be reused. + if (t->StagingPoolNeedsFlush(dev_to)) { + s->FlushCommandBuffer(); + t->ResetStagingPool(dev_to); + } id staging = t->GetOrCreateStagingBuffer(dev_to, size); memcpy([staging contents], static_cast(from) + from_offset, size); id encoder = s->GetBlitEncoderOnPendingBuffer(); @@ -375,6 +380,12 @@ int GetWarpSize(id dev) { return staging_pools_[dev.device_id].GetOrCreate(mtl_dev, size); } +bool MetalThreadEntry::StagingPoolNeedsFlush(Device dev) { + return staging_pools_[dev.device_id].NeedsFlush(); +} + +void MetalThreadEntry::ResetStagingPool(Device dev) { staging_pools_[dev.device_id].ResetIndex(); } + MetalThreadEntry* MetalThreadEntry::ThreadLocal() { static thread_local MetalThreadEntry inst; return &inst;