diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index d48bdd929..499069ebb 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -138,6 +138,37 @@ def test_from_arrow_table_with_name(ctx): assert tables[0] == "tbl" +def test_from_arrow_table_empty(ctx): + data = {"a": [], "b": []} + schema = pa.schema([("a", pa.int32()), ("b", pa.string())]) + table = pa.Table.from_pydict(data, schema=schema) + + # convert to DataFrame + df = ctx.from_arrow_table(table) + tables = list(ctx.tables()) + + assert df + assert len(tables) == 1 + assert isinstance(df, DataFrame) + assert set(df.schema().names) == {"a", "b"} + assert len(df.collect()) == 0 + + +def test_from_arrow_table_empty_no_schema(ctx): + data = {"a": [], "b": []} + table = pa.Table.from_pydict(data) + + # convert to DataFrame + df = ctx.from_arrow_table(table) + tables = list(ctx.tables()) + + assert df + assert len(tables) == 1 + assert isinstance(df, DataFrame) + assert set(df.schema().names) == {"a", "b"} + assert len(df.collect()) == 0 + + def test_from_pylist(ctx): # create a dataframe from Python list data = [ diff --git a/src/context.rs b/src/context.rs index f34fbce81..fe915c56b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -39,7 +39,7 @@ use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::utils::{get_tokio_runtime, wait_for_future}; -use datafusion::arrow::datatypes::{DataType, Schema}; +use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -289,9 +289,15 @@ impl PySessionContext { &mut self, partitions: PyArrowType>>, name: Option<&str>, + schema: Option>, py: Python, ) -> PyResult { - let schema = partitions.0[0][0].schema(); + let schema = if let Some(schema) = schema { + SchemaRef::from(schema.0) + } else { + partitions.0[0][0].schema() + }; + let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?; // generate a random (unique) name for this table if none is provided @@ -373,12 +379,15 @@ impl PySessionContext { // Instantiate pyarrow Table object & convert to batches let table = data.call_method0(py, "to_batches")?; + let schema = data.getattr(py, "schema")?; + let schema = schema.extract::>(py)?; + // Cast PyObject to RecordBatch type // Because create_dataframe() expects a vector of vectors of record batches // here we need to wrap the vector of record batches in an additional vector let batches = table.extract::>>(py)?; let list_of_batches = PyArrowType::from(vec![batches.0]); - self.create_dataframe(list_of_batches, name, py) + self.create_dataframe(list_of_batches, name, Some(schema), py) }) }