1818//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`
1919
2020use std:: cmp:: Ordering ;
21+ use std:: mem;
2122use std:: sync:: Arc ;
2223
2324use crate :: optimizer:: ApplyOrder ;
@@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
3233use datafusion_common:: { internal_err, DFSchema , DFSchemaRef , Result , ScalarValue } ;
3334use datafusion_expr:: expr:: { BinaryExpr , Cast , InList , TryCast } ;
3435use datafusion_expr:: utils:: merge_schema;
35- use datafusion_expr:: {
36- binary_expr, in_list, lit, Expr , ExprSchemable , LogicalPlan , Operator ,
37- } ;
36+ use datafusion_expr:: { lit, Expr , ExprSchemable , LogicalPlan , Operator } ;
3837
3938/// [`UnwrapCastInComparison`] attempts to remove casts from
4039/// comparisons to literals ([`ScalarValue`]s) by applying the casts
@@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter {
140139impl TreeNodeRewriter for UnwrapCastExprRewriter {
141140 type Node = Expr ;
142141
143- fn f_up ( & mut self , expr : Expr ) -> Result < Transformed < Expr > > {
144- match & expr {
142+ fn f_up ( & mut self , mut expr : Expr ) -> Result < Transformed < Expr > > {
143+ match & mut expr {
145144 // For case:
146145 // try_cast/cast(expr as data_type) op literal
147146 // literal op try_cast/cast(expr as data_type)
148- Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
149- let left = left. as_ref ( ) . clone ( ) ;
150- let right = right. as_ref ( ) . clone ( ) ;
151- let left_type = left. get_type ( & self . schema ) ?;
152- let right_type = right. get_type ( & self . schema ) ?;
153- // Because the plan has been done the type coercion, the left and right must be equal
154- if is_support_data_type ( & left_type)
155- && is_support_data_type ( & right_type)
156- && is_comparison_op ( op)
157- {
158- match ( & left, & right) {
159- (
160- Expr :: Literal ( left_lit_value) ,
161- Expr :: TryCast ( TryCast { expr, .. } )
162- | Expr :: Cast ( Cast { expr, .. } ) ,
163- ) => {
164- // if the left_lit_value can be casted to the type of expr
165- // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
166- let expr_type = expr. get_type ( & self . schema ) ?;
167- let casted_scalar_value =
168- try_cast_literal_to_type ( left_lit_value, & expr_type) ?;
169- if let Some ( value) = casted_scalar_value {
170- // unwrap the cast/try_cast for the right expr
171- return Ok ( Transformed :: yes ( binary_expr (
172- lit ( value) ,
173- * op,
174- expr. as_ref ( ) . clone ( ) ,
175- ) ) ) ;
176- }
177- }
178- (
179- Expr :: TryCast ( TryCast { expr, .. } )
180- | Expr :: Cast ( Cast { expr, .. } ) ,
181- Expr :: Literal ( right_lit_value) ,
182- ) => {
183- // if the right_lit_value can be casted to the type of expr
184- // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
185- let expr_type = expr. get_type ( & self . schema ) ?;
186- let casted_scalar_value =
187- try_cast_literal_to_type ( right_lit_value, & expr_type) ?;
188- if let Some ( value) = casted_scalar_value {
189- // unwrap the cast/try_cast for the left expr
190- return Ok ( Transformed :: yes ( binary_expr (
191- expr. as_ref ( ) . clone ( ) ,
192- * op,
193- lit ( value) ,
194- ) ) ) ;
195- }
196- }
197- ( _, _) => {
198- // do nothing
199- }
147+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
148+ if {
149+ let Ok ( left_type) = left. get_type ( & self . schema ) else {
150+ return Ok ( Transformed :: no ( expr) ) ;
200151 } ;
152+ let Ok ( right_type) = right. get_type ( & self . schema ) else {
153+ return Ok ( Transformed :: no ( expr) ) ;
154+ } ;
155+ is_support_data_type ( & left_type)
156+ && is_support_data_type ( & right_type)
157+ && is_comparison_op ( op)
158+ } =>
159+ {
160+ match ( left. as_mut ( ) , right. as_mut ( ) ) {
161+ (
162+ Expr :: Literal ( left_lit_value) ,
163+ Expr :: TryCast ( TryCast {
164+ expr : right_expr, ..
165+ } )
166+ | Expr :: Cast ( Cast {
167+ expr : right_expr, ..
168+ } ) ,
169+ ) => {
170+ // if the left_lit_value can be casted to the type of expr
171+ // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
172+ let Ok ( expr_type) = right_expr. get_type ( & self . schema ) else {
173+ return Ok ( Transformed :: no ( expr) ) ;
174+ } ;
175+ let Ok ( Some ( value) ) =
176+ try_cast_literal_to_type ( left_lit_value, & expr_type)
177+ else {
178+ return Ok ( Transformed :: no ( expr) ) ;
179+ } ;
180+ * * left = lit ( value) ;
181+ // unwrap the cast/try_cast for the right expr
182+ * * right =
183+ mem:: replace ( right_expr, Expr :: Literal ( ScalarValue :: Null ) ) ;
184+ Ok ( Transformed :: yes ( expr) )
185+ }
186+ (
187+ Expr :: TryCast ( TryCast {
188+ expr : left_expr, ..
189+ } )
190+ | Expr :: Cast ( Cast {
191+ expr : left_expr, ..
192+ } ) ,
193+ Expr :: Literal ( right_lit_value) ,
194+ ) => {
195+ // if the right_lit_value can be casted to the type of expr
196+ // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
197+ let Ok ( expr_type) = left_expr. get_type ( & self . schema ) else {
198+ return Ok ( Transformed :: no ( expr) ) ;
199+ } ;
200+ let Ok ( Some ( value) ) =
201+ try_cast_literal_to_type ( right_lit_value, & expr_type)
202+ else {
203+ return Ok ( Transformed :: no ( expr) ) ;
204+ } ;
205+ // unwrap the cast/try_cast for the left expr
206+ * * left =
207+ mem:: replace ( left_expr, Expr :: Literal ( ScalarValue :: Null ) ) ;
208+ * * right = lit ( value) ;
209+ Ok ( Transformed :: yes ( expr) )
210+ }
211+ _ => Ok ( Transformed :: no ( expr) ) ,
201212 }
202- // return the new binary op
203- Ok ( Transformed :: yes ( binary_expr ( left, * op, right) ) )
204213 }
205214 // For case:
206215 // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
207216 Expr :: InList ( InList {
208- expr : left_expr,
209- list,
210- negated,
217+ expr : left, list, ..
211218 } ) => {
212- if let Some (
213- Expr :: TryCast ( TryCast {
214- expr : internal_left_expr,
215- ..
216- } )
217- | Expr :: Cast ( Cast {
218- expr : internal_left_expr,
219- ..
220- } ) ,
221- ) = Some ( left_expr. as_ref ( ) )
222- {
223- let internal_left = internal_left_expr. as_ref ( ) . clone ( ) ;
224- let internal_left_type = internal_left. get_type ( & self . schema ) ;
225- if internal_left_type. is_err ( ) {
226- // error data type
227- return Ok ( Transformed :: no ( expr) ) ;
228- }
229- let internal_left_type = internal_left_type?;
230- if !is_support_data_type ( & internal_left_type) {
231- // not supported data type
232- return Ok ( Transformed :: no ( expr) ) ;
233- }
234- let right_exprs = list
235- . iter ( )
236- . map ( |right| {
237- let right_type = right. get_type ( & self . schema ) ?;
238- if !is_support_data_type ( & right_type) {
239- return internal_err ! (
240- "The type of list expr {} not support" ,
241- & right_type
242- ) ;
243- }
244- match right {
245- Expr :: Literal ( right_lit_value) => {
246- // if the right_lit_value can be casted to the type of internal_left_expr
247- // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
248- let casted_scalar_value =
249- try_cast_literal_to_type ( right_lit_value, & internal_left_type) ?;
250- if let Some ( value) = casted_scalar_value {
251- Ok ( lit ( value) )
252- } else {
253- internal_err ! (
254- "Can't cast the list expr {:?} to type {:?}" ,
255- right_lit_value, & internal_left_type
256- )
257- }
258- }
259- other_expr => internal_err ! (
260- "Only support literal expr to optimize, but the expr is {:?}" ,
261- & other_expr
262- ) ,
263- }
264- } )
265- . collect :: < Result < Vec < _ > > > ( ) ;
266- match right_exprs {
267- Ok ( right_exprs) => Ok ( Transformed :: yes ( in_list (
268- internal_left,
269- right_exprs,
270- * negated,
271- ) ) ) ,
272- Err ( _) => Ok ( Transformed :: no ( expr) ) ,
273- }
274- } else {
275- Ok ( Transformed :: no ( expr) )
219+ let ( Expr :: TryCast ( TryCast {
220+ expr : left_expr, ..
221+ } )
222+ | Expr :: Cast ( Cast {
223+ expr : left_expr, ..
224+ } ) ) = left. as_mut ( )
225+ else {
226+ return Ok ( Transformed :: no ( expr) ) ;
227+ } ;
228+ let Ok ( expr_type) = left_expr. get_type ( & self . schema ) else {
229+ return Ok ( Transformed :: no ( expr) ) ;
230+ } ;
231+ if !is_support_data_type ( & expr_type) {
232+ return Ok ( Transformed :: no ( expr) ) ;
276233 }
234+ let Ok ( right_exprs) = list
235+ . iter ( )
236+ . map ( |right| {
237+ let right_type = right. get_type ( & self . schema ) ?;
238+ if !is_support_data_type ( & right_type) {
239+ internal_err ! (
240+ "The type of list expr {} is not supported" ,
241+ & right_type
242+ ) ?;
243+ }
244+ match right {
245+ Expr :: Literal ( right_lit_value) => {
246+ // if the right_lit_value can be casted to the type of internal_left_expr
247+ // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
248+ let Ok ( Some ( value) ) = try_cast_literal_to_type ( right_lit_value, & expr_type) else {
249+ internal_err ! (
250+ "Can't cast the list expr {:?} to type {:?}" ,
251+ right_lit_value, & expr_type
252+ ) ?
253+ } ;
254+ Ok ( lit ( value) )
255+ }
256+ other_expr => internal_err ! (
257+ "Only support literal expr to optimize, but the expr is {:?}" ,
258+ & other_expr
259+ ) ,
260+ }
261+ } )
262+ . collect :: < Result < Vec < _ > > > ( ) else {
263+ return Ok ( Transformed :: no ( expr) )
264+ } ;
265+ * * left = mem:: replace ( left_expr, Expr :: Literal ( ScalarValue :: Null ) ) ;
266+ * list = right_exprs;
267+ Ok ( Transformed :: yes ( expr) )
277268 }
278269 // TODO: handle other expr type and dfs visit them
279270 _ => Ok ( Transformed :: no ( expr) ) ,
0 commit comments