Skip to content

Commit f48a997

Browse files
authored
Evaluate expressions after type coercion (#3444)
* Evaluate expressions after type coercion * Fix some explains * Fix some explains * Fix some explains * Update test * Update test * Update test * Update more tests * Fix tests * Use supported date string
1 parent 97b3a4b commit f48a997

6 files changed

Lines changed: 121 additions & 88 deletions

File tree

datafusion/core/tests/sql/aggregates.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
18341834
assert_eq!(results.len(), 1);
18351835

18361836
let expected = vec![
1837-
"+--------------+-------------------------+-------------------------+-------------------------+",
1838-
"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
1839-
"+--------------+-------------------------+-------------------------+-------------------------+",
1840-
"| 1.5 | 2.5 | 3.5 | 2.5 |",
1841-
"+--------------+-------------------------+-------------------------+-------------------------+",
1837+
"+--------------+---------------------------+---------------------------+---------------------------+",
1838+
"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
1839+
"+--------------+---------------------------+---------------------------+---------------------------+",
1840+
"| 1.5 | 2.5 | 3.5 | 2.5 |",
1841+
"+--------------+---------------------------+---------------------------+---------------------------+",
18421842
];
18431843
assert_batches_sorted_eq!(expected, &results);
18441844

datafusion/core/tests/sql/decimal.rs

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
376376
actual[0].schema().field(0).data_type()
377377
);
378378
let expected = vec![
379-
"+------------------------------+",
380-
"| decimal_simple.c1 + Int64(1) |",
381-
"+------------------------------+",
382-
"| 1.000010 |",
383-
"| 1.000020 |",
384-
"| 1.000020 |",
385-
"| 1.000030 |",
386-
"| 1.000030 |",
387-
"| 1.000030 |",
388-
"| 1.000040 |",
389-
"| 1.000040 |",
390-
"| 1.000040 |",
391-
"| 1.000040 |",
392-
"| 1.000050 |",
393-
"| 1.000050 |",
394-
"| 1.000050 |",
395-
"| 1.000050 |",
396-
"| 1.000050 |",
397-
"+------------------------------+",
379+
"+----------------------------------------------------+",
380+
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
381+
"+----------------------------------------------------+",
382+
"| 1.000010 |",
383+
"| 1.000020 |",
384+
"| 1.000020 |",
385+
"| 1.000030 |",
386+
"| 1.000030 |",
387+
"| 1.000030 |",
388+
"| 1.000040 |",
389+
"| 1.000040 |",
390+
"| 1.000040 |",
391+
"| 1.000040 |",
392+
"| 1.000050 |",
393+
"| 1.000050 |",
394+
"| 1.000050 |",
395+
"| 1.000050 |",
396+
"| 1.000050 |",
397+
"+----------------------------------------------------+",
398398
];
399399
assert_batches_eq!(expected, &actual);
400400
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
@@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
434434
actual[0].schema().field(0).data_type()
435435
);
436436
let expected = vec![
437-
"+------------------------------+",
438-
"| decimal_simple.c1 - Int64(1) |",
439-
"+------------------------------+",
440-
"| -0.999990 |",
441-
"| -0.999980 |",
442-
"| -0.999980 |",
443-
"| -0.999970 |",
444-
"| -0.999970 |",
445-
"| -0.999970 |",
446-
"| -0.999960 |",
447-
"| -0.999960 |",
448-
"| -0.999960 |",
449-
"| -0.999960 |",
450-
"| -0.999950 |",
451-
"| -0.999950 |",
452-
"| -0.999950 |",
453-
"| -0.999950 |",
454-
"| -0.999950 |",
455-
"+------------------------------+",
437+
"+----------------------------------------------------+",
438+
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
439+
"+----------------------------------------------------+",
440+
"| -0.999990 |",
441+
"| -0.999980 |",
442+
"| -0.999980 |",
443+
"| -0.999970 |",
444+
"| -0.999970 |",
445+
"| -0.999970 |",
446+
"| -0.999960 |",
447+
"| -0.999960 |",
448+
"| -0.999960 |",
449+
"| -0.999960 |",
450+
"| -0.999950 |",
451+
"| -0.999950 |",
452+
"| -0.999950 |",
453+
"| -0.999950 |",
454+
"| -0.999950 |",
455+
"+----------------------------------------------------+",
456456
];
457457
assert_batches_eq!(expected, &actual);
458458

