diff --git a/Cargo.lock b/Cargo.lock index 2622625c7..db7c9a768 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -899,6 +899,8 @@ dependencies = [ "mimalloc", "object_store", "parking_lot", + "prost", + "prost-types", "pyo3", "pyo3-build-config", "rand", diff --git a/Cargo.toml b/Cargo.toml index 8c205a4cd..4c5203d69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,8 @@ datafusion-expr = "23.0.0" datafusion-optimizer = "23.0.0" datafusion-sql = "23.0.0" datafusion-substrait = "23.0.0" +prost = "0.11" +prost-types = "0.11" uuid = { version = "1.2", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false } async-trait = "0.1" diff --git a/datafusion/tests/test_substrait.py b/datafusion/tests/test_substrait.py index 9a08b760e..01df2d746 100644 --- a/datafusion/tests/test_substrait.py +++ b/datafusion/tests/test_substrait.py @@ -41,6 +41,8 @@ def test_substrait_serialization(ctx): substrait_plan = ss.substrait.serde.serialize_to_plan( "SELECT * FROM t", ctx ) + substrait_bytes = substrait_plan.encode() + assert type(substrait_bytes) is bytes substrait_bytes = ss.substrait.serde.serialize_bytes( "SELECT * FROM t", ctx ) diff --git a/examples/substrait.py b/examples/substrait.py index c167f7d90..515311d85 100644 --- a/examples/substrait.py +++ b/examples/substrait.py @@ -32,8 +32,13 @@ ) # type(substrait_plan) -> +# Encode it to bytes +substrait_bytes = substrait_plan.encode() +# type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely +# where they could subsequently be deserialized on the receiving end. + # Alternative serialization approaches -# type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely +# type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely # where they could subsequently be deserialized on the receiving end. substrait_bytes = ss.substrait.serde.serialize_bytes( "SELECT * FROM aggregate_test_data", ctx diff --git a/src/errors.rs b/src/errors.rs index ce6b3c28a..d12b6ade1 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -21,6 +21,7 @@ use std::fmt::Debug; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; +use prost::EncodeError; use pyo3::{exceptions::PyException, PyErr}; pub type Result = std::result::Result; @@ -31,6 +32,7 @@ pub enum DataFusionError { ArrowError(ArrowError), Common(String), PythonError(PyErr), + EncodeError(EncodeError), } impl fmt::Display for DataFusionError { @@ -40,6 +42,7 @@ impl fmt::Display for DataFusionError { DataFusionError::ArrowError(e) => write!(f, "Arrow error: {e:?}"), DataFusionError::PythonError(e) => write!(f, "Python error {e:?}"), DataFusionError::Common(e) => write!(f, "{e}"), + DataFusionError::EncodeError(e) => write!(f, "Failed to encode substrait plan: {e}"), } } } diff --git a/src/substrait.rs b/src/substrait.rs index 2bde01123..361e9ac4b 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyBytes}; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, DataFusionError}; @@ -25,6 +25,7 @@ use crate::utils::wait_for_future; use datafusion_substrait::logical_plan::{consumer, producer}; use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; +use prost::Message; #[pyclass(name = "plan", module = "datafusion.substrait", subclass, unsendable)] #[derive(Debug, Clone)] @@ -32,6 +33,17 @@ pub(crate) struct PyPlan { pub(crate) plan: Plan, } +#[pymethods] +impl PyPlan { + fn encode(&self, py: Python) -> PyResult { + let mut proto_bytes = Vec::::new(); + self.plan + .encode(&mut proto_bytes) + .map_err(|e| DataFusionError::EncodeError(e))?; + Ok(PyBytes::new(py, &proto_bytes).into()) + } +} + impl From for Plan { fn from(plan: PyPlan) -> Plan { plan.plan @@ -63,16 +75,19 @@ impl PySubstraitSerializer { #[staticmethod] pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { match PySubstraitSerializer::serialize_bytes(sql, ctx, py) { - Ok(proto_bytes) => PySubstraitSerializer::deserialize_bytes(proto_bytes, py), + Ok(proto_bytes) => { + let proto_bytes: &PyBytes = proto_bytes.as_ref(py).downcast().unwrap(); + PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py) + } Err(e) => Err(py_datafusion_err(e)), } } #[staticmethod] - pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult> { + pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { let proto_bytes: Vec = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx)) .map_err(DataFusionError::from)?; - Ok(proto_bytes) + Ok(PyBytes::new(py, &proto_bytes).into()) } #[staticmethod] @@ -136,6 +151,7 @@ impl PySubstraitConsumer { } pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;