Skip to content

Commit 33cc3f2

Browse files
committed
Fix for real
1 parent 91df46f commit 33cc3f2

2 files changed

Lines changed: 179 additions & 4 deletions

File tree

datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,10 @@ fn pushdown_requirement_to_children(
354354
Ok(None)
355355
}
356356
}
357+
} else if let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
358+
handle_aggregate_pushdown(aggregate_exec, parent_required)
357359
} else if maintains_input_order.is_empty()
358360
|| !maintains_input_order.iter().any(|o| *o)
359-
// Aggregate output columns can be computed expressions that are not
360-
// order-preserving wrt input columns. The generic index-based rewrite
361-
// in handle_custom_pushdown is not safe for AggregateExec.
362-
|| plan.as_any().is::<AggregateExec>()
363361
|| plan.as_any().is::<RepartitionExec>()
364362
|| plan.as_any().is::<FilterExec>()
365363
// TODO: Add support for Projection push down
@@ -393,6 +391,68 @@ fn pushdown_requirement_to_children(
393391
// TODO: Add support for Projection push down
394392
}
395393

394+
/// Handle pushdown through aggregates by mapping sort requirements on aggregate
395+
/// output columns back to the corresponding GROUP BY input expressions, if possible.
396+
///
397+
/// Currently only push down when the mapped requirement is already satisfied by
398+
/// the aggregate input. Otherwise, keep the sort above the aggregate.
399+
fn handle_aggregate_pushdown(
400+
aggregate_exec: &AggregateExec,
401+
parent_required: OrderingRequirements,
402+
) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
403+
if !aggregate_exec
404+
.maintains_input_order()
405+
.into_iter()
406+
.any(|o| o)
407+
{
408+
return Ok(None);
409+
}
410+
411+
let group_expr = aggregate_exec.group_expr();
412+
// GROUPING SETS introduce additional output columns and NULL substitutions;
413+
// skip pushdown until we can map those cases safely.
414+
if group_expr.has_grouping_set() {
415+
return Ok(None);
416+
}
417+
418+
let group_input_exprs = group_expr.input_exprs();
419+
let parent_requirement = parent_required.into_single();
420+
let mut child_requirement = Vec::with_capacity(parent_requirement.len());
421+
422+
for req in parent_requirement {
423+
// Sort above AggregateExec should reference its output columns. Map each
424+
// output group-by column to its original input expression.
425+
let Some(column) = req.expr.as_any().downcast_ref::<Column>() else {
426+
return Ok(None);
427+
};
428+
if column.index() >= group_input_exprs.len() {
429+
// Sorting by aggregate result columns can not be pushed through.
430+
return Ok(None);
431+
}
432+
child_requirement.push(PhysicalSortRequirement::new(
433+
Arc::clone(&group_input_exprs[column.index()]),
434+
req.options,
435+
));
436+
}
437+
438+
let Some(child_requirement) = LexRequirement::new(child_requirement) else {
439+
return Ok(None);
440+
};
441+
442+
// Keep sort above aggregate unless input ordering already satisfies the
443+
// mapped requirement.
444+
if aggregate_exec
445+
.input()
446+
.equivalence_properties()
447+
.ordering_satisfy_requirement(child_requirement.iter().cloned())?
448+
{
449+
let child_requirements = OrderingRequirements::new(child_requirement);
450+
Ok(Some(vec![Some(child_requirements)]))
451+
} else {
452+
Ok(None)
453+
}
454+
}
455+
396456
/// Return true if pushing the sort requirements through a node would violate
397457
/// the input sorting requirements for the plan
398458
fn pushdown_would_violate_requirements(

datafusion/sqllogictest/test_files/sort_pushdown.slt

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,121 @@ ORDER BY x, CAST(y AS BIGINT) % 2;
914914
2 0 50
915915
2 1 100
916916

917+
# Test 3.8: Aggregate ORDER BY monotonic expression can push down (no SortExec)
918+
query TT
919+
EXPLAIN SELECT
920+
x,
921+
CAST(y AS BIGINT),
922+
SUM(v)
923+
FROM agg_expr_parquet
924+
GROUP BY x, CAST(y AS BIGINT)
925+
ORDER BY x, CAST(y AS BIGINT);
926+
----
927+
logical_plan
928+
01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y ASC NULLS LAST
929+
02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
930+
03)----TableScan: agg_expr_parquet projection=[x, y, v]
931+
physical_plan
932+
01)AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) as agg_expr_parquet.y], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
933+
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet
934+
935+
query III
936+
SELECT
937+
x,
938+
CAST(y AS BIGINT),
939+
SUM(v)
940+
FROM agg_expr_parquet
941+
GROUP BY x, CAST(y AS BIGINT)
942+
ORDER BY x, CAST(y AS BIGINT);
943+
----
944+
1 1 10
945+
1 2 20
946+
1 3 30
947+
2 1 40
948+
2 2 50
949+
2 3 60
950+
951+
# Test 3.9: Aggregate ORDER BY aggregate output should keep SortExec
952+
query TT
953+
EXPLAIN SELECT x, SUM(v)
954+
FROM agg_expr_parquet
955+
GROUP BY x
956+
ORDER BY SUM(v);
957+
----
958+
logical_plan
959+
01)Sort: sum(agg_expr_parquet.v) ASC NULLS LAST
960+
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
961+
03)----TableScan: agg_expr_parquet projection=[x, v]
962+
physical_plan
963+
01)SortExec: expr=[sum(agg_expr_parquet.v)@1 ASC NULLS LAST], preserve_partitioning=[false]
964+
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
965+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet
966+
967+
query II
968+
SELECT x, SUM(v)
969+
FROM agg_expr_parquet
970+
GROUP BY x
971+
ORDER BY SUM(v);
972+
----
973+
1 60
974+
2 150
975+
976+
# Test 3.10: Aggregate with non-preserved input order should keep SortExec
977+
# v is not part of the order by
978+
query TT
979+
EXPLAIN SELECT v, SUM(y)
980+
FROM agg_expr_parquet
981+
GROUP BY v
982+
ORDER BY v;
983+
----
984+
logical_plan
985+
01)Sort: agg_expr_parquet.v ASC NULLS LAST
986+
02)--Aggregate: groupBy=[[agg_expr_parquet.v]], aggr=[[sum(CAST(agg_expr_parquet.y AS Int64))]]
987+
03)----TableScan: agg_expr_parquet projection=[y, v]
988+
physical_plan
989+
01)SortExec: expr=[v@0 ASC NULLS LAST], preserve_partitioning=[false]
990+
02)--AggregateExec: mode=Single, gby=[v@1 as v], aggr=[sum(agg_expr_parquet.y)]
991+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[y, v], file_type=parquet
992+
993+
query II
994+
SELECT v, SUM(y)
995+
FROM agg_expr_parquet
996+
GROUP BY v
997+
ORDER BY v;
998+
----
999+
10 1
1000+
20 2
1001+
30 3
1002+
40 1
1003+
50 2
1004+
60 3
1005+
1006+
# Test 3.11: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec
1007+
# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1)
1008+
query TT
1009+
EXPLAIN SELECT x, SUM(v)
1010+
FROM agg_expr_parquet
1011+
GROUP BY x
1012+
ORDER BY x + 1 DESC;
1013+
----
1014+
logical_plan
1015+
01)Sort: CAST(agg_expr_parquet.x AS Int64) + Int64(1) DESC NULLS FIRST
1016+
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
1017+
03)----TableScan: agg_expr_parquet projection=[x, v]
1018+
physical_plan
1019+
01)SortExec: expr=[CAST(x@0 AS Int64) + 1 DESC], preserve_partitioning=[false]
1020+
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
1021+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet
1022+
1023+
query II
1024+
SELECT x, SUM(v)
1025+
FROM agg_expr_parquet
1026+
GROUP BY x
1027+
ORDER BY x + 1 DESC;
1028+
----
1029+
2 150
1030+
1 60
1031+
9171032
# Cleanup
9181033
statement ok
9191034
DROP TABLE timestamp_data;

0 commit comments

Comments
 (0)