2020use std:: sync:: Arc ;
2121
2222use arrow:: {
23- array:: { ArrayRef , Int32Array } ,
23+ array:: { as_string_array , ArrayRef , Int32Array , StringArray } ,
2424 compute:: SortOptions ,
2525 record_batch:: RecordBatch ,
2626} ;
@@ -29,6 +29,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
2929use datafusion:: physical_plan:: sorts:: sort:: SortExec ;
3030use datafusion:: physical_plan:: { collect, ExecutionPlan } ;
3131use datafusion:: prelude:: { SessionConfig , SessionContext } ;
32+ use datafusion_common:: cast:: as_int32_array;
3233use datafusion_execution:: memory_pool:: GreedyMemoryPool ;
3334use datafusion_physical_expr:: expressions:: col;
3435use datafusion_physical_expr_common:: sort_expr:: LexOrdering ;
@@ -42,12 +43,17 @@ const KB: usize = 1 << 10;
4243#[ cfg_attr( tarpaulin, ignore) ]
4344async fn test_sort_10k_mem ( ) {
4445 for ( batch_size, should_spill) in [ ( 5 , false ) , ( 20000 , true ) , ( 500000 , true ) ] {
45- SortTest :: new ( )
46+ let ( input , collected ) = SortTest :: new ( )
4647 . with_int32_batches ( batch_size)
48+ . with_sort_columns ( vec ! [ "x" ] )
4749 . with_pool_size ( 10 * KB )
4850 . with_should_spill ( should_spill)
4951 . run ( )
5052 . await ;
53+
54+ let expected = partitions_to_sorted_vec ( & input) ;
55+ let actual = batches_to_vec ( & collected) ;
56+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
5157 }
5258}
5359
@@ -57,29 +63,123 @@ async fn test_sort_100k_mem() {
5763 for ( batch_size, should_spill) in
5864 [ ( 5 , false ) , ( 10000 , false ) , ( 20000 , true ) , ( 1000000 , true ) ]
5965 {
60- SortTest :: new ( )
66+ let ( input , collected ) = SortTest :: new ( )
6167 . with_int32_batches ( batch_size)
68+ . with_sort_columns ( vec ! [ "x" ] )
69+ . with_pool_size ( 100 * KB )
70+ . with_should_spill ( should_spill)
71+ . run ( )
72+ . await ;
73+
74+ let expected = partitions_to_sorted_vec ( & input) ;
75+ let actual = batches_to_vec ( & collected) ;
76+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
77+ }
78+ }
79+
80+ #[ tokio:: test]
81+ #[ cfg_attr( tarpaulin, ignore) ]
82+ async fn test_sort_strings_100k_mem ( ) {
83+ for ( batch_size, should_spill) in
84+ [ ( 5 , false ) , ( 1000 , false ) , ( 10000 , true ) , ( 20000 , true ) ]
85+ {
86+ let ( input, collected) = SortTest :: new ( )
87+ . with_utf8_batches ( batch_size)
88+ . with_sort_columns ( vec ! [ "x" ] )
6289 . with_pool_size ( 100 * KB )
6390 . with_should_spill ( should_spill)
6491 . run ( )
6592 . await ;
93+
94+ let mut input = input
95+ . iter ( )
96+ . flat_map ( |p| p. iter ( ) )
97+ . map ( |b| {
98+ let array = b. column ( 0 ) ;
99+ as_string_array ( array)
100+ . iter ( )
101+ . map ( |s| s. unwrap ( ) . to_string ( ) )
102+ } )
103+ . flatten ( )
104+ . collect :: < Vec < String > > ( ) ;
105+ input. sort_unstable ( ) ;
106+ let actual = collected
107+ . iter ( )
108+ . map ( |b| {
109+ let array = b. column ( 0 ) ;
110+ as_string_array ( array)
111+ . iter ( )
112+ . map ( |s| s. unwrap ( ) . to_string ( ) )
113+ } )
114+ . flatten ( )
115+ . collect :: < Vec < String > > ( ) ;
116+ assert_eq ! ( input, actual) ;
117+ }
118+ }
119+
120+ #[ tokio:: test]
121+ #[ cfg_attr( tarpaulin, ignore) ]
122+ async fn test_sort_multi_columns_100k_mem ( ) {
123+ for ( batch_size, should_spill) in
124+ [ ( 5 , false ) , ( 1000 , false ) , ( 10000 , true ) , ( 20000 , true ) ]
125+ {
126+ let ( input, collected) = SortTest :: new ( )
127+ . with_int32_utf8_batches ( batch_size)
128+ . with_sort_columns ( vec ! [ "x" , "y" ] )
129+ . with_pool_size ( 100 * KB )
130+ . with_should_spill ( should_spill)
131+ . run ( )
132+ . await ;
133+
134+ fn record_batch_to_vec ( b : & RecordBatch ) -> Vec < ( i32 , String ) > {
135+ let mut rows: Vec < _ > = Vec :: new ( ) ;
136+ let i32_array = as_int32_array ( b. column ( 0 ) ) . unwrap ( ) ;
137+ let string_array = as_string_array ( b. column ( 1 ) ) ;
138+ for i in 0 ..b. num_rows ( ) {
139+ let str = string_array. value ( i) . to_string ( ) ;
140+ let i32 = i32_array. value ( i) ;
141+ rows. push ( ( i32, str) ) ;
142+ }
143+ rows
144+ }
145+ let mut input = input
146+ . iter ( )
147+ . flat_map ( |p| p. iter ( ) )
148+ . map ( record_batch_to_vec)
149+ . flatten ( )
150+ . collect :: < Vec < ( i32 , String ) > > ( ) ;
151+ input. sort_unstable ( ) ;
152+ let actual = collected
153+ . iter ( )
154+ . map ( record_batch_to_vec)
155+ . flatten ( )
156+ . collect :: < Vec < ( i32 , String ) > > ( ) ;
157+ assert_eq ! ( input, actual) ;
66158 }
67159}
68160
69161#[ tokio:: test]
70162async fn test_sort_unlimited_mem ( ) {
71163 for ( batch_size, should_spill) in [ ( 5 , false ) , ( 20000 , false ) , ( 1000000 , false ) ] {
72- SortTest :: new ( )
164+ let ( input , collected ) = SortTest :: new ( )
73165 . with_int32_batches ( batch_size)
166+ . with_sort_columns ( vec ! [ "x" ] )
74167 . with_pool_size ( usize:: MAX )
75168 . with_should_spill ( should_spill)
76169 . run ( )
77170 . await ;
171+
172+ let expected = partitions_to_sorted_vec ( & input) ;
173+ let actual = batches_to_vec ( & collected) ;
174+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
78175 }
79176}
177+
80178#[ derive( Debug , Default ) ]
81179struct SortTest {
82180 input : Vec < Vec < RecordBatch > > ,
181+ /// The names of the columns to sort by
182+ sort_columns : Vec < String > ,
83183 /// GreedyMemoryPool size, if specified
84184 pool_size : Option < usize > ,
85185 /// If true, expect the sort to spill
@@ -91,12 +191,29 @@ impl SortTest {
91191 Default :: default ( )
92192 }
93193
194+ fn with_sort_columns ( mut self , sort_columns : Vec < & str > ) -> Self {
195+ self . sort_columns = sort_columns. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
196+ self
197+ }
198+
94199 /// Create batches of int32 values of rows
95200 fn with_int32_batches ( mut self , rows : usize ) -> Self {
96201 self . input = vec ! [ make_staggered_i32_batches( rows) ] ;
97202 self
98203 }
99204
205+ /// Create batches of utf8 values of rows
206+ fn with_utf8_batches ( mut self , rows : usize ) -> Self {
207+ self . input = vec ! [ make_staggered_utf8_batches( rows) ] ;
208+ self
209+ }
210+
211+ /// Create batches of int32 and utf8 values of rows
212+ fn with_int32_utf8_batches ( mut self , rows : usize ) -> Self {
213+ self . input = vec ! [ make_staggered_i32_utf8_batches( rows) ] ;
214+ self
215+ }
216+
100217 /// specify that this test should use a memory pool of the specified size
101218 fn with_pool_size ( mut self , pool_size : usize ) -> Self {
102219 self . pool_size = Some ( pool_size) ;
@@ -110,7 +227,7 @@ impl SortTest {
110227
111228 /// Sort the input using SortExec and ensure the results are
112229 /// correct according to `Vec::sort` both with and without spilling
113- async fn run ( & self ) {
230+ async fn run ( & self ) -> ( Vec < Vec < RecordBatch > > , Vec < RecordBatch > ) {
114231 let input = self . input . clone ( ) ;
115232 let first_batch = input
116233 . iter ( )
@@ -119,16 +236,21 @@ impl SortTest {
119236 . expect ( "at least one batch" ) ;
120237 let schema = first_batch. schema ( ) ;
121238
122- let sort = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
123- expr: col( "x" , & schema) . unwrap( ) ,
124- options: SortOptions {
125- descending: false ,
126- nulls_first: true ,
127- } ,
128- } ] ) ;
239+ let sort_ordering = LexOrdering :: new (
240+ self . sort_columns
241+ . iter ( )
242+ . map ( |c| PhysicalSortExpr {
243+ expr : col ( c, & schema) . unwrap ( ) ,
244+ options : SortOptions {
245+ descending : false ,
246+ nulls_first : true ,
247+ } ,
248+ } )
249+ . collect ( ) ,
250+ ) ;
129251
130252 let exec = MemorySourceConfig :: try_new_exec ( & input, schema, None ) . unwrap ( ) ;
131- let sort = Arc :: new ( SortExec :: new ( sort , exec) ) ;
253+ let sort = Arc :: new ( SortExec :: new ( sort_ordering , exec) ) ;
132254
133255 let session_config = SessionConfig :: new ( ) ;
134256 let session_ctx = if let Some ( pool_size) = self . pool_size {
@@ -153,9 +275,6 @@ impl SortTest {
153275 let task_ctx = session_ctx. task_ctx ( ) ;
154276 let collected = collect ( sort. clone ( ) , task_ctx) . await . unwrap ( ) ;
155277
156- let expected = partitions_to_sorted_vec ( & input) ;
157- let actual = batches_to_vec ( & collected) ;
158-
159278 if self . should_spill {
160279 assert_ne ! (
161280 sort. metrics( ) . unwrap( ) . spill_count( ) . unwrap( ) ,
@@ -175,7 +294,8 @@ impl SortTest {
175294 0 ,
176295 "The sort should have returned all memory used back to the memory pool"
177296 ) ;
178- assert_eq ! ( expected, actual, "failure in @ pool_size {self:?}" ) ;
297+
298+ ( input, collected)
179299 }
180300}
181301
@@ -203,3 +323,63 @@ fn make_staggered_i32_batches(len: usize) -> Vec<RecordBatch> {
203323 }
204324 batches
205325}
326+
327+ /// Return randomly sized record batches in a field named 'x' of type `Utf8`
328+ /// with randomized content
329+ fn make_staggered_utf8_batches ( len : usize ) -> Vec < RecordBatch > {
330+ let mut rng = rand:: thread_rng ( ) ;
331+ let max_batch = 1024 ;
332+
333+ let mut batches = vec ! [ ] ;
334+ let mut remaining = len;
335+ while remaining != 0 {
336+ let to_read = rng. gen_range ( 0 ..=remaining. min ( max_batch) ) ;
337+ remaining -= to_read;
338+
339+ batches. push (
340+ RecordBatch :: try_from_iter ( vec ! [ (
341+ "x" ,
342+ Arc :: new( StringArray :: from_iter_values(
343+ ( 0 ..to_read) . map( |_| format!( "test_string_{}" , rng. gen :: <u32 >( ) ) ) ,
344+ ) ) as ArrayRef ,
345+ ) ] )
346+ . unwrap ( ) ,
347+ )
348+ }
349+ batches
350+ }
351+
352+ /// Return randomly sized record batches in a field named 'x' of type `Int32`
353+ /// with randomized i32 content and a field named 'y' of type `Utf8`
354+ /// with randomized content
355+ fn make_staggered_i32_utf8_batches ( len : usize ) -> Vec < RecordBatch > {
356+ let mut rng = rand:: thread_rng ( ) ;
357+ let max_batch = 1024 ;
358+
359+ let mut batches = vec ! [ ] ;
360+ let mut remaining = len;
361+ while remaining != 0 {
362+ let to_read = rng. gen_range ( 0 ..=remaining. min ( max_batch) ) ;
363+ remaining -= to_read;
364+
365+ batches. push (
366+ RecordBatch :: try_from_iter ( vec ! [
367+ (
368+ "x" ,
369+ Arc :: new( Int32Array :: from_iter_values(
370+ ( 0 ..to_read) . map( |_| rng. gen ( ) ) ,
371+ ) ) as ArrayRef ,
372+ ) ,
373+ (
374+ "y" ,
375+ Arc :: new( StringArray :: from_iter_values(
376+ ( 0 ..to_read) . map( |_| format!( "test_string_{}" , rng. gen :: <u32 >( ) ) ) ,
377+ ) ) as ArrayRef ,
378+ ) ,
379+ ] )
380+ . unwrap ( ) ,
381+ )
382+ }
383+
384+ batches
385+ }
0 commit comments