Skip to content

Commit c50b82d

Browse files
committed
Derive WindowUDFImpl equality, hash from Eq, Hash traits
Previously, the `WindowUDFImpl` trait contained `equals` and `hash_value` methods with contracts following the `Eq` and `Hash` traits. However, the existence of default implementations of these methods made it error-prone, with many functions (scalar, aggregate, window) missing to customize the equals even though they ought to. There is no fix to this that's not an API breaking change, so a breaking change is warranted. Removing the default implementations would be enough of a solution, but at the cost of a lot of boilerplate needed in implementations. Instead, this removes the methods from the trait, and reuses `DynEq`, `DynHash` traits used previously only for physical expressions. This allows for functions to provide their implementations using no more than `#[derive(PartialEq, Eq, Hash)]` in a typical case.
1 parent 5987a22 commit c50b82d

18 files changed

Lines changed: 88 additions & 118 deletions

File tree

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use datafusion::prelude::*;
4343
/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance.
4444
///
4545
/// To do so, we must implement the `WindowUDFImpl` trait.
46-
#[derive(Debug, Clone)]
46+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4747
struct SmoothItUdf {
4848
signature: Signature,
4949
}
@@ -149,7 +149,7 @@ impl PartitionEvaluator for MyPartitionEvaluator {
149149
}
150150

151151
/// This UDWF will show how to use the WindowUDFImpl::simplify() API
152-
#[derive(Debug, Clone)]
152+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
153153
struct SimplifySmoothItUdf {
154154
signature: Signature,
155155
}

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ use datafusion::prelude::SessionContext;
3030
use datafusion_common::exec_datafusion_err;
3131
use datafusion_expr::ptr_eq::PtrEq;
3232
use datafusion_expr::{
33-
udf_equals_hash, PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF,
34-
WindowUDFImpl,
33+
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
3534
};
3635
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
3736
use datafusion_functions_window_common::{
@@ -42,7 +41,7 @@ use datafusion_physical_expr::{
4241
PhysicalExpr,
4342
};
4443
use std::collections::HashMap;
45-
use std::hash::{DefaultHasher, Hash, Hasher};
44+
use std::hash::{Hash, Hasher};
4645
use std::{
4746
any::Any,
4847
ops::Range,
@@ -571,8 +570,6 @@ impl OddCounter {
571570
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
572571
Ok(Field::new(field_args.name(), DataType::Int64, true).into())
573572
}
574-
575-
udf_equals_hash!(WindowUDFImpl);
576573
}
577574

578575
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
@@ -648,7 +645,7 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef {
648645
Arc::new(array)
649646
}
650647

651-
#[derive(Debug)]
648+
#[derive(Debug, PartialEq, Eq, Hash)]
652649
struct VariadicWindowUDF {
653650
signature: Signature,
654651
}
@@ -770,6 +767,31 @@ struct MetadataBasedWindowUdf {
770767
metadata: HashMap<String, String>,
771768
}
772769

770+
impl PartialEq for MetadataBasedWindowUdf {
771+
fn eq(&self, other: &Self) -> bool {
772+
let Self {
773+
name,
774+
signature,
775+
metadata,
776+
} = self;
777+
name == &other.name
778+
&& signature == &other.signature
779+
&& metadata == &other.metadata
780+
}
781+
}
782+
impl Eq for MetadataBasedWindowUdf {}
783+
impl Hash for MetadataBasedWindowUdf {
784+
fn hash<H: Hasher>(&self, state: &mut H) {
785+
let Self {
786+
name,
787+
signature,
788+
metadata: _, // unhashable
789+
} = self;
790+
name.hash(state);
791+
signature.hash(state);
792+
}
793+
}
794+
773795
impl MetadataBasedWindowUdf {
774796
fn new(metadata: HashMap<String, String>) -> Self {
775797
// The name we return must be unique. Otherwise we will not call distinct
@@ -820,33 +842,6 @@ impl WindowUDFImpl for MetadataBasedWindowUdf {
820842
.with_metadata(self.metadata.clone())
821843
.into())
822844
}
823-
824-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
825-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
826-
return false;
827-
};
828-
let Self {
829-
name,
830-
signature,
831-
metadata,
832-
} = self;
833-
name == &other.name
834-
&& signature == &other.signature
835-
&& metadata == &other.metadata
836-
}
837-
838-
fn hash_value(&self) -> u64 {
839-
let Self {
840-
name,
841-
signature,
842-
metadata: _, // unhashable
843-
} = self;
844-
let mut hasher = DefaultHasher::new();
845-
std::any::type_name::<Self>().hash(&mut hasher);
846-
name.hash(&mut hasher);
847-
signature.hash(&mut hasher);
848-
hasher.finish()
849-
}
850845
}
851846

852847
#[derive(Debug)]

