Skip to content

Commit 88c98e1

Browse files
authored
Refactor UnwrapCastInComparison to remove Expr clones (#10115)
1 parent f715d8c commit 88c98e1

1 file changed

Lines changed: 117 additions & 126 deletions

File tree

datafusion/optimizer/src/unwrap_cast_in_comparison.rs

Lines changed: 117 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`
1919
2020
use std::cmp::Ordering;
21+
use std::mem;
2122
use std::sync::Arc;
2223

2324
use crate::optimizer::ApplyOrder;
@@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
3233
use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue};
3334
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
3435
use datafusion_expr::utils::merge_schema;
35-
use datafusion_expr::{
36-
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
37-
};
36+
use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator};
3837

3938
/// [`UnwrapCastInComparison`] attempts to remove casts from
4039
/// comparisons to literals ([`ScalarValue`]s) by applying the casts
@@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter {
140139
impl TreeNodeRewriter for UnwrapCastExprRewriter {
141140
type Node = Expr;
142141

143-
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
144-
match &expr {
142+
fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
143+
match &mut expr {
145144
// For case:
146145
// try_cast/cast(expr as data_type) op literal
147146
// literal op try_cast/cast(expr as data_type)
148-
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
149-
let left = left.as_ref().clone();
150-
let right = right.as_ref().clone();
151-
let left_type = left.get_type(&self.schema)?;
152-
let right_type = right.get_type(&self.schema)?;
153-
// Because the plan has been done the type coercion, the left and right must be equal
154-
if is_support_data_type(&left_type)
155-
&& is_support_data_type(&right_type)
156-
&& is_comparison_op(op)
157-
{
158-
match (&left, &right) {
159-
(
160-
Expr::Literal(left_lit_value),
161-
Expr::TryCast(TryCast { expr, .. })
162-
| Expr::Cast(Cast { expr, .. }),
163-
) => {
164-
// if the left_lit_value can be casted to the type of expr
165-
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
166-
let expr_type = expr.get_type(&self.schema)?;
167-
let casted_scalar_value =
168-
try_cast_literal_to_type(left_lit_value, &expr_type)?;
169-
if let Some(value) = casted_scalar_value {
170-
// unwrap the cast/try_cast for the right expr
171-
return Ok(Transformed::yes(binary_expr(
172-
lit(value),
173-
*op,
174-
expr.as_ref().clone(),
175-
)));
176-
}
177-
}
178-
(
179-
Expr::TryCast(TryCast { expr, .. })
180-
| Expr::Cast(Cast { expr, .. }),
181-
Expr::Literal(right_lit_value),
182-
) => {
183-
// if the right_lit_value can be casted to the type of expr
184-
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
185-
let expr_type = expr.get_type(&self.schema)?;
186-
let casted_scalar_value =
187-
try_cast_literal_to_type(right_lit_value, &expr_type)?;
188-
if let Some(value) = casted_scalar_value {
189-
// unwrap the cast/try_cast for the left expr
190-
return Ok(Transformed::yes(binary_expr(
191-
expr.as_ref().clone(),
192-
*op,
193-
lit(value),
194-
)));
195-
}
196-
}
197-
(_, _) => {
198-
// do nothing
199-
}
147+
Expr::BinaryExpr(BinaryExpr { left, op, right })
148+
if {
149+
let Ok(left_type) = left.get_type(&self.schema) else {
150+
return Ok(Transformed::no(expr));
200151
};
152+
let Ok(right_type) = right.get_type(&self.schema) else {
153+
return Ok(Transformed::no(expr));
154+
};
155+
is_support_data_type(&left_type)
156+
&& is_support_data_type(&right_type)
157+
&& is_comparison_op(op)
158+
} =>
159+
{
160+
match (left.as_mut(), right.as_mut()) {
161+
(
162+
Expr::Literal(left_lit_value),
163+
Expr::TryCast(TryCast {
164+
expr: right_expr, ..
165+
})
166+
| Expr::Cast(Cast {
167+
expr: right_expr, ..
168+
}),
169+
) => {
170+
// if the left_lit_value can be casted to the type of expr
171+
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
172+
let Ok(expr_type) = right_expr.get_type(&self.schema) else {
173+
return Ok(Transformed::no(expr));
174+
};
175+
let Ok(Some(value)) =
176+
try_cast_literal_to_type(left_lit_value, &expr_type)
177+
else {
178+
return Ok(Transformed::no(expr));
179+
};
180+
**left = lit(value);
181+
// unwrap the cast/try_cast for the right expr
182+
**right =
183+
mem::replace(right_expr, Expr::Literal(ScalarValue::Null));
184+
Ok(Transformed::yes(expr))
185+
}
186+
(
187+
Expr::TryCast(TryCast {
188+
expr: left_expr, ..
189+
})
190+
| Expr::Cast(Cast {
191+
expr: left_expr, ..
192+
}),
193+
Expr::Literal(right_lit_value),
194+
) => {
195+
// if the right_lit_value can be casted to the type of expr
196+
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
197+
let Ok(expr_type) = left_expr.get_type(&self.schema) else {
198+
return Ok(Transformed::no(expr));
199+
};
200+
let Ok(Some(value)) =
201+
try_cast_literal_to_type(right_lit_value, &expr_type)
202+
else {
203+
return Ok(Transformed::no(expr));
204+
};
205+
// unwrap the cast/try_cast for the left expr
206+
**left =
207+
mem::replace(left_expr, Expr::Literal(ScalarValue::Null));
208+
**right = lit(value);
209+
Ok(Transformed::yes(expr))
210+
}
211+
_ => Ok(Transformed::no(expr)),
201212
}
202-
// return the new binary op
203-
Ok(Transformed::yes(binary_expr(left, *op, right)))
204213
}
205214
// For case:
206215
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
207216
Expr::InList(InList {
208-
expr: left_expr,
209-
list,
210-
negated,
217+
expr: left, list, ..
211218
}) => {
212-
if let Some(
213-
Expr::TryCast(TryCast {
214-
expr: internal_left_expr,
215-
..
216-
})
217-
| Expr::Cast(Cast {
218-
expr: internal_left_expr,
219-
..
220-
}),
221-
) = Some(left_expr.as_ref())
222-
{
223-
let internal_left = internal_left_expr.as_ref().clone();
224-
let internal_left_type = internal_left.get_type(&self.schema);
225-
if internal_left_type.is_err() {
226-
// error data type
227-
return Ok(Transformed::no(expr));
228-
}
229-
let internal_left_type = internal_left_type?;
230-
if !is_support_data_type(&internal_left_type) {
231-
// not supported data type
232-
return Ok(Transformed::no(expr));
233-
}
234-
let right_exprs = list
235-
.iter()
236-
.map(|right| {
237-
let right_type = right.get_type(&self.schema)?;
238-
if !is_support_data_type(&right_type) {
239-
return internal_err!(
240-
"The type of list expr {} not support",
241-
&right_type
242-
);
243-
}
244-
match right {
245-
Expr::Literal(right_lit_value) => {
246-
// if the right_lit_value can be casted to the type of internal_left_expr
247-
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
248-
let casted_scalar_value =
249-
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
250-
if let Some(value) = casted_scalar_value {
251-
Ok(lit(value))
252-
} else {
253-
internal_err!(
254-
"Can't cast the list expr {:?} to type {:?}",
255-
right_lit_value, &internal_left_type
256-
)
257-
}
258-
}
259-
other_expr => internal_err!(
260-
"Only support literal expr to optimize, but the expr is {:?}",
261-
&other_expr
262-
),
263-
}
264-
})
265-
.collect::<Result<Vec<_>>>();
266-
match right_exprs {
267-
Ok(right_exprs) => Ok(Transformed::yes(in_list(
268-
internal_left,
269-
right_exprs,
270-
*negated,
271-
))),
272-
Err(_) => Ok(Transformed::no(expr)),
273-
}
274-
} else {
275-
Ok(Transformed::no(expr))
219+
let (Expr::TryCast(TryCast {
220+
expr: left_expr, ..
221+
})
222+
| Expr::Cast(Cast {
223+
expr: left_expr, ..
224+
})) = left.as_mut()
225+
else {
226+
return Ok(Transformed::no(expr));
227+
};
228+
let Ok(expr_type) = left_expr.get_type(&self.schema) else {
229+
return Ok(Transformed::no(expr));
230+
};
231+
if !is_support_data_type(&expr_type) {
232+
return Ok(Transformed::no(expr));
276233
}
234+
let Ok(right_exprs) = list
235+
.iter()
236+
.map(|right| {
237+
let right_type = right.get_type(&self.schema)?;
238+
if !is_support_data_type(&right_type) {
239+
internal_err!(
240+
"The type of list expr {} is not supported",
241+
&right_type
242+
)?;
243+
}
244+
match right {
245+
Expr::Literal(right_lit_value) => {
246+
// if the right_lit_value can be casted to the type of internal_left_expr
247+
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
248+
let Ok(Some(value)) = try_cast_literal_to_type(right_lit_value, &expr_type) else {
249+
internal_err!(
250+
"Can't cast the list expr {:?} to type {:?}",
251+
right_lit_value, &expr_type
252+
)?
253+
};
254+
Ok(lit(value))
255+
}
256+
other_expr => internal_err!(
257+
"Only support literal expr to optimize, but the expr is {:?}",
258+
&other_expr
259+
),
260+
}
261+
})
262+
.collect::<Result<Vec<_>>>() else {
263+
return Ok(Transformed::no(expr))
264+
};
265+
**left = mem::replace(left_expr, Expr::Literal(ScalarValue::Null));
266+
*list = right_exprs;
267+
Ok(Transformed::yes(expr))
277268
}
278269
// TODO: handle other expr type and dfs visit them
279270
_ => Ok(Transformed::no(expr)),

0 commit comments

Comments
 (0)