Skip to content

Commit 5f0641b

Browse files
rodrigojdebempraveenbingo
authored andcommitted
ARROW-12410: [C++][Gandiva] Implement regexp_replace function on Gandiva
Closes #10059 from rodrigojdebem/feature/implement-regexp-replace and squashes the following commits: baf2778 <rodrigojdebem> Add implementation for REGEXP_REPLACE Authored-by: rodrigojdebem <rodrigodebem1@gmail.com> Signed-off-by: Praveen <praveen@dremio.com>
1 parent 87e0252 commit 5f0641b

9 files changed

Lines changed: 394 additions & 5 deletions

File tree

cpp/src/gandiva/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ set(SRC_FILES
8686
literal_holder.cc
8787
projector.cc
8888
regex_util.cc
89+
replace_holder.cc
8990
selection_vector.cc
9091
tree_expr_builder.cc
9192
to_date_holder.cc
@@ -230,6 +231,7 @@ add_gandiva_test(internals-test
230231
to_date_holder_test.cc
231232
simple_arena_test.cc
232233
like_holder_test.cc
234+
replace_holder_test.cc
233235
decimal_type_util_test.cc
234236
random_generator_holder_test.cc
235237
hash_utils_test.cc

cpp/src/gandiva/function_holder_registry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "gandiva/like_holder.h"
2929
#include "gandiva/node.h"
3030
#include "gandiva/random_generator_holder.h"
31+
#include "gandiva/replace_holder.h"
3132
#include "gandiva/to_date_holder.h"
3233

3334
namespace gandiva {
@@ -66,6 +67,7 @@ class FunctionHolderRegistry {
6667
{"to_date", LAMBDA_MAKER(ToDateHolder)},
6768
{"random", LAMBDA_MAKER(RandomGeneratorHolder)},
6869
{"rand", LAMBDA_MAKER(RandomGeneratorHolder)},
70+
{"regexp_replace", LAMBDA_MAKER(ReplaceHolder)},
6971
};
7072
return maker_map;
7173
}

cpp/src/gandiva/function_registry_string.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
194194
NativeFunction("rpad", {}, DataTypeVector{utf8(), int32()}, utf8(),
195195
kResultNullIfNull, "rpad_utf8_int32", NativeFunction::kNeedsContext),
196196

197+
NativeFunction("regexp_replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
198+
kResultNullIfNull, "gdv_fn_regexp_replace_utf8_utf8",
199+
NativeFunction::kNeedsContext |
200+
NativeFunction::kNeedsFunctionHolder |
201+
NativeFunction::kCanReturnErrors),
202+
197203
NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8()}, utf8(),
198204
kResultNullIfNull, "concatOperator_utf8_utf8",
199205
NativeFunction::kNeedsContext),

cpp/src/gandiva/gdv_function_stubs.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "gandiva/like_holder.h"
3636
#include "gandiva/precompiled/types.h"
3737
#include "gandiva/random_generator_holder.h"
38+
#include "gandiva/replace_holder.h"
3839
#include "gandiva/to_date_holder.h"
3940

4041
/// Stub functions that can be accessed from LLVM or the pre-compiled library.
@@ -60,6 +61,18 @@ bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len,
6061
return (*holder)(std::string(data, data_len));
6162
}
6263

64+
const char* gdv_fn_regexp_replace_utf8_utf8(
65+
int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len,
66+
const char* /*pattern*/, int32_t /*pattern_len*/, const char* replace_string,
67+
int32_t replace_string_len, int32_t* out_length) {
68+
gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);
69+
70+
gandiva::ReplaceHolder* holder = reinterpret_cast<gandiva::ReplaceHolder*>(holder_ptr);
71+
72+
return (*holder)(context, data, data_len, replace_string, replace_string_len,
73+
out_length);
74+
}
75+
6376
double gdv_fn_random(int64_t ptr) {
6477
gandiva::RandomGeneratorHolder* holder =
6578
reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
@@ -898,6 +911,21 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
898911
types->i1_type() /*return_type*/, args,
899912
reinterpret_cast<void*>(gdv_fn_ilike_utf8_utf8));
900913

914+
// gdv_fn_regexp_replace_utf8_utf8
915+
args = {types->i64_type(), // int64_t ptr
916+
types->i64_type(), // int64_t holder_ptr
917+
types->i8_ptr_type(), // const char* data
918+
types->i32_type(), // int data_len
919+
types->i8_ptr_type(), // const char* pattern
920+
types->i32_type(), // int pattern_len
921+
types->i8_ptr_type(), // const char* replace_string
922+
types->i32_type(), // int32_t replace_string_len
923+
types->i32_ptr_type()}; // int32_t* out_length
924+
925+
engine->AddGlobalMappingForFunc(
926+
"gdv_fn_regexp_replace_utf8_utf8", types->i8_ptr_type() /*return_type*/, args,
927+
reinterpret_cast<void*>(gdv_fn_regexp_replace_utf8_utf8));
928+
901929
// gdv_fn_to_date_utf8_utf8
902930
args = {types->i64_type(), // int64_t execution_context
903931
types->i64_type(), // int64_t holder_ptr

cpp/src/gandiva/precompiled/string_ops_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ TEST(TestStringOps, TestCastBoolToVarchar) {
256256
EXPECT_EQ(std::string(out_str, out_len), "false");
257257
EXPECT_FALSE(ctx.has_error());
258258

259-
out_str = castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len);
259+
castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len);
260260
EXPECT_THAT(ctx.get_error(),
261261
::testing::HasSubstr("Output buffer length can't be negative"));
262262
ctx.Reset();
@@ -1441,13 +1441,13 @@ TEST(TestStringOps, TestReplace) {
14411441
EXPECT_EQ(std::string(out_str, out_len), "TestString");
14421442
EXPECT_FALSE(ctx.has_error());
14431443

1444-
out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5,
1445-
5, &out_len);
1444+
replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5,
1445+
&out_len);
14461446
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
14471447
ctx.Reset();
14481448

