Skip to content

Commit a2a9f5d

Browse files
maartenbreddelsnealrichardsonpitrou
committed
ARROW-10306: [C++] Add string replacement kernel
Two new kernels * replace_substring like Python's str.replace * replace_substring_re2 like Python's re.sub Closes #8468 from maartenbreddels/ARROW-10306 Lead-authored-by: Maarten A. Breddels <maartenbreddels@gmail.com> Co-authored-by: Neal Richardson <neal.p.richardson@gmail.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
1 parent 99d7291 commit a2a9f5d

10 files changed

Lines changed: 351 additions & 20 deletions

File tree

ci/scripts/PKGBUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ build() {
7979
export CPPFLAGS="${CPPFLAGS} -I${MINGW_PREFIX}/include"
8080
export LIBS="-L${MINGW_PREFIX}/libs"
8181
export ARROW_S3=OFF
82+
export ARROW_WITH_RE2=OFF
8283
else
8384
export ARROW_S3=ON
85+
export ARROW_WITH_RE2=ON
8486
fi
8587

8688
MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \
@@ -105,6 +107,7 @@ build() {
105107
-DARROW_SNAPPY_USE_SHARED=OFF \
106108
-DARROW_USE_GLOG=OFF \
107109
-DARROW_WITH_LZ4=ON \
110+
-DARROW_WITH_RE2="${ARROW_WITH_RE2}" \
108111
-DARROW_WITH_SNAPPY=ON \
109112
-DARROW_WITH_ZLIB=ON \
110113
-DARROW_WITH_ZSTD=ON \

cpp/src/arrow/compute/api_scalar.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions {
6868
std::string pattern;
6969
};
7070

71+
struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
72+
explicit ReplaceSubstringOptions(std::string pattern, std::string replacement,
73+
int64_t max_replacements = -1)
74+
: pattern(pattern), replacement(replacement), max_replacements(max_replacements) {}
75+
76+
/// Pattern to match, literal, or regular expression depending on which kernel is used
77+
std::string pattern;
78+
/// String to replace the pattern with
79+
std::string replacement;
80+
/// Max number of substrings to replace (-1 means unbounded)
81+
int64_t max_replacements;
82+
};
83+
7184
/// Options for IsIn and IndexIn functions
7285
struct ARROW_EXPORT SetLookupOptions : public FunctionOptions {
7386
explicit SetLookupOptions(Datum value_set, bool skip_nulls = false)

cpp/src/arrow/compute/kernels/scalar_string.cc

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
#include <utf8proc.h>
2424
#endif
2525

26+
#ifdef ARROW_WITH_RE2
27+
#include <re2/re2.h>
28+
#endif
29+
2630
#include "arrow/array/builder_binary.h"
2731
#include "arrow/array/builder_nested.h"
2832
#include "arrow/buffer_builder.h"
@@ -1230,6 +1234,197 @@ void AddSplit(FunctionRegistry* registry) {
12301234
#endif
12311235
}
12321236

1237+
// ----------------------------------------------------------------------
1238+
// Replace substring (plain, regex)
1239+
1240+
template <typename Type, typename Replacer>
1241+
struct ReplaceSubString {
1242+
using ScalarType = typename TypeTraits<Type>::ScalarType;
1243+
using offset_type = typename Type::offset_type;
1244+
using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
1245+
using OffsetBuilder = TypedBufferBuilder<offset_type>;
1246+
using State = OptionsWrapper<ReplaceSubstringOptions>;
1247+
1248+
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
1249+
// TODO Cache replacer accross invocations (for regex compilation)
1250+
Replacer replacer{ctx, State::Get(ctx)};
1251+
if (!ctx->HasError()) {
1252+
Replace(ctx, batch, &replacer, out);
1253+
}
1254+
}
1255+
1256+
static void Replace(KernelContext* ctx, const ExecBatch& batch, Replacer* replacer,
1257+
Datum* out) {
1258+
ValueDataBuilder value_data_builder(ctx->memory_pool());
1259+
OffsetBuilder offset_builder(ctx->memory_pool());
1260+
1261+
if (batch[0].kind() == Datum::ARRAY) {
1262+
// We already know how many strings we have, so we can use Reserve/UnsafeAppend
1263+
KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Reserve(batch[0].array()->length));
1264+
offset_builder.UnsafeAppend(0); // offsets start at 0
1265+
1266+
const ArrayData& input = *batch[0].array();
1267+
KERNEL_RETURN_IF_ERROR(
1268+
ctx, VisitArrayDataInline<Type>(
1269+
input,
1270+
[&](util::string_view s) {
1271+
RETURN_NOT_OK(replacer->ReplaceString(s, &value_data_builder));
1272+
offset_builder.UnsafeAppend(
1273+
static_cast<offset_type>(value_data_builder.length()));
1274+
return Status::OK();
1275+
},
1276+
[&]() {
1277+
// offset for null value
1278+
offset_builder.UnsafeAppend(
1279+
static_cast<offset_type>(value_data_builder.length()));
1280+
return Status::OK();
1281+
}));
1282+
ArrayData* output = out->mutable_array();
1283+
KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&output->buffers[2]));
1284+
KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Finish(&output->buffers[1]));
1285+
} else {
1286+
const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
1287+
auto result = std::make_shared<ScalarType>();
1288+
if (input.is_valid) {
1289+
util::string_view s = static_cast<util::string_view>(*input.value);
1290+
KERNEL_RETURN_IF_ERROR(ctx, replacer->ReplaceString(s, &value_data_builder));
1291+
KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&result->value));
1292+
result->is_valid = true;
1293+
}
1294+
out->value = result;
1295+
}
1296+
}
1297+
};
1298+
1299+
struct PlainSubStringReplacer {
1300+
const ReplaceSubstringOptions& options_;
1301+
1302+
PlainSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options)
1303+
: options_(options) {}
1304+
1305+
Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) {
1306+
const char* i = s.begin();
1307+
const char* end = s.end();
1308+
int64_t max_replacements = options_.max_replacements;
1309+
while ((i < end) && (max_replacements != 0)) {
1310+
const char* pos =
1311+
std::search(i, end, options_.pattern.begin(), options_.pattern.end());
1312+
if (pos == end) {
1313+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1314+
static_cast<int64_t>(end - i)));
1315+
i = end;
1316+
} else {
1317+
// the string before the pattern
1318+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1319+
static_cast<int64_t>(pos - i)));
1320+
// the replacement
1321+
RETURN_NOT_OK(
1322+
builder->Append(reinterpret_cast<const uint8_t*>(options_.replacement.data()),
1323+
options_.replacement.length()));
1324+
// skip pattern
1325+
i = pos + options_.pattern.length();
1326+
max_replacements--;
1327+
}
1328+
}
1329+
// if we exited early due to max_replacements, add the trailing part
1330+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1331+
static_cast<int64_t>(end - i)));
1332+
return Status::OK();
1333+
}
1334+
};
1335+
1336+
#ifdef ARROW_WITH_RE2
1337+
struct RegexSubStringReplacer {
1338+
const ReplaceSubstringOptions& options_;
1339+
const RE2 regex_find_;
1340+
const RE2 regex_replacement_;
1341+
1342+
// Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore
1343+
// we have 2 regexes, one with () around it, one without.
1344+
RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options)
1345+
: options_(options),
1346+
regex_find_("(" + options_.pattern + ")"),
1347+
regex_replacement_(options_.pattern) {
1348+
if (!(regex_find_.ok() && regex_replacement_.ok())) {
1349+
ctx->SetStatus(Status::Invalid("Regular expression error"));
1350+
return;
1351+
}
1352+
}
1353+
1354+
Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) {
1355+
re2::StringPiece replacement(options_.replacement);
1356+
if (options_.max_replacements == -1) {
1357+
std::string s_copy(s.to_string());
1358+
re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement);
1359+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(s_copy.data()),
1360+
s_copy.length()));
1361+
return Status::OK();
1362+
}
1363+
1364+
// Since RE2 does not have the concept of max_replacements, we have to do some work
1365+
// ourselves.
1366+
// We might do this faster similar to RE2::GlobalReplace using Match and Rewrite
1367+
const char* i = s.begin();
1368+
const char* end = s.end();
1369+
re2::StringPiece piece(s.data(), s.length());
1370+
1371+
int64_t max_replacements = options_.max_replacements;
1372+
while ((i < end) && (max_replacements != 0)) {
1373+
std::string found;
1374+
if (!re2::RE2::FindAndConsume(&piece, regex_find_, &found)) {
1375+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1376+
static_cast<int64_t>(end - i)));
1377+
i = end;
1378+
} else {
1379+
// wind back to the beginning of the match
1380+
const char* pos = piece.begin() - found.length();
1381+
// the string before the pattern
1382+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1383+
static_cast<int64_t>(pos - i)));
1384+
// replace the pattern in what we found
1385+
if (!re2::RE2::Replace(&found, regex_replacement_, replacement)) {
1386+
return Status::Invalid("Regex found, but replacement failed");
1387+
}
1388+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(found.data()),
1389+
static_cast<int64_t>(found.length())));
1390+
// skip pattern
1391+
i = piece.begin();
1392+
max_replacements--;
1393+
}
1394+
}
1395+
// If we exited early due to max_replacements, add the trailing part
1396+
RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
1397+
static_cast<int64_t>(end - i)));
1398+
return Status::OK();
1399+
}
1400+
};
1401+
#endif
1402+
1403+
template <typename Type>
1404+
using ReplaceSubStringPlain = ReplaceSubString<Type, PlainSubStringReplacer>;
1405+
1406+
const FunctionDoc replace_substring_doc(
1407+
"Replace non-overlapping substrings that match pattern by replacement",
1408+
("For each string in `strings`, replace non-overlapping substrings that match\n"
1409+
"`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n"
1410+
"maximum amount of replacements made, counting from the left. Null values emit\n"
1411+
"null."),
1412+
{"strings"}, "ReplaceSubstringOptions");
1413+
1414+
#ifdef ARROW_WITH_RE2
1415+
template <typename Type>
1416+
using ReplaceSubStringRegex = ReplaceSubString<Type, RegexSubStringReplacer>;
1417+
1418+
const FunctionDoc replace_substring_regex_doc(
1419+
"Replace non-overlapping substrings that match regex `pattern` by `replacement`",
1420+
("For each string in `strings`, replace non-overlapping substrings that match the\n"
1421+
"regular expression `pattern` by `replacement` using the Google RE2 library.\n"
1422+
"If `max_replacements != -1`, it determines the maximum amount of replacements\n"
1423+
"made, counting from the left. Note that if the pattern contains groups,\n"
1424+
"backreferencing macan be used. Null values emit null."),
1425+
{"strings"}, "ReplaceSubstringOptions");
1426+
#endif
1427+
12331428
// ----------------------------------------------------------------------
12341429
// strptime string parsing
12351430

