Skip to content

Commit 36a2003

Browse files
committed
Add specializations for null / non null
1 parent 5ef1038 commit 36a2003

2 files changed

Lines changed: 108 additions & 91 deletions

File tree

datafusion/physical-plan/src/aggregates/group_values/column.rs

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
// under the License.
1717

1818
use crate::aggregates::group_values::group_column::{
19-
ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder,
19+
ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder,
20+
PrimitiveGroupValueBuilder,
2021
};
2122
use crate::aggregates::group_values::GroupValues;
2223
use ahash::RandomState;
@@ -116,6 +117,26 @@ impl GroupValuesColumn {
116117
}
117118
}
118119

120+
/// instantiates a [`PrimitiveGroupValueBuilder`] or
121+
/// [`NonNullPrimitiveGroupValueBuilder`] and pushes it into $v
122+
///
123+
/// Arguments:
124+
/// `$v`: the vector to push the new builder into
125+
/// `$nullable`: whether the input can contains nulls
126+
/// `$t`: the primitive type of the builder
127+
///
128+
macro_rules! instantiate_primitive {
129+
($v:expr, $nullable:expr, $t:ty) => {
130+
if $nullable {
131+
let b = PrimitiveGroupValueBuilder::<$t>::new();
132+
$v.push(Box::new(b) as _)
133+
} else {
134+
let b = NonNullPrimitiveGroupValueBuilder::<$t>::new();
135+
$v.push(Box::new(b) as _)
136+
}
137+
};
138+
}
139+
119140
impl GroupValues for GroupValuesColumn {
120141
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
121142
let n_rows = cols[0].len();
@@ -126,54 +147,22 @@ impl GroupValues for GroupValuesColumn {
126147
for f in self.schema.fields().iter() {
127148
let nullable = f.is_nullable();
128149
match f.data_type() {
129-
&DataType::Int8 => {
130-
let b = PrimitiveGroupValueBuilder::<Int8Type>::new(nullable);
131-
v.push(Box::new(b) as _)
132-
}
133-
&DataType::Int16 => {
134-
let b = PrimitiveGroupValueBuilder::<Int16Type>::new(nullable);
135-
v.push(Box::new(b) as _)
136-
}
137-
&DataType::Int32 => {
138-
let b = PrimitiveGroupValueBuilder::<Int32Type>::new(nullable);
139-
v.push(Box::new(b) as _)
140-
}
141-
&DataType::Int64 => {
142-
let b = PrimitiveGroupValueBuilder::<Int64Type>::new(nullable);
143-
v.push(Box::new(b) as _)
144-
}
145-
&DataType::UInt8 => {
146-
let b = PrimitiveGroupValueBuilder::<UInt8Type>::new(nullable);
147-
v.push(Box::new(b) as _)
148-
}
149-
&DataType::UInt16 => {
150-
let b = PrimitiveGroupValueBuilder::<UInt16Type>::new(nullable);
151-
v.push(Box::new(b) as _)
152-
}
153-
&DataType::UInt32 => {
154-
let b = PrimitiveGroupValueBuilder::<UInt32Type>::new(nullable);
155-
v.push(Box::new(b) as _)
156-
}
157-
&DataType::UInt64 => {
158-
let b = PrimitiveGroupValueBuilder::<UInt64Type>::new(nullable);
159-
v.push(Box::new(b) as _)
160-
}
150+
&DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type),
151+
&DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type),
152+
&DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type),
153+
&DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type),
154+
&DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type),
155+
&DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type),
156+
&DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type),
157+
&DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type),
161158
&DataType::Float32 => {
162-
let b = PrimitiveGroupValueBuilder::<Float32Type>::new(nullable);
163-
v.push(Box::new(b) as _)
159+
instantiate_primitive!(v, nullable, Float32Type)
164160
}
165161
&DataType::Float64 => {
166-
let b = PrimitiveGroupValueBuilder::<Float64Type>::new(nullable);
167-
v.push(Box::new(b) as _)
168-
}
169-
&DataType::Date32 => {
170-
let b = PrimitiveGroupValueBuilder::<Date32Type>::new(nullable);
171-
v.push(Box::new(b) as _)
172-
}
173-
&DataType::Date64 => {
174-
let b = PrimitiveGroupValueBuilder::<Date64Type>::new(nullable);
175-
v.push(Box::new(b) as _)
162+
instantiate_primitive!(v, nullable, Float64Type)
176163
}
164+
&DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type),
165+
&DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type),
177166
&DataType::Utf8 => {
178167
let b = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
179168
v.push(Box::new(b) as _)

datafusion/physical-plan/src/aggregates/group_values/group_column.rs

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -62,62 +62,96 @@ pub trait GroupColumn: Send + Sync {
6262
fn take_n(&mut self, n: usize) -> ArrayRef;
6363
}
6464

65+
/// Stores a collection of primitive group values which are known to have no nulls
66+
#[derive(Debug)]
67+
pub struct NonNullPrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
68+
group_values: Vec<T::Native>,
69+
}
70+
71+
impl<T> NonNullPrimitiveGroupValueBuilder<T>
72+
where
73+
T: ArrowPrimitiveType,
74+
{
75+
pub fn new() -> Self {
76+
Self {
77+
group_values: vec![],
78+
}
79+
}
80+
}
81+
82+
impl<T: ArrowPrimitiveType> GroupColumn for NonNullPrimitiveGroupValueBuilder<T> {
83+
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
84+
// know input has no nulls
85+
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
86+
}
87+
88+
fn append_val(&mut self, array: &ArrayRef, row: usize) {
89+
// input can't possibly have nulls, so don't worry about them
90+
self.group_values.push(array.as_primitive::<T>().value(row))
91+
}
92+
93+
fn len(&self) -> usize {
94+
self.group_values.len()
95+
}
96+
97+
fn size(&self) -> usize {
98+
self.group_values.allocated_size()
99+
}
100+
101+
fn build(self: Box<Self>) -> ArrayRef {
102+
let Self { group_values } = *self;
103+
104+
let nulls = None;
105+
106+
Arc::new(PrimitiveArray::<T>::new(
107+
ScalarBuffer::from(group_values),
108+
nulls,
109+
))
110+
}
111+
112+
fn take_n(&mut self, n: usize) -> ArrayRef {
113+
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
114+
let first_n_nulls = None;
115+
116+
Arc::new(PrimitiveArray::<T>::new(
117+
ScalarBuffer::from(first_n),
118+
first_n_nulls,
119+
))
120+
}
121+
}
122+
123+
/// Stores a collection of primitive group values which may have nulls
124+
#[derive(Debug)]
65125
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
66126
group_values: Vec<T::Native>,
67-
/// Null state (when None, input is guaranteed not to have nulls)
68-
nulls: Option<MaybeNullBufferBuilder>,
127+
nulls: MaybeNullBufferBuilder,
69128
}
70129

