14
14
#include < memory>
15
15
#include < optional>
16
16
#include < string>
17
+ #include < type_traits>
17
18
#include < unordered_map>
18
19
#include < vector>
19
20
20
21
// Third Party
21
22
#include < re2/re2.h>
22
23
23
24
// Local
25
+ #include < pytorch/tokenizers/error.h>
24
26
#include < pytorch/tokenizers/result.h>
27
+ #include < pytorch/tokenizers/string_integer_map.h>
25
28
#include < pytorch/tokenizers/tokenizer.h>
26
29
27
30
namespace tokenizers {
28
31
namespace detail {
29
32
30
- using Encoder = std::unordered_map<std::string, uint64_t >;
31
- using Decoder = std::unordered_map<uint64_t , std::string>;
32
33
using Re2UPtr = std::unique_ptr<re2::RE2>;
34
+ using TokenMap = StringIntegerMap<>;
35
+
36
+ template <typename TToken, typename TRank>
37
+ static Result<TokenMap> buildTokenMap (
38
+ std::vector<std::pair<TToken, TRank>> container) {
39
+ static_assert (
40
+ std::is_same_v<TToken, std::string> ||
41
+ std::is_same_v<TToken, std::string_view>,
42
+ " TToken must be std::string or std::string_view" );
43
+ static_assert (
44
+ std::is_integral_v<TRank> && std::is_unsigned_v<TRank>,
45
+ " TRank must be an unsigned integer" );
46
+
47
+ std::sort (
48
+ container.begin (), container.end (), [](const auto & a, const auto & b) {
49
+ return a.first < b.first ;
50
+ });
51
+
52
+ auto duplicate_begin = std::unique (
53
+ container.begin (), container.end (), [](const auto & a, const auto & b) {
54
+ return a.first == b.first ;
55
+ });
56
+
57
+ TK_CHECK_OR_RETURN_ERROR (
58
+ duplicate_begin == container.end (),
59
+ ParseFailure,
60
+ " duplicate token: %s rank: %llu" ,
61
+ duplicate_begin->first .c_str (),
62
+ static_cast <unsigned long long >(duplicate_begin->second ));
63
+
64
+ std::sort (
65
+ container.begin (), container.end (), [](const auto & a, const auto & b) {
66
+ return a.second < b.second ;
67
+ });
68
+
69
+ duplicate_begin = std::unique (
70
+ container.begin (), container.end (), [](const auto & a, const auto & b) {
71
+ return a.second == b.second ;
72
+ });
73
+
74
+ TK_CHECK_OR_RETURN_ERROR (
75
+ duplicate_begin == container.end (),
76
+ ParseFailure,
77
+ " duplicate rank: %llu"
78
+ " token: %s" ,
79
+ static_cast <unsigned long long >(duplicate_begin->second ),
80
+ duplicate_begin->first .c_str ());
81
+
82
+ return TokenMap (container);
83
+ };
84
+
85
+ template <typename TContainer, typename TTokenAccessor, typename TRankAccessor>
86
+ static Result<TokenMap> buildTokenMap (
87
+ const TContainer& container,
88
+ TTokenAccessor token_accessor,
89
+ TRankAccessor rank_accessor) {
90
+ using TokenType = std::invoke_result_t <TTokenAccessor, const TContainer&>;
91
+ using RankType = std::invoke_result_t <TRankAccessor, const TContainer&>;
92
+
93
+ static_assert (
94
+ std::is_same_v<TokenType, std::string> ||
95
+ std::is_same_v<TokenType, std::string_view>,
96
+ " TokenType must be std::string or std::string_view" );
97
+ static_assert (
98
+ std::is_integral_v<RankType> && std::is_unsigned_v<RankType>,
99
+ " RankType must be an unsigned integer" );
100
+
101
+ std::vector<std::pair<TokenType, RankType>> pairs;
102
+ pairs.reserve (container.size ());
103
+ for (const auto & value : container) {
104
+ pairs.emplace_back (token_accessor (value), rank_accessor (value));
105
+ }
106
+
107
+ return buildTokenMap (std::move (pairs));
108
+ }
33
109
34
110
class BPETokenizerBase : public Tokenizer {
35
111
public:
@@ -46,22 +122,20 @@ class BPETokenizerBase : public Tokenizer {
46
122
std::pair<std::optional<std::string>, re2::StringPiece>
47
123
split_with_allowed_special_token_ (
48
124
re2::StringPiece& input,
49
- const Encoder & allowed_special) const ;
125
+ const TokenMap & allowed_special) const ;
50
126
51
127
Result<std::pair<std::vector<uint64_t >, uint64_t >> encode_with_special_token_ (
52
128
const std::string& text,
53
- const Encoder & allowed_special) const ;
129
+ const TokenMap & allowed_special) const ;
54
130
55
131
Result<std::vector<uint64_t >> byte_pair_encode_ (
56
132
const std::string& piece,
57
- const Encoder & encoder) const ;
133
+ const TokenMap & encoder) const ;
58
134
59
135
// Protected members that can be overloaded by other BPE tokenizers
60
136
Re2UPtr special_token_regex_;
61
- Encoder encoder_;
62
- Encoder special_token_encoder_;
63
- Decoder decoder_;
64
- Decoder special_token_decoder_;
137
+ std::optional<TokenMap> token_map_;
138
+ std::optional<TokenMap> special_token_map_;
65
139
66
140
private:
67
141
virtual Error _encode (
0 commit comments