1616// under the License.
1717
1818use datafusion_common:: config:: ConfigOptions ;
19- use datafusion_common:: Result ;
19+ use datafusion_common:: { Column , DFField , DFSchema , DFSchemaRef , Result } ;
2020use datafusion_expr:: expr:: AggregateFunction ;
21+ use datafusion_expr:: expr_rewriter:: { ExprRewritable , ExprRewriter } ;
2122use datafusion_expr:: utils:: COUNT_STAR_EXPANSION ;
22- use datafusion_expr:: { aggregate_function, lit, Aggregate , Expr , LogicalPlan , Window } ;
23+ use datafusion_expr:: Expr :: Exists ;
24+ use datafusion_expr:: {
25+ aggregate_function, count, expr, lit, window_function, Aggregate , Expr , Filter ,
26+ LogicalPlan , Projection , Subquery , Window ,
27+ } ;
28+ use std:: string:: ToString ;
29+ use std:: sync:: Arc ;
2330
2431use crate :: analyzer:: AnalyzerRule ;
2532use crate :: rewrite:: TreeNodeRewritable ;
2633
34+ pub const COUNT_STAR : & str = "COUNT(*)" ;
35+
2736/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
2837/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
38+ #[ derive( Default ) ]
2939pub struct CountWildcardRule { }
3040
31- impl Default for CountWildcardRule {
32- fn default ( ) -> Self {
33- CountWildcardRule :: new ( )
34- }
35- }
36-
3741impl CountWildcardRule {
3842 pub fn new ( ) -> Self {
3943 CountWildcardRule { }
4044 }
4145}
4246impl AnalyzerRule for CountWildcardRule {
4347 fn analyze ( & self , plan : & LogicalPlan , _: & ConfigOptions ) -> Result < LogicalPlan > {
44- plan. clone ( ) . transform_down ( & analyze_internal)
48+ Ok ( plan. clone ( ) . transform_down ( & analyze_internal) . unwrap ( ) )
4549 }
4650
4751 fn name ( & self ) -> & str {
@@ -50,35 +54,145 @@ impl AnalyzerRule for CountWildcardRule {
5054}
5155
5256fn analyze_internal ( plan : LogicalPlan ) -> Result < Option < LogicalPlan > > {
57+ let mut rewriter = CountWildcardRewriter { } ;
58+
5359 match plan {
5460 LogicalPlan :: Window ( window) => {
55- let window_expr = handle_wildcard ( & window. window_expr ) ;
61+ let window_expr = window
62+ . window_expr
63+ . iter ( )
64+ . map ( |expr| {
65+ let name = expr. name ( ) ;
66+ let variant_name = expr. variant_name ( ) ;
67+ expr. clone ( ) . rewrite ( & mut rewriter) . unwrap ( )
68+ } )
69+ . collect :: < Vec < Expr > > ( ) ;
70+
5671 Ok ( Some ( LogicalPlan :: Window ( Window {
5772 input : window. input . clone ( ) ,
5873 window_expr,
59- schema : window. schema ,
74+ schema : rewrite_schema ( window. schema ) ,
6075 } ) ) )
6176 }
6277 LogicalPlan :: Aggregate ( agg) => {
63- let aggr_expr = handle_wildcard ( & agg. aggr_expr ) ;
78+ let aggr_expr = agg
79+ . aggr_expr
80+ . iter ( )
81+ . map ( |expr| expr. clone ( ) . rewrite ( & mut rewriter) . unwrap ( ) )
82+ . collect ( ) ;
6483 Ok ( Some ( LogicalPlan :: Aggregate (
6584 Aggregate :: try_new_with_schema (
6685 agg. input . clone ( ) ,
6786 agg. group_expr . clone ( ) ,
6887 aggr_expr,
69- agg. schema ,
88+ rewrite_schema ( agg. schema ) ,
89+ ) ?,
90+ ) ) )
91+ }
92+ LogicalPlan :: Projection ( projection) => {
93+ let projection_expr = projection
94+ . expr
95+ . iter ( )
96+ . map ( |expr| {
97+ let name = expr. name ( ) ;
98+ let variant_name = expr. variant_name ( ) ;
99+ expr. clone ( ) . rewrite ( & mut rewriter) . unwrap ( )
100+ } )
101+ . collect ( ) ;
102+ Ok ( Some ( LogicalPlan :: Projection (
103+ Projection :: try_new_with_schema (
104+ projection_expr,
105+ projection. input ,
106+ rewrite_schema ( projection. schema ) ,
70107 ) ?,
71108 ) ) )
72109 }
110+ LogicalPlan :: Filter ( Filter {
111+ predicate, input, ..
112+ } ) => {
113+ let predicate = match predicate {
114+ Exists { subquery, negated } => {
115+ let new_plan = subquery
116+ . subquery
117+ . as_ref ( )
118+ . clone ( )
119+ . transform_down ( & analyze_internal)
120+ . unwrap ( ) ;
121+
122+ Exists {
123+ subquery : Subquery {
124+ subquery : Arc :: new ( new_plan) ,
125+ outer_ref_columns : subquery. outer_ref_columns ,
126+ } ,
127+ negated,
128+ }
129+ }
130+ _ => predicate,
131+ } ;
132+
133+ Ok ( Some ( LogicalPlan :: Filter (
134+ Filter :: try_new ( predicate, input) . unwrap ( ) ,
135+ ) ) )
136+ }
137+
73138 _ => Ok ( None ) ,
74139 }
75140}
76141
77- // handle Count(Expr:Wildcard) with DataFrame API
78- pub fn handle_wildcard ( exprs : & [ Expr ] ) -> Vec < Expr > {
79- exprs
80- . iter ( )
81- . map ( |expr| match expr {
142+ struct CountWildcardRewriter { }
143+
144+ impl ExprRewriter for CountWildcardRewriter {
145+ fn mutate ( & mut self , expr : Expr ) -> Result < Expr > {
146+ let count_star: String = count ( Expr :: Wildcard ) . to_string ( ) ;
147+ let old_expr = expr. clone ( ) ;
148+
149+ let new_expr = match old_expr. clone ( ) {
150+ Expr :: Column ( Column { name, relation } ) if name. contains ( & count_star) => {
151+ Expr :: Column ( Column {
152+ name : name. replace (
153+ & count_star,
154+ count ( lit ( COUNT_STAR_EXPANSION ) ) . to_string ( ) . as_str ( ) ,
155+ ) ,
156+ relation : relation. clone ( ) ,
157+ } )
158+ }
159+ Expr :: WindowFunction ( expr:: WindowFunction {
160+ fun :
161+ window_function:: WindowFunction :: AggregateFunction (
162+ aggregate_function:: AggregateFunction :: Count ,
163+ ) ,
164+ args,
165+ partition_by,
166+ order_by,
167+ window_frame,
168+ } ) if args. len ( ) == 1 => match args[ 0 ] {
169+ Expr :: Wildcard => {
170+ Expr :: WindowFunction ( datafusion_expr:: expr:: WindowFunction {
171+ fun : window_function:: WindowFunction :: AggregateFunction (
172+ aggregate_function:: AggregateFunction :: Count ,
173+ ) ,
174+ args : vec ! [ lit( COUNT_STAR_EXPANSION ) ] ,
175+ partition_by : partition_by. clone ( ) ,
176+ order_by : order_by. clone ( ) ,
177+ window_frame : window_frame. clone ( ) ,
178+ } )
179+ }
180+
181+ _ => old_expr. clone ( ) ,
182+ } ,
183+ Expr :: WindowFunction ( expr:: WindowFunction {
184+ fun :
185+ window_function:: WindowFunction :: AggregateFunction (
186+ aggregate_function:: AggregateFunction :: Count ,
187+ ) ,
188+ args,
189+ partition_by,
190+ order_by,
191+ window_frame,
192+ } ) => {
193+ println ! ( "hahahhaha {}" , args[ 0 ] ) ;
194+ old_expr. clone ( )
195+ }
82196 Expr :: AggregateFunction ( AggregateFunction {
83197 fun : aggregate_function:: AggregateFunction :: Count ,
84198 args,
@@ -88,12 +202,39 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
88202 Expr :: Wildcard => Expr :: AggregateFunction ( AggregateFunction {
89203 fun : aggregate_function:: AggregateFunction :: Count ,
90204 args : vec ! [ lit( COUNT_STAR_EXPANSION ) ] ,
91- distinct : * distinct ,
205+ distinct,
92206 filter : filter. clone ( ) ,
93207 } ) ,
94- _ => expr . clone ( ) ,
208+ _ => old_expr . clone ( ) ,
95209 } ,
96- _ => expr. clone ( ) ,
210+ _ => old_expr. clone ( ) ,
211+ } ;
212+ Ok ( new_expr)
213+ }
214+ }
215+
216+ fn rewrite_schema ( schema : DFSchemaRef ) -> DFSchemaRef {
217+ let new_fields = schema
218+ . fields ( )
219+ . iter ( )
220+ . map ( |DFField { qualifier, field } | {
221+ let mut name = field. name ( ) . clone ( ) ;
222+ if name. contains ( COUNT_STAR . clone ( ) ) {
223+ name = name. replace (
224+ COUNT_STAR ,
225+ count ( lit ( COUNT_STAR_EXPANSION ) ) . to_string ( ) . as_str ( ) ,
226+ )
227+ }
228+ DFField :: new (
229+ qualifier. clone ( ) ,
230+ name. as_str ( ) ,
231+ field. data_type ( ) . clone ( ) ,
232+ field. is_nullable ( ) ,
233+ )
97234 } )
98- . collect ( )
235+ . collect :: < Vec < DFField > > ( ) ;
236+
237+ DFSchemaRef :: new (
238+ DFSchema :: new_with_metadata ( new_fields, schema. metadata ( ) . clone ( ) ) . unwrap ( ) ,
239+ )
99240}
0 commit comments