datafusion/expr-common/src/dyn_eq.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,43 @@ use std::hash::{Hash, Hasher};
2525
/// they must be [`DynEq`]-equal if and only if they are [`PartialEq`]-equal.
2626
/// It is therefore strongly discouraged to implement this trait for types
2727
/// that implement `PartialEq<Other>` or `Eq<Other>` for any type `Other` other than `Self`.
28+
///
29+
/// Note: This trait should not be implemented directly. Implement `Eq` and `Any` and use
30+
/// the blanket implementation.
2831
pub trait DynEq {
2932
fn dyn_eq(&self, other: &dyn Any) -> bool;
33+
34+
fn i_did_not_implement_the_trait_directly_but_using_the_blanked_impl_instead()
35+
where
36+
Self: Sized;
3037
}
3138

3239
impl<T: Eq + Any> DynEq for T {
3340
fn dyn_eq(&self, other: &dyn Any) -> bool {
3441
other.downcast_ref::<Self>() == Some(self)
3542
}
43+
44+
fn i_did_not_implement_the_trait_directly_but_using_the_blanked_impl_instead() {}
3645
}
3746

3847
/// A dyn-compatible version of [`Hash`] trait.
3948
/// If two values are equal according to [`DynEq`], they must produce the same hash value.
49+
///
50+
/// Note: This trait should not be implemented directly. Implement `Hash` and `Any` and use
51+
/// the blanket implementation.
4052
pub trait DynHash {
4153
fn dyn_hash(&self, _state: &mut dyn Hasher);
54+
55+
fn i_did_not_implement_the_trait_directly_but_using_the_blanked_impl_instead()
56+
where
57+
Self: Sized;
4258
}
4359

4460
impl<T: Hash + Any> DynHash for T {
4561
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
4662
self.type_id().hash(&mut state);
4763
self.hash(&mut state)
4864
}
65+
66+
fn i_did_not_implement_the_trait_directly_but_using_the_blanked_impl_instead() {}
4967
}

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,6 @@ impl WindowUDFImpl for SimpleWindowUDF {
695695
true,
696696
)))
697697
}
698-
699-
udf_equals_hash!(WindowUDFImpl);
700698
}
701699

702700
pub fn interval_year_month_lit(value: &str) -> Expr {

datafusion/expr/src/udf_eq.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl};
1919
use std::fmt::Debug;
20-
use std::hash::{Hash, Hasher};
20+
use std::hash::{DefaultHasher, Hash, Hasher};
2121
use std::ops::Deref;
2222
use std::sync::Arc;
2323

