diff --git a/datafusion/spark/src/function/string/luhn_check.rs b/datafusion/spark/src/function/string/luhn_check.rs new file mode 100644 index 0000000000000..07a4a4a41dabf --- /dev/null +++ b/datafusion/spark/src/function/string/luhn_check.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; + +/// Spark-compatible `luhn_check` expression +/// +#[derive(Debug)] +pub struct SparkLuhnCheck { + signature: Signature, +} + +impl Default for SparkLuhnCheck { + fn default() -> Self { + Self::new() + } +} + +impl SparkLuhnCheck { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLuhnCheck { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "luhn_check" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; + + match array { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8View => { + let str_array = array.as_string_view(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + DataType::Utf8 => { + let str_array = array.as_string::(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + DataType::LargeUtf8 => { + let str_array = array.as_string::(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + other => { + exec_err!("Unsupported data type {other:?} for function `luhn_check`") + } + }, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok( + ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))), + ), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + } + other => { + exec_err!("Unsupported data type {other:?} for function `luhn_check`") + } + } + } +} + +/// Validates a string using the Luhn algorithm. +/// Returns `true` if the input is a valid Luhn number. +fn luhn_check_impl(input: &str) -> bool { + let mut sum = 0u32; + let mut alt = false; + let mut digits_processed = 0; + + for b in input.as_bytes().iter().rev() { + let digit = match b { + b'0'..=b'9' => { + digits_processed += 1; + b - b'0' + } + _ => return false, + }; + + let mut val = digit as u32; + if alt { + val *= 2; + if val > 9 { + val -= 9; + } + } + sum += val; + alt = !alt; + } + + digits_processed > 0 && sum % 10 == 0 +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 9d5fabe832e92..e45bf4add7721 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -17,6 +17,7 @@ pub mod ascii; pub mod char; +pub mod luhn_check; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; @@ -24,6 +25,7 @@ use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); make_udf_function!(char::SparkChar, char); +make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); pub mod expr_fn { use datafusion_functions::export_functions; @@ -38,8 +40,13 @@ pub mod expr_fn { "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", arg1 )); + export_functions!(( + luhn_check, + "Returns whether the input string of digits is valid according to the Luhn algorithm.", + arg1 + )); } pub fn functions() -> Vec> { - vec![ascii(), char()] + vec![ascii(), char(), luhn_check()] } diff --git a/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt b/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt index 389c34ef68ab9..ccb17323b24dc 100644 --- a/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt +++ b/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt @@ -15,23 +15,145 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT luhn_check('79927398713'); -## PySpark 3.5.5 Result: {'luhn_check(79927398713)': True, 'typeof(luhn_check(79927398713))': 'boolean', 'typeof(79927398713)': 'string'} -#query -#SELECT luhn_check('79927398713'::string); - -## Original Query: SELECT luhn_check('79927398714'); -## PySpark 3.5.5 Result: {'luhn_check(79927398714)': False, 'typeof(luhn_check(79927398714))': 'boolean', 'typeof(79927398714)': 'string'} -#query -#SELECT luhn_check('79927398714'::string); - -## Original Query: SELECT luhn_check('8112189876'); -## PySpark 3.5.5 Result: {'luhn_check(8112189876)': True, 'typeof(luhn_check(8112189876))': 'boolean', 'typeof(8112189876)': 'string'} -#query -#SELECT luhn_check('8112189876'::string); + +query B +SELECT luhn_check('79927398713'::string); +---- +true + + +query B +SELECT luhn_check('79927398714'::string); +---- +false + + +query B +SELECT luhn_check('8112189876'::string); +---- +true + +query B +select luhn_check('4111111111111111'::string); +---- +true + +query B +select luhn_check('5500000000000004'::string); +---- +true + +query B +select luhn_check('340000000000009'::string); +---- +true + +query B +select luhn_check('6011000000000004'::string); +---- +true + + +query B +select luhn_check('6011000000000005'::string); +---- +false + + +query B +select luhn_check('378282246310006'::string); +---- +false + + +query B +select luhn_check('0'::string); +---- +true + + +query B +select luhn_check('79927398713'::string) +---- +true + +query B +select luhn_check('4417123456789113'::string) +---- +true + +query B +select luhn_check('7992 7398 714'::string) +---- +false + +query B +select luhn_check('79927398714'::string) +---- +false + +query B +select luhn_check('4111111111111111 '::string) +---- +false + + +query B +select luhn_check('4111111 111111111'::string) +---- +false + +query B +select luhn_check(' 4111111111111111'::string) +---- +false + +query B +select luhn_check(''::string) +---- +false + +query B +select luhn_check(' ') +---- +false + + +query B +select luhn_check('510B105105105106'::string) +---- +false + + +query B +select luhn_check('ABCDED'::string) +---- +false + +query B +select luhn_check(null); +---- +NULL + +query B +select luhn_check(6011111111111117::BIGINT) +---- +true + + +query B +select luhn_check(6011111111111118::BIGINT) +---- +false + + +query B +select luhn_check(123.456::decimal(6,3)) +---- +false + +query B +SELECT luhn_check(a) FROM (VALUES ('79927398713'::string), ('79927398714'::string)) AS t(a); +---- +true +false