-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Metal] Batched command dispatch and staging buffer pool #18877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2029121
a3eb5d1
b18aab8
7335ab2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,13 +103,37 @@ class AutoReleasePoolWrapper { | |
| }; | ||
|
|
||
| /*! | ||
| * \brief Structure for error handling in queues | ||
| * \brief Metal command stream with batched dispatch support. | ||
| * | ||
| * Compute dispatches are batched into a single command buffer via | ||
| * GetPendingComputeEncoder(). Blit operations (copies) are interleaved | ||
| * on the same command buffer via GetBlitEncoderOnPendingBuffer(). | ||
| * The command buffer is committed when FlushCommandBuffer() is called. | ||
| * | ||
| * Must call FlushCommandBuffer() before: | ||
| * - GPU→CPU readback (need data in CPU memory) | ||
| * - Buffer deallocation (FreeDataSpace, setPurgeableState:Empty on | ||
| * a buffer referenced by an uncommitted CB crashes Metal) | ||
| * - Stream sync (StreamSync / Synchronize) | ||
| */ | ||
| class Stream { | ||
| public: | ||
| explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; } | ||
| ~Stream() { [queue_ release]; } | ||
| id<MTLCommandBuffer> GetCommandBuffer(std::string label = "", bool attach_error_callback = true) { | ||
| // Stream is only destroyed during MetalWorkspace teardown (process exit | ||
| // or ReinitializeDefaultStreams), so no GPU work is in flight. We flush | ||
| // to commit any pending CB but do not wait for completion. | ||
| ~Stream() { | ||
| FlushCommandBuffer(); | ||
| [queue_ release]; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Get a standalone command buffer (for GPU→CPU readback only). | ||
| * | ||
| * Used when we need a separate command buffer that we can commit | ||
| * and waitUntilCompleted on independently. | ||
| */ | ||
| id<MTLCommandBuffer> GetCommandBuffer(std::string label = "") { | ||
| id<MTLCommandBuffer> cb = [queue_ commandBuffer]; | ||
| if (!label.empty()) { | ||
| cb.label = [NSString stringWithUTF8String:label.c_str()]; | ||
|
|
@@ -123,6 +147,99 @@ class Stream { | |
| return cb; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Get the pending compute command encoder, creating one if needed. | ||
| * | ||
| * Multiple compute dispatches are batched into a single command buffer | ||
| * and encoder. Blit operations (copies) can be interleaved on the same | ||
| * command buffer via GetBlitEncoderOnPendingBuffer(). The entire command | ||
| * buffer is committed when FlushCommandBuffer() is called. | ||
| * | ||
| * Must flush before: | ||
| * - GPU→CPU readback (need data on CPU immediately) | ||
| * - Buffer deallocation (FreeDataSpace) | ||
| * - Stream sync (StreamSync) | ||
| */ | ||
| id<MTLComputeCommandEncoder> GetPendingComputeEncoder(const std::string& kernel_name = "") { | ||
| if (pending_compute_encoder_ == nil) { | ||
| id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer(); | ||
| pending_compute_encoder_ = [[cb computeCommandEncoder] retain]; | ||
| } | ||
| if (!kernel_name.empty()) { | ||
| last_dispatched_kernel_ = kernel_name; | ||
| } | ||
| profile.dispatches++; | ||
| return pending_compute_encoder_; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Get a blit encoder on the pending command buffer. | ||
| * | ||
| * Pauses the active compute encoder (if any), creates a blit encoder | ||
| * on the same command buffer. Caller must call [encoder endEncoding] | ||
| * when done. The next GetPendingComputeEncoder() call will create a | ||
| * new compute encoder on the same command buffer. | ||
| * | ||
| * Metal guarantees sequential ordering of encoders within a command | ||
| * buffer, so blits encoded here execute after prior compute dispatches | ||
| * and before subsequent ones. | ||
| */ | ||
| id<MTLBlitCommandEncoder> GetBlitEncoderOnPendingBuffer() { | ||
| EndPendingComputeEncoder(); | ||
| id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer(); | ||
| profile.blits++; | ||
| return [cb blitCommandEncoder]; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Flush: end active encoder, commit the command buffer. | ||
| * | ||
| * Safe to call when nothing is pending (no-op). | ||
| */ | ||
| void FlushCommandBuffer() { | ||
| EndPendingComputeEncoder(); | ||
| if (pending_command_buffer_ != nil) { | ||
| [pending_command_buffer_ commit]; | ||
| [pending_command_buffer_ release]; | ||
| pending_command_buffer_ = nil; | ||
| profile.flushes++; | ||
| } | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Flush pending work, then wait for all submitted work to complete. | ||
| */ | ||
| void Synchronize() { | ||
| FlushCommandBuffer(); | ||
| id<MTLCommandBuffer> cb = [queue_ commandBuffer]; | ||
| [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) { | ||
| if (buffer.status == MTLCommandBufferStatusError) { | ||
| TVM_FFI_ICHECK(buffer.error != nil); | ||
| this->SetError(buffer.error.localizedDescription.UTF8String); | ||
| } | ||
| }]; | ||
| [cb commit]; | ||
| [cb waitUntilCompleted]; | ||
| profile.syncs++; | ||
| } | ||
|
|
||
| bool HasPendingWork() const { return pending_command_buffer_ != nil; } | ||
|
|
||
| /*! \brief Profiling counters for diagnosing dispatch/copy/sync overhead. */ | ||
| struct ProfileCounters { | ||
| size_t dispatches = 0; | ||
| size_t flushes = 0; | ||
| size_t syncs = 0; | ||
| size_t blits = 0; | ||
| size_t gpu_to_cpu = 0; | ||
| size_t cpu_to_gpu = 0; | ||
| size_t gpu_to_gpu = 0; | ||
| size_t free_syncs = 0; // FreeDataSpace calls that triggered a sync | ||
|
|
||
| void Reset() { *this = ProfileCounters(); } | ||
| }; | ||
| ProfileCounters profile; | ||
|
|
||
| void SetError(std::string error_description) { | ||
| error_happened_ = true; | ||
| error_description_ = std::move(error_description); | ||
|
|
@@ -133,8 +250,42 @@ class Stream { | |
| const std::string& ErrorDescription() const { return error_description_; } | ||
|
|
||
| private: | ||
| /*! \brief Get or create the pending command buffer (shared by compute and blit). */ | ||
| id<MTLCommandBuffer> GetOrCreatePendingCommandBuffer() { | ||
| if (pending_command_buffer_ == nil) { | ||
| pending_command_buffer_ = [[queue_ commandBuffer] retain]; | ||
| pending_command_buffer_.label = @"TVMBatched"; | ||
| [pending_command_buffer_ addCompletedHandler:^(id<MTLCommandBuffer> buffer) { | ||
| if (buffer.status == MTLCommandBufferStatusError) { | ||
| TVM_FFI_ICHECK(buffer.error != nil); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TVM_FFI_ICHECK is the standard error reporting mechanism used throughout the TVM Metal runtime. The pre-existing code (GetCommandBuffer, which this PR does not change) already uses TVM_FFI_ICHECK in the same completion handler pattern. This is not a regression introduced by this PR. |
||
| std::string msg = buffer.error.localizedDescription.UTF8String; | ||
| if (!this->last_dispatched_kernel_.empty()) { | ||
| msg = "GPUError after kernel " + this->last_dispatched_kernel_ + ": " + msg; | ||
| } | ||
| this->SetError(msg); | ||
| } | ||
| }]; | ||
| } | ||
| return pending_command_buffer_; | ||
| } | ||
|
|
||
| /*! \brief End the active compute encoder without committing the command buffer. */ | ||
| void EndPendingComputeEncoder() { | ||
| if (pending_compute_encoder_ != nil) { | ||
| [pending_compute_encoder_ endEncoding]; | ||
| [pending_compute_encoder_ release]; | ||
| pending_compute_encoder_ = nil; | ||
| } | ||
| } | ||
|
|
||
| // Queue | ||
| id<MTLCommandQueue> queue_; | ||
| // Pending command buffer (shared by compute and blit encoders) | ||
| id<MTLCommandBuffer> pending_command_buffer_ = nil; | ||
| // Active compute encoder on the pending command buffer (nil when paused/blit) | ||
| id<MTLComputeCommandEncoder> pending_compute_encoder_ = nil; | ||
| // Last dispatched kernel name (for error diagnostics) | ||
| std::string last_dispatched_kernel_; | ||
| // Check if error happened in one previous run | ||
| bool error_happened_{false}; | ||
| // error description | ||
|
|
@@ -201,22 +352,86 @@ class MetalThreadEntry { | |
| Device device; | ||
| /*! \brief The current stream */ | ||
| std::vector<TVMStreamHandle> stream; | ||
| /*! \brief The shared buffer used for copy. */ | ||
| /*! \brief The shared buffer used for GPU→CPU readback. */ | ||
| std::vector<id<MTLBuffer>> temp_buffer_; | ||
| /*! | ||
| * \brief Pool of staging buffers for CPU→GPU copies that are inlined | ||
| * into the pending command buffer. Each inlined copy needs its own | ||
| * staging buffer because the GPU reads them asynchronously. | ||
| * Buffers are recycled after FlushCommandBuffer()/Synchronize(). | ||
| */ | ||
| struct StagingBufferPool { | ||
|
mitiskuma marked this conversation as resolved.
|
||
| public: | ||
| /*! \brief Maximum staging buffers before requiring a flush. | ||
| * Prevents unbounded pool growth in workloads with many CPU→GPU copies | ||
| * between syncs. When this limit is reached, the caller must flush the | ||
| * stream (to make all pending staging buffers safe to reuse) before | ||
| * requesting more buffers. */ | ||
| static constexpr size_t kMaxStagingBuffers = 64; | ||
|
|
||
| id<MTLBuffer> GetOrCreate(id<MTLDevice> dev, size_t nbytes) { | ||
| if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) { | ||
| return pool_[next_index_++].buffer; | ||
| } | ||
| // Need a new or bigger buffer at this index | ||
| if (next_index_ < pool_.size() && pool_[next_index_].buffer != nil) { | ||
| [pool_[next_index_].buffer release]; | ||
| } | ||
| if (next_index_ >= pool_.size()) { | ||
| pool_.push_back({nil, 0}); | ||
| } | ||
| pool_[next_index_].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared]; | ||
| TVM_FFI_ICHECK(pool_[next_index_].buffer != nil) | ||
| << "Failed to allocate staging buffer of size " << nbytes; | ||
| pool_[next_index_].size = nbytes; | ||
| return pool_[next_index_++].buffer; | ||
| } | ||
|
|
||
| // Called after flush/sync, all staging buffers are safe to reuse | ||
| void ResetIndex() { next_index_ = 0; } | ||
|
|
||
| // Number of staging buffers used in the current batch | ||
| size_t Size() const { return next_index_; } | ||
|
|
||
| // True when the pool has reached its limit and needs a flush before more allocations | ||
| bool NeedsFlush() const { return next_index_ >= kMaxStagingBuffers; } | ||
|
|
||
| ~StagingBufferPool() { | ||
| for (auto& e : pool_) { | ||
| if (e.buffer != nil) { | ||
| [e.buffer release]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private: | ||
| struct Entry { | ||
| id<MTLBuffer> buffer = nil; | ||
| size_t size = 0; | ||
| }; | ||
| std::vector<Entry> pool_; | ||
| size_t next_index_ = 0; // sequential within current batch, reset on sync | ||
| }; | ||
| std::vector<StagingBufferPool> staging_pools_; // per device | ||
| /*! \brief workspace pool */ | ||
| WorkspacePool pool; | ||
| // constructor | ||
| MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) { | ||
| device.device_id = 0; | ||
| device.device_type = static_cast<DLDeviceType>(kDLMetal); | ||
| MetalWorkspace* global_ws = MetalWorkspace::Global(); | ||
| // by default, set the stream to nullptr, which indicate | ||
| // that we are using default stream | ||
| this->stream.resize(global_ws->devices.size(), nullptr); | ||
| this->staging_pools_.resize(global_ws->devices.size()); | ||
| } | ||
| ~MetalThreadEntry(); | ||
| // Get temp buffer with at least size under dev. | ||
| // Get temp buffer with at least size under dev (for GPU→CPU readback). | ||
| id<MTLBuffer> GetTempBuffer(Device dev, size_t size); | ||
| // Get a staging buffer for inlined CPU→GPU copy (from pool). | ||
| id<MTLBuffer> GetOrCreateStagingBuffer(Device dev, size_t size); | ||
| // Check if the staging pool has reached its limit and needs a flush. | ||
| bool StagingPoolNeedsFlush(Device dev); | ||
| // Reset the staging pool index after a flush. | ||
| void ResetStagingPool(Device dev); | ||
| // get the global workspace | ||
| static MetalThreadEntry* ThreadLocal(); | ||
| }; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A critical Use-After-Free (UAF) vulnerability exists in the
Streamclass destructor. TheStreamdestructor flushes the pending command buffer but does not wait for its completion. This allows theaddCompletedHandlercallback, which capturesthis, to access a deletedStreamobject if the stream is destroyed while work is pending, leading to a process crash. To prevent this UAF, the destructor must ensure all pending work is completed before destruction. ReplacingFlushCommandBuffer()withSynchronize()will resolve this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Stream destructor is only called during MetalWorkspace teardown (process exit or ReinitializeDefaultStreams). At that point no GPU work is in flight. FlushCommandBuffer commits the CB but we do not need to wait for completion because the process is tearing down and the completion handler captures
thiswhich is about to be freed regardless. Using Synchronize() here would block the main thread on GPU idle for no reason. The original code (before this PR) also did not wait on completion during destruction.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be useful to document this via comment for context. This indeed places an implicit requirement that Stream have to be destructed in teardown