@@ -97,7 +97,18 @@ macro_rules! impl_for_udf_eq {
9797

9898
impl_for_udf_eq!(dyn AggregateUDFImpl + '_);
9999
impl_for_udf_eq!(dyn ScalarUDFImpl + '_);
100-
impl_for_udf_eq!(dyn WindowUDFImpl + '_);
100+
101+
impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
102+
fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
103+
self.as_ref().dyn_eq(other.as_any())
104+
}
105+
106+
fn hash_value(&self) -> u64 {
107+
let hasher = &mut DefaultHasher::new();
108+
self.as_ref().dyn_hash(hasher);
109+
hasher.finish()
110+
}
111+
}
101112

102113
#[cfg(test)]
103114
mod tests {

datafusion/expr/src/udwf.rs

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use arrow::compute::SortOptions;
2121
use std::cmp::Ordering;
22-
use std::hash::{DefaultHasher, Hash, Hasher};
22+
use std::hash::{Hash, Hasher};
2323
use std::{
2424
any::Any,
2525
fmt::{self, Debug, Display, Formatter},
@@ -31,11 +31,11 @@ use arrow::datatypes::{DataType, FieldRef};
3131
use crate::expr::WindowFunction;
3232
use crate::udf_eq::UdfEq;
3333
use crate::{
34-
function::WindowFunctionSimplification, udf_equals_hash, Expr, PartitionEvaluator,
35-
Signature,
34+
function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature,
3635
};
3736
use datafusion_common::{not_impl_err, Result};
3837
use datafusion_doc::Documentation;
38+
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
3939
use datafusion_functions_window_common::expr::ExpressionArgs;
4040
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
4141
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
@@ -82,15 +82,15 @@ impl Display for WindowUDF {
8282

8383
impl PartialEq for WindowUDF {
8484
fn eq(&self, other: &Self) -> bool {
85-
self.inner.equals(other.inner.as_ref())
85+
self.inner.dyn_eq(&other.inner)
8686
}
8787
}
8888

8989
impl Eq for WindowUDF {}
9090

9191
impl Hash for WindowUDF {
9292
fn hash<H: Hasher>(&self, state: &mut H) {
93-
self.inner.hash_value().hash(state)
93+
self.inner.dyn_hash(state)
9494
}
9595
}
9696

@@ -246,7 +246,7 @@ where
246246
/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
247247
/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
248248
///
249-
/// #[derive(Debug, Clone)]
249+
/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
250250
/// struct SmoothIt {
251251
/// signature: Signature,
252252
/// }
@@ -305,7 +305,7 @@ where
305305
/// .build()
306306
/// .unwrap();
307307
/// ```
308-
pub trait WindowUDFImpl: Debug + Send + Sync {
308+
pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync {
309309
/// Returns this object as an [`Any`] trait object
310310
fn as_any(&self) -> &dyn Any;
311311

@@ -358,41 +358,6 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
358358
None
359359
}
360360

361-
/// Return true if this window UDF is equal to the other.
362-
///
363-
/// Allows customizing the equality of window UDFs.
364-
/// *Must* be implemented explicitly if the UDF type has internal state.
365-
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
366-
///
367-
/// - reflexive: `a.equals(a)`;
368-
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
369-
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
370-
///
371-
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
372-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
373-
self.as_any().type_id() == other.as_any().type_id()
374-
&& self.name() == other.name()
375-
&& self.aliases() == other.aliases()
376-
&& self.signature() == other.signature()
377-
}
378-
379-
/// Returns a hash value for this window UDF.
380-
///
381-
/// Allows customizing the hash code of window UDFs.
382-
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
383-
///
384-
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
385-
/// their `hash_value`s must be the same.
386-
///
387-
/// By default, it only hashes the type. The other fields are not hashed, as usually the
388-
/// name, signature, and aliases are implied by the UDF type. Recall that UDFs with state
389-
/// (and thus possibly changing fields) must override [`Self::equals`] and [`Self::hash_value`].
390-
fn hash_value(&self) -> u64 {
391-
let hasher = &mut DefaultHasher::new();
392-
self.as_any().type_id().hash(hasher);
393-
hasher.finish()
394-
}
395-
396361
/// The [`FieldRef`] of the final result of evaluating this window function.
397362
///
398363
/// Call `field_args.name()` to get the fully qualified name for defining
@@ -461,7 +426,7 @@ pub enum ReversedUDWF {
461426

462427
impl PartialEq for dyn WindowUDFImpl {
463428
fn eq(&self, other: &Self) -> bool {
464-
self.equals(other)
429+
self.dyn_eq(other.as_any())
465430
}
466431
}
467432

@@ -533,8 +498,6 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
533498
self.inner.simplify()
534499
}
535500

536-
udf_equals_hash!(WindowUDFImpl);
537-
538501
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
539502
self.inner.field(field_args)
540503
}
@@ -598,7 +561,7 @@ mod test {
598561
use std::any::Any;
599562
use std::cmp::Ordering;
600563

601-
#[derive(Debug, Clone)]
564+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
602565
struct AWindowUDF {
603566
signature: Signature,
604567
}
@@ -637,7 +600,7 @@ mod test {
637600
}
638601
}
639602

640-
#[derive(Debug, Clone)]
603+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
641604
struct BWindowUDF {
642605
signature: Signature,
643606
}

datafusion/ffi/src/udwf/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use arrow::{
2525
datatypes::{DataType, SchemaRef},
2626
};
2727
use arrow_schema::{Field, FieldRef};
28-
use datafusion::logical_expr::udf_equals_hash;
2928
use datafusion::{
3029
error::DataFusionError,
3130
logical_expr::{
@@ -349,8 +348,6 @@ impl WindowUDFImpl for ForeignWindowUDF {
349348
let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
350349
options.map(|s| s.into())
351350
}
352-
353-
udf_equals_hash!(WindowUDFImpl);
354351
}
355352

356353
#[repr(C)]

datafusion/functions-window/src/cume_dist.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ FROM employees;
6363
```
6464
"#
6565
)]
66-
#[derive(Debug)]
66+
#[derive(Debug, PartialEq, Eq, Hash)]
6767
pub struct CumeDist {
6868
signature: Signature,
6969
}

datafusion/functions-window/src/lead_lag.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use datafusion_common::arrow::datatypes::Field;
2525
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
2626
use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
2727
use datafusion_expr::{
28-
udf_equals_hash, Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29-
TypeSignature, Volatility, WindowUDFImpl,
28+
Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
29+
Volatility, WindowUDFImpl,
3030
};
3131
use datafusion_functions_window_common::expr::ExpressionArgs;
3232
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
@@ -299,8 +299,6 @@ impl WindowUDFImpl for WindowShift {
299299
WindowShiftKind::Lead => Some(get_lead_doc()),
300300
}
301301
}
302-
303-
udf_equals_hash!(WindowUDFImpl);
304302
}
305303

306304
/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to

0 commit comments

Comments
 (0)