Skip to content

Commit 68c042d

Browse files
fix: default UDWFImpl::expressions returns all expressions (#13169)
* fix: default UDWFImpl::expressions returns all expressions * Add unit test to check for window function inputs * Add unit test to catch errors in udwf with multiple column arguments * remove unnecessary qualification from user defined window test * cargo fmt --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent cf76421 commit 68c042d

2 files changed

Lines changed: 127 additions & 7 deletions

File tree

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,20 @@ use std::{
2929

3030
use arrow::array::AsArray;
3131
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
32-
use arrow_schema::{DataType, Field};
32+
use arrow_schema::{DataType, Field, Schema};
3333
use datafusion::{assert_batches_eq, prelude::SessionContext};
3434
use datafusion_common::{Result, ScalarValue};
3535
use datafusion_expr::{
36-
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
36+
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
3737
};
38-
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3938
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
39+
use datafusion_functions_window_common::{
40+
expr::ExpressionArgs, field::WindowUDFFieldArgs,
41+
};
42+
use datafusion_physical_expr::{
43+
expressions::{col, lit},
44+
PhysicalExpr,
45+
};
4046

4147
/// A query with a window function evaluated over the entire partition
4248
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
@@ -641,3 +647,120 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef {
641647
let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect();
642648
Arc::new(array)
643649
}
650+
651+
#[derive(Debug)]
652+
struct VariadicWindowUDF {
653+
signature: Signature,
654+
}
655+
656+
impl VariadicWindowUDF {
657+
fn new() -> Self {
658+
Self {
659+
signature: Signature::one_of(
660+
vec![
661+
TypeSignature::Any(0),
662+
TypeSignature::Any(1),
663+
TypeSignature::Any(2),
664+
TypeSignature::Any(3),
665+
],
666+
Volatility::Immutable,
667+
),
668+
}
669+
}
670+
}
671+
672+
impl WindowUDFImpl for VariadicWindowUDF {
673+
fn as_any(&self) -> &dyn Any {
674+
self
675+
}
676+
677+
fn name(&self) -> &str {
678+
"variadic_window_udf"
679+
}
680+
681+
fn signature(&self) -> &Signature {
682+
&self.signature
683+
}
684+
685+
fn partition_evaluator(
686+
&self,
687+
_: PartitionEvaluatorArgs,
688+
) -> Result<Box<dyn PartitionEvaluator>> {
689+
unimplemented!("unnecessary for testing");
690+
}
691+
692+
fn field(&self, _: WindowUDFFieldArgs) -> Result<Field> {
693+
unimplemented!("unnecessary for testing");
694+
}
695+
}
696+
697+
#[test]
698+
// Fixes: default implementation of `WindowUDFImpl::expressions`
699+
// returns all input expressions to the user-defined window
700+
// function unmodified.
701+
//
702+
// See: https://github.com/apache/datafusion/pull/13169
703+
fn test_default_expressions() -> Result<()> {
704+
let udwf = WindowUDF::from(VariadicWindowUDF::new());
705+
706+
let field_a = Field::new("a", DataType::Int32, false);
707+
let field_b = Field::new("b", DataType::Float32, false);
708+
let field_c = Field::new("c", DataType::Boolean, false);
709+
let schema = Schema::new(vec![field_a, field_b, field_c]);
710+
711+
let test_cases = vec![
712+
//
713+
// Zero arguments
714+
//
715+
vec![],
716+
//
717+
// Single argument
718+
//
719+
vec![col("a", &schema)?],
720+
vec![lit(1)],
721+
//
722+
// Two arguments
723+
//
724+
vec![col("a", &schema)?, col("b", &schema)?],
725+
vec![col("a", &schema)?, lit(2)],
726+
vec![lit(false), col("a", &schema)?],
727+
//
728+
// Three arguments
729+
//
730+
vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?],
731+
vec![col("a", &schema)?, col("b", &schema)?, lit(false)],
732+
vec![col("a", &schema)?, lit(0.5), col("c", &schema)?],
733+
vec![lit(3), col("b", &schema)?, col("c", &schema)?],
734+
];
735+
736+
for input_exprs in &test_cases {
737+
let input_types = input_exprs
738+
.iter()
739+
.map(|expr: &Arc<dyn PhysicalExpr>| expr.data_type(&schema).unwrap())
740+
.collect::<Vec<_>>();
741+
let expr_args = ExpressionArgs::new(input_exprs, &input_types);
742+
743+
let ret_exprs = udwf.expressions(expr_args);
744+
745+
// Verify same number of input expressions are returned
746+
assert_eq!(
747+
input_exprs.len(),
748+
ret_exprs.len(),
749+
"\nInput expressions: {:?}\nReturned expressions: {:?}",
750+
input_exprs,
751+
ret_exprs
752+
);
753+
754+
// Compares each returned expression with original input expressions
755+
for (expected, actual) in input_exprs.iter().zip(&ret_exprs) {
756+
assert_eq!(
757+
format!("{expected:?}"),
758+
format!("{actual:?}"),
759+
"\nInput expressions: {:?}\nReturned expressions: {:?}",
760+
input_exprs,
761+
ret_exprs
762+
);
763+
}
764+
}
765+
Ok(())
766+
}

datafusion/expr/src/udwf.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,7 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
312312

313313
/// Returns the expressions that are passed to the [`PartitionEvaluator`].
314314
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
315-
expr_args
316-
.input_exprs()
317-
.first()
318-
.map_or(vec![], |expr| vec![Arc::clone(expr)])
315+
expr_args.input_exprs().into()
319316
}
320317

321318
/// Invoke the function, returning the [`PartitionEvaluator`] instance

0 commit comments

Comments
 (0)