Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 222 additions & 7 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,37 @@ class AutoReleasePoolWrapper {
};

/*!
* \brief Structure for error handling in queues
* \brief Metal command stream with batched dispatch support.
*
* Compute dispatches are batched into a single command buffer via
* GetPendingComputeEncoder(). Blit operations (copies) are interleaved
* on the same command buffer via GetBlitEncoderOnPendingBuffer().
* The command buffer is committed when FlushCommandBuffer() is called.
*
* Must call FlushCommandBuffer() before:
* - GPU→CPU readback (need data in CPU memory)
* - Buffer deallocation (FreeDataSpace, setPurgeableState:Empty on
* a buffer referenced by an uncommitted CB crashes Metal)
* - Stream sync (StreamSync / Synchronize)
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer(std::string label = "", bool attach_error_callback = true) {
// Stream is only destroyed during MetalWorkspace teardown (process exit
// or ReinitializeDefaultStreams), so no GPU work is in flight. We flush
// to commit any pending CB but do not wait for completion.
~Stream() {
FlushCommandBuffer();
[queue_ release];
}
Comment on lines +125 to +128

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

A critical Use-After-Free (UAF) vulnerability exists in the Stream class destructor. The Stream destructor flushes the pending command buffer but does not wait for its completion. This allows the addCompletedHandler callback, which captures this, to access a deleted Stream object if the stream is destroyed while work is pending, leading to a process crash. To prevent this UAF, the destructor must ensure all pending work is completed before destruction. Replacing FlushCommandBuffer() with Synchronize() will resolve this.

  ~Stream() {
    Synchronize();
    [queue_ release];
  }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Stream destructor is only called during MetalWorkspace teardown (process exit or ReinitializeDefaultStreams). At that point no GPU work is in flight. FlushCommandBuffer commits the CB but we do not need to wait for completion because the process is tearing down and the completion handler captures this which is about to be freed regardless. Using Synchronize() here would block the main thread on GPU idle for no reason. The original code (before this PR) also did not wait on completion during destruction.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be useful to document this via comment for context. This indeed places an implicit requirement that Stream have to be destructed in teardown


/*!
* \brief Get a standalone command buffer (for GPU→CPU readback only).
*
* Used when we need a separate command buffer that we can commit
* and waitUntilCompleted on independently.
*/
id<MTLCommandBuffer> GetCommandBuffer(std::string label = "") {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
if (!label.empty()) {
cb.label = [NSString stringWithUTF8String:label.c_str()];
Expand All @@ -123,6 +147,99 @@ class Stream {
return cb;
}

/*!
* \brief Get the pending compute command encoder, creating one if needed.
*
* Multiple compute dispatches are batched into a single command buffer
* and encoder. Blit operations (copies) can be interleaved on the same
* command buffer via GetBlitEncoderOnPendingBuffer(). The entire command
* buffer is committed when FlushCommandBuffer() is called.
*
* Must flush before:
* - GPU→CPU readback (need data on CPU immediately)
* - Buffer deallocation (FreeDataSpace)
* - Stream sync (StreamSync)
*/
id<MTLComputeCommandEncoder> GetPendingComputeEncoder(const std::string& kernel_name = "") {
if (pending_compute_encoder_ == nil) {
id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
pending_compute_encoder_ = [[cb computeCommandEncoder] retain];
}
if (!kernel_name.empty()) {
last_dispatched_kernel_ = kernel_name;
}
profile.dispatches++;
return pending_compute_encoder_;
}

/*!
* \brief Get a blit encoder on the pending command buffer.
*
* Pauses the active compute encoder (if any), creates a blit encoder
* on the same command buffer. Caller must call [encoder endEncoding]
* when done. The next GetPendingComputeEncoder() call will create a
* new compute encoder on the same command buffer.
*
* Metal guarantees sequential ordering of encoders within a command
* buffer, so blits encoded here execute after prior compute dispatches
* and before subsequent ones.
*/
id<MTLBlitCommandEncoder> GetBlitEncoderOnPendingBuffer() {
EndPendingComputeEncoder();
id<MTLCommandBuffer> cb = GetOrCreatePendingCommandBuffer();
profile.blits++;
return [cb blitCommandEncoder];
}

/*!
* \brief Flush: end active encoder, commit the command buffer.
*
* Safe to call when nothing is pending (no-op).
*/
void FlushCommandBuffer() {
EndPendingComputeEncoder();
if (pending_command_buffer_ != nil) {
[pending_command_buffer_ commit];
[pending_command_buffer_ release];
pending_command_buffer_ = nil;
profile.flushes++;
}
}

/*!
* \brief Flush pending work, then wait for all submitted work to complete.
*/
void Synchronize() {
FlushCommandBuffer();
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) {
TVM_FFI_ICHECK(buffer.error != nil);
this->SetError(buffer.error.localizedDescription.UTF8String);
}
}];
[cb commit];
[cb waitUntilCompleted];
profile.syncs++;
}

bool HasPendingWork() const { return pending_command_buffer_ != nil; }

/*! \brief Profiling counters for diagnosing dispatch/copy/sync overhead. */
struct ProfileCounters {
size_t dispatches = 0;
size_t flushes = 0;
size_t syncs = 0;
size_t blits = 0;
size_t gpu_to_cpu = 0;
size_t cpu_to_gpu = 0;
size_t gpu_to_gpu = 0;
size_t free_syncs = 0; // FreeDataSpace calls that triggered a sync

void Reset() { *this = ProfileCounters(); }
};
ProfileCounters profile;

