diff --git a/src/errors.rs b/src/errors.rs index e739fe31f..ce6b3c28a 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -23,6 +23,8 @@ use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; use pyo3::{exceptions::PyException, PyErr}; +pub type Result = std::result::Result; + #[derive(Debug)] pub enum DataFusionError { ExecutionError(InnerDataFusionError), diff --git a/src/expr/literal.rs b/src/expr/literal.rs index 27674ce6f..b29497e64 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::errors::py_runtime_err; +use crate::errors::{py_runtime_err, DataFusionError}; use datafusion_common::ScalarValue; use pyo3::prelude::*; @@ -37,6 +37,15 @@ impl From for PyLiteral { } } +macro_rules! extract_scalar_value { + ($self: expr, $variant: ident) => { + match &$self.value { + ScalarValue::$variant(value) => Ok(*value), + other => Err(unexpected_literal_value(other)), + } + }; +} + #[pymethods] impl PyLiteral { /// Get the data type of this literal value @@ -44,31 +53,109 @@ impl PyLiteral { format!("{}", self.value.get_datatype()) } - fn value_i32(&self) -> PyResult { - if let ScalarValue::Int32(Some(n)) = &self.value { - Ok(*n) - } else { - Err(py_runtime_err("Cannot access value as i32")) + pub fn value_f32(&self) -> PyResult> { + extract_scalar_value!(self, Float32) + } + + pub fn value_f64(&self) -> PyResult> { + extract_scalar_value!(self, Float64) + } + + pub fn value_decimal128(&mut self) -> PyResult<(Option, u8, i8)> { + match &self.value { + ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), + other => Err(unexpected_literal_value(other)), + } + } + + pub fn value_i8(&self) -> PyResult> { + extract_scalar_value!(self, Int8) + } + + pub fn value_i16(&self) -> PyResult> { + extract_scalar_value!(self, Int16) + } + + pub fn value_i32(&self) -> PyResult> { + extract_scalar_value!(self, Int32) + } + + pub fn value_i64(&self) -> PyResult> { + extract_scalar_value!(self, Int64) + } + + pub fn value_u8(&self) -> PyResult> { + extract_scalar_value!(self, UInt8) + } + + pub fn value_u16(&self) -> PyResult> { + extract_scalar_value!(self, UInt16) + } + + pub fn value_u32(&self) -> PyResult> { + extract_scalar_value!(self, UInt32) + } + + pub fn value_u64(&self) -> PyResult> { + extract_scalar_value!(self, UInt64) + } + + pub fn value_date32(&self) -> PyResult> { + extract_scalar_value!(self, Date32) + } + + pub fn value_date64(&self) -> PyResult> { + extract_scalar_value!(self, Date64) + } + + pub fn value_time64(&self) -> PyResult> { + extract_scalar_value!(self, Time64Nanosecond) + } + + pub fn value_timestamp(&mut self) -> PyResult<(Option, Option)> { + match &self.value { + ScalarValue::TimestampNanosecond(iv, tz) + | ScalarValue::TimestampMicrosecond(iv, tz) + | ScalarValue::TimestampMillisecond(iv, tz) + | ScalarValue::TimestampSecond(iv, tz) => Ok((*iv, tz.clone())), + other => Err(unexpected_literal_value(other)), } } - fn value_i64(&self) -> PyResult { - if let ScalarValue::Int64(Some(n)) = &self.value { - Ok(*n) - } else { - Err(py_runtime_err("Cannot access value as i64")) + pub fn value_bool(&self) -> PyResult> { + extract_scalar_value!(self, Boolean) + } + + pub fn value_string(&self) -> PyResult> { + match &self.value { + ScalarValue::Utf8(value) => Ok(value.clone()), + other => Err(unexpected_literal_value(other)), } } - fn value_str(&self) -> PyResult { - if let ScalarValue::Utf8(Some(str)) = &self.value { - Ok(str.clone()) - } else { - Err(py_runtime_err("Cannot access value as string")) + pub fn value_interval_day_time(&self) -> PyResult> { + match &self.value { + ScalarValue::IntervalDayTime(Some(iv)) => { + let interval = *iv as u64; + let days = (interval >> 32) as i32; + let ms = interval as i32; + Ok(Some((days, ms))) + } + ScalarValue::IntervalDayTime(None) => Ok(None), + other => Err(unexpected_literal_value(other)), } } + #[allow(clippy::wrong_self_convention)] + fn into_type(&self, py: Python) -> PyResult { + Ok(self.clone().into_py(py)) + } + fn __repr__(&self) -> PyResult { Ok(format!("{}", self.value)) } } + +fn unexpected_literal_value(value: &ScalarValue) -> PyErr { + DataFusionError::Common(format!("getValue() - Unexpected value: {value}")).into() +}