diff --git a/.gitignore b/.gitignore index 1d0a84a43..365b89d5c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ *.py[cod] *$py.class +# Python dist ignore +dist + # C extensions *.so @@ -24,4 +27,4 @@ apache-rat-*.jar .env CHANGELOG.md.bak -docs/mdbook/book \ No newline at end of file +docs/mdbook/book diff --git a/datafusion/context.py b/datafusion/context.py new file mode 100644 index 000000000..aa9c9a8af --- /dev/null +++ b/datafusion/context.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from abc import ABC, abstractmethod + + +class BaseSessionContext(ABC): + """ + Abstraction defining all methods, properties, and common functionality + shared amongst implementations using DataFusion as their SQL Parser/Engine + """ + + @abstractmethod + def register_table( + self, + table_name: str, + path: str, + **kwargs, + ): + pass + + # TODO: Remove abstraction, this functionality can be shared + # between all implementing classes since it just prints the + # logical plan from DataFusion + @abstractmethod + def explain(self, sql): + pass + + @abstractmethod + def sql(self, sql): + pass diff --git a/datafusion/cudf.py b/datafusion/cudf.py index d5f02156f..594e5efea 100644 --- a/datafusion/cudf.py +++ b/datafusion/cudf.py @@ -17,18 +17,15 @@ import cudf import datafusion +from datafusion.context import BaseSessionContext from datafusion.expr import Projection, TableScan, Column -class SessionContext: +class SessionContext(BaseSessionContext): def __init__(self): self.datafusion_ctx = datafusion.SessionContext() self.parquet_tables = {} - def register_parquet(self, name, path): - self.parquet_tables[name] = path - self.datafusion_ctx.register_parquet(name, path) - def to_cudf_expr(self, expr): # get Python wrapper for logical expression expr = expr.to_variant() @@ -55,6 +52,13 @@ def to_cudf_df(self, plan): "unsupported logical operator: {}".format(type(node)) ) + def register_table(self, name, path, **kwargs): + self.parquet_tables[name] = path + self.datafusion_ctx.register_parquet(name, path) + + def explain(self, sql): + super.explain() + def sql(self, sql): datafusion_df = self.datafusion_ctx.sql(sql) plan = datafusion_df.logical_plan() diff --git a/datafusion/pandas.py b/datafusion/pandas.py index f8e56512b..935d9619b 100644 --- a/datafusion/pandas.py +++ b/datafusion/pandas.py @@ -17,18 +17,15 @@ import pandas as pd import datafusion +from datafusion.context import BaseSessionContext from datafusion.expr import Projection, TableScan, Column -class SessionContext: +class SessionContext(BaseSessionContext): def __init__(self): self.datafusion_ctx = datafusion.SessionContext() self.parquet_tables = {} - def register_parquet(self, name, path): - self.parquet_tables[name] = path - self.datafusion_ctx.register_parquet(name, path) - def to_pandas_expr(self, expr): # get Python wrapper for logical expression expr = expr.to_variant() @@ -55,6 +52,13 @@ def to_pandas_df(self, plan): "unsupported logical operator: {}".format(type(node)) ) + def register_table(self, name, path, **kwargs): + self.parquet_tables[name] = path + self.datafusion_ctx.register_parquet(name, path) + + def explain(self, sql): + super.explain() + def sql(self, sql): datafusion_df = self.datafusion_ctx.sql(sql) plan = datafusion_df.logical_plan() diff --git a/datafusion/polars.py b/datafusion/polars.py index a1bafbef8..bbc1fd7c2 100644 --- a/datafusion/polars.py +++ b/datafusion/polars.py @@ -17,19 +17,16 @@ import polars import datafusion +from datafusion.context import BaseSessionContext from datafusion.expr import Projection, TableScan, Aggregate from datafusion.expr import Column, AggregateFunction -class SessionContext: +class SessionContext(BaseSessionContext): def __init__(self): self.datafusion_ctx = datafusion.SessionContext() self.parquet_tables = {} - def register_parquet(self, name, path): - self.parquet_tables[name] = path - self.datafusion_ctx.register_parquet(name, path) - def to_polars_expr(self, expr): # get Python wrapper for logical expression expr = expr.to_variant() @@ -78,6 +75,13 @@ def to_polars_df(self, plan): "unsupported logical operator: {}".format(type(node)) ) + def register_table(self, name, path, **kwargs): + self.parquet_tables[name] = path + self.datafusion_ctx.register_parquet(name, path) + + def explain(self, sql): + super.explain() + def sql(self, sql): datafusion_df = self.datafusion_ctx.sql(sql) plan = datafusion_df.logical_plan() diff --git a/docs/mdbook/src/usage/create-table.md b/docs/mdbook/src/usage/create-table.md index 332863a16..98870fac0 100644 --- a/docs/mdbook/src/usage/create-table.md +++ b/docs/mdbook/src/usage/create-table.md @@ -55,5 +55,5 @@ ctx.register_csv("csv_1e8", "G1_1e8_1e2_0_0.csv") You can read a Parquet file into a DataFusion DataFrame. Here's how to read the `yellow_tripdata_2021-01.parquet` file into a table named `taxi`. ```python -ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet") +ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet") ``` diff --git a/examples/sql-on-cudf.py b/examples/sql-on-cudf.py index 999756fc8..b64d8f046 100644 --- a/examples/sql-on-cudf.py +++ b/examples/sql-on-cudf.py @@ -19,6 +19,6 @@ ctx = SessionContext() -ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet") +ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet") df = ctx.sql("select passenger_count from taxi") print(df) diff --git a/examples/sql-on-pandas.py b/examples/sql-on-pandas.py index 0efd77631..e3312a201 100644 --- a/examples/sql-on-pandas.py +++ b/examples/sql-on-pandas.py @@ -19,6 +19,6 @@ ctx = SessionContext() -ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet") +ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet") df = ctx.sql("select passenger_count from taxi") print(df) diff --git a/examples/sql-on-polars.py b/examples/sql-on-polars.py index c208114c1..dd7a9e021 100644 --- a/examples/sql-on-polars.py +++ b/examples/sql-on-polars.py @@ -19,7 +19,7 @@ ctx = SessionContext() -ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet") +ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet") df = ctx.sql( "select passenger_count, count(*) from taxi group by passenger_count" )