Skip to content

Commit 61834d4

Browse files
authored
Implement basic common subexpression eliminate optimization (#792)
* basic impl desc identifier Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * expr rewriter Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix: get expr's type from input schema Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * clean & doc Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix: alias projection & rewriter check expr count Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fill blank exprs & plans Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix: correct column expr's name in generating proj plan Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix: step index in fast path Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * add rewriter control enum, fix create_name Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix topk test case Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * add LogicalPlan::Analyze into optimize() * more tests * fix unprojected filter
1 parent 6402200 commit 61834d4

6 files changed

Lines changed: 893 additions & 16 deletions

File tree

datafusion/src/execution/context.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ use crate::execution::dataframe_impl::DataFrameImpl;
5757
use crate::logical_plan::{
5858
FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE,
5959
};
60+
use crate::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
6061
use crate::optimizer::constant_folding::ConstantFolding;
6162
use crate::optimizer::filter_push_down::FilterPushDown;
6263
use crate::optimizer::limit_push_down::LimitPushDown;
@@ -741,6 +742,7 @@ impl Default for ExecutionConfig {
741742
batch_size: 8192,
742743
optimizers: vec![
743744
Arc::new(ConstantFolding::new()),
745+
Arc::new(CommonSubexprEliminate::new()),
744746
Arc::new(EliminateLimit::new()),
745747
Arc::new(ProjectionPushDown::new()),
746748
Arc::new(FilterPushDown::new()),
@@ -1020,6 +1022,7 @@ impl FunctionRegistry for ExecutionContextState {
10201022
mod tests {
10211023

10221024
use super::*;
1025+
use crate::logical_plan::{binary_expr, lit, Operator};
10231026
use crate::physical_plan::functions::make_scalar_function;
10241027
use crate::physical_plan::{collect, collect_partitioned};
10251028
use crate::test;
@@ -1998,6 +2001,27 @@ mod tests {
19982001
Ok(())
19992002
}
20002003

2004+
#[tokio::test]
2005+
async fn aggregate_avg_add() -> Result<()> {
2006+
let results = execute(
2007+
"SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test",
2008+
4,
2009+
)
2010+
.await?;
2011+
assert_eq!(results.len(), 1);
2012+
2013+
let expected = vec![
2014+
"+--------------+----------------------------+----------------------------+----------------------------+",
2015+
"| AVG(test.c1) | AVG(test.c1) Plus Int64(1) | AVG(test.c1) Plus Int64(2) | Int64(1) Plus AVG(test.c1) |",
2016+
"+--------------+----------------------------+----------------------------+----------------------------+",
2017+
"| 1.5 | 2.5 | 3.5 | 2.5 |",
2018+
"+--------------+----------------------------+----------------------------+----------------------------+",
2019+
];
2020+
assert_batches_sorted_eq!(expected, &results);
2021+
2022+
Ok(())
2023+
}
2024+
20012025
#[tokio::test]
20022026
async fn join_partitioned() -> Result<()> {
20032027
// self join on partition id (workaround for duplicate column name)
@@ -2166,6 +2190,30 @@ mod tests {
21662190
}
21672191
}
21682192

2193+
#[tokio::test]
2194+
async fn unprojected_filter() {
2195+
let mut ctx = ExecutionContext::new();
2196+
let df = ctx
2197+
.read_table(test::table_with_sequence(1, 3).unwrap())
2198+
.unwrap();
2199+
2200+
let df = df
2201+
.select(vec![binary_expr(col("i"), Operator::Plus, col("i"))])
2202+
.unwrap()
2203+
.filter(col("i").gt(lit(2)))
2204+
.unwrap();
2205+
let results = df.collect().await.unwrap();
2206+
2207+
let expected = vec![
2208+
"+--------------------------+",
2209+
"| ?table?.i Plus ?table?.i |",
2210+
"+--------------------------+",
2211+
"| 6 |",
2212+
"+--------------------------+",
2213+
];
2214+
assert_batches_sorted_eq!(expected, &results);
2215+
}
2216+
21692217
#[tokio::test]
21702218
async fn group_by_dictionary() {
21712219
async fn run_test_case<K: ArrowDictionaryKeyType>() {

datafusion/src/logical_plan/expr.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,11 @@ impl Expr {
804804
where
805805
R: ExprRewriter,
806806
{
807-
if !rewriter.pre_visit(&self)? {
808-
return Ok(self);
807+
let need_mutate = match rewriter.pre_visit(&self)? {
808+
RewriteRecursion::Mutate => return rewriter.mutate(self),
809+
RewriteRecursion::Stop => return Ok(self),
810+
RewriteRecursion::Continue => true,
811+
RewriteRecursion::Skip => false,
809812
};
810813

811814
// recurse into all sub expressions(and cover all expression types)
@@ -915,14 +918,18 @@ impl Expr {
915918
negated,
916919
} => Expr::InList {
917920
expr: rewrite_boxed(expr, rewriter)?,
918-
list,
921+
list: rewrite_vec(list, rewriter)?,
919922
negated,
920923
},
921924
Expr::Wildcard => Expr::Wildcard,
922925
};
923926

924927
// now rewrite this expression itself
925-
rewriter.mutate(expr)
928+
if need_mutate {
929+
rewriter.mutate(expr)
930+
} else {
931+
Ok(expr)
932+
}
926933
}
927934
}
928935

@@ -990,15 +997,27 @@ pub trait ExpressionVisitor: Sized {
990997
}
991998
}
992999

1000+
/// Controls how the [ExprRewriter] recursion should proceed.
1001+
pub enum RewriteRecursion {
1002+
/// Continue rewrite / visit this expression.
1003+
Continue,
1004+
/// Call [mutate()] immediately and return.
1005+
Mutate,
1006+
/// Do not rewrite / visit the children of this expression.
1007+
Stop,
1008+
/// Keep recursive but skip mutate on this expression
1009+
Skip,
1010+
}
1011+
9931012
/// Trait for potentially recursively rewriting an [`Expr`] expression
9941013
/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is
9951014
/// invoked recursively on all nodes of an expression tree. See the
9961015
/// comments on `Expr::rewrite` for details on its use
9971016
pub trait ExprRewriter: Sized {
9981017
/// Invoked before any children of `expr` are rewritten /
999-
/// visited. Default implementation returns `Ok(true)`
1000-
fn pre_visit(&mut self, _expr: &Expr) -> Result<bool> {
1001-
Ok(true)
1018+
/// visited. Default implementation returns `Ok(RewriteRecursion::Continue)`
1019+
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
1020+
Ok(RewriteRecursion::Continue)
10021021
}
10031022

10041023
/// Invoked after all children of `expr` have been mutated and
@@ -1721,13 +1740,17 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
17211740
} => {
17221741
let mut name = "CASE ".to_string();
17231742
if let Some(e) = expr {
1724-
name += &format!("{:?} ", e);
1743+
let e = create_name(e, input_schema)?;
1744+
name += &format!("{} ", e);
17251745
}
17261746
for (w, t) in when_then_expr {
1727-
name += &format!("WHEN {:?} THEN {:?} ", w, t);
1747+
let when = create_name(w, input_schema)?;
1748+
let then = create_name(t, input_schema)?;
1749+
name += &format!("WHEN {} THEN {} ", when, then);
17281750
}
17291751
if let Some(e) = else_expr {
1730-
name += &format!("ELSE {:?} ", e);
1752+
let e = create_name(e, input_schema)?;
1753+
name += &format!("ELSE {} ", e);
17311754
}
17321755
name += "END";
17331756
Ok(name)
@@ -1887,9 +1910,9 @@ mod tests {
18871910
Ok(expr)
18881911
}
18891912

1890-
fn pre_visit(&mut self, expr: &Expr) -> Result<bool> {
1913+
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
18911914
self.v.push(format!("Previsited {:?}", expr));
1892-
Ok(true)
1915+
Ok(RewriteRecursion::Continue)
18931916
}
18941917
}
18951918

datafusion/src/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub use expr::{
4545
right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part,
4646
sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
4747
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
48-
ExpressionVisitor, Literal, Recursion,
48+
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
4949
};
5050
pub use extension::UserDefinedLogicalNode;
5151
pub use operators::Operator;

0 commit comments

Comments
 (0)