Skip to content

Commit

Permalink
GH-44615: [C++][Compute] Add extract_regex_span function (#45577)
Browse files Browse the repository at this point in the history
### Rationale for this change

While the `extract_regex` function returns substrings of the matching regex captures, `extract_regex_span` returns (index, length) pairs of these substrings relative to the original string values.

### Are these changes tested?

Yes, by dedicated unit tests.

### Are there any user-facing changes?

No, except a new compute function.

* GitHub Issue: #44615

Lead-authored-by: arash andishgar <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
arashandishgar and pitrou authored Mar 11, 2025
1 parent cf77281 commit 0494115
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 34 deletions.
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ static auto kElementWiseAggregateOptionsType =
DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
static auto kExtractRegexOptionsType = GetFunctionOptionsType<ExtractRegexOptions>(
DataMember("pattern", &ExtractRegexOptions::pattern));
static auto kExtractRegexSpanOptionsType =
GetFunctionOptionsType<ExtractRegexSpanOptions>(
DataMember("pattern", &ExtractRegexSpanOptions::pattern));
static auto kJoinOptionsType = GetFunctionOptionsType<JoinOptions>(
DataMember("null_handling", &JoinOptions::null_handling),
DataMember("null_replacement", &JoinOptions::null_replacement));
Expand Down Expand Up @@ -438,6 +441,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern)
ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {}
constexpr char ExtractRegexOptions::kTypeName[];

ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern)
: FunctionOptions(internal::kExtractRegexSpanOptionsType),
pattern(std::move(pattern)) {}
ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {}
constexpr char ExtractRegexSpanOptions::kTypeName[];

JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement)
: FunctionOptions(internal::kJoinOptionsType),
null_handling(null_handling),
Expand Down Expand Up @@ -684,6 +693,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
std::string pattern;
};

class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
public:
explicit ExtractRegexSpanOptions(std::string pattern);
ExtractRegexSpanOptions();
static constexpr char const kTypeName[] = "ExtractRegexSpanOptions";

/// Regular expression with named capture fields
std::string pattern;
};

/// Options for IsIn and IndexIn functions
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
public:
Expand Down
205 changes: 176 additions & 29 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>

#include "arrow/array/builder_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/kernels/scalar_string_internal.h"
#include "arrow/result.h"
#include "arrow/util/config.h"
Expand Down Expand Up @@ -2184,29 +2185,40 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {

using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData {
// Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE)
std::unique_ptr<RE2> regex;
std::vector<std::string> group_names;

static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
RETURN_NOT_OK(RegexStatus(*data.regex));

const int group_count = data.regex->NumberOfCapturingGroups();
const auto& name_map = data.regex->CapturingGroupNames();
data.group_names.reserve(group_count);
struct BaseExtractRegexData {
Status Init() {
RETURN_NOT_OK(RegexStatus(*regex));
const int group_count = regex->NumberOfCapturingGroups();
const auto& name_map = regex->CapturingGroupNames();
group_names.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
data.group_names.emplace_back(item->second);
group_names.emplace_back(item->second);
}
return Status::OK();
}

int64_t num_groups() const { return static_cast<int64_t>(group_names.size()); }

std::unique_ptr<RE2> regex;
std::vector<std::string> group_names;

protected:
explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
};

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData : public BaseExtractRegexData {
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Expand All @@ -2220,7 +2232,7 @@ struct ExtractRegexData {
// of each field in the output struct type.
DCHECK(is_base_binary_like(input_type->id()));
FieldVector fields;
fields.reserve(group_names.size());
fields.reserve(num_groups());
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
[&](const std::string& name) { return field(name, owned_type); });
Expand All @@ -2229,7 +2241,7 @@ struct ExtractRegexData {

private:
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
: BaseExtractRegexData(pattern, is_utf8) {}
};

Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
Expand All @@ -2240,17 +2252,17 @@ Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
}

struct ExtractRegexBase {
const ExtractRegexData& data;
const BaseExtractRegexData& data;
const int group_count;
std::vector<re2::StringPiece> found_values;
std::vector<RE2::Arg> args;
std::vector<const RE2::Arg*> args_pointers;
const RE2::Arg** args_pointers_start;
const RE2::Arg* null_arg = nullptr;

explicit ExtractRegexBase(const ExtractRegexData& data)
explicit ExtractRegexBase(const BaseExtractRegexData& data)
: data(data),
group_count(static_cast<int>(data.group_names.size())),
group_count(static_cast<int>(data.num_groups())),
found_values(group_count) {
args.reserve(group_count);
args_pointers.reserve(group_count);
Expand Down Expand Up @@ -2280,25 +2292,23 @@ struct ExtractRegex : public ExtractRegexBase {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8));
return ExtractRegex{data}.Extract(ctx, batch, out);
return ExtractRegex(data).Extract(ctx, batch, out);
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO: why is this needed? Type resolution should already be
// done and the output type set in the output variable
ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes()));
DCHECK_NE(out_type.type, nullptr);
std::shared_ptr<DataType> type = out_type.GetSharedPtr();

