diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 30140101df7b1..05a4f18e63ab2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3364,6 +3364,12 @@ mod tests { ); assert_eq!(simplify(expr.clone()), lit(true)); + // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps) + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or( + in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true), + ); + assert_eq!(simplify(expr.clone()), col("c1").not_eq(lit(4))); + // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7) let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and( in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true), @@ -3457,6 +3463,7 @@ mod tests { true, ))); // TODO: Further simplify this expression + // https://github.com/apache/arrow-datafusion/issues/8970 // assert_eq!(simplify(expr.clone()), lit(true)); assert_eq!(simplify(expr.clone()), expr); } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index fa95f1688e6f4..e9ce2734636c4 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -52,85 +52,92 @@ impl TreeNodeRewriter for InListSimplifier { type N = Expr; fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { - if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = - (left.as_ref(), op, right.as_ref()) - { - if l1.expr == l2.expr && !l1.negated && !l2.negated { - return inlist_intersection(l1, l2, false); - } else if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_union(l1, l2, true); - } else if l1.expr == l2.expr && !l1.negated && l2.negated { - return inlist_except(l1, l2); - } else if l1.expr == l2.expr && l1.negated && !l2.negated { - return inlist_except(l2, l1); + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { + match (*left, op, *right) { + (Expr::InList(l1), Operator::And, Expr::InList(l2)) + if l1.expr == l2.expr && !l1.negated && !l2.negated => + { + inlist_intersection(l1, l2, false) } - } else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) = - (left.as_ref(), op, right.as_ref()) - { - if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_intersection(l1, l2, true); + (Expr::InList(l1), Operator::And, Expr::InList(l2)) + if l1.expr == l2.expr && l1.negated && l2.negated => + { + inlist_union(l1, l2, true) + } + (Expr::InList(l1), Operator::And, Expr::InList(l2)) + if l1.expr == l2.expr && !l1.negated && l2.negated => + { + inlist_except(l1, l2) + } + (Expr::InList(l1), Operator::And, Expr::InList(l2)) + if l1.expr == l2.expr && l1.negated && !l2.negated => + { + inlist_except(l2, l1) + } + (Expr::InList(l1), Operator::Or, Expr::InList(l2)) + if l1.expr == l2.expr && l1.negated && l2.negated => + { + inlist_intersection(l1, l2, true) + } + (left, op, right) => { + // put the expression back together + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })) } } + } else { + Ok(expr) } - - Ok(expr) } } -fn inlist_union(l1: &InList, l2: &InList, negated: bool) -> Result { - let mut seen: HashSet = HashSet::new(); - let list = l1 - .list - .iter() - .chain(l2.list.iter()) - .filter(|&e| seen.insert(e.to_owned())) - .cloned() - .collect::>(); - let merged_inlist = InList { - expr: l1.expr.clone(), - list, - negated, - }; - Ok(Expr::InList(merged_inlist)) -} +/// Return the union of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { + // extend the list in l1 with the elements in l2 that are not already in l1 + let l1_items: HashSet<_> = l1.list.iter().collect(); -fn inlist_intersection(l1: &InList, l2: &InList, negated: bool) -> Result { - let l1_set: HashSet = l1.list.iter().cloned().collect(); - let intersect_list: Vec = l2 + // keep all l2 items that do not also appear in l1 + let keep_l2: Vec<_> = l2 .list - .iter() - .filter(|x| l1_set.contains(x)) - .cloned() + .into_iter() + .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) .collect(); + + l1.list.extend(keep_l2); + l1.negated = negated; + Ok(Expr::InList(l1)) +} + +/// Return the intersection of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // remove all items from l1 that are not in l2 + l1.list.retain(|e| l2_items.contains(e)); + // e in () is always false // e not in () is always true - if intersect_list.is_empty() { + if l1.list.is_empty() { return Ok(lit(negated)); } - let merged_inlist = InList { - expr: l1.expr.clone(), - list: intersect_list, - negated, - }; - Ok(Expr::InList(merged_inlist)) + Ok(Expr::InList(l1)) } -fn inlist_except(l1: &InList, l2: &InList) -> Result { - let l2_set: HashSet = l2.list.iter().cloned().collect(); - let except_list: Vec = l1 - .list - .iter() - .filter(|x| !l2_set.contains(x)) - .cloned() - .collect(); - if except_list.is_empty() { +/// Return the all items in l1 that are not in l2 +/// maintaining the order of the elements in the two lists +fn inlist_except(mut l1: InList, l2: InList) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // keep only items from l1 that are not in l2 + l1.list.retain(|e| !l2_items.contains(e)); + + if l1.list.is_empty() { return Ok(lit(false)); } - let merged_inlist = InList { - expr: l1.expr.clone(), - list: except_list, - negated: false, - }; - Ok(Expr::InList(merged_inlist)) + Ok(Expr::InList(l1)) }