Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 89 additions & 10 deletions datafusion/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pyarrow as pa
import pytest

from datafusion import SessionContext, column
from datafusion import SessionContext, column, lit
from datafusion import functions as f


Expand All @@ -28,21 +29,99 @@ def df():

# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
names=["a", "b"],
[
pa.array([1, 2, 3]),
pa.array([4, 4, 6]),
pa.array([9, 8, 5]),
],
names=["a", "b", "c"],
)
return ctx.create_dataframe([[batch]])


def test_built_in_aggregation(df):
col_a = column("a")
col_b = column("b")
df = df.aggregate(
col_c = column("c")

agg_df = df.aggregate(
[],
[f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)],
[
f.approx_distinct(col_b),
f.approx_median(col_b),
f.approx_percentile_cont(col_b, lit(0.5)),
f.approx_percentile_cont_with_weight(col_b, lit(0.6), lit(0.5)),
f.array_agg(col_b),
f.avg(col_a),
f.corr(col_a, col_b),
f.count(col_a),
f.covar(col_a, col_b),
f.covar_pop(col_a, col_c),
f.covar_samp(col_b, col_c),
# f.grouping(col_a), # No physical plan implemented yet
f.max(col_a),
f.mean(col_b),
f.median(col_b),
f.min(col_a),
f.sum(col_b),
f.stddev(col_a),
f.stddev_pop(col_b),
f.stddev_samp(col_c),
f.var(col_a),
f.var_pop(col_b),
f.var_samp(col_c),
],
)
result = agg_df.collect()[0]
values_a, values_b, values_c = df.collect()[0]

