Skip to content

Commit 858b284

Browse files
Update test to expect remaining filter clause
1 parent 23b0ffb commit 858b284

2 files changed

Lines changed: 52 additions & 47 deletions

File tree

datafusion/core/tests/sql/subqueries.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ async fn tpch_q4_correlated() -> Result<()> {
4646
TableScan: orders projection=[o_orderkey, o_orderpriority]
4747
Projection: #lineitem.l_orderkey
4848
Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]]
49-
TableScan: lineitem projection=[l_orderkey]"#
49+
Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate
50+
TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate], partial_filters=[#lineitem.l_commitdate < #lineitem.l_receiptdate]"#
5051
.to_string();
5152
assert_eq!(actual, expected);
5253

datafusion/optimizer/src/subquery_decorrelate.rs

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use crate::{utils, OptimizerConfig, OptimizerRule};
2-
use datafusion_common::{Column};
2+
use datafusion_common::Column;
33
use datafusion_expr::logical_plan::{Filter, JoinType, Subquery};
4-
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator};
4+
use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator};
55
use hashbrown::HashSet;
6-
use std::sync::Arc;
76
use itertools::{Either, Itertools};
7+
use std::sync::Arc;
88

99
/// Optimizer rule for rewriting subquery filters to joins
1010
#[derive(Default)]
@@ -25,7 +25,7 @@ impl OptimizerRule for SubqueryDecorrelate {
2525
) -> datafusion_common::Result<LogicalPlan> {
2626
match plan {
2727
LogicalPlan::Filter(Filter { predicate, input }) => {
28-
return match predicate {
28+
match predicate {
2929
// TODO: arbitrary expressions
3030
Expr::Exists { subquery, negated } => {
3131
if *negated {
@@ -34,7 +34,7 @@ impl OptimizerRule for SubqueryDecorrelate {
3434
optimize_exists(plan, subquery, input)
3535
}
3636
_ => Ok(plan.clone()),
37-
};
37+
}
3838
}
3939
_ => {
4040
// Apply the optimization to all inputs of the plan
@@ -84,7 +84,8 @@ fn optimize_exists(
8484
utils::split_conjunction(&filter.predicate, &mut filters);
8585

8686
// get names of fields TODO: Must fully qualify these!
87-
let fields: HashSet<_> = sub_input.schema()
87+
let fields: HashSet<_> = sub_input
88+
.schema()
8889
.fields()
8990
.iter()
9091
.map(|f| f.name())
@@ -97,18 +98,26 @@ fn optimize_exists(
9798
}
9899

99100
// Only operate if one column is present and the other closed upon from outside scope
100-
let l_col: Vec<_> = cols.iter()
101+
let l_col: Vec<_> = cols
102+
.iter()
101103
.map(|it| &it.0)
102104
.map(|it| Column::from_qualified_name(it.as_str()))
103105
.collect();
104-
let r_col: Vec<_> = cols.iter()
106+
let r_col: Vec<_> = cols
107+
.iter()
105108
.map(|it| &it.1)
106109
.map(|it| Column::from_qualified_name(it.as_str()))
107110
.collect();
108111
let expr: Vec<_> = r_col.iter().map(|it| Expr::Column(it.clone())).collect();
109112
let aggr_expr: Vec<Expr> = vec![];
110-
let join_keys = (l_col.clone(), r_col.clone());
111-
let right = LogicalPlanBuilder::from((*filter.input).clone())
113+
let join_keys = (l_col, r_col);
114+
let right = LogicalPlanBuilder::from((*filter.input).clone());
115+
let right = if let Some(expr) = combine_filters(&others) {
116+
right.filter(expr)?
117+
} else {
118+
right
119+
};
120+
let right = right
112121
.aggregate(expr.clone(), aggr_expr)?
113122
.project(expr)?
114123
.build()?;
@@ -122,43 +131,38 @@ fn find_join_exprs(
122131
filters: Vec<&Expr>,
123132
fields: &HashSet<&String>,
124133
) -> (Vec<(String, String)>, Vec<Expr>) {
125-
let (joins, others): (Vec<_>, Vec<_>) = filters.iter()
126-
.partition_map(|filter| {
127-
let (left, op, right) = match filter {
128-
Expr::BinaryExpr { left, op, right } => {
129-
(*left.clone(), op.clone(), *right.clone())
130-
}
131-
_ => {
132-
return Either::Right((*filter).clone())
133-
}
134-
};
135-
match op {
136-
Operator::Eq => {}
137-
_ => return Either::Right((*filter).clone()),
138-
}
139-
let left = match left {
140-
Expr::Column(c) => c,
141-
_ => return Either::Right((*filter).clone()),
142-
};
143-
let right = match right {
144-
Expr::Column(c) => c,
145-
_ => return Either::Right((*filter).clone()),
146-
};
147-
if fields.contains(&left.name) && fields.contains(&right.name) {
148-
return Either::Right((*filter).clone()); // Need one of each
149-
}
150-
if !fields.contains(&left.name) && !fields.contains(&right.name) {
151-
return Either::Right((*filter).clone()); // Need one of each
152-
}
134+
let (joins, others): (Vec<_>, Vec<_>) = filters.iter().partition_map(|filter| {
135+
let (left, op, right) = match filter {
136+
Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()),
137+
_ => return Either::Right((*filter).clone()),
138+
};
139+
match op {
140+
Operator::Eq => {}
141+
_ => return Either::Right((*filter).clone()),
142+
}
143+
let left = match left {
144+
Expr::Column(c) => c,
145+
_ => return Either::Right((*filter).clone()),
146+
};
147+
let right = match right {
148+
Expr::Column(c) => c,
149+
_ => return Either::Right((*filter).clone()),
150+
};
151+
if fields.contains(&left.name) && fields.contains(&right.name) {
152+
return Either::Right((*filter).clone()); // Need one of each
153+
}
154+
if !fields.contains(&left.name) && !fields.contains(&right.name) {
155+
return Either::Right((*filter).clone()); // Need one of each
156+
}
153157

154-
let sorted = if fields.contains(&left.name) {
155-
(right.name.clone(), left.name.clone())
156-
} else {
157-
(left.name.clone(), right.name.clone())
158-
};
158+
let sorted = if fields.contains(&left.name) {
159+
(right.name, left.name)
160+
} else {
161+
(left.name, right.name)
162+
};
159163

160-
Either::Left(sorted)
161-
});
164+
Either::Left(sorted)
165+
});
162166

163167
(joins, others)
164-
}
168+
}

0 commit comments

Comments
 (0)