diff --git a/core/cli/worker.go b/core/cli/worker.go index 3cb3eda9fb9a..affde4b08842 100644 --- a/core/cli/worker.go +++ b/core/cli/worker.go @@ -512,11 +512,9 @@ func (s *backendSupervisor) stopBackend(backend string) { // Network I/O outside the lock client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken) - if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok { - xlog.Debug("Calling Free() before stopping backend", "backend", backend) - if err := freeFunc.Free(context.Background()); err != nil { - xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err) - } + xlog.Debug("Calling Free() before stopping backend", "backend", backend) + if err := client.Free(context.Background()); err != nil { + xlog.Warn("Free() failed (best-effort)", "backend", backend, "error", err) } xlog.Info("Stopping backend process", "backend", backend, "addr", bp.addr) @@ -774,10 +772,8 @@ func (s *backendSupervisor) subscribeLifecycleEvents() { if targetAddr != "" { // Best-effort gRPC Free() client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken) - if freeFunc, ok := client.(interface{ Free(context.Context) error }); ok { - if err := freeFunc.Free(context.Background()); err != nil { - xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr) - } + if err := client.Free(context.Background()); err != nil { + xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr) } } diff --git a/core/services/nodes/health_mock_test.go b/core/services/nodes/health_mock_test.go index 13a84e37f575..4b49d75a327a 100644 --- a/core/services/nodes/health_mock_test.go +++ b/core/services/nodes/health_mock_test.go @@ -231,6 +231,9 @@ func (c *fakeBackendClient) QuantizationProgress(_ context.Context, _ *pb.Quanti func (c *fakeBackendClient) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) { return nil, nil } +func (c *fakeBackendClient) Free(_ context.Context) error { + return nil +} // --- fakeBackendClientFactory --- diff --git a/core/services/nodes/inflight_test.go b/core/services/nodes/inflight_test.go index 8c1ba068fe3a..8266b6f215b3 100644 --- a/core/services/nodes/inflight_test.go +++ b/core/services/nodes/inflight_test.go @@ -175,6 +175,10 @@ func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.Quantization return &pb.Result{}, nil } +func (f *fakeGRPCBackend) Free(_ context.Context) error { + return nil +} + // --- Tests --- var _ = Describe("InFlightTrackingClient", func() { diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index bfb55a3a75f1..8d9818186a0e 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -85,4 +85,7 @@ type Backend interface { StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) + + // Free releases GPU/model resources (e.g. VRAM) without stopping the process. + Free(ctx context.Context) error } diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index f198f25aad07..c14b20427b35 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -163,6 +163,11 @@ func (e *embedBackend) StopQuantization(ctx context.Context, in *pb.Quantization return e.s.StopQuantization(ctx, in) } +func (e *embedBackend) Free(ctx context.Context) error { + _, err := e.s.Free(ctx, &pb.HealthMessage{}) + return err +} + var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream) type embedBackendFineTuneProgressStream struct { diff --git a/pkg/model/process.go b/pkg/model/process.go index aa451ac1a415..21cb53728868 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -1,6 +1,7 @@ package model import ( + "context" "errors" "fmt" "os" @@ -52,11 +53,9 @@ func (ml *ModelLoader) deleteProcess(s string) error { } // Free GPU resources before stopping the process to ensure VRAM is released - if freeFunc, ok := model.GRPC(false, ml.wd).(interface{ Free() error }); ok { - xlog.Debug("Calling Free() to release GPU resources", "model", s) - if err := freeFunc.Free(); err != nil { - xlog.Warn("Error freeing GPU resources", "error", err, "model", s) - } + xlog.Debug("Calling Free() to release GPU resources", "model", s) + if err := model.GRPC(false, ml.wd).Free(context.Background()); err != nil { + xlog.Warn("Error freeing GPU resources", "error", err, "model", s) } process := model.Process()