diff --git a/datafusion/__init__.py b/datafusion/__init__.py index 8bc0e2306..8ce63018d 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -47,6 +47,7 @@ Filter, Limit, Projection, + ScalarVariable, Sort, TableScan, ) @@ -72,6 +73,7 @@ "Sort", "Limit", "Filter", + "ScalarVariable", "Alias", ] diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index cadd31edd..83d64f385 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -44,6 +44,7 @@ Aggregate, Sort, Analyze, + ScalarVariable, Alias, ) @@ -78,6 +79,7 @@ def test_class_module_is_datafusion(): Limit, Filter, Analyze, + ScalarVariable, Alias, ]: assert klass.__module__ == "datafusion.expr" diff --git a/src/expr.rs b/src/expr.rs index 41bd6b18e..a656e798e 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -30,6 +30,7 @@ use crate::expr::literal::PyLiteral; use datafusion::scalar::ScalarValue; use self::alias::PyAlias; +use self::scalar_variable::PyScalarVariable; pub mod aggregate; pub mod aggregate_expr; @@ -43,6 +44,7 @@ pub mod limit; pub mod literal; pub mod logical_node; pub mod projection; +pub mod scalar_variable; pub mod sort; pub mod table_scan; @@ -72,6 +74,9 @@ impl PyExpr { Python::with_gil(|_| match &self.expr { Expr::Alias(alias, name) => Ok(PyAlias::new(alias, name).into_py(py)), Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)), + Expr::ScalarVariable(data_type, variables) => { + Ok(PyScalarVariable::new(data_type, variables).into_py(py)) + } Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_py(py)), Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_py(py)), Expr::AggregateFunction(expr) => { @@ -193,6 +198,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; // operators m.add_class::()?; diff --git a/src/expr/scalar_variable.rs b/src/expr/scalar_variable.rs new file mode 100644 index 000000000..7b50ba241 --- /dev/null +++ b/src/expr/scalar_variable.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::DataType; +use pyo3::prelude::*; + +use crate::common::data_type::PyDataType; + +#[pyclass(name = "ScalarVariable", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyScalarVariable { + data_type: DataType, + variables: Vec, +} + +impl PyScalarVariable { + pub fn new(data_type: &DataType, variables: &[String]) -> Self { + Self { + data_type: data_type.to_owned(), + variables: variables.to_vec(), + } + } +} + +#[pymethods] +impl PyScalarVariable { + /// Get the data type + fn data_type(&self) -> PyResult { + Ok(self.data_type.clone().into()) + } + + fn variables(&self) -> PyResult> { + Ok(self.variables.clone()) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("{}{:?}", self.data_type, self.variables)) + } +}