Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,12 @@ impl ExecutionPlan for RepartitionExec {
if preserve_order {
// Store streams from all the input partitions:
// Each input partition gets its own spill reader to maintain proper FIFO ordering
//
// Pass None for metrics here — these intermediate streams feed into
// StreamingMerge which is the actual output. Only the merge's
// BaselineMetrics should contribute to the operator's reported
// output_rows. Without this, every row would be counted twice
// (once by PerPartitionStream, once by StreamingMerge).
let input_streams = rx
.into_iter()
.zip(spill_readers)
Expand All @@ -1049,7 +1055,7 @@ impl ExecutionPlan for RepartitionExec {
Arc::clone(&reservation),
spill_stream,
1, // Each receiver handles one input partition
BaselineMetrics::new(&metrics, partition),
None,
None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
)) as SendableRecordBatchStream
})
Expand Down Expand Up @@ -1088,7 +1094,7 @@ impl ExecutionPlan for RepartitionExec {
reservation,
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
Some(BaselineMetrics::new(&metrics, partition)),
Some(context.session_config().batch_size()),
)) as SendableRecordBatchStream)
}
Expand Down Expand Up @@ -1576,8 +1582,8 @@ struct PerPartitionStream {
/// each sending None when complete. We must wait for all of them.
remaining_partitions: usize,

/// Execution metrics
baseline_metrics: BaselineMetrics,
/// Execution metrics (None in preserve-order mode where StreamingMerge owns the metrics)
baseline_metrics: Option<BaselineMetrics>,

/// None for sort preserving variant (merge sort already does coalescing)
batch_coalescer: Option<LimitedBatchCoalescer>,
Expand All @@ -1592,7 +1598,7 @@ impl PerPartitionStream {
reservation: SharedMemoryReservation,
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
baseline_metrics: Option<BaselineMetrics>,
batch_size: Option<usize>,
) -> Self {
let batch_coalescer =
Expand All @@ -1615,8 +1621,11 @@ impl PerPartitionStream {
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
use futures::StreamExt;
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let _timer = cloned_time.timer();
let elapsed = self
.baseline_metrics
.as_ref()
.map(|m| m.elapsed_compute().clone());
let _timer = elapsed.as_ref().map(|t| t.timer());

loop {
match self.state {
Expand Down Expand Up @@ -1696,7 +1705,10 @@ impl PerPartitionStream {
cx: &mut Context<'_>,
coalescer: &mut LimitedBatchCoalescer,
) -> Poll<Option<Result<RecordBatch>>> {
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let cloned_time = self
.baseline_metrics
.as_ref()
.map(|m| m.elapsed_compute().clone());
let mut completed = false;

loop {
Expand All @@ -1709,7 +1721,7 @@ impl PerPartitionStream {

match ready!(self.poll_next_inner(cx)) {
Some(Ok(batch)) => {
let _timer = cloned_time.timer();
let _timer = cloned_time.as_ref().map(|t| t.timer());
if let Err(err) = coalescer.push_batch(batch) {
return Poll::Ready(Some(Err(err)));
}
Expand All @@ -1719,7 +1731,7 @@ impl PerPartitionStream {
}
None => {
completed = true;
let _timer = cloned_time.timer();
let _timer = cloned_time.as_ref().map(|t| t.timer());
if let Err(err) = coalescer.finish() {
return Poll::Ready(Some(Err(err)));
}
Expand All @@ -1743,7 +1755,11 @@ impl Stream for PerPartitionStream {
} else {
poll = self.poll_next_inner(cx);
}
self.baseline_metrics.record_poll(poll)
if let Some(metrics) = &self.baseline_metrics {
metrics.record_poll(poll)
} else {
poll
}
}
}

Expand Down Expand Up @@ -2953,4 +2969,47 @@ mod test {
let exec = Arc::new(exec);
Arc::new(TestMemoryExec::update_cache(&exec))
}

/// preserve_order repartition should not double-count
/// output rows.
#[tokio::test]
async fn test_preserve_order_output_rows_not_double_counted() -> Result<()> {
use datafusion_execution::TaskContext;

// Two sorted input partitions, 2 rows each (4 total)
let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
let schema = batch1.schema();
let sort_exprs = sort_exprs(&schema);

let input_partitions = vec![vec![batch1], vec![batch2]];
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
.try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
let exec = Arc::new(exec);
let exec = Arc::new(TestMemoryExec::update_cache(&exec));

let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
.with_preserve_order();

let task_ctx = Arc::new(TaskContext::default());
let mut total_rows = 0;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
total_rows += result?.num_rows();
}
}

assert_eq!(total_rows, 4, "actual rows collected should be 4");

let metrics = exec.metrics().unwrap();
let reported_output_rows = metrics.output_rows().unwrap();
assert_eq!(
reported_output_rows, total_rows,
"metrics output_rows ({reported_output_rows}) should match \
actual rows collected ({total_rows}), not double-count"
);

Ok(())
}
}
Loading