@@ -29,6 +29,7 @@ use std::iter::repeat;
2929use std:: str:: FromStr ;
3030use std:: sync:: Arc ;
3131
32+ use crate :: arrow_datafusion_err;
3233use crate :: cast:: {
3334 as_decimal128_array, as_decimal256_array, as_dictionary_array,
3435 as_fixed_size_binary_array, as_fixed_size_list_array,
@@ -1168,6 +1169,13 @@ impl ScalarValue {
11681169
11691170 /// Calculate arithmetic negation for a scalar value
11701171 pub fn arithmetic_negate ( & self ) -> Result < Self > {
1172+ fn neg_checked_with_ctx < T : ArrowNativeTypeOp > (
1173+ v : T ,
1174+ ctx : impl Fn ( ) -> String ,
1175+ ) -> Result < T > {
1176+ v. neg_checked ( )
1177+ . map_err ( |e| arrow_datafusion_err ! ( e) . context ( ctx ( ) ) )
1178+ }
11711179 match self {
11721180 ScalarValue :: Int8 ( None )
11731181 | ScalarValue :: Int16 ( None )
@@ -1177,40 +1185,91 @@ impl ScalarValue {
11771185 | ScalarValue :: Float64 ( None ) => Ok ( self . clone ( ) ) ,
11781186 ScalarValue :: Float64 ( Some ( v) ) => Ok ( ScalarValue :: Float64 ( Some ( -v) ) ) ,
11791187 ScalarValue :: Float32 ( Some ( v) ) => Ok ( ScalarValue :: Float32 ( Some ( -v) ) ) ,
1180- ScalarValue :: Int8 ( Some ( v) ) => Ok ( ScalarValue :: Int8 ( Some ( -v) ) ) ,
1181- ScalarValue :: Int16 ( Some ( v) ) => Ok ( ScalarValue :: Int16 ( Some ( -v) ) ) ,
1182- ScalarValue :: Int32 ( Some ( v) ) => Ok ( ScalarValue :: Int32 ( Some ( -v) ) ) ,
1183- ScalarValue :: Int64 ( Some ( v) ) => Ok ( ScalarValue :: Int64 ( Some ( -v) ) ) ,
1184- ScalarValue :: IntervalYearMonth ( Some ( v) ) => {
1185- Ok ( ScalarValue :: IntervalYearMonth ( Some ( -v) ) )
1186- }
1188+ ScalarValue :: Int8 ( Some ( v) ) => Ok ( ScalarValue :: Int8 ( Some ( v. neg_checked ( ) ?) ) ) ,
1189+ ScalarValue :: Int16 ( Some ( v) ) => Ok ( ScalarValue :: Int16 ( Some ( v. neg_checked ( ) ?) ) ) ,
1190+ ScalarValue :: Int32 ( Some ( v) ) => Ok ( ScalarValue :: Int32 ( Some ( v. neg_checked ( ) ?) ) ) ,
1191+ ScalarValue :: Int64 ( Some ( v) ) => Ok ( ScalarValue :: Int64 ( Some ( v. neg_checked ( ) ?) ) ) ,
1192+ ScalarValue :: IntervalYearMonth ( Some ( v) ) => Ok (
1193+ ScalarValue :: IntervalYearMonth ( Some ( neg_checked_with_ctx ( * v, || {
1194+ format ! ( "In negation of IntervalYearMonth({v})" )
1195+ } ) ?) ) ,
1196+ ) ,
11871197 ScalarValue :: IntervalDayTime ( Some ( v) ) => {
11881198 let ( days, ms) = IntervalDayTimeType :: to_parts ( * v) ;
1189- let val = IntervalDayTimeType :: make_value ( -days, -ms) ;
1199+ let val = IntervalDayTimeType :: make_value (
1200+ neg_checked_with_ctx ( days, || {
1201+ format ! ( "In negation of days {days} in IntervalDayTime" )
1202+ } ) ?,
1203+ neg_checked_with_ctx ( ms, || {
1204+ format ! ( "In negation of milliseconds {ms} in IntervalDayTime" )
1205+ } ) ?,
1206+ ) ;
11901207 Ok ( ScalarValue :: IntervalDayTime ( Some ( val) ) )
11911208 }
11921209 ScalarValue :: IntervalMonthDayNano ( Some ( v) ) => {
11931210 let ( months, days, nanos) = IntervalMonthDayNanoType :: to_parts ( * v) ;
1194- let val = IntervalMonthDayNanoType :: make_value ( -months, -days, -nanos) ;
1211+ let val = IntervalMonthDayNanoType :: make_value (
1212+ neg_checked_with_ctx ( months, || {
1213+ format ! ( "In negation of months {months} of IntervalMonthDayNano" )
1214+ } ) ?,
1215+ neg_checked_with_ctx ( days, || {
1216+ format ! ( "In negation of days {days} of IntervalMonthDayNano" )
1217+ } ) ?,
1218+ neg_checked_with_ctx ( nanos, || {
1219+ format ! ( "In negation of nanos {nanos} of IntervalMonthDayNano" )
1220+ } ) ?,
1221+ ) ;
11951222 Ok ( ScalarValue :: IntervalMonthDayNano ( Some ( val) ) )
11961223 }
11971224 ScalarValue :: Decimal128 ( Some ( v) , precision, scale) => {
1198- Ok ( ScalarValue :: Decimal128 ( Some ( -v) , * precision, * scale) )
1225+ Ok ( ScalarValue :: Decimal128 (
1226+ Some ( neg_checked_with_ctx ( * v, || {
1227+ format ! ( "In negation of Decimal128({v}, {precision}, {scale})" )
1228+ } ) ?) ,
1229+ * precision,
1230+ * scale,
1231+ ) )
1232+ }
1233+ ScalarValue :: Decimal256 ( Some ( v) , precision, scale) => {
1234+ Ok ( ScalarValue :: Decimal256 (
1235+ Some ( neg_checked_with_ctx ( * v, || {
1236+ format ! ( "In negation of Decimal256({v}, {precision}, {scale})" )
1237+ } ) ?) ,
1238+ * precision,
1239+ * scale,
1240+ ) )
11991241 }
1200- ScalarValue :: Decimal256 ( Some ( v) , precision, scale) => Ok (
1201- ScalarValue :: Decimal256 ( Some ( v. neg_wrapping ( ) ) , * precision, * scale) ,
1202- ) ,
12031242 ScalarValue :: TimestampSecond ( Some ( v) , tz) => {
1204- Ok ( ScalarValue :: TimestampSecond ( Some ( -v) , tz. clone ( ) ) )
1243+ Ok ( ScalarValue :: TimestampSecond (
1244+ Some ( neg_checked_with_ctx ( * v, || {
1245+ format ! ( "In negation of TimestampSecond({v})" )
1246+ } ) ?) ,
1247+ tz. clone ( ) ,
1248+ ) )
12051249 }
12061250 ScalarValue :: TimestampNanosecond ( Some ( v) , tz) => {
1207- Ok ( ScalarValue :: TimestampNanosecond ( Some ( -v) , tz. clone ( ) ) )
1251+ Ok ( ScalarValue :: TimestampNanosecond (
1252+ Some ( neg_checked_with_ctx ( * v, || {
1253+ format ! ( "In negation of TimestampNanoSecond({v})" )
1254+ } ) ?) ,
1255+ tz. clone ( ) ,
1256+ ) )
12081257 }
12091258 ScalarValue :: TimestampMicrosecond ( Some ( v) , tz) => {
1210- Ok ( ScalarValue :: TimestampMicrosecond ( Some ( -v) , tz. clone ( ) ) )
1259+ Ok ( ScalarValue :: TimestampMicrosecond (
1260+ Some ( neg_checked_with_ctx ( * v, || {
1261+ format ! ( "In negation of TimestampMicroSecond({v})" )
1262+ } ) ?) ,
1263+ tz. clone ( ) ,
1264+ ) )
12111265 }
12121266 ScalarValue :: TimestampMillisecond ( Some ( v) , tz) => {
1213- Ok ( ScalarValue :: TimestampMillisecond ( Some ( -v) , tz. clone ( ) ) )
1267+ Ok ( ScalarValue :: TimestampMillisecond (
1268+ Some ( neg_checked_with_ctx ( * v, || {
1269+ format ! ( "In negation of TimestampMilliSecond({v})" )
1270+ } ) ?) ,
1271+ tz. clone ( ) ,
1272+ ) )
12141273 }
12151274 value => _internal_err ! (
12161275 "Can not run arithmetic negative on scalar value {value:?}"
@@ -3501,6 +3560,7 @@ mod tests {
35013560 use crate :: assert_batches_eq;
35023561 use arrow:: buffer:: OffsetBuffer ;
35033562 use arrow:: compute:: { is_null, kernels} ;
3563+ use arrow:: error:: ArrowError ;
35043564 use arrow:: util:: pretty:: pretty_format_columns;
35053565 use arrow_buffer:: Buffer ;
35063566 use arrow_schema:: Fields ;
@@ -5494,6 +5554,89 @@ mod tests {
54945554 Ok ( ( ) )
54955555 }
54965556
5557+ #[ test]
5558+ #[ allow( arithmetic_overflow) ] // we want to test them
5559+ fn test_scalar_negative_overflows ( ) -> Result < ( ) > {
5560+ macro_rules! test_overflow_on_value {
5561+ ( $( $val: expr) ,* $( , ) ?) => { $(
5562+ {
5563+ let value: ScalarValue = $val;
5564+ let err = value. arithmetic_negate( ) . expect_err( "Should receive overflow error on negating {value:?}" ) ;
5565+ let root_err = err. find_root( ) ;
5566+ match root_err{
5567+ DataFusionError :: ArrowError (
5568+ ArrowError :: ComputeError ( _) ,
5569+ _,
5570+ ) => { }
5571+ _ => return Err ( err) ,
5572+ } ;
5573+ }
5574+ ) * } ;
5575+ }
5576+ test_overflow_on_value ! (
5577+ // the integers
5578+ i8 :: MIN . into( ) ,
5579+ i16 :: MIN . into( ) ,
5580+ i32 :: MIN . into( ) ,
5581+ i64 :: MIN . into( ) ,
5582+ // for decimals, only value needs to be tested
5583+ ScalarValue :: try_new_decimal128( i128 :: MIN , 10 , 5 ) ?,
5584+ ScalarValue :: Decimal256 ( Some ( i256:: MIN ) , 20 , 5 ) ,
5585+ // interval, check all possible values
5586+ ScalarValue :: IntervalYearMonth ( Some ( i32 :: MIN ) ) ,
5587+ ScalarValue :: new_interval_dt( i32 :: MIN , 999 ) ,
5588+ ScalarValue :: new_interval_dt( 1 , i32 :: MIN ) ,
5589+ ScalarValue :: new_interval_mdn( i32 :: MIN , 15 , 123_456 ) ,
5590+ ScalarValue :: new_interval_mdn( 12 , i32 :: MIN , 123_456 ) ,
5591+ ScalarValue :: new_interval_mdn( 12 , 15 , i64 :: MIN ) ,
5592+ // tz doesn't matter when negating
5593+ ScalarValue :: TimestampSecond ( Some ( i64 :: MIN ) , None ) ,
5594+ ScalarValue :: TimestampMillisecond ( Some ( i64 :: MIN ) , None ) ,
5595+ ScalarValue :: TimestampMicrosecond ( Some ( i64 :: MIN ) , None ) ,
5596+ ScalarValue :: TimestampNanosecond ( Some ( i64 :: MIN ) , None ) ,
5597+ ) ;
5598+
5599+ let float_cases = [
5600+ (
5601+ ScalarValue :: Float16 ( Some ( f16:: MIN ) ) ,
5602+ ScalarValue :: Float16 ( Some ( f16:: MAX ) ) ,
5603+ ) ,
5604+ (
5605+ ScalarValue :: Float16 ( Some ( f16:: MAX ) ) ,
5606+ ScalarValue :: Float16 ( Some ( f16:: MIN ) ) ,
5607+ ) ,
5608+ ( f32:: MIN . into ( ) , f32:: MAX . into ( ) ) ,
5609+ ( f32:: MAX . into ( ) , f32:: MIN . into ( ) ) ,
5610+ ( f64:: MIN . into ( ) , f64:: MAX . into ( ) ) ,
5611+ ( f64:: MAX . into ( ) , f64:: MIN . into ( ) ) ,
5612+ ] ;
5613+ // skip float 16 because they aren't supported
5614+ for ( test, expected) in float_cases. into_iter ( ) . skip ( 2 ) {
5615+ assert_eq ! ( test. arithmetic_negate( ) ?, expected) ;
5616+ }
5617+ Ok ( ( ) )
5618+ }
5619+
5620+ #[ test]
5621+ #[ should_panic( expected = "Can not run arithmetic negative on scalar value Float16" ) ]
5622+ fn f16_test_overflow ( ) {
5623+ // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
5624+ let cases = [
5625+ (
5626+ ScalarValue :: Float16 ( Some ( f16:: MIN ) ) ,
5627+ ScalarValue :: Float16 ( Some ( f16:: MAX ) ) ,
5628+ ) ,
5629+ (
5630+ ScalarValue :: Float16 ( Some ( f16:: MAX ) ) ,
5631+ ScalarValue :: Float16 ( Some ( f16:: MIN ) ) ,
5632+ ) ,
5633+ ] ;
5634+
5635+ for ( test, expected) in cases {
5636+ assert_eq ! ( test. arithmetic_negate( ) . unwrap( ) , expected) ;
5637+ }
5638+ }
5639+
54975640 macro_rules! expect_operation_error {
54985641 ( $TEST_NAME: ident, $FUNCTION: ident, $EXPECTED_ERROR: expr) => {
54995642 #[ test]
0 commit comments