Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError as InnerDataFusionError;
use pyo3::{exceptions::PyException, PyErr};

pub type Result<T> = std::result::Result<T, DataFusionError>;

#[derive(Debug)]
pub enum DataFusionError {
ExecutionError(InnerDataFusionError),
Expand Down
119 changes: 103 additions & 16 deletions src/expr/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::errors::py_runtime_err;
use crate::errors::{py_runtime_err, DataFusionError};
use datafusion_common::ScalarValue;
use pyo3::prelude::*;

Expand All @@ -37,38 +37,125 @@ impl From<ScalarValue> for PyLiteral {
}
}

macro_rules! extract_scalar_value {
($self: expr, $variant: ident) => {
match &$self.value {
ScalarValue::$variant(value) => Ok(*value),
other => Err(unexpected_literal_value(other)),
}
};
}

#[pymethods]
impl PyLiteral {
/// Get the data type of this literal value
fn data_type(&self) -> String {
format!("{}", self.value.get_datatype())
}

fn value_i32(&self) -> PyResult<i32> {
if let ScalarValue::Int32(Some(n)) = &self.value {
Ok(*n)
} else {
Err(py_runtime_err("Cannot access value as i32"))
pub fn value_f32(&self) -> PyResult<Option<f32>> {
extract_scalar_value!(self, Float32)
}

pub fn value_f64(&self) -> PyResult<Option<f64>> {
extract_scalar_value!(self, Float64)
}

pub fn value_decimal128(&mut self) -> PyResult<(Option<i128>, u8, i8)> {
match &self.value {
ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)),
other => Err(unexpected_literal_value(other)),
}
}

pub fn value_i8(&self) -> PyResult<Option<i8>> {
extract_scalar_value!(self, Int8)
}

pub fn value_i16(&self) -> PyResult<Option<i16>> {
extract_scalar_value!(self, Int16)
}

pub fn value_i32(&self) -> PyResult<Option<i32>> {
extract_scalar_value!(self, Int32)
}

pub fn value_i64(&self) -> PyResult<Option<i64>> {
extract_scalar_value!(self, Int64)
}

pub fn value_u8(&self) -> PyResult<Option<u8>> {
extract_scalar_value!(self, UInt8)
}

pub fn value_u16(&self) -> PyResult<Option<u16>> {
extract_scalar_value!(self, UInt16)
}

pub fn value_u32(&self) -> PyResult<Option<u32>> {
extract_scalar_value!(self, UInt32)
}

pub fn value_u64(&self) -> PyResult<Option<u64>> {
extract_scalar_value!(self, UInt64)
}

pub fn value_date32(&self) -> PyResult<Option<i32>> {
extract_scalar_value!(self, Date32)
}

pub fn value_date64(&self) -> PyResult<Option<i64>> {
extract_scalar_value!(self, Date64)
}

pub fn value_time64(&self) -> PyResult<Option<i64>> {
extract_scalar_value!(self, Time64Nanosecond)
}

pub fn value_timestamp(&mut self) -> PyResult<(Option<i64>, Option<String>)> {
match &self.value {
ScalarValue::TimestampNanosecond(iv, tz)
| ScalarValue::TimestampMicrosecond(iv, tz)
| ScalarValue::TimestampMillisecond(iv, tz)
| ScalarValue::TimestampSecond(iv, tz) => Ok((*iv, tz.clone())),
other => Err(unexpected_literal_value(other)),
}
}

fn value_i64(&self) -> PyResult<i64> {
if let ScalarValue::Int64(Some(n)) = &self.value {
Ok(*n)
} else {
Err(py_runtime_err("Cannot access value as i64"))
pub fn value_bool(&self) -> PyResult<Option<bool>> {
extract_scalar_value!(self, Boolean)
}

pub fn value_string(&self) -> PyResult<Option<String>> {
match &self.value {
ScalarValue::Utf8(value) => Ok(value.clone()),
other => Err(unexpected_literal_value(other)),
}
}

fn value_str(&self) -> PyResult<String> {
if let ScalarValue::Utf8(Some(str)) = &self.value {
Ok(str.clone())
} else {
Err(py_runtime_err("Cannot access value as string"))
pub fn value_interval_day_time(&self) -> PyResult<Option<(i32, i32)>> {
match &self.value {
ScalarValue::IntervalDayTime(Some(iv)) => {
let interval = *iv as u64;
let days = (interval >> 32) as i32;
let ms = interval as i32;
Ok(Some((days, ms)))
}
ScalarValue::IntervalDayTime(None) => Ok(None),
other => Err(unexpected_literal_value(other)),
}
}

#[allow(clippy::wrong_self_convention)]
fn into_type(&self, py: Python) -> PyResult<PyObject> {
Ok(self.clone().into_py(py))
}

fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.value))
}
}

fn unexpected_literal_value(value: &ScalarValue) -> PyErr {
DataFusionError::Common(format!("getValue<T>() - Unexpected value: {value}")).into()
}