@@ -1904,6 +2099,14 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
19042099
AddBinaryLength(registry);
19052100
AddUtf8Length(registry);
19062101
AddMatchSubstring(registry);
2102+
MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>(
2103+
"replace_substring", registry, &replace_substring_doc,
2104+
MemAllocation::NO_PREALLOCATE);
2105+
#ifdef ARROW_WITH_RE2
2106+
MakeUnaryStringBatchKernelWithState<ReplaceSubStringRegex>(
2107+
"replace_substring_regex", registry, &replace_substring_regex_doc,
2108+
MemAllocation::NO_PREALLOCATE);
2109+
#endif
19072110
AddStrptime(registry);
19082111
}
19092112

cpp/src/arrow/compute/kernels/scalar_string_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ class BaseTestStringKernels : public ::testing::Test {
4848
CheckScalarUnary(func_name, type(), json_input, out_ty, json_expected, options);
4949
}
5050

51+
void CheckBinaryScalar(std::string func_name, std::string json_left_input,
52+
std::string json_right_scalar, std::shared_ptr<DataType> out_ty,
53+
std::string json_expected,
54+
const FunctionOptions* options = nullptr) {
55+
CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty,
56+
json_expected, options);
57+
}
58+
5159
std::shared_ptr<DataType> type() { return TypeTraits<TestType>::type_singleton(); }
5260

