1717
1818//! Defines physical expressions that can evaluated at runtime during query execution
1919
20- use arrow:: array:: { Array , ArrayRef } ;
20+ use arrow:: array:: { Array , ArrayRef , AsArray } ;
2121use arrow:: datatypes:: DataType ;
2222use arrow_schema:: Field ;
2323
@@ -29,6 +29,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
2929use datafusion_expr:: utils:: format_state_name;
3030use datafusion_expr:: AggregateUDFImpl ;
3131use datafusion_expr:: { Accumulator , Signature , Volatility } ;
32+ use std:: collections:: HashSet ;
3233use std:: sync:: Arc ;
3334
3435make_udaf_expr_and_func ! (
@@ -82,6 +83,14 @@ impl AggregateUDFImpl for ArrayAgg {
8283 }
8384
8485 fn state_fields ( & self , args : StateFieldsArgs ) -> Result < Vec < Field > > {
86+ if args. is_distinct {
87+ return Ok ( vec ! [ Field :: new_list(
88+ format_state_name( args. name, "distinct_array_agg" ) ,
89+ Field :: new( "item" , args. input_type. clone( ) , true ) ,
90+ true ,
91+ ) ] ) ;
92+ }
93+
8594 Ok ( vec ! [ Field :: new_list(
8695 format_state_name( args. name, "array_agg" ) ,
8796 Field :: new( "item" , args. input_type. clone( ) , true ) ,
@@ -90,6 +99,12 @@ impl AggregateUDFImpl for ArrayAgg {
9099 }
91100
92101 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
102+ if acc_args. is_distinct {
103+ return Ok ( Box :: new ( DistinctArrayAggAccumulator :: try_new (
104+ acc_args. input_type ,
105+ ) ?) ) ;
106+ }
107+
93108 Ok ( Box :: new ( ArrayAggAccumulator :: try_new ( acc_args. input_type ) ?) )
94109 }
95110}
@@ -170,3 +185,65 @@ impl Accumulator for ArrayAggAccumulator {
170185 - std:: mem:: size_of_val ( & self . datatype )
171186 }
172187}
188+
189+ #[ derive( Debug ) ]
190+ struct DistinctArrayAggAccumulator {
191+ values : HashSet < ScalarValue > ,
192+ datatype : DataType ,
193+ }
194+
195+ impl DistinctArrayAggAccumulator {
196+ pub fn try_new ( datatype : & DataType ) -> Result < Self > {
197+ Ok ( Self {
198+ values : HashSet :: new ( ) ,
199+ datatype : datatype. clone ( ) ,
200+ } )
201+ }
202+ }
203+
204+ impl Accumulator for DistinctArrayAggAccumulator {
205+ fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
206+ Ok ( vec ! [ self . evaluate( ) ?] )
207+ }
208+
209+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
210+ assert_eq ! ( values. len( ) , 1 , "batch input should only include 1 column!" ) ;
211+
212+ let array = & values[ 0 ] ;
213+
214+ for i in 0 ..array. len ( ) {
215+ let scalar = ScalarValue :: try_from_array ( & array, i) ?;
216+ self . values . insert ( scalar) ;
217+ }
218+
219+ Ok ( ( ) )
220+ }
221+
222+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
223+ if states. is_empty ( ) {
224+ return Ok ( ( ) ) ;
225+ }
226+
227+ states[ 0 ]
228+ . as_list :: < i32 > ( )
229+ . iter ( )
230+ . flatten ( )
231+ . try_for_each ( |val| self . update_batch ( & [ val] ) )
232+ }
233+
234+ fn evaluate ( & mut self ) -> Result < ScalarValue > {
235+ let values: Vec < ScalarValue > = self . values . iter ( ) . cloned ( ) . collect ( ) ;
236+ if values. is_empty ( ) {
237+ return Ok ( ScalarValue :: new_null_list ( self . datatype . clone ( ) , true , 1 ) ) ;
238+ }
239+ let arr = ScalarValue :: new_list ( & values, & self . datatype , true ) ;
240+ Ok ( ScalarValue :: List ( arr) )
241+ }
242+
243+ fn size ( & self ) -> usize {
244+ std:: mem:: size_of_val ( self ) + ScalarValue :: size_of_hashset ( & self . values )
245+ - std:: mem:: size_of_val ( & self . values )
246+ + self . datatype . size ( )
247+ - std:: mem:: size_of_val ( & self . datatype )
248+ }
249+ }
0 commit comments