Skip to content

Commit 100430f

Browse files
committed
PR review
1 parent 934ffa3 commit 100430f

File tree

7 files changed

+137
-53
lines changed

7 files changed

+137
-53
lines changed

include/pytorch/tokenizers/re2_regex.h

+17-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <string>
5-
#include "regex.h"
613

7-
// Third Party
814
#include <re2/re2.h>
915

16+
#include <pytorch/tokenizers/regex.h>
17+
18+
namespace tokenizers {
19+
1020
/**
1121
* @brief RE2-based implementation of IRegex.
1222
*/
@@ -24,11 +34,10 @@ class Re2Regex : public IRegex {
2434
*/
2535
virtual std::vector<Match> findAll(const std::string& text) const override;
2636

27-
protected:
2837
/**
2938
* @brief Check if RE2 compiled the pattern successfully.
3039
*/
31-
bool ok() const;
40+
bool ok() const override;
3241

3342
/**
3443
* @brief Expose internal RE2 pointer to the factory if needed.
@@ -38,5 +47,8 @@ class Re2Regex : public IRegex {
3847
private:
3948
std::unique_ptr<re2::RE2> regex_;
4049

41-
friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
50+
friend Result<std::unique_ptr<IRegex>> createRegex(
51+
const std::string& pattern);
4252
};
53+
54+
} // namespace tokenizers

include/pytorch/tokenizers/regex.h

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <string>
513
#include <vector>
614

15+
#include <pytorch/tokenizers/result.h>
16+
17+
namespace tokenizers {
18+
719
struct Match {
820
std::string text;
921
size_t position;
@@ -23,6 +35,13 @@ class IRegex {
2335
* @return A vector of strings containing all matched substrings.
2436
*/
2537
virtual std::vector<Match> findAll(const std::string& text) const = 0;
38+
39+
/**
40+
* @brief Check if the regex pattern was compiled successfully.
41+
*
42+
* @return true if the pattern is valid and ready to use, false otherwise.
43+
*/
44+
virtual bool ok() const = 0;
2645
};
2746

2847
/**
@@ -31,4 +50,6 @@ class IRegex {
3150
* @param pattern The regex pattern to compile.
3251
* @return A unique pointer to an IRegex-compatible object.
3352
*/
34-
std::unique_ptr<IRegex> createRegex(const std::string& pattern);
53+
Result<std::unique_ptr<IRegex>> createRegex(const std::string& pattern);
54+
55+
} // namespace tokenizers

include/pytorch/tokenizers/std_regex.h

+19
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <regex>
513
#include <string>
614
#include "regex.h"
715

16+
namespace tokenizers {
17+
818
/**
919
* @brief std::regex-based implementation of IRegex.
1020
*/
@@ -23,6 +33,15 @@ class StdRegex : public IRegex {
2333
*/
2434
virtual std::vector<Match> findAll(const std::string& text) const override;
2535

36+
/**
37+
* @brief Check if std::regex compiled the pattern successfully.
38+
*
39+
* @return true if the pattern is valid, false otherwise.
40+
*/
41+
bool ok() const override;
42+
2643
private:
2744
std::regex regex_;
2845
};
46+
47+
} // namespace tokenizers

src/re2_regex.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/re2_regex.h>
10+
11+
namespace tokenizers {
12+
13+
Re2Regex::Re2Regex(const std::string& pattern) {
14+
regex_ = std::make_unique<re2::RE2>(pattern);
15+
// Warmup re2 as it is slow on the first run, void the return value as it's
16+
// not needed Refer to
17+
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
18+
(void)regex_->ReverseProgramSize();
19+
}
20+
21+
std::vector<Match> Re2Regex::findAll(const std::string& text) const {
22+
std::vector<Match> result;
23+
re2::StringPiece input(text);
24+
re2::StringPiece piece;
25+
26+
const char* base = input.data();
27+
28+
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
29+
size_t pos = piece.data() - base;
30+
result.push_back({std::string(piece.data(), piece.size()), pos});
31+
}
32+
33+
return result;
34+
}
35+
36+
bool Re2Regex::ok() const {
37+
return regex_ && regex_->ok();
38+
}
39+
40+
const re2::RE2* Re2Regex::rawRegex() const {
41+
return regex_.get();
42+
}
43+
44+
} // namespace tokenizers

src/re2_regex.cpp

-33
This file was deleted.

src/regex.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
#include "pytorch/tokenizers/regex.h"
2-
#include "pytorch/tokenizers/re2_regex.h"
3-
#include "pytorch/tokenizers/std_regex.h"
1+
#include <pytorch/tokenizers/regex.h>
2+
#include <pytorch/tokenizers/re2_regex.h>
3+
#include <pytorch/tokenizers/std_regex.h>
44

55
#include <re2/re2.h>
66
#include <iostream>
77
#include <memory>
88

9+
namespace tokenizers {
10+
911
/**
1012
* @brief Factory function that creates a regex object using RE2 if possible.
1113
* Falls back to std::regex if RE2 rejects the pattern with
12-
* ErrorBadPerlOp.
14+
* ErrorBadPerlOp.
1315
*/
14-
std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
15-
auto re2 = std::make_unique<Re2Regex>(pattern);
16+
Result<std::unique_ptr<IRegex>> createRegex(const std::string& pattern) {
17+
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
1618

1719
if (re2->ok()) {
18-
return re2;
20+
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
1921
}
2022

2123
const re2::RE2* raw = re2->rawRegex();
@@ -24,14 +26,17 @@ std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
2426
std::cout
2527
<< "RE2 is unable to support things such as negative lookaheads in "
2628
<< pattern << ", defaulting to std::regex.";
27-
return std::make_unique<StdRegex>(pattern);
29+
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
30+
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
2831
} catch (const std::regex_error& e) {
2932
std::cerr << "std::regex failed: " << e.what() << std::endl;
30-
return nullptr;
33+
return tokenizers::Error::LoadFailure;
3134
}
3235
} else {
3336
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
3437
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
35-
return nullptr;
38+
return tokenizers::Error::LoadFailure;
3639
}
3740
}
41+
42+
} // namespace tokenizers

src/std_regex.cpp

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
#include "pytorch/tokenizers/std_regex.h"
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/std_regex.h>
210
#include <regex>
311

4-
StdRegex::StdRegex(const std::string& pattern)
5-
: regex_("(" + pattern + ")") // Add parentheses like RE2 version
6-
{}
12+
namespace tokenizers {
13+
14+
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
715

816
std::vector<Match> StdRegex::findAll(const std::string& text) const {
917
std::vector<Match> result;
@@ -20,3 +28,11 @@ std::vector<Match> StdRegex::findAll(const std::string& text) const {
2028

2129
return result;
2230
}
31+
32+
bool StdRegex::ok() const {
33+
// std::regex constructor throws if the pattern is invalid
34+
// If we got here, the pattern is valid
35+
return true;
36+
}
37+
38+
} // namespace tokenizers

0 commit comments

Comments
 (0)