Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 134 additions & 121 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from abc import ABC, abstractmethod
from typing import Dict, List
from typing import Any, Dict, List

from datafusion.common import SqlSchema, SqlTable

Expand Down Expand Up @@ -79,8 +79,9 @@ def show_schemas(self, **kwargs) -> Dict[str, SqlSchema]:
@abstractmethod
def create_table(
self,
schema_name: str,
table_name: str,
schema_name: str = None,
input_source: Any,
**kwargs,
):
"""
Expand Down
22 changes: 22 additions & 0 deletions datafusion/input/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from .location import LocationInputPlugin

__all__ = [
LocationInputPlugin,
]
43 changes: 43 additions & 0 deletions datafusion/input/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from abc import ABC, abstractmethod
from typing import Any

from datafusion.common import SqlTable


class BaseInputSource(ABC):
"""
If a consuming library would like to provider their own InputSource
this is the class they should extend to write their own. Once
completed the Plugin InputSource can be registered with the
SessionContext to ensure that it will be used in order
to obtain the SqlTable information from the custom datasource.
"""

@abstractmethod
def is_correct_input(
self, input_item: Any, table_name: str, **kwargs
) -> bool:
pass

@abstractmethod
def build_table(
self, input_item: Any, table_name: str, **kwarg
) -> SqlTable:
pass
88 changes: 88 additions & 0 deletions datafusion/input/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
from typing import Any

from datafusion.common import DataTypeMap, SqlTable
from datafusion.input.base import BaseInputSource


class LocationInputPlugin(BaseInputSource):
"""
Input Plugin for everything, which can be read
in from a file (on disk, remote etc.)
"""

def is_correct_input(self, input_item: Any, table_name: str, **kwargs):
return isinstance(input_item, str)

def build_table(
self,
input_file: str,
table_name: str,
**kwargs,
) -> SqlTable:
_, extension = os.path.splitext(input_file)
format = extension.lstrip(".").lower()
num_rows = 0 # Total number of rows in the file. Used for statistics
columns = []

if format == "parquet":
import pyarrow.parquet as pq

# Read the Parquet metadata
metadata = pq.read_metadata(input_file)
num_rows = metadata.num_rows

# Iterate through the schema and build the SqlTable
print(f"Metadata Schema: {metadata.schema}")
for col in metadata.schema:
columns.append(
(
col.name,
DataTypeMap.from_parquet_type_str(col.physical_type),
)
)

elif format == "csv":
import csv

# Consume header row and count number of rows for statistics.
# TODO: Possibly makes sense to have the eager number of rows
# calculated as a configuration since you must read the entire file
# to get that information. However, this should only be occuring
# at table creation time and therefore shouldn't
# slow down query performance.
with open(input_file, "r") as file:
reader = csv.reader(file)
header_row = next(reader)
print(header_row)
for _ in reader:
num_rows += 1

# TODO: Need to actually consume this row into resonable columns
raise RuntimeError(
"TODO: Currently unable to support CSV input files."
)
else:
raise RuntimeError(
f"Input of format: `{format}` is currently not supported.\
Only Parquet and CSV."
)

return SqlTable(table_name, columns, num_rows, input_file)
33 changes: 33 additions & 0 deletions datafusion/tests/test_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
from datafusion.input.location import LocationInputPlugin


def test_location_input():
location_input = LocationInputPlugin()

cwd = os.getcwd()
input_file = (
cwd + "/testing/data/parquet/generated_simple_numerics/blogs.parquet"
)
table_name = "blog"
tbl = location_input.build_table(input_file, table_name)
assert "blog" == tbl.name
assert 3 == len(tbl.columns)
assert "blogs.parquet" in tbl.filepath
56 changes: 55 additions & 1 deletion src/common/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_common::{DataFusionError, ScalarValue};
use pyo3::prelude::*;
use pyo3::{exceptions::PyValueError, prelude::*};

use crate::errors::py_datafusion_err;

Expand Down Expand Up @@ -303,12 +303,43 @@ impl DataTypeMap {
}
}

#[staticmethod]
#[pyo3(name = "from_parquet_type_str")]
/// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
/// with a String type for schema rather than an object type. Here we make a best effort
/// to convert that to a physical type.
pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
"boolean" => Ok(DataType::Boolean),
"int32" => Ok(DataType::Int32),
"int64" => Ok(DataType::Int64),
"int96" => {
// Int96 is an old datatype that is now deprecated. We convert to nanosecond timestamp
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
"float" => Ok(DataType::Float32),
"double" => Ok(DataType::Float64),
_ => Err(PyValueError::new_err(format!(
"Unable to determine Arrow Data Type from Parquet String type: {:?}",
parquet_str_type
))),
};
DataTypeMap::map_from_arrow_type(&arrow_dtype?)
}

#[staticmethod]
#[pyo3(name = "arrow")]
pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
}

#[staticmethod]
#[pyo3(name = "arrow_str")]
pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
DataTypeMap::map_from_arrow_type(&data_type?.data_type)
}

#[staticmethod]
#[pyo3(name = "sql")]
pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
Expand Down Expand Up @@ -558,6 +589,29 @@ pub struct PyDataType {
pub data_type: DataType,
}

impl PyDataType {
/// There are situations when obtaining dtypes on the Python side where the Arrow type
/// is presented as a String rather than an actual DataType. This function is used to
/// convert that String to a DataType for the Python side to use.
pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
"boolean" => Ok(DataType::Boolean),
"int32" => Ok(DataType::Int32),
"int64" => Ok(DataType::Int64),
"float" => Ok(DataType::Float32),
"double" => Ok(DataType::Float64),
"float64" => Ok(DataType::Float64),
_ => Err(PyValueError::new_err(format!(
"Unable to determine Arrow Data Type from Arrow String type: {:?}",
arrow_str_type
))),
};
Ok(PyDataType {
data_type: arrow_dtype?,
})
}
}

impl From<PyDataType> for DataType {
fn from(data_type: PyDataType) -> DataType {
data_type.data_type
Expand Down
26 changes: 20 additions & 6 deletions src/expr/table_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::TableReference;
use datafusion_expr::logical_plan::TableScan;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
Expand Down Expand Up @@ -71,12 +72,25 @@ impl PyTableScan {
Ok(format!("{}", self.table_scan.table_name))
}

/// TODO: Bindings for `TableSource` need to exist first. Left as a
/// placeholder to display intention to add when able to.
// #[pyo3(name = "source")]
// fn py_source(&self) -> PyResult<Arc<dyn TableSource>> {
// Ok(self.table_scan.source)
// }
#[pyo3(name = "fqn")]
fn fqn(&self) -> PyResult<(Option<String>, Option<String>, String)> {
let table_ref: TableReference = self.table_scan.table_name.clone();
Ok(match table_ref {
TableReference::Bare { table } => (None, None, table.to_string()),
TableReference::Partial { schema, table } => {
(None, Some(schema.to_string()), table.to_string())
}
TableReference::Full {
catalog,
schema,
table,
} => (
Some(catalog.to_string()),
Some(schema.to_string()),
table.to_string(),
),
})
}

/// The column indexes that should be. Note if this is empty then
/// all columns should be read by the `TableProvider`. This function
Expand Down