Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Type Coercion for UDF Arguments #14268

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bcc0620
Fix DF 43 regression coerce ascii input to string
shehabgamin Jan 22, 2025
a8539c1
datafusion-testing submodule has new commits
shehabgamin Jan 22, 2025
4e4cb02
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 23, 2025
eb61d49
implicit cast int to string
shehabgamin Jan 24, 2025
a6f62a0
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 24, 2025
5afbfc0
fix can_coerce_to and add tests
shehabgamin Jan 24, 2025
36a23b7
update deprecation message for values exec
shehabgamin Jan 24, 2025
d2eadea
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 24, 2025
944e0a3
lint
shehabgamin Jan 24, 2025
2d77206
coerce to string
shehabgamin Jan 24, 2025
1a27626
Adjust type signature
shehabgamin Jan 24, 2025
e714ba1
fix comment
shehabgamin Jan 24, 2025
93d75b1
clean up clippy warnings
shehabgamin Jan 25, 2025
5067223
type signature coercible
shehabgamin Jan 25, 2025
4d395e2
bump pyo3
shehabgamin Jan 25, 2025
ec8ccd1
udf type coercion
shehabgamin Jan 25, 2025
d78877a
moving testing out of functions crate due to circular dependencies
shehabgamin Jan 25, 2025
041e4ac
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 26, 2025
437b83d
find the common string type for TypeSignature::Coercible
shehabgamin Jan 26, 2025
8fd9fb3
update coercible string
shehabgamin Jan 26, 2025
62b97c5
fix error msg and add tests for udfs
shehabgamin Jan 26, 2025
46350c9
update docs to note that args are coercible string
shehabgamin Jan 26, 2025
3f0c870
update expr test
shehabgamin Jan 26, 2025
46cda71
undo
shehabgamin Jan 27, 2025
e7d474d
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 27, 2025
1f826cb
remove test since already covered in slt
shehabgamin Jan 27, 2025
5d258b2
add dictionary to base_yupe
shehabgamin Jan 27, 2025
17df1bc
add dictionary to base_type
shehabgamin Jan 27, 2025
97c7db0
fix lint issues
shehabgamin Jan 28, 2025
d0f3f9a
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Feb 1, 2025
bb773d5
implement TypeSignatureClass::Integer
shehabgamin Feb 1, 2025
08d9b4a
fix slt test
shehabgamin Feb 1, 2025
09ef47c
fix slt
shehabgamin Feb 1, 2025
f519d92
Merge branch 'branch-45' of github.com:lakehq/datafusion into sail-df…
shehabgamin Feb 4, 2025
1f806b2
revert native to original and add anynative
shehabgamin Feb 4, 2025
25bcb58
use anynative
shehabgamin Feb 4, 2025
1d94e5c
use user defined rules instead
shehabgamin Feb 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ log = { workspace = true }
object_store = { workspace = true, optional = true }
parquet = { workspace = true, optional = true, default-features = true }
paste = "1.0.15"
pyo3 = { version = "0.23.3", optional = true }
pyo3 = { version = "0.23.4", optional = true }
recursive = { workspace = true, optional = true }
sqlparser = { workspace = true }
tokio = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ pub fn base_type(data_type: &DataType) -> DataType {
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _) => base_type(field.data_type()),
DataType::Dictionary(_, value_type) => base_type(value_type),
_ => data_type.to_owned(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl Signature {
}
}

/// A specified number of numeric arguments
/// A specified number of string arguments
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::String(arg_count),
Expand Down
215 changes: 110 additions & 105 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::utils::{base_type, coerced_fixed_size_list_to_list};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
types::{LogicalType, NativeType},
Expand Down Expand Up @@ -53,14 +53,14 @@ pub fn data_types_with_scalar_udf(
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
return if type_signature.supports_zero_argument() {
Ok(vec![])
} else if type_signature.used_to_support_zero_arguments() {
// Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name());
plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name())
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
plan_err!("{} does not support zero arguments.", func.name())
};
}

let valid_types =
Expand Down Expand Up @@ -91,14 +91,14 @@ pub fn data_types_with_aggregate_udf(
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
return if type_signature.supports_zero_argument() {
Ok(vec![])
} else if type_signature.used_to_support_zero_arguments() {
// Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name());
plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name())
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
plan_err!("{} does not support zero arguments.", func.name())
};
}

let valid_types =
Expand Down Expand Up @@ -128,14 +128,14 @@ pub fn data_types_with_window_udf(
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
return if type_signature.supports_zero_argument() {
Ok(vec![])
} else if type_signature.used_to_support_zero_arguments() {
// Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
return plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name());
plan_err!("{} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.", func.name())
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
plan_err!("{} does not support zero arguments.", func.name())
};
}

