Skip to content

Commit abb2ae7

Browse files
NGA-TRANalamb
andauthored
Revert "Minor: remove unnecessary projection in `single_distinct_to_g… (#8176)
* Revert "Minor: remove unnecessary projection in `single_distinct_to_group_by` rule (#8061)" This reverts commit 15d8c9b. * Add regression test --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent f390f15 commit abb2ae7

4 files changed

Lines changed: 130 additions & 55 deletions

File tree

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ use std::sync::Arc;
2222
use crate::optimizer::ApplyOrder;
2323
use crate::{OptimizerConfig, OptimizerRule};
2424

25-
use datafusion_common::Result;
25+
use datafusion_common::{DFSchema, Result};
2626
use datafusion_expr::{
2727
col,
2828
expr::AggregateFunction,
29-
logical_plan::{Aggregate, LogicalPlan},
30-
Expr,
29+
logical_plan::{Aggregate, LogicalPlan, Projection},
30+
utils::columnize_expr,
31+
Expr, ExprSchemable,
3132
};
3233

3334
use hashbrown::HashSet;
@@ -152,7 +153,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
152153

153154
// replace the distinct arg with alias
154155
let mut group_fields_set = HashSet::new();
155-
let outer_aggr_exprs = aggr_expr
156+
let new_aggr_exprs = aggr_expr
156157
.iter()
157158
.map(|aggr_expr| match aggr_expr {
158159
Expr::AggregateFunction(AggregateFunction {
@@ -174,24 +175,67 @@ impl OptimizerRule for SingleDistinctToGroupBy {
174175
false, // intentional to remove distinct here
175176
filter.clone(),
176177
order_by.clone(),
177-
))
178-
.alias(aggr_expr.display_name()?))
178+
)))
179179
}
180180
_ => Ok(aggr_expr.clone()),
181181
})
182182
.collect::<Result<Vec<_>>>()?;
183183

184184
// construct the inner AggrPlan
185+
let inner_fields = inner_group_exprs
186+
.iter()
187+
.map(|expr| expr.to_field(input.schema()))
188+
.collect::<Result<Vec<_>>>()?;
189+
let inner_schema = DFSchema::new_with_metadata(
190+
inner_fields,
191+
input.schema().metadata().clone(),
192+
)?;
185193
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
186194
input.clone(),
187195
inner_group_exprs,
188196
Vec::new(),
189197
)?);
190198

191-
Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new(
199+
let outer_fields = outer_group_exprs
200+
.iter()
201+
.chain(new_aggr_exprs.iter())
202+
.map(|expr| expr.to_field(&inner_schema))
203+
.collect::<Result<Vec<_>>>()?;
204+
let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata(
205+
outer_fields,
206+
input.schema().metadata().clone(),
207+
)?);
208+
209+
// so the aggregates are displayed in the same way even after the rewrite
210+
// this optimizer has two kinds of alias:
211+
// - group_by aggr
212+
// - aggr expr
213+
let group_size = group_expr.len();
214+
let alias_expr = out_group_expr_with_alias
215+
.into_iter()
216+
.map(|(group_expr, original_field)| {
217+
if let Some(name) = original_field {
218+
group_expr.alias(name)
219+
} else {
220+
group_expr
221+
}
222+
})
223+
.chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
224+
let idx = idx + group_size;
225+
let name = fields[idx].qualified_name();
226+
columnize_expr(expr.clone().alias(name), &outer_aggr_schema)
227+
}))
228+
.collect();
229+
230+
let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
192231
Arc::new(inner_agg),
193232
outer_group_exprs,
194-
outer_aggr_exprs,
233+
new_aggr_exprs,
234+
)?);
235+
236+
Ok(Some(LogicalPlan::Projection(Projection::try_new(
237+
alias_expr,
238+
Arc::new(outer_aggr),
195239
)?)))
196240
} else {
197241
Ok(None)
@@ -255,9 +299,10 @@ mod tests {
255299
.build()?;
256300

257301
// Should work
258-
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [COUNT(DISTINCT test.b):Int64;N]\
259-
\n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
260-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
302+
let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\
303+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\
304+
\n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
305+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
261306

262307
assert_optimized_plan_equal(&plan, expected)
263308
}
@@ -328,9 +373,10 @@ mod tests {
328373
.aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
329374
.build()?;
330375

331-
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b)]] [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
332-
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
333-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
376+
let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
377+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\
378+
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
379+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
334380

335381
assert_optimized_plan_equal(&plan, expected)
336382
}
@@ -344,9 +390,10 @@ mod tests {
344390
.build()?;
345391

346392
// Should work
347-
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
348-
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
349-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
393+
let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
394+
\n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\
395+
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
396+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
350397

351398
assert_optimized_plan_equal(&plan, expected)
352399
}
@@ -389,9 +436,10 @@ mod tests {
389436
)?
390437
.build()?;
391438
// Should work
392-
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
393-
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
394-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
439+
let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
440+
\n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
441+
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
442+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
395443

396444
assert_optimized_plan_equal(&plan, expected)
397445
}
@@ -423,9 +471,10 @@ mod tests {
423471
.build()?;
424472

425473
// Should work
426-
let expected = "Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.c)]] [group_alias_0:Int32, COUNT(DISTINCT test.c):Int64;N]\
427-
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
428-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
474+
let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\
475+
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
476+
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
477+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
429478

430479
assert_optimized_plan_equal(&plan, expected)
431480
}

