From 2d8d558034eea0a88336fc23c5fa96c811cd5e0e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 20 Apr 2024 09:12:40 -0400 Subject: [PATCH 1/2] Expose unnest feature --- datafusion/tests/test_dataframe.py | 34 ++++++++++++ src/dataframe.rs | 12 +++++ src/expr.rs | 2 + src/expr/unnest.rs | 85 ++++++++++++++++++++++++++++++ src/sql/logical.rs | 2 + 5 files changed, 135 insertions(+) create mode 100644 src/expr/unnest.rs diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index c8c74fa2c..49774d89b 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -62,6 +62,20 @@ def struct_df(): return ctx.create_dataframe([[batch]]) +@pytest.fixture +def nested_df(): + ctx = SessionContext() + + # create a RecordBatch and a new DataFrame from it + # Intentionally make each array of different length + batch = pa.RecordBatch.from_arrays( + [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])], + names=["a", "b"], + ) + + return ctx.create_dataframe([[batch]]) + + @pytest.fixture def aggregate_df(): ctx = SessionContext() @@ -160,6 +174,26 @@ def test_with_column_renamed(df): assert result.schema.field(2).name == "sum" +def test_unnest(nested_df): + nested_df = nested_df.unnest("a") + + # execute and collect the first (and only) batch + result = nested_df.collect()[0] + + assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None]) + assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10]) + + +def test_unnest_without_nulls(nested_df): + nested_df = nested_df.unnest("a", preserve_nulls=False) + + # execute and collect the first (and only) batch + result = nested_df.collect()[0] + + assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6]) + assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) + + def test_udf(df): # is_null is a pa function over arrays is_null = udf( diff --git a/src/dataframe.rs b/src/dataframe.rs index a239a35f9..53769a463 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -25,6 +25,7 @@ use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::parquet::file::properties::WriterProperties; use datafusion::prelude::*; +use datafusion_common::UnnestOptions; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -293,6 +294,17 @@ impl PyDataFrame { Ok(Self::new(new_df)) } + #[pyo3(signature = (column, preserve_nulls=true))] + fn unnest(&self, column: &str, preserve_nulls: bool) -> PyResult { + let unnest_options = UnnestOptions { preserve_nulls }; + let df = self + .df + .as_ref() + .clone() + .unnest_column_with_options(column, unnest_options)?; + Ok(Self::new(df)) + } + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn intersect(&self, py_df: PyDataFrame) -> PyResult { let new_df = self diff --git a/src/expr.rs b/src/expr.rs index c0e7019f0..0958c4a98 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -89,6 +89,7 @@ pub mod subquery; pub mod subquery_alias; pub mod table_scan; pub mod union; +pub mod unnest; pub mod window; /// A PyExpr that can be used on a DataFrame @@ -684,6 +685,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs new file mode 100644 index 000000000..33fb82feb --- /dev/null +++ b/src/expr/unnest.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::logical_plan::Unnest; +use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; + +use crate::common::df_schema::PyDFSchema; +use crate::expr::logical_node::LogicalNode; +use crate::sql::logical::PyLogicalPlan; + +#[pyclass(name = "Unnest", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyUnnest { + unnest_: Unnest, +} + +impl From for PyUnnest { + fn from(unnest_: Unnest) -> PyUnnest { + PyUnnest { unnest_ } + } +} + +impl From for Unnest { + fn from(unnest_: PyUnnest) -> Self { + unnest_.unnest_ + } +} + +impl Display for PyUnnest { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Unnest + Inputs: {:?} + Schema: {:?}", + &self.unnest_.input, &self.unnest_.schema, + ) + } +} + +#[pymethods] +impl PyUnnest { + /// Retrieves the input `LogicalPlan` to this `Unnest` node + fn input(&self) -> PyResult> { + Ok(Self::inputs(self)) + } + + /// Resulting Schema for this `Unnest` node instance + fn schema(&self) -> PyResult { + Ok(self.unnest_.schema.as_ref().clone().into()) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("Unnest({})", self)) + } + + fn __name__(&self) -> PyResult { + Ok("Unnest".to_string()) + } +} + +impl LogicalNode for PyUnnest { + fn inputs(&self) -> Vec { + vec![PyLogicalPlan::from((*self.unnest_.input).clone())] + } + + fn to_variant(&self, py: Python) -> PyResult { + Ok(self.clone().into_py(py)) + } +} diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 3aa8a699f..62515c3dd 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -33,6 +33,7 @@ use crate::expr::sort::PySort; use crate::expr::subquery::PySubquery; use crate::expr::subquery_alias::PySubqueryAlias; use crate::expr::table_scan::PyTableScan; +use crate::expr::unnest::PyUnnest; use crate::expr::window::PyWindow; use datafusion_expr::LogicalPlan; use pyo3::prelude::*; @@ -78,6 +79,7 @@ impl PyLogicalPlan { LogicalPlan::TableScan(plan) => PyTableScan::from(plan.clone()).to_variant(py), LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py), LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py), + LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py), LogicalPlan::Window(plan) => PyWindow::from(plan.clone()).to_variant(py), other => Err(py_unsupported_variant_err(format!( "Cannot convert this plan to a LogicalNode: {:?}", From 1a226f9e2d5c70eeea219a7d69fe18bd3b1a5a1d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 21 Apr 2024 07:53:01 -0400 Subject: [PATCH 2/2] Update dataframe operation name to match rust implementation --- datafusion/tests/test_dataframe.py | 4 ++-- src/dataframe.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 49774d89b..efb1679b9 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -175,7 +175,7 @@ def test_with_column_renamed(df): def test_unnest(nested_df): - nested_df = nested_df.unnest("a") + nested_df = nested_df.unnest_column("a") # execute and collect the first (and only) batch result = nested_df.collect()[0] @@ -185,7 +185,7 @@ def test_unnest(nested_df): def test_unnest_without_nulls(nested_df): - nested_df = nested_df.unnest("a", preserve_nulls=False) + nested_df = nested_df.unnest_column("a", preserve_nulls=False) # execute and collect the first (and only) batch result = nested_df.collect()[0] diff --git a/src/dataframe.rs b/src/dataframe.rs index 53769a463..a319b3d73 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -295,7 +295,7 @@ impl PyDataFrame { } #[pyo3(signature = (column, preserve_nulls=true))] - fn unnest(&self, column: &str, preserve_nulls: bool) -> PyResult { + fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult { let unnest_options = UnnestOptions { preserve_nulls }; let df = self .df