let valid_types =
Expand Down Expand Up @@ -165,20 +165,20 @@ pub fn data_types(
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
return if type_signature.supports_zero_argument() {
Ok(vec![])
} else if type_signature.used_to_support_zero_arguments() {
// Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
return plan_err!(
plan_err!(
"signature {:?} does not support zero arguments. Use TypeSignature::Nullary for zero arguments.",
type_signature
);
)
} else {
return plan_err!(
plan_err!(
"signature {:?} does not support zero arguments.",
type_signature
);
}
)
};
}

let valid_types = get_valid_types(type_signature, current_types)?;
Expand Down Expand Up @@ -387,8 +387,8 @@ fn get_valid_types(

// We need to find the coerced base type, mainly for cases like:
// `array_append(List(null), i64)` -> `List(i64)`
let array_base_type = datafusion_common::utils::base_type(array_type);
let elem_base_type = datafusion_common::utils::base_type(elem_type);
let array_base_type = base_type(array_type);
let elem_base_type = base_type(elem_type);
let new_base_type = comparison_coercion(&array_base_type, &elem_base_type);

let new_base_type = new_base_type.ok_or_else(|| {
Expand Down Expand Up @@ -453,62 +453,19 @@ fn get_valid_types(
.collect(),
TypeSignature::String(number) => {
function_length_check(current_types.len(), *number)?;

let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();
if logical_data_type == NativeType::String {
new_types.push(data_type.to_owned());
} else if logical_data_type == NativeType::Null {
// TODO: Switch to Utf8View if all the string functions supports Utf8View
new_types.push(DataType::Utf8);
} else {
return plan_err!(
"The signature expected NativeType::String but received {logical_data_type}"
);
}
}

// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
find_common_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
} else {
plan_err!(
"{} and {} are not coercible to a common string type",
lhs_type,
rhs_type
)
}
}
}
}

// Length checked above, safe to unwrap
let mut coerced_type = new_types.first().unwrap().to_owned();
let new_types = validate_and_collect_string_types(current_types)?;
let mut coerced_type = new_types
.first()
.ok_or_else(|| {
internal_datafusion_err!(
"Expected at least one type in the list of types"
)
})?
.to_owned();
for t in new_types.iter().skip(1) {
coerced_type = find_common_type(&coerced_type, t)?;
coerced_type = find_common_string_type(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
}
}

