Skip to content

Use common base class private functions for TikToken #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions include/pytorch/tokenizers/tiktoken.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,13 @@ class Tiktoken : public detail::BPETokenizerBase {
return special_tokens;
}

template <typename T>
std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(
re2::StringPiece& input,
const T& allowed_special) const;

Error _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(re2::StringPiece input, std::string& ret) const override;

template <typename T>
Result<std::pair<std::vector<uint64_t>, uint64_t>> _encode_with_special_token(
const std::string& text,
const T& allowed_special) const;

detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const;

std::unique_ptr<std::vector<std::string>> _special_tokens;
Expand Down
14 changes: 13 additions & 1 deletion src/bpe_tokenizer_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ BPETokenizerBase::split_with_allowed_special_token_(
return std::make_pair(std::nullopt, input);
}

#if __cplusplus >= 202002L
auto start = input.begin();
#else
const char* start = input.data();
#endif

std::string special;
while (true) {
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
Expand All @@ -148,9 +153,15 @@ BPETokenizerBase::split_with_allowed_special_token_(

if (allowed_special.tryGetInteger(special).has_value()) {
// Found an allowed special token, split the text with it.
#if __cplusplus >= 202002L
return std::make_pair(
special,
re2::StringPiece(start, input.begin() - start - special.size()));
#else
return std::make_pair(
special,
re2::StringPiece(start, (input.data() - start) - special.size()));
#endif
} // else try to find the next special token
}

Expand All @@ -168,7 +179,8 @@ BPETokenizerBase::encode_with_special_token_(
auto [special, sub_input] =
split_with_allowed_special_token_(input, allowed_special);

_encode(sub_input, tokens, last_piece_token_len);
TK_CHECK_OK_OR_RETURN_ERROR(
_encode(sub_input, tokens, last_piece_token_len));

if (special) {
const auto result = special_token_map_->tryGetInteger(*special);
Expand Down
75 changes: 0 additions & 75 deletions src/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,44 +113,6 @@ static Result<TokenMap> _load_token_map(const std::string& path) {
// ------------------------------Util end------------------------------------
// -------------------------private method start-------------------------------

template <typename T>
std::pair<std::optional<std::string>, re2::StringPiece>
Tiktoken::_split_with_allowed_special_token(
re2::StringPiece& input,
const T& allowed_special) const {
if (!special_token_regex_) {
return std::make_pair(std::nullopt, input);
}

#if __cplusplus >= 202002L
auto start = input.begin();
#else
const char* start = input.data();
#endif
std::string special;
while (true) {
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
// No special token.
break;
}

if (allowed_special.tryGetInteger(special)) {
// Found an allowed special token, split the text with it.
#if __cplusplus >= 202002L
return std::make_pair(
special,
re2::StringPiece(start, input.begin() - start - special.size()));
#else
return std::make_pair(
special,
re2::StringPiece(start, (input.data() - start) - special.size()));
#endif
} // else try to find the next special token
}

return std::make_pair(std::nullopt, input);
}

Error Tiktoken::_encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
Expand Down Expand Up @@ -179,43 +141,6 @@ void Tiktoken::_decode(re2::StringPiece input, std::string& ret) const {
#endif
}

template <typename T>
Result<std::pair<std::vector<uint64_t>, uint64_t>>
Tiktoken::_encode_with_special_token(
const std::string& text,
const T& allowed_special) const {
std::vector<uint64_t> tokens;
uint64_t last_piece_token_len = 0;
re2::StringPiece input(text);
while (true) {
auto [special, sub_input] =
_split_with_allowed_special_token(input, allowed_special);

TK_CHECK_OK_OR_RETURN_ERROR(
_encode(sub_input, tokens, last_piece_token_len));

if (special) {
const auto result = special_token_map_->tryGetInteger(*special);
if (!result) {
// Should never go here, since special pattern includes all special
// chars.
TK_LOG(Error, "unknown special token: %s", special->c_str());
return Error::EncodeFailure;
}

tokens.push_back(*result);
last_piece_token_len = 0;
} else {
break;
}
}

// last_piece_token_len is how many tokens came from the last regex split.
// This is used for determining unstable tokens, since you can't merge
// across (stable) regex splits
return std::make_pair(tokens, last_piece_token_len);
}

// -------------------------private method end-------------------------------
// -------------------------public method start-------------------------------

Expand Down