From 2665166593a0f8ec55275c3c4a73e278e848a645 Mon Sep 17 00:00:00 2001 From: Thomas Tanon Date: Fri, 13 Mar 2026 17:43:43 +0100 Subject: [PATCH] Remove validate_pycapsule The Bound<'_, PyCapsule>::pointer_checked does the same validation and is already used across the codebase --- crates/core/src/array.rs | 6 +----- crates/core/src/catalog.rs | 11 +++-------- crates/core/src/context.rs | 21 ++++++--------------- crates/core/src/dataframe.rs | 7 ++----- crates/core/src/udaf.rs | 7 ++----- crates/core/src/udf.rs | 11 ++++------- crates/core/src/udtf.rs | 8 ++------ crates/core/src/udwf.rs | 11 ++++------- crates/util/src/lib.rs | 13 ++++--------- 9 files changed, 28 insertions(+), 67 deletions(-) diff --git a/crates/core/src/array.rs b/crates/core/src/array.rs index 99e63ef50..f284fa9de 100644 --- a/crates/core/src/array.rs +++ b/crates/core/src/array.rs @@ -22,8 +22,6 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{Field, FieldRef}; use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; use arrow::pyarrow::ToPyArrow; -use datafusion_python_util::validate_pycapsule; -use pyo3::ffi::c_str; use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods}; use pyo3::types::PyCapsule; use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods}; @@ -53,10 +51,8 @@ impl PyArrowArrayExportable { requested_schema: Option>, ) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { let field = if let Some(schema_capsule) = requested_schema { - validate_pycapsule(&schema_capsule, "arrow_schema")?; - let data: NonNull = schema_capsule - .pointer_checked(Some(c_str!("arrow_schema")))? + .pointer_checked(Some(c"arrow_schema"))? .cast(); let schema_ptr = unsafe { data.as_ref() }; let desired_field = Field::try_from(schema_ptr)?; diff --git a/crates/core/src/catalog.rs b/crates/core/src/catalog.rs index f707e7e5c..30ec4744c 100644 --- a/crates/core/src/catalog.rs +++ b/crates/core/src/catalog.rs @@ -31,12 +31,10 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::schema_provider::FFI_SchemaProvider; use datafusion_python_util::{ - create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, validate_pycapsule, - wait_for_future, + create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, wait_for_future, }; use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyKeyError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -659,9 +657,8 @@ fn extract_catalog_provider_from_pyobj( } let provider = if let Ok(capsule) = catalog_provider.cast::() { - validate_pycapsule(capsule, "datafusion_catalog_provider")?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider")))? + .pointer_checked(Some(c"datafusion_catalog_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -692,10 +689,8 @@ fn extract_schema_provider_from_pyobj( } let provider = if let Ok(capsule) = schema_provider.cast::() { - validate_pycapsule(capsule, "datafusion_schema_provider")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_schema_provider")))? + .pointer_checked(Some(c"datafusion_schema_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 200b6470b..00cab4be4 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -55,12 +55,11 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; use datafusion_python_util::{ create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx, - get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future, + get_tokio_runtime, spawn_future, wait_for_future, }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyKeyError, PyValueError}; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; use url::Url; @@ -675,10 +674,8 @@ impl PySessionContext { let factory: Arc = if let Ok(capsule) = factory.cast::().map_err(py_datafusion_err) { - validate_pycapsule(capsule, "datafusion_table_provider_factory")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_table_provider_factory")))? + .pointer_checked(Some(c"datafusion_table_provider_factory"))? .cast(); let factory = unsafe { data.as_ref() }; factory.into() @@ -709,12 +706,9 @@ impl PySessionContext { .call1((codec_capsule,))?; } - let provider = if let Ok(capsule) = provider.cast::().map_err(py_datafusion_err) - { - validate_pycapsule(capsule, "datafusion_catalog_provider_list")?; - + let provider = if let Ok(capsule) = provider.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))? + .pointer_checked(Some(c"datafusion_catalog_provider_list"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -747,12 +741,9 @@ impl PySessionContext { .call1((codec_capsule,))?; } - let provider = if let Ok(capsule) = provider.cast::().map_err(py_datafusion_err) - { - validate_pycapsule(capsule, "datafusion_catalog_provider")?; - + let provider = if let Ok(capsule) = provider.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider")))? + .pointer_checked(Some(c"datafusion_catalog_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 29fc05ed3..72595ba81 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -41,12 +41,11 @@ use datafusion::logical_expr::SortExpr; use datafusion::logical_expr::dml::InsertOp; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; -use datafusion_python_util::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future}; +use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future}; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; use pyo3::PyErr; use pyo3::exceptions::PyValueError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; @@ -1117,10 +1116,8 @@ impl PyDataFrame { let mut projection: Option = None; if let Some(schema_capsule) = requested_schema { - validate_pycapsule(&schema_capsule, "arrow_schema")?; - let data: NonNull = schema_capsule - .pointer_checked(Some(c_str!("arrow_schema")))? + .pointer_checked(Some(c"arrow_schema"))? .cast(); let schema_ptr = unsafe { data.as_ref() }; let desired_schema = Schema::try_from(schema_ptr)?; diff --git a/crates/core/src/udaf.rs b/crates/core/src/udaf.rs index ed26c79cc..80ef51716 100644 --- a/crates/core/src/udaf.rs +++ b/crates/core/src/udaf.rs @@ -27,8 +27,7 @@ use datafusion::logical_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf, }; use datafusion_ffi::udaf::FFI_AggregateUDF; -use datafusion_python_util::{parse_volatility, validate_pycapsule}; -use pyo3::ffi::c_str; +use datafusion_python_util::parse_volatility; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple}; @@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { } fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { - validate_pycapsule(capsule, "datafusion_aggregate_udf")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_aggregate_udf")))? + .pointer_checked(Some(c"datafusion_aggregate_udf"))? .cast(); let udaf = unsafe { data.as_ref() }; let udaf: Arc = udaf.into(); diff --git a/crates/core/src/udf.rs b/crates/core/src/udf.rs index 7543f96d4..c0a39cb47 100644 --- a/crates/core/src/udf.rs +++ b/crates/core/src/udf.rs @@ -32,13 +32,12 @@ use datafusion::logical_expr::{ Volatility, }; use datafusion_ffi::udf::FFI_ScalarUDF; -use datafusion_python_util::{parse_volatility, validate_pycapsule}; -use pyo3::ffi::c_str; +use datafusion_python_util::parse_volatility; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple}; use crate::array::PyArrowArrayExportable; -use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionResult, to_datafusion_err}; use crate::expr::PyExpr; /// This struct holds the Python written function that is a @@ -194,11 +193,9 @@ impl PyScalarUDF { pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { if func.hasattr("__datafusion_scalar_udf__")? { let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_scalar_udf")?; - + let capsule = capsule.cast::().map_err(to_datafusion_err)?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_scalar_udf")))? + .pointer_checked(Some(c"datafusion_scalar_udf"))? .cast(); let udf = unsafe { data.as_ref() }; let udf: Arc = udf.into(); diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index 77c5ffbbc..9371732dc 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -22,10 +22,8 @@ use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; -use datafusion_python_util::validate_pycapsule; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyImportError, PyTypeError}; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple, PyType}; @@ -73,11 +71,9 @@ impl PyTableFunction { err } })?; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_function")?; - + let capsule = capsule.cast::()?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_table_function")))? + .pointer_checked(Some(c"datafusion_table_function"))? .cast(); let ffi_func = unsafe { data.as_ref() }; let foreign_func: Arc = ffi_func.to_owned().into(); diff --git a/crates/core/src/udwf.rs b/crates/core/src/udwf.rs index ff7ab0352..1d3608ada 100644 --- a/crates/core/src/udwf.rs +++ b/crates/core/src/udwf.rs @@ -32,14 +32,13 @@ use datafusion::logical_expr::{ }; use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; -use datafusion_python_util::{parse_volatility, validate_pycapsule}; +use datafusion_python_util::parse_volatility; use pyo3::exceptions::PyValueError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyList, PyTuple}; use crate::common::data_type::PyScalarValue; -use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionResult, to_datafusion_err}; use crate::expr::PyExpr; #[derive(Debug)] @@ -262,11 +261,9 @@ impl PyWindowUDF { func }; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_window_udf")?; - + let capsule = capsule.cast::().map_err(to_datafusion_err)?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_window_udf")))? + .pointer_checked(Some(c"datafusion_window_udf"))? .cast(); let udwf = unsafe { data.as_ref() }; let udwf: Arc = udwf.into(); diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 2678a6b9a..5b1c89936 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -26,14 +26,13 @@ use datafusion::logical_expr::Volatility; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError}; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyType}; use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio::time::sleep; -use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionError, PyDataFusionResult, to_datafusion_err}; pub mod errors; @@ -186,11 +185,9 @@ pub fn table_provider_from_pycapsule<'py>( })?; } - if let Ok(capsule) = obj.cast::().map_err(py_datafusion_err) { - validate_pycapsule(capsule, "datafusion_table_provider")?; - + if let Ok(capsule) = obj.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_table_provider")))? + .pointer_checked(Some(c"datafusion_table_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -220,10 +217,8 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound) -> PyResult()?; - validate_pycapsule(capsule, "datafusion_logical_extension_codec")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))? + .pointer_checked(Some(c"datafusion_logical_extension_codec"))? .cast(); let codec = unsafe { data.as_ref() };