Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ dist
# C extensions
*.so

# Python dist
dist

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
Expand Down
52 changes: 52 additions & 0 deletions datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.

from abc import ABC, abstractmethod
from typing import Dict

from datafusion.common import SqlSchema


class BaseSessionContext(ABC):
Expand All @@ -24,6 +27,55 @@ class BaseSessionContext(ABC):
shared amongst implementations using DataFusion as their SQL Parser/Engine
"""

DEFAULT_CATALOG_NAME = "root"
DEFAULT_SCHEMA_NAME = "datafusion"

@abstractmethod
def create_schema(
self,
schema_name: str,
**kwargs,
):
"""
Creates/Registers a logical container that holds database
objects such as tables, views, indexes, and other
related objects. It provides a way to group related database
objects together. A schema can be owned by a database
user and can be used to separate objects in different
logical groups for easy management.
"""
pass

@abstractmethod
def update_schema(
self,
schema_name: str,
new_schema: SqlSchema,
**kwargs,
):
"""
Updates an existing schema in the SessionContext
"""
pass

@abstractmethod
def drop_schema(
self,
schema_name: str,
**kwargs,
):
"""
Drops the specified Schema, based on name, from the current context
"""
pass

@abstractmethod
def show_schemas(self, **kwargs) -> Dict[str, SqlSchema]:
"""
Return all schemas in the current SessionContext impl.
"""
pass

@abstractmethod
def register_table(
self,
Expand Down
42 changes: 38 additions & 4 deletions datafusion/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,36 @@
# specific language governing permissions and limitations
# under the License.

import logging
import cudf
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column

from datafusion.common import SqlSchema

logger = logging.getLogger(__name__)


class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
def __init__(self, context, logging_level=logging.INFO):
"""
Create a new Session.
"""
# Cudf requires a provided context
self.context = context

# Set the logging level for this SQL context
logging.basicConfig(level=logging_level)

# Name of the root catalog
self.catalog_name = self.DEFAULT_CATALOG_NAME
# Name of the root schema
self.schema_name = self.DEFAULT_SCHEMA_NAME
# Add the schema to the context
sch = SqlSchema(self.schema_name)
self.schemas = {}
self.schemas[self.schema_name] = sch
self.context.register_schema(self.schema_name, sch)

def to_cudf_expr(self, expr):
# get Python wrapper for logical expression
Expand Down Expand Up @@ -52,6 +72,20 @@ def to_cudf_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
self.schemas[schema_name] = SqlSchema(schema_name)
self.context.register_schema(schema_name, SqlSchema(schema_name))

def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
self.schemas[schema_name] = new_schema

def drop_schema(self, schema_name, **kwargs):
del self.schemas[schema_name]

def show_schemas(self, **kwargs):
return self.schemas

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
Expand Down
32 changes: 31 additions & 1 deletion datafusion/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,33 @@
# specific language governing permissions and limitations
# under the License.

import logging
import pandas as pd
import datafusion
from datafusion.common import SqlSchema
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column

logger = logging.getLogger(__name__)


class SessionContext(BaseSessionContext):
def __init__(self):
def __init__(self, logging_level=logging.INFO):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

# Set the logging level for this SQL context
logging.basicConfig(level=logging_level)

# Name of the root catalog
self.catalog_name = self.DEFAULT_CATALOG_NAME
# Name of the root schema
self.schema_name = self.DEFAULT_SCHEMA_NAME
# Add the schema to the context
sch = SqlSchema(self.schema_name)
self.schemas[self.schema_name] = sch
self.context.register_schema(self.schema_name, sch)

def to_pandas_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
Expand All @@ -52,6 +68,20 @@ def to_pandas_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
self.schemas[schema_name] = SqlSchema(schema_name)
self.context.register_schema(schema_name, SqlSchema(schema_name))

def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
self.schemas[schema_name] = new_schema

def drop_schema(self, schema_name, **kwargs):
del self.schemas[schema_name]

def show_schemas(self, **kwargs):
return self.schemas

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
Expand Down
22 changes: 20 additions & 2 deletions datafusion/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
import polars
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Aggregate
from datafusion.expr import Column, AggregateFunction

from datafusion.common import SqlSchema

logger = logging.getLogger(__name__)


class SessionContext(BaseSessionContext):
def __init__(self):
def __init__(self, logging_level=logging.INFO):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

Expand Down Expand Up @@ -75,6 +79,20 @@ def to_polars_df(self, plan):
"unsupported logical operator: {}".format(type(node))
)

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
self.schemas[schema_name] = SqlSchema(schema_name)
self.context.register_schema(schema_name, SqlSchema(schema_name))

def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
self.schemas[schema_name] = new_schema

def drop_schema(self, schema_name, **kwargs):
del self.schemas[schema_name]

def show_schemas(self, **kwargs):
return self.schemas

def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
Expand Down
7 changes: 7 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use pyo3::prelude::*;
pub mod data_type;
pub mod df_field;
pub mod df_schema;
pub mod function;
pub mod schema;

/// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
Expand All @@ -29,5 +31,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<data_type::DataTypeMap>()?;
m.add_class::<data_type::PythonType>()?;
m.add_class::<data_type::SqlType>()?;
m.add_class::<schema::SqlTable>()?;
m.add_class::<schema::SqlSchema>()?;
m.add_class::<schema::SqlView>()?;
m.add_class::<schema::SqlStatistics>()?;
m.add_class::<function::SqlFunction>()?;
Ok(())
}
55 changes: 55 additions & 0 deletions src/common/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;

use datafusion::arrow::datatypes::DataType;
use pyo3::prelude::*;

use super::data_type::PyDataType;

#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlFunction {
pub name: String,
pub return_types: HashMap<Vec<DataType>, DataType>,
pub aggregation: bool,
}

impl SqlFunction {
pub fn new(
function_name: String,
input_types: Vec<PyDataType>,
return_type: PyDataType,
aggregation_bool: bool,
) -> Self {
let mut func = Self {
name: function_name,
return_types: HashMap::new(),
aggregation: aggregation_bool,
};
func.add_type_mapping(input_types, return_type);
func
}

pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>, return_type: PyDataType) {
self.return_types.insert(
input_types.iter().map(|t| t.clone().into()).collect(),
return_type.into(),
);
}
}
Loading