Skip to content

Commit b1d79ba

Browse files
committed
put all type coercion in coerce_arguments_for_signature
1 parent 76cede4 commit b1d79ba

3 files changed

Lines changed: 51 additions & 56 deletions

File tree

datafusion/common/src/utils.rs

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -464,62 +464,53 @@ pub fn coerced_type_with_base_type_only(
464464
base_type: &DataType,
465465
) -> DataType {
466466
match data_type {
467-
DataType::List(field)
468-
| DataType::FixedSizeList(field, _)
469-
| DataType::LargeList(field) => {
470-
let field_type = match field.data_type() {
471-
// nested type could be different list type
472-
DataType::List(_)
473-
| DataType::FixedSizeList(_, _)
474-
| DataType::LargeList(_) => {
475-
coerced_type_with_base_type_only(field.data_type(), base_type)
476-
}
477-
_ => base_type.to_owned(),
478-
};
479-
if matches!(data_type, DataType::LargeList(_)) {
480-
DataType::LargeList(Arc::new(Field::new(
481-
field.name(),
482-
field_type,
483-
field.is_nullable(),
484-
)))
485-
} else {
486-
DataType::List(Arc::new(Field::new(
487-
field.name(),
488-
field_type,
489-
field.is_nullable(),
490-
)))
491-
}
467+
DataType::List(field) | DataType::FixedSizeList(field, _) => {
468+
let field_type =
469+
coerced_type_with_base_type_only(field.data_type(), base_type);
470+
471+
DataType::List(Arc::new(Field::new(
472+
field.name(),
473+
field_type,
474+
field.is_nullable(),
475+
)))
476+
}
477+
DataType::LargeList(field) => {
478+
let field_type =
479+
coerced_type_with_base_type_only(field.data_type(), base_type);
480+
481+
DataType::LargeList(Arc::new(Field::new(
482+
field.name(),
483+
field_type,
484+
field.is_nullable(),
485+
)))
492486
}
487+
493488
_ => base_type.clone(),
494489
}
495490
}
496491

497492
pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
498493
match data_type {
499-
DataType::FixedSizeList(field, _) => {
500-
let field_type = match field.data_type() {
501-
DataType::List(_)
502-
| DataType::FixedSizeList(_, _)
503-
| DataType::LargeList(_) => {
504-
coerced_fixed_size_list_to_list(field.data_type())
505-
}
506-
_ => field.data_type().to_owned(),
507-
};
508-
if matches!(data_type, DataType::LargeList(_)) {
509-
DataType::LargeList(Arc::new(Field::new(
510-
field.name(),
511-
field_type,
512-
field.is_nullable(),
513-
)))
514-
} else {
515-
DataType::List(Arc::new(Field::new(
516-
field.name(),
517-
field_type,
518-
field.is_nullable(),
519-
)))
520-
}
494+
DataType::List(field) | DataType::FixedSizeList(field, _) => {
495+
let field_type = coerced_fixed_size_list_to_list(field.data_type());
496+
497+
DataType::List(Arc::new(Field::new(
498+
field.name(),
499+
field_type,
500+
field.is_nullable(),
501+
)))
521502
}
522-
_ => data_type.to_owned(),
503+
DataType::LargeList(field) => {
504+
let field_type = coerced_fixed_size_list_to_list(field.data_type());
505+
506+
DataType::LargeList(Arc::new(Field::new(
507+
field.name(),
508+
field_type,
509+
field.is_nullable(),
510+
)))
511+
}
512+
513+
_ => data_type.clone(),
523514
}
524515
}
525516

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow::{
2121
compute::can_cast_types,
2222
datatypes::{DataType, TimeUnit},
2323
};
24-
use datafusion_common::utils::list_ndims;
24+
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
2525
use datafusion_common::{
2626
internal_datafusion_err, internal_err, plan_err, DataFusionError, Result,
2727
};
@@ -141,7 +141,8 @@ fn get_valid_types(
141141
DataType::List(_)
142142
| DataType::LargeList(_)
143143
| DataType::FixedSizeList(_, _) => {
144-
Ok(vec![vec![array_type.clone(), DataType::Int64]])
144+
let array_type = coerced_fixed_size_list_to_list(array_type);
145+
Ok(vec![vec![array_type, DataType::Int64]])
145146
}
146147
_ => Ok(vec![vec![]]),
147148
}

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use arrow::datatypes::{DataType, IntervalUnit};
2323

2424
use datafusion_common::config::ConfigOptions;
2525
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
26-
use datafusion_common::utils::coerced_fixed_size_list_to_list;
2726
use datafusion_common::{
2827
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
2928
DataFusionError, Result, ScalarValue,
@@ -590,17 +589,21 @@ fn coerce_arguments_for_fun(
590589
if expressions.is_empty() {
591590
return Ok(vec![]);
592591
}
593-
594592
let mut expressions: Vec<Expr> = expressions.to_vec();
595593

596-
// coerce the fixed size list to list for all array fucntions
597-
if fun.name().contains("array") {
594+
// Cast Fixedsizelist to List for array functions
595+
if *fun == BuiltinScalarFunction::MakeArray {
598596
expressions = expressions
599597
.into_iter()
600598
.map(|expr| {
601599
let data_type = expr.get_type(schema).unwrap();
602-
let to_type = coerced_fixed_size_list_to_list(&data_type);
603-
expr.cast_to(&to_type, schema)
600+
if let DataType::FixedSizeList(field, _) = data_type {
601+
let field = field.as_ref().clone();
602+
let to_type = DataType::List(Arc::new(field));
603+
expr.cast_to(&to_type, schema)
604+
} else {
605+
Ok(expr)
606+
}
604607
})
605608
.collect::<Result<Vec<_>>>()?;
606609
}

0 commit comments

Comments
 (0)