diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 449c75c257f27..b6a49fbf65c72 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -721,7 +721,7 @@ impl std::hash::Hash for ScalarValue { /// dictionary array #[inline] fn get_dict_value( - array: &ArrayRef, + array: &dyn Array, index: usize, ) -> (&ArrayRef, Option) { let dict_array = as_dictionary_array::(array).unwrap(); @@ -1963,7 +1963,7 @@ impl ScalarValue { } fn get_decimal_value_from_array( - array: &ArrayRef, + array: &dyn Array, index: usize, precision: u8, scale: i8, @@ -1978,7 +1978,7 @@ impl ScalarValue { } /// Converts a value in `array` at `index` into a ScalarValue - pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value if !array.is_valid(index) { return array.data_type().try_into(); diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index c0b747702d131..c1cefd70ec6a6 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -79,7 +79,7 @@ SELECT stddev_pop(c2) FROM aggregate_test_100 1.3665650368716449 # csv_query_stddev_2 -query R +query R SELECT stddev_pop(c6) FROM aggregate_test_100 ---- 5.114326382039172e18 @@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan ---- NaN +# median_multi +# test case for https://github.com/apache/arrow-datafusion/issues/3105 +# has an intermediate grouping +statement ok +create table cpu (host string, usage float) as select * from (values +('host0', 90.1), +('host1', 90.2), +('host1', 90.4) +); + +query CI rowsort +select host, median(usage) from cpu group by host; +---- +host1 90.3 +host0 90.1 + +query CI +select median(usage) from cpu; +---- +90.2 + + +statement ok +drop table cpu; + +# median_multi_odd + +# data is not sorted and has an odd number of values per group +statement ok +create table cpu (host string, usage float) as select * from (values + ('host0', 90.2), + ('host1', 90.1), + ('host1', 90.5), + ('host0', 90.5), + ('host1', 90.0), + ('host1', 90.3), + ('host0', 87.9), + ('host1', 89.3) +); + +query CI rowsort +select host, median(usage) from cpu group by host; +---- +host0 90.2 +host1 90.1 + + +statement ok +drop table cpu; + +# median_multi_even +# data is not sorted and has an odd number of values per group +statement ok +create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3)); + +query CI rowsort +select host, median(usage) from cpu group by host; +---- +host1 90.25 +host0 90.35 + +statement ok +drop table cpu + # csv_query_external_table_count query I SELECT COUNT(c12) FROM aggregate_test_100 @@ -818,7 +882,7 @@ select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count # SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100; # csv_query_array_cube_agg_with_overflow -query TIIRIII +query TIIRIII select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2 ---- a 1 -88 -17.6 83 -85 5 @@ -870,7 +934,7 @@ e 847 40.333333333333336 120 -95 21 # query IIII # SELECT count(nanos), count(micros), count(millis), count(secs) FROM t # ---- -# 3 3 3 3 +# 3 3 3 3 # aggregate_timestamps_min # query TTTT diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index a04bd5369210c..abde3702fb565 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -19,13 +19,9 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; -use arrow::compute::sort; -use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; -use datafusion_common::cast::as_primitive_array; +use arrow::array::{Array, ArrayRef, UInt32Array}; +use arrow::compute::sort_to_indices; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, AggregateState}; use std::any::Any; @@ -74,9 +70,13 @@ impl AggregateExpr for Median { } fn state_fields(&self) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", self.data_type.clone(), true); + let data_type = DataType::List(Box::new(field)); + Ok(vec![Field::new( &format_state_name(&self.name, "median"), - self.data_type.clone(), + data_type, true, )]) } @@ -91,158 +91,126 @@ impl AggregateExpr for Median { } #[derive(Debug)] +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of those scalars struct MedianAccumulator { data_type: DataType, - all_values: Vec, -} - -macro_rules! median { - ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{ - let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?; - if combined.is_empty() { - return Ok(ScalarValue::Null); - } - let sorted = sort(&combined, None)?; - let array = as_primitive_array::<$TY>(&sorted)?; - let len = sorted.len(); - let mid = len / 2; - if len % 2 == 0 { - Ok(ScalarValue::$SCALAR_TY(Some( - (array.value(mid - 1) + array.value(mid)) / $TWO, - ))) - } else { - Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid)))) - } - }}; + all_values: Vec, } impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let mut vec: Vec = self - .all_values - .iter() - .map(|v| AggregateState::Array(v.clone())) - .collect(); - if vec.is_empty() { - match self.data_type { - DataType::UInt8 => vec.push(empty_array::()), - DataType::UInt16 => vec.push(empty_array::()), - DataType::UInt32 => vec.push(empty_array::()), - DataType::UInt64 => vec.push(empty_array::()), - DataType::Int8 => vec.push(empty_array::()), - DataType::Int16 => vec.push(empty_array::()), - DataType::Int32 => vec.push(empty_array::()), - DataType::Int64 => vec.push(empty_array::()), - DataType::Float32 => vec.push(empty_array::()), - DataType::Float64 => vec.push(empty_array::()), - _ => { - return Err(DataFusionError::Execution( - "unsupported data type for median".to_string(), - )) - } - } - } - Ok(vec) + let state = + ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone()); + Ok(vec![AggregateState::Scalar(state)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let x = values[0].clone(); - self.all_values.extend_from_slice(&[x]); + assert_eq!(values.len(), 1); + let array = &values[0]; + + assert_eq!(array.data_type(), &self.data_type); + self.all_values.reserve(self.all_values.len() + array.len()); + for index in 0..array.len() { + self.all_values + .push(ScalarValue::try_from_array(array, index)?); + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for array in states { - self.all_values.extend_from_slice(&[array.clone()]); + assert_eq!(states.len(), 1); + + let array = &states[0]; + assert!(matches!(array.data_type(), DataType::List(_))); + for index in 0..array.len() { + match ScalarValue::try_from_array(array, index)? { + ScalarValue::List(Some(mut values), _) => { + self.all_values.append(&mut values); + } + ScalarValue::List(None, _) => {} // skip empty state + v => { + return Err(DataFusionError::Internal(format!( + "unexpected state in median. Expected DataType::List, got {:?}", + v + ))) + } + } } Ok(()) } fn evaluate(&self) -> Result { - match self.all_values[0].data_type() { - DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 2), - DataType::Int16 => median!(self, arrow::datatypes::Int16Type, Int16, 2), - DataType::Int32 => median!(self, arrow::datatypes::Int32Type, Int32, 2), - DataType::Int64 => median!(self, arrow::datatypes::Int64Type, Int64, 2), - DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, UInt8, 2), - DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, UInt16, 2), - DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, UInt32, 2), - DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, UInt64, 2), - DataType::Float32 => { - median!(self, arrow::datatypes::Float32Type, Float32, 2_f32) - } - DataType::Float64 => { - median!(self, arrow::datatypes::Float64Type, Float64, 2_f64) + // Create an array of all the non null values and find the + // sorted indexes + let array = ScalarValue::iter_to_array( + self.all_values + .iter() + // ignore null values + .filter(|v| !v.is_null()) + .cloned(), + )?; + + // find the mid point + let len = array.len(); + let mid = len / 2; + + // only sort up to the top size/2 elements + let limit = Some(mid + 1); + let options = None; + let indices = sort_to_indices(&array, options, limit)?; + + // pick the relevant indices in the original arrays + let result = if len >= 2 && len % 2 == 0 { + // even number of values, average the two mid points + let s1 = scalar_at_index(&array, &indices, mid - 1)?; + let s2 = scalar_at_index(&array, &indices, mid)?; + match s1.add(s2)? { + ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)), + ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)), + ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)), + ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)), + ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)), + ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)), + ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)), + ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)), + ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)), + ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)), + v => { + return Err(DataFusionError::Internal(format!( + "Unsupported type in MedianAccumulator: {:?}", + v + ))) + } } - _ => Err(DataFusionError::Execution( - "unsupported data type for median".to_string(), - )), - } + } else { + // odd number of values, pick that one + scalar_at_index(&array, &indices, mid)? + }; + + Ok(result) } fn size(&self) -> usize { - std::mem::align_of_val(self) - + (std::mem::size_of::() * self.all_values.capacity()) - + self - .all_values - .iter() - .map(|array_ref| { - std::mem::size_of_val(array_ref.as_ref()) - + array_ref.get_array_memory_size() - }) - .sum::() + std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.all_values) + - std::mem::size_of_val(&self.all_values) + self.data_type.size() - std::mem::size_of_val(&self.data_type) } } -/// Create an empty array -fn empty_array() -> AggregateState { - AggregateState::Array(Arc::new(PrimitiveBuilder::::with_capacity(0).finish())) -} - -/// Combine all non-null values from provided arrays into a single array -fn combine_arrays(arrays: &[ArrayRef]) -> Result { - let len = arrays.iter().map(|a| a.len() - a.null_count()).sum(); - let mut builder: PrimitiveBuilder = PrimitiveBuilder::with_capacity(len); - for array in arrays { - let array = as_primitive_array::(array)?; - for i in 0..array.len() { - if !array.is_null(i) { - builder.append_value(array.value(i)); - } - } - } - Ok(Arc::new(builder.finish())) -} - -#[cfg(test)] -mod test { - use crate::aggregate::median::combine_arrays; - use arrow::array::{Int32Array, UInt32Array}; - use arrow::datatypes::{Int32Type, UInt32Type}; - use datafusion_common::Result; - use std::sync::Arc; - - #[test] - fn combine_i32_array() -> Result<()> { - let a = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b = combine_arrays::(&[a.clone(), a])?; - assert_eq!( - "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", - format!("{:?}", b) - ); - Ok(()) - } - - #[test] - fn combine_u32_array() -> Result<()> { - let a = Arc::new(UInt32Array::from(vec![1, 2, 3])); - let b = combine_arrays::(&[a.clone(), a])?; - assert_eq!( - "PrimitiveArray\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]", - format!("{:?}", b) - ); - Ok(()) - } +/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` +fn scalar_at_index( + array: &dyn Array, + indices: &UInt32Array, + indicies_index: usize, +) -> Result { + let array_index = indices + .value(indicies_index) + .try_into() + .expect("Convert uint32 to usize"); + ScalarValue::try_from_array(array, array_index) }