Skip to content

Commit 6d3854f

Browse files
pepijnvealamb
andauthored
#17972 Restore case expr/expr optimisation while ensuring lazy evaluation (#17973)
* #17972 Restore case expr/expr optimisation while ensuring lazy evaluation * Avoid calling `PhysicalExpr::evaluate` from `PhysicalExpr::evaluate_selection` for empty selections. * Make `PhysicalExpr::evaluate_selection` correctly handle empty input sets and all false filters * Reoragnize code to avoid scatter codepath when using `evaluate` fast path. * Clarify comments in case * Move null handling after true count check. * Tweaking comments * Add unit tests to help define the boundary case behaviour of evaluate_selection * Code polishing - Add extra comments - Use match for the scatter paragraph - Validate that the size of selection and batch match * Fix clippy errors * Add additional case SLTs --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent a65a2cb commit 6d3854f

3 files changed

Lines changed: 282 additions & 28 deletions

File tree

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 244 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ use std::sync::Arc;
2323

2424
use crate::utils::scatter;
2525

26-
use arrow::array::{ArrayRef, BooleanArray};
26+
use arrow::array::{new_empty_array, ArrayRef, BooleanArray};
2727
use arrow::compute::filter_record_batch;
2828
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
2929
use arrow::record_batch::RecordBatch;
3030
use datafusion_common::tree_node::{
3131
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
3232
};
33-
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
33+
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
3434
use datafusion_expr_common::columnar_value::ColumnarValue;
3535
use datafusion_expr_common::interval_arithmetic::Interval;
3636
use datafusion_expr_common::sort_properties::ExprProperties;
@@ -90,36 +90,69 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
9090
self.nullable(input_schema)?,
9191
)))
9292
}
93-
/// Evaluate an expression against a RecordBatch after first applying a
94-
/// validity array
93+
/// Evaluate an expression against a RecordBatch after first applying a validity array
94+
///
95+
/// # Errors
96+
///
97+
/// Returns an `Err` if the expression could not be evaluated or if the length of the
98+
/// `selection` validity array and the number of row in `batch` is not equal.
9599
fn evaluate_selection(
96100
&self,
97101
batch: &RecordBatch,
98102
selection: &BooleanArray,
99103
) -> Result<ColumnarValue> {
100-
let tmp_batch = filter_record_batch(batch, selection)?;
101-
102-
let tmp_result = self.evaluate(&tmp_batch)?;
103-
104-
if batch.num_rows() == tmp_batch.num_rows() {
105-
// All values from the `selection` filter are true.
106-
Ok(tmp_result)
107-
} else if let ColumnarValue::Array(a) = tmp_result {
108-
scatter(selection, a.as_ref()).map(ColumnarValue::Array)
109-
} else if let ColumnarValue::Scalar(ScalarValue::Boolean(value)) = &tmp_result {
110-
// When the scalar is true or false, skip the scatter process
111-
if let Some(v) = value {
112-
if *v {
113-
Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
104+
let row_count = batch.num_rows();
105+
if row_count != selection.len() {
106+
return exec_err!("Selection array length does not match batch row count: {} != {row_count}", selection.len());
107+
}
108+
109+
let selection_count = selection.true_count();
110+
111+
// First, check if we can avoid filtering altogether.
112+
if selection_count == row_count {
113+
// All values from the `selection` filter are true and match the input batch.
114+
// No need to perform any filtering.
115+
return self.evaluate(batch);
116+
}
117+
118+
// Next, prepare the result array for each 'true' row in the selection vector.
119+
let filtered_result = if selection_count == 0 {
120+
// Do not call `evaluate` when the selection is empty.
121+
// `evaluate_selection` is used to conditionally evaluate expressions.
122+
// When the expression in question is fallible, evaluating it with an empty
123+
// record batch may trigger a runtime error (e.g. division by zero).
124+
//
125+
// Instead, create an empty array matching the expected return type.
126+
let datatype = self.data_type(batch.schema_ref().as_ref())?;
127+
ColumnarValue::Array(new_empty_array(&datatype))
128+
} else {
129+
// If we reach this point, there's no other option than to filter the batch.
130+
// This is a fairly costly operation since it requires creating partial copies
131+
// (worst case of length `row_count - 1`) of all the arrays in the record batch.
132+
// The resulting `filtered_batch` will contain `selection_count` rows.
133+
let filtered_batch = filter_record_batch(batch, selection)?;
134+
self.evaluate(&filtered_batch)?
135+
};
136+
137+
// Finally, scatter the filtered result array so that the indices match the input rows again.
138+
match &filtered_result {
139+
ColumnarValue::Array(a) => {
140+
scatter(selection, a.as_ref()).map(ColumnarValue::Array)
141+
}
142+
ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
143+
// When the scalar is true or false, skip the scatter process
144+
if let Some(v) = value {
145+
if *v {
146+
Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
147+
} else {
148+
Ok(filtered_result)
149+
}
114150
} else {
115-
Ok(tmp_result)
151+
let array = BooleanArray::from(vec![None; row_count]);
152+
scatter(selection, &array).map(ColumnarValue::Array)
116153
}
117-
} else {
118-
let array = BooleanArray::from(vec![None; batch.num_rows()]);
119-
scatter(selection, &array).map(ColumnarValue::Array)
120154
}
121-
} else {
122-
Ok(tmp_result)
155+
ColumnarValue::Scalar(_) => Ok(filtered_result),
123156
}
124157
}
125158

@@ -601,3 +634,190 @@ pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
601634
.expect("infallible closure should not fail");
602635
is_volatile
603636
}
637+
638+
#[cfg(test)]
639+
mod test {
640+
use crate::physical_expr::PhysicalExpr;
641+
use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
642+
use arrow::datatypes::{DataType, Schema};
643+
use datafusion_expr_common::columnar_value::ColumnarValue;
644+
use std::fmt::{Display, Formatter};
645+
use std::sync::Arc;
646+
647+
#[derive(Debug, PartialEq, Eq, Hash)]
648+
struct TestExpr {}
649+
650+
impl PhysicalExpr for TestExpr {
651+
fn as_any(&self) -> &dyn std::any::Any {
652+
self
653+
}
654+
655+
fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
656+
Ok(DataType::Int64)
657+
}
658+
659+
fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
660+
Ok(false)
661+
}
662+
663+
fn evaluate(
664+
&self,
665+
batch: &RecordBatch,
666+
) -> datafusion_common::Result<ColumnarValue> {
667+
let data = vec![1; batch.num_rows()];
668+
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
669+
}
670+
671+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
672+
vec![]
673+
}
674+
675+
fn with_new_children(
676+
self: Arc<Self>,
677+
_children: Vec<Arc<dyn PhysicalExpr>>,
678+
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
679+
Ok(Arc::new(Self {}))
680+
}
681+
682+
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
683+
f.write_str("TestExpr")
684+
}
685+
}
686+
687+
impl Display for TestExpr {
688+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
689+
self.fmt_sql(f)
690+
}
691+
}
692+
693+
macro_rules! assert_arrays_eq {
694+
($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
695+
let expected = $EXPECTED.to_array(1).unwrap();
696+
let actual = $ACTUAL;
697+
698+
let actual_array = actual.to_array(expected.len()).unwrap();
699+
let actual_ref = actual_array.as_ref();
700+
let expected_ref = expected.as_ref();
701+
assert!(
702+
actual_ref == expected_ref,
703+
"{}: expected: {:?}, actual: {:?}",
704+
$MESSAGE,
705+
$EXPECTED,
706+
actual_ref
707+
);
708+
};
709+
}
710+
711+
fn test_evaluate_selection(
712+
batch: &RecordBatch,
713+
selection: &BooleanArray,
714+
expected: &ColumnarValue,
715+
) {
716+
let expr = TestExpr {};
717+
718+
// First check that the `evaluate_selection` is the expected one
719+
let selection_result = expr.evaluate_selection(batch, selection).unwrap();
720+
assert_eq!(
721+
expected.to_array(1).unwrap().len(),
722+
selection_result.to_array(1).unwrap().len(),
723+
"evaluate_selection should output row count should match input record batch"
724+
);
725+
assert_arrays_eq!(
726+
expected,
727+
&selection_result,
728+
"evaluate_selection returned unexpected value"
729+
);
730+
731+
// If we're selecting all rows, the result should be the same as calling `evaluate`
732+
// with the full record batch.
733+
if (0..batch.num_rows())
734+
.all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
735+
{
736+
let empty_result = expr.evaluate(batch).unwrap();
737+
738+
assert_arrays_eq!(
739+
empty_result,
740+
&selection_result,
741+
"evaluate_selection does not match unfiltered evaluate result"
742+
);
743+
}
744+
}
745+
746+
fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
747+
let expr = TestExpr {};
748+
749+
// First check that the `evaluate_selection` is the expected one
750+
let selection_result = expr.evaluate_selection(batch, selection);
751+
assert!(selection_result.is_err(), "evaluate_selection should fail");
752+
}
753+
754+
#[test]
755+
pub fn test_evaluate_selection_with_empty_record_batch() {
756+
test_evaluate_selection(
757+
&RecordBatch::new_empty(Arc::new(Schema::empty())),
758+
&BooleanArray::from(vec![false; 0]),
759+
&ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
760+
);
761+
}
762+
763+
#[test]
764+
pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
765+
test_evaluate_selection_error(
766+
&RecordBatch::new_empty(Arc::new(Schema::empty())),
767+
&BooleanArray::from(vec![false; 10]),
768+
);
769+
}
770+
771+
#[test]
772+
pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
773+
test_evaluate_selection_error(
774+
&RecordBatch::new_empty(Arc::new(Schema::empty())),
775+
&BooleanArray::from(vec![true; 10]),
776+
);
777+
}
778+
779+
#[test]
780+
pub fn test_evaluate_selection_with_non_empty_record_batch() {
781+
test_evaluate_selection(
782+
unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
783+
&BooleanArray::from(vec![true; 10]),
784+
&ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
785+
);
786+
}
787+
788+
#[test]
789+
pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection(
790+
) {
791+
test_evaluate_selection_error(
792+
unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
793+
&BooleanArray::from(vec![false; 20]),
794+
);
795+
}
796+
797+
#[test]
798+
pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection(
799+
) {
800+
test_evaluate_selection_error(
801+
unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
802+
&BooleanArray::from(vec![true; 20]),
803+
);
804+
}
805+
806+
#[test]
807+
pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection(
808+
) {
809+
test_evaluate_selection_error(
810+
unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
811+
&BooleanArray::from(vec![false; 5]),
812+
);
813+
}
814+
815+
#[test]
816+
pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection(
817+
) {
818+
test_evaluate_selection_error(
819+
unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
820+
&BooleanArray::from(vec![true; 5]),
821+
);
822+
}
823+
}

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,7 @@ impl CaseExpr {
155155
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
156156
{
157157
EvalMethod::ScalarOrScalar
158-
} else if when_then_expr.len() == 1
159-
&& is_cheap_and_infallible(&(when_then_expr[0].1))
160-
&& else_expr.as_ref().is_some_and(is_cheap_and_infallible)
161-
{
158+
} else if when_then_expr.len() == 1 && else_expr.is_some() {
162159
EvalMethod::ExpressionOrExpression
163160
} else {
164161
EvalMethod::NoExpression
@@ -425,6 +422,16 @@ impl CaseExpr {
425422
)
426423
})?;
427424

