@@ -32,36 +32,119 @@ use datafusion::error::Result;
3232use datafusion:: execution:: context:: SessionContext ;
3333use datafusion:: prelude:: JoinType ;
3434use datafusion:: prelude:: { CsvReadOptions , ParquetReadOptions } ;
35+ use datafusion:: test_util:: parquet_test_data;
3536use datafusion:: { assert_batches_eq, assert_batches_sorted_eq} ;
37+ use datafusion_common:: ScalarValue ;
3638use datafusion_expr:: expr:: { GroupingSet , Sort } ;
37- use datafusion_expr:: { avg, col, count, lit, max, sum, Expr , ExprSchemable } ;
39+ use datafusion_expr:: Expr :: Wildcard ;
40+ use datafusion_expr:: {
41+ avg, col, count, expr, lit, max, sum, AggregateFunction , Expr , ExprSchemable ,
42+ Subquery , WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunction ,
43+ } ;
3844
3945#[ tokio:: test]
40- async fn count_wildcard ( ) -> Result < ( ) > {
46+ async fn test_count_wildcard_on_sort ( ) -> Result < ( ) > {
4147 let ctx = SessionContext :: new ( ) ;
42- let testdata = datafusion :: test_util :: parquet_test_data ( ) ;
48+ register_alltypes_tiny_pages_parquet ( & ctx ) . await ? ;
4349
44- ctx. register_parquet (
45- "alltypes_tiny_pages" ,
46- & format ! ( "{testdata}/alltypes_tiny_pages.parquet" ) ,
47- ParquetReadOptions :: default ( ) ,
50+ let sql_results=ctx. sql (
51+ "select string_col,count(*) from alltypes_tiny_pages group by string_col order by count(*)" ,
4852 )
49- . await ?;
53+ . await ?
54+ . explain ( false , false ) ?
55+ . collect ( ) . await ?;
56+
57+ let df_results = ctx
58+ . table ( "alltypes_tiny_pages" )
59+ . await ?
60+ . aggregate ( vec ! [ col( "string_col" ) ] , vec ! [ count( Wildcard ) ] ) ?
61+ . sort ( vec ! [ count( Wildcard ) . sort( true , false ) ] ) ?
62+ . explain ( false , false ) ?
63+ . collect ( )
64+ . await ?;
65+ //make sure sql plan same with df plan
66+ assert_eq ! (
67+ pretty_format_batches( & sql_results) ?. to_string( ) ,
68+ pretty_format_batches( & df_results) ?. to_string( )
69+ ) ;
70+ Ok ( ( ) )
71+ }
72+
73+ #[ tokio:: test]
74+ async fn test_count_wildcard_on_where_exist ( ) -> Result < ( ) > {
75+ let ctx = create_join_context ( ) ?;
76+
77+ let df_results = ctx
78+ . table ( "t1" )
79+ . await ?
80+ . filter ( Expr :: Exists {
81+ subquery : Subquery {
82+ subquery : Arc :: new (
83+ ctx. table ( "t2" )
84+ . await ?
85+ . aggregate ( vec ! [ ] , vec ! [ count( Expr :: Wildcard ) ] ) ?
86+ . select ( vec ! [ count( Expr :: Wildcard ) ] ) ?
87+ . into_optimized_plan ( ) ?,
88+ ) ,
89+ outer_ref_columns : vec ! [ ] ,
90+ } ,
91+ negated : false ,
92+ } ) ?
93+ . select ( vec ! [ col( "a" ) , col( "b" ) ] ) ?
94+ . explain ( false , false ) ?
95+ . collect ( )
96+ . await ?;
97+ #[ rustfmt:: skip]
98+ let expected = vec ! [
99+ "+--------------+-------------------------------------------------------+" ,
100+ "| plan_type | plan |" ,
101+ "+--------------+-------------------------------------------------------+" ,
102+ "| logical_plan | Filter: EXISTS (<subquery>) |" ,
103+ "| | Subquery: |" ,
104+ "| | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |" ,
105+ "| | TableScan: t2 projection=[a] |" ,
106+ "| | TableScan: t1 projection=[a, b] |" ,
107+ "+--------------+-------------------------------------------------------+" ,
108+ ] ;
109+ assert_batches_eq ! ( expected, & df_results) ;
110+ Ok ( ( ) )
111+ }
112+
113+ #[ tokio:: test]
114+ async fn test_count_wildcard_on_window ( ) -> Result < ( ) > {
115+ let ctx = SessionContext :: new ( ) ;
116+
117+ register_alltypes_tiny_pages_parquet ( & ctx) . await ?;
50118
51119 let sql_results = ctx
52- . sql ( "select count (*) from alltypes_tiny_pages" )
120+ . sql ( "select COUNT (*) OVER(ORDER BY timestamp_col DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from alltypes_tiny_pages" )
53121 . await ?
54- . select ( vec ! [ count( Expr :: Wildcard ) ] ) ?
55122 . explain ( false , false ) ?
56123 . collect ( )
57124 . await ?;
58125
59- // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
60126 let df_results = ctx
61127 . table ( "alltypes_tiny_pages" )
62128 . await ?
63- . aggregate ( vec ! [ ] , vec ! [ count( Expr :: Wildcard ) ] ) ?
64- . select ( vec ! [ count( Expr :: Wildcard ) ] ) ?
129+ . select ( vec ! [ Expr :: WindowFunction ( expr:: WindowFunction :: new(
130+ WindowFunction :: AggregateFunction ( AggregateFunction :: Count ) ,
131+ vec![ Expr :: Wildcard ] ,
132+ vec![ ] ,
133+ vec![ Expr :: Sort ( Sort :: new(
134+ Box :: new( col( "timestamp_col" ) ) ,
135+ false ,
136+ true ,
137+ ) ) ] ,
138+ WindowFrame {
139+ units: WindowFrameUnits :: Range ,
140+ start_bound: WindowFrameBound :: Preceding ( ScalarValue :: IntervalDayTime (
141+ Some ( 6 ) ,
142+ ) ) ,
143+ end_bound: WindowFrameBound :: Following ( ScalarValue :: IntervalDayTime (
144+ Some ( 2 ) ,
145+ ) ) ,
146+ } ,
147+ ) ) ] ) ?
65148 . explain ( false , false ) ?
66149 . collect ( )
67150 . await ?;
@@ -72,21 +155,37 @@ async fn count_wildcard() -> Result<()> {
72155 pretty_format_batches( & df_results) ?. to_string( )
73156 ) ;
74157
75- let results = ctx
158+ Ok ( ( ) )
159+ }
160+
161+ #[ tokio:: test]
162+ async fn test_count_wildcard_on_aggregate ( ) -> Result < ( ) > {
163+ let ctx = SessionContext :: new ( ) ;
164+ register_alltypes_tiny_pages_parquet ( & ctx) . await ?;
165+
166+ let sql_results = ctx
167+ . sql ( "select count(*) from alltypes_tiny_pages" )
168+ . await ?
169+ . select ( vec ! [ count( Expr :: Wildcard ) ] ) ?
170+ . explain ( false , false ) ?
171+ . collect ( )
172+ . await ?;
173+
174+ // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
175+ let df_results = ctx
76176 . table ( "alltypes_tiny_pages" )
77177 . await ?
78178 . aggregate ( vec ! [ ] , vec ! [ count( Expr :: Wildcard ) ] ) ?
179+ . select ( vec ! [ count( Expr :: Wildcard ) ] ) ?
180+ . explain ( false , false ) ?
79181 . collect ( )
80182 . await ?;
81183
82- let expected = vec ! [
83- "+-----------------+" ,
84- "| COUNT(UInt8(1)) |" ,
85- "+-----------------+" ,
86- "| 7300 |" ,
87- "+-----------------+" ,
88- ] ;
89- assert_batches_sorted_eq ! ( expected, & results) ;
184+ //make sure sql plan same with df plan
185+ assert_eq ! (
186+ pretty_format_batches( & sql_results) ?. to_string( ) ,
187+ pretty_format_batches( & df_results) ?. to_string( )
188+ ) ;
90189
91190 Ok ( ( ) )
92191}
@@ -1047,3 +1146,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
10471146 ctx. register_batch ( "shapes" , batch) ?;
10481147 ctx. table ( "shapes" ) . await
10491148}
1149+
1150+ pub async fn register_alltypes_tiny_pages_parquet ( ctx : & SessionContext ) -> Result < ( ) > {
1151+ let testdata = parquet_test_data ( ) ;
1152+ ctx. register_parquet (
1153+ "alltypes_tiny_pages" ,
1154+ & format ! ( "{testdata}/alltypes_tiny_pages.parquet" ) ,
1155+ ParquetReadOptions :: default ( ) ,
1156+ )
1157+ . await ?;
1158+ Ok ( ( ) )
1159+ }
0 commit comments