datafusion/sqllogictest/test_files/groupby.slt

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,21 +3823,44 @@ query TT
38233823
EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y;
38243824
----
38253825
logical_plan
3826-
Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)
3827-
--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x)]]
3826+
Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x)
3827+
--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]]
38283828
----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]]
38293829
------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y
38303830
--------TableScan: t1 projection=[x, y]
38313831
physical_plan
3832-
ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)]
3833-
--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)]
3832+
ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)]
3833+
--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)]
38343834
----CoalesceBatchesExec: target_batch_size=2
38353835
------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8
3836-
--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)]
3836+
--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)]
38373837
----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[]
38383838
------------CoalesceBatchesExec: target_batch_size=2
38393839
--------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8
38403840
----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1
38413841
------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[]
38423842
--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y]
38433843
----------------------MemoryExec: partitions=1, partition_sizes=[1]
3844+
3845+
statement ok
3846+
drop table t1
3847+
3848+
# Reproducer for https://github.com/apache/arrow-datafusion/issues/8175
3849+
3850+
statement ok
3851+
create table t1(state string, city string, min_temp float, area int, time timestamp) as values
3852+
('MA', 'Boston', 70.4, 1, 50),
3853+
('MA', 'Bedford', 71.59, 2, 150);
3854+
3855+
query RI
3856+
select date_part('year', time) as bla, count(distinct state) as count from t1 group by bla;
3857+
----
3858+
1970 1
3859+
3860+
query PI
3861+
select date_bin(interval '1 year', time) as bla, count(distinct state) as count from t1 group by bla;
3862+
----
3863+
1970-01-01T00:00:00 1
3864+
3865+
statement ok
3866+
drop table t1

datafusion/sqllogictest/test_files/joins.slt

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,29 +1361,31 @@ from join_t1
13611361
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
13621362
----
13631363
logical_plan
1364-
Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id)]]
1365-
--Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]]
1366-
----Projection: join_t1.t1_id
1367-
------Inner Join: join_t1.t1_id = join_t2.t2_id
1368-
--------TableScan: join_t1 projection=[t1_id]
1369-
--------TableScan: join_t2 projection=[t2_id]
1364+
Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id)
1365+
--Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]]
1366+
----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]]
1367+
------Projection: join_t1.t1_id
1368+
--------Inner Join: join_t1.t1_id = join_t2.t2_id
1369+
----------TableScan: join_t1 projection=[t1_id]
1370+
----------TableScan: join_t2 projection=[t2_id]
13701371
physical_plan
1371-
AggregateExec: mode=Final, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)]
1372-
--CoalescePartitionsExec
1373-
----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)]
1374-
------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
1375-
--------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
1376-
----------ProjectionExec: expr=[t1_id@0 as t1_id]
1377-
------------CoalesceBatchesExec: target_batch_size=2
1378-
--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)]
1379-
----------------CoalesceBatchesExec: target_batch_size=2
1380-
------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
1381-
--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1382-
----------------------MemoryExec: partitions=1, partition_sizes=[1]
1383-
----------------CoalesceBatchesExec: target_batch_size=2
1384-
------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
1385-
--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1386-
----------------------MemoryExec: partitions=1, partition_sizes=[1]
1372+
ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)]
1373+
--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)]
1374+
----CoalescePartitionsExec
1375+
------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)]
1376+
--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
1377+
----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
1378+
------------ProjectionExec: expr=[t1_id@0 as t1_id]
1379+
--------------CoalesceBatchesExec: target_batch_size=2
1380+
----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)]
1381+
------------------CoalesceBatchesExec: target_batch_size=2
1382+
--------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
1383+
----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1384+
------------------------MemoryExec: partitions=1, partition_sizes=[1]
1385+
------------------CoalesceBatchesExec: target_batch_size=2
1386+
--------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
1387+
----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1388+
------------------------MemoryExec: partitions=1, partition_sizes=[1]
13871389

13881390
statement ok
13891391
set datafusion.explain.logical_plan_only = true;
@@ -3407,3 +3409,4 @@ set datafusion.optimizer.prefer_existing_sort = false;
34073409

34083410
statement ok
34093411
drop table annotated_data;
3412+

datafusion/sqllogictest/test_files/tpch/q16.slt.part

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ limit 10;
5252
logical_plan
5353
Limit: skip=0, fetch=10
5454
--Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10
55-
----Projection: part.p_brand, part.p_type, part.p_size, COUNT(DISTINCT partsupp.ps_suppkey) AS supplier_cnt
56-
------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)]]
55+
----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt
56+
------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]]
5757
--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]]
5858
----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey
5959
------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size
@@ -69,11 +69,11 @@ physical_plan
6969
GlobalLimitExec: skip=0, fetch=10
7070
--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10
7171
----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
72-
------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(DISTINCT partsupp.ps_suppkey)@3 as supplier_cnt]
73-
--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)]
72+
------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt]
73+
--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)]
7474
----------CoalesceBatchesExec: target_batch_size=8192
7575
------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4
76-
--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)]
76+
--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)]
7777
----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[]
7878
------------------CoalesceBatchesExec: target_batch_size=8192
7979
--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4

0 commit comments

Comments
 (0)