@@ -43,7 +43,15 @@ use arrow::array::*;
4343use arrow:: compute:: { cast, filter} ;
4444use arrow:: datatypes:: { DataType , Schema , UInt32Type } ;
4545use arrow:: { compute, datatypes:: SchemaRef , record_batch:: RecordBatch } ;
46- use datafusion_common:: cast:: { as_boolean_array, as_decimal128_array} ;
46+ use arrow_array:: types:: {
47+ Int16Type , Int32Type , Int64Type , Int8Type , UInt16Type , UInt64Type , UInt8Type ,
48+ } ;
49+ use arrow_schema:: { IntervalUnit , TimeUnit } ;
50+ use datafusion_common:: cast:: {
51+ as_boolean_array, as_decimal128_array, as_fixed_size_binary_array,
52+ as_fixed_size_list_array, as_list_array, as_struct_array,
53+ } ;
54+ use datafusion_common:: scalar:: get_dict_value;
4755use datafusion_common:: utils:: get_arrayref_at_indices;
4856use datafusion_common:: { DataFusionError , Result , ScalarValue } ;
4957use datafusion_expr:: Accumulator ;
@@ -865,8 +873,17 @@ macro_rules! typed_cast_to_scalar {
865873 } } ;
866874}
867875
868- /// This method is similar to [Scalar::try_from_array], it is used to update the Row Accumulators
869- /// This method only covers the types which support the row layout and the Null handling is different.
876+ macro_rules! typed_cast_tz_to_scalar {
877+ ( $array: expr, $index: expr, $ARRAYTYPE: ident, $SCALAR: ident, $TZ: expr) => { {
878+ let array = $array. as_any( ) . downcast_ref:: <$ARRAYTYPE>( ) . unwrap( ) ;
879+ Ok ( ScalarValue :: $SCALAR(
880+ Some ( array. value( $index) . into( ) ) ,
881+ $TZ. clone( ) ,
882+ ) )
883+ } } ;
884+ }
885+
886+ /// This method is similar to Scalar::try_from_array except for the Null handling.
870887/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
871888fn col_to_scalar (
872889 array : & ArrayRef ,
@@ -917,13 +934,161 @@ fn col_to_scalar(
917934 DataType :: Binary => {
918935 typed_cast_to_scalar ! ( array, row_index, BinaryArray , Binary )
919936 }
937+ DataType :: LargeBinary => {
938+ typed_cast_to_scalar ! ( array, row_index, LargeBinaryArray , LargeBinary )
939+ }
920940 DataType :: Utf8 => typed_cast_to_scalar ! ( array, row_index, StringArray , Utf8 ) ,
941+ DataType :: LargeUtf8 => {
942+ typed_cast_to_scalar ! ( array, row_index, LargeStringArray , LargeUtf8 )
943+ }
944+ DataType :: List ( nested_type) => {
945+ let list_array = as_list_array ( array) ?;
946+
947+ let nested_array = list_array. value ( row_index) ;
948+ let scalar_vec = ( 0 ..nested_array. len ( ) )
949+ . map ( |i| ScalarValue :: try_from_array ( & nested_array, i) )
950+ . collect :: < Result < Vec < _ > > > ( ) ?;
951+ let value = Some ( scalar_vec) ;
952+ Ok ( ScalarValue :: new_list (
953+ value,
954+ nested_type. data_type ( ) . clone ( ) ,
955+ ) )
956+ }
921957 DataType :: Date32 => {
922958 typed_cast_to_scalar ! ( array, row_index, Date32Array , Date32 )
923959 }
924960 DataType :: Date64 => {
925961 typed_cast_to_scalar ! ( array, row_index, Date64Array , Date64 )
926962 }
963+ DataType :: Time32 ( TimeUnit :: Second ) => {
964+ typed_cast_to_scalar ! ( array, row_index, Time32SecondArray , Time32Second )
965+ }
966+ DataType :: Time32 ( TimeUnit :: Millisecond ) => typed_cast_to_scalar ! (
967+ array,
968+ row_index,
969+ Time32MillisecondArray ,
970+ Time32Millisecond
971+ ) ,
972+ DataType :: Time64 ( TimeUnit :: Microsecond ) => typed_cast_to_scalar ! (
973+ array,
974+ row_index,
975+ Time64MicrosecondArray ,
976+ Time64Microsecond
977+ ) ,
978+ DataType :: Time64 ( TimeUnit :: Nanosecond ) => typed_cast_to_scalar ! (
979+ array,
980+ row_index,
981+ Time64NanosecondArray ,
982+ Time64Nanosecond
983+ ) ,
984+ DataType :: Timestamp ( TimeUnit :: Second , tz_opt) => typed_cast_tz_to_scalar ! (
985+ array,
986+ row_index,
987+ TimestampSecondArray ,
988+ TimestampSecond ,
989+ tz_opt
990+ ) ,
991+ DataType :: Timestamp ( TimeUnit :: Millisecond , tz_opt) => {
992+ typed_cast_tz_to_scalar ! (
993+ array,
994+ row_index,
995+ TimestampMillisecondArray ,
996+ TimestampMillisecond ,
997+ tz_opt
998+ )
999+ }
1000+ DataType :: Timestamp ( TimeUnit :: Microsecond , tz_opt) => {
1001+ typed_cast_tz_to_scalar ! (
1002+ array,
1003+ row_index,
1004+ TimestampMicrosecondArray ,
1005+ TimestampMicrosecond ,
1006+ tz_opt
1007+ )
1008+ }
1009+ DataType :: Timestamp ( TimeUnit :: Nanosecond , tz_opt) => {
1010+ typed_cast_tz_to_scalar ! (
1011+ array,
1012+ row_index,
1013+ TimestampNanosecondArray ,
1014+ TimestampNanosecond ,
1015+ tz_opt
1016+ )
1017+ }
1018+ DataType :: Dictionary ( key_type, _) => {
1019+ let ( values_array, values_index) = match key_type. as_ref ( ) {
1020+ DataType :: Int8 => get_dict_value :: < Int8Type > ( array, row_index) ,
1021+ DataType :: Int16 => get_dict_value :: < Int16Type > ( array, row_index) ,
1022+ DataType :: Int32 => get_dict_value :: < Int32Type > ( array, row_index) ,
1023+ DataType :: Int64 => get_dict_value :: < Int64Type > ( array, row_index) ,
1024+ DataType :: UInt8 => get_dict_value :: < UInt8Type > ( array, row_index) ,
1025+ DataType :: UInt16 => get_dict_value :: < UInt16Type > ( array, row_index) ,
1026+ DataType :: UInt32 => get_dict_value :: < UInt32Type > ( array, row_index) ,
1027+ DataType :: UInt64 => get_dict_value :: < UInt64Type > ( array, row_index) ,
1028+ _ => unreachable ! ( "Invalid dictionary keys type: {:?}" , key_type) ,
1029+ } ;
1030+ // look up the index in the values dictionary
1031+ match values_index {
1032+ Some ( values_index) => {
1033+ let value = ScalarValue :: try_from_array ( values_array, values_index) ?;
1034+ Ok ( ScalarValue :: Dictionary ( key_type. clone ( ) , Box :: new ( value) ) )
1035+ }
1036+ // else entry was null, so return null
1037+ None => Ok ( ScalarValue :: Null ) ,
1038+ }
1039+ }
1040+ DataType :: Struct ( fields) => {
1041+ let array = as_struct_array ( array) ?;
1042+ let mut field_values: Vec < ScalarValue > = Vec :: new ( ) ;
1043+ for col_index in 0 ..array. num_columns ( ) {
1044+ let col_array = array. column ( col_index) ;
1045+ let col_scalar = ScalarValue :: try_from_array ( col_array, row_index) ?;
1046+ field_values. push ( col_scalar) ;
1047+ }
1048+ Ok ( ScalarValue :: Struct ( Some ( field_values) , fields. clone ( ) ) )
1049+ }
1050+ DataType :: FixedSizeList ( nested_type, _len) => {
1051+ let list_array = as_fixed_size_list_array ( array) ?;
1052+ match list_array. is_null ( row_index) {
1053+ true => Ok ( ScalarValue :: Null ) ,
1054+ false => {
1055+ let nested_array = list_array. value ( row_index) ;
1056+ let scalar_vec = ( 0 ..nested_array. len ( ) )
1057+ . map ( |i| ScalarValue :: try_from_array ( & nested_array, i) )
1058+ . collect :: < Result < Vec < _ > > > ( ) ?;
1059+ Ok ( ScalarValue :: new_list (
1060+ Some ( scalar_vec) ,
1061+ nested_type. data_type ( ) . clone ( ) ,
1062+ ) )
1063+ }
1064+ }
1065+ }
1066+ DataType :: FixedSizeBinary ( _) => {
1067+ let array = as_fixed_size_binary_array ( array) ?;
1068+ let size = match array. data_type ( ) {
1069+ DataType :: FixedSizeBinary ( size) => * size,
1070+ _ => unreachable ! ( ) ,
1071+ } ;
1072+ Ok ( ScalarValue :: FixedSizeBinary (
1073+ size,
1074+ Some ( array. value ( row_index) . into ( ) ) ,
1075+ ) )
1076+ }
1077+ DataType :: Interval ( IntervalUnit :: DayTime ) => {
1078+ typed_cast_to_scalar ! ( array, row_index, IntervalDayTimeArray , IntervalDayTime )
1079+ }
1080+ DataType :: Interval ( IntervalUnit :: YearMonth ) => typed_cast_to_scalar ! (
1081+ array,
1082+ row_index,
1083+ IntervalYearMonthArray ,
1084+ IntervalYearMonth
1085+ ) ,
1086+ DataType :: Interval ( IntervalUnit :: MonthDayNano ) => typed_cast_to_scalar ! (
1087+ array,
1088+ row_index,
1089+ IntervalMonthDayNanoArray ,
1090+ IntervalMonthDayNano
1091+ ) ,
9271092 other => Err ( DataFusionError :: NotImplemented ( format ! (
9281093 "GroupedHashAggregate: can't create a scalar from array of type \" {other:?}\" "
9291094 ) ) ) ,
0 commit comments