@@ -24,6 +24,7 @@ use arrow::array::*;
2424use arrow:: buffer:: OffsetBuffer ;
2525use arrow:: compute;
2626use arrow:: datatypes:: { DataType , Field , UInt64Type } ;
27+ use arrow:: row:: { RowConverter , SortField } ;
2728use arrow_buffer:: NullBuffer ;
2829
2930use datafusion_common:: cast:: {
@@ -35,6 +36,7 @@ use datafusion_common::{
3536 DataFusionError , Result ,
3637} ;
3738
39+ use hashbrown:: HashSet ;
3840use itertools:: Itertools ;
3941
4042macro_rules! downcast_arg {
@@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
347349 let data_type = arrays[ 0 ] . data_type ( ) ;
348350 let field = Arc :: new ( Field :: new ( "item" , data_type. to_owned ( ) , true ) ) ;
349351 let elements = arrays. iter ( ) . map ( |x| x. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
350- let values = arrow :: compute:: concat ( elements. as_slice ( ) ) ?;
352+ let values = compute:: concat ( elements. as_slice ( ) ) ?;
351353 let list_arr = ListArray :: new (
352354 field,
353355 OffsetBuffer :: from_lengths ( array_lengths) ,
@@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
368370 . iter ( )
369371 . map ( |x| x as & dyn Array )
370372 . collect :: < Vec < _ > > ( ) ;
371- let values = arrow :: compute:: concat ( elements. as_slice ( ) ) ?;
373+ let values = compute:: concat ( elements. as_slice ( ) ) ?;
372374 let list_arr = ListArray :: new (
373375 field,
374376 OffsetBuffer :: from_lengths ( list_array_lengths) ,
@@ -801,7 +803,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
801803 . collect :: < Vec < & dyn Array > > ( ) ;
802804
803805 // Concatenated array on i-th row
804- let concated_array = arrow :: compute:: concat ( elements. as_slice ( ) ) ?;
806+ let concated_array = compute:: concat ( elements. as_slice ( ) ) ?;
805807 array_lengths. push ( concated_array. len ( ) ) ;
806808 arrays. push ( concated_array) ;
807809 valid. append ( true ) ;
@@ -819,7 +821,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
819821 let list_arr = ListArray :: new (
820822 Arc :: new ( Field :: new ( "item" , data_type, true ) ) ,
821823 OffsetBuffer :: from_lengths ( array_lengths) ,
822- Arc :: new ( arrow :: compute:: concat ( elements. as_slice ( ) ) ?) ,
824+ Arc :: new ( compute:: concat ( elements. as_slice ( ) ) ?) ,
823825 Some ( NullBuffer :: new ( buffer) ) ,
824826 ) ;
825827 Ok ( Arc :: new ( list_arr) )
@@ -913,7 +915,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
913915 }
914916
915917 let new_values: Vec < _ > = new_values. iter ( ) . map ( |a| a. as_ref ( ) ) . collect ( ) ;
916- let values = arrow :: compute:: concat ( & new_values) ?;
918+ let values = compute:: concat ( & new_values) ?;
917919
918920 Ok ( Arc :: new ( ListArray :: try_new (
919921 Arc :: new ( Field :: new ( "item" , data_type. to_owned ( ) , true ) ) ,
@@ -981,7 +983,7 @@ fn general_list_repeat(
981983
982984 let lengths = new_values. iter ( ) . map ( |a| a. len ( ) ) . collect :: < Vec < _ > > ( ) ;
983985 let new_values: Vec < _ > = new_values. iter ( ) . map ( |a| a. as_ref ( ) ) . collect ( ) ;
984- let values = arrow :: compute:: concat ( & new_values) ?;
986+ let values = compute:: concat ( & new_values) ?;
985987
986988 Ok ( Arc :: new ( ListArray :: try_new (
987989 Arc :: new ( Field :: new ( "item" , data_type. to_owned ( ) , true ) ) ,
@@ -1294,7 +1296,7 @@ fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> {
12941296 let data = mutable. freeze ( ) ;
12951297 let replaced_array = arrow_array:: make_array ( data) ;
12961298
1297- let v = arrow :: compute:: concat ( & [ & values, & replaced_array] ) ?;
1299+ let v = compute:: concat ( & [ & values, & replaced_array] ) ?;
12981300 values = v;
12991301 offsets. push ( last_offset + replaced_array. len ( ) as i32 ) ;
13001302 }
@@ -1807,6 +1809,61 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
18071809 Ok ( Arc :: new ( list_array) as ArrayRef )
18081810}
18091811
1812+ /// array_intersect SQL function
1813+ pub fn array_intersect ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
1814+ assert_eq ! ( args. len( ) , 2 ) ;
1815+
1816+ let first_array = as_list_array ( & args[ 0 ] ) ?;
1817+ let second_array = as_list_array ( & args[ 1 ] ) ?;
1818+
1819+ if first_array. value_type ( ) != second_array. value_type ( ) {
1820+ return internal_err ! ( "array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'" ) ;
1821+ }
1822+ let dt = first_array. value_type ( ) . clone ( ) ;
1823+
1824+ let mut offsets = vec ! [ 0 ] ;
1825+ let mut tmp_values = vec ! [ ] ;
1826+
1827+ let converter = RowConverter :: new ( vec ! [ SortField :: new( dt. clone( ) ) ] ) ?;
1828+ for ( first_arr, second_arr) in first_array. iter ( ) . zip ( second_array. iter ( ) ) {
1829+ if let ( Some ( first_arr) , Some ( second_arr) ) = ( first_arr, second_arr) {
1830+ let l_values = converter. convert_columns ( & [ first_arr] ) ?;
1831+ let r_values = converter. convert_columns ( & [ second_arr] ) ?;
1832+
1833+ let values_set: HashSet < _ > = l_values. iter ( ) . collect ( ) ;
1834+ let mut rows = Vec :: with_capacity ( r_values. num_rows ( ) ) ;
1835+ for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
1836+ if values_set. contains ( & r_val) {
1837+ rows. push ( r_val) ;
1838+ }
1839+ }
1840+
1841+ let last_offset: i32 = offsets. last ( ) . copied ( ) . ok_or_else ( || {
1842+ DataFusionError :: Internal ( format ! ( "offsets should not be empty" ) )
1843+ } ) ?;
1844+ offsets. push ( last_offset + rows. len ( ) as i32 ) ;
1845+ let tmp_value = converter. convert_rows ( rows) ?;
1846+ tmp_values. push (
1847+ tmp_value
1848+ . get ( 0 )
1849+ . ok_or_else ( || {
1850+ DataFusionError :: Internal ( format ! (
1851+ "array_intersect: failed to get value from rows"
1852+ ) )
1853+ } ) ?
1854+ . clone ( ) ,
1855+ ) ;
1856+ }
1857+ }
1858+
1859+ let field = Arc :: new ( Field :: new ( "item" , dt, true ) ) ;
1860+ let offsets = OffsetBuffer :: new ( offsets. into ( ) ) ;
1861+ let tmp_values_ref = tmp_values. iter ( ) . map ( |v| v. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
1862+ let values = compute:: concat ( & tmp_values_ref) ?;
1863+ let arr = Arc :: new ( ListArray :: try_new ( field, offsets, values, None ) ?) ;
1864+ Ok ( arr)
1865+ }
1866+
18101867#[ cfg( test) ]
18111868mod tests {
18121869 use super :: * ;
0 commit comments