|
23 | 23 | #include <utf8proc.h> |
24 | 24 | #endif |
25 | 25 |
|
| 26 | +#ifdef ARROW_WITH_RE2 |
| 27 | +#include <re2/re2.h> |
| 28 | +#endif |
| 29 | + |
26 | 30 | #include "arrow/array/builder_binary.h" |
27 | 31 | #include "arrow/array/builder_nested.h" |
28 | 32 | #include "arrow/buffer_builder.h" |
@@ -1230,6 +1234,197 @@ void AddSplit(FunctionRegistry* registry) { |
1230 | 1234 | #endif |
1231 | 1235 | } |
1232 | 1236 |
|
| 1237 | +// ---------------------------------------------------------------------- |
| 1238 | +// Replace substring (plain, regex) |
| 1239 | + |
| 1240 | +template <typename Type, typename Replacer> |
| 1241 | +struct ReplaceSubString { |
| 1242 | + using ScalarType = typename TypeTraits<Type>::ScalarType; |
| 1243 | + using offset_type = typename Type::offset_type; |
| 1244 | + using ValueDataBuilder = TypedBufferBuilder<uint8_t>; |
| 1245 | + using OffsetBuilder = TypedBufferBuilder<offset_type>; |
| 1246 | + using State = OptionsWrapper<ReplaceSubstringOptions>; |
| 1247 | + |
| 1248 | + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { |
| 1249 | + // TODO Cache replacer accross invocations (for regex compilation) |
| 1250 | + Replacer replacer{ctx, State::Get(ctx)}; |
| 1251 | + if (!ctx->HasError()) { |
| 1252 | + Replace(ctx, batch, &replacer, out); |
| 1253 | + } |
| 1254 | + } |
| 1255 | + |
| 1256 | + static void Replace(KernelContext* ctx, const ExecBatch& batch, Replacer* replacer, |
| 1257 | + Datum* out) { |
| 1258 | + ValueDataBuilder value_data_builder(ctx->memory_pool()); |
| 1259 | + OffsetBuilder offset_builder(ctx->memory_pool()); |
| 1260 | + |
| 1261 | + if (batch[0].kind() == Datum::ARRAY) { |
| 1262 | + // We already know how many strings we have, so we can use Reserve/UnsafeAppend |
| 1263 | + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Reserve(batch[0].array()->length)); |
| 1264 | + offset_builder.UnsafeAppend(0); // offsets start at 0 |
| 1265 | + |
| 1266 | + const ArrayData& input = *batch[0].array(); |
| 1267 | + KERNEL_RETURN_IF_ERROR( |
| 1268 | + ctx, VisitArrayDataInline<Type>( |
| 1269 | + input, |
| 1270 | + [&](util::string_view s) { |
| 1271 | + RETURN_NOT_OK(replacer->ReplaceString(s, &value_data_builder)); |
| 1272 | + offset_builder.UnsafeAppend( |
| 1273 | + static_cast<offset_type>(value_data_builder.length())); |
| 1274 | + return Status::OK(); |
| 1275 | + }, |
| 1276 | + [&]() { |
| 1277 | + // offset for null value |
| 1278 | + offset_builder.UnsafeAppend( |
| 1279 | + static_cast<offset_type>(value_data_builder.length())); |
| 1280 | + return Status::OK(); |
| 1281 | + })); |
| 1282 | + ArrayData* output = out->mutable_array(); |
| 1283 | + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&output->buffers[2])); |
| 1284 | + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Finish(&output->buffers[1])); |
| 1285 | + } else { |
| 1286 | + const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar()); |
| 1287 | + auto result = std::make_shared<ScalarType>(); |
| 1288 | + if (input.is_valid) { |
| 1289 | + util::string_view s = static_cast<util::string_view>(*input.value); |
| 1290 | + KERNEL_RETURN_IF_ERROR(ctx, replacer->ReplaceString(s, &value_data_builder)); |
| 1291 | + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&result->value)); |
| 1292 | + result->is_valid = true; |
| 1293 | + } |
| 1294 | + out->value = result; |
| 1295 | + } |
| 1296 | + } |
| 1297 | +}; |
| 1298 | + |
| 1299 | +struct PlainSubStringReplacer { |
| 1300 | + const ReplaceSubstringOptions& options_; |
| 1301 | + |
| 1302 | + PlainSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) |
| 1303 | + : options_(options) {} |
| 1304 | + |
| 1305 | + Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) { |
| 1306 | + const char* i = s.begin(); |
| 1307 | + const char* end = s.end(); |
| 1308 | + int64_t max_replacements = options_.max_replacements; |
| 1309 | + while ((i < end) && (max_replacements != 0)) { |
| 1310 | + const char* pos = |
| 1311 | + std::search(i, end, options_.pattern.begin(), options_.pattern.end()); |
| 1312 | + if (pos == end) { |
| 1313 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1314 | + static_cast<int64_t>(end - i))); |
| 1315 | + i = end; |
| 1316 | + } else { |
| 1317 | + // the string before the pattern |
| 1318 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1319 | + static_cast<int64_t>(pos - i))); |
| 1320 | + // the replacement |
| 1321 | + RETURN_NOT_OK( |
| 1322 | + builder->Append(reinterpret_cast<const uint8_t*>(options_.replacement.data()), |
| 1323 | + options_.replacement.length())); |
| 1324 | + // skip pattern |
| 1325 | + i = pos + options_.pattern.length(); |
| 1326 | + max_replacements--; |
| 1327 | + } |
| 1328 | + } |
| 1329 | + // if we exited early due to max_replacements, add the trailing part |
| 1330 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1331 | + static_cast<int64_t>(end - i))); |
| 1332 | + return Status::OK(); |
| 1333 | + } |
| 1334 | +}; |
| 1335 | + |
| 1336 | +#ifdef ARROW_WITH_RE2 |
| 1337 | +struct RegexSubStringReplacer { |
| 1338 | + const ReplaceSubstringOptions& options_; |
| 1339 | + const RE2 regex_find_; |
| 1340 | + const RE2 regex_replacement_; |
| 1341 | + |
| 1342 | + // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore |
| 1343 | + // we have 2 regexes, one with () around it, one without. |
| 1344 | + RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) |
| 1345 | + : options_(options), |
| 1346 | + regex_find_("(" + options_.pattern + ")"), |
| 1347 | + regex_replacement_(options_.pattern) { |
| 1348 | + if (!(regex_find_.ok() && regex_replacement_.ok())) { |
| 1349 | + ctx->SetStatus(Status::Invalid("Regular expression error")); |
| 1350 | + return; |
| 1351 | + } |
| 1352 | + } |
| 1353 | + |
| 1354 | + Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) { |
| 1355 | + re2::StringPiece replacement(options_.replacement); |
| 1356 | + if (options_.max_replacements == -1) { |
| 1357 | + std::string s_copy(s.to_string()); |
| 1358 | + re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); |
| 1359 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(s_copy.data()), |
| 1360 | + s_copy.length())); |
| 1361 | + return Status::OK(); |
| 1362 | + } |
| 1363 | + |
| 1364 | + // Since RE2 does not have the concept of max_replacements, we have to do some work |
| 1365 | + // ourselves. |
| 1366 | + // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite |
| 1367 | + const char* i = s.begin(); |
| 1368 | + const char* end = s.end(); |
| 1369 | + re2::StringPiece piece(s.data(), s.length()); |
| 1370 | + |
| 1371 | + int64_t max_replacements = options_.max_replacements; |
| 1372 | + while ((i < end) && (max_replacements != 0)) { |
| 1373 | + std::string found; |
| 1374 | + if (!re2::RE2::FindAndConsume(&piece, regex_find_, &found)) { |
| 1375 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1376 | + static_cast<int64_t>(end - i))); |
| 1377 | + i = end; |
| 1378 | + } else { |
| 1379 | + // wind back to the beginning of the match |
| 1380 | + const char* pos = piece.begin() - found.length(); |
| 1381 | + // the string before the pattern |
| 1382 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1383 | + static_cast<int64_t>(pos - i))); |
| 1384 | + // replace the pattern in what we found |
| 1385 | + if (!re2::RE2::Replace(&found, regex_replacement_, replacement)) { |
| 1386 | + return Status::Invalid("Regex found, but replacement failed"); |
| 1387 | + } |
| 1388 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(found.data()), |
| 1389 | + static_cast<int64_t>(found.length()))); |
| 1390 | + // skip pattern |
| 1391 | + i = piece.begin(); |
| 1392 | + max_replacements--; |
| 1393 | + } |
| 1394 | + } |
| 1395 | + // If we exited early due to max_replacements, add the trailing part |
| 1396 | + RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i), |
| 1397 | + static_cast<int64_t>(end - i))); |
| 1398 | + return Status::OK(); |
| 1399 | + } |
| 1400 | +}; |
| 1401 | +#endif |
| 1402 | + |
| 1403 | +template <typename Type> |
| 1404 | +using ReplaceSubStringPlain = ReplaceSubString<Type, PlainSubStringReplacer>; |
| 1405 | + |
| 1406 | +const FunctionDoc replace_substring_doc( |
| 1407 | + "Replace non-overlapping substrings that match pattern by replacement", |
| 1408 | + ("For each string in `strings`, replace non-overlapping substrings that match\n" |
| 1409 | + "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n" |
| 1410 | + "maximum amount of replacements made, counting from the left. Null values emit\n" |
| 1411 | + "null."), |
| 1412 | + {"strings"}, "ReplaceSubstringOptions"); |
| 1413 | + |
| 1414 | +#ifdef ARROW_WITH_RE2 |
| 1415 | +template <typename Type> |
| 1416 | +using ReplaceSubStringRegex = ReplaceSubString<Type, RegexSubStringReplacer>; |
| 1417 | + |
| 1418 | +const FunctionDoc replace_substring_regex_doc( |
| 1419 | + "Replace non-overlapping substrings that match regex `pattern` by `replacement`", |
| 1420 | + ("For each string in `strings`, replace non-overlapping substrings that match the\n" |
| 1421 | + "regular expression `pattern` by `replacement` using the Google RE2 library.\n" |
| 1422 | + "If `max_replacements != -1`, it determines the maximum amount of replacements\n" |
| 1423 | + "made, counting from the left. Note that if the pattern contains groups,\n" |
| 1424 | + "backreferencing macan be used. Null values emit null."), |
| 1425 | + {"strings"}, "ReplaceSubstringOptions"); |
| 1426 | +#endif |
| 1427 | + |
1233 | 1428 | // ---------------------------------------------------------------------- |
1234 | 1429 | // strptime string parsing |
1235 | 1430 |
|
@@ -1904,6 +2099,14 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { |
1904 | 2099 | AddBinaryLength(registry); |
1905 | 2100 | AddUtf8Length(registry); |
1906 | 2101 | AddMatchSubstring(registry); |
| 2102 | + MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>( |
| 2103 | + "replace_substring", registry, &replace_substring_doc, |
| 2104 | + MemAllocation::NO_PREALLOCATE); |
| 2105 | +#ifdef ARROW_WITH_RE2 |
| 2106 | + MakeUnaryStringBatchKernelWithState<ReplaceSubStringRegex>( |
| 2107 | + "replace_substring_regex", registry, &replace_substring_regex_doc, |
| 2108 | + MemAllocation::NO_PREALLOCATE); |
| 2109 | +#endif |
1907 | 2110 | AddStrptime(registry); |
1908 | 2111 | } |
1909 | 2112 |
|
|
0 commit comments