Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions core/cli/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,9 @@ func (s *backendSupervisor) stopBackend(backend string) {

// Network I/O outside the lock
client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
}
xlog.Debug("Calling Free() before stopping backend", "backend", backend)
if err := client.Free(context.Background()); err != nil {
xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err)
}

xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr)
Expand Down Expand Up @@ -774,10 +772,8 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
if targetAddr != "" {
// Best-effort gRPC Free()
client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken)
if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok {
if err := freeFunc.Free(context.Background()); err != nil {
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
}
if err := client.Free(context.Background()); err != nil {
xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr)
}
}

Expand Down
3 changes: 3 additions & 0 deletions core/services/nodes/health_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ func (c *fakeBackendClient) QuantizationProgress(_ context.Context, _ *pb.Quanti
func (c *fakeBackendClient) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return nil, nil
}
func (c *fakeBackendClient) Free(_ context.Context) error {
return nil
}

// --- fakeBackendClientFactory ---

Expand Down
4 changes: 4 additions & 0 deletions core/services/nodes/inflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.Quantization
return &pb.Result{}, nil
}

func (f *fakeGRPCBackend) Free(_ context.Context) error {
return nil
}

// --- Tests ---

var _ = Describe("InFlightTrackingClient", func() {
Expand Down
3 changes: 3 additions & 0 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,7 @@ type Backend interface {
StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error)
QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error
StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error)

// Free releases GPU/model resources (e.g. VRAM) without stopping the process.
Free(ctx context.Context) error
}
5 changes: 5 additions & 0 deletions pkg/grpc/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ func (e *embedBackend) StopQuantization(ctx context.Context, in *pb.Quantization
return e.s.StopQuantization(ctx, in)
}

func (e *embedBackend) Free(ctx context.Context) error {
_, err := e.s.Free(ctx, &pb.HealthMessage{})
return err
}

var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)

type embedBackendFineTuneProgressStream struct {
Expand Down
9 changes: 4 additions & 5 deletions pkg/model/process.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package model

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -52,11 +53,9 @@ func (ml *ModelLoader) deleteProcess(s string) error {
}

// Free GPU resources before stopping the process to ensure VRAM is released
if freeFunc, ok := model.GRPC(false, ml.wd).(interface{ Free() error }); ok {
xlog.Debug("Calling Free() to release GPU resources", "model", s)
if err := freeFunc.Free(); err != nil {
xlog.Warn("Error freeing GPU resources", "error", err, "model", s)
}
xlog.Debug("Calling Free() to release GPU resources", "model", s)
if err := model.GRPC(false, ml.wd).Free(context.Background()); err != nil {
xlog.Warn("Error freeing GPU resources", "error", err, "model", s)
}

process := model.Process()
Expand Down
Loading