Skip to content

Commit 4bed04e

Browse files
authored
Add customizable equality and hash functions to UDFs (#11392)
* Add customizable equality and hash functions to UDFs * Improve equals and hash_value documentation * Add tests for parameterized UDFs
1 parent d314ced commit 4bed04e

5 files changed

Lines changed: 367 additions & 44 deletions

File tree

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@
1818
//! This module contains end to end demonstrations of creating
1919
//! user defined aggregate functions
2020
21-
use arrow::{array::AsArray, datatypes::Fields};
22-
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
23-
use arrow_schema::Schema;
21+
use std::hash::{DefaultHasher, Hash, Hasher};
2422
use std::sync::{
2523
atomic::{AtomicBool, Ordering},
2624
Arc,
2725
};
2826

27+
use arrow::{array::AsArray, datatypes::Fields};
28+
use arrow_array::{
29+
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
30+
};
31+
use arrow_schema::Schema;
32+
33+
use datafusion::dataframe::DataFrame;
2934
use datafusion::datasource::MemTable;
3035
use datafusion::test_util::plan_and_collect;
3136
use datafusion::{
@@ -45,8 +50,8 @@ use datafusion::{
4550
};
4651
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
4752
use datafusion_expr::{
48-
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
49-
SimpleAggregateUDF,
53+
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
54+
LogicalPlanBuilder, SimpleAggregateUDF,
5055
};
5156
use datafusion_functions_aggregate::average::AvgAccumulator;
5257

@@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> {
377382
Ok(())
378383
}
379384

385+
#[tokio::test]
386+
async fn test_parameterized_aggregate_udf() -> Result<()> {
387+
let batch = RecordBatch::try_from_iter([(
388+
"text",
389+
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
390+
)])?;
391+
392+
let ctx = SessionContext::new();
393+
ctx.register_batch("t", batch)?;
394+
let t = ctx.table("t").await?;
395+
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
396+
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
397+
signature: signature.clone(),
398+
result: 1,
399+
});
400+
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
401+
signature: signature.clone(),
402+
result: 2,
403+
});
404+
405+
let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
406+
.aggregate(
407+
[col("text")],
408+
[
409+
udf1.call(vec![col("text")]).alias("a"),
410+
udf2.call(vec![col("text")]).alias("b"),
411+
],
412+
)?
413+
.build()?;
414+
415+
assert_eq!(
416+
format!("{plan:?}"),
417+
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
418+
);
419+
420+
let actual = DataFrame::new(ctx.state(), plan).collect().await?;
421+
let expected = [
422+
"+------+---+---+",
423+
"| text | a | b |",
424+
"+------+---+---+",
425+
"| foo | 1 | 2 |",
426+
"+------+---+---+",
427+
];
428+
assert_batches_eq!(expected, &actual);
429+
430+
ctx.deregister_table("t")?;
431+
Ok(())
432+
}
433+
380434
/// Returns an context with a table "t" and the "first" and "time_sum"
381435
/// aggregate functions registered.
382436
///
@@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
735789
) -> Result<Box<dyn GroupsAccumulator>> {
736790
Ok(Box::new(self.clone()))
737791
}
792+
793+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
794+
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
795+
self.result == other.result && self.signature == other.signature
796+
} else {
797+
false
798+
}
799+
}
800+
801+
fn hash_value(&self) -> u64 {
802+
let hasher = &mut DefaultHasher::new();
803+
self.signature.hash(hasher);
804+
self.result.hash(hasher);
805+
hasher.finish()
806+
}
738807
}
739808