void SetError(std::string error_description) {
error_happened_ = true;
error_description_ = std::move(error_description);
Expand All @@ -133,8 +250,42 @@ class Stream {
const std::string& ErrorDescription() const { return error_description_; }

private:
/*! \brief Get or create the pending command buffer (shared by compute and blit). */
id<MTLCommandBuffer> GetOrCreatePendingCommandBuffer() {
if (pending_command_buffer_ == nil) {
pending_command_buffer_ = [[queue_ commandBuffer] retain];
pending_command_buffer_.label = @"TVMBatched";
[pending_command_buffer_ addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) {
TVM_FFI_ICHECK(buffer.error != nil);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The completedHandler uses TVM_FFI_ICHECK, which throws an exception if the condition fails. Because this handler is executed on a background thread owned by the Metal runtime, an unhandled exception will likely cause the entire process to terminate abruptly (via std::terminate), as there is typically no exception handling logic on the stack of these background threads.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TVM_FFI_ICHECK is the standard error reporting mechanism used throughout the TVM Metal runtime. The pre-existing code (GetCommandBuffer, which this PR does not change) already uses TVM_FFI_ICHECK in the same completion handler pattern. This is not a regression introduced by this PR.

std::string msg = buffer.error.localizedDescription.UTF8String;
if (!this->last_dispatched_kernel_.empty()) {
msg = "GPUError after kernel " + this->last_dispatched_kernel_ + ": " + msg;
}
this->SetError(msg);
}
}];
}
return pending_command_buffer_;
}

/*! \brief End the active compute encoder without committing the command buffer. */
void EndPendingComputeEncoder() {
if (pending_compute_encoder_ != nil) {
[pending_compute_encoder_ endEncoding];
[pending_compute_encoder_ release];
pending_compute_encoder_ = nil;
}
}

// Queue
id<MTLCommandQueue> queue_;
// Pending command buffer (shared by compute and blit encoders)
id<MTLCommandBuffer> pending_command_buffer_ = nil;
// Active compute encoder on the pending command buffer (nil when paused/blit)
id<MTLComputeCommandEncoder> pending_compute_encoder_ = nil;
// Last dispatched kernel name (for error diagnostics)
std::string last_dispatched_kernel_;
// Check if error happened in one previous run
bool error_happened_{false};
// error description
Expand Down Expand Up @@ -201,22 +352,86 @@ class MetalThreadEntry {
Device device;
/*! \brief The current stream */
std::vector<TVMStreamHandle> stream;
/*! \brief The shared buffer used for copy. */
/*! \brief The shared buffer used for GPU→CPU readback. */
std::vector<id<MTLBuffer>> temp_buffer_;
/*!
* \brief Pool of staging buffers for CPU→GPU copies that are inlined
* into the pending command buffer. Each inlined copy needs its own
* staging buffer because the GPU reads them asynchronously.
* Buffers are recycled after FlushCommandBuffer()/Synchronize().
*/
struct StagingBufferPool {
Comment thread
mitiskuma marked this conversation as resolved.
public:
/*! \brief Maximum staging buffers before requiring a flush.
* Prevents unbounded pool growth in workloads with many CPU→GPU copies
* between syncs. When this limit is reached, the caller must flush the
* stream (to make all pending staging buffers safe to reuse) before
* requesting more buffers. */
static constexpr size_t kMaxStagingBuffers = 64;

id<MTLBuffer> GetOrCreate(id<MTLDevice> dev, size_t nbytes) {
if (next_index_ < pool_.size() && pool_[next_index_].size >= nbytes) {
return pool_[next_index_++].buffer;
}
// Need a new or bigger buffer at this index
if (next_index_ < pool_.size() && pool_[next_index_].buffer != nil) {
[pool_[next_index_].buffer release];
}
if (next_index_ >= pool_.size()) {
pool_.push_back({nil, 0});
}
pool_[next_index_].buffer = [dev newBufferWithLength:nbytes options:MTLStorageModeShared];
TVM_FFI_ICHECK(pool_[next_index_].buffer != nil)
<< "Failed to allocate staging buffer of size " << nbytes;
pool_[next_index_].size = nbytes;
return pool_[next_index_++].buffer;
}

// Called after flush/sync, all staging buffers are safe to reuse
void ResetIndex() { next_index_ = 0; }

// Number of staging buffers used in the current batch
size_t Size() const { return next_index_; }

// True when the pool has reached its limit and needs a flush before more allocations
bool NeedsFlush() const { return next_index_ >= kMaxStagingBuffers; }

~StagingBufferPool() {
for (auto& e : pool_) {
if (e.buffer != nil) {
[e.buffer release];
}
}
}

private:
struct Entry {
id<MTLBuffer> buffer = nil;
size_t size = 0;
};
std::vector<Entry> pool_;
size_t next_index_ = 0; // sequential within current batch, reset on sync
};
std::vector<StagingBufferPool> staging_pools_; // per device
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
device.device_id = 0;
device.device_type = static_cast<DLDeviceType>(kDLMetal);
MetalWorkspace* global_ws = MetalWorkspace::Global();
// by default, set the stream to nullptr, which indicate
// that we are using default stream
this->stream.resize(global_ws->devices.size(), nullptr);
this->staging_pools_.resize(global_ws->devices.size());
}
~MetalThreadEntry();
// Get temp buffer with at least size under dev.
// Get temp buffer with at least size under dev (for GPU→CPU readback).
id<MTLBuffer> GetTempBuffer(Device dev, size_t size);
// Get a staging buffer for inlined CPU→GPU copy (from pool).
id<MTLBuffer> GetOrCreateStagingBuffer(Device dev, size_t size);
// Check if the staging pool has reached its limit and needs a flush.
bool StagingPoolNeedsFlush(Device dev);
// Reset the staging pool index after a flush.
void ResetStagingPool(Device dev);
// get the global workspace
static MetalThreadEntry* ThreadLocal();
};
Expand Down
Loading
Loading