@@ -23,14 +23,14 @@ use std::sync::Arc;
2323
2424use crate :: utils:: scatter;
2525
26- use arrow:: array:: { ArrayRef , BooleanArray } ;
26+ use arrow:: array:: { new_empty_array , ArrayRef , BooleanArray } ;
2727use arrow:: compute:: filter_record_batch;
2828use arrow:: datatypes:: { DataType , Field , FieldRef , Schema } ;
2929use arrow:: record_batch:: RecordBatch ;
3030use datafusion_common:: tree_node:: {
3131 Transformed , TransformedResult , TreeNode , TreeNodeRecursion ,
3232} ;
33- use datafusion_common:: { internal_err, not_impl_err, Result , ScalarValue } ;
33+ use datafusion_common:: { exec_err , internal_err, not_impl_err, Result , ScalarValue } ;
3434use datafusion_expr_common:: columnar_value:: ColumnarValue ;
3535use datafusion_expr_common:: interval_arithmetic:: Interval ;
3636use datafusion_expr_common:: sort_properties:: ExprProperties ;
@@ -90,36 +90,69 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
9090 self . nullable ( input_schema) ?,
9191 ) ) )
9292 }
93- /// Evaluate an expression against a RecordBatch after first applying a
94- /// validity array
93+ /// Evaluate an expression against a RecordBatch after first applying a validity array
94+ ///
95+ /// # Errors
96+ ///
97+ /// Returns an `Err` if the expression could not be evaluated or if the length of the
98+ /// `selection` validity array and the number of row in `batch` is not equal.
9599 fn evaluate_selection (
96100 & self ,
97101 batch : & RecordBatch ,
98102 selection : & BooleanArray ,
99103 ) -> Result < ColumnarValue > {
100- let tmp_batch = filter_record_batch ( batch, selection) ?;
101-
102- let tmp_result = self . evaluate ( & tmp_batch) ?;
103-
104- if batch. num_rows ( ) == tmp_batch. num_rows ( ) {
105- // All values from the `selection` filter are true.
106- Ok ( tmp_result)
107- } else if let ColumnarValue :: Array ( a) = tmp_result {
108- scatter ( selection, a. as_ref ( ) ) . map ( ColumnarValue :: Array )
109- } else if let ColumnarValue :: Scalar ( ScalarValue :: Boolean ( value) ) = & tmp_result {
110- // When the scalar is true or false, skip the scatter process
111- if let Some ( v) = value {
112- if * v {
113- Ok ( ColumnarValue :: from ( Arc :: new ( selection. clone ( ) ) as ArrayRef ) )
104+ let row_count = batch. num_rows ( ) ;
105+ if row_count != selection. len ( ) {
106+ return exec_err ! ( "Selection array length does not match batch row count: {} != {row_count}" , selection. len( ) ) ;
107+ }
108+
109+ let selection_count = selection. true_count ( ) ;
110+
111+ // First, check if we can avoid filtering altogether.
112+ if selection_count == row_count {
113+ // All values from the `selection` filter are true and match the input batch.
114+ // No need to perform any filtering.
115+ return self . evaluate ( batch) ;
116+ }
117+
118+ // Next, prepare the result array for each 'true' row in the selection vector.
119+ let filtered_result = if selection_count == 0 {
120+ // Do not call `evaluate` when the selection is empty.
121+ // `evaluate_selection` is used to conditionally evaluate expressions.
122+ // When the expression in question is fallible, evaluating it with an empty
123+ // record batch may trigger a runtime error (e.g. division by zero).
124+ //
125+ // Instead, create an empty array matching the expected return type.
126+ let datatype = self . data_type ( batch. schema_ref ( ) . as_ref ( ) ) ?;
127+ ColumnarValue :: Array ( new_empty_array ( & datatype) )
128+ } else {
129+ // If we reach this point, there's no other option than to filter the batch.
130+ // This is a fairly costly operation since it requires creating partial copies
131+ // (worst case of length `row_count - 1`) of all the arrays in the record batch.
132+ // The resulting `filtered_batch` will contain `selection_count` rows.
133+ let filtered_batch = filter_record_batch ( batch, selection) ?;
134+ self . evaluate ( & filtered_batch) ?
135+ } ;
136+
137+ // Finally, scatter the filtered result array so that the indices match the input rows again.
138+ match & filtered_result {
139+ ColumnarValue :: Array ( a) => {
140+ scatter ( selection, a. as_ref ( ) ) . map ( ColumnarValue :: Array )
141+ }
142+ ColumnarValue :: Scalar ( ScalarValue :: Boolean ( value) ) => {
143+ // When the scalar is true or false, skip the scatter process
144+ if let Some ( v) = value {
145+ if * v {
146+ Ok ( ColumnarValue :: from ( Arc :: new ( selection. clone ( ) ) as ArrayRef ) )
147+ } else {
148+ Ok ( filtered_result)
149+ }
114150 } else {
115- Ok ( tmp_result)
151+ let array = BooleanArray :: from ( vec ! [ None ; row_count] ) ;
152+ scatter ( selection, & array) . map ( ColumnarValue :: Array )
116153 }
117- } else {
118- let array = BooleanArray :: from ( vec ! [ None ; batch. num_rows( ) ] ) ;
119- scatter ( selection, & array) . map ( ColumnarValue :: Array )
120154 }
121- } else {
122- Ok ( tmp_result)
155+ ColumnarValue :: Scalar ( _) => Ok ( filtered_result) ,
123156 }
124157 }
125158
@@ -601,3 +634,190 @@ pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
601634 . expect ( "infallible closure should not fail" ) ;
602635 is_volatile
603636}
637+
638+ #[ cfg( test) ]
639+ mod test {
640+ use crate :: physical_expr:: PhysicalExpr ;
641+ use arrow:: array:: { Array , BooleanArray , Int64Array , RecordBatch } ;
642+ use arrow:: datatypes:: { DataType , Schema } ;
643+ use datafusion_expr_common:: columnar_value:: ColumnarValue ;
644+ use std:: fmt:: { Display , Formatter } ;
645+ use std:: sync:: Arc ;
646+
647+ #[ derive( Debug , PartialEq , Eq , Hash ) ]
648+ struct TestExpr { }
649+
650+ impl PhysicalExpr for TestExpr {
651+ fn as_any ( & self ) -> & dyn std:: any:: Any {
652+ self
653+ }
654+
655+ fn data_type ( & self , _schema : & Schema ) -> datafusion_common:: Result < DataType > {
656+ Ok ( DataType :: Int64 )
657+ }
658+
659+ fn nullable ( & self , _schema : & Schema ) -> datafusion_common:: Result < bool > {
660+ Ok ( false )
661+ }
662+
663+ fn evaluate (
664+ & self ,
665+ batch : & RecordBatch ,
666+ ) -> datafusion_common:: Result < ColumnarValue > {
667+ let data = vec ! [ 1 ; batch. num_rows( ) ] ;
668+ Ok ( ColumnarValue :: Array ( Arc :: new ( Int64Array :: from ( data) ) ) )
669+ }
670+
671+ fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
672+ vec ! [ ]
673+ }
674+
675+ fn with_new_children (
676+ self : Arc < Self > ,
677+ _children : Vec < Arc < dyn PhysicalExpr > > ,
678+ ) -> datafusion_common:: Result < Arc < dyn PhysicalExpr > > {
679+ Ok ( Arc :: new ( Self { } ) )
680+ }
681+
682+ fn fmt_sql ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
683+ f. write_str ( "TestExpr" )
684+ }
685+ }
686+
687+ impl Display for TestExpr {
688+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
689+ self . fmt_sql ( f)
690+ }
691+ }
692+
693+ macro_rules! assert_arrays_eq {
694+ ( $EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
695+ let expected = $EXPECTED. to_array( 1 ) . unwrap( ) ;
696+ let actual = $ACTUAL;
697+
698+ let actual_array = actual. to_array( expected. len( ) ) . unwrap( ) ;
699+ let actual_ref = actual_array. as_ref( ) ;
700+ let expected_ref = expected. as_ref( ) ;
701+ assert!(
702+ actual_ref == expected_ref,
703+ "{}: expected: {:?}, actual: {:?}" ,
704+ $MESSAGE,
705+ $EXPECTED,
706+ actual_ref
707+ ) ;
708+ } ;
709+ }
710+
711+ fn test_evaluate_selection (
712+ batch : & RecordBatch ,
713+ selection : & BooleanArray ,
714+ expected : & ColumnarValue ,
715+ ) {
716+ let expr = TestExpr { } ;
717+
718+ // First check that the `evaluate_selection` is the expected one
719+ let selection_result = expr. evaluate_selection ( batch, selection) . unwrap ( ) ;
720+ assert_eq ! (
721+ expected. to_array( 1 ) . unwrap( ) . len( ) ,
722+ selection_result. to_array( 1 ) . unwrap( ) . len( ) ,
723+ "evaluate_selection should output row count should match input record batch"
724+ ) ;
725+ assert_arrays_eq ! (
726+ expected,
727+ & selection_result,
728+ "evaluate_selection returned unexpected value"
729+ ) ;
730+
731+ // If we're selecting all rows, the result should be the same as calling `evaluate`
732+ // with the full record batch.
733+ if ( 0 ..batch. num_rows ( ) )
734+ . all ( |row_idx| row_idx < selection. len ( ) && selection. value ( row_idx) )
735+ {
736+ let empty_result = expr. evaluate ( batch) . unwrap ( ) ;
737+
738+ assert_arrays_eq ! (
739+ empty_result,
740+ & selection_result,
741+ "evaluate_selection does not match unfiltered evaluate result"
742+ ) ;
743+ }
744+ }
745+
746+ fn test_evaluate_selection_error ( batch : & RecordBatch , selection : & BooleanArray ) {
747+ let expr = TestExpr { } ;
748+
749+ // First check that the `evaluate_selection` is the expected one
750+ let selection_result = expr. evaluate_selection ( batch, selection) ;
751+ assert ! ( selection_result. is_err( ) , "evaluate_selection should fail" ) ;
752+ }
753+
754+ #[ test]
755+ pub fn test_evaluate_selection_with_empty_record_batch ( ) {
756+ test_evaluate_selection (
757+ & RecordBatch :: new_empty ( Arc :: new ( Schema :: empty ( ) ) ) ,
758+ & BooleanArray :: from ( vec ! [ false ; 0 ] ) ,
759+ & ColumnarValue :: Array ( Arc :: new ( Int64Array :: new_null ( 0 ) ) ) ,
760+ ) ;
761+ }
762+
763+ #[ test]
764+ pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection ( ) {
765+ test_evaluate_selection_error (
766+ & RecordBatch :: new_empty ( Arc :: new ( Schema :: empty ( ) ) ) ,
767+ & BooleanArray :: from ( vec ! [ false ; 10 ] ) ,
768+ ) ;
769+ }
770+
771+ #[ test]
772+ pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection ( ) {
773+ test_evaluate_selection_error (
774+ & RecordBatch :: new_empty ( Arc :: new ( Schema :: empty ( ) ) ) ,
775+ & BooleanArray :: from ( vec ! [ true ; 10 ] ) ,
776+ ) ;
777+ }
778+
779+ #[ test]
780+ pub fn test_evaluate_selection_with_non_empty_record_batch ( ) {
781+ test_evaluate_selection (
782+ unsafe { & RecordBatch :: new_unchecked ( Arc :: new ( Schema :: empty ( ) ) , vec ! [ ] , 10 ) } ,
783+ & BooleanArray :: from ( vec ! [ true ; 10 ] ) ,
784+ & ColumnarValue :: Array ( Arc :: new ( Int64Array :: from ( vec ! [ 1 ; 10 ] ) ) ) ,
785+ ) ;
786+ }
787+
788+ #[ test]
789+ pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection (
790+ ) {
791+ test_evaluate_selection_error (
792+ unsafe { & RecordBatch :: new_unchecked ( Arc :: new ( Schema :: empty ( ) ) , vec ! [ ] , 10 ) } ,
793+ & BooleanArray :: from ( vec ! [ false ; 20 ] ) ,
794+ ) ;
795+ }
796+
797+ #[ test]
798+ pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection (
799+ ) {
800+ test_evaluate_selection_error (
801+ unsafe { & RecordBatch :: new_unchecked ( Arc :: new ( Schema :: empty ( ) ) , vec ! [ ] , 10 ) } ,
802+ & BooleanArray :: from ( vec ! [ true ; 20 ] ) ,
803+ ) ;
804+ }
805+
806+ #[ test]
807+ pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection (
808+ ) {
809+ test_evaluate_selection_error (
810+ unsafe { & RecordBatch :: new_unchecked ( Arc :: new ( Schema :: empty ( ) ) , vec ! [ ] , 10 ) } ,
811+ & BooleanArray :: from ( vec ! [ false ; 5 ] ) ,
812+ ) ;
813+ }
814+
815+ #[ test]
816+ pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection (
817+ ) {
818+ test_evaluate_selection_error (
819+ unsafe { & RecordBatch :: new_unchecked ( Arc :: new ( Schema :: empty ( ) ) , vec ! [ ] , 10 ) } ,
820+ & BooleanArray :: from ( vec ! [ true ; 5 ] ) ,
821+ ) ;
822+ }
823+ }
0 commit comments