Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ __pycache__/
*.py[cod]
*$py.class

# Python dist ignore
dist

# C extensions
*.so

Expand All @@ -24,4 +27,4 @@ apache-rat-*.jar
.env
CHANGELOG.md.bak

docs/mdbook/book
docs/mdbook/book
45 changes: 45 additions & 0 deletions datafusion/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from abc import ABC, abstractmethod


class BaseSessionContext(ABC):
"""
Abstraction defining all methods, properties, and common functionality
shared amongst implementations using DataFusion as their SQL Parser/Engine
"""

@abstractmethod
def register_table(
self,
table_name: str,
path: str,
**kwargs,
):
pass

# TODO: Remove abstraction, this functionality can be shared
# between all implementing classes since it just prints the
# logical plan from DataFusion
@abstractmethod
def explain(self, sql):
pass

@abstractmethod
def sql(self, sql):
pass
14 changes: 9 additions & 5 deletions datafusion/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@

import cudf
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column


class SessionContext:
class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

def register_parquet(self, name, path):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def to_cudf_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
Expand All @@ -55,6 +52,13 @@ def to_cudf_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def explain(self, sql):
super.explain()

def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
Expand Down
14 changes: 9 additions & 5 deletions datafusion/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@

import pandas as pd
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column


class SessionContext:
class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

def register_parquet(self, name, path):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def to_pandas_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
Expand All @@ -55,6 +52,13 @@ def to_pandas_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def explain(self, sql):
super.explain()

def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
Expand Down
14 changes: 9 additions & 5 deletions datafusion/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@

import polars
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Aggregate
from datafusion.expr import Column, AggregateFunction


class SessionContext:
class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

def register_parquet(self, name, path):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def to_polars_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
Expand Down Expand Up @@ -78,6 +75,13 @@ def to_polars_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def explain(self, sql):
super.explain()

def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
Expand Down
2 changes: 1 addition & 1 deletion docs/mdbook/src/usage/create-table.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ ctx.register_csv("csv_1e8", "G1_1e8_1e2_0_0.csv")
You can read a Parquet file into a DataFusion DataFrame. Here's how to read the `yellow_tripdata_2021-01.parquet` file into a table named `taxi`.

```python
ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
```
2 changes: 1 addition & 1 deletion examples/sql-on-cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@


ctx = SessionContext()
ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql("select passenger_count from taxi")
print(df)
2 changes: 1 addition & 1 deletion examples/sql-on-pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@


ctx = SessionContext()
ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql("select passenger_count from taxi")
print(df)
2 changes: 1 addition & 1 deletion examples/sql-on-polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


ctx = SessionContext()
ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql(
"select passenger_count, count(*) from taxi group by passenger_count"
)
Expand Down