@@ -198,12 +198,16 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
198198
199199#[ cfg( test) ]
200200mod tests {
201- use crate :: expressions:: lit;
201+ use crate :: expressions:: { col , lit} ;
202202
203203 use super :: * ;
204- use arrow:: datatypes:: { DataType , Field , Schema } ;
205- use datafusion_common:: ScalarValue ;
204+ use arrow:: {
205+ array:: { RecordBatch , RecordBatchOptions } ,
206+ datatypes:: { DataType , Field , Schema , SchemaRef } ,
207+ } ;
208+ use datafusion_common:: { record_batch, ScalarValue } ;
206209 use datafusion_expr:: Operator ;
210+ use itertools:: Itertools ;
207211 use std:: sync:: Arc ;
208212
209213 fn create_test_schema ( ) -> ( Schema , Schema ) {
@@ -369,4 +373,95 @@ mod tests {
369373 . to_string( )
370374 . contains( "Non-nullable column 'b' is missing" ) ) ;
371375 }
376+
377+ /// Stolen from ProjectionExec
378+ fn batch_project (
379+ expr : Vec < Arc < dyn PhysicalExpr > > ,
380+ batch : & RecordBatch ,
381+ schema : SchemaRef ,
382+ ) -> Result < RecordBatch > {
383+ // Records time on drop
384+ let arrays = expr
385+ . iter ( )
386+ . map ( |expr| {
387+ expr. evaluate ( batch)
388+ . and_then ( |v| v. into_array ( batch. num_rows ( ) ) )
389+ } )
390+ . collect :: < Result < Vec < _ > > > ( ) ?;
391+
392+ if arrays. is_empty ( ) {
393+ let options =
394+ RecordBatchOptions :: new ( ) . with_row_count ( Some ( batch. num_rows ( ) ) ) ;
395+ RecordBatch :: try_new_with_options ( Arc :: clone ( & schema) , arrays, & options)
396+ . map_err ( Into :: into)
397+ } else {
398+ RecordBatch :: try_new ( Arc :: clone ( & schema) , arrays) . map_err ( Into :: into)
399+ }
400+ }
401+
402+ /// Example showing how we can use the `PhysicalExprSchemaRewriter` to adapt RecordBatches during a scan
403+ /// to apply projections, type conversions and handling of missing columns all at once.
404+ #[ test]
405+ fn test_adapt_batches ( ) {
406+ let physical_batch = record_batch ! (
407+ ( "a" , Int32 , vec![ Some ( 1 ) , None , Some ( 3 ) ] ) ,
408+ ( "extra" , Utf8 , vec![ Some ( "x" ) , Some ( "y" ) , None ] )
409+ )
410+ . unwrap ( ) ;
411+
412+ let physical_schema = physical_batch. schema ( ) ;
413+
414+ let logical_schema = Arc :: new ( Schema :: new ( vec ! [
415+ Field :: new( "a" , DataType :: Int64 , true ) , // Different type
416+ Field :: new( "b" , DataType :: Utf8 , true ) , // Missing from physical
417+ ] ) ) ;
418+
419+ let projection = vec ! [
420+ col( "b" , & logical_schema) . unwrap( ) ,
421+ col( "a" , & logical_schema) . unwrap( ) ,
422+ ] ;
423+
424+ let rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema, & logical_schema) ;
425+
426+ let adapted_projection = projection
427+ . into_iter ( )
428+ . map ( |expr| rewriter. rewrite ( expr) . unwrap ( ) )
429+ . collect_vec ( ) ;
430+
431+ let adapted_schema = Arc :: new ( Schema :: new (
432+ adapted_projection
433+ . iter ( )
434+ . map ( |expr| expr. return_field ( & physical_schema) . unwrap ( ) )
435+ . collect_vec ( )
436+ ) ) ;
437+
438+ let res = batch_project (
439+ adapted_projection,
440+ & physical_batch,
441+ Arc :: clone ( & adapted_schema) ,
442+ )
443+ . unwrap ( ) ;
444+
445+ assert_eq ! ( res. num_columns( ) , 2 ) ;
446+ assert_eq ! ( res. column( 0 ) . data_type( ) , & DataType :: Utf8 ) ;
447+ assert_eq ! ( res. column( 1 ) . data_type( ) , & DataType :: Int64 ) ;
448+ assert_eq ! (
449+ res. column( 0 )
450+ . as_any( )
451+ . downcast_ref:: <arrow:: array:: StringArray >( )
452+ . unwrap( )
453+ . iter( )
454+ . collect_vec( ) ,
455+ vec![ None , None , None ]
456+ ) ;
457+ assert_eq ! (
458+ res. column( 1 )
459+ . as_any( )
460+ . downcast_ref:: <arrow:: array:: Int64Array >( )
461+ . unwrap( )
462+ . iter( )
463+ . collect_vec( ) ,
464+ vec![ Some ( 1 ) , None , Some ( 3 ) ]
465+ ) ;
466+ }
372467}
0 commit comments