@@ -67,14 +67,31 @@ impl PartialEq for ScalarUDF {
6767 }
6868}
6969
70- // TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `ScalarUDF` and it should be
71- // Manual implementation based on `ScalarUDFImpl::equals`
7270impl PartialOrd for ScalarUDF {
7371 fn partial_cmp ( & self , other : & Self ) -> Option < Ordering > {
74- match self . name ( ) . partial_cmp ( other. name ( ) ) {
75- Some ( Ordering :: Equal ) => self . signature ( ) . partial_cmp ( other . signature ( ) ) ,
76- cmp => cmp ,
72+ let mut cmp = self . name ( ) . cmp ( other. name ( ) ) ;
73+ if cmp == Ordering :: Equal {
74+ cmp = self . signature ( ) . partial_cmp ( other . signature ( ) ) ? ;
7775 }
76+ if cmp == Ordering :: Equal {
77+ cmp = self . aliases ( ) . partial_cmp ( other. aliases ( ) ) ?;
78+ }
79+ // Contract for PartialOrd and PartialEq consistency requires that
80+ // a == b if and only if partial_cmp(a, b) == Some(Equal).
81+ if cmp == Ordering :: Equal && self != other {
82+ // Functions may have other properties besides name and signature
83+ // that differentiate two instances (e.g. type, or arbitrary parameters).
84+ // We cannot return Some(Equal) in such case.
85+ return None ;
86+ }
87+ debug_assert ! (
88+ cmp == Ordering :: Equal || self != other,
89+ "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
90+ The functions compare as equal, but they are not equal based on general properties that \
91+ the PartialOrd implementation observes,",
92+ self . name( ) , other. name( )
93+ ) ;
94+ Some ( cmp)
7895 }
7996}
8097
@@ -932,23 +949,26 @@ The following regular expression functions are supported:"#,
932949#[ cfg( test) ]
933950mod tests {
934951 use super :: * ;
952+ use datafusion_expr_common:: signature:: Volatility ;
935953 use std:: hash:: DefaultHasher ;
936954
937955 #[ derive( Debug , PartialEq , Eq , Hash ) ]
938956 struct TestScalarUDFImpl {
957+ name : & ' static str ,
939958 field : & ' static str ,
959+ signature : Signature ,
940960 }
941961 impl ScalarUDFImpl for TestScalarUDFImpl {
942962 fn as_any ( & self ) -> & dyn Any {
943963 self
944964 }
945965
946966 fn name ( & self ) -> & str {
947- "TestScalarUDFImpl"
967+ self . name
948968 }
949969
950970 fn signature ( & self ) -> & Signature {
951- unimplemented ! ( )
971+ & self . signature
952972 }
953973
954974 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
@@ -960,17 +980,53 @@ mod tests {
960980 }
961981 }
962982
983+ // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
984+ // must be consistent, so they are tested together.
963985 #[ test]
964- fn test_partial_eq ( ) {
965- let a1 = ScalarUDF :: from ( TestScalarUDFImpl { field : "a" } ) ;
966- let a2 = ScalarUDF :: from ( TestScalarUDFImpl { field : "a" } ) ;
967- let b = ScalarUDF :: from ( TestScalarUDFImpl { field : "b" } ) ;
968- let eq = a1 == a2;
969- assert ! ( eq) ;
970- assert_eq ! ( a1, a2) ;
971- assert_eq ! ( hash( & a1) , hash( & a2) ) ;
972- assert_ne ! ( a1, b) ;
973- assert_ne ! ( a2, b) ;
986+ fn test_partial_eq_hash_and_partial_ord ( ) {
987+ // A parameterized function
988+ let f = test_func ( "foo" , "a" ) ;
989+
990+ // Same like `f`, different instance
991+ let f2 = test_func ( "foo" , "a" ) ;
992+ assert ! ( { f == f2 } ) ;
993+ assert_eq ! ( f, f2) ;
994+ assert_eq ! ( hash( & f) , hash( & f2) ) ;
995+ assert_eq ! ( f. partial_cmp( & f2) , Some ( Ordering :: Equal ) ) ;
996+ assert ! ( !( f < f2) ) ;
997+ assert ! ( !( f > f2) ) ;
998+ assert ! ( !( f2 < f) ) ;
999+ assert ! ( !( f2 > f) ) ;
1000+
1001+ // Different parameter
1002+ let b = test_func ( "foo" , "b" ) ;
1003+ assert ! ( !{ f == b } ) ;
1004+ assert_ne ! ( f, b) ;
1005+ assert_ne ! ( hash( & f) , hash( & b) ) ; // hash can collide for different values but does not collide in this test
1006+ assert_eq ! ( f. partial_cmp( & b) , None ) ;
1007+ assert ! ( !( f < b) ) ;
1008+ assert ! ( !( f > b) ) ;
1009+ assert ! ( !( b < f) ) ;
1010+ assert ! ( !( b > f) ) ;
1011+
1012+ // Different name
1013+ let o = test_func ( "other" , "b" ) ;
1014+ assert ! ( !{ f == o } ) ;
1015+ assert_ne ! ( f, o) ;
1016+ assert_ne ! ( hash( & f) , hash( & o) ) ; // hash can collide for different values but does not collide in this test
1017+ assert_eq ! ( f. partial_cmp( & o) , Some ( Ordering :: Less ) ) ;
1018+ assert ! ( f < o) ;
1019+ assert ! ( !( f > o) ) ;
1020+ assert ! ( !( o < f) ) ;
1021+ assert ! ( o > f) ;
1022+ }
1023+
1024+ fn test_func ( name : & ' static str , parameter : & ' static str ) -> ScalarUDF {
1025+ ScalarUDF :: from ( TestScalarUDFImpl {
1026+ name,
1027+ field : parameter,
1028+ signature : Signature :: any ( 1 , Volatility :: Immutable ) ,
1029+ } )
9741030 }
9751031
9761032 fn hash < T : Hash > ( value : & T ) -> u64 {
0 commit comments