Skip to content

Commit 16d628f

Browse files
committed
More Arrow IPC scaffolding
Change-Id: I9eaf54a1a058a18f17251816ec22e5e4e3a260da
1 parent 591aceb commit 16d628f

3 files changed

Lines changed: 156 additions & 0 deletions

File tree

cpp/src/arrow/gpu/cuda-test.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "gtest/gtest.h"
2323

2424
#include "arrow/status.h"
25+
#include "arrow/ipc/test-common.h"
26+
#include "arrow/ipc/api.h"
2527
#include "arrow/test-util.h"
2628

2729
#include "arrow/gpu/cuda_api.h"
@@ -262,5 +264,41 @@ TEST_F(TestCudaBufferReader, Basics) {
262264
ASSERT_EQ(0, std::memcmp(stack_buffer, host_data + 925, 75));
263265
}
264266

267+
class TestCudaArrowIpc : public TestCudaBufferBase {
268+
public:
269+
void SetUp() {
270+
TestCudaBufferBase::SetUp();
271+
pool_ = default_memory_pool();
272+
}
273+
274+
protected:
275+
MemoryPool* pool_;
276+
};
277+
278+
TEST_F(TestCudaArrowIpc, BasicWriteRead) {
279+
std::shared_ptr<RecordBatch> batch;
280+
ASSERT_OK(ipc::MakeIntRecordBatch(&batch));
281+
282+
std::shared_ptr<CudaBuffer> device_serialized;
283+
ASSERT_OK(arrow::gpu::SerializeRecordBatch(*batch, context_.get(),
284+
&device_serialized));
285+
286+
// Test that ReadRecordBatch works properly
287+
std::shared_ptr<RecordBatch> device_batch;
288+
ASSERT_OK(ReadRecordBatch(batch->schema(), device_serialized, &device_batch));
289+
290+
// Copy data from device, read batch, and compare
291+
std::shared_ptr<MutableBuffer> host_buffer;
292+
int64_t size = device_serialized->size();
293+
ASSERT_OK(AllocateBuffer(pool_, size, &host_buffer));
294+
ASSERT_OK(device_serialized->CopyToHost(0, size, host_buffer->mutable_data()));
295+
296+
std::shared_ptr<RecordBatch> cpu_batch;
297+
io::BufferReader cpu_reader(host_buffer);
298+
ASSERT_OK(ipc::ReadRecordBatch(batch->schema(), &cpu_reader, &cpu_batch));
299+
300+
ipc::CompareBatch(*batch, *cpu_batch);
301+
}
302+
265303
} // namespace gpu
266304
} // namespace arrow
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <cstdint>
19+
#include <memory>
20+
21+
#include "arrow/buffer.h"
22+
#include "arrow/ipc/message.h"
23+
#include "arrow/ipc/writer.h"
24+
#include "arrow/status.h"
25+
#include "arrow/table.h"
26+
#include "arrow/util/visibility.h"
27+
28+
#include "arrow/gpu/cuda_context.h"
29+
#include "arrow/gpu/cuda_memory.h"
30+
31+
namespace arrow {
32+
namespace gpu {
33+
34+
Status SerializeRecordBatch(const RecordBatch& batch, CudaContext* ctx,
35+
std::shared_ptr<CudaBuffer>* out) {
36+
int64_t size = 0;
37+
RETURN_NOT_OK(ipc::GetRecordBatchSize(batch, &size));
38+
39+
std::shared_ptr<CudaBuffer> buffer;
40+
RETURN_NOT_OK(ctx->Allocate(size, &buffer));
41+
42+
CudaBufferWriter stream(buffer);
43+
44+
// Use 8MB buffering, which yields generally good performance
45+
RETURN_NOT_OK(stream.SetBufferSize(1 << 23));
46+
47+
// We use the default memory pool here since any allocations are ephemeral
48+
RETURN_NOT_OK(ipc::SerializeRecordBatch(batch, default_memory_pool(),
49+
&stream));
50+
*out = buffer;
51+
return Status::OK();
52+
}
53+
54+
Status ReadMessage(CudaBufferReader* stream, MemoryPool* pool,
55+
std::unique_ptr<Message>* message) {
56+
uint8_t length_buf[4] = {0};
57+
58+
int64_t bytes_read = 0;
59+
RETURN_NOT_OK(file->Read(sizeof(int32_t), &bytes_read, length_buf));
60+
if (bytes_read != sizeof(int32_t)) {
61+
*message = nullptr;
62+
return Status::OK();
63+
}
64+
65+
const int32_t metadata_length = *reinterpret_cast<const int32_t*>(length_buf);
66+
67+
if (metadata_length == 0) {
68+
// Optional 0 EOS control message
69+
*message = nullptr;
70+
return Status::OK();
71+
}
72+
73+
std::shared_ptr<MutableBuffer> metadata;
74+
RETURN_NOT_OK(AllocateBuffer(pool, metadata_length, &metadata));
75+
RETURN_NOT_OK(file->Read(message_length, &bytes_read, metadata->mutable_data()));
76+
if (bytes_read != metadata_length) {
77+
return Status::IOError("Unexpected end of stream trying to read message");
78+
}
79+
80+
auto fb_message = flatbuf::GetMessage(metadata->data());
81+
82+
int64_t body_length = fb_message->bodyLength();
83+
84+
// Zero copy
85+
std::shared_ptr<Buffer> body;
86+
RETURN_NOT_OK(stream->Read(body_length, &body));
87+
if (body->size() < body_length) {
88+
std::stringstream ss;
89+
ss << "Expected to be able to read " << body_length << " bytes for message body, got "
90+
<< body->size();
91+
return Status::IOError(ss.str());
92+
}
93+
94+
return Message::Open(metadata, body, message);
95+
}
96+
97+
} // namespace gpu
98+
} // namespace arrow

cpp/src/arrow/gpu/cuda_arrow_ipc.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,26 @@ ARROW_EXPORT
4343
Status SerializeRecordBatch(const RecordBatch& batch, CudaContext* ctx,
4444
std::shared_ptr<CudaBuffer>* out);
4545

46+
/// \brief Read Arrow IPC message located on GPU device
47+
/// \param[in] stream a CudaBufferReader
48+
/// \param[in] pool a MemoryPool to allocate CPU memory for the metadata
49+
/// \param[out] message the deserialized message, body still on device
50+
///
51+
/// This function reads the message metadata into host memory, but leaves the
52+
/// message body on the device
53+
ARROW_EXPORT
54+
Status ReadMessage(io::CudaBufferReader* stream, MemoryPool* pool,
55+
std::unique_ptr<Message>* message);
56+
57+
/// \brief ReadRecordBatch specialized to handle metadata on CUDA device
58+
/// \param[in] schema
59+
/// \param[in] buffer
60+
/// \param[out] out the reconstructed RecordBatch, with device pointers
61+
ARROW_EXPORT
62+
Status ReadRecordBatch(const std::shared_ptr<Schema>& schema
63+
const std::shared_ptr<CudaBuffer>& buffer,
64+
std::shared_ptr<RecordBatch>* out);
65+
4666
} // namespace gpu
4767
} // namespace arrow
4868

0 commit comments

Comments
 (0)