Skip to content

Commit 690b0b1

Browse files
committed
Only update TopK dynamic filters if the new ones are more selective
1 parent 7d52145 commit 690b0b1

2 files changed

Lines changed: 137 additions & 51 deletions

File tree

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use crate::spill::get_record_batch_memory_size;
4040
use crate::spill::in_progress_spill_file::InProgressSpillFile;
4141
use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
4242
use crate::stream::RecordBatchStreamAdapter;
43-
use crate::topk::TopK;
43+
use crate::topk::{TopK, TopKDynamicFilters};
4444
use crate::{
4545
DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan,
4646
ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream,
@@ -858,8 +858,10 @@ pub struct SortExec {
858858
common_sort_prefix: Vec<PhysicalSortExpr>,
859859
/// Cache holding plan properties like equivalences, output partitioning etc.
860860
cache: PlanProperties,
861-
/// Filter matching the state of the sort for dynamic filter pushdown
862-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
861+
/// Filter matching the state of the sort for dynamic filter pushdown.
862+
/// If `fetch` is `Some`, this will also be set and a TopK operator may be used.
863+
/// If `fetch` is `None`, this will be `None`.
864+
filter: Option<TopKDynamicFilters>,
863865
}
864866

865867
impl SortExec {
@@ -905,14 +907,14 @@ impl SortExec {
905907
self
906908
}
907909

908-
/// Add or reset `self.filter` to a new `DynamicFilterPhysicalExpr`.
909-
fn create_filter(&self) -> Arc<DynamicFilterPhysicalExpr> {
910+
/// Add or reset `self.filter` to a new `TopKDynamicFilters`.
911+
fn create_filter(&self) -> TopKDynamicFilters {
910912
let children = self
911913
.expr
912914
.iter()
913915
.map(|sort_expr| Arc::clone(&sort_expr.expr))
914916
.collect::<Vec<_>>();
915-
Arc::new(DynamicFilterPhysicalExpr::new(children, lit(true)))
917+
TopKDynamicFilters::new(Arc::new(DynamicFilterPhysicalExpr::new(children, lit(true))))
916918
}
917919

918920
fn cloned(&self) -> Self {
@@ -1051,7 +1053,7 @@ impl DisplayAs for SortExec {
10511053
Some(fetch) => {
10521054
write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?;
10531055
if let Some(filter) = &self.filter {
1054-
if let Ok(current) = filter.current() {
1056+
if let Ok(current) = filter.expr().current() {
10551057
if !current.eq(&lit(true)) {
10561058
write!(f, ", filter=[{current}]")?;
10571059
}
@@ -1196,7 +1198,10 @@ impl ExecutionPlan for SortExec {
11961198
context.session_config().batch_size(),
11971199
context.runtime_env(),
11981200
&self.metrics_set,
1199-
self.filter.clone(),
1201+
self.filter
1202+
.as_ref()
1203+
.expect("Filter should be set when fetch is Some")
1204+
.clone(),
12001205
)?;
12011206
Ok(Box::pin(RecordBatchStreamAdapter::new(
12021207
self.schema(),
@@ -1320,8 +1325,7 @@ impl ExecutionPlan for SortExec {
13201325

13211326
if let Some(filter) = &self.filter {
13221327
if config.optimizer.enable_dynamic_filter_pushdown {
1323-
child =
1324-
child.with_self_filter(Arc::clone(filter) as Arc<dyn PhysicalExpr>);
1328+
child = child.with_self_filter(filter.expr());
13251329
}
13261330
}
13271331

datafusion/physical-plan/src/topk/mod.rs

Lines changed: 123 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::{
2323
row::{RowConverter, Rows, SortField},
2424
};
2525
use datafusion_expr::{ColumnarValue, Operator};
26+
use parking_lot::RwLock;
2627
use std::mem::size_of;
2728
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
2829

@@ -121,13 +122,37 @@ pub struct TopK {
121122
/// Common sort prefix between the input and the sort expressions to allow early exit optimization
122123
common_sort_prefix: Arc<[PhysicalSortExpr]>,
123124
/// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
124-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
125+
filter: TopKDynamicFilters,
125126
/// If true, indicates that all rows of subsequent batches are guaranteed
126127
/// to be greater (by byte order, after row conversion) than the top K,
127128
/// which means the top K won't change and the computation can be finished early.
128129
pub(crate) finished: bool,
129130
}
130131

132+
#[derive(Debug, Clone)]
133+
pub struct TopKDynamicFilters {
134+
/// The current *global* threshold for the dynamic filter.
135+
/// This is shared across all partitions and is updated by any of them.
136+
/// Stored as row bytes for efficient comparison.
137+
threshold_row: Arc<RwLock<Option<Vec<u8>>>>,
138+
/// The expression used to evaluate the dynamic filter
139+
expr: Arc<DynamicFilterPhysicalExpr>,
140+
}
141+
142+
impl TopKDynamicFilters {
143+
/// Create a new `TopKDynamicFilters` with the given expression
144+
pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
145+
Self {
146+
threshold_row: Arc::new(RwLock::new(None)),
147+
expr,
148+
}
149+
}
150+
151+
pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
152+
Arc::clone(&self.expr)
153+
}
154+
}
155+
131156
// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
132157
const ESTIMATED_BYTES_PER_ROW: usize = 20;
133158

@@ -160,7 +185,7 @@ impl TopK {
160185
batch_size: usize,
161186
runtime: Arc<RuntimeEnv>,
162187
metrics: &ExecutionPlanMetricsSet,
163-
filter: Option<Arc<DynamicFilterPhysicalExpr>>,
188+
filter: TopKDynamicFilters,
164189
) -> Result<Self> {
165190
let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
166191
.register(&runtime.memory_pool);
@@ -214,41 +239,39 @@ impl TopK {
214239

215240
let mut selected_rows = None;
216241

217-
if let Some(filter) = self.filter.as_ref() {
218-
// If a filter is provided, update it with the new rows
219-
let filter = filter.current()?;
220-
let filtered = filter.evaluate(&batch)?;
221-
let num_rows = batch.num_rows();
222-
let array = filtered.into_array(num_rows)?;
223-
let mut filter = array.as_boolean().clone();
224-
let true_count = filter.true_count();
225-
if true_count == 0 {
226-
// nothing to filter, so no need to update
227-
return Ok(());
242+
// If a filter is provided, update it with the new rows
243+
let filter = self.filter.expr.current()?;
244+
let filtered = filter.evaluate(&batch)?;
245+
let num_rows = batch.num_rows();
246+
let array = filtered.into_array(num_rows)?;
247+
let mut filter = array.as_boolean().clone();
248+
let true_count = filter.true_count();
249+
if true_count == 0 {
250+
// nothing to filter, so no need to update
251+
return Ok(());
252+
}
253+
// only update the keys / rows if the filter does not match all rows
254+
if true_count < num_rows {
255+
// Indices in `set_indices` should be correct if filter contains nulls
256+
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
257+
// so there is no overhead to do this here.
258+
if filter.nulls().is_some() {
259+
filter = prep_null_mask_filter(&filter);
228260
}
229-
// only update the keys / rows if the filter does not match all rows
230-
if true_count < num_rows {
231-
// Indices in `set_indices` should be correct if filter contains nulls
232-
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
233-
// so there is no overhead to do this here.
234-
if filter.nulls().is_some() {
235-
filter = prep_null_mask_filter(&filter);
236-
}
237261

238-
let filter_predicate = FilterBuilder::new(&filter);
239-
let filter_predicate = if sort_keys.len() > 1 {
240-
// Optimize filter when it has multiple sort keys
241-
filter_predicate.optimize().build()
242-
} else {
243-
filter_predicate.build()
244-
};
245-
selected_rows = Some(filter);
246-
sort_keys = sort_keys
247-
.iter()
248-
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
249-
.collect::<Result<Vec<_>>>()?;
250-
}
251-
};
262+
let filter_predicate = FilterBuilder::new(&filter);
263+
let filter_predicate = if sort_keys.len() > 1 {
264+
// Optimize filter when it has multiple sort keys
265+
filter_predicate.optimize().build()
266+
} else {
267+
filter_predicate.build()
268+
};
269+
selected_rows = Some(filter);
270+
sort_keys = sort_keys
271+
.iter()
272+
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
273+
.collect::<Result<Vec<_>>>()?;
274+
}
252275
// reuse existing `Rows` to avoid reallocations
253276
let rows = &mut self.scratch_rows;
254277
rows.clear();
@@ -319,19 +342,75 @@ impl TopK {
319342
/// (a > 2 OR (a = 2 AND b < 3))
320343
/// ```
321344
fn update_filter(&mut self) -> Result<()> {
322-
let Some(filter) = &self.filter else {
345+
// If the heap doesn't have k elements yet, we can't create thresholds
346+
let Some(max_row) = self.heap.max() else {
323347
return Ok(());
324348
};
325-
let Some(thresholds) = self.heap.get_threshold_values(&self.expr)? else {
326-
return Ok(());
349+
350+
let new_threshold_row = &max_row.row;
351+
352+
// Extract filter expression reference before entering critical section
353+
let filter_expr = Arc::clone(&self.filter.expr);
354+
355+
// Check if we need to update and do both threshold and filter update atomically
356+
{
357+
let mut threshold_guard = self.filter.threshold_row.write();
358+
if let Some(current_row) = threshold_guard.as_ref() {
359+
match current_row.as_slice().cmp(new_threshold_row) {
360+
Ordering::Greater => {
361+
// new < current, so new threshold is more selective
362+
// Update threshold and filter atomically to prevent race conditions
363+
*threshold_guard = Some(new_threshold_row.to_vec());
364+
365+
// Extract scalar values for filter expression creation
366+
let thresholds =
367+
match self.heap.get_threshold_values(&self.expr)? {
368+
Some(t) => t,
369+
None => return Ok(()),
370+
};
371+
372+
// Update the filter expression while still holding the lock
373+
Self::update_filter_expression(
374+
&filter_expr,
375+
&self.expr,
376+
thresholds,
377+
)?;
378+
}
379+
_ => {
380+
// Same threshold or current is more selective, no need to update
381+
}
382+
}
383+
} else {
384+
// No current thresholds, so update with the new ones
385+
*threshold_guard = Some(new_threshold_row.to_vec());
386+
387+
// Extract scalar values for filter expression creation
388+
let thresholds = match self.heap.get_threshold_values(&self.expr)? {
389+
Some(t) => t,
390+
None => return Ok(()),
391+
};
392+
393+
// Update the filter expression while still holding the lock
394+
Self::update_filter_expression(&filter_expr, &self.expr, thresholds)?;
395+
}
327396
};
328397

398+
Ok(())
399+
}
400+
401+
/// Update the filter expression with the given thresholds.
402+
/// This should only be called while holding the threshold lock.
403+
fn update_filter_expression(
404+
filter_expr: &DynamicFilterPhysicalExpr,
405+
sort_exprs: &[PhysicalSortExpr],
406+
thresholds: Vec<ScalarValue>,
407+
) -> Result<()> {
329408
// Create filter expressions for each threshold
330409
let mut filters: Vec<Arc<dyn PhysicalExpr>> =
331410
Vec::with_capacity(thresholds.len());
332411

333412
let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
334-
for (sort_expr, value) in self.expr.iter().zip(thresholds.iter()) {
413+
for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
335414
// Create the appropriate operator based on sort order
336415
let op = if sort_expr.options.descending {
337416
// For descending sort, we want col > threshold (exclude smaller values)
@@ -405,7 +484,7 @@ impl TopK {
405484

406485
if let Some(predicate) = dynamic_predicate {
407486
if !predicate.eq(&lit(true)) {
408-
filter.update(predicate)?;
487+
filter_expr.update(predicate)?;
409488
}
410489
}
411490

@@ -1053,7 +1132,10 @@ mod tests {
10531132
2,
10541133
runtime,
10551134
&metrics,
1056-
None,
1135+
TopKDynamicFilters::new(Arc::new(DynamicFilterPhysicalExpr::new(
1136+
vec![],
1137+
lit(true),
1138+
))),
10571139
)?;
10581140

10591141
// Create the first batch with two columns:

0 commit comments

Comments
 (0)