diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b3873c01dd06b..c87882ca72fcd 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -26,7 +26,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ - buffer::{Buffer, MutableBuffer}, + buffer::Buffer, datatypes::{ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, @@ -645,6 +645,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), + // ScalarValue::List is serialized using arrow IPC format Value::ListValue(scalar_list) => { let protobuf::ScalarListValue { ipc_message, @@ -655,29 +656,36 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let schema: Schema = if let Some(schema_ref) = schema { schema_ref.try_into()? } else { - return Err(Error::General("Unexpected schema".to_string())); + return Err(Error::General( + "Invalid schema while deserializing ScalarValue::List" + .to_string(), + )); }; - let message = root_as_message(ipc_message.as_slice()).unwrap(); + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data); - // TODO: Add comment to why adding 0 before arrow_data. - // This code is from https://github.com/apache/arrow-rs/blob/4320a753beaee0a1a6870c59ef46b59e88c9c323/arrow-ipc/src/reader.rs#L1670-L1674C45 - // Construct an unaligned buffer - let mut buffer = MutableBuffer::with_capacity(arrow_data.len() + 1); - buffer.push(0_u8); - buffer.extend_from_slice(arrow_data.as_slice()); - let b = Buffer::from(buffer).slice(1); + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List" + .to_string(), + ) + })?; - let ipc_batch = message.header_as_record_batch().unwrap(); let record_batch = read_record_batch( - &b, + &buffer, ipc_batch, Arc::new(schema), &Default::default(), None, &message.version(), ) - .unwrap(); + .map_err(DataFusionError::ArrowError) + .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); Self::List(arr.to_owned()) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e80d60931cf62..125ced032e20c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -56,13 +56,6 @@ use datafusion_expr::{ pub enum Error { General(String), - InconsistentListTyping(DataType, DataType), - - InconsistentListDesignated { - value: ScalarValue, - designated: DataType, - }, - InvalidScalarValue(ScalarValue), InvalidScalarType(DataType), @@ -80,18 +73,6 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::General(desc) => write!(f, "General error: {desc}"), - Self::InconsistentListTyping(type1, type2) => { - write!( - f, - "Lists with inconsistent typing; {type1:?} and {type2:?} found within list", - ) - } - Self::InconsistentListDesignated { value, designated } => { - write!( - f, - "Value {value:?} was inconsistent with designated type {designated:?}" - ) - } Self::InvalidScalarValue(value) => { write!(f, "{value:?} is invalid as a DataFusion scalar value") } @@ -1145,15 +1126,27 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { "Proto serialization error: ScalarValue::Fixedsizelist not supported" .to_string(), )), + // ScalarValue::List is serialized using Arrow IPC messages. + // as a single column RecordBatch ScalarValue::List(arr) => { - let batch = - RecordBatch::try_from_iter(vec![("field_name", arr.to_owned())]) - .unwrap(); + // Wrap in a "field_name" column + let batch = RecordBatch::try_from_iter(vec![( + "field_name", + arr.to_owned(), + )]) + .map_err(|e| { + Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) + })?; + let gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let (_, encoded_message) = gen .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .unwrap(); + .map_err(|e| { + Error::General(format!( + "Error encoding ScalarValue::List as IPC: {e}" + )) + })?; let schema: protobuf::Schema = batch.schema().try_into()?;