Skip to content

Commit 68061aa

Browse files
authored
Parallel merge sort (#6162) (#6308)
* Parallel merge sort (#6162) * Fix test
1 parent a07d6eb commit 68061aa

3 files changed

Lines changed: 42 additions & 31 deletions

File tree

datafusion/core/src/physical_plan/common.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use super::SendableRecordBatchStream;
2121
use crate::error::{DataFusionError, Result};
2222
use crate::execution::context::TaskContext;
2323
use crate::execution::memory_pool::MemoryReservation;
24+
use crate::physical_plan::stream::RecordBatchReceiverStream;
2425
use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics};
2526
use arrow::datatypes::Schema;
2627
use arrow::ipc::writer::{FileWriter, IpcWriteOptions};
@@ -131,6 +132,31 @@ pub(crate) fn spawn_execution(
131132
})
132133
}
133134

135+
/// If running in a tokio context spawns the execution of `stream` to a separate task
136+
/// allowing it to execute in parallel with an intermediate buffer of size `buffer`
137+
pub(crate) fn spawn_buffered(
138+
mut input: SendableRecordBatchStream,
139+
buffer: usize,
140+
) -> SendableRecordBatchStream {
141+
// Use tokio only if running from a tokio context (#2201)
142+
let handle = match tokio::runtime::Handle::try_current() {
143+
Ok(handle) => handle,
144+
Err(_) => return input,
145+
};
146+
147+
let schema = input.schema();
148+
let (sender, receiver) = mpsc::channel(buffer);
149+
let join = handle.spawn(async move {
150+
while let Some(item) = input.next().await {
151+
if sender.send(item).await.is_err() {
152+
return;
153+
}
154+
}
155+
});
156+
157+
RecordBatchReceiverStream::create(&schema, receiver, join)
158+
}
159+
134160
/// Computes the statistics for an in-memory RecordBatch
135161
///
136162
/// Only computes statistics that are in arrows metadata (num rows, byte size and nulls)

datafusion/core/src/physical_plan/sorts/sort.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::execution::memory_pool::{
2525
human_readable_size, MemoryConsumer, MemoryReservation,
2626
};
2727
use crate::execution::runtime_env::RuntimeEnv;
28-
use crate::physical_plan::common::{batch_byte_size, IPCWriter};
28+
use crate::physical_plan::common::{batch_byte_size, spawn_buffered, IPCWriter};
2929
use crate::physical_plan::expressions::PhysicalSortExpr;
3030
use crate::physical_plan::metrics::{
3131
BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet,
@@ -284,11 +284,13 @@ impl ExternalSorter {
284284
self.partition_id,
285285
&self.runtime.memory_pool,
286286
);
287-
sort_batch_stream(batch, self.expr.clone(), self.fetch, metrics)
287+
Ok(spawn_buffered(
288+
sort_batch_stream(batch, self.expr.clone(), self.fetch, metrics)?,
289+
1,
290+
))
288291
})
289292
.collect::<Result<_>>()?;
290293

291-
// TODO: Run batch sorts concurrently (#6162)
292294
// TODO: Pushdown fetch to streaming merge (#6000)
293295

294296
streaming_merge(

datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@ use std::sync::Arc;
2222

2323
use arrow::datatypes::SchemaRef;
2424
use log::{debug, trace};
25-
use tokio::sync::mpsc;
2625

2726
use crate::error::{DataFusionError, Result};
2827
use crate::execution::context::TaskContext;
28+
use crate::physical_plan::common::spawn_buffered;
2929
use crate::physical_plan::metrics::{
3030
ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
3131
};
3232
use crate::physical_plan::sorts::streaming_merge;
33-
use crate::physical_plan::stream::RecordBatchReceiverStream;
3433
use crate::physical_plan::{
35-
common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
36-
Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics,
34+
expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan,
35+
Partitioning, SendableRecordBatchStream, Statistics,
3736
};
3837
use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
3938

@@ -181,29 +180,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
181180
result
182181
}
183182
_ => {
184-
// Use tokio only if running from a tokio context (#2201)
185-
let receivers = match tokio::runtime::Handle::try_current() {
186-
Ok(_) => (0..input_partitions)
187-
.map(|part_i| {
188-
let (sender, receiver) = mpsc::channel(1);
189-
let join_handle = spawn_execution(
190-
self.input.clone(),
191-
sender,
192-
part_i,
193-
context.clone(),
194-
);
195-
196-
RecordBatchReceiverStream::create(
197-
&schema,
198-
receiver,
199-
join_handle,
200-
)
201-
})
202-
.collect(),
203-
Err(_) => (0..input_partitions)
204-
.map(|partition| self.input.execute(partition, context.clone()))
205-
.collect::<Result<_>>()?,
206-
};
183+
let receivers = (0..input_partitions)
184+
.map(|partition| {
185+
let stream = self.input.execute(partition, context.clone())?;
186+
Ok(spawn_buffered(stream, 1))
187+
})
188+
.collect::<Result<_>>()?;
207189

208190
debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
209191

@@ -262,6 +244,7 @@ mod tests {
262244
use crate::physical_plan::memory::MemoryExec;
263245
use crate::physical_plan::metrics::MetricValue;
264246
use crate::physical_plan::sorts::sort::SortExec;
247+
use crate::physical_plan::stream::RecordBatchReceiverStream;
265248
use crate::physical_plan::{collect, common};
266249
use crate::prelude::{SessionConfig, SessionContext};
267250
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
@@ -812,7 +795,7 @@ mod tests {
812795
let mut streams = Vec::with_capacity(partition_count);
813796

814797
for partition in 0..partition_count {
815-
let (sender, receiver) = mpsc::channel(1);
798+
let (sender, receiver) = tokio::sync::mpsc::channel(1);
816799
let mut stream = batches.execute(partition, task_ctx.clone()).unwrap();
817800
let join_handle = tokio::spawn(async move {
818801
while let Some(batch) = stream.next().await {

0 commit comments

Comments
 (0)