Skip to content

Commit 3f22001

Browse files
committed
remove duplicate the logic b/w DataFrame API and SQL planning
1 parent c8a3d58 commit 3f22001

8 files changed

Lines changed: 330 additions & 59 deletions

File tree

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/src/dfschema.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,9 @@ impl ExprSchema for DFSchema {
630630
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
631631
pub struct DFField {
632632
/// Optional qualifier (usually a table or relation name)
633-
qualifier: Option<OwnedTableReference>,
633+
pub qualifier: Option<OwnedTableReference>,
634634
/// Arrow field definition
635-
field: Field,
635+
pub field: Field,
636636
}
637637

638638
impl DFField {

datafusion/core/tests/dataframe.rs

Lines changed: 132 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,119 @@ use datafusion::error::Result;
3232
use datafusion::execution::context::SessionContext;
3333
use datafusion::prelude::JoinType;
3434
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
35+
use datafusion::test_util::parquet_test_data;
3536
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
37+
use datafusion_common::ScalarValue;
3638
use datafusion_expr::expr::{GroupingSet, Sort};
37-
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
39+
use datafusion_expr::Expr::Wildcard;
40+
use datafusion_expr::{
41+
avg, col, count, expr, lit, max, sum, AggregateFunction, Expr, ExprSchemable,
42+
Subquery, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
43+
};
3844

3945
#[tokio::test]
40-
async fn count_wildcard() -> Result<()> {
46+
async fn test_count_wildcard_on_sort() -> Result<()> {
4147
let ctx = SessionContext::new();
42-
let testdata = datafusion::test_util::parquet_test_data();
48+
register_alltypes_tiny_pages_parquet(&ctx).await?;
4349

44-
ctx.register_parquet(
45-
"alltypes_tiny_pages",
46-
&format!("{testdata}/alltypes_tiny_pages.parquet"),
47-
ParquetReadOptions::default(),
50+
let sql_results=ctx.sql(
51+
"select string_col,count(*) from alltypes_tiny_pages group by string_col order by count(*)",
4852
)
49-
.await?;
53+
.await?
54+
.explain(false, false)?
55+
.collect().await?;
56+
57+
let df_results = ctx
58+
.table("alltypes_tiny_pages")
59+
.await?
60+
.aggregate(vec![col("string_col")], vec![count(Wildcard)])?
61+
.sort(vec![count(Wildcard).sort(true, false)])?
62+
.explain(false, false)?
63+
.collect()
64+
.await?;
65+
//make sure sql plan same with df plan
66+
assert_eq!(
67+
pretty_format_batches(&sql_results)?.to_string(),
68+
pretty_format_batches(&df_results)?.to_string()
69+
);
70+
Ok(())
71+
}
72+
73+
#[tokio::test]
74+
async fn test_count_wildcard_on_where_exist() -> Result<()> {
75+
let ctx = create_join_context()?;
76+
77+
let df_results = ctx
78+
.table("t1")
79+
.await?
80+
.filter(Expr::Exists {
81+
subquery: Subquery {
82+
subquery: Arc::new(
83+
ctx.table("t2")
84+
.await?
85+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
86+
.select(vec![count(Expr::Wildcard)])?
87+
.into_optimized_plan()?,
88+
),
89+
outer_ref_columns: vec![],
90+
},
91+
negated: false,
92+
})?
93+
.select(vec![col("a"), col("b")])?
94+
.explain(false, false)?
95+
.collect()
96+
.await?;
97+
#[rustfmt::skip]
98+
let expected = vec![
99+
"+--------------+-------------------------------------------------------+",
100+
"| plan_type | plan |",
101+
"+--------------+-------------------------------------------------------+",
102+
"| logical_plan | Filter: EXISTS (<subquery>) |",
103+
"| | Subquery: |",
104+
"| | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |",
105+
"| | TableScan: t2 projection=[a] |",
106+
"| | TableScan: t1 projection=[a, b] |",
107+
"+--------------+-------------------------------------------------------+",
108+
];
109+
assert_batches_eq!(expected, &df_results);
110+
Ok(())
111+
}
112+
113+
#[tokio::test]
114+
async fn test_count_wildcard_on_window() -> Result<()> {
115+
let ctx = SessionContext::new();
116+
117+
register_alltypes_tiny_pages_parquet(&ctx).await?;
50118

51119
let sql_results = ctx
52-
.sql("select count(*) from alltypes_tiny_pages")
120+
.sql("select COUNT(*) OVER(ORDER BY timestamp_col DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from alltypes_tiny_pages")
53121
.await?
54-
.select(vec![count(Expr::Wildcard)])?
55122
.explain(false, false)?
56123
.collect()
57124
.await?;
58125

59-
// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
60126
let df_results = ctx
61127
.table("alltypes_tiny_pages")
62128
.await?
63-
.aggregate(vec![], vec![count(Expr::Wildcard)])?
64-
.select(vec![count(Expr::Wildcard)])?
129+
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
130+
WindowFunction::AggregateFunction(AggregateFunction::Count),
131+
vec![Expr::Wildcard],
132+
vec![],
133+
vec![Expr::Sort(Sort::new(
134+
Box::new(col("timestamp_col")),
135+
false,
136+
true,
137+
))],
138+
WindowFrame {
139+
units: WindowFrameUnits::Range,
140+
start_bound: WindowFrameBound::Preceding(ScalarValue::IntervalDayTime(
141+
Some(6),
142+
)),
143+
end_bound: WindowFrameBound::Following(ScalarValue::IntervalDayTime(
144+
Some(2),
145+
)),
146+
},
147+
))])?
65148
.explain(false, false)?
66149
.collect()
67150
.await?;
@@ -72,21 +155,37 @@ async fn count_wildcard() -> Result<()> {
72155
pretty_format_batches(&df_results)?.to_string()
73156
);
74157

75-
let results = ctx
158+
Ok(())
159+
}
160+
161+
#[tokio::test]
162+
async fn test_count_wildcard_on_aggregate() -> Result<()> {
163+
let ctx = SessionContext::new();
164+
register_alltypes_tiny_pages_parquet(&ctx).await?;
165+
166+
let sql_results = ctx
167+
.sql("select count(*) from alltypes_tiny_pages")
168+
.await?
169+
.select(vec![count(Expr::Wildcard)])?
170+
.explain(false, false)?
171+
.collect()
172+
.await?;
173+
174+
// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
175+
let df_results = ctx
76176
.table("alltypes_tiny_pages")
77177
.await?
78178
.aggregate(vec![], vec![count(Expr::Wildcard)])?
179+
.select(vec![count(Expr::Wildcard)])?
180+
.explain(false, false)?
79181
.collect()
80182
.await?;
81183

82-
let expected = vec![
83-
"+-----------------+",
84-
"| COUNT(UInt8(1)) |",
85-
"+-----------------+",
86-
"| 7300 |",
87-
"+-----------------+",
88-
];
89-
assert_batches_sorted_eq!(expected, &results);
184+
//make sure sql plan same with df plan
185+
assert_eq!(
186+
pretty_format_batches(&sql_results)?.to_string(),
187+
pretty_format_batches(&df_results)?.to_string()
188+
);
90189

91190
Ok(())
92191
}
@@ -1047,3 +1146,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
10471146
ctx.register_batch("shapes", batch)?;
10481147
ctx.table("shapes").await
10491148
}
1149+
1150+
pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> {
1151+
let testdata = parquet_test_data();
1152+
ctx.register_parquet(
1153+
"alltypes_tiny_pages",
1154+
&format!("{testdata}/alltypes_tiny_pages.parquet"),
1155+
ParquetReadOptions::default(),
1156+
)
1157+
.await?;
1158+
Ok(())
1159+
}

datafusion/core/tests/sql/mod.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,15 +1161,6 @@ async fn try_execute_to_batches(
11611161
/// Execute query and return results as a Vec of RecordBatches
11621162
async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
11631163
let df = ctx.sql(sql).await.unwrap();
1164-
1165-
// We are not really interested in the direct output of optimized_logical_plan
1166-
// since the physical plan construction already optimizes the given logical plan
1167-
// and we want to avoid double-optimization as a consequence. So we just construct
1168-
// it here to make sure that it doesn't fail at this step and get the optimized
1169-
// schema (to assert later that the logical and optimized schemas are the same).
1170-
let optimized = df.clone().into_optimized_plan().unwrap();
1171-
assert_eq!(df.logical_plan().schema(), optimized.schema());
1172-
11731164
df.collect().await.unwrap()
11741165
}
11751166

0 commit comments

Comments
 (0)