1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use std:: any:: Any ;
19+ use std:: fmt:: { Debug , Display } ;
20+ use std:: hash:: { Hash , Hasher } ;
21+ use std:: sync:: Arc ;
22+
1823use crate :: intervals:: Interval ;
1924use crate :: sort_properties:: SortProperties ;
2025use crate :: utils:: scatter;
@@ -27,11 +32,6 @@ use datafusion_common::utils::DataPtr;
2732use datafusion_common:: { internal_err, not_impl_err, DataFusionError , Result } ;
2833use datafusion_expr:: ColumnarValue ;
2934
30- use std:: any:: Any ;
31- use std:: fmt:: { Debug , Display } ;
32- use std:: hash:: { Hash , Hasher } ;
33- use std:: sync:: Arc ;
34-
3535/// Expression that can be evaluated against a RecordBatch
3636/// A Physical expression knows its type, nullability and how to evaluate itself.
3737pub trait PhysicalExpr : Send + Sync + Display + Debug + PartialEq < dyn Any > {
@@ -54,13 +54,12 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
5454 let tmp_batch = filter_record_batch ( batch, selection) ?;
5555
5656 let tmp_result = self . evaluate ( & tmp_batch) ?;
57- // All values from the `selection` filter are true.
57+
5858 if batch. num_rows ( ) == tmp_batch. num_rows ( ) {
59- return Ok ( tmp_result) ;
60- }
61- if let ColumnarValue :: Array ( a) = tmp_result {
62- let result = scatter ( selection, a. as_ref ( ) ) ?;
63- Ok ( ColumnarValue :: Array ( result) )
59+ // All values from the `selection` filter are true.
60+ Ok ( tmp_result)
61+ } else if let ColumnarValue :: Array ( a) = tmp_result {
62+ scatter ( selection, a. as_ref ( ) ) . map ( ColumnarValue :: Array )
6463 } else {
6564 Ok ( tmp_result)
6665 }
@@ -216,8 +215,8 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
216215 }
217216}
218217
219- /// It is similar to contains method of vector.
220- /// Finds whether `expr` is among `physical_exprs`.
218+ /// This function is similar to the ` contains` method of `Vec`. It finds
219+ /// whether `expr` is among `physical_exprs`.
221220pub fn physical_exprs_contains (
222221 physical_exprs : & [ Arc < dyn PhysicalExpr > ] ,
223222 expr : & Arc < dyn PhysicalExpr > ,
@@ -226,3 +225,49 @@ pub fn physical_exprs_contains(
226225 . iter ( )
227226 . any ( |physical_expr| physical_expr. eq ( expr) )
228227}
228+
229+ #[ cfg( test) ]
230+ mod tests {
231+ use std:: sync:: Arc ;
232+
233+ use crate :: expressions:: { Column , Literal } ;
234+ use crate :: physical_expr:: { physical_exprs_contains, PhysicalExpr } ;
235+
236+ use datafusion_common:: { Result , ScalarValue } ;
237+
238+ #[ test]
239+ fn test_physical_exprs_contains ( ) -> Result < ( ) > {
240+ let lit_true = Arc :: new ( Literal :: new ( ScalarValue :: Boolean ( Some ( true ) ) ) )
241+ as Arc < dyn PhysicalExpr > ;
242+ let lit_false = Arc :: new ( Literal :: new ( ScalarValue :: Boolean ( Some ( false ) ) ) )
243+ as Arc < dyn PhysicalExpr > ;
244+ let lit4 =
245+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 4 ) ) ) ) as Arc < dyn PhysicalExpr > ;
246+ let lit2 =
247+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 2 ) ) ) ) as Arc < dyn PhysicalExpr > ;
248+ let lit1 =
249+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) as Arc < dyn PhysicalExpr > ;
250+ let col_a_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) as Arc < dyn PhysicalExpr > ;
251+ let col_b_expr = Arc :: new ( Column :: new ( "b" , 1 ) ) as Arc < dyn PhysicalExpr > ;
252+ let col_c_expr = Arc :: new ( Column :: new ( "c" , 2 ) ) as Arc < dyn PhysicalExpr > ;
253+
254+ // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b)
255+ let physical_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [
256+ lit_true. clone( ) ,
257+ lit_false. clone( ) ,
258+ lit4. clone( ) ,
259+ lit2. clone( ) ,
260+ col_a_expr. clone( ) ,
261+ col_b_expr. clone( ) ,
262+ ] ;
263+ // below expressions are inside physical_exprs
264+ assert ! ( physical_exprs_contains( & physical_exprs, & lit_true) ) ;
265+ assert ! ( physical_exprs_contains( & physical_exprs, & lit2) ) ;
266+ assert ! ( physical_exprs_contains( & physical_exprs, & col_b_expr) ) ;
267+
268+ // below expressions are not inside physical_exprs
269+ assert ! ( !physical_exprs_contains( & physical_exprs, & col_c_expr) ) ;
270+ assert ! ( !physical_exprs_contains( & physical_exprs, & lit1) ) ;
271+ Ok ( ( ) )
272+ }
273+ }
0 commit comments