Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,37 @@ def test_from_arrow_table_with_name(ctx):
assert tables[0] == "tbl"


def test_from_arrow_table_empty(ctx):
data = {"a": [], "b": []}
schema = pa.schema([("a", pa.int32()), ("b", pa.string())])
table = pa.Table.from_pydict(data, schema=schema)

# convert to DataFrame
df = ctx.from_arrow_table(table)
tables = list(ctx.tables())

assert df
assert len(tables) == 1
assert isinstance(df, DataFrame)
assert set(df.schema().names) == {"a", "b"}
assert len(df.collect()) == 0


def test_from_arrow_table_empty_no_schema(ctx):
data = {"a": [], "b": []}
table = pa.Table.from_pydict(data)

# convert to DataFrame
df = ctx.from_arrow_table(table)
tables = list(ctx.tables())

assert df
assert len(tables) == 1
assert isinstance(df, DataFrame)
assert set(df.schema().names) == {"a", "b"}
assert len(df.collect()) == 0


def test_from_pylist(ctx):
# create a dataframe from Python list
data = [
Expand Down
15 changes: 12 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
Expand Down Expand Up @@ -289,9 +289,15 @@ impl PySessionContext {
&mut self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
name: Option<&str>,
schema: Option<PyArrowType<Schema>>,
py: Python,
) -> PyResult<PyDataFrame> {
let schema = partitions.0[0][0].schema();
let schema = if let Some(schema) = schema {
SchemaRef::from(schema.0)
} else {
partitions.0[0][0].schema()
};

let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;

// generate a random (unique) name for this table if none is provided
Expand Down Expand Up @@ -373,12 +379,15 @@ impl PySessionContext {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0(py, "to_batches")?;

let schema = data.getattr(py, "schema")?;
let schema = schema.extract::<PyArrowType<Schema>>(py)?;

// Cast PyObject to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, py)
self.create_dataframe(list_of_batches, name, Some(schema), py)
})
}

Expand Down