5361
std::shared_ptr<DataType> offset_type() {
@@ -422,6 +430,52 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
422430
&options_max);
423431
}
424432

433+
TYPED_TEST(TestStringKernels, ReplaceSubstring) {
434+
ReplaceSubstringOptions options{"foo", "bazz"};
435+
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
436+
this->type(), R"(["bazz", "this bazz that bazz", null])", &options);
437+
}
438+
439+
TYPED_TEST(TestStringKernels, ReplaceSubstringLimited) {
440+
ReplaceSubstringOptions options{"foo", "bazz", 1};
441+
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
442+
this->type(), R"(["bazz", "this bazz that foo", null])", &options);
443+
}
444+
445+
TYPED_TEST(TestStringKernels, ReplaceSubstringNoOptions) {
446+
Datum input = ArrayFromJSON(this->type(), "[]");
447+
ASSERT_RAISES(Invalid, CallFunction("replace_substring", {input}));
448+
}
449+
450+
#ifdef ARROW_WITH_RE2
451+
TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) {
452+
ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz"};
453+
this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])",
454+
this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz", null])",
455+
&options_regex);
456+
// make sure we match non-overlapping
457+
ReplaceSubstringOptions options_regex2{"(a.a)", "aba\\1"};
458+
this->CheckUnary("replace_substring_regex", R"(["aaaaaa"])", this->type(),
459+
R"(["abaaaaabaaaa"])", &options_regex2);
460+
}
461+
462+
TYPED_TEST(TestStringKernels, ReplaceSubstringRegexLimited) {
463+
// With a finite number of replacements
464+
ReplaceSubstringOptions options1{"foo", "bazz", 1};
465+
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
466+
this->type(), R"(["bazz", "this bazz that foo", null])", &options1);
467+
ReplaceSubstringOptions options_regex1{"(fo+)\\s*", "\\1-bazz", 1};
468+
this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])",
469+
this->type(), R"(["foo-bazz", "this foo-bazzthat foo", null])",
470+
&options_regex1);
471+
}
472+
473+
TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) {
474+
Datum input = ArrayFromJSON(this->type(), "[]");
475+
ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input}));
476+
}
477+
#endif
478+
425479
TYPED_TEST(TestStringKernels, Strptime) {
426480
std::string input1 = R"(["5/1/2020", null, "12/11/1900"])";
427481
std::string output1 = R"(["2020-05-01", null, "1900-12-11"])";

docs/source/cpp/compute.rst

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -426,21 +426,25 @@ The third set of functions examines string elements on a byte-per-byte basis:
426426
String transforms
427427
~~~~~~~~~~~~~~~~~
428428

429-
+--------------------------+------------+-------------------------+---------------------+---------+
430-
| Function name | Arity | Input types | Output type | Notes |
431-
+==========================+============+=========================+=====================+=========+
432-
| ascii_lower | Unary | String-like | String-like | \(1) |
433-
+--------------------------+------------+-------------------------+---------------------+---------+
434-
| ascii_upper | Unary | String-like | String-like | \(1) |
435-
+--------------------------+------------+-------------------------+---------------------+---------+
436-
| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) |
437-
+--------------------------+------------+-------------------------+---------------------+---------+
438-
| utf8_length | Unary | String-like | Int32 or Int64 | \(3) |
439-
+--------------------------+------------+-------------------------+---------------------+---------+
440-
| utf8_lower | Unary | String-like | String-like | \(4) |
441-
+--------------------------+------------+-------------------------+---------------------+---------+
442-
| utf8_upper | Unary | String-like | String-like | \(4) |
443-
+--------------------------+------------+-------------------------+---------------------+---------+
429+
+--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+
430+
| Function name | Arity | Input types | Output type | Notes | Options class |
431+
+==========================+============+=========================+=====================+=========+=======================================+
432+
| ascii_lower | Unary | String-like | String-like | \(1) | |
433+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
434+
| ascii_upper | Unary | String-like | String-like | \(1) | |
435+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
436+
| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) | |
437+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
438+
| replace_substring | Unary | String-like | String-like | \(3) | :struct:`ReplaceSubstringOptions` |
439+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
440+
| replace_substring_regex | Unary | String-like | String-like | \(4) | :struct:`ReplaceSubstringOptions` |
441+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
442+
| utf8_length | Unary | String-like | Int32 or Int64 | \(5) | |
443+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
444+
| utf8_lower | Unary | String-like | String-like | \(6) | |
445+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
446+
| utf8_upper | Unary | String-like | String-like | \(6) | |
447+
+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
444448