@@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
492492
actual[0].schema().field(0).data_type()
493493
);
494494
let expected = vec![
495-
"+-------------------------------+",
496-
"| decimal_simple.c1 * Int64(20) |",
497-
"+-------------------------------+",
498-
"| 0.000200 |",
499-
"| 0.000400 |",
500-
"| 0.000400 |",
501-
"| 0.000600 |",
502-
"| 0.000600 |",
503-
"| 0.000600 |",
504-
"| 0.000800 |",
505-
"| 0.000800 |",
506-
"| 0.000800 |",
507-
"| 0.000800 |",
508-
"| 0.001000 |",
509-
"| 0.001000 |",
510-
"| 0.001000 |",
511-
"| 0.001000 |",
512-
"| 0.001000 |",
513-
"+-------------------------------+",
495+
"+-----------------------------------------------------+",
496+
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
497+
"+-----------------------------------------------------+",
498+
"| 0.000200 |",
499+
"| 0.000400 |",
500+
"| 0.000400 |",
501+
"| 0.000600 |",
502+
"| 0.000600 |",
503+
"| 0.000600 |",
504+
"| 0.000800 |",
505+
"| 0.000800 |",
506+
"| 0.000800 |",
507+
"| 0.000800 |",
508+
"| 0.001000 |",
509+
"| 0.001000 |",
510+
"| 0.001000 |",
511+
"| 0.001000 |",
512+
"| 0.001000 |",
513+
"+-----------------------------------------------------+",
514514
];
515515
assert_batches_eq!(expected, &actual);
516516

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ order by
653653
let expected = "\
654654
Sort: #revenue DESC NULLS FIRST\
655655
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
656-
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
656+
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
657657
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
658658
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
659659
\n Inner Join: #customer.c_custkey = #orders.o_custkey\

datafusion/core/tests/sql/subqueries.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,9 @@ order by s_name;
336336
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
337337
Filter: #part.p_name LIKE Utf8("forest%")
338338
TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
339-
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
339+
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
340340
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
341-
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
341+
Filter: #lineitem.l_shipdate >= Date32("8766")
342342
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
343343
.to_string();
344344
assert_eq!(actual, expected);
@@ -393,7 +393,7 @@ order by cntrycode;"#;
393393
TableScan: orders projection=[o_custkey]
394394
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
395395
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
396-
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
396+
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
397397
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
398398
.to_string();
399399
assert_eq!(actual, expected);
@@ -453,7 +453,7 @@ order by value desc;
453453
TableScan: supplier projection=[s_suppkey, s_nationkey]
454454
Filter: #nation.n_name = Utf8("GERMANY")
455455
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]
456-
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
456+
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
457457
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
458458
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
459459
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey

datafusion/optimizer/src/type_coercion.rs

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Optimizer rule for type validation and coercion
1919
20+
use crate::simplify_expressions::ConstEvaluator;
2021
use crate::{OptimizerConfig, OptimizerRule};
2122
use arrow::datatypes::DataType;
2223
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
@@ -26,6 +27,7 @@ use datafusion_expr::type_coercion::data_types;
2627
use datafusion_expr::utils::from_plan;
2728
use datafusion_expr::{Expr, LogicalPlan};
2829
use datafusion_expr::{ExprSchemable, Signature};
30+
use datafusion_physical_expr::execution_props::ExecutionProps;
2931
use std::sync::Arc;
3032

3133
#[derive(Default)]
@@ -64,8 +66,14 @@ impl OptimizerRule for TypeCoercion {
6466
},
6567
);
6668

69+
let mut execution_props = ExecutionProps::new();
70+
execution_props.query_execution_start_time =
71+
optimizer_config.query_execution_start_time;
72+
let const_evaluator = ConstEvaluator::try_new(&execution_props)?;
73+
6774
let mut expr_rewrite = TypeCoercionRewriter {
6875
schema: Arc::new(schema),
76+
const_evaluator,
6977
};
7078

7179
let new_expr = plan
@@ -78,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
7886
}
7987
}
8088

81-
struct TypeCoercionRewriter {
89+
struct TypeCoercionRewriter<'a> {
8290
schema: DFSchemaRef,
91+
const_evaluator: ConstEvaluator<'a>,
8392
}
8493

