@@ -47,8 +47,8 @@ pub struct DistinctCount {
4747 name : String ,
4848 /// The DataType for the final count
4949 data_type : DataType ,
50- /// The DataType for each input argument
51- input_data_types : Vec < DataType > ,
50+ /// The DataType used to hold the state for each input
51+ state_data_types : Vec < DataType > ,
5252 /// The input arguments
5353 exprs : Vec < Arc < dyn PhysicalExpr > > ,
5454}
@@ -61,15 +61,26 @@ impl DistinctCount {
6161 name : String ,
6262 data_type : DataType ,
6363 ) -> Self {
64+ let state_data_types = input_data_types. into_iter ( ) . map ( state_type) . collect ( ) ;
65+
6466 Self {
65- input_data_types,
66- exprs,
6767 name,
6868 data_type,
69+ state_data_types,
70+ exprs,
6971 }
7072 }
7173}
7274
75+ /// return the type to use to accumulate state for the specified input type
76+ fn state_type ( data_type : DataType ) -> DataType {
77+ match data_type {
78+ // when aggregating dictionary values, use the underlying value type
79+ DataType :: Dictionary ( _key_type, value_type) => * value_type,
80+ t => t,
81+ }
82+ }
83+
7384impl AggregateExpr for DistinctCount {
7485 /// Return a reference to Any that can be used for downcasting
7586 fn as_any ( & self ) -> & dyn Any {
@@ -82,12 +93,16 @@ impl AggregateExpr for DistinctCount {
8293
8394 fn state_fields ( & self ) -> Result < Vec < Field > > {
8495 Ok ( self
85- . input_data_types
96+ . state_data_types
8697 . iter ( )
87- . map ( |data_type | {
98+ . map ( |state_data_type | {
8899 Field :: new (
89100 & format_state_name ( & self . name , "count distinct" ) ,
90- DataType :: List ( Box :: new ( Field :: new ( "item" , data_type. clone ( ) , true ) ) ) ,
101+ DataType :: List ( Box :: new ( Field :: new (
102+ "item" ,
103+ state_data_type. clone ( ) ,
104+ true ,
105+ ) ) ) ,
91106 false ,
92107 )
93108 } )
@@ -101,7 +116,7 @@ impl AggregateExpr for DistinctCount {
101116 fn create_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
102117 Ok ( Box :: new ( DistinctCountAccumulator {
103118 values : HashSet :: default ( ) ,
104- data_types : self . input_data_types . clone ( ) ,
119+ state_data_types : self . state_data_types . clone ( ) ,
105120 count_data_type : self . data_type . clone ( ) ,
106121 } ) )
107122 }
@@ -110,7 +125,7 @@ impl AggregateExpr for DistinctCount {
110125#[ derive( Debug ) ]
111126struct DistinctCountAccumulator {
112127 values : HashSet < DistinctScalarValues , RandomState > ,
113- data_types : Vec < DataType > ,
128+ state_data_types : Vec < DataType > ,
114129 count_data_type : DataType ,
115130}
116131
@@ -156,9 +171,11 @@ impl Accumulator for DistinctCountAccumulator {
156171
157172 fn state ( & self ) -> Result < Vec < ScalarValue > > {
158173 let mut cols_out = self
159- . data_types
174+ . state_data_types
160175 . iter ( )
161- . map ( |data_type| ScalarValue :: List ( Some ( Vec :: new ( ) ) , data_type. clone ( ) ) )
176+ . map ( |state_data_type| {
177+ ScalarValue :: List ( Some ( Vec :: new ( ) ) , state_data_type. clone ( ) )
178+ } )
162179 . collect :: < Vec < _ > > ( ) ;
163180
164181 let mut cols_vec = cols_out
0 commit comments