Skip to content

Commit 10b0eff

Browse files
authored
Row accumulator support update Scalar values (#6003)
* support update RowAccumulators using Scalar values * fix group by count multi exprs * refine hot path, avoid Vec creation * fix UT * resolve review comments * remove redundant null check
1 parent 9798fbc commit 10b0eff

9 files changed

Lines changed: 330 additions & 41 deletions

File tree

datafusion/common/src/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ impl std::hash::Hash for ScalarValue {
14431443
/// return a reference to the values array and the index into it for a
14441444
/// dictionary array
14451445
#[inline]
1446-
fn get_dict_value<K: ArrowDictionaryKeyType>(
1446+
pub fn get_dict_value<K: ArrowDictionaryKeyType>(
14471447
array: &dyn Array,
14481448
index: usize,
14491449
) -> (&ArrayRef, Option<usize>) {

datafusion/core/src/physical_plan/aggregates/row_hash.rs

Lines changed: 114 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,15 @@ use futures::stream::{Stream, StreamExt};
3131

3232
use crate::execution::context::TaskContext;
3333
use crate::execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
34+
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
3435
use crate::physical_plan::aggregates::{
3536
evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AccumulatorItem,
3637
AggregateMode, PhysicalGroupBy, RowAccumulatorItem,
3738
};
3839
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
3940
use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
4041
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
41-
42-
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
43-
use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray, UInt32Builder};
42+
use arrow::array::*;
4443
use arrow::compute::{cast, filter};
4544
use arrow::datatypes::{DataType, Schema, UInt32Type};
4645
use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
@@ -53,6 +52,7 @@ use datafusion_row::layout::RowLayout;
5352
use datafusion_row::reader::{read_row, RowReader};
5453
use datafusion_row::MutableRecordBatch;
5554
use hashbrown::raw::RawTable;
55+
use itertools::izip;
5656

5757
/// Grouping aggregate with row-format aggregation states inside.
5858
///
@@ -409,7 +409,7 @@ impl GroupedHashAggregateStream {
409409

410410
// Update the accumulator results, according to row_aggr_state.
411411
#[allow(clippy::too_many_arguments)]
412-
fn update_accumulators(
412+
fn update_accumulators_using_batch(
413413
&mut self,
414414
groups_with_rows: &[usize],
415415
offsets: &[usize],
@@ -490,6 +490,55 @@ impl GroupedHashAggregateStream {
490490
Ok(())
491491
}
492492

493+
// Update the accumulator results, according to row_aggr_state.
494+
fn update_accumulators_using_scalar(
495+
&mut self,
496+
groups_with_rows: &[usize],
497+
row_values: &[Vec<ArrayRef>],
498+
row_filter_values: &[Option<ArrayRef>],
499+
) -> Result<()> {
500+
let filter_bool_array = row_filter_values
501+
.iter()
502+
.map(|filter_opt| match filter_opt {
503+
Some(f) => Ok(Some(as_boolean_array(f)?)),
504+
None => Ok(None),
505+
})
506+
.collect::<Result<Vec<_>>>()?;
507+
508+
for group_idx in groups_with_rows {
509+
let group_state = &mut self.aggr_state.group_states[*group_idx];
510+
let mut state_accessor =
511+
RowAccessor::new_from_layout(self.row_aggr_layout.clone());
512+
state_accessor.point_to(0, group_state.aggregation_buffer.as_mut_slice());
513+
for idx in &group_state.indices {
514+
for (accumulator, values_array, filter_array) in izip!(
515+
self.row_accumulators.iter_mut(),
516+
row_values.iter(),
517+
filter_bool_array.iter()
518+
) {
519+
if values_array.len() == 1 {
520+
let scalar_value =
521+
col_to_scalar(&values_array[0], filter_array, *idx as usize)?;
522+
accumulator.update_scalar(&scalar_value, &mut state_accessor)?;
523+
} else {
524+
let scalar_values = values_array
525+
.iter()
526+
.map(|array| {
527+
col_to_scalar(array, filter_array, *idx as usize)
528+
})
529+
.collect::<Result<Vec<_>>>()?;
530+
accumulator
531+
.update_scalar_values(&scalar_values, &mut state_accessor)?;
532+
}
533+
}
534+
}
535+
// clear the group indices in this group
536+
group_state.indices.clear();
537+
}
538+
539+
Ok(())
540+
}
541+
493542
/// Perform group-by aggregation for the given [`RecordBatch`].
494543
///
495544
/// If successful, this returns the additional number of bytes that were allocated during this process.
@@ -515,35 +564,50 @@ impl GroupedHashAggregateStream {
515564
for group_values in &group_by_values {
516565
let groups_with_rows =
517566
self.update_group_state(group_values, &mut allocated)?;
518-
519-
// Collect all indices + offsets based on keys in this vec
520-
let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0);
521-
let mut offsets = vec![0];
522-
let mut offset_so_far = 0;
523-
for &group_idx in groups_with_rows.iter() {
524-
let indices = &self.aggr_state.group_states[group_idx].indices;
525-
batch_indices.append_slice(indices);
526-
offset_so_far += indices.len();
527-
offsets.push(offset_so_far);
567+
// Decide the accumulators update mode, use scalar value to update the accumulators when all of the conditions are meet:
568+
// 1) The aggregation mode is Partial or Single
569+
// 2) There is not normal aggregation expressions
570+
// 3) The number of affected groups is high (entries in `aggr_state` have rows need to update). Usually the high cardinality case
571+
if matches!(self.mode, AggregateMode::Partial | AggregateMode::Single)
572+
&& normal_aggr_input_values.is_empty()
573+
&& normal_filter_values.is_empty()
574+
&& groups_with_rows.len() >= batch.num_rows() / 10
575+
{
576+
self.update_accumulators_using_scalar(
577+
&groups_with_rows,
578+
&row_aggr_input_values,
579+
&row_filter_values,
580+
)?;
581+
} else {
582+
// Collect all indices + offsets based on keys in this vec
583+
let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0);
584+
let mut offsets = vec![0];
585+
let mut offset_so_far = 0;
586+
for &group_idx in groups_with_rows.iter() {
587+
let indices = &self.aggr_state.group_states[group_idx].indices;
588+
batch_indices.append_slice(indices);
589+
offset_so_far += indices.len();
590+
offsets.push(offset_so_far);
591+
}
592+
let batch_indices = batch_indices.finish();
593+
594+
let row_values = get_at_indices(&row_aggr_input_values, &batch_indices)?;
595+
let normal_values =
596+
get_at_indices(&normal_aggr_input_values, &batch_indices)?;
597+
let row_filter_values =
598+
get_optional_filters(&row_filter_values, &batch_indices);
599+
let normal_filter_values =
600+
get_optional_filters(&normal_filter_values, &batch_indices);
601+
self.update_accumulators_using_batch(
602+
&groups_with_rows,
603+
&offsets,
604+
&row_values,
605+
&normal_values,
606+
&row_filter_values,
607+
&normal_filter_values,
608+
&mut allocated,
609+
)?;
528610
}
529-
let batch_indices = batch_indices.finish();
530-
531-
let row_values = get_at_indices(&row_aggr_input_values, &batch_indices)?;
532-
let normal_values =
533-
get_at_indices(&normal_aggr_input_values, &batch_indices)?;
534-
let row_filter_values =
535-
get_optional_filters(&row_filter_values, &batch_indices);
536-
let normal_filter_values =
537-
get_optional_filters(&normal_filter_values, &batch_indices);
538-
self.update_accumulators(
539-
&groups_with_rows,
540-
&offsets,
541-
&row_values,
542-
&normal_values,
543-
&row_filter_values,
544-
&normal_filter_values,
545-
&mut allocated,
546-
)?;
547611
}
548612
allocated += self
549613
.row_converter
@@ -791,3 +855,21 @@ fn slice_and_maybe_filter(
791855
};
792856
Ok(filtered_arrays)
793857
}
858+
859+
/// This method is similar to Scalar::try_from_array except for the Null handling.
860+
/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
861+
fn col_to_scalar(
862+
array: &ArrayRef,
863+
filter: &Option<&BooleanArray>,
864+
row_index: usize,
865+
) -> Result<ScalarValue> {
866+
if array.is_null(row_index) {
867+
return Ok(ScalarValue::Null);
868+
}
869+
if let Some(filter) = filter {
870+
if !filter.value(row_index) {
871+
return Ok(ScalarValue::Null);
872+
}
873+
}
874+
ScalarValue::try_from_array(array, row_index)
875+
}

datafusion/core/tests/sql/aggregates.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,57 @@ async fn count_multi_expr() -> Result<()> {
543543
Ok(())
544544
}
545545

546+
#[tokio::test]
547+
async fn count_multi_expr_group_by() -> Result<()> {
548+
let schema = Arc::new(Schema::new(vec![
549+
Field::new("c1", DataType::Int32, true),
550+
Field::new("c2", DataType::Int32, true),
551+
Field::new("c3", DataType::Int32, true),
552+
]));
553+
554+
let data = RecordBatch::try_new(
555+
schema.clone(),
556+
vec![
557+
Arc::new(Int32Array::from(vec![
558+
Some(0),
559+
None,
560+
Some(1),
561+
Some(2),
562+
None,
563+
])),
564+
Arc::new(Int32Array::from(vec![
565+
Some(1),
566+
Some(1),
567+
Some(0),
568+
None,
569+
None,
570+
])),
571+
Arc::new(Int32Array::from(vec![
572+
Some(10),
573+
Some(10),
574+
Some(10),
575+
Some(10),
576+
Some(10),
577+
])),
578+
],
579+
)?;
580+
581+
let ctx = SessionContext::new();
582+
ctx.register_batch("test", data)?;
583+
let sql = "SELECT c3, count(c1, c2) FROM test group by c3";
584+
let actual = execute_to_batches(&ctx, sql).await;
585+
586+
let expected = vec![
587+
"+----+------------------------+",
588+
"| c3 | COUNT(test.c1,test.c2) |",
589+
"+----+------------------------+",
590+
"| 10 | 2 |",
591+
"+----+------------------------+",
592+
];
593+
assert_batches_sorted_eq!(expected, &actual);
594+
Ok(())
595+
}
596+
546597
#[tokio::test]
547598
async fn simple_avg() -> Result<()> {
548599
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,24 @@ impl RowAccumulator for AvgRowAccumulator {
299299
self.state_index() + 1,
300300
accessor,
301301
&sum::sum_batch(values, &self.sum_datatype)?,
302-
)?;
303-
Ok(())
302+
)
303+
}
304+
305+
fn update_scalar_values(
306+
&mut self,
307+
values: &[ScalarValue],
308+
accessor: &mut RowAccessor,
309+
) -> Result<()> {
310+
let value = &values[0];
311+
sum::update_avg_to_row(self.state_index(), accessor, value)
312+
}
313+
314+
fn update_scalar(
315+
&mut self,
316+
value: &ScalarValue,
317+
accessor: &mut RowAccessor,
318+
) -> Result<()> {
319+
sum::update_avg_to_row(self.state_index(), accessor, value)
304320
}
305321

306322
fn merge_batch(
@@ -315,8 +331,7 @@ impl RowAccumulator for AvgRowAccumulator {
315331

316332
// sum
317333
let difference = sum::sum_batch(&states[1], &self.sum_datatype)?;
318-
sum::add_to_row(self.state_index() + 1, accessor, &difference)?;
319-
Ok(())
334+
sum::add_to_row(self.state_index() + 1, accessor, &difference)
320335
}
321336

322337
fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {

datafusion/physical-expr/src/aggregate/count.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,31 @@ impl RowAccumulator for CountRowAccumulator {
242242
Ok(())
243243
}
244244

245+
fn update_scalar_values(
246+
&mut self,
247+
values: &[ScalarValue],
248+
accessor: &mut RowAccessor,
249+
) -> Result<()> {
250+
if !values.iter().any(|s| matches!(s, ScalarValue::Null)) {
251+
accessor.add_u64(self.state_index, 1)
252+
}
253+
Ok(())
254+
}
255+
256+
fn update_scalar(
257+
&mut self,
258+
value: &ScalarValue,
259+
accessor: &mut RowAccessor,
260+
) -> Result<()> {
261+
match value {
262+
ScalarValue::Null => {
263+
// do not update the accumulator
264+
}
265+
_ => accessor.add_u64(self.state_index, 1),
266+
}
267+
Ok(())
268+
}
269+
245270
fn merge_batch(
246271
&mut self,
247272
states: &[ArrayRef],

datafusion/physical-expr/src/aggregate/min_max.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ macro_rules! min_max_v2 {
565565
ScalarValue::Decimal128(rhs, ..) => {
566566
typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP)
567567
}
568+
ScalarValue::Null => {
569+
// do nothing
570+
}
568571
e => {
569572
return Err(DataFusionError::Internal(format!(
570573
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
@@ -709,8 +712,24 @@ impl RowAccumulator for MaxRowAccumulator {
709712
) -> Result<()> {
710713
let values = &values[0];
711714
let delta = &max_batch(values)?;
712-
max_row(self.index, accessor, delta)?;
713-
Ok(())
715+
max_row(self.index, accessor, delta)
716+
}
717+
718+
fn update_scalar_values(
719+
&mut self,
720+
values: &[ScalarValue],
721+
accessor: &mut RowAccessor,
722+
) -> Result<()> {
723+
let value = &values[0];
724+
max_row(self.index, accessor, value)
725+
}
726+
727+
fn update_scalar(
728+
&mut self,
729+
value: &ScalarValue,
730+
accessor: &mut RowAccessor,
731+
) -> Result<()> {
732+
max_row(self.index, accessor, value)
714733
}
715734

716735
fn merge_batch(
@@ -956,6 +975,23 @@ impl RowAccumulator for MinRowAccumulator {
956975
Ok(())
957976
}
958977

978+
fn update_scalar_values(
979+
&mut self,
980+
values: &[ScalarValue],
981+
accessor: &mut RowAccessor,
982+
) -> Result<()> {
983+
let value = &values[0];
984+
min_row(self.index, accessor, value)
985+
}
986+
987+
fn update_scalar(
988+
&mut self,
989+
value: &ScalarValue,
990+
accessor: &mut RowAccessor,
991+
) -> Result<()> {
992+
min_row(self.index, accessor, value)
993+
}
994+
959995
fn merge_batch(
960996
&mut self,
961997
states: &[ArrayRef],

0 commit comments

Comments
 (0)