Skip to content

Commit c9e610d

Browse files
committed
remove unnecessary logic on sql count wildcard
1 parent 68e3040 commit c9e610d

6 files changed

Lines changed: 186 additions & 51 deletions

File tree

datafusion/common/src/dfschema.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,9 @@ impl ExprSchema for DFSchema {
630630
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
631631
pub struct DFField {
632632
/// Optional qualifier (usually a table or relation name)
633-
qualifier: Option<OwnedTableReference>,
633+
pub qualifier: Option<OwnedTableReference>,
634634
/// Arrow field definition
635-
field: Field,
635+
pub field: Field,
636636
}
637637

638638
impl DFField {

datafusion/core/tests/dataframe.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@ async fn count_wildcard() -> Result<()> {
5151
let sql_results = ctx
5252
.sql("select count(*) from alltypes_tiny_pages")
5353
.await?
54-
.select(vec![count(Expr::Wildcard)])?
5554
.explain(false, false)?
5655
.collect()
5756
.await?;
58-
5957
// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
6058
let df_results = ctx
6159
.table("alltypes_tiny_pages")
@@ -452,7 +450,7 @@ async fn select_with_alias_overwrite() -> Result<()> {
452450
let results = df.collect().await?;
453451

454452
#[rustfmt::skip]
455-
let expected = vec![
453+
let expected = vec![
456454
"+-------+",
457455
"| a |",
458456
"+-------+",

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -756,10 +756,11 @@ async fn explain_logical_plan_only() {
756756
let expected = vec![
757757
vec![
758758
"logical_plan",
759-
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
760-
\n SubqueryAlias: t\
761-
\n Projection: column1\
762-
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
759+
"Projection: COUNT(UInt8(1))\
760+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
761+
\n SubqueryAlias: t\
762+
\n Projection: column1\
763+
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
763764
]];
764765
assert_eq!(expected, actual);
765766
}
@@ -775,9 +776,9 @@ async fn explain_physical_plan_only() {
775776

776777
let expected = vec![vec![
777778
"physical_plan",
778-
"ProjectionExec: expr=[2 as COUNT(UInt8(1))]\
779-
\n EmptyExec: produce_one_row=true\
780-
\n",
779+
"ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\
780+
\n ProjectionExec: expr=[2 as COUNT(UInt8(1))]\
781+
\n EmptyExec: produce_one_row=true\n",
781782
]];
782783
assert_eq!(expected, actual);
783784
}

datafusion/core/tests/sql/json.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,19 @@ async fn json_explain() {
8282
let actual = normalize_vec_for_explain(actual);
8383
let expected = vec![
8484
vec![
85-
"logical_plan",
86-
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
87-
\n TableScan: t1 projection=[a]",
85+
"logical_plan", "Projection: COUNT(UInt8(1))\
86+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
87+
\n TableScan: t1 projection=[a]"
8888
],
8989
vec![
90-
"physical_plan",
91-
"AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\
92-
\n CoalescePartitionsExec\
93-
\n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\
94-
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\
95-
\n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\n",
96-
],
90+
"physical_plan",
91+
"ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\
92+
\n AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\
93+
\n CoalescePartitionsExec\
94+
\n AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\
95+
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1\
96+
\n JsonExec: limit=None, files={1 group: [[WORKING_DIR/tests/jsons/2.json]]}\
97+
\n" ],
9798
];
9899
assert_eq!(expected, actual);
99100
}

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

Lines changed: 163 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,36 @@
1616
// under the License.
1717

1818
use datafusion_common::config::ConfigOptions;
19-
use datafusion_common::Result;
19+
use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result};
2020
use datafusion_expr::expr::AggregateFunction;
21+
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
2122
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
22-
use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window};
23+
use datafusion_expr::Expr::Exists;
24+
use datafusion_expr::{
25+
aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter,
26+
LogicalPlan, Projection, Subquery, Window,
27+
};
28+
use std::string::ToString;
29+
use std::sync::Arc;
2330

2431
use crate::analyzer::AnalyzerRule;
2532
use crate::rewrite::TreeNodeRewritable;
2633

34+
pub const COUNT_STAR: &str = "COUNT(*)";
35+
2736
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
2837
/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473.
38+
#[derive(Default)]
2939
pub struct CountWildcardRule {}
3040

31-
impl Default for CountWildcardRule {
32-
fn default() -> Self {
33-
CountWildcardRule::new()
34-
}
35-
}
36-
3741
impl CountWildcardRule {
3842
pub fn new() -> Self {
3943
CountWildcardRule {}
4044
}
4145
}
4246
impl AnalyzerRule for CountWildcardRule {
4347
fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
44-
plan.clone().transform_down(&analyze_internal)
48+
Ok(plan.clone().transform_down(&analyze_internal).unwrap())
4549
}
4650

4751
fn name(&self) -> &str {
@@ -50,35 +54,145 @@ impl AnalyzerRule for CountWildcardRule {
5054
}
5155

5256
fn analyze_internal(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
57+
let mut rewriter = CountWildcardRewriter {};
58+
5359
match plan {
5460
LogicalPlan::Window(window) => {
55-
let window_expr = handle_wildcard(&window.window_expr);
61+
let window_expr = window
62+
.window_expr
63+
.iter()
64+
.map(|expr| {
65+
let name = expr.name();
66+
let variant_name = expr.variant_name();
67+
expr.clone().rewrite(&mut rewriter).unwrap()
68+
})
69+
.collect::<Vec<Expr>>();
70+
5671
Ok(Some(LogicalPlan::Window(Window {
5772
input: window.input.clone(),
5873
window_expr,
59-
schema: window.schema,
74+
schema: rewrite_schema(window.schema),
6075
})))
6176
}
6277
LogicalPlan::Aggregate(agg) => {
63-
let aggr_expr = handle_wildcard(&agg.aggr_expr);
78+
let aggr_expr = agg
79+
.aggr_expr
80+
.iter()
81+
.map(|expr| expr.clone().rewrite(&mut rewriter).unwrap())
82+
.collect();
6483
Ok(Some(LogicalPlan::Aggregate(
6584
Aggregate::try_new_with_schema(
6685
agg.input.clone(),
6786
agg.group_expr.clone(),
6887
aggr_expr,
69-
agg.schema,
88+
rewrite_schema(agg.schema),
89+
)?,
90+
)))
91+
}
92+
LogicalPlan::Projection(projection) => {
93+
let projection_expr = projection
94+
.expr
95+
.iter()
96+
.map(|expr| {
97+
let name = expr.name();
98+
let variant_name = expr.variant_name();
99+
expr.clone().rewrite(&mut rewriter).unwrap()
100+
})
101+
.collect();
102+
Ok(Some(LogicalPlan::Projection(
103+
Projection::try_new_with_schema(
104+
projection_expr,
105+
projection.input,
106+
rewrite_schema(projection.schema),
70107
)?,
71108
)))
72109
}
110+
LogicalPlan::Filter(Filter {
111+
predicate, input, ..
112+
}) => {
113+
let predicate = match predicate {
114+
Exists { subquery, negated } => {
115+
let new_plan = subquery
116+
.subquery
117+
.as_ref()
118+
.clone()
119+
.transform_down(&analyze_internal)
120+
.unwrap();
121+
122+
Exists {
123+
subquery: Subquery {
124+
subquery: Arc::new(new_plan),
125+
outer_ref_columns: subquery.outer_ref_columns,
126+
},
127+
negated,
128+
}
129+
}
130+
_ => predicate,
131+
};
132+
133+
Ok(Some(LogicalPlan::Filter(
134+
Filter::try_new(predicate, input).unwrap(),
135+
)))
136+
}
137+
73138
_ => Ok(None),
74139
}
75140
}
76141

77-
// handle Count(Expr:Wildcard) with DataFrame API
78-
pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
79-
exprs
80-
.iter()
81-
.map(|expr| match expr {
142+
struct CountWildcardRewriter {}
143+
144+
impl ExprRewriter for CountWildcardRewriter {
145+
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
146+
let count_star: String = count(Expr::Wildcard).to_string();
147+
let old_expr = expr.clone();
148+
149+
let new_expr = match old_expr.clone() {
150+
Expr::Column(Column { name, relation }) if name.contains(&count_star) => {
151+
Expr::Column(Column {
152+
name: name.replace(
153+
&count_star,
154+
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
155+
),
156+
relation: relation.clone(),
157+
})
158+
}
159+
Expr::WindowFunction(expr::WindowFunction {
160+
fun:
161+
window_function::WindowFunction::AggregateFunction(
162+
aggregate_function::AggregateFunction::Count,
163+
),
164+
args,
165+
partition_by,
166+
order_by,
167+
window_frame,
168+
}) if args.len() == 1 => match args[0] {
169+
Expr::Wildcard => {
170+
Expr::WindowFunction(datafusion_expr::expr::WindowFunction {
171+
fun: window_function::WindowFunction::AggregateFunction(
172+
aggregate_function::AggregateFunction::Count,
173+
),
174+
args: vec![lit(COUNT_STAR_EXPANSION)],
175+
partition_by: partition_by.clone(),
176+
order_by: order_by.clone(),
177+
window_frame: window_frame.clone(),
178+
})
179+
}
180+
181+
_ => old_expr.clone(),
182+
},
183+
Expr::WindowFunction(expr::WindowFunction {
184+
fun:
185+
window_function::WindowFunction::AggregateFunction(
186+
aggregate_function::AggregateFunction::Count,
187+
),
188+
args,
189+
partition_by,
190+
order_by,
191+
window_frame,
192+
}) => {
193+
println!("hahahhaha {}", args[0]);
194+
old_expr.clone()
195+
}
82196
Expr::AggregateFunction(AggregateFunction {
83197
fun: aggregate_function::AggregateFunction::Count,
84198
args,
@@ -88,12 +202,39 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec<Expr> {
88202
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
89203
fun: aggregate_function::AggregateFunction::Count,
90204
args: vec![lit(COUNT_STAR_EXPANSION)],
91-
distinct: *distinct,
205+
distinct,
92206
filter: filter.clone(),
93207
}),
94-
_ => expr.clone(),
208+
_ => old_expr.clone(),
95209
},
96-
_ => expr.clone(),
210+
_ => old_expr.clone(),
211+
};
212+
Ok(new_expr)
213+
}
214+
}
215+
216+
fn rewrite_schema(schema: DFSchemaRef) -> DFSchemaRef {
217+
let new_fields = schema
218+
.fields()
219+
.iter()
220+
.map(|DFField { qualifier, field }| {
221+
let mut name = field.name().clone();
222+
if name.contains(COUNT_STAR.clone()) {
223+
name = name.replace(
224+
COUNT_STAR,
225+
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
226+
)
227+
}
228+
DFField::new(
229+
qualifier.clone(),
230+
name.as_str(),
231+
field.data_type().clone(),
232+
field.is_nullable(),
233+
)
97234
})
98-
.collect()
235+
.collect::<Vec<DFField>>();
236+
237+
DFSchemaRef::new(
238+
DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(),
239+
)
99240
}

datafusion/sql/src/expr/function.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1919
use crate::utils::normalize_ident;
2020
use datafusion_common::{DFSchema, DataFusionError, Result};
21-
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
2221
use datafusion_expr::window_frame::regularize;
2322
use datafusion_expr::{
2423
expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame,
@@ -216,12 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
216215
// Special case rewrite COUNT(*) to COUNT(constant)
217216
AggregateFunction::Count => args
218217
.into_iter()
219-
.map(|a| match a {
220-
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
221-
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
222-
}
223-
_ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context),
224-
})
218+
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context))
225219
.collect::<Result<Vec<Expr>>>()?,
226220
_ => self.function_args_to_expr(args, schema, planner_context)?,
227221
};

0 commit comments

Comments
 (0)