85-
impl ExprRewriter for TypeCoercionRewriter {
94+
impl ExprRewriter for TypeCoercionRewriter<'_> {
8695
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
8796
Ok(RewriteRecursion::Continue)
8897
}
@@ -106,15 +115,17 @@ impl ExprRewriter for TypeCoercionRewriter {
106115
}
107116
_ => {
108117
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
109-
Ok(Expr::BinaryExpr {
118+
let expr = Expr::BinaryExpr {
110119
left: Box::new(
111120
left.clone().cast_to(&coerced_type, &self.schema)?,
112121
),
113122
op,
114123
right: Box::new(
115124
right.clone().cast_to(&coerced_type, &self.schema)?,
116125
),
117-
})
126+
};
127+
128+
expr.rewrite(&mut self.const_evaluator)
118129
}
119130
}
120131
}
@@ -133,23 +144,25 @@ impl ExprRewriter for TypeCoercionRewriter {
133144
expr_type, low_type
134145
))
135146
})?;
136-
Ok(Expr::Between {
147+
let expr = Expr::Between {
137148
expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
138149
negated,
139150
low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
140151
high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
141-
})
152+
};
153+
expr.rewrite(&mut self.const_evaluator)
142154
}
143155
Expr::ScalarUDF { fun, args } => {
144156
let new_expr = coerce_arguments_for_signature(
145157
args.as_slice(),
146158
&self.schema,
147159
&fun.signature,
148160
)?;
149-
Ok(Expr::ScalarUDF {
161+
let expr = Expr::ScalarUDF {
150162
fun,
151163
args: new_expr,
152-
})
164+
};
165+
expr.rewrite(&mut self.const_evaluator)
153166
}
154167
expr => Ok(expr),
155168
}
@@ -188,7 +201,8 @@ mod test {
188201
use crate::type_coercion::TypeCoercion;
189202
use crate::{OptimizerConfig, OptimizerRule};
190203
use arrow::datatypes::DataType;
191-
use datafusion_common::{DFSchema, Result, ScalarValue};
204+
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
205+
use datafusion_expr::{col, ColumnarValue};
192206
use datafusion_expr::{
193207
lit,
194208
logical_plan::{EmptyRelation, Projection},
@@ -199,28 +213,40 @@ mod test {
199213

200214
#[test]
201215
fn simple_case() -> Result<()> {
202-
let expr = lit(1.2_f64).lt(lit(2_u32));
216+
let expr = col("a").lt(lit(2_u32));
203217
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
204218
produce_one_row: false,
205-
schema: Arc::new(DFSchema::empty()),
219+
schema: Arc::new(
220+
DFSchema::new_with_metadata(
221+
vec![DFField::new(None, "a", DataType::Float64, true)],
222+
std::collections::HashMap::new(),
223+
)
224+
.unwrap(),
225+
),
206226
}));
207227
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
208228
let rule = TypeCoercion::new();
209229
let mut config = OptimizerConfig::default();
210230
let plan = rule.optimize(&plan, &mut config)?;
211231
assert_eq!(
212-
"Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation",
232+
"Projection: #a < Float64(2)\n EmptyRelation",
213233
&format!("{:?}", plan)
214234
);
215235
Ok(())
216236
}
217237

218238
#[test]
219239
fn nested_case() -> Result<()> {
220-
let expr = lit(1.2_f64).lt(lit(2_u32));
240+
let expr = col("a").lt(lit(2_u32));
221241
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
222242
produce_one_row: false,
223-
schema: Arc::new(DFSchema::empty()),
243+
schema: Arc::new(
244+
DFSchema::new_with_metadata(
245+
vec![DFField::new(None, "a", DataType::Float64, true)],
246+
std::collections::HashMap::new(),
247+
)
248+
.unwrap(),
249+
),
224250
}));
225251
let plan = LogicalPlan::Projection(Projection::try_new(
226252
vec![expr.clone().or(expr)],
@@ -230,8 +256,11 @@ mod test {
230256
let rule = TypeCoercion::new();
231257
let mut config = OptimizerConfig::default();
232258
let plan = rule.optimize(&plan, &mut config)?;
233-
assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\
234-
\n EmptyRelation", &format!("{:?}", plan));
259+
assert_eq!(
260+
"Projection: #a < Float64(2) OR #a < Float64(2)\
261+
\n EmptyRelation",
262+
&format!("{:?}", plan)
263+
);
235264
Ok(())
236265
}
237266

@@ -240,7 +269,11 @@ mod test {
240269
let empty = empty();
241270
let return_type: ReturnTypeFunction =
242271
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
243-
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
272+
let fun: ScalarFunctionImplementation = Arc::new(move |_| {
273+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
274+
"a".to_string(),
275+
))))
276+
});
244277
let udf = Expr::ScalarUDF {
245278
fun: Arc::new(ScalarUDF::new(
246279
"TestScalarUDF",
@@ -255,7 +288,7 @@ mod test {
255288
let mut config = OptimizerConfig::default();
256289
let plan = rule.optimize(&plan, &mut config)?;
257290
assert_eq!(
258-
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation",
291+
"Projection: Utf8(\"a\")\n EmptyRelation",
259292
&format!("{:?}", plan)
260293
);
261294
Ok(())

0 commit comments

Comments
 (0)