diff --git a/README.md b/README.md index 77badc778..5df02f282 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The Databricks SDK for Go includes functionality to accelerate development with - [Paginated responses](#paginated-responses) - [GetByName utility methods](#getbyname-utility-methods) - [Node type and Databricks Runtime selectors](#node-type-and-databricks-runtime-selectors) -- [io.Reader integration for DBFS](#ioreader-integration-for-dbfs) +- [Integration with `io` interfaces for DBFS](#integration-with-io-interfaces-for-dbfs) - [Logging](#logging) - [Interface stability](#interface-stability) @@ -477,17 +477,44 @@ runningCluster, err := w.Clusters.CreateAndWait(ctx, clusters.CreateCluster{ }) ``` -## `io.Reader` integration for DBFS +## Integration with `io` interfaces for DBFS -Use the higher-level `w.Dbfs.Open` and `w.Dbfs.Overwrite` methods to work with remote files through the `io.Reader` interface. Internally, these methods wrap the low-level intricacies of working with Databricks REST APIs, providing a convenient interface to you as a developer. +You can open a file on DBFS for reading or writing with `w.Dbfs.Open`. +This function returns a `dbfs.Handle` that is compatible with a subset of `io` +interfaces for reading, writing, and closing. + +Uploading a file from an `io.Reader`: ```go upload, _ := os.Open("/path/to/local/file.ext") -_ = w.Dbfs.Overwrite(ctx, "/path/to/remote/file", upload) +remote, _ := w.Dbfs.Open(ctx, "/path/to/remote/file", dbfs.FileModeWrite|dbfs.FileModeOverwrite) +_, _ = io.Copy(remote, upload) +_ = remote.Close() +``` +Downloading a file to an `io.Writer`: + +```go download, _ := os.Create("/path/to/local") -remote, _ := w.Dbfs.Open(ctx, "/path/to/remote") -_ = io.Copy(download, remote) +remote, _ := w.Dbfs.Open(ctx, "/path/to/remote/file", dbfs.FileModeRead) +_, _ = io.Copy(download, remote) +``` + +### Reading into and writing from buffers + +You can read from or write to a DBFS file directly from a byte slice through +the convenience functions `w.Dbfs.ReadFile` and `w.Dbfs.WriteFile`. + +Uploading a file from a byte slice: + +```go +err := w.Dbfs.WriteFile(ctx, "/path/to/remote/file", []byte("Hello world!")) +``` + +Downloading a file into a byte slice: + +```go +buf, err := w.Dbfs.ReadFile(ctx, "/path/to/remote/file") ``` ## `pflag.Value` for enums diff --git a/internal/dbfs_test.go b/internal/dbfs_test.go index e965d2ed0..7edb69e5a 100644 --- a/internal/dbfs_test.go +++ b/internal/dbfs_test.go @@ -16,7 +16,15 @@ import ( "github.com/stretchr/testify/require" ) -func TestAccDbfsUtilities(t *testing.T) { +type hashable []byte + +func (buf hashable) Hash() uint32 { + h := fnv.New32a() + h.Write(buf) + return h.Sum32() +} + +func TestAccDbfsOpen(t *testing.T) { ctx, w := workspaceTest(t) if w.Config.IsGcp() { t.Skip("dbfs not available on gcp") @@ -26,47 +34,108 @@ func TestAccDbfsUtilities(t *testing.T) { rand.Seed(time.Now().UnixNano()) in := make([]byte, 1.44*1e6) _, _ = rand.Read(in) - h := fnv.New32a() - h.Write(in) - inHash := h.Sum32() - - err := w.Dbfs.Overwrite(ctx, path, bytes.NewReader(in)) - require.NoError(t, err) defer w.Dbfs.Delete(ctx, dbfs.Delete{ Path: path, }) - // Download directly [io.Reader] and let [io.ReadAll] determine buffer size. + // Upload through [io.Writer]. + { + handle, err := w.Dbfs.Open(ctx, path, dbfs.FileModeWrite) + require.NoError(t, err) + n, err := handle.Write(in) + require.NoError(t, err) + assert.Equal(t, len(in), int(n)) + require.NoError(t, handle.Close()) + + // Verify contents hash. + out, err := w.Dbfs.ReadFile(ctx, path) + require.NoError(t, err) + assert.Equal(t, hashable(in).Hash(), hashable(out).Hash()) + } + + // Upload through [io.Writer] should fail because the file exists. + { + _, err := w.Dbfs.Open(ctx, path, dbfs.FileModeWrite) + require.ErrorContains(t, err, "dbfs open: A file or directory already exists at the input path") + } + + // Upload through [io.ReadFrom] with overwrite bit set. + { + handle, err := w.Dbfs.Open(ctx, path, dbfs.FileModeWrite|dbfs.FileModeOverwrite) + require.NoError(t, err) + n, err := handle.ReadFrom(bytes.NewReader(in)) + require.NoError(t, err) + assert.Equal(t, len(in), int(n)) + require.NoError(t, handle.Close()) + + // Verify contents hash. + out, err := w.Dbfs.ReadFile(ctx, path) + require.NoError(t, err) + assert.Equal(t, hashable(in).Hash(), hashable(out).Hash()) + } + + // Download through [io.Reader] and let [io.ReadAll] determine buffer size. { - dbfsReader, err := w.Dbfs.Open(ctx, path) + handle, err := w.Dbfs.Open(ctx, path, dbfs.FileModeRead) require.NoError(t, err) // Note: [io.ReadAll] always calls into the [io.Reader] interface. - out, err := io.ReadAll(dbfsReader) + out, err := io.ReadAll(handle) require.NoError(t, err) // Verify contents hash. - h := fnv.New32a() - h.Write(out) - require.Equal(t, inHash, h.Sum32()) + assert.Equal(t, hashable(in).Hash(), hashable(out).Hash()) } - // Download through [io.WriterTo] with maximum buffer size. + // Download through [io.WriterTo]. { - dbfsReader, err := w.Dbfs.Open(ctx, path) + handle, err := w.Dbfs.Open(ctx, path, dbfs.FileModeRead) require.NoError(t, err) - // Note: [io.Copy] leverages the [io.WriterTo] interface if available. var buf bytes.Buffer - _, err = io.Copy(&buf, dbfsReader) + _, err = handle.WriteTo(&buf) require.NoError(t, err) // Verify contents hash. - h := fnv.New32a() - h.Write(buf.Bytes()) - require.Equal(t, inHash, h.Sum32()) + assert.Equal(t, hashable(in).Hash(), hashable(buf.Bytes()).Hash()) + } +} + +func TestAccDbfsReadFileWriteFile(t *testing.T) { + ctx, w := workspaceTest(t) + if w.Config.IsGcp() { + t.Skip("dbfs not available on gcp") } + + path := RandomName("/tmp/.sdk/fake") + rand.Seed(time.Now().UnixNano()) + in := make([]byte, 1.44*1e6) + _, _ = rand.Read(in) + + defer w.Dbfs.Delete(ctx, dbfs.Delete{ + Path: path, + }) + + // Write file to DBFS. + err := w.Dbfs.WriteFile(ctx, path, in) + require.NoError(t, err) + + // Verify contents hash. + out, err := w.Dbfs.ReadFile(ctx, path) + require.NoError(t, err) + assert.Equal(t, hashable(in).Hash(), hashable(out).Hash()) + + hello := []byte("Hello world!") + + // Writing to the same path should truncate the existing file. + err = w.Dbfs.WriteFile(ctx, path, hello) + require.NoError(t, err) + + // Verify contents hash. + out, err = w.Dbfs.ReadFile(ctx, path) + require.NoError(t, err) + assert.Equal(t, hashable(hello).Hash(), hashable(out).Hash()) } func TestAccListDbfsIntegration(t *testing.T) { diff --git a/service/dbfs/doc.go b/service/dbfs/doc.go index a8dc9a3f8..0779ace43 100644 --- a/service/dbfs/doc.go +++ b/service/dbfs/doc.go @@ -1,33 +1,49 @@ -// Databricks FileSystem (DBFS) API +// Databricks File System (DBFS) API // -// We strongly recommend using clients created via -// [github.com/databricks/databricks-sdk-go/workspaces.New] to simplify -// configuration experience. +// We recommend using a client created via [databricks.NewWorkspaceClient] +// to simplify the configuration experience. // -// Please use the high-level [DbfsAPI.Open] and [DbfsAPI.Overwrite] methods -// to work with remote files through Go's [io] interfaces. The return value -// of [DbfsAPI.Open] implements the [io.Reader] and [io.WriterTo] interfaces. -// The [io.WriterTo] interface is used by [io.Copy] and maximizes throughput by -// reading data with the DBFS maximum read chunk size of 1MB. +// # Reading and writing files // -// Internally, these methods wrap the low level [DbfsAPI.Create], -// [DbfsAPI.Close], [DbfsAPI.Read], and [DbfsAPI.AddBlock] methods: +// You can open a file on DBFS for reading or writing with [DbfsAPI.Open]. +// This function returns a [Handle] that is compatible with a subset of [io] +// interfaces for reading, writing, and closing. +// +// Uploading a file from an [io.Reader]: // // upload, _ := os.Open("/path/to/local/file.ext") -// _ = w.Dbfs.Overwrite(ctx, "/path/to/remote/file", upload) +// remote, _ := w.Dbfs.Open(ctx, "/path/to/remote/file", dbfs.FileModeWrite|dbfs.FileModeOverwrite) +// io.Copy(remote, upload) +// remote.Close() +// +// Downloading a file to an [io.Writer]: // // download, _ := os.Create("/path/to/local") -// remote, _ := w.Dbfs.Open(ctx, "/path/to/remote") +// remote, _ := w.Dbfs.Open(ctx, "/path/to/remote/file", dbfs.FileModeRead) // _ = io.Copy(download, remote) // -// Moving files: +// # Reading and writing files from buffers +// +// You can read from or write to a DBFS file directly from a byte slice through +// the convenience functions [DbfsAPI.ReadFile] and [DbfsAPI.WriteFile]. +// +// Uploading a file from a byte slice: +// +// buf := []byte("Hello world!") +// _ = w.Dbfs.WriteFile(ctx, "/path/to/remote/file", buf) +// +// Downloading a file into a byte slice: +// +// buf, err := w.Dbfs.ReadFile(ctx, "/path/to/remote/file") +// +// # Moving files // // err := w.Dbfs.Move(ctx, dbfs.Move{ // SourcePath: "/remote/src/path", // DestinationPath: "/remote/dst/path", // }) // -// Creating directories: +// # Creating directories // // w.Dbfs.MkdirsByPath(ctx, "/remote/dir/path") package dbfs diff --git a/service/dbfs/utilities.go b/service/dbfs/utilities.go index 5b7a9452a..566f5240b 100644 --- a/service/dbfs/utilities.go +++ b/service/dbfs/utilities.go @@ -1,7 +1,7 @@ package dbfs import ( - "bytes" + "bufio" "context" "encoding/base64" "fmt" @@ -11,116 +11,290 @@ import ( "github.com/databricks/databricks-sdk-go/useragent" ) -var b64 = base64.StdEncoding +// FileMode conveys user intent when opening a file. +type FileMode int -// Overwrite is like Put, but more friendly -func (a *DbfsAPI) Overwrite(ctx context.Context, path string, r io.Reader) (err error) { - ctx = useragent.InContext(ctx, "sdk-feature", "dbfs-overwrite") - handle, err := a.Create(ctx, Create{ - Path: path, - Overwrite: true, - }) +const ( + // Exactly one of FileModeRead or FileModeWrite must be specified. + FileModeRead FileMode = 1 << iota + FileModeWrite + FileModeOverwrite +) + +// Maximum read or write length for the DBFS API. +const maxDbfsBlockSize = 1024 * 1024 + +// Internal only state for a read handle. +type fileHandleReader struct { + size int64 + offset int64 +} + +func (f *fileHandleReader) errorf(format string, a ...any) error { + return fmt.Errorf("dbfs read: "+format, a...) +} + +func (f *fileHandleReader) error(err error) error { + if err == nil { + return nil + } + return f.errorf("%w", err) +} + +// Internal only state for a write handle. +type fileHandleWriter struct { + handle int64 +} + +func (f *fileHandleWriter) errorf(format string, a ...any) error { + return fmt.Errorf("dbfs write: "+format, a...) +} + +func (f *fileHandleWriter) error(err error) error { + if err == nil { + return nil + } + return f.errorf("%w", err) +} + +// Internal only state for a DBFS file handle. +type fileHandle struct { + ctx context.Context + api *DbfsAPI + path string + + reader *fileHandleReader + writer *fileHandleWriter +} + +func (h *fileHandle) checkRead() (*fileHandleReader, error) { + if h.reader != nil { + return h.reader, nil + } + return nil, fmt.Errorf("dbfs: file not open for reading") +} + +func (h *fileHandle) checkWrite() (*fileHandleWriter, error) { + if h.writer != nil { + return h.writer, nil + } + return nil, fmt.Errorf("dbfs: file not open for writing") +} + +// Handle defines the interface of the object returned by [DbfsAPI.Open]. +type Handle interface { + io.ReadWriteCloser + io.WriterTo + io.ReaderFrom +} + +// Implement the [io.Reader] interface. +func (h *fileHandle) Read(p []byte) (int, error) { + r, err := h.checkRead() if err != nil { - return fmt.Errorf("create: %w", err) + return 0, err } - defer func() { - cerr := a.CloseByHandle(ctx, handle.Handle) - if cerr != nil { - err = fmt.Errorf("close: %w", cerr) + + var ntotal int + for ntotal < len(p) { + if r.offset >= r.size { + return ntotal, io.EOF + } + + chunk := p[ntotal:] + if len(chunk) > maxDbfsBlockSize { + chunk = chunk[:maxDbfsBlockSize] + } + + res, err := h.api.Read(h.ctx, Read{ + Path: h.path, + Length: len(chunk), + Offset: int(r.offset), // TODO: make int32/in64 work properly + }) + if err != nil { + return ntotal, r.error(err) } - }() - buffer := make([]byte, 1e6) - for { - n, err := r.Read(buffer) - if err == io.EOF { - break + + // The guard against offset >= size happens above, so this can only happen + // if the file is modified or truncated while reading. If this happens, + // the read contents will likely be corrupted, so we return an error. + if res.BytesRead == 0 { + return ntotal, r.errorf("unexpected EOF at offset %d (size %d)", r.offset, r.size) } + + nread, err := base64.StdEncoding.Decode(chunk, []byte(res.Data)) if err != nil { - return fmt.Errorf("read: %w", err) + return ntotal, r.error(err) + } + + ntotal += nread + r.offset += int64(nread) + } + + return ntotal, nil +} + +// Implement the [io.Writer] interface. +func (h *fileHandle) Write(p []byte) (int, error) { + w, err := h.checkWrite() + if err != nil { + return 0, err + } + + var ntotal int + for ntotal < len(p) { + chunk := p[ntotal:] + if len(chunk) > maxDbfsBlockSize { + chunk = chunk[:maxDbfsBlockSize] } - err = a.AddBlock(ctx, AddBlock{ - Data: b64.EncodeToString(buffer[0:n]), - Handle: handle.Handle, + + err := h.api.AddBlock(h.ctx, AddBlock{ + Data: base64.StdEncoding.EncodeToString(chunk), + Handle: w.handle, }) if err != nil { - return fmt.Errorf("add block: %w", err) + return ntotal, w.error(err) } + + ntotal += len(chunk) } - return err + + return ntotal, nil } -type FileReader struct { - Size int64 - ctx context.Context - api *DbfsAPI - path string - offset int64 +// Implement the [io.Closer] interface. +func (h *fileHandle) Close() error { + w, err := h.checkWrite() + if err != nil { + return err + } + + return w.error(h.api.CloseByHandle(h.ctx, w.handle)) +} + +// Implement the [io.WriterTo] interface. +func (h *fileHandle) WriteTo(w io.Writer) (int64, error) { + _, err := h.checkRead() + if err != nil { + return 0, err + } + + // Limit types to io.Reader and io.Writer to avoid recursion + // into WriteTo or ReadFrom functions on underlying types. + ior := struct{ io.Reader }{h} + iow := struct{ io.Writer }{w} + return bufio.NewReaderSize(ior, maxDbfsBlockSize).WriteTo(iow) +} + +// Implement the [io.ReaderFrom] interface. +func (h *fileHandle) ReadFrom(r io.Reader) (int64, error) { + _, err := h.checkWrite() + if err != nil { + return 0, err + } + + // Limit types to io.Reader and io.Writer to avoid recursion + // into WriteTo or ReadFrom functions on underlying types. + ior := struct{ io.Reader }{r} + iow := struct{ io.Writer }{h} + bw := bufio.NewWriterSize(iow, maxDbfsBlockSize) + n, err := bw.ReadFrom(ior) + if err != nil { + return n, err + } + return n, bw.Flush() } -func (r *FileReader) Read(p []byte) (n int, err error) { - if r.api == nil { - panic("invalid call") +func (h *fileHandle) openForRead(mode FileMode) error { + res, err := h.api.GetStatusByPath(h.ctx, h.path) + if err != nil { + return err } - if r.offset >= r.Size { - return 0, io.EOF + h.reader = &fileHandleReader{ + size: res.FileSize, } - resp, err := r.api.Read(r.ctx, Read{ - Path: r.path, - Length: len(p), - Offset: int(r.offset), // TODO: make int32/in64 work properly + return nil +} + +func (h *fileHandle) openForWrite(mode FileMode) error { + res, err := h.api.Create(h.ctx, Create{ + Path: h.path, + Overwrite: (mode & FileModeOverwrite) != 0, }) if err != nil { - return 0, fmt.Errorf("dbfs read: %w", err) + return err + } + h.writer = &fileHandleWriter{ + handle: res.Handle, } - // The guard against offset >= size happens above, so this can only happen - // if the file is modified or truncated while reading. If this happens, - // the read contents will likely be corrupted, so we return an error. - if resp.BytesRead == 0 { - return 0, fmt.Errorf("dbfs read: unexpected EOF at offset %d (size %d)", r.offset, r.Size) + return nil +} + +// Open opens a remote DBFS file for reading or writing. +// The returned object implements relevant [io] interfaces for convenient +// integration with other code that reads or writes bytes. +// +// The [io.WriterTo] interface is provided and maximizes throughput for +// bulk reads by reading data with the DBFS maximum read chunk size of 1MB. +// Similarly, the [io.ReaderFrom] interface is provided for bulk writing. +// +// A file opened for writing must always be closed. +func (a *DbfsAPI) Open(ctx context.Context, path string, mode FileMode) (Handle, error) { + h := &fileHandle{ + ctx: useragent.InContext(ctx, "sdk-feature", "dbfs-handle"), + api: a, + path: path, + } + + isRead := (mode & FileModeRead) != 0 + isWrite := (mode & FileModeWrite) != 0 + if (isRead && isWrite) || (!isRead && !isWrite) { + return nil, fmt.Errorf("dbfs open: must specify dbfs.FileModeRead or dbfs.FileModeWrite") } - r.offset += resp.BytesRead - return b64.Decode(p, []byte(resp.Data)) + + var err error + if isRead { + err = h.openForRead(mode) + } + if isWrite { + err = h.openForWrite(mode) + } + if err != nil { + return nil, fmt.Errorf("dbfs open: %w", err) + } + + return h, nil } -// Maximum read length for the DBFS read API (see [DbfsApi.Read]). -const maxDbfsReadSize = 1024 * 1024 +// ReadFile is identical to [os.ReadFile] but for DBFS. +func (a *DbfsAPI) ReadFile(ctx context.Context, name string) ([]byte, error) { + h, err := a.Open(ctx, name, FileModeRead) + if err != nil { + return nil, err + } -// WriteTo makes [FileReader] implement the [io.WriterTo] interface. -// This can be used with [io.Copy] to maximize throughput, as -// it uses the maximum buffer size allowed by the DBFS API. -func (r *FileReader) WriteTo(w io.Writer) (n int64, err error) { - buf := make([]byte, maxDbfsReadSize) - nwritten := int64(0) - for { - n, err := r.Read(buf) - if err != nil { - // EOF on read means we're done. - // For writers being done means returning a nil error. - if err == io.EOF { - err = nil - } - return nwritten, err - } - n64, err := io.Copy(w, bytes.NewReader(buf[:n])) - nwritten += n64 - if err != nil { - return nwritten, err - } + h_ := h.(*fileHandle) + buf := make([]byte, h_.reader.size) + _, err = h.Read(buf) + if err != nil && err != io.EOF { + return nil, err } + return buf, nil } -func (a *DbfsAPI) Open(ctx context.Context, path string) (*FileReader, error) { - ctx = useragent.InContext(ctx, "sdk-feature", "dbfs-open") - info, err := a.GetStatusByPath(ctx, path) +// WriteFile is identical to [os.WriteFile] but for DBFS. +func (a *DbfsAPI) WriteFile(ctx context.Context, name string, data []byte) error { + h, err := a.Open(ctx, name, FileModeWrite|FileModeOverwrite) if err != nil { - return nil, fmt.Errorf("get status: %w", err) + return err } - return &FileReader{ - Size: info.FileSize, - path: path, - ctx: ctx, - api: a, - }, nil + + _, err = h.Write(data) + cerr := h.Close() + if err == nil && cerr != nil { + err = cerr + } + return err } // RecursiveList traverses the DBFS tree and returns all non-directory