vec![vec![base_type_or_default_type(&coerced_type); *number]]
vec![vec![base_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
function_length_check(current_types.len(), *number)?;
Expand Down Expand Up @@ -584,23 +541,7 @@ fn get_valid_types(
match target_type_class {
TypeSignatureClass::Native(native_type) => {
let target_type = native_type.native();
if &logical_type == target_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean others function that used Coercible String now also cast integer to string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's TypeSignature::Coercible with a TypeSignatureClass::Native(logical_string()), then yes. Any function that specifies TypeSignature::Coercible with a TypeSignatureClass::Native should coerce according to the behavior implemented in the default_cast_for function for NativeType in order to be consistent.

Copy link
Contributor

@jayzhan211 jayzhan211 Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeSignature::Coercible is designed to cast between the same logical type, but now it casts to any kind of type, I don't think this is ideal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, in this PR, SELECT bit_length(12); now return 16 instead of error, but I think we should error. Any unexpected type is valid which doesn't seem correct

Copy link
Contributor Author

@shehabgamin shehabgamin Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeSignature::Coercible is designed to cast between the same logical type, but now it casts to any kind of type, I don't think this is ideal.

This is what the doc comments say:

/// One or more arguments belonging to the [`TypeSignatureClass`], in order.
///
/// For example, `Coercible(vec![logical_float64()])` accepts
/// arguments like `vec![Int32]` or `vec![Float32]`
/// since i32 and f32 can be cast to f64
///
/// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`].
Coercible(Vec<TypeSignatureClass>),

Link:

/// One or more arguments belonging to the [`TypeSignatureClass`], in order.
///
/// For example, `Coercible(vec![logical_float64()])` accepts
/// arguments like `vec![Int32]` or `vec![Float32]`
/// since i32 and f32 can be cast to f64
///
/// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`].
Coercible(Vec<TypeSignatureClass>),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shehabgamin
I suggest with revert to the previous change and add binary->string casting given most of the string function allow this kind of casting, but not int->string like ascii(2)

For Signature::String, we can make it a convenient wrapper of Coercible(string, string) or deprecate it.

Copy link
Contributor

@jayzhan211 jayzhan211 Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule of thumb, stick to either Postgres or DuckDB for functions, and avoid overly broad type casting that allows any type to convert to any other. This can turn invalid queries into valid ones, making errors harder to catch, especially when tests don’t cover all cases.

Based on that,

  1. Cast from Binary to String 👍🏻
  2. Cast from Numeric to String 👎🏻 (At least for the string functions changed in this PR)
  3. Cast from Float to Integer 👎🏻 (At least not for repeat)

return target_type.default_cast_for(current_type);
}

if logical_type == NativeType::Null {
return target_type.default_cast_for(current_type);
}

if target_type.is_integer() && logical_type.is_integer() {
return target_type.default_cast_for(current_type);
}

internal_err!(
"Expect {} but received {}",
target_type_class,
current_type
)
target_type.default_cast_for(current_type)
}
// Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp
TypeSignatureClass::Timestamp
Expand Down Expand Up @@ -637,6 +578,24 @@ fn get_valid_types(
new_types.push(target_type);
}

// Following the behavior of `TypeSignature::String`, we find the common string type.
let string_indices: Vec<_> = target_types.iter().enumerate()
.filter(|(_, t)| {
matches!(t, TypeSignatureClass::Native(n) if n.native() == &NativeType::String)
})
.map(|(i, _)| i)
.collect();
if !string_indices.is_empty() {
let mut coerced_string_type = new_types[string_indices[0]].to_owned();
for &i in string_indices.iter().skip(1) {
coerced_string_type =
find_common_string_type(&coerced_string_type, &new_types[i])?;
}
for i in string_indices {
new_types[i] = coerced_string_type.clone();
}
}

vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => {
Expand Down Expand Up @@ -744,6 +703,47 @@ fn get_valid_types(
Ok(valid_types)
}

/// Validates that all data types are either [`NativeType::String`] or [`NativeType::Null`].
/// For [`NativeType::Null`], returns [`DataType::Utf8`] as the default string type.
/// Returns error if any type is neither [`NativeType::String`] nor [`NativeType::Null`].
fn validate_and_collect_string_types(data_types: &[DataType]) -> Result<Vec<DataType>> {
data_types
.iter()
.map(|data_type| {
let logical_type: NativeType = data_type.into();
match logical_type {
NativeType::String => Ok(data_type.to_owned()),
// TODO: Switch to Utf8View if all the string functions supports Utf8View
NativeType::Null => Ok(DataType::Utf8),
_ => plan_err!("The signature expected NativeType::String but received {logical_type}"),
}
})
.collect()
}

/// Returns a common string [`DataType`] that both input types can be coerced to.
/// Handles [`DataType::Dictionary`] by recursively finding common type of their value [`DataType`].
/// Returns error if types cannot be coerced to a common string [`DataType`].
fn find_common_string_type(lhs_type: &DataType, rhs_type: &DataType) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
find_common_string_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other) | (other, DataType::Dictionary(_, v)) => {
find_common_string_type(v, other)
}
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
} else {
plan_err!(
"{lhs_type} and {rhs_type} are not coercible to a common string type"
)
}
}
}
}

/// Try to coerce the current argument types to match the given `valid_types`.
///
/// For example, if a function `func` accepts arguments of `(int64, int64)`,
Expand Down Expand Up @@ -881,7 +881,7 @@ fn coerced_from<'a>(
Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
) => Some(type_into.clone()),
(Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
// We can go into a Utf8View from a Utf8 or LargeUtf8
// We can go into a Utf8View from an Utf8 or LargeUtf8
(Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
// Any type can be coerced into strings
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
Expand All @@ -892,7 +892,7 @@ fn coerced_from<'a>(
// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
(List(_) | LargeList(_), _)
if datafusion_common::utils::base_type(type_from).eq(&Null)
if base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
Expand Down Expand Up @@ -931,12 +931,17 @@ fn coerced_from<'a>(

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::Volatility;

use super::*;
use arrow::datatypes::Field;
use datafusion_common::assert_contains;
use crate::type_coercion::functions::{
can_coerce_from, coerced_from, data_types, get_valid_types, maybe_data_types,
};
use crate::TypeSignature;
use arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion_common::{assert_contains, Result};
use datafusion_expr_common::signature::{
Signature, Volatility, FIXED_SIZE_LIST_WILDCARD,
};

#[test]
fn test_string_conversion() {
Expand Down Expand Up @@ -1124,7 +1129,7 @@ mod tests {
Volatility::Stable,
);

let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
let coerced_data_types = data_types("test", &current_types, &signature)?;
assert_eq!(coerced_data_types, current_types);

// make sure it can't coerce to a different size
Expand All @@ -1140,7 +1145,7 @@ mod tests {
vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
Volatility::Stable,
);
let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
let coerced_data_types = data_types("test", &current_types, &signature)?;
assert_eq!(coerced_data_types, current_types);

Ok(())
Expand Down
Loading
Loading