@@ -124,6 +124,21 @@ impl ExternalSorter {
124124 // calls to `timer.done()` below.
125125 let _timer = tracking_metrics. elapsed_compute ( ) . timer ( ) ;
126126 let partial = sort_batch ( input, self . schema . clone ( ) , & self . expr , self . fetch ) ?;
127+ // The resulting batch might be smaller than the input batch if there
128+ // is an propagated limit.
129+
130+ if self . fetch . is_some ( ) {
131+ let new_size = batch_byte_size ( & partial. sorted_batch ) ;
132+ let size_delta = size. checked_sub ( new_size) . ok_or_else ( || {
133+ DataFusionError :: Internal ( format ! (
134+ "The size of the sorted batch is larger than the size of the input batch: {} > {}" ,
135+ size,
136+ new_size
137+ ) )
138+ } ) ?;
139+ self . shrink ( size_delta) ;
140+ self . metrics . mem_used ( ) . sub ( size_delta) ;
141+ }
127142 in_mem_batches. push ( partial) ;
128143 }
129144 Ok ( ( ) )
@@ -1062,6 +1077,65 @@ mod tests {
10621077 Ok ( ( ) )
10631078 }
10641079
1080+ #[ tokio:: test]
1081+ async fn test_sort_fetch_memory_calculation ( ) -> Result < ( ) > {
1082+ // This test mirrors down the size from the example above.
1083+ let avg_batch_size = 5336 ;
1084+ let partitions = 4 ;
1085+
1086+ // A tuple of (fetch, expect_spillage)
1087+ let test_options = vec ! [
1088+ // Since we don't have a limit (and the memory is less than the total size of
1089+ // all the batches we are processing, we expect it to spill.
1090+ ( None , true ) ,
1091+ // When we have a limit however, the buffered size of batches should fit in memory
1092+ // since it is much lover than the total size of the input batch.
1093+ ( Some ( 1 ) , false ) ,
1094+ ] ;
1095+
1096+ for ( fetch, expect_spillage) in test_options {
1097+ let config = RuntimeConfig :: new ( )
1098+ . with_memory_limit ( avg_batch_size * ( partitions - 1 ) , 1.0 ) ;
1099+ let runtime = Arc :: new ( RuntimeEnv :: new ( config) ?) ;
1100+ let session_ctx =
1101+ SessionContext :: with_config_rt ( SessionConfig :: new ( ) , runtime) ;
1102+
1103+ let csv = test:: scan_partitioned_csv ( partitions) ?;
1104+ let schema = csv. schema ( ) ;
1105+
1106+ let sort_exec = Arc :: new ( SortExec :: try_new (
1107+ vec ! [
1108+ // c1 string column
1109+ PhysicalSortExpr {
1110+ expr: col( "c1" , & schema) ?,
1111+ options: SortOptions :: default ( ) ,
1112+ } ,
1113+ // c2 uin32 column
1114+ PhysicalSortExpr {
1115+ expr: col( "c2" , & schema) ?,
1116+ options: SortOptions :: default ( ) ,
1117+ } ,
1118+ // c7 uin8 column
1119+ PhysicalSortExpr {
1120+ expr: col( "c7" , & schema) ?,
1121+ options: SortOptions :: default ( ) ,
1122+ } ,
1123+ ] ,
1124+ Arc :: new ( CoalescePartitionsExec :: new ( csv) ) ,
1125+ fetch,
1126+ ) ?) ;
1127+
1128+ let task_ctx = session_ctx. task_ctx ( ) ;
1129+ let result = collect ( sort_exec. clone ( ) , task_ctx) . await ?;
1130+ assert_eq ! ( result. len( ) , 1 ) ;
1131+
1132+ let metrics = sort_exec. metrics ( ) . unwrap ( ) ;
1133+ let did_it_spill = metrics. spill_count ( ) . unwrap ( ) > 0 ;
1134+ assert_eq ! ( did_it_spill, expect_spillage, "with fetch: {:?}" , fetch) ;
1135+ }
1136+ Ok ( ( ) )
1137+ }
1138+
10651139 #[ tokio:: test]
10661140 async fn test_sort_metadata ( ) -> Result < ( ) > {
10671141 let session_ctx = SessionContext :: new ( ) ;
0 commit comments