@@ -22,12 +22,13 @@ use std::sync::Arc;
2222use crate :: optimizer:: ApplyOrder ;
2323use crate :: { OptimizerConfig , OptimizerRule } ;
2424
25- use datafusion_common:: Result ;
25+ use datafusion_common:: { DFSchema , Result } ;
2626use datafusion_expr:: {
2727 col,
2828 expr:: AggregateFunction ,
29- logical_plan:: { Aggregate , LogicalPlan } ,
30- Expr ,
29+ logical_plan:: { Aggregate , LogicalPlan , Projection } ,
30+ utils:: columnize_expr,
31+ Expr , ExprSchemable ,
3132} ;
3233
3334use hashbrown:: HashSet ;
@@ -152,7 +153,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
152153
153154 // replace the distinct arg with alias
154155 let mut group_fields_set = HashSet :: new ( ) ;
155- let outer_aggr_exprs = aggr_expr
156+ let new_aggr_exprs = aggr_expr
156157 . iter ( )
157158 . map ( |aggr_expr| match aggr_expr {
158159 Expr :: AggregateFunction ( AggregateFunction {
@@ -174,24 +175,67 @@ impl OptimizerRule for SingleDistinctToGroupBy {
174175 false , // intentional to remove distinct here
175176 filter. clone ( ) ,
176177 order_by. clone ( ) ,
177- ) )
178- . alias ( aggr_expr. display_name ( ) ?) )
178+ ) ) )
179179 }
180180 _ => Ok ( aggr_expr. clone ( ) ) ,
181181 } )
182182 . collect :: < Result < Vec < _ > > > ( ) ?;
183183
184184 // construct the inner AggrPlan
185+ let inner_fields = inner_group_exprs
186+ . iter ( )
187+ . map ( |expr| expr. to_field ( input. schema ( ) ) )
188+ . collect :: < Result < Vec < _ > > > ( ) ?;
189+ let inner_schema = DFSchema :: new_with_metadata (
190+ inner_fields,
191+ input. schema ( ) . metadata ( ) . clone ( ) ,
192+ ) ?;
185193 let inner_agg = LogicalPlan :: Aggregate ( Aggregate :: try_new (
186194 input. clone ( ) ,
187195 inner_group_exprs,
188196 Vec :: new ( ) ,
189197 ) ?) ;
190198
191- Ok ( Some ( LogicalPlan :: Aggregate ( Aggregate :: try_new (
199+ let outer_fields = outer_group_exprs
200+ . iter ( )
201+ . chain ( new_aggr_exprs. iter ( ) )
202+ . map ( |expr| expr. to_field ( & inner_schema) )
203+ . collect :: < Result < Vec < _ > > > ( ) ?;
204+ let outer_aggr_schema = Arc :: new ( DFSchema :: new_with_metadata (
205+ outer_fields,
206+ input. schema ( ) . metadata ( ) . clone ( ) ,
207+ ) ?) ;
208+
209+ // so the aggregates are displayed in the same way even after the rewrite
210+ // this optimizer has two kinds of alias:
211+ // - group_by aggr
212+ // - aggr expr
213+ let group_size = group_expr. len ( ) ;
214+ let alias_expr = out_group_expr_with_alias
215+ . into_iter ( )
216+ . map ( |( group_expr, original_field) | {
217+ if let Some ( name) = original_field {
218+ group_expr. alias ( name)
219+ } else {
220+ group_expr
221+ }
222+ } )
223+ . chain ( new_aggr_exprs. iter ( ) . enumerate ( ) . map ( |( idx, expr) | {
224+ let idx = idx + group_size;
225+ let name = fields[ idx] . qualified_name ( ) ;
226+ columnize_expr ( expr. clone ( ) . alias ( name) , & outer_aggr_schema)
227+ } ) )
228+ . collect ( ) ;
229+
230+ let outer_aggr = LogicalPlan :: Aggregate ( Aggregate :: try_new (
192231 Arc :: new ( inner_agg) ,
193232 outer_group_exprs,
194- outer_aggr_exprs,
233+ new_aggr_exprs,
234+ ) ?) ;
235+
236+ Ok ( Some ( LogicalPlan :: Projection ( Projection :: try_new (
237+ alias_expr,
238+ Arc :: new ( outer_aggr) ,
195239 ) ?) ) )
196240 } else {
197241 Ok ( None )
@@ -255,9 +299,10 @@ mod tests {
255299 . build ( ) ?;
256300
257301 // Should work
258- let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [COUNT(DISTINCT test.b):Int64;N]\
259- \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
260- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
302+ let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\
303+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\
304+ \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
305+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
261306
262307 assert_optimized_plan_equal ( & plan, expected)
263308 }
@@ -328,9 +373,10 @@ mod tests {
328373 . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ count_distinct( lit( 2 ) * col( "b" ) ) ] ) ?
329374 . build ( ) ?;
330375
331- let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b)]] [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
332- \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
333- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
376+ let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
377+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\
378+ \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
379+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
334380
335381 assert_optimized_plan_equal ( & plan, expected)
336382 }
@@ -344,9 +390,10 @@ mod tests {
344390 . build ( ) ?;
345391
346392 // Should work
347- let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
348- \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
349- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
393+ let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
394+ \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\
395+ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
396+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
350397
351398 assert_optimized_plan_equal ( & plan, expected)
352399 }
@@ -389,9 +436,10 @@ mod tests {
389436 ) ?
390437 . build ( ) ?;
391438 // Should work
392- let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
393- \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
394- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
439+ let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
440+ \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
441+ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
442+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
395443
396444 assert_optimized_plan_equal ( & plan, expected)
397445 }
@@ -423,9 +471,10 @@ mod tests {
423471 . build ( ) ?;
424472
425473 // Should work
426- let expected = "Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.c)]] [group_alias_0:Int32, COUNT(DISTINCT test.c):Int64;N]\
427- \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
428- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
474+ let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\
475+ \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
476+ \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
477+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
429478
430479 assert_optimized_plan_equal ( & plan, expected)
431480 }
0 commit comments