425+
// For the true and false/null selection vectors, bypass `evaluate_selection` and merging
426+
// results. This avoids materializing the array for the other branch which we will discard
427+
// entirely anyway.
428+
let true_count = when_value.true_count();
429+
if true_count == batch.num_rows() {
430+
return self.when_then_expr[0].1.evaluate(batch);
431+
} else if true_count == 0 {
432+
return self.else_expr.as_ref().unwrap().evaluate(batch);
433+
}
434+
428435
// Treat 'NULL' as false value
429436
let when_value = match when_value.null_count() {
430437
0 => Cow::Borrowed(when_value),

datafusion/sqllogictest/test_files/case.slt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,19 +467,46 @@ FROM t;
467467
----
468468
[{foo: blarg}]
469469

470+
# mix of then and else
470471
query II
471472
SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v)
472473
----
473474
0 42
474475
1 10
475476
2 5
476477

478+
# when expressions is always false, then branch should never be evaluated
477479
query II
478480
SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v)
479481
----
480482
1 1
481483
2 1
482484

485+
# when expressions is always true, else branch should never be evaluated
486+
query II
487+
SELECT v, CASE WHEN v > 0 THEN 1 ELSE 10/0 END FROM (VALUES (1), (2)) t(v)
488+
----
489+
1 1
490+
2 1
491+
492+
493+
# lazy evaluation of multiple when branches, else branch should never be evaluated
494+
query II
495+
SELECT v, CASE WHEN v == 1 THEN -1 WHEN v == 2 THEN -2 WHEN v == 3 THEN -3 ELSE 10/0 END FROM (VALUES (1), (2), (3)) t(v)
496+
----
497+
1 -1
498+
2 -2
499+
3 -3
500+
501+
# covers the InfallibleExprOrNull evaluation strategy
502+
query II
503+
SELECT v, CASE WHEN v THEN 1 END FROM (VALUES (1), (2), (3), (NULL)) t(v)
504+
----
505+
1 1
506+
2 1
507+
3 1
508+
NULL NULL
509+
483510
statement ok
484511
drop table t
485512

0 commit comments

Comments
 (0)