Skip to content

Commit 2093f4c

Browse files
committed
Fix PartialOrd for ScalarUDF
1 parent b6d4d3b commit 2093f4c

1 file changed

Lines changed: 73 additions & 17 deletions

File tree

datafusion/expr/src/udf.rs

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,31 @@ impl PartialEq for ScalarUDF {
6767
}
6868
}
6969

70-
// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `ScalarUDF` and it should be
71-
// Manual implementation based on `ScalarUDFImpl::equals`
7270
impl PartialOrd for ScalarUDF {
7371
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
74-
match self.name().partial_cmp(other.name()) {
75-
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
76-
cmp => cmp,
72+
let mut cmp = self.name().cmp(other.name());
73+
if cmp == Ordering::Equal {
74+
cmp = self.signature().partial_cmp(other.signature())?;
7775
}
76+
if cmp == Ordering::Equal {
77+
cmp = self.aliases().partial_cmp(other.aliases())?;
78+
}
79+
// Contract for PartialOrd and PartialEq consistency requires that
80+
// a == b if and only if partial_cmp(a, b) == Some(Equal).
81+
if cmp == Ordering::Equal && self != other {
82+
// Functions may have other properties besides name and signature
83+
// that differentiate two instances (e.g. type, or arbitrary parameters).
84+
// We cannot return Some(Equal) in such case.
85+
return None;
86+
}
87+
debug_assert!(
88+
cmp == Ordering::Equal || self != other,
89+
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
90+
The functions compare as equal, but they are not equal based on general properties that \
91+
the PartialOrd implementation observes,",
92+
self.name(), other.name()
93+
);
94+
Some(cmp)
7895
}
7996
}
8097

@@ -932,23 +949,26 @@ The following regular expression functions are supported:"#,
932949
#[cfg(test)]
933950
mod tests {
934951
use super::*;
952+
use datafusion_expr_common::signature::Volatility;
935953
use std::hash::DefaultHasher;
936954

937955
#[derive(Debug, PartialEq, Eq, Hash)]
938956
struct TestScalarUDFImpl {
957+
name: &'static str,
939958
field: &'static str,
959+
signature: Signature,
940960
}
941961
impl ScalarUDFImpl for TestScalarUDFImpl {
942962
fn as_any(&self) -> &dyn Any {
943963
self
944964
}
945965

946966
fn name(&self) -> &str {
947-
"TestScalarUDFImpl"
967+
self.name
948968
}
949969

950970
fn signature(&self) -> &Signature {
951-
unimplemented!()
971+
&self.signature
952972
}
953973

954974
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
@@ -960,17 +980,53 @@ mod tests {
960980
}
961981
}
962982

983+
// PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
984+
// must be consistent, so they are tested together.
963985
#[test]
964-
fn test_partial_eq() {
965-
let a1 = ScalarUDF::from(TestScalarUDFImpl { field: "a" });
966-
let a2 = ScalarUDF::from(TestScalarUDFImpl { field: "a" });
967-
let b = ScalarUDF::from(TestScalarUDFImpl { field: "b" });
968-
let eq = a1 == a2;
969-
assert!(eq);
970-
assert_eq!(a1, a2);
971-
assert_eq!(hash(&a1), hash(&a2));
972-
assert_ne!(a1, b);
973-
assert_ne!(a2, b);
986+
fn test_partial_eq_hash_and_partial_ord() {
987+
// A parameterized function
988+
let f = test_func("foo", "a");
989+
990+
// Same like `f`, different instance
991+
let f2 = test_func("foo", "a");
992+
assert!({ f == f2 });
993+
assert_eq!(f, f2);
994+
assert_eq!(hash(&f), hash(&f2));
995+
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
996+
assert!(!(f < f2));
997+
assert!(!(f > f2));
998+
assert!(!(f2 < f));
999+
assert!(!(f2 > f));
1000+
1001+
// Different parameter
1002+
let b = test_func("foo", "b");
1003+
assert!(!{ f == b });
1004+
assert_ne!(f, b);
1005+
assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1006+
assert_eq!(f.partial_cmp(&b), None);
1007+
assert!(!(f < b));
1008+
assert!(!(f > b));
1009+
assert!(!(b < f));
1010+
assert!(!(b > f));
1011+
1012+
// Different name
1013+
let o = test_func("other", "b");
1014+
assert!(!{ f == o });
1015+
assert_ne!(f, o);
1016+
assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1017+
assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1018+
assert!(f < o);
1019+
assert!(!(f > o));
1020+
assert!(!(o < f));
1021+
assert!(o > f);
1022+
}
1023+
1024+
fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1025+
ScalarUDF::from(TestScalarUDFImpl {
1026+
name,
1027+
field: parameter,
1028+
signature: Signature::any(1, Volatility::Immutable),
1029+
})
9741030
}
9751031

9761032
fn hash<T: Hash>(value: &T) -> u64 {

0 commit comments

Comments
 (0)