Skip to content

Commit ec3c71d

Browse files
Overflow in negate operator (#11084)
* Do checked negative op instead of unchecked * add tests for checking if overflow error occurs * add context to negating complexer ScalarValues * put format! call to create error message in closure * seperate test case for f16 that should panic with not implemented
1 parent 5f02c8a commit ec3c71d

File tree

1 file changed

+160
-17
lines changed
  • datafusion/common/src/scalar

1 file changed

+160
-17
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use std::iter::repeat;
2929
use std::str::FromStr;
3030
use std::sync::Arc;
3131

32+
use crate::arrow_datafusion_err;
3233
use crate::cast::{
3334
as_decimal128_array, as_decimal256_array, as_dictionary_array,
3435
as_fixed_size_binary_array, as_fixed_size_list_array,
@@ -1168,6 +1169,13 @@ impl ScalarValue {
11681169

11691170
/// Calculate arithmetic negation for a scalar value
11701171
pub fn arithmetic_negate(&self) -> Result<Self> {
1172+
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
1173+
v: T,
1174+
ctx: impl Fn() -> String,
1175+
) -> Result<T> {
1176+
v.neg_checked()
1177+
.map_err(|e| arrow_datafusion_err!(e).context(ctx()))
1178+
}
11711179
match self {
11721180
ScalarValue::Int8(None)
11731181
| ScalarValue::Int16(None)
@@ -1177,40 +1185,91 @@ impl ScalarValue {
11771185
| ScalarValue::Float64(None) => Ok(self.clone()),
11781186
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
11791187
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
1180-
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))),
1181-
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))),
1182-
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))),
1183-
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))),
1184-
ScalarValue::IntervalYearMonth(Some(v)) => {
1185-
Ok(ScalarValue::IntervalYearMonth(Some(-v)))
1186-
}
1188+
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
1189+
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))),
1190+
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))),
1191+
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))),
1192+
ScalarValue::IntervalYearMonth(Some(v)) => Ok(
1193+
ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || {
1194+
format!("In negation of IntervalYearMonth({v})")
1195+
})?)),
1196+
),
11871197
ScalarValue::IntervalDayTime(Some(v)) => {
11881198
let (days, ms) = IntervalDayTimeType::to_parts(*v);
1189-
let val = IntervalDayTimeType::make_value(-days, -ms);
1199+
let val = IntervalDayTimeType::make_value(
1200+
neg_checked_with_ctx(days, || {
1201+
format!("In negation of days {days} in IntervalDayTime")
1202+
})?,
1203+
neg_checked_with_ctx(ms, || {
1204+
format!("In negation of milliseconds {ms} in IntervalDayTime")
1205+
})?,
1206+
);
11901207
Ok(ScalarValue::IntervalDayTime(Some(val)))
11911208
}
11921209
ScalarValue::IntervalMonthDayNano(Some(v)) => {
11931210
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v);
1194-
let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos);
1211+
let val = IntervalMonthDayNanoType::make_value(
1212+
neg_checked_with_ctx(months, || {
1213+
format!("In negation of months {months} of IntervalMonthDayNano")
1214+
})?,
1215+
neg_checked_with_ctx(days, || {
1216+
format!("In negation of days {days} of IntervalMonthDayNano")
1217+
})?,
1218+
neg_checked_with_ctx(nanos, || {
1219+
format!("In negation of nanos {nanos} of IntervalMonthDayNano")
1220+
})?,
1221+
);
11951222
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
11961223
}
11971224
ScalarValue::Decimal128(Some(v), precision, scale) => {
1198-
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
1225+
Ok(ScalarValue::Decimal128(
1226+
Some(neg_checked_with_ctx(*v, || {
1227+
format!("In negation of Decimal128({v}, {precision}, {scale})")
1228+
})?),
1229+
*precision,
1230+
*scale,
1231+
))
1232+
}
1233+
ScalarValue::Decimal256(Some(v), precision, scale) => {
1234+
Ok(ScalarValue::Decimal256(
1235+
Some(neg_checked_with_ctx(*v, || {
1236+
format!("In negation of Decimal256({v}, {precision}, {scale})")
1237+
})?),
1238+
*precision,
1239+
*scale,
1240+
))
11991241
}
1200-
ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
1201-
ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale),
1202-
),
12031242
ScalarValue::TimestampSecond(Some(v), tz) => {
1204-
Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone()))
1243+
Ok(ScalarValue::TimestampSecond(
1244+
Some(neg_checked_with_ctx(*v, || {
1245+
format!("In negation of TimestampSecond({v})")
1246+
})?),
1247+
tz.clone(),
1248+
))
12051249
}
12061250
ScalarValue::TimestampNanosecond(Some(v), tz) => {
1207-
Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone()))
1251+
Ok(ScalarValue::TimestampNanosecond(
1252+
Some(neg_checked_with_ctx(*v, || {
1253+
format!("In negation of TimestampNanoSecond({v})")
1254+
})?),
1255+
tz.clone(),
1256+
))
12081257
}
12091258
ScalarValue::TimestampMicrosecond(Some(v), tz) => {
1210-
Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone()))
1259+
Ok(ScalarValue::TimestampMicrosecond(
1260+
Some(neg_checked_with_ctx(*v, || {
1261+
format!("In negation of TimestampMicroSecond({v})")
1262+
})?),
1263+
tz.clone(),
1264+
))
12111265
}
12121266
ScalarValue::TimestampMillisecond(Some(v), tz) => {
1213-
Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone()))
1267+
Ok(ScalarValue::TimestampMillisecond(
1268+
Some(neg_checked_with_ctx(*v, || {
1269+
format!("In negation of TimestampMilliSecond({v})")
1270+
})?),
1271+
tz.clone(),
1272+
))
12141273
}
12151274
value => _internal_err!(
12161275
"Can not run arithmetic negative on scalar value {value:?}"
@@ -3501,6 +3560,7 @@ mod tests {
35013560
use crate::assert_batches_eq;
35023561
use arrow::buffer::OffsetBuffer;
35033562
use arrow::compute::{is_null, kernels};
3563+
use arrow::error::ArrowError;
35043564
use arrow::util::pretty::pretty_format_columns;
35053565
use arrow_buffer::Buffer;
35063566
use arrow_schema::Fields;
@@ -5494,6 +5554,89 @@ mod tests {
54945554
Ok(())
54955555
}
54965556

5557+
#[test]
5558+
#[allow(arithmetic_overflow)] // we want to test them
5559+
fn test_scalar_negative_overflows() -> Result<()> {
5560+
macro_rules! test_overflow_on_value {
5561+
($($val:expr),* $(,)?) => {$(
5562+
{
5563+
let value: ScalarValue = $val;
5564+
let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}");
5565+
let root_err = err.find_root();
5566+
match root_err{
5567+
DataFusionError::ArrowError(
5568+
ArrowError::ComputeError(_),
5569+
_,
5570+
) => {}
5571+
_ => return Err(err),
5572+
};
5573+
}
5574+
)*};
5575+
}
5576+
test_overflow_on_value!(
5577+
// the integers
5578+
i8::MIN.into(),
5579+
i16::MIN.into(),
5580+
i32::MIN.into(),
5581+
i64::MIN.into(),
5582+
// for decimals, only value needs to be tested
5583+
ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?,
5584+
ScalarValue::Decimal256(Some(i256::MIN), 20, 5),
5585+
// interval, check all possible values
5586+
ScalarValue::IntervalYearMonth(Some(i32::MIN)),
5587+
ScalarValue::new_interval_dt(i32::MIN, 999),
5588+
ScalarValue::new_interval_dt(1, i32::MIN),
5589+
ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456),
5590+
ScalarValue::new_interval_mdn(12, i32::MIN, 123_456),
5591+
ScalarValue::new_interval_mdn(12, 15, i64::MIN),
5592+
// tz doesn't matter when negating
5593+
ScalarValue::TimestampSecond(Some(i64::MIN), None),
5594+
ScalarValue::TimestampMillisecond(Some(i64::MIN), None),
5595+
ScalarValue::TimestampMicrosecond(Some(i64::MIN), None),
5596+
ScalarValue::TimestampNanosecond(Some(i64::MIN), None),
5597+
);
5598+
5599+
let float_cases = [
5600+
(
5601+
ScalarValue::Float16(Some(f16::MIN)),
5602+
ScalarValue::Float16(Some(f16::MAX)),
5603+
),
5604+
(
5605+
ScalarValue::Float16(Some(f16::MAX)),
5606+
ScalarValue::Float16(Some(f16::MIN)),
5607+
),
5608+
(f32::MIN.into(), f32::MAX.into()),
5609+
(f32::MAX.into(), f32::MIN.into()),
5610+
(f64::MIN.into(), f64::MAX.into()),
5611+
(f64::MAX.into(), f64::MIN.into()),
5612+
];
5613+
// skip float 16 because they aren't supported
5614+
for (test, expected) in float_cases.into_iter().skip(2) {
5615+
assert_eq!(test.arithmetic_negate()?, expected);
5616+
}
5617+
Ok(())
5618+
}
5619+
5620+
#[test]
5621+
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
5622+
fn f16_test_overflow() {
5623+
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
5624+
let cases = [
5625+
(
5626+
ScalarValue::Float16(Some(f16::MIN)),
5627+
ScalarValue::Float16(Some(f16::MAX)),
5628+
),
5629+
(
5630+
ScalarValue::Float16(Some(f16::MAX)),
5631+
ScalarValue::Float16(Some(f16::MIN)),
5632+
),
5633+
];
5634+
5635+
for (test, expected) in cases {
5636+
assert_eq!(test.arithmetic_negate().unwrap(), expected);
5637+
}
5638+
}
5639+
54975640
macro_rules! expect_operation_error {
54985641
($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => {
54995642
#[test]

0 commit comments

Comments
 (0)