diff --git a/datafusion/tests/test_aggregation.py b/datafusion/tests/test_aggregation.py index b274e18cf..2c8c064b1 100644 --- a/datafusion/tests/test_aggregation.py +++ b/datafusion/tests/test_aggregation.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column +from datafusion import SessionContext, column, lit from datafusion import functions as f @@ -28,8 +29,12 @@ def df(): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], - names=["a", "b"], + [ + pa.array([1, 2, 3]), + pa.array([4, 4, 6]), + pa.array([9, 8, 5]), + ], + names=["a", "b", "c"], ) return ctx.create_dataframe([[batch]]) @@ -37,12 +42,86 @@ def df(): def test_built_in_aggregation(df): col_a = column("a") col_b = column("b") - df = df.aggregate( + col_c = column("c") + + agg_df = df.aggregate( [], - [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], + [ + f.approx_distinct(col_b), + f.approx_median(col_b), + f.approx_percentile_cont(col_b, lit(0.5)), + f.approx_percentile_cont_with_weight(col_b, lit(0.6), lit(0.5)), + f.array_agg(col_b), + f.avg(col_a), + f.corr(col_a, col_b), + f.count(col_a), + f.covar(col_a, col_b), + f.covar_pop(col_a, col_c), + f.covar_samp(col_b, col_c), + # f.grouping(col_a), # No physical plan implemented yet + f.max(col_a), + f.mean(col_b), + f.median(col_b), + f.min(col_a), + f.sum(col_b), + f.stddev(col_a), + f.stddev_pop(col_b), + f.stddev_samp(col_c), + f.var(col_a), + f.var_pop(col_b), + f.var_samp(col_c), + ], + ) + result = agg_df.collect()[0] + values_a, values_b, values_c = df.collect()[0] + + assert result.column(0) == pa.array([2], type=pa.uint64()) + assert result.column(1) == pa.array([4]) + assert result.column(2) == pa.array([4]) + assert result.column(3) == pa.array([6]) + assert result.column(4) == pa.array([[4, 4, 6]]) + np.testing.assert_array_almost_equal( + result.column(5), np.average(values_a) + ) + np.testing.assert_array_almost_equal( + result.column(6), np.corrcoef(values_a, values_b)[0][1] + ) + assert result.column(7) == pa.array([len(values_a)]) + # Sample (co)variance -> ddof=1 + # Population (co)variance -> ddof=0 + np.testing.assert_array_almost_equal( + result.column(8), np.cov(values_a, values_b, ddof=1)[0][1] + ) + np.testing.assert_array_almost_equal( + result.column(9), np.cov(values_a, values_c, ddof=0)[0][1] + ) + np.testing.assert_array_almost_equal( + result.column(10), np.cov(values_b, values_c, ddof=1)[0][1] + ) + np.testing.assert_array_almost_equal(result.column(11), np.max(values_a)) + np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b)) + np.testing.assert_array_almost_equal( + result.column(13), np.median(values_b) + ) + np.testing.assert_array_almost_equal(result.column(14), np.min(values_a)) + np.testing.assert_array_almost_equal( + result.column(15), np.sum(values_b.to_pylist()) + ) + np.testing.assert_array_almost_equal( + result.column(16), np.std(values_a, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(17), np.std(values_b, ddof=0) + ) + np.testing.assert_array_almost_equal( + result.column(18), np.std(values_c, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(19), np.var(values_a, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(20), np.var(values_b, ddof=0) + ) + np.testing.assert_array_almost_equal( + result.column(21), np.var(values_c, ddof=1) ) - result = df.collect()[0] - assert result.column(0) == pa.array([3]) - assert result.column(1) == pa.array([1]) - assert result.column(2) == pa.array([3], type=pa.int64()) - assert result.column(3) == pa.array([2], type=pa.uint64()) diff --git a/src/functions.rs b/src/functions.rs index ac1077ea5..8847dabdc 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -287,25 +287,49 @@ scalar_function!(upper, Upper, "Converts the string to all upper case."); scalar_function!(make_array, MakeArray); scalar_function!(array, MakeArray); scalar_function!(nullif, NullIf); -//scalar_function!(uuid, Uuid); -//scalar_function!(struct, Struct); +scalar_function!(uuid, Uuid); +scalar_function!(r#struct, Struct); // Use raw identifier since struct is a keyword scalar_function!(from_unixtime, FromUnixtime); scalar_function!(arrow_typeof, ArrowTypeof); scalar_function!(random, Random); +aggregate_function!(approx_distinct, ApproxDistinct); +aggregate_function!(approx_median, ApproxMedian); +aggregate_function!(approx_percentile_cont, ApproxPercentileCont); +aggregate_function!( + approx_percentile_cont_with_weight, + ApproxPercentileContWithWeight +); +aggregate_function!(array_agg, ArrayAgg); aggregate_function!(avg, Avg); +aggregate_function!(corr, Correlation); aggregate_function!(count, Count); +aggregate_function!(covar, Covariance); +aggregate_function!(covar_pop, CovariancePop); +aggregate_function!(covar_samp, Covariance); +aggregate_function!(grouping, Grouping); aggregate_function!(max, Max); +aggregate_function!(mean, Avg); +aggregate_function!(median, Median); aggregate_function!(min, Min); aggregate_function!(sum, Sum); -aggregate_function!(approx_distinct, ApproxDistinct); +aggregate_function!(stddev, Stddev); +aggregate_function!(stddev_pop, StddevPop); +aggregate_function!(stddev_samp, Stddev); +aggregate_function!(var, Variance); +aggregate_function!(var_pop, VariancePop); +aggregate_function!(var_samp, Variance); pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(approx_distinct))?; m.add_wrapped(wrap_pyfunction!(alias))?; + m.add_wrapped(wrap_pyfunction!(approx_median))?; + m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; + m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; @@ -322,9 +346,13 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(count))?; m.add_wrapped(wrap_pyfunction!(count_star))?; + m.add_wrapped(wrap_pyfunction!(covar))?; + m.add_wrapped(wrap_pyfunction!(covar_pop))?; + m.add_wrapped(wrap_pyfunction!(covar_samp))?; m.add_wrapped(wrap_pyfunction!(current_date))?; m.add_wrapped(wrap_pyfunction!(current_time))?; m.add_wrapped(wrap_pyfunction!(date_bin))?; @@ -336,6 +364,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(exp))?; m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(left))?; @@ -350,6 +379,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(max))?; m.add_wrapped(wrap_pyfunction!(make_array))?; m.add_wrapped(wrap_pyfunction!(md5))?; + m.add_wrapped(wrap_pyfunction!(mean))?; + m.add_wrapped(wrap_pyfunction!(median))?; m.add_wrapped(wrap_pyfunction!(min))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(nullif))?; @@ -376,8 +407,11 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(split_part))?; m.add_wrapped(wrap_pyfunction!(sqrt))?; m.add_wrapped(wrap_pyfunction!(starts_with))?; + m.add_wrapped(wrap_pyfunction!(stddev))?; + m.add_wrapped(wrap_pyfunction!(stddev_pop))?; + m.add_wrapped(wrap_pyfunction!(stddev_samp))?; m.add_wrapped(wrap_pyfunction!(strpos))?; - //m.add_wrapped(wrap_pyfunction!(struct))?; + m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword m.add_wrapped(wrap_pyfunction!(substr))?; m.add_wrapped(wrap_pyfunction!(sum))?; m.add_wrapped(wrap_pyfunction!(tan))?; @@ -390,7 +424,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; - //m.add_wrapped(wrap_pyfunction!(uuid))?; + m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision + m.add_wrapped(wrap_pyfunction!(var))?; + m.add_wrapped(wrap_pyfunction!(var_pop))?; + m.add_wrapped(wrap_pyfunction!(var_samp))?; m.add_wrapped(wrap_pyfunction!(window))?; Ok(()) }