Skip to content

Commit 62ee8fb

Browse files
authored
Minor: refactor trim to clean up duplicated code (#8434)
* refactor trim * add fmt for TrimType * fix closure * update comment
1 parent 182a37e commit 62ee8fb

1 file changed

Lines changed: 69 additions & 100 deletions

File tree

datafusion/physical-expr/src/string_expressions.rs

Lines changed: 69 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ use datafusion_common::{
3737
};
3838
use datafusion_common::{internal_err, DataFusionError, Result};
3939
use datafusion_expr::ColumnarValue;
40-
use std::iter;
4140
use std::sync::Arc;
41+
use std::{
42+
fmt::{Display, Formatter},
43+
iter,
44+
};
4245
use uuid::Uuid;
4346

4447
/// applies a unary expression to `args[0]` that is expected to be downcastable to
@@ -133,53 +136,6 @@ pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
133136
Ok(Arc::new(result) as ArrayRef)
134137
}
135138

136-
/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string.
137-
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
138-
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
139-
match args.len() {
140-
1 => {
141-
let string_array = as_generic_string_array::<T>(&args[0])?;
142-
143-
let result = string_array
144-
.iter()
145-
.map(|string| {
146-
string.map(|string: &str| {
147-
string.trim_start_matches(' ').trim_end_matches(' ')
148-
})
149-
})
150-
.collect::<GenericStringArray<T>>();
151-
152-
Ok(Arc::new(result) as ArrayRef)
153-
}
154-
2 => {
155-
let string_array = as_generic_string_array::<T>(&args[0])?;
156-
let characters_array = as_generic_string_array::<T>(&args[1])?;
157-
158-
let result = string_array
159-
.iter()
160-
.zip(characters_array.iter())
161-
.map(|(string, characters)| match (string, characters) {
162-
(None, _) => None,
163-
(_, None) => None,
164-
(Some(string), Some(characters)) => {
165-
let chars: Vec<char> = characters.chars().collect();
166-
Some(
167-
string
168-
.trim_start_matches(&chars[..])
169-
.trim_end_matches(&chars[..]),
170-
)
171-
}
172-
})
173-
.collect::<GenericStringArray<T>>();
174-
175-
Ok(Arc::new(result) as ArrayRef)
176-
}
177-
other => internal_err!(
178-
"btrim was called with {other} arguments. It requires at least 1 and at most 2."
179-
),
180-
}
181-
}
182-
183139
/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character.
184140
/// chr(65) = 'A'
185141
pub fn chr(args: &[ArrayRef]) -> Result<ArrayRef> {
@@ -346,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
346302
handle(args, |string| string.to_ascii_lowercase(), "lower")
347303
}
348304

