Skip to content

Commit 9262a5d

Browse files
committed
ARROW-11790: [Rust][DataFusion] Change builder signatures to take impl Interator<Item=Expr> rather than &[Expr]
# NOTE: Since is a fairly major backwards incompatible change (many callsites need to be updated, though mostly mechanically); I gathered some feedback on this approach in #9692 and this is the PR I propose for merge. I'll leave this open for several days and also send a note to the mailing lists for additional comment It is part of my overall plan to make the DataFusion optimizer more idiomatic and do much less copying [ARROW-11689](https://issues.apache.org/jira/browse/ARROW-11689) # Rationale: All callsites currently need an owned `Vec` (or equivalent) so they can pass in `&[Expr]` and then Datafusion copies all the `Expr`s. Many times the original `Vec<Expr>` is discarded immediately after use (I'll point out where this happens in a few places below). Thus I it would better (more idiomatic and often less copy/faster) to take something that could produce an iterator over Expr # Changes 1. Change `Dataframe` so it takes `Vec<Expr>` rather than `&[Expr]` 2. Change `LogicalPlanBuilder` so it takes `impl Iterator<Item=Expr>` rather than `&[Expr]` I couldn't figure out how to allow the `Dataframe` API (which is a Trait) to take an `impl Iterator<Item=Expr>` Closes #9703 from alamb/alamb/less_copy_in_plan_builder_final Authored-by: Andrew Lamb <andrew@nerdnetworks.org> Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 21483ad commit 9262a5d

19 files changed

Lines changed: 181 additions & 132 deletions

File tree

