@@ -31,16 +31,15 @@ use futures::stream::{Stream, StreamExt};
3131
3232use crate :: execution:: context:: TaskContext ;
3333use crate :: execution:: memory_pool:: proxy:: { RawTableAllocExt , VecAllocExt } ;
34+ use crate :: execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
3435use crate :: physical_plan:: aggregates:: {
3536 evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AccumulatorItem ,
3637 AggregateMode , PhysicalGroupBy , RowAccumulatorItem ,
3738} ;
3839use crate :: physical_plan:: metrics:: { BaselineMetrics , RecordOutput } ;
3940use crate :: physical_plan:: { aggregates, AggregateExpr , PhysicalExpr } ;
4041use crate :: physical_plan:: { RecordBatchStream , SendableRecordBatchStream } ;
41-
42- use crate :: execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
43- use arrow:: array:: { new_null_array, Array , ArrayRef , PrimitiveArray , UInt32Builder } ;
42+ use arrow:: array:: * ;
4443use arrow:: compute:: { cast, filter} ;
4544use arrow:: datatypes:: { DataType , Schema , UInt32Type } ;
4645use arrow:: { compute, datatypes:: SchemaRef , record_batch:: RecordBatch } ;
@@ -53,6 +52,7 @@ use datafusion_row::layout::RowLayout;
5352use datafusion_row:: reader:: { read_row, RowReader } ;
5453use datafusion_row:: MutableRecordBatch ;
5554use hashbrown:: raw:: RawTable ;
55+ use itertools:: izip;
5656
5757/// Grouping aggregate with row-format aggregation states inside.
5858///
@@ -409,7 +409,7 @@ impl GroupedHashAggregateStream {
409409
410410 // Update the accumulator results, according to row_aggr_state.
411411 #[ allow( clippy:: too_many_arguments) ]
412- fn update_accumulators (
412+ fn update_accumulators_using_batch (
413413 & mut self ,
414414 groups_with_rows : & [ usize ] ,
415415 offsets : & [ usize ] ,
@@ -490,6 +490,55 @@ impl GroupedHashAggregateStream {
490490 Ok ( ( ) )
491491 }
492492
493+ // Update the accumulator results, according to row_aggr_state.
494+ fn update_accumulators_using_scalar (
495+ & mut self ,
496+ groups_with_rows : & [ usize ] ,
497+ row_values : & [ Vec < ArrayRef > ] ,
498+ row_filter_values : & [ Option < ArrayRef > ] ,
499+ ) -> Result < ( ) > {
500+ let filter_bool_array = row_filter_values
501+ . iter ( )
502+ . map ( |filter_opt| match filter_opt {
503+ Some ( f) => Ok ( Some ( as_boolean_array ( f) ?) ) ,
504+ None => Ok ( None ) ,
505+ } )
506+ . collect :: < Result < Vec < _ > > > ( ) ?;
507+
508+ for group_idx in groups_with_rows {
509+ let group_state = & mut self . aggr_state . group_states [ * group_idx] ;
510+ let mut state_accessor =
511+ RowAccessor :: new_from_layout ( self . row_aggr_layout . clone ( ) ) ;
512+ state_accessor. point_to ( 0 , group_state. aggregation_buffer . as_mut_slice ( ) ) ;
513+ for idx in & group_state. indices {
514+ for ( accumulator, values_array, filter_array) in izip ! (
515+ self . row_accumulators. iter_mut( ) ,
516+ row_values. iter( ) ,
517+ filter_bool_array. iter( )
518+ ) {
519+ if values_array. len ( ) == 1 {
520+ let scalar_value =
521+ col_to_scalar ( & values_array[ 0 ] , filter_array, * idx as usize ) ?;
522+ accumulator. update_scalar ( & scalar_value, & mut state_accessor) ?;
523+ } else {
524+ let scalar_values = values_array
525+ . iter ( )
526+ . map ( |array| {
527+ col_to_scalar ( array, filter_array, * idx as usize )
528+ } )
529+ . collect :: < Result < Vec < _ > > > ( ) ?;
530+ accumulator
531+ . update_scalar_values ( & scalar_values, & mut state_accessor) ?;
532+ }
533+ }
534+ }
535+ // clear the group indices in this group
536+ group_state. indices . clear ( ) ;
537+ }
538+
539+ Ok ( ( ) )
540+ }
541+
493542 /// Perform group-by aggregation for the given [`RecordBatch`].
494543 ///
495544 /// If successful, this returns the additional number of bytes that were allocated during this process.
@@ -515,35 +564,50 @@ impl GroupedHashAggregateStream {
515564 for group_values in & group_by_values {
516565 let groups_with_rows =
517566 self . update_group_state ( group_values, & mut allocated) ?;
518-
519- // Collect all indices + offsets based on keys in this vec
520- let mut batch_indices: UInt32Builder = UInt32Builder :: with_capacity ( 0 ) ;
521- let mut offsets = vec ! [ 0 ] ;
522- let mut offset_so_far = 0 ;
523- for & group_idx in groups_with_rows. iter ( ) {
524- let indices = & self . aggr_state . group_states [ group_idx] . indices ;
525- batch_indices. append_slice ( indices) ;
526- offset_so_far += indices. len ( ) ;
527- offsets. push ( offset_so_far) ;
567+ // Decide the accumulators update mode, use scalar value to update the accumulators when all of the conditions are meet:
568+ // 1) The aggregation mode is Partial or Single
569+ // 2) There is not normal aggregation expressions
570+ // 3) The number of affected groups is high (entries in `aggr_state` have rows need to update). Usually the high cardinality case
571+ if matches ! ( self . mode, AggregateMode :: Partial | AggregateMode :: Single )
572+ && normal_aggr_input_values. is_empty ( )
573+ && normal_filter_values. is_empty ( )
574+ && groups_with_rows. len ( ) >= batch. num_rows ( ) / 10
575+ {
576+ self . update_accumulators_using_scalar (
577+ & groups_with_rows,
578+ & row_aggr_input_values,
579+ & row_filter_values,
580+ ) ?;
581+ } else {
582+ // Collect all indices + offsets based on keys in this vec
583+ let mut batch_indices: UInt32Builder = UInt32Builder :: with_capacity ( 0 ) ;
584+ let mut offsets = vec ! [ 0 ] ;
585+ let mut offset_so_far = 0 ;
586+ for & group_idx in groups_with_rows. iter ( ) {
587+ let indices = & self . aggr_state . group_states [ group_idx] . indices ;
588+ batch_indices. append_slice ( indices) ;
589+ offset_so_far += indices. len ( ) ;
590+ offsets. push ( offset_so_far) ;
591+ }
592+ let batch_indices = batch_indices. finish ( ) ;
593+
594+ let row_values = get_at_indices ( & row_aggr_input_values, & batch_indices) ?;
595+ let normal_values =
596+ get_at_indices ( & normal_aggr_input_values, & batch_indices) ?;
597+ let row_filter_values =
598+ get_optional_filters ( & row_filter_values, & batch_indices) ;
599+ let normal_filter_values =
600+ get_optional_filters ( & normal_filter_values, & batch_indices) ;
601+ self . update_accumulators_using_batch (
602+ & groups_with_rows,
603+ & offsets,
604+ & row_values,
605+ & normal_values,
606+ & row_filter_values,
607+ & normal_filter_values,
608+ & mut allocated,
609+ ) ?;
528610 }
529- let batch_indices = batch_indices. finish ( ) ;
530-
531- let row_values = get_at_indices ( & row_aggr_input_values, & batch_indices) ?;
532- let normal_values =
533- get_at_indices ( & normal_aggr_input_values, & batch_indices) ?;
534- let row_filter_values =
535- get_optional_filters ( & row_filter_values, & batch_indices) ;
536- let normal_filter_values =
537- get_optional_filters ( & normal_filter_values, & batch_indices) ;
538- self . update_accumulators (
539- & groups_with_rows,
540- & offsets,
541- & row_values,
542- & normal_values,
543- & row_filter_values,
544- & normal_filter_values,
545- & mut allocated,
546- ) ?;
547611 }
548612 allocated += self
549613 . row_converter
@@ -791,3 +855,21 @@ fn slice_and_maybe_filter(
791855 } ;
792856 Ok ( filtered_arrays)
793857}
858+
859+ /// This method is similar to Scalar::try_from_array except for the Null handling.
860+ /// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
861+ fn col_to_scalar (
862+ array : & ArrayRef ,
863+ filter : & Option < & BooleanArray > ,
864+ row_index : usize ,
865+ ) -> Result < ScalarValue > {
866+ if array. is_null ( row_index) {
867+ return Ok ( ScalarValue :: Null ) ;
868+ }
869+ if let Some ( filter) = filter {
870+ if !filter. value ( row_index) {
871+ return Ok ( ScalarValue :: Null ) ;
872+ }
873+ }
874+ ScalarValue :: try_from_array ( array, row_index)
875+ }
0 commit comments