445449

446450
* \(1) Each ASCII character in the input is converted to lowercase or
@@ -449,10 +453,23 @@ String transforms
449453
* \(2) Output is the physical length in bytes of each input element. Output
450454
type is Int32 for Binary / String, Int64 for LargeBinary / LargeString.
451455

452-
* \(3) Output is the number of characters (not bytes) of each input element.
456+
* \(3) Replace non-overlapping substrings that match to
457+
:member:`ReplaceSubstringOptions::pattern` by
458+
:member:`ReplaceSubstringOptions::replacement`. If
459+
:member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the
460+
maximum number of replacements made, counting from the left.
461+
462+
* \(4) Replace non-overlapping substrings that match to the regular expression
463+
:member:`ReplaceSubstringOptions::pattern` by
464+
:member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If
465+
:member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the
466+
maximum number of replacements made, counting from the left. Note that if the
467+
pattern contains groups, backreferencing can be used.
468+
469+
* \(5) Output is the number of characters (not bytes) of each input element.
453470
Output type is Int32 for String, Int64 for LargeString.
454471

455-
* \(4) Each UTF8-encoded character in the input is converted to lowercase or
472+
* \(6) Each UTF8-encoded character in the input is converted to lowercase or
456473
uppercase.
457474

458475

0 commit comments

Comments
 (0)