1449-
out_str = replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14,
1450-
&out_len);
1449+
replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14,
1450+
&out_len);
14511451
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
14521452
ctx.Reset();
14531453
}

cpp/src/gandiva/replace_holder.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "gandiva/replace_holder.h"
19+
20+
#include "gandiva/node.h"
21+
#include "gandiva/regex_util.h"
22+
23+
namespace gandiva {
24+
25+
static bool IsArrowStringLiteral(arrow::Type::type type) {
26+
return type == arrow::Type::STRING || type == arrow::Type::BINARY;
27+
}
28+
29+
Status ReplaceHolder::Make(const FunctionNode& node,
30+
std::shared_ptr<ReplaceHolder>* holder) {
31+
ARROW_RETURN_IF(node.children().size() != 3,
32+
Status::Invalid("'replace' function requires three parameters"));
33+
34+
auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
35+
ARROW_RETURN_IF(
36+
literal == nullptr,
37+
Status::Invalid("'replace' function requires a literal as the second parameter"));
38+
39+
auto literal_type = literal->return_type()->id();
40+
ARROW_RETURN_IF(
41+
!IsArrowStringLiteral(literal_type),
42+
Status::Invalid(
43+
"'replace' function requires a string literal as the second parameter"));
44+
45+
return Make(arrow::util::get<std::string>(literal->holder()), holder);
46+
}
47+
48+
Status ReplaceHolder::Make(const std::string& sql_pattern,
49+
std::shared_ptr<ReplaceHolder>* holder) {
50+
auto lholder = std::shared_ptr<ReplaceHolder>(new ReplaceHolder(sql_pattern));
51+
ARROW_RETURN_IF(!lholder->regex_.ok(),
52+
Status::Invalid("Building RE2 pattern '", sql_pattern, "' failed"));
53+
54+
*holder = lholder;
55+
return Status::OK();
56+
}
57+
58+
void ReplaceHolder::return_error(ExecutionContext* context, std::string& data,
59+
std::string& replace_string) {
60+
std::string err_msg = "Error replacing '" + replace_string + "' on the given string '" +
61+
data + "' for the given pattern: " + pattern_;
62+
context->set_error_msg(err_msg.c_str());
63+
}
64+
65+
} // namespace gandiva

cpp/src/gandiva/replace_holder.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#pragma once
19+
20+
#include <re2/re2.h>
21+
22+
#include <memory>
23+
#include <string>
24+
25+
#include "arrow/status.h"
26+
#include "gandiva/execution_context.h"
27+
#include "gandiva/function_holder.h"
28+
#include "gandiva/node.h"
29+
#include "gandiva/visibility.h"
30+
31+
namespace gandiva {
32+
33+
/// Function Holder for 'replace'
34+
class GANDIVA_EXPORT ReplaceHolder : public FunctionHolder {
35+
public:
36+
~ReplaceHolder() override = default;
37+
38+
static Status Make(const FunctionNode& node, std::shared_ptr<ReplaceHolder>* holder);
39+
40+
static Status Make(const std::string& sql_pattern,
41+
std::shared_ptr<ReplaceHolder>* holder);
42+
43+
/// Return a new string with the pattern that matched the regex replaced for
44+
/// the replace_input parameter.
45+
const char* operator()(ExecutionContext* ctx, const char* user_input,
46+
int32_t user_input_len, const char* replace_input,
47+
int32_t replace_input_len, int32_t* out_length) {
48+
std::string user_input_as_str(user_input, user_input_len);
49+
std::string replace_input_as_str(replace_input, replace_input_len);
50+
51+
int32_t total_replaces =
52+
RE2::GlobalReplace(&user_input_as_str, regex_, replace_input_as_str);
53+
54+
if (total_replaces < 0) {
55+
return_error(ctx, user_input_as_str, replace_input_as_str);
56+
*out_length = 0;
57+
return "";
58+
}
59+
60+
if (total_replaces == 0) {
61+
*out_length = user_input_len;
62+
return user_input;
63+
}
64+
65+
*out_length = static_cast<int32_t>(user_input_as_str.size());
66+
67+
// This condition treats the case where the whole string is replaced by an empty
68+
// string
69+
if (*out_length == 0) {
70+
return "";
71+
}
72+
73+
char* result_buffer = reinterpret_cast<char*>(ctx->arena()->Allocate(*out_length));
74+
75+
if (result_buffer == NULLPTR) {
76+
ctx->set_error_msg("Could not allocate memory for result");
77+
*out_length = 0;
78+
return "";
79+
}
80+
81+
memcpy(result_buffer, user_input_as_str.data(), *out_length);
82+
83+
return result_buffer;
84+
}
85+
86+
private:
87+
explicit ReplaceHolder(const std::string& pattern)
88+
: pattern_(pattern), regex_(pattern) {}
89+
90+
void return_error(ExecutionContext* context, std::string& data,
91+
std::string& replace_string);
92+
93+
std::string pattern_; // posix pattern string, to help debugging
94+
RE2 regex_; // compiled regex for the pattern
95+
};
96+
97+
} // namespace gandiva

0 commit comments

Comments
 (0)