Skip to content

Commit 5bfc141

Browse files
committed
fix UT
1 parent b95c769 commit 5bfc141

1 file changed

Lines changed: 168 additions & 3 deletions

File tree

  • datafusion/core/src/physical_plan/aggregates

datafusion/core/src/physical_plan/aggregates/row_hash.rs

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,15 @@ use arrow::array::*;
4343
use arrow::compute::{cast, filter};
4444
use arrow::datatypes::{DataType, Schema, UInt32Type};
4545
use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
46-
use datafusion_common::cast::{as_boolean_array, as_decimal128_array};
46+
use arrow_array::types::{
47+
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt64Type, UInt8Type,
48+
};
49+
use arrow_schema::{IntervalUnit, TimeUnit};
50+
use datafusion_common::cast::{
51+
as_boolean_array, as_decimal128_array, as_fixed_size_binary_array,
52+
as_fixed_size_list_array, as_list_array, as_struct_array,
53+
};
54+
use datafusion_common::scalar::get_dict_value;
4755
use datafusion_common::utils::get_arrayref_at_indices;
4856
use datafusion_common::{DataFusionError, Result, ScalarValue};
4957
use datafusion_expr::Accumulator;
@@ -865,8 +873,17 @@ macro_rules! typed_cast_to_scalar {
865873
}};
866874
}
867875

