Skip to content

Commit 6b7198f

Browse files
committed
Initial Implementation of array_intersect
Signed-off-by: veeupup <code@tanweime.com>
1 parent 4512805 commit 6b7198f

11 files changed

Lines changed: 232 additions & 19 deletions

File tree

datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,12 +1536,10 @@ mod test {
15361536
.unwrap()
15371537
.resolve(&schema)
15381538
.unwrap();
1539-
let r4 = apache_avro::to_value(serde_json::json!({
1540-
"col1": null
1541-
}))
1542-
.unwrap()
1543-
.resolve(&schema)
1544-
.unwrap();
1539+
let r4 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1540+
.unwrap()
1541+
.resolve(&schema)
1542+
.unwrap();
15451543

15461544
let mut w = apache_avro::Writer::new(&schema, vec![]);
15471545
w.append(r1).unwrap();
@@ -1600,12 +1598,10 @@ mod test {
16001598
}"#,
16011599
)
16021600
.unwrap();
1603-
let r1 = apache_avro::to_value(serde_json::json!({
1604-
"col1": null
1605-
}))
1606-
.unwrap()
1607-
.resolve(&schema)
1608-
.unwrap();
1601+
let r1 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1602+
.unwrap()
1603+
.resolve(&schema)
1604+
.unwrap();
16091605
let r2 = apache_avro::to_value(serde_json::json!({
16101606
"col1": {
16111607
"col2": "hello"

datafusion/expr/src/built_in_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction {
174174
ArraySlice,
175175
/// array_to_string
176176
ArrayToString,
177+
/// array_intersect
178+
ArrayIntersect,
177179
/// cardinality
178180
Cardinality,
179181
/// construct an array from columns
@@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
398400
BuiltinScalarFunction::Flatten => Volatility::Immutable,
399401
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
400402
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
403+
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
401404
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
402405
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
403406
BuiltinScalarFunction::Ascii => Volatility::Immutable,
@@ -577,6 +580,7 @@ impl BuiltinScalarFunction {
577580
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
578581
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
579582
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
583+
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
580584
BuiltinScalarFunction::Cardinality => Ok(UInt64),
581585
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
582586
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -880,6 +884,7 @@ impl BuiltinScalarFunction {
880884
BuiltinScalarFunction::ArrayToString => {
881885
Signature::variadic_any(self.volatility())
882886
}
887+
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
883888
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
884889
BuiltinScalarFunction::MakeArray => {
885890
// 0 or more arguments of arbitrary type
@@ -1505,6 +1510,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
15051510
],
15061511
BuiltinScalarFunction::Cardinality => &["cardinality"],
15071512
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
1513+
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"],
15081514

15091515
// struct functions
15101516
BuiltinScalarFunction::Struct => &["struct"],

datafusion/expr/src/expr_fn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,12 @@ nary_scalar_expr!(
715715
array,
716716
"returns an Arrow array using the specified input expressions."
717717
);
718+
scalar_expr!(
719+
ArrayIntersect,
720+
array_intersect,
721+
first_array second_array,
722+
"Returns an array of the elements in the intersection of array1 and array2."
723+
);
718724

719725
// string functions
720726
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use arrow::array::*;
2424
use arrow::buffer::OffsetBuffer;
2525
use arrow::compute;
2626
use arrow::datatypes::{DataType, Field, UInt64Type};
27+
use arrow::row::{RowConverter, SortField};
2728
use arrow_buffer::NullBuffer;
2829

2930
use datafusion_common::cast::{
@@ -35,6 +36,7 @@ use datafusion_common::{
3536
DataFusionError, Result,
3637
};
3738

39+
use hashbrown::HashSet;
3840
use itertools::Itertools;
3941

4042
macro_rules! downcast_arg {
@@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
347349
let data_type = arrays[0].data_type();
348350
let field = Arc::new(Field::new("item", data_type.to_owned(), true));
349351
let elements = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
350-
let values = arrow::compute::concat(elements.as_slice())?;
352+
let values = compute::concat(elements.as_slice())?;
351353
let list_arr = ListArray::new(
352354
field,
353355
OffsetBuffer::from_lengths(array_lengths),
@@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
368370
.iter()
369371
.map(|x| x as &dyn Array)
370372
.collect::<Vec<_>>();
371-
let values = arrow::compute::concat(elements.as_slice())?;
373+
let values = compute::concat(elements.as_slice())?;
372374
let list_arr = ListArray::new(
373375
field,
374376
OffsetBuffer::from_lengths(list_array_lengths),
@@ -801,7 +803,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
801803
.collect::<Vec<&dyn Array>>();
802804

803805
// Concatenated array on i-th row
804-
let concated_array = arrow::compute::concat(elements.as_slice())?;
806+
let concated_array = compute::concat(elements.as_slice())?;
805807
array_lengths.push(concated_array.len());
806808
arrays.push(concated_array);
807809
valid.append(true);
@@ -819,7 +821,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
819821
let list_arr = ListArray::new(
820822
Arc::new(Field::new("item", data_type, true)),
821823
OffsetBuffer::from_lengths(array_lengths),
822-
Arc::new(arrow::compute::concat(elements.as_slice())?),
824+
Arc::new(compute::concat(elements.as_slice())?),
823825
Some(NullBuffer::new(buffer)),
824826
);
825827
Ok(Arc::new(list_arr))
@@ -913,7 +915,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
913915
}
914916

915917
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
916-
let values = arrow::compute::concat(&new_values)?;
918+
let values = compute::concat(&new_values)?;
917919

918920
Ok(Arc::new(ListArray::try_new(
919921
Arc::new(Field::new("item", data_type.to_owned(), true)),
@@ -981,7 +983,7 @@ fn general_list_repeat(
981983

982984
let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
983985
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
984-
let values = arrow::compute::concat(&new_values)?;
986+
let values = compute::concat(&new_values)?;
985987

986988
Ok(Arc::new(ListArray::try_new(
987989
Arc::new(Field::new("item", data_type.to_owned(), true)),
@@ -1294,7 +1296,7 @@ fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> {
12941296
let data = mutable.freeze();
12951297
let replaced_array = arrow_array::make_array(data);
12961298

1297-
let v = arrow::compute::concat(&[&values, &replaced_array])?;
1299+
let v = compute::concat(&[&values, &replaced_array])?;
12981300
values = v;
12991301
offsets.push(last_offset + replaced_array.len() as i32);
13001302
}
@@ -1807,6 +1809,61 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
18071809
Ok(Arc::new(list_array) as ArrayRef)
18081810
}
18091811

1812+
/// array_intersect SQL function
1813+
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
1814+
assert_eq!(args.len(), 2);
1815+
1816+
let first_array = as_list_array(&args[0])?;
1817+
let second_array = as_list_array(&args[1])?;
1818+
1819+
if first_array.value_type() != second_array.value_type() {
1820+
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
1821+
}
1822+
let dt = first_array.value_type().clone();
1823+
1824+
let mut offsets = vec![0];
1825+
let mut tmp_values = vec![];
1826+
1827+
let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
1828+
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
1829+
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
1830+
let l_values = converter.convert_columns(&[first_arr])?;
1831+
let r_values = converter.convert_columns(&[second_arr])?;
1832+
1833+
let values_set: HashSet<_> = l_values.iter().collect();
1834+
let mut rows = Vec::with_capacity(r_values.num_rows());
1835+
for r_val in r_values.iter().sorted().dedup() {
1836+
if values_set.contains(&r_val) {
1837+
rows.push(r_val);
1838+
}
1839+
}
1840+
1841+
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
1842+
DataFusionError::Internal(format!("offsets should not be empty"))
1843+
})?;
1844+
offsets.push(last_offset + rows.len() as i32);
1845+
let tmp_value = converter.convert_rows(rows)?;
1846+
tmp_values.push(
1847+
tmp_value
1848+
.get(0)
1849+
.ok_or_else(|| {
1850+
DataFusionError::Internal(format!(
1851+
"array_intersect: failed to get value from rows"
1852+
))
1853+
})?
1854+
.clone(),
1855+
);
1856+
}
1857+
}
1858+
1859+
let field = Arc::new(Field::new("item", dt, true));
1860+
let offsets = OffsetBuffer::new(offsets.into());
1861+
let tmp_values_ref = tmp_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
1862+
let values = compute::concat(&tmp_values_ref)?;
1863+
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
1864+
Ok(arr)
1865+
}
1866+
18101867
#[cfg(test)]
18111868
mod tests {
18121869
use super::*;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,9 @@ pub fn create_physical_fun(
532532
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
533533
make_scalar_function(array_expressions::array_to_string)(args)
534534
}),
535+
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
536+
make_scalar_function(array_expressions::array_intersect)(args)
537+
}),
535538
BuiltinScalarFunction::Cardinality => {
536539
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
537540
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ enum ScalarFunction {
621621
ArrayPopBack = 116;
622622
StringToArray = 117;
623623
ToTimestampNanos = 118;
624+
ArrayIntersect = 119;
624625
}
625626

626627
message ScalarFunctionNode {

datafusion/proto/src/generated/prost.rs

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
482482
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
483483
ScalarFunction::ArraySlice => Self::ArraySlice,
484484
ScalarFunction::ArrayToString => Self::ArrayToString,
485+
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
485486
ScalarFunction::Cardinality => Self::Cardinality,
486487
ScalarFunction::Array => Self::MakeArray,
487488
ScalarFunction::NullIf => Self::NullIf,

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
14811481
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
14821482
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
14831483
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
1484+
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
14841485
BuiltinScalarFunction::Cardinality => Self::Cardinality,
14851486
BuiltinScalarFunction::MakeArray => Self::Array,
14861487
BuiltinScalarFunction::NullIf => Self::NullIf,

0 commit comments

Comments
 (0)