std::unique_ptr<ArrayBuilder> array_builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
DCHECK_NE(out->array_data(), nullptr);
std::shared_ptr<DataType> type = out->array_data()->type;
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> array_builder,
MakeBuilder(type, ctx->memory_pool()));
StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].length()));

std::vector<BuilderType*> field_builders;
field_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
field_builders.push_back(
checked_cast<BuilderType*>(struct_builder->field_builder(i)));
RETURN_NOT_OK(field_builders.back()->Reserve(batch[0].length()));
}

auto visit_null = [&]() { return struct_builder->AppendNull(); };
Expand Down Expand Up @@ -2347,6 +2357,142 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}

struct ExtractRegexSpanData : public BaseExtractRegexData {
static Result<ExtractRegexSpanData> Make(const std::string& pattern,
bool is_utf8 = true) {
auto data = ExtractRegexSpanData(pattern, is_utf8);
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
const DataType* input_type = types[0].type;
if (input_type == nullptr) {
return nullptr;
}
DCHECK(is_base_binary_like(input_type->id()));
FieldVector fields;
fields.reserve(num_groups());
auto index_type = is_binary_like(input_type->id()) ? int32() : int64();
for (const auto& group_name : group_names) {
// list size is 2 as every span contains position and length
fields.push_back(field(group_name, fixed_size_list(index_type, 2)));
}
return struct_(std::move(fields));
}

private:
ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
: BaseExtractRegexData(pattern, is_utf8) {}
};

template <typename Type>
struct ExtractRegexSpan : ExtractRegexBase {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using offset_type = typename Type::offset_type;
using OffsetBuilderType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType;
using OffsetCType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType;

using ExtractRegexBase::ExtractRegexBase;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data,
ExtractRegexSpanData::Make(options.pattern, Type::is_utf8));
return ExtractRegexSpan{data}.Extract(ctx, batch, out);
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
DCHECK_NE(out->array_data(), nullptr);
std::shared_ptr<DataType> out_type = out->array_data()->type;
ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(out_type, ctx->memory_pool()));
StructBuilder* struct_builder = checked_cast<StructBuilder*>(out_builder.get());
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));

std::vector<FixedSizeListBuilder*> span_builders;
std::vector<OffsetBuilderType*> array_builders;
span_builders.reserve(group_count);
array_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
span_builders.push_back(
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
array_builders.push_back(
checked_cast<OffsetBuilderType*>(span_builders.back()->value_builder()));
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].length()));
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].length()));
}

auto visit_null = [&]() { return struct_builder->AppendNull(); };
auto visit_value = [&](std::string_view element) -> Status {
if (Match(element)) {
for (int i = 0; i < group_count; i++) {
// https://github.com/google/re2/issues/24#issuecomment-97653183
if (found_values[i].data() != nullptr) {
int64_t begin = found_values[i].data() - element.data();
int64_t size = found_values[i].size();
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(begin));
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(size));
ARROW_RETURN_NOT_OK(span_builders[i]->Append());
} else {
ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
}
}
ARROW_RETURN_NOT_OK(struct_builder->Append());
} else {
ARROW_RETURN_NOT_OK(struct_builder->AppendNull());
}
return Status::OK();
};
ARROW_RETURN_NOT_OK(
VisitArraySpanInline<Type>(batch[0].array, visit_value, visit_null));

ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish());
out->value = std::move(out_array->data());
return Status::OK();
}
};

const FunctionDoc extract_regex_span_doc(
"Extract string spans captured by a regex pattern",
("For each string in strings, match the regular expression and, if\n"
"successful, emit a struct with field names and values coming from the\n"
"regular expression's named capture groups. Each struct field value\n"
"will be a fixed_size_list(offset_type, 2) where offset_type is int32\n"
"or int64, depending on the input string type. The two elements in\n"
"each fixed-size list are the index and the length of the substring\n"
"matched by the corresponding named capture group.\n"
"\n"
"If the input is null or the regular expression fails matching,\n"
"a null output value is emitted.\n"
"\n"
"Regular expression matching is done using the Google RE2 library."),
{"strings"}, "ExtractRegexSpanOptions", /*options_required=*/true);

Result<TypeHolder> ResolveExtractRegexSpanOutputType(
KernelContext* ctx, const std::vector<TypeHolder>& types) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern));
return span.ResolveOutputType(types);
}

void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("extract_regex_span", Arity::Unary(),
extract_regex_span_doc);
OutputType output_type(ResolveExtractRegexSpanOutputType);
for (const auto& type : BaseBinaryTypes()) {
ScalarKernel kernel({type}, output_type,
GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),
OptionsWrapper<ExtractRegexSpanOptions>::Init);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
DCHECK_OK(registry->AddFunction(func));
}
#endif // ARROW_WITH_RE2

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -3457,6 +3603,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddAsciiStringSplitWhitespace(registry);
#ifdef ARROW_WITH_RE2
AddAsciiStringSplitRegex(registry);
AddAsciiStringExtractRegexSpan(registry);
#endif
AddAsciiStringJoin(registry);
AddAsciiStringRepeat(registry);
Expand Down
Loading

0 comments on commit 0494115

Please sign in to comment.