Skip to content

Commit 245f0b8

Browse files
authored
support large-utf8 in groupby (#35)
* support large-utf8 in groupby * add test
1 parent 32951c3 commit 245f0b8

5 files changed

Lines changed: 79 additions & 4 deletions

File tree

datafusion/src/execution/context.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,57 @@ mod tests {
16461646
Ok(())
16471647
}
16481648

1649+
#[tokio::test]
1650+
async fn group_by_largeutf8() {
1651+
{
1652+
let mut ctx = ExecutionContext::new();
1653+
1654+
// input data looks like:
1655+
// A, 1
1656+
// B, 2
1657+
// A, 2
1658+
// A, 4
1659+
// C, 1
1660+
// A, 1
1661+
1662+
let str_array: LargeStringArray = vec!["A", "B", "A", "A", "C", "A"]
1663+
.into_iter()
1664+
.map(Some)
1665+
.collect();
1666+
let str_array = Arc::new(str_array);
1667+
1668+
let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into();
1669+
let val_array = Arc::new(val_array);
1670+
1671+
let schema = Arc::new(Schema::new(vec![
1672+
Field::new("str", str_array.data_type().clone(), false),
1673+
Field::new("val", val_array.data_type().clone(), false),
1674+
]));
1675+
1676+
let batch =
1677+
RecordBatch::try_new(schema.clone(), vec![str_array, val_array]).unwrap();
1678+
1679+
let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
1680+
ctx.register_table("t", Arc::new(provider)).unwrap();
1681+
1682+
let results =
1683+
plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t GROUP BY str")
1684+
.await
1685+
.expect("ran plan correctly");
1686+
1687+
let expected = vec![
1688+
"+-----+------------+",
1689+
"| str | COUNT(val) |",
1690+
"+-----+------------+",
1691+
"| A | 4 |",
1692+
"| B | 1 |",
1693+
"| C | 1 |",
1694+
"+-----+------------+",
1695+
];
1696+
assert_batches_sorted_eq!(expected, &results);
1697+
}
1698+
}
1699+
16491700
#[tokio::test]
16501701
async fn group_by_dictionary() {
16511702
async fn run_test_case<K: ArrowDictionaryKeyType>() {

datafusion/src/physical_plan/group_scalar.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub(crate) enum GroupByScalar {
3737
Int32(i32),
3838
Int64(i64),
3939
Utf8(Box<String>),
40+
LargeUtf8(Box<String>),
4041
Boolean(bool),
4142
TimeMillisecond(i64),
4243
TimeMicrosecond(i64),
@@ -74,6 +75,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
7475
GroupByScalar::TimeNanosecond(*v)
7576
}
7677
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
78+
ScalarValue::LargeUtf8(Some(v)) => {
79+
GroupByScalar::LargeUtf8(Box::new(v.clone()))
80+
}
7781
ScalarValue::Float32(None)
7882
| ScalarValue::Float64(None)
7983
| ScalarValue::Boolean(None)
@@ -116,6 +120,7 @@ impl From<&GroupByScalar> for ScalarValue {
116120
GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
117121
GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
118122
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
123+
GroupByScalar::LargeUtf8(v) => ScalarValue::LargeUtf8(Some(v.to_string())),
119124
GroupByScalar::TimeMillisecond(v) => {
120125
ScalarValue::TimestampMillisecond(Some(*v))
121126
}
@@ -191,14 +196,14 @@ mod tests {
191196
#[test]
192197
fn from_scalar_unsupported() {
193198
// Use any ScalarValue type not supported by GroupByScalar.
194-
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
199+
let scalar_value = ScalarValue::Binary(Some(vec![1, 2]));
195200
let result = GroupByScalar::try_from(&scalar_value);
196201

197202
match result {
198203
Err(DataFusionError::Internal(error_message)) => assert_eq!(
199204
error_message,
200205
String::from(
201-
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
206+
"Cannot convert a ScalarValue with associated DataType Binary"
202207
)
203208
),
204209
_ => panic!("Unexpected result"),

datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ use ordered_float::OrderedFloat;
5959
use pin_project_lite::pin_project;
6060

6161
use arrow::array::{
62-
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
62+
LargeStringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
63+
TimestampNanosecondArray,
6364
};
6465
use async_trait::async_trait;
6566

@@ -540,6 +541,14 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
540541
// store the string value
541542
vec.extend_from_slice(value.as_bytes());
542543
}
544+
DataType::LargeUtf8 => {
545+
let array = col.as_any().downcast_ref::<LargeStringArray>().unwrap();
546+
let value = array.value(row);
547+
// store the size
548+
vec.extend_from_slice(&value.len().to_le_bytes());
549+
// store the string value
550+
vec.extend_from_slice(value.as_bytes());
551+
}
543552
DataType::Date32 => {
544553
let array = col.as_any().downcast_ref::<Date32Array>().unwrap();
545554
vec.extend_from_slice(&array.value(row).to_le_bytes());
@@ -953,6 +962,9 @@ fn create_batch_from_map(
953962
GroupByScalar::Utf8(str) => {
954963
Arc::new(StringArray::from(vec![&***str]))
955964
}
965+
GroupByScalar::LargeUtf8(str) => {
966+
Arc::new(LargeStringArray::from(vec![&***str]))
967+
}
956968
GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])),
957969
GroupByScalar::TimeMillisecond(n) => {
958970
Arc::new(TimestampMillisecondArray::from(vec![*n]))
@@ -1103,6 +1115,10 @@ fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<GroupByScalar> {
11031115
let array = col.as_any().downcast_ref::<StringArray>().unwrap();
11041116
Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
11051117
}
1118+
DataType::LargeUtf8 => {
1119+
let array = col.as_any().downcast_ref::<LargeStringArray>().unwrap();
1120+
Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
1121+
}
11061122
DataType::Boolean => {
11071123
let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
11081124
Ok(GroupByScalar::Boolean(array.value(row)))

datafusion/src/physical_plan/hash_join.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ pub fn create_hashes<'a>(
831831
DataType::Utf8 => {
832832
hash_array!(StringArray, col, str, hashes_buffer, random_state);
833833
}
834+
DataType::LargeUtf8 => {
835+
hash_array!(LargeStringArray, col, str, hashes_buffer, random_state);
836+
}
834837
_ => {
835838
// This is internal because we should have caught this before.
836839
return Err(DataFusionError::Internal(

datafusion/src/physical_plan/type_coercion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
196196
| Float64
197197
),
198198
Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)),
199-
Utf8 => true,
199+
Utf8 | LargeUtf8 => true,
200200
_ => false,
201201
}
202202
}

0 commit comments

Comments
 (0)