@@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions};
2222use arrow:: datatypes:: DataType ;
2323use arrow:: record_batch:: RecordBatch ;
2424use arrow:: util:: pretty:: pretty_format_batches;
25- use datafusion:: physical_plan:: aggregates:: {
26- AggregateExec , AggregateMode , PhysicalGroupBy ,
27- } ;
25+ use arrow_array:: cast:: AsArray ;
26+ use arrow_array:: types:: Int64Type ;
27+ use arrow_array:: Array ;
28+ use hashbrown:: HashMap ;
2829use rand:: rngs:: StdRng ;
2930use rand:: { Rng , SeedableRng } ;
31+ use tokio:: task:: JoinSet ;
3032
33+ use datafusion:: common:: Result ;
34+ use datafusion:: datasource:: MemTable ;
35+ use datafusion:: physical_plan:: aggregates:: {
36+ AggregateExec , AggregateMode , PhysicalGroupBy ,
37+ } ;
3138use datafusion:: physical_plan:: memory:: MemoryExec ;
3239use datafusion:: physical_plan:: { collect, displayable, ExecutionPlan } ;
33- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
40+ use datafusion:: prelude:: { DataFrame , SessionConfig , SessionContext } ;
41+ use datafusion_common:: tree_node:: { TreeNode , TreeNodeVisitor , VisitRecursion } ;
3442use datafusion_physical_expr:: expressions:: { col, Sum } ;
3543use datafusion_physical_expr:: { AggregateExpr , PhysicalSortExpr } ;
36- use test_utils:: add_empty_batches;
44+ use datafusion_physical_plan:: InputOrderMode ;
45+ use test_utils:: { add_empty_batches, StringBatchGenerator } ;
3746
38- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 8 ) ]
39- async fn aggregate_test ( ) {
47+ /// Tests that streaming aggregate and batch (non streaming) aggregate produce
48+ /// same results
49+ #[ tokio:: test( flavor = "multi_thread" ) ]
50+ async fn streaming_aggregate_test ( ) {
4051 let test_cases = vec ! [
4152 vec![ "a" ] ,
4253 vec![ "b" , "a" ] ,
@@ -50,18 +61,18 @@ async fn aggregate_test() {
5061 let n = 300 ;
5162 let distincts = vec ! [ 10 , 20 ] ;
5263 for distinct in distincts {
53- let mut handles = Vec :: new ( ) ;
64+ let mut join_set = JoinSet :: new ( ) ;
5465 for i in 0 ..n {
5566 let test_idx = i % test_cases. len ( ) ;
5667 let group_by_columns = test_cases[ test_idx] . clone ( ) ;
57- let job = tokio :: spawn ( run_aggregate_test (
68+ join_set . spawn ( run_aggregate_test (
5869 make_staggered_batches :: < true > ( 1000 , distinct, i as u64 ) ,
5970 group_by_columns,
6071 ) ) ;
61- handles. push ( job) ;
6272 }
63- for job in handles {
64- job. await . unwrap ( ) ;
73+ while let Some ( join_handle) = join_set. join_next ( ) . await {
74+ // propagate errors
75+ join_handle. unwrap ( ) ;
6576 }
6677 }
6778}
@@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>(
234245 }
235246 add_empty_batches ( batches, & mut rng)
236247}
248+
249+ /// Test group by with string/large string columns
250+ #[ tokio:: test( flavor = "multi_thread" ) ]
251+ async fn group_by_strings ( ) {
252+ let mut join_set = JoinSet :: new ( ) ;
253+ for large in [ true , false ] {
254+ for sorted in [ true , false ] {
255+ for generator in StringBatchGenerator :: interesting_cases ( ) {
256+ join_set. spawn ( group_by_string_test ( generator, sorted, large) ) ;
257+ }
258+ }
259+ }
260+ while let Some ( join_handle) = join_set. join_next ( ) . await {
261+ // propagate errors
262+ join_handle. unwrap ( ) ;
263+ }
264+ }
265+
266+ /// Run GROUP BY <x> using SQL and ensure the results are correct
267+ ///
268+ /// If sorted is true, the input batches will be sorted by the group by column
269+ /// to test the streaming group by case
270+ ///
271+ /// if large is true, the input batches will be LargeStringArray
272+ async fn group_by_string_test (
273+ mut generator : StringBatchGenerator ,
274+ sorted : bool ,
275+ large : bool ,
276+ ) {
277+ let column_name = "a" ;
278+ let input = if sorted {
279+ generator. make_sorted_input_batches ( large)
280+ } else {
281+ generator. make_input_batches ( )
282+ } ;
283+
284+ let expected = compute_counts ( & input, column_name) ;
285+
286+ let schema = input[ 0 ] . schema ( ) ;
287+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 50 ) ;
288+ let ctx = SessionContext :: new_with_config ( session_config) ;
289+
290+ let provider = MemTable :: try_new ( schema. clone ( ) , vec ! [ input] ) . unwrap ( ) ;
291+ let provider = if sorted {
292+ let sort_expr = datafusion:: prelude:: col ( "a" ) . sort ( true , true ) ;
293+ provider. with_sort_order ( vec ! [ vec![ sort_expr] ] )
294+ } else {
295+ provider
296+ } ;
297+
298+ ctx. register_table ( "t" , Arc :: new ( provider) ) . unwrap ( ) ;
299+
300+ let df = ctx
301+ . sql ( "SELECT a, COUNT(*) FROM t GROUP BY a" )
302+ . await
303+ . unwrap ( ) ;
304+ verify_ordered_aggregate ( & df, sorted) . await ;
305+ let results = df. collect ( ) . await . unwrap ( ) ;
306+
307+ // verify that the results are correct
308+ let actual = extract_result_counts ( results) ;
309+ assert_eq ! ( expected, actual) ;
310+ }
311+ async fn verify_ordered_aggregate ( frame : & DataFrame , expected_sort : bool ) {
312+ struct Visitor {
313+ expected_sort : bool ,
314+ }
315+ let mut visitor = Visitor { expected_sort } ;
316+
317+ impl TreeNodeVisitor for Visitor {
318+ type N = Arc < dyn ExecutionPlan > ;
319+ fn pre_visit ( & mut self , node : & Self :: N ) -> Result < VisitRecursion > {
320+ if let Some ( exec) = node. as_any ( ) . downcast_ref :: < AggregateExec > ( ) {
321+ if self . expected_sort {
322+ assert ! ( matches!(
323+ exec. input_order_mode( ) ,
324+ InputOrderMode :: PartiallySorted ( _) | InputOrderMode :: Sorted
325+ ) ) ;
326+ } else {
327+ assert ! ( matches!( exec. input_order_mode( ) , InputOrderMode :: Linear ) ) ;
328+ }
329+ }
330+ Ok ( VisitRecursion :: Continue )
331+ }
332+ }
333+
334+ let plan = frame. clone ( ) . create_physical_plan ( ) . await . unwrap ( ) ;
335+ plan. visit ( & mut visitor) . unwrap ( ) ;
336+ }
337+
338+ /// Compute the count of each distinct value in the specified column
339+ ///
340+ /// ```text
341+ /// +---------------+---------------+
342+ /// | a | b |
343+ /// +---------------+---------------+
344+ /// | 𭏷𑩁 | 𘱦𫎛 |
345+ /// | | 𬿪 |
346+ /// ```
347+ fn compute_counts ( batches : & [ RecordBatch ] , col : & str ) -> HashMap < Option < String > , i64 > {
348+ let mut output = HashMap :: new ( ) ;
349+ for arr in batches
350+ . iter ( )
351+ . map ( |batch| batch. column_by_name ( col) . unwrap ( ) )
352+ {
353+ for value in to_str_vec ( arr) {
354+ output. entry ( value) . and_modify ( |e| * e += 1 ) . or_insert ( 1 ) ;
355+ }
356+ }
357+ output
358+ }
359+
360+ fn to_str_vec ( array : & ArrayRef ) -> Vec < Option < String > > {
361+ match array. data_type ( ) {
362+ DataType :: Utf8 => array
363+ . as_string :: < i32 > ( )
364+ . iter ( )
365+ . map ( |x| x. map ( |x| x. to_string ( ) ) )
366+ . collect ( ) ,
367+ DataType :: LargeUtf8 => array
368+ . as_string :: < i64 > ( )
369+ . iter ( )
370+ . map ( |x| x. map ( |x| x. to_string ( ) ) )
371+ . collect ( ) ,
372+ _ => panic ! ( "unexpected type" ) ,
373+ }
374+ }
375+
376+ /// extracts the value of the first column and the count of the second column
377+ /// ```text
378+ /// +----------------+----------+
379+ /// | a | COUNT(*) |
380+ /// +----------------+----------+
381+ /// | | 8 |
382+ /// | | 11 |
383+ /// ```
384+ fn extract_result_counts ( results : Vec < RecordBatch > ) -> HashMap < Option < String > , i64 > {
385+ let group_arrays = results. iter ( ) . map ( |batch| batch. column ( 0 ) ) ;
386+
387+ let count_arrays = results
388+ . iter ( )
389+ . map ( |batch| batch. column ( 1 ) . as_primitive :: < Int64Type > ( ) ) ;
390+
391+ let mut output = HashMap :: new ( ) ;
392+ for ( group_arr, count_arr) in group_arrays. zip ( count_arrays) {
393+ assert_eq ! ( group_arr. len( ) , count_arr. len( ) ) ;
394+ let group_values = to_str_vec ( group_arr) ;
395+ for ( group, count) in group_values. into_iter ( ) . zip ( count_arr. iter ( ) ) {
396+ assert ! ( output. get( & group) . is_none( ) ) ;
397+ let count = count. unwrap ( ) ; // counts can never be null
398+ output. insert ( group, count) ;
399+ }
400+ }
401+ output
402+ }
0 commit comments