assert result.column(0) == pa.array([2], type=pa.uint64())
assert result.column(1) == pa.array([4])
assert result.column(2) == pa.array([4])
assert result.column(3) == pa.array([6])
assert result.column(4) == pa.array([[4, 4, 6]])
np.testing.assert_array_almost_equal(
result.column(5), np.average(values_a)
)
np.testing.assert_array_almost_equal(
result.column(6), np.corrcoef(values_a, values_b)[0][1]
)
assert result.column(7) == pa.array([len(values_a)])
# Sample (co)variance -> ddof=1
# Population (co)variance -> ddof=0
np.testing.assert_array_almost_equal(
result.column(8), np.cov(values_a, values_b, ddof=1)[0][1]
)
np.testing.assert_array_almost_equal(
result.column(9), np.cov(values_a, values_c, ddof=0)[0][1]
)
np.testing.assert_array_almost_equal(
result.column(10), np.cov(values_b, values_c, ddof=1)[0][1]
)
np.testing.assert_array_almost_equal(result.column(11), np.max(values_a))
np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b))
np.testing.assert_array_almost_equal(
result.column(13), np.median(values_b)
)
np.testing.assert_array_almost_equal(result.column(14), np.min(values_a))
np.testing.assert_array_almost_equal(
result.column(15), np.sum(values_b.to_pylist())
)
np.testing.assert_array_almost_equal(
result.column(16), np.std(values_a, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(17), np.std(values_b, ddof=0)
)
np.testing.assert_array_almost_equal(
result.column(18), np.std(values_c, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(19), np.var(values_a, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(20), np.var(values_b, ddof=0)
)
np.testing.assert_array_almost_equal(
result.column(21), np.var(values_c, ddof=1)
)
result = df.collect()[0]
assert result.column(0) == pa.array([3])
assert result.column(1) == pa.array([1])
assert result.column(2) == pa.array([3], type=pa.int64())
assert result.column(3) == pa.array([2], type=pa.uint64())
47 changes: 42 additions & 5 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,25 +287,49 @@ scalar_function!(upper, Upper, "Converts the string to all upper case.");
scalar_function!(make_array, MakeArray);
scalar_function!(array, MakeArray);
scalar_function!(nullif, NullIf);
//scalar_function!(uuid, Uuid);
//scalar_function!(struct, Struct);
scalar_function!(uuid, Uuid);
scalar_function!(r#struct, Struct); // Use raw identifier since struct is a keyword
scalar_function!(from_unixtime, FromUnixtime);
scalar_function!(arrow_typeof, ArrowTypeof);
scalar_function!(random, Random);

aggregate_function!(approx_distinct, ApproxDistinct);
aggregate_function!(approx_median, ApproxMedian);
aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
aggregate_function!(
approx_percentile_cont_with_weight,
ApproxPercentileContWithWeight
);
aggregate_function!(array_agg, ArrayAgg);
aggregate_function!(avg, Avg);
aggregate_function!(corr, Correlation);
aggregate_function!(count, Count);
aggregate_function!(covar, Covariance);
aggregate_function!(covar_pop, CovariancePop);
aggregate_function!(covar_samp, Covariance);
aggregate_function!(grouping, Grouping);
aggregate_function!(max, Max);
aggregate_function!(mean, Avg);
aggregate_function!(median, Median);
aggregate_function!(min, Min);
aggregate_function!(sum, Sum);
aggregate_function!(approx_distinct, ApproxDistinct);
aggregate_function!(stddev, Stddev);
aggregate_function!(stddev_pop, StddevPop);
aggregate_function!(stddev_samp, Stddev);
aggregate_function!(var, Variance);
aggregate_function!(var_pop, VariancePop);
aggregate_function!(var_samp, Variance);

pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
m.add_wrapped(wrap_pyfunction!(acos))?;
m.add_wrapped(wrap_pyfunction!(approx_distinct))?;
m.add_wrapped(wrap_pyfunction!(alias))?;
m.add_wrapped(wrap_pyfunction!(approx_median))?;
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
m.add_wrapped(wrap_pyfunction!(array))?;
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
Expand All @@ -322,9 +346,13 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(col))?;
m.add_wrapped(wrap_pyfunction!(concat_ws))?;
m.add_wrapped(wrap_pyfunction!(concat))?;
m.add_wrapped(wrap_pyfunction!(corr))?;
m.add_wrapped(wrap_pyfunction!(cos))?;
m.add_wrapped(wrap_pyfunction!(count))?;
m.add_wrapped(wrap_pyfunction!(count_star))?;
m.add_wrapped(wrap_pyfunction!(covar))?;
m.add_wrapped(wrap_pyfunction!(covar_pop))?;
m.add_wrapped(wrap_pyfunction!(covar_samp))?;
m.add_wrapped(wrap_pyfunction!(current_date))?;
m.add_wrapped(wrap_pyfunction!(current_time))?;
m.add_wrapped(wrap_pyfunction!(date_bin))?;
Expand All @@ -336,6 +364,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(exp))?;
m.add_wrapped(wrap_pyfunction!(floor))?;
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
m.add_wrapped(wrap_pyfunction!(grouping))?;
m.add_wrapped(wrap_pyfunction!(in_list))?;
m.add_wrapped(wrap_pyfunction!(initcap))?;
m.add_wrapped(wrap_pyfunction!(left))?;
Expand All @@ -350,6 +379,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(max))?;
m.add_wrapped(wrap_pyfunction!(make_array))?;
m.add_wrapped(wrap_pyfunction!(md5))?;
m.add_wrapped(wrap_pyfunction!(mean))?;
m.add_wrapped(wrap_pyfunction!(median))?;
m.add_wrapped(wrap_pyfunction!(min))?;
m.add_wrapped(wrap_pyfunction!(now))?;
m.add_wrapped(wrap_pyfunction!(nullif))?;
Expand All @@ -376,8 +407,11 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(split_part))?;
m.add_wrapped(wrap_pyfunction!(sqrt))?;
m.add_wrapped(wrap_pyfunction!(starts_with))?;
m.add_wrapped(wrap_pyfunction!(stddev))?;
m.add_wrapped(wrap_pyfunction!(stddev_pop))?;
m.add_wrapped(wrap_pyfunction!(stddev_samp))?;
m.add_wrapped(wrap_pyfunction!(strpos))?;
//m.add_wrapped(wrap_pyfunction!(struct))?;
m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword
m.add_wrapped(wrap_pyfunction!(substr))?;
m.add_wrapped(wrap_pyfunction!(sum))?;
m.add_wrapped(wrap_pyfunction!(tan))?;
Expand All @@ -390,7 +424,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
m.add_wrapped(wrap_pyfunction!(upper))?;
//m.add_wrapped(wrap_pyfunction!(uuid))?;
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
m.add_wrapped(wrap_pyfunction!(var))?;
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_samp))?;
m.add_wrapped(wrap_pyfunction!(window))?;
Ok(())
}