71130
impl<T> PrimitiveGroupValueBuilder<T>
72131
where
73132
T: ArrowPrimitiveType,
74133
{
75-
/// Create a new [`PrimitiveGroupValueBuilder`]
76-
///
77-
/// If `nullable` is false, it means the input will never have nulls
78-
pub fn new(nullable: bool) -> Self {
79-
let nulls = if nullable {
80-
Some(MaybeNullBufferBuilder::new())
81-
} else {
82-
None
83-
};
84-
134+
pub fn new() -> Self {
85135
Self {
86136
group_values: vec![],
87-
nulls,
137+
nulls: MaybeNullBufferBuilder::new(),
88138
}
89139
}
90140
}
91141

92142
impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
93143
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
94-
// fast path when input has no nulls
95-
match self.nulls.as_ref() {
96-
None => {
97-
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
98-
}
99-
Some(nulls) => {
100-
// slower path if the input could have nulls
101-
nulls.is_null(lhs_row) == array.is_null(rhs_row)
102-
&& self.group_values[lhs_row]
103-
== array.as_primitive::<T>().value(rhs_row)
104-
}
105-
}
144+
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
145+
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
106146
}
107147

108148
fn append_val(&mut self, array: &ArrayRef, row: usize) {
109-
match self.nulls.as_mut() {
110-
// input can't possibly have nulls, so don't worry about them
111-
None => self.group_values.push(array.as_primitive::<T>().value(row)),
112-
Some(nulls) => {
113-
if array.is_null(row) {
114-
nulls.append(true);
115-
self.group_values.push(T::default_value());
116-
} else {
117-
nulls.append(false);
118-
self.group_values.push(array.as_primitive::<T>().value(row));
119-
}
120-
}
149+
if array.is_null(row) {
150+
self.nulls.append(true);
151+
self.group_values.push(T::default_value());
152+
} else {
153+
self.nulls.append(false);
154+
self.group_values.push(array.as_primitive::<T>().value(row));
121155
}
122156
}
123157

@@ -126,13 +160,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
126160
}
127161

128162
fn size(&self) -> usize {
129-
let nulls_size = self
130-
.nulls
131-
.as_ref()
132-
.map(|nulls| nulls.allocated_size())
133-
.unwrap_or(0);
134-
135-
self.group_values.allocated_size() + nulls_size
163+
self.group_values.allocated_size() + self.nulls.allocated_size()
136164
}
137165

138166
fn build(self: Box<Self>) -> ArrayRef {
@@ -141,7 +169,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
141169
nulls,
142170
} = *self;
143171

144-
let nulls = nulls.and_then(|nulls| nulls.build());
172+
let nulls = nulls.build();
145173

146174
Arc::new(PrimitiveArray::<T>::new(
147175
ScalarBuffer::from(group_values),
@@ -151,7 +179,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
151179

152180
fn take_n(&mut self, n: usize) -> ArrayRef {
153181
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
154-
let first_n_nulls = self.nulls.as_mut().and_then(|nulls| nulls.take_n(n));
182+
let first_n_nulls = self.nulls.take_n(n);
155183

156184
Arc::new(PrimitiveArray::<T>::new(
157185
ScalarBuffer::from(first_n),

0 commit comments

Comments
 (0)