rust/benchmarks/src/bin/tpch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1636,7 +1636,7 @@ mod tests {
16361636
.file_extension(".out");
16371637
let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?;
16381638
let df = df.select(
1639-
&get_answer_schema(n)
1639+
get_answer_schema(n)
16401640
.fields()
16411641
.iter()
16421642
.map(|field| {

rust/datafusion/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async fn main() -> datafusion::error::Result<()> {
100100
let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
101101

102102
let df = df.filter(col("a").lt_eq(col("b")))?
103-
.aggregate(&[col("a")], &[min(col("b"))])?
103+
.aggregate(vec![col("a")], vec![min(col("b"))])?
104104
.limit(100)?;
105105

106106
// execute and print results

rust/datafusion/examples/simple_udaf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async fn main() -> Result<()> {
148148
let df = ctx.table("t")?;
149149

150150
// perform the aggregation
151-
let df = df.aggregate(&[], &[geometric_mean.call(vec![col("a")])])?;
151+
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
152152

153153
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
154154

rust/datafusion/examples/simple_udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async fn main() -> Result<()> {
133133
let expr1 = pow.call(vec![col("a"), col("b")]);
134134

135135
// equivalent to `'SELECT pow(a, b), pow(a, b) AS pow1 FROM t'`
136-
let df = df.select(&[
136+
let df = df.select(vec![
137137
expr,
138138
// alias so that they have different column names
139139
expr1.alias("pow1"),

rust/datafusion/src/dataframe.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use async_trait::async_trait;
4444
/// let mut ctx = ExecutionContext::new();
4545
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
4646
/// let df = df.filter(col("a").lt_eq(col("b")))?
47-
/// .aggregate(&[col("a")], &[min(col("b"))])?
47+
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
4848
/// .limit(100)?;
4949
/// let results = df.collect();
5050
/// # Ok(())
@@ -75,11 +75,11 @@ pub trait DataFrame: Send + Sync {
7575
/// # fn main() -> Result<()> {
7676
/// let mut ctx = ExecutionContext::new();
7777
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
78-
/// let df = df.select(&[col("a") * col("b"), col("c")])?;
78+
/// let df = df.select(vec![col("a") * col("b"), col("c")])?;
7979
/// # Ok(())
8080
/// # }
8181
/// ```
82-
fn select(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>>;
82+
fn select(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;
8383

8484
/// Filter a DataFrame to only include rows that match the specified filter expression.
8585
///
@@ -105,17 +105,17 @@ pub trait DataFrame: Send + Sync {
105105
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
106106
///
107107
/// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a"
108-
/// let _ = df.aggregate(&[col("a")], &[min(col("b"))])?;
108+
/// let _ = df.aggregate(vec![col("a")], vec![min(col("b"))])?;
109109
///
110110
/// // The following use is the equivalent of "SELECT MIN(b)"
111-
/// let _ = df.aggregate(&[], &[min(col("b"))])?;
111+
/// let _ = df.aggregate(vec![], vec![min(col("b"))])?;
112112
/// # Ok(())
113113
/// # }
114114
/// ```
115115
fn aggregate(
116116
&self,
117-
group_expr: &[Expr],
118-
aggr_expr: &[Expr],
117+
group_expr: Vec<Expr>,
118+
aggr_expr: Vec<Expr>,
119119
) -> Result<Arc<dyn DataFrame>>;
120120

121121
/// Limit the number of rows returned from this DataFrame.
@@ -155,11 +155,11 @@ pub trait DataFrame: Send + Sync {
155155
/// # fn main() -> Result<()> {
156156
/// let mut ctx = ExecutionContext::new();
157157
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
158-
/// let df = df.sort(&[col("a").sort(true, true), col("b").sort(false, false)])?;
158+
/// let df = df.sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?;
159159
/// # Ok(())
160160
/// # }
161161
/// ```
162-
fn sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>>;
162+
fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;
163163

164164
/// Join this DataFrame with another DataFrame using the specified columns as join keys
165165
///
@@ -171,7 +171,7 @@ pub trait DataFrame: Send + Sync {
171171
/// let mut ctx = ExecutionContext::new();
172172
/// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
173173
/// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?
174-
/// .select(&[
174+
/// .select(vec![
175175
/// col("a").alias("a2"),
176176
/// col("b").alias("b2"),
177177
/// col("c").alias("c2")])?;

rust/datafusion/src/execution/context.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ use parquet::file::properties::WriterProperties;
8282
/// let mut ctx = ExecutionContext::new();
8383
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
8484
/// let df = df.filter(col("a").lt_eq(col("b")))?
85-
/// .aggregate(&[col("a")], &[min(col("b"))])?
85+
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
8686
/// .limit(100)?;
8787
/// let results = df.collect();
8888
/// # Ok(())
@@ -954,7 +954,7 @@ mod tests {
954954

955955
let table = ctx.table("test")?;
956956
let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan())
957-
.project(&[col("c2")])?
957+
.project(vec![col("c2")])?
958958
.build()?;
959959

960960
let optimized_plan = ctx.optimize(&logical_plan)?;
@@ -999,7 +999,7 @@ mod tests {
999999
assert_eq!(schema.field_with_name("c1")?.is_nullable(), false);
10001000

10011001
let plan = LogicalPlanBuilder::scan_empty("", &schema, None)?
1002-
.project(&[col("c1")])?
1002+
.project(vec![col("c1")])?
10031003
.build()?;
10041004

10051005
let plan = ctx.optimize(&plan)?;
@@ -1030,7 +1030,7 @@ mod tests {
10301030
)?]];
10311031

10321032
let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)?
1033-
.project(&[col("b")])?
1033+
.project(vec![col("b")])?
10341034
.build()?;
10351035
assert_fields_eq(&plan, vec!["b"]);
10361036

@@ -1660,8 +1660,8 @@ mod tests {
16601660
]));
16611661

16621662
let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)?
1663-
.aggregate(&[col("c1")], &[sum(col("c2"))])?
1664-
.project(&[col("c1"), col("SUM(c2)").alias("total_salary")])?
1663+
.aggregate(vec![col("c1")], vec![sum(col("c2"))])?
1664+
.project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])?
16651665
.build()?;
16661666

16671667
let plan = ctx.optimize(&plan)?;
@@ -1886,7 +1886,7 @@ mod tests {
18861886
let t = ctx.table("t")?;
18871887

18881888
let plan = LogicalPlanBuilder::from(&t.to_logical_plan())
1889-
.project(&[
1889+
.project(vec![
18901890
col("a"),
18911891
col("b"),
18921892
ctx.udf("my_add")?.call(vec![col("a"), col("b")]),

rust/datafusion/src/execution/dataframe_impl.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ impl DataFrame for DataFrameImpl {
5858
.map(|name| self.plan.schema().field_with_unqualified_name(name))
5959
.collect::<Result<Vec<_>>>()?;
6060
let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
61-
self.select(&expr)
61+
self.select(expr)
6262
}
6363

6464
/// Create a projection based on arbitrary expressions
65-
fn select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>> {
65+
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
6666
let plan = LogicalPlanBuilder::from(&self.plan)
6767
.project(expr_list)?
6868
.build()?;
@@ -80,8 +80,8 @@ impl DataFrame for DataFrameImpl {
8080
/// Perform an aggregate query
8181
fn aggregate(
8282
&self,
83-
group_expr: &[Expr],
84-
aggr_expr: &[Expr],
83+
group_expr: Vec<Expr>,
84+
aggr_expr: Vec<Expr>,
8585
) -> Result<Arc<dyn DataFrame>> {
8686
let plan = LogicalPlanBuilder::from(&self.plan)
8787
.aggregate(group_expr, aggr_expr)?
@@ -96,7 +96,7 @@ impl DataFrame for DataFrameImpl {
9696
}
9797

9898
/// Sort by specified sorting expressions
99-
fn sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>> {
99+
fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
100100
let plan = LogicalPlanBuilder::from(&self.plan).sort(expr)?.build()?;
101101
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
102102
}
@@ -204,7 +204,7 @@ mod tests {
204204
fn select_expr() -> Result<()> {
205205
// build plan using Table API
206206
let t = test_table()?;
207-
let t2 = t.select(&[col("c1"), col("c2"), col("c11")])?;
207+
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
208208
let plan = t2.to_logical_plan();
209209

210210
// build query using SQL
@@ -220,8 +220,8 @@ mod tests {
220220
fn aggregate() -> Result<()> {
221221
// build plan using DataFrame API
222222
let df = test_table()?;
223-
let group_expr = &[col("c1")];
224-
let aggr_expr = &[
223+
let group_expr = vec![col("c1")];
224+
let aggr_expr = vec![
225225
min(col("c12")),
226226
max(col("c12")),
227227
avg(col("c12")),
@@ -322,7 +322,7 @@ mod tests {
322322

323323
let f = df.registry();
324324

325-
let df = df.select(&[f.udf("my_fn")?.call(vec![col("c12")])])?;
325+
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
326326
let plan = df.to_logical_plan();
327327

328328
// build query using SQL

rust/datafusion/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
//!
4949
//! // create a plan
5050
//! let df = df.filter(col("a").lt_eq(col("b")))?
51-
//! .aggregate(&[col("a")], &[min(col("b"))])?
51+
//! .aggregate(vec![col("a")], vec![min(col("b"))])?
5252
//! .limit(100)?;
5353
//!
5454
//! // execute the plan

0 commit comments

Comments
 (0)