868-
/// This method is similar to [Scalar::try_from_array], it is used to update the Row Accumulators
869-
/// This method only covers the types which support the row layout and the Null handling is different.
876+
macro_rules! typed_cast_tz_to_scalar {
877+
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{
878+
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
879+
Ok(ScalarValue::$SCALAR(
880+
Some(array.value($index).into()),
881+
$TZ.clone(),
882+
))
883+
}};
884+
}
885+
886+
/// This method is similar to Scalar::try_from_array except for the Null handling.
870887
/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
871888
fn col_to_scalar(
872889
array: &ArrayRef,
@@ -917,13 +934,161 @@ fn col_to_scalar(
917934
DataType::Binary => {
918935
typed_cast_to_scalar!(array, row_index, BinaryArray, Binary)
919936
}
937+
DataType::LargeBinary => {
938+
typed_cast_to_scalar!(array, row_index, LargeBinaryArray, LargeBinary)
939+
}
920940
DataType::Utf8 => typed_cast_to_scalar!(array, row_index, StringArray, Utf8),
941+
DataType::LargeUtf8 => {
942+
typed_cast_to_scalar!(array, row_index, LargeStringArray, LargeUtf8)
943+
}
944+
DataType::List(nested_type) => {
945+
let list_array = as_list_array(array)?;
946+
947+
let nested_array = list_array.value(row_index);
948+
let scalar_vec = (0..nested_array.len())
949+
.map(|i| ScalarValue::try_from_array(&nested_array, i))
950+
.collect::<Result<Vec<_>>>()?;
951+
let value = Some(scalar_vec);
952+
Ok(ScalarValue::new_list(
953+
value,
954+
nested_type.data_type().clone(),
955+
))
956+
}
921957
DataType::Date32 => {
922958
typed_cast_to_scalar!(array, row_index, Date32Array, Date32)
923959
}
924960
DataType::Date64 => {
925961
typed_cast_to_scalar!(array, row_index, Date64Array, Date64)
926962
}
963+
DataType::Time32(TimeUnit::Second) => {
964+
typed_cast_to_scalar!(array, row_index, Time32SecondArray, Time32Second)
965+
}
966+
DataType::Time32(TimeUnit::Millisecond) => typed_cast_to_scalar!(
967+
array,
968+
row_index,
969+
Time32MillisecondArray,
970+
Time32Millisecond
971+
),
972+
DataType::Time64(TimeUnit::Microsecond) => typed_cast_to_scalar!(
973+
array,
974+
row_index,
975+
Time64MicrosecondArray,
976+
Time64Microsecond
977+
),
978+
DataType::Time64(TimeUnit::Nanosecond) => typed_cast_to_scalar!(
979+
array,
980+
row_index,
981+
Time64NanosecondArray,
982+
Time64Nanosecond
983+
),
984+
DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz_to_scalar!(
985+
array,
986+
row_index,
987+
TimestampSecondArray,
988+
TimestampSecond,
989+
tz_opt
990+
),
991+
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
992+
typed_cast_tz_to_scalar!(
993+
array,
994+
row_index,
995+
TimestampMillisecondArray,
996+
TimestampMillisecond,
997+
tz_opt
998+
)
999+
}
1000+
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
1001+
typed_cast_tz_to_scalar!(
1002+
array,
1003+
row_index,
1004+
TimestampMicrosecondArray,
1005+
TimestampMicrosecond,
1006+
tz_opt
1007+
)
1008+
}
1009+
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
1010+
typed_cast_tz_to_scalar!(
1011+
array,
1012+
row_index,
1013+
TimestampNanosecondArray,
1014+
TimestampNanosecond,
1015+
tz_opt
1016+
)
1017+
}
1018+
DataType::Dictionary(key_type, _) => {
1019+
let (values_array, values_index) = match key_type.as_ref() {
1020+
DataType::Int8 => get_dict_value::<Int8Type>(array, row_index),
1021+
DataType::Int16 => get_dict_value::<Int16Type>(array, row_index),
1022+
DataType::Int32 => get_dict_value::<Int32Type>(array, row_index),
1023+
DataType::Int64 => get_dict_value::<Int64Type>(array, row_index),
1024+
DataType::UInt8 => get_dict_value::<UInt8Type>(array, row_index),
1025+
DataType::UInt16 => get_dict_value::<UInt16Type>(array, row_index),
1026+
DataType::UInt32 => get_dict_value::<UInt32Type>(array, row_index),
1027+
DataType::UInt64 => get_dict_value::<UInt64Type>(array, row_index),
1028+
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
1029+
};
1030+
// look up the index in the values dictionary
1031+
match values_index {
1032+
Some(values_index) => {
1033+
let value = ScalarValue::try_from_array(values_array, values_index)?;
1034+
Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(value)))
1035+
}
1036+
// else entry was null, so return null
1037+
None => Ok(ScalarValue::Null),
1038+
}
1039+
}
1040+
DataType::Struct(fields) => {
1041+
let array = as_struct_array(array)?;
1042+
let mut field_values: Vec<ScalarValue> = Vec::new();
1043+
for col_index in 0..array.num_columns() {
1044+
let col_array = array.column(col_index);
1045+
let col_scalar = ScalarValue::try_from_array(col_array, row_index)?;
1046+
field_values.push(col_scalar);
1047+
}
1048+
Ok(ScalarValue::Struct(Some(field_values), fields.clone()))
1049+
}
1050+
DataType::FixedSizeList(nested_type, _len) => {
1051+
let list_array = as_fixed_size_list_array(array)?;
1052+
match list_array.is_null(row_index) {
1053+
true => Ok(ScalarValue::Null),
1054+
false => {
1055+
let nested_array = list_array.value(row_index);
1056+
let scalar_vec = (0..nested_array.len())
1057+
.map(|i| ScalarValue::try_from_array(&nested_array, i))
1058+
.collect::<Result<Vec<_>>>()?;
1059+
Ok(ScalarValue::new_list(
1060+
Some(scalar_vec),
1061+
nested_type.data_type().clone(),
1062+
))
1063+
}
1064+
}
1065+
}
1066+
DataType::FixedSizeBinary(_) => {
1067+
let array = as_fixed_size_binary_array(array)?;
1068+
let size = match array.data_type() {
1069+
DataType::FixedSizeBinary(size) => *size,
1070+
_ => unreachable!(),
1071+
};
1072+
Ok(ScalarValue::FixedSizeBinary(
1073+
size,
1074+
Some(array.value(row_index).into()),
1075+
))
1076+
}
1077+
DataType::Interval(IntervalUnit::DayTime) => {
1078+
typed_cast_to_scalar!(array, row_index, IntervalDayTimeArray, IntervalDayTime)
1079+
}
1080+
DataType::Interval(IntervalUnit::YearMonth) => typed_cast_to_scalar!(
1081+
array,
1082+
row_index,
1083+
IntervalYearMonthArray,
1084+
IntervalYearMonth
1085+
),
1086+
DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast_to_scalar!(
1087+
array,
1088+
row_index,
1089+
IntervalMonthDayNanoArray,
1090+
IntervalMonthDayNano
1091+
),
9271092
other => Err(DataFusionError::NotImplemented(format!(
9281093
"GroupedHashAggregate: can't create a scalar from array of type \"{other:?}\""
9291094
))),

0 commit comments

Comments
 (0)