349-
/// Removes the longest string containing only characters in characters (a space by default) from the start of string.
350-
/// ltrim('zzzytest', 'xyz') = 'test'
351-
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
305+
enum TrimType {
306+
Left,
307+
Right,
308+
Both,
309+
}
310+
311+
impl Display for TrimType {
312+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
313+
match self {
314+
TrimType::Left => write!(f, "ltrim"),
315+
TrimType::Right => write!(f, "rtrim"),
316+
TrimType::Both => write!(f, "btrim"),
317+
}
318+
}
319+
}
320+
321+
fn general_trim<T: OffsetSizeTrait>(
322+
args: &[ArrayRef],
323+
trim_type: TrimType,
324+
) -> Result<ArrayRef> {
325+
let func = match trim_type {
326+
TrimType::Left => |input, pattern: &str| {
327+
let pattern = pattern.chars().collect::<Vec<char>>();
328+
str::trim_start_matches::<&[char]>(input, pattern.as_ref())
329+
},
330+
TrimType::Right => |input, pattern: &str| {
331+
let pattern = pattern.chars().collect::<Vec<char>>();
332+
str::trim_end_matches::<&[char]>(input, pattern.as_ref())
333+
},
334+
TrimType::Both => |input, pattern: &str| {
335+
let pattern = pattern.chars().collect::<Vec<char>>();
336+
str::trim_end_matches::<&[char]>(
337+
str::trim_start_matches::<&[char]>(input, pattern.as_ref()),
338+
pattern.as_ref(),
339+
)
340+
},
341+
};
342+
343+
let string_array = as_generic_string_array::<T>(&args[0])?;
344+
352345
match args.len() {
353346
1 => {
354-
let string_array = as_generic_string_array::<T>(&args[0])?;
355-
356347
let result = string_array
357348
.iter()
358-
.map(|string| string.map(|string: &str| string.trim_start_matches(' ')))
349+
.map(|string| string.map(|string: &str| func(string, " ")))
359350
.collect::<GenericStringArray<T>>();
360351

361352
Ok(Arc::new(result) as ArrayRef)
362353
}
363354
2 => {
364-
let string_array = as_generic_string_array::<T>(&args[0])?;
365355
let characters_array = as_generic_string_array::<T>(&args[1])?;
366356

367357
let result = string_array
368358
.iter()
369359
.zip(characters_array.iter())
370360
.map(|(string, characters)| match (string, characters) {
371-
(Some(string), Some(characters)) => {
372-
let chars: Vec<char> = characters.chars().collect();
373-
Some(string.trim_start_matches(&chars[..]))
374-
}
361+
(Some(string), Some(characters)) => Some(func(string, characters)),
375362
_ => None,
376363
})
377364
.collect::<GenericStringArray<T>>();
378365

379366
Ok(Arc::new(result) as ArrayRef)
380367
}
381-
other => internal_err!(
382-
"ltrim was called with {other} arguments. It requires at least 1 and at most 2."
383-
),
368+
other => {
369+
internal_err!(
370+
"{trim_type} was called with {other} arguments. It requires at least 1 and at most 2."
371+
)
372+
}
384373
}
385374
}
386375

376+
/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
377+
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
378+
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
379+
general_trim::<T>(args, TrimType::Both)
380+
}
381+
382+
/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed.
383+
/// ltrim('zzzytest', 'xyz') = 'test'
384+
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
385+
general_trim::<T>(args, TrimType::Left)
386+
}
387+
388+
/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed.
389+
/// rtrim('testxxzx', 'xyz') = 'test'
390+
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
391+
general_trim::<T>(args, TrimType::Right)
392+
}
393+
387394
/// Repeats string the specified number of times.
388395
/// repeat('Pg', 4) = 'PgPgPgPg'
389396
pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
@@ -422,44 +429,6 @@ pub fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
422429
Ok(Arc::new(result) as ArrayRef)
423430
}
424431

425-
/// Removes the longest string containing only characters in characters (a space by default) from the end of string.
426-
/// rtrim('testxxzx', 'xyz') = 'test'
427-
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
428-
match args.len() {
429-
1 => {
430-
let string_array = as_generic_string_array::<T>(&args[0])?;
431-
432-
let result = string_array
433-
.iter()
434-
.map(|string| string.map(|string: &str| string.trim_end_matches(' ')))
435-
.collect::<GenericStringArray<T>>();
436-
437-
Ok(Arc::new(result) as ArrayRef)
438-
}
439-
2 => {
440-
let string_array = as_generic_string_array::<T>(&args[0])?;
441-
let characters_array = as_generic_string_array::<T>(&args[1])?;
442-
443-
let result = string_array
444-
.iter()
445-
.zip(characters_array.iter())
446-
.map(|(string, characters)| match (string, characters) {
447-
(Some(string), Some(characters)) => {
448-
let chars: Vec<char> = characters.chars().collect();
449-
Some(string.trim_end_matches(&chars[..]))
450-
}
451-
_ => None,
452-
})
453-
.collect::<GenericStringArray<T>>();
454-
455-
Ok(Arc::new(result) as ArrayRef)
456-
}
457-
other => internal_err!(
458-
"rtrim was called with {other} arguments. It requires at least 1 and at most 2."
459-
),
460-
}
461-
}
462-
463432
/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
464433
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
465434
pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

0 commit comments

Comments
 (0)