Skip to content

Commit 46eaa9d

Browse files
committed
Add any_value aggregation calculation function and test cases
1 parent 5e3127f commit 46eaa9d

5 files changed

Lines changed: 27 additions & 3 deletions

File tree

examples/aggregate/aggregate.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,6 @@ select b.name, c.goods_name, aggregate_unique(b.name) as names, aggregate_join(c
4949
join `data/users.json` b on a.uid=b.uid
5050
join `data/goodses.json` c on a.goods_id=c.goods_id where a.status=0 group by b.name, c.goods_name;
5151

52-
select uid, json_arrayagg(goods_id) as agoods_ids, json_objectagg(goods_id, order_id) as ogoods_ids from `data/orders.json` where status=0 group by uid;
52+
select uid, json_arrayagg(goods_id) as agoods_ids, json_objectagg(goods_id, order_id) as ogoods_ids from `data/orders.json` where status=0 group by uid;
53+
54+
select uid, any_value(goods_id) as goods_id from `data/orders.json` where status=0 group by uid;

syncanysql/calculaters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"group_bit_xor": AggregateGroupBitXorCalculater,
3737
"json_arrayagg": AggregateJsonArrayaggCalculater,
3838
"json_objectagg": AggregateJsonObjectaggCalculater,
39+
"aggregate_any_value": AggregateAnyValueCalculater,
3940
"row_number": WindowAggregateRowNumberCalculater,
4041
"rank": WindowAggregateRankCalculater,
4142
"dense_rank": WindowAggregateDenseRankCalculater,

syncanysql/calculaters/aggregate_calculater.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,4 +439,21 @@ def final_value(self, state_value):
439439
return json.dumps(state_value, ensure_ascii=False, default=self.format_value)
440440

441441
def get_final_filter(self):
442-
return StringFilter.default()
442+
return StringFilter.default()
443+
444+
445+
class AggregateAnyValueCalculater(StateAggregateCalculater):
446+
def aggregate(self, state_value, data_value):
447+
if state_value is None:
448+
return {"value": data_value}
449+
return state_value
450+
451+
def reduce(self, state_value, data_value):
452+
if state_value is None:
453+
return data_value
454+
return state_value
455+
456+
def final_value(self, state_value):
457+
if state_value is None:
458+
return None
459+
return state_value["value"]

syncanysql/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,8 @@ def compile_aggregate_calculate(self, expression):
13851385
return "@group_concat::aggregate", "@group_concat::reduce", "@group_concat::final_value"
13861386
elif isinstance(expression, sqlglot_expressions.GroupUniqArray):
13871387
return "@group_uniq_array::aggregate", "@group_uniq_array::reduce", "@group_uniq_array::final_value"
1388+
elif isinstance(expression, sqlglot_expressions.AnyValue):
1389+
return "@aggregate_any_value::aggregate", "@aggregate_any_value::reduce", "@aggregate_any_value::final_value"
13881390
elif isinstance(expression, sqlglot_expressions.Anonymous):
13891391
aggregate_funcs = {"grouparray": "group_array", "groupuniqarray": "group_uniq_array", "groupbitand": "group_bit_and",
13901392
"groupbitor": "group_bit_or", "groupbitxor": "group_bit_xor"}
@@ -2981,7 +2983,7 @@ def is_aggregate(self, expression, config, arguments):
29812983
return False
29822984
if isinstance(expression, (sqlglot_expressions.Count, sqlglot_expressions.Sum, sqlglot_expressions.Max,
29832985
sqlglot_expressions.Min, sqlglot_expressions.Avg, sqlglot_expressions.GroupConcat,
2984-
sqlglot_expressions.GroupUniqArray)):
2986+
sqlglot_expressions.GroupUniqArray, sqlglot_expressions.AnyValue)):
29852987
return True
29862988
if isinstance(expression, sqlglot_expressions.Anonymous):
29872989
aggregate_funcs = {"group_array", "grouparray", "group_uniq_array", "groupuniqarray", "group_bit_and", "groupbitand",

tests/test_example_aggregate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def test_aggregate(self):
7676
self.assert_data(52, [{'uid': 2, 'agoods_ids': '[1, 2, 1, 2]', 'ogoods_ids': '{"1": 5, "2": 6}'},
7777
{'uid': 1, 'agoods_ids': '[1, 1]', 'ogoods_ids': '{"1": 4}'}], "data error")
7878

79+
self.assert_data(54, [{'uid': 2, 'goods_id': 1}, {'uid': 1, 'goods_id': 1}], "data error")
80+
7981
def test_aggregate_batch(self):
8082
self.execute("aggregate_batch.sql")
8183

0 commit comments

Comments
 (0)