diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 485ab0f6e7465..671deb0276b6a 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -16,11 +16,10 @@ // under the License. use crate::utils::{ - exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction, - verify_not_disjunction, + exprs_to_join_cols, find_join_exprs, split_conjunction, verify_not_disjunction, }; use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, plan_err}; +use datafusion_common::{context, plan_err, DataFusionError}; use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; @@ -134,11 +133,23 @@ fn optimize_exists( outer_input: &LogicalPlan, outer_other_exprs: &[Expr], ) -> datafusion_common::Result { - let subqry_inputs = query_info.query.subquery.inputs(); - let subqry_input = only_or_err(subqry_inputs.as_slice()) - .map_err(|e| context!("single expression projection required", e))?; - let subqry_filter = Filter::try_from_plan(subqry_input) - .map_err(|e| context!("cannot optimize non-correlated subquery", e))?; + let subqry_filter = match query_info.query.subquery.as_ref() { + LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { + LogicalPlan::Projection(subqry_proj) => { + Filter::try_from_plan(&*subqry_proj.input) + } + _ => Err(DataFusionError::NotImplemented( + "Subquery currently only supports distinct or projection".to_string(), + )), + }, + LogicalPlan::Projection(subqry_proj) => { + Filter::try_from_plan(&*subqry_proj.input) + } + _ => Err(DataFusionError::NotImplemented( + "Subquery currently only supports distinct or projection".to_string(), + )), + } + .map_err(|e| context!("cannot optimize non-correlated subquery", e))?; // split into filters let mut subqry_filter_exprs = vec![]; diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index dc452af3be0ac..f6fe685ee2820 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -118,6 +118,21 @@ fn anti_join_with_join_filter() -> Result<()> { Ok(()) } +#[test] +fn where_exists_distinct() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3724 + let sql = "SELECT * FROM test WHERE EXISTS (\ + SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)"; + let plan = test_sql(sql)?; + let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64 + Semi Join: test.col_int32 = t2.col_int32 + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64] + SubqueryAlias: t2 + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn intersect() -> Result<()> { let sql = "SELECT col_int32, col_utf8 FROM test \