Skip to content

Commit 567473d

Browse files
authored
Port string_integer_map and changes to tiktoken to pytorch
Differential Revision: D71500411 Pull Request resolved: #37
1 parent c8bff72 commit 567473d

File tree

10 files changed

+1191
-138
lines changed

10 files changed

+1191
-138
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ if(TOKENIZERS_BUILD_TEST)
7878
${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece
7979
${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2
8080
${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include)
81-
target_link_libraries(${test_name} gtest_main tokenizers)
81+
target_link_libraries(${test_name} gtest_main GTest::gmock tokenizers)
8282
add_test(${test_name} "${test_name}")
8383
set_tests_properties(${test_name} PROPERTIES ENVIRONMENT ${test_env})
8484
endforeach()

include/pytorch/tokenizers/bpe_tokenizer_base.h

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,98 @@
1414
#include <memory>
1515
#include <optional>
1616
#include <string>
17+
#include <type_traits>
1718
#include <unordered_map>
1819
#include <vector>
1920

2021
// Third Party
2122
#include <re2/re2.h>
2223

2324
// Local
25+
#include <pytorch/tokenizers/error.h>
2426
#include <pytorch/tokenizers/result.h>
27+
#include <pytorch/tokenizers/string_integer_map.h>
2528
#include <pytorch/tokenizers/tokenizer.h>
2629

2730
namespace tokenizers {
2831
namespace detail {
2932

30-
using Encoder = std::unordered_map<std::string, uint64_t>;
31-
using Decoder = std::unordered_map<uint64_t, std::string>;
3233
using Re2UPtr = std::unique_ptr<re2::RE2>;
34+
using TokenMap = StringIntegerMap<>;
35+
36+
template <typename TToken, typename TRank>
37+
static Result<TokenMap> buildTokenMap(
38+
std::vector<std::pair<TToken, TRank>> container) {
39+
static_assert(
40+
std::is_same_v<TToken, std::string> ||
41+
std::is_same_v<TToken, std::string_view>,
42+
"TToken must be std::string or std::string_view");
43+
static_assert(
44+
std::is_integral_v<TRank> && std::is_unsigned_v<TRank>,
45+
"TRank must be an unsigned integer");
46+
47+
std::sort(
48+
container.begin(), container.end(), [](const auto& a, const auto& b) {
49+
return a.first < b.first;
50+
});
51+
52+
auto duplicate_begin = std::unique(
53+
container.begin(), container.end(), [](const auto& a, const auto& b) {
54+
return a.first == b.first;
55+
});
56+
57+
TK_CHECK_OR_RETURN_ERROR(
58+
duplicate_begin == container.end(),
59+
ParseFailure,
60+
"duplicate token: %s rank: %llu",
61+
duplicate_begin->first.c_str(),
62+
static_cast<unsigned long long>(duplicate_begin->second));
63+
64+
std::sort(
65+
container.begin(), container.end(), [](const auto& a, const auto& b) {
66+
return a.second < b.second;
67+
});
68+
69+
duplicate_begin = std::unique(
70+
container.begin(), container.end(), [](const auto& a, const auto& b) {
71+
return a.second == b.second;
72+
});
73+
74+
TK_CHECK_OR_RETURN_ERROR(
75+
duplicate_begin == container.end(),
76+
ParseFailure,
77+
"duplicate rank: %llu"
78+
" token: %s",
79+
static_cast<unsigned long long>(duplicate_begin->second),
80+
duplicate_begin->first.c_str());
81+
82+
return TokenMap(container);
83+
};
84+
85+
template <typename TContainer, typename TTokenAccessor, typename TRankAccessor>
86+
static Result<TokenMap> buildTokenMap(
87+
const TContainer& container,
88+
TTokenAccessor token_accessor,
89+
TRankAccessor rank_accessor) {
90+
using TokenType = std::invoke_result_t<TTokenAccessor, const TContainer&>;
91+
using RankType = std::invoke_result_t<TRankAccessor, const TContainer&>;
92+
93+
static_assert(
94+
std::is_same_v<TokenType, std::string> ||
95+
std::is_same_v<TokenType, std::string_view>,
96+
"TokenType must be std::string or std::string_view");
97+
static_assert(
98+
std::is_integral_v<RankType> && std::is_unsigned_v<RankType>,
99+
"RankType must be an unsigned integer");
100+
101+
std::vector<std::pair<TokenType, RankType>> pairs;
102+
pairs.reserve(container.size());
103+
for (const auto& value : container) {
104+
pairs.emplace_back(token_accessor(value), rank_accessor(value));
105+
}
106+
107+
return buildTokenMap(std::move(pairs));
108+
}
33109

34110
class BPETokenizerBase : public Tokenizer {
35111
public:
@@ -46,22 +122,20 @@ class BPETokenizerBase : public Tokenizer {
46122
std::pair<std::optional<std::string>, re2::StringPiece>
47123
split_with_allowed_special_token_(
48124
re2::StringPiece& input,
49-
const Encoder& allowed_special) const;
125+
const TokenMap& allowed_special) const;
50126

51127
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
52128
const std::string& text,
53-
const Encoder& allowed_special) const;
129+
const TokenMap& allowed_special) const;
54130

55131
Result<std::vector<uint64_t>> byte_pair_encode_(
56132
const std::string& piece,
57-
const Encoder& encoder) const;
133+
const TokenMap& encoder) const;
58134

59135
// Protected members that can be overloaded by other BPE tokenizers
60136
Re2UPtr special_token_regex_;
61-
Encoder encoder_;
62-
Encoder special_token_encoder_;
63-
Decoder decoder_;
64-
Decoder special_token_decoder_;
137+
std::optional<TokenMap> token_map_;
138+
std::optional<TokenMap> special_token_map_;
65139

66140
private:
67141
virtual Error _encode(

0 commit comments

Comments
 (0)