Skip to content

Commit 32725dd

Browse files
committed
add example for RecordBatch adaptation
1 parent 3db0b09 commit 32725dd

1 file changed

Lines changed: 98 additions & 3 deletions

File tree

datafusion/physical-expr/src/schema_rewriter.rs

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,16 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
198198

199199
#[cfg(test)]
200200
mod tests {
201-
use crate::expressions::lit;
201+
use crate::expressions::{col, lit};
202202

203203
use super::*;
204-
use arrow::datatypes::{DataType, Field, Schema};
205-
use datafusion_common::ScalarValue;
204+
use arrow::{
205+
array::{RecordBatch, RecordBatchOptions},
206+
datatypes::{DataType, Field, Schema, SchemaRef},
207+
};
208+
use datafusion_common::{record_batch, ScalarValue};
206209
use datafusion_expr::Operator;
210+
use itertools::Itertools;
207211
use std::sync::Arc;
208212

209213
fn create_test_schema() -> (Schema, Schema) {
@@ -369,4 +373,95 @@ mod tests {
369373
.to_string()
370374
.contains("Non-nullable column 'b' is missing"));
371375
}
376+
377+
/// Stolen from ProjectionExec
378+
fn batch_project(
379+
expr: Vec<Arc<dyn PhysicalExpr>>,
380+
batch: &RecordBatch,
381+
schema: SchemaRef,
382+
) -> Result<RecordBatch> {
383+
// Records time on drop
384+
let arrays = expr
385+
.iter()
386+
.map(|expr| {
387+
expr.evaluate(batch)
388+
.and_then(|v| v.into_array(batch.num_rows()))
389+
})
390+
.collect::<Result<Vec<_>>>()?;
391+
392+
if arrays.is_empty() {
393+
let options =
394+
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
395+
RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
396+
.map_err(Into::into)
397+
} else {
398+
RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
399+
}
400+
}
401+
402+
/// Example showing how we can use the `PhysicalExprSchemaRewriter` to adapt RecordBatches during a scan
403+
/// to apply projections, type conversions and handling of missing columns all at once.
404+
#[test]
405+
fn test_adapt_batches() {
406+
let physical_batch = record_batch!(
407+
("a", Int32, vec![Some(1), None, Some(3)]),
408+
("extra", Utf8, vec![Some("x"), Some("y"), None])
409+
)
410+
.unwrap();
411+
412+
let physical_schema = physical_batch.schema();
413+
414+
let logical_schema = Arc::new(Schema::new(vec![
415+
Field::new("a", DataType::Int64, true), // Different type
416+
Field::new("b", DataType::Utf8, true), // Missing from physical
417+
]));
418+
419+
let projection = vec![
420+
col("b", &logical_schema).unwrap(),
421+
col("a", &logical_schema).unwrap(),
422+
];
423+
424+
let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema);
425+
426+
let adapted_projection = projection
427+
.into_iter()
428+
.map(|expr| rewriter.rewrite(expr).unwrap())
429+
.collect_vec();
430+
431+
let adapted_schema = Arc::new(Schema::new(
432+
adapted_projection
433+
.iter()
434+
.map(|expr| expr.return_field(&physical_schema).unwrap())
435+
.collect_vec()
436+
));
437+
438+
let res = batch_project(
439+
adapted_projection,
440+
&physical_batch,
441+
Arc::clone(&adapted_schema),
442+
)
443+
.unwrap();
444+
445+
assert_eq!(res.num_columns(), 2);
446+
assert_eq!(res.column(0).data_type(), &DataType::Utf8);
447+
assert_eq!(res.column(1).data_type(), &DataType::Int64);
448+
assert_eq!(
449+
res.column(0)
450+
.as_any()
451+
.downcast_ref::<arrow::array::StringArray>()
452+
.unwrap()
453+
.iter()
454+
.collect_vec(),
455+
vec![None, None, None]
456+
);
457+
assert_eq!(
458+
res.column(1)
459+
.as_any()
460+
.downcast_ref::<arrow::array::Int64Array>()
461+
.unwrap()
462+
.iter()
463+
.collect_vec(),
464+
vec![Some(1), None, Some(3)]
465+
);
466+
}
372467
}

0 commit comments

Comments
 (0)