Skip to content

Commit bca09a2

Browse files
authored
Add regex interface with re2 and std::regex implementations
Differential Revision: D73071817 Pull Request resolved: #48
1 parent 6a6e24f commit bca09a2

File tree

6 files changed

+248
-0
lines changed

6 files changed

+248
-0
lines changed
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
#include <string>
13+
14+
#include <re2/re2.h>
15+
16+
#include <pytorch/tokenizers/regex.h>
17+
18+
namespace tokenizers {
19+
20+
/**
21+
* @brief RE2-based implementation of IRegex.
22+
*/
23+
class Re2Regex : public IRegex {
24+
public:
25+
/**
26+
* @brief Construct a RE2 regex with the given pattern.
27+
*
28+
* @param pattern The regex pattern to compile.
29+
*/
30+
explicit Re2Regex(const std::string& pattern);
31+
32+
/**
33+
* @brief Return all non-overlapping matches found in the input string.
34+
*/
35+
virtual std::vector<Match> find_all(const std::string& text) const override;
36+
37+
private:
38+
std::unique_ptr<re2::RE2> regex_;
39+
40+
friend Result<std::unique_ptr<IRegex>> create_regex(
41+
const std::string& pattern);
42+
};
43+
44+
} // namespace tokenizers

include/pytorch/tokenizers/regex.h

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
#include <string>
13+
#include <vector>
14+
15+
#include <pytorch/tokenizers/result.h>
16+
17+
namespace tokenizers {
18+
19+
struct Match {
20+
size_t start; // starting index of the match
21+
size_t end; // ending index of the match (exclusive)
22+
};
23+
24+
/**
25+
* @brief Abstract interface for regex wrappers.
26+
*/
27+
class IRegex {
28+
public:
29+
virtual ~IRegex() = default;
30+
31+
/**
32+
* @brief Find all non-overlapping matches in the input string.
33+
*
34+
* @param text The input string to search.
35+
* @return A vector of strings containing all matched substrings.
36+
*/
37+
virtual std::vector<Match> find_all(const std::string& text) const = 0;
38+
};
39+
40+
/**
41+
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex.
42+
*
43+
* @param pattern The regex pattern to compile.
44+
* @return A unique pointer to an IRegex-compatible object.
45+
*/
46+
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);
47+
48+
} // namespace tokenizers
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
#include <regex>
13+
#include <string>
14+
#include "regex.h"
15+
16+
namespace tokenizers {
17+
18+
/**
19+
* @brief std::regex-based implementation of IRegex.
20+
*/
21+
class StdRegex : public IRegex {
22+
public:
23+
/**
24+
* @brief Construct a std::regex wrapper with the given pattern.
25+
*
26+
* @param pattern The regex pattern to compile.
27+
* @throws std::regex_error if the pattern is invalid.
28+
*/
29+
explicit StdRegex(const std::string& pattern);
30+
31+
/**
32+
* @brief Find all non-overlapping matches in the input string.
33+
*/
34+
virtual std::vector<Match> find_all(const std::string& text) const override;
35+
36+
private:
37+
std::regex regex_;
38+
};
39+
40+
} // namespace tokenizers

src/re2_regex.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/re2_regex.h>
10+
11+
namespace tokenizers {
12+
13+
Re2Regex::Re2Regex(const std::string& pattern) {
14+
regex_ = std::make_unique<re2::RE2>(pattern);
15+
// Warmup re2 as it is slow on the first run, void the return value as it's
16+
// not needed Refer to
17+
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
18+
(void)regex_->ReverseProgramSize();
19+
}
20+
21+
std::vector<Match> Re2Regex::find_all(const std::string& text) const {
22+
std::vector<Match> result;
23+
re2::StringPiece input(text);
24+
re2::StringPiece piece;
25+
26+
const char* base = input.data();
27+
28+
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
29+
size_t start = piece.data() - base;
30+
result.push_back({start, start + piece.size()});
31+
}
32+
33+
return result;
34+
}
35+
36+
} // namespace tokenizers

src/regex.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/re2_regex.h>
10+
#include <pytorch/tokenizers/regex.h>
11+
#include <pytorch/tokenizers/std_regex.h>
12+
13+
#include <re2/re2.h>
14+
#include <iostream>
15+
#include <memory>
16+
17+
namespace tokenizers {
18+
19+
/**
20+
* @brief Factory function that creates a regex object using RE2 if possible.
21+
* Falls back to std::regex if RE2 rejects the pattern with
22+
* ErrorBadPerlOp.
23+
*/
24+
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
25+
// Try RE2 first
26+
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
27+
28+
if (re2->regex_->ok()) {
29+
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
30+
}
31+
32+
if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) {
33+
try {
34+
std::cout
35+
<< "RE2 is unable to support things such as negative lookaheads in "
36+
<< pattern << ", defaulting to std::regex.";
37+
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
38+
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
39+
} catch (const std::regex_error& e) {
40+
std::cerr << "std::regex failed: " << e.what() << std::endl;
41+
return tokenizers::Error::LoadFailure;
42+
}
43+
} else {
44+
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
45+
std::cerr << "Error: " << (re2->regex_->error()) << std::endl;
46+
return tokenizers::Error::LoadFailure;
47+
}
48+
}
49+
50+
} // namespace tokenizers

src/std_regex.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/std_regex.h>
10+
#include <regex>
11+
12+
namespace tokenizers {
13+
14+
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
15+
16+
std::vector<Match> StdRegex::find_all(const std::string& text) const {
17+
std::vector<Match> result;
18+
std::sregex_iterator iter(text.begin(), text.end(), regex_);
19+
std::sregex_iterator end;
20+
21+
for (; iter != end; ++iter) {
22+
const auto& match = *iter;
23+
size_t start = match.position(1);
24+
result.push_back({start, start + match[1].length()});
25+
}
26+
27+
return result;
28+
}
29+
30+
} // namespace tokenizers

0 commit comments

Comments
 (0)