@@ -29,14 +29,20 @@ use std::{
2929
3030use arrow:: array:: AsArray ;
3131use arrow_array:: { ArrayRef , Int64Array , RecordBatch , StringArray } ;
32- use arrow_schema:: { DataType , Field } ;
32+ use arrow_schema:: { DataType , Field , Schema } ;
3333use datafusion:: { assert_batches_eq, prelude:: SessionContext } ;
3434use datafusion_common:: { Result , ScalarValue } ;
3535use datafusion_expr:: {
36- PartitionEvaluator , Signature , Volatility , WindowUDF , WindowUDFImpl ,
36+ PartitionEvaluator , Signature , TypeSignature , Volatility , WindowUDF , WindowUDFImpl ,
3737} ;
38- use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
3938use datafusion_functions_window_common:: partition:: PartitionEvaluatorArgs ;
39+ use datafusion_functions_window_common:: {
40+ expr:: ExpressionArgs , field:: WindowUDFFieldArgs ,
41+ } ;
42+ use datafusion_physical_expr:: {
43+ expressions:: { col, lit} ,
44+ PhysicalExpr ,
45+ } ;
4046
4147/// A query with a window function evaluated over the entire partition
4248const UNBOUNDED_WINDOW_QUERY : & str = "SELECT x, y, val, \
@@ -641,3 +647,120 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef {
641647 let array: Int64Array = std:: iter:: repeat ( odd_count ( arr) ) . take ( num_rows) . collect ( ) ;
642648 Arc :: new ( array)
643649}
650+
651+ #[ derive( Debug ) ]
652+ struct VariadicWindowUDF {
653+ signature : Signature ,
654+ }
655+
656+ impl VariadicWindowUDF {
657+ fn new ( ) -> Self {
658+ Self {
659+ signature : Signature :: one_of (
660+ vec ! [
661+ TypeSignature :: Any ( 0 ) ,
662+ TypeSignature :: Any ( 1 ) ,
663+ TypeSignature :: Any ( 2 ) ,
664+ TypeSignature :: Any ( 3 ) ,
665+ ] ,
666+ Volatility :: Immutable ,
667+ ) ,
668+ }
669+ }
670+ }
671+
672+ impl WindowUDFImpl for VariadicWindowUDF {
673+ fn as_any ( & self ) -> & dyn Any {
674+ self
675+ }
676+
677+ fn name ( & self ) -> & str {
678+ "variadic_window_udf"
679+ }
680+
681+ fn signature ( & self ) -> & Signature {
682+ & self . signature
683+ }
684+
685+ fn partition_evaluator (
686+ & self ,
687+ _: PartitionEvaluatorArgs ,
688+ ) -> Result < Box < dyn PartitionEvaluator > > {
689+ unimplemented ! ( "unnecessary for testing" ) ;
690+ }
691+
692+ fn field ( & self , _: WindowUDFFieldArgs ) -> Result < Field > {
693+ unimplemented ! ( "unnecessary for testing" ) ;
694+ }
695+ }
696+
697+ #[ test]
698+ // Fixes: default implementation of `WindowUDFImpl::expressions`
699+ // returns all input expressions to the user-defined window
700+ // function unmodified.
701+ //
702+ // See: https://github.com/apache/datafusion/pull/13169
703+ fn test_default_expressions ( ) -> Result < ( ) > {
704+ let udwf = WindowUDF :: from ( VariadicWindowUDF :: new ( ) ) ;
705+
706+ let field_a = Field :: new ( "a" , DataType :: Int32 , false ) ;
707+ let field_b = Field :: new ( "b" , DataType :: Float32 , false ) ;
708+ let field_c = Field :: new ( "c" , DataType :: Boolean , false ) ;
709+ let schema = Schema :: new ( vec ! [ field_a, field_b, field_c] ) ;
710+
711+ let test_cases = vec ! [
712+ //
713+ // Zero arguments
714+ //
715+ vec![ ] ,
716+ //
717+ // Single argument
718+ //
719+ vec![ col( "a" , & schema) ?] ,
720+ vec![ lit( 1 ) ] ,
721+ //
722+ // Two arguments
723+ //
724+ vec![ col( "a" , & schema) ?, col( "b" , & schema) ?] ,
725+ vec![ col( "a" , & schema) ?, lit( 2 ) ] ,
726+ vec![ lit( false ) , col( "a" , & schema) ?] ,
727+ //
728+ // Three arguments
729+ //
730+ vec![ col( "a" , & schema) ?, col( "b" , & schema) ?, col( "c" , & schema) ?] ,
731+ vec![ col( "a" , & schema) ?, col( "b" , & schema) ?, lit( false ) ] ,
732+ vec![ col( "a" , & schema) ?, lit( 0.5 ) , col( "c" , & schema) ?] ,
733+ vec![ lit( 3 ) , col( "b" , & schema) ?, col( "c" , & schema) ?] ,
734+ ] ;
735+
736+ for input_exprs in & test_cases {
737+ let input_types = input_exprs
738+ . iter ( )
739+ . map ( |expr : & Arc < dyn PhysicalExpr > | expr. data_type ( & schema) . unwrap ( ) )
740+ . collect :: < Vec < _ > > ( ) ;
741+ let expr_args = ExpressionArgs :: new ( input_exprs, & input_types) ;
742+
743+ let ret_exprs = udwf. expressions ( expr_args) ;
744+
745+ // Verify same number of input expressions are returned
746+ assert_eq ! (
747+ input_exprs. len( ) ,
748+ ret_exprs. len( ) ,
749+ "\n Input expressions: {:?}\n Returned expressions: {:?}" ,
750+ input_exprs,
751+ ret_exprs
752+ ) ;
753+
754+ // Compares each returned expression with original input expressions
755+ for ( expected, actual) in input_exprs. iter ( ) . zip ( & ret_exprs) {
756+ assert_eq ! (
757+ format!( "{expected:?}" ) ,
758+ format!( "{actual:?}" ) ,
759+ "\n Input expressions: {:?}\n Returned expressions: {:?}" ,
760+ input_exprs,
761+ ret_exprs
762+ ) ;
763+ }
764+ }
765+ Ok ( ( ) )
766+ }
0 commit comments