740809
impl Accumulator for TestGroupsAccumulator {

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,20 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::hash::{DefaultHasher, Hash, Hasher};
1920
use std::sync::Arc;
2021

2122
use arrow::compute::kernels::numeric::add;
22-
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
23+
use arrow_array::builder::BooleanBuilder;
24+
use arrow_array::cast::AsArray;
25+
use arrow_array::{
26+
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
27+
};
2328
use arrow_schema::{DataType, Field, Schema};
29+
use parking_lot::Mutex;
30+
use regex::Regex;
31+
use sqlparser::ast::Ident;
32+
2433
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2534
use datafusion::prelude::*;
2635
use datafusion::{execution::registry::FunctionRegistry, test_util};
@@ -37,8 +46,6 @@ use datafusion_expr::{
3746
Volatility,
3847
};
3948
use datafusion_functions_array::range::range_udf;
40-
use parking_lot::Mutex;
41-
use sqlparser::ast::Ident;
4249

4350
/// test that casting happens on udfs.
4451
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
@@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<(
10211028
Ok(())
10221029
}
10231030

1031+
#[derive(Debug)]
1032+
struct MyRegexUdf {
1033+
signature: Signature,
1034+
regex: Regex,
1035+
}
1036+
1037+
impl MyRegexUdf {
1038+
fn new(pattern: &str) -> Self {
1039+
Self {
1040+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
1041+
regex: Regex::new(pattern).expect("regex"),
1042+
}
1043+
}
1044+
1045+
fn matches(&self, value: Option<&str>) -> Option<bool> {
1046+
Some(self.regex.is_match(value?))
1047+
}
1048+
}
1049+
1050+
impl ScalarUDFImpl for MyRegexUdf {
1051+
fn as_any(&self) -> &dyn Any {
1052+
self
1053+
}
1054+
1055+
fn name(&self) -> &str {
1056+
"regex_udf"
1057+
}
1058+
1059+
fn signature(&self) -> &Signature {
1060+
&self.signature
1061+
}
1062+
1063+
fn return_type(&self, args: &[DataType]) -> Result<DataType> {
1064+
if matches!(args, [DataType::Utf8]) {
1065+
Ok(DataType::Boolean)
1066+
} else {
1067+
plan_err!("regex_udf only accepts a Utf8 argument")
1068+
}
1069+
}
1070+
1071+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
1072+
match args {
1073+
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
1074+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
1075+
self.matches(value.as_deref()),
1076+
)))
1077+
}
1078+
[ColumnarValue::Array(values)] => {
1079+
let mut builder = BooleanBuilder::with_capacity(values.len());
1080+
for value in values.as_string::<i32>() {
1081+
builder.append_option(self.matches(value))
1082+
}
1083+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
1084+
}
1085+
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
1086+
}
1087+
}
1088+
1089+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1090+
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
1091+
self.regex.as_str() == other.regex.as_str()
1092+
} else {
1093+
false
1094+
}
1095+
}
1096+
1097+
fn hash_value(&self) -> u64 {
1098+
let hasher = &mut DefaultHasher::new();
1099+
self.regex.as_str().hash(hasher);
1100+
hasher.finish()
1101+
}
1102+
}
1103+
1104+
#[tokio::test]
1105+
async fn test_parameterized_scalar_udf() -> Result<()> {
1106+
let batch = RecordBatch::try_from_iter([(
1107+
"text",
1108+
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
1109+
)])?;
1110+
1111+
let ctx = SessionContext::new();
1112+
ctx.register_batch("t", batch)?;
1113+
let t = ctx.table("t").await?;
1114+
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
1115+
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));
1116+
1117+
let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
1118+
.filter(
1119+
foo_udf
1120+
.call(vec![col("text")])
1121+
.and(bar_udf.call(vec![col("text")])),
1122+
)?
1123+
.filter(col("text").is_not_null())?
1124+
.build()?;
1125+
1126+
assert_eq!(
1127+
format!("{plan:?}"),
1128+
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
1129+
);
1130+
1131+
let actual = DataFrame::new(ctx.state(), plan).collect().await?;
1132+
let expected = [
1133+
"+--------+",
1134+
"| text |",
1135+
"+--------+",
1136+
"| foobar |",
1137+
"| barfoo |",
1138+
"+--------+",
1139+
];
1140+
assert_batches_eq!(expected, &actual);
1141+
1142+
ctx.deregister_table("t")?;
1143+
Ok(())
1144+
}
1145+
10241146
fn create_udf_context() -> SessionContext {
10251147
let ctx = SessionContext::new();
10261148
// register a custom UDF

datafusion/expr/src/udaf.rs

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20+
use std::any::Any;
21+
use std::fmt::{self, Debug, Formatter};
22+
use std::hash::{DefaultHasher, Hash, Hasher};
23+
use std::sync::Arc;
24+
use std::vec;
25+
26+
use arrow::datatypes::{DataType, Field};
27+
use sqlparser::ast::NullTreatment;
28+
29+
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
30+
2031
use crate::expr::AggregateFunction;
2132
use crate::function::{
2233
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
@@ -26,13 +37,6 @@ use crate::utils::format_state_name;
2637
use crate::utils::AggregateOrderSensitivity;
2738
use crate::{Accumulator, Expr};
2839
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
29-
use arrow::datatypes::{DataType, Field};
30-
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
31-
use sqlparser::ast::NullTreatment;
32-
use std::any::Any;
33-
use std::fmt::{self, Debug, Formatter};
34-
use std::sync::Arc;
35-
use std::vec;
3640

3741
/// Logical representation of a user-defined [aggregate function] (UDAF).
3842
///
@@ -72,20 +76,19 @@ pub struct AggregateUDF {
7276

7377
impl PartialEq for AggregateUDF {
7478
fn eq(&self, other: &Self) -> bool {
75-
self.name() == other.name() && self.signature() == other.signature()
79+
self.inner.equals(other.inner.as_ref())
7680
}
7781
}
7882

7983
impl Eq for AggregateUDF {}
8084

81-
impl std::hash::Hash for AggregateUDF {
82-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83-
self.name().hash(state);
84-
self.signature().hash(state);
85+
impl Hash for AggregateUDF {
86+
fn hash<H: Hasher>(&self, state: &mut H) {
87+
self.inner.hash_value().hash(state)
8588
}
8689
}
8790

88-
impl std::fmt::Display for AggregateUDF {
91+
impl fmt::Display for AggregateUDF {
8992
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
9093
write!(f, "{}", self.name())
9194
}
@@ -280,7 +283,7 @@ where
280283
/// #[derive(Debug, Clone)]
281284
/// struct GeoMeanUdf {
282285
/// signature: Signature
283-
/// };
286+
/// }
284287
///
285288
/// impl GeoMeanUdf {
286289
/// fn new() -> Self {
@@ -507,6 +510,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
507510
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
508511
not_impl_err!("Function {} does not implement coerce_types", self.name())
509512
}
513+
514+
/// Return true if this aggregate UDF is equal to the other.
515+
///
516+
/// Allows customizing the equality of aggregate UDFs.
517+
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
518+
///
519+
/// - reflexive: `a.equals(a)`;
520+
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
521+
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
522+
///
523+
/// By default, compares [`Self::name`] and [`Self::signature`].
524+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
525+
self.name() == other.name() && self.signature() == other.signature()
526+
}
527+
528+
/// Returns a hash value for this aggregate UDF.
529+
///
530+
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
531+
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
532+
///
533+
/// By default, hashes [`Self::name`] and [`Self::signature`].
534+
fn hash_value(&self) -> u64 {
535+
let hasher = &mut DefaultHasher::new();
536+
self.name().hash(hasher);
537+
self.signature().hash(hasher);
538+
hasher.finish()
539+
}
510540
}
511541

512542
pub enum ReversedUDAF {
@@ -562,6 +592,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
562592
fn aliases(&self) -> &[String] {
563593
&self.aliases
564594
}
595+
596+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
597+
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
598+
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
599+
} else {
600+
false
601+
}
602+
}
603+
604+
fn hash_value(&self) -> u64 {
605+
let hasher = &mut DefaultHasher::new();
606+
self.inner.hash_value().hash(hasher);
607+
self.aliases.hash(hasher);
608+
hasher.finish()
609+
}
565610
}
566611

567612
/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers

0 commit comments

Comments
 (0)