Skip to content

Commit d7e19d4

Browse files
authored
Rely on runtime_wrapper to provide supported platforms (#33)
* Rely on runtime_wrapper to provide supported platforms As titled. * Update targets.bzl * Move llama.cpp-unicode headers into llama.cpp-unicode/include Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix the build Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 4a1d033 commit d7e19d4

File tree

8 files changed

+116
-114
lines changed

8 files changed

+116
-114
lines changed

src/pre_tokenizer.cpp

+22-25
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
// Local
1010
#include <pytorch/tokenizers/pre_tokenizer.h>
11-
#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h>
11+
#include <unicode.h>
1212

1313
// Standard
1414
#include <algorithm>
@@ -63,37 +63,35 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const {
6363
"Missing pretokenizers for PreTokenizer of type Sequence");
6464
}
6565
std::vector<PreTokenizer::Ptr> pretoks;
66-
std::transform(
67-
pretokenizers->begin(),
68-
pretokenizers->end(),
69-
std::back_inserter(pretoks),
70-
[](const PreTokenizerConfig& cfg) { return cfg.create(); });
66+
std::transform(pretokenizers->begin(), pretokenizers->end(),
67+
std::back_inserter(pretoks),
68+
[](const PreTokenizerConfig &cfg) { return cfg.create(); });
7169
return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks));
7270
}
7371
throw std::runtime_error("Unsupported PreTokenizer type: " + type);
7472
}
7573

76-
PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {
74+
PreTokenizerConfig &PreTokenizerConfig::parse_json(const json &json_config) {
7775
type = json_config.at("type");
7876
if (type == "Split") {
7977
try {
8078
pattern = json_config.at("pattern");
81-
} catch (json::out_of_range&) {
79+
} catch (json::out_of_range &) {
8280
}
8381
} else if (type == "Digits") {
8482
try {
8583
individual_digits = json_config.at("individual_digits");
86-
} catch (json::out_of_range&) {
84+
} catch (json::out_of_range &) {
8785
}
8886
} else if (type == "ByteLevel") {
8987
try {
9088
add_prefix_space = json_config.at("add_prefix_space");
91-
} catch (json::out_of_range&) {
89+
} catch (json::out_of_range &) {
9290
}
9391
// TODO: trim_offsets, use_regex
9492
} else if (type == "Sequence") {
9593
pretokenizers = std::vector<PreTokenizerConfig>();
96-
for (const auto& entry : json_config.at("pretokenizers")) {
94+
for (const auto &entry : json_config.at("pretokenizers")) {
9795
pretokenizers->push_back(PreTokenizerConfig().parse_json(entry));
9896
}
9997
} else {
@@ -104,14 +102,14 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {
104102

105103
// RegexPreTokenizer ///////////////////////////////////////////////////////////
106104

107-
RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_(
108-
const std::string& pattern) {
105+
RegexPreTokenizer::Re2UPtr
106+
RegexPreTokenizer::create_regex_(const std::string &pattern) {
109107
assert(!pattern.empty());
110108
return std::make_unique<re2::RE2>("(" + pattern + ")");
111109
}
112110

113-
std::vector<std::string> RegexPreTokenizer::pre_tokenize(
114-
re2::StringPiece input) const {
111+
std::vector<std::string>
112+
RegexPreTokenizer::pre_tokenize(re2::StringPiece input) const {
115113
std::vector<std::string> result;
116114
std::string piece;
117115
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
@@ -138,14 +136,13 @@ constexpr char GPT2_EXPR[] =
138136
// Construction //
139137
//////////////////
140138

141-
ByteLevelPreTokenizer::ByteLevelPreTokenizer(
142-
bool add_prefix_space,
143-
const std::string& pattern)
139+
ByteLevelPreTokenizer::ByteLevelPreTokenizer(bool add_prefix_space,
140+
const std::string &pattern)
144141
: pattern_(pattern.empty() ? GPT2_EXPR : pattern),
145142
add_prefix_space_(add_prefix_space) {}
146143

147-
std::vector<std::string> ByteLevelPreTokenizer::pre_tokenize(
148-
re2::StringPiece input) const {
144+
std::vector<std::string>
145+
ByteLevelPreTokenizer::pre_tokenize(re2::StringPiece input) const {
149146
// Add the prefix space if configured to do so
150147
std::string input_str(input);
151148
if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') {
@@ -161,13 +158,13 @@ SequencePreTokenizer::SequencePreTokenizer(
161158
std::vector<PreTokenizer::Ptr> pre_tokenizers)
162159
: pre_tokenizers_(std::move(pre_tokenizers)) {}
163160

164-
std::vector<std::string> SequencePreTokenizer::pre_tokenize(
165-
re2::StringPiece input) const {
161+
std::vector<std::string>
162+
SequencePreTokenizer::pre_tokenize(re2::StringPiece input) const {
166163
std::vector<std::string> pieces{std::string(input)};
167-
for (const auto& pre_tokenizer : pre_tokenizers_) {
164+
for (const auto &pre_tokenizer : pre_tokenizers_) {
168165
std::vector<std::string> new_pieces;
169-
for (const auto& piece : pieces) {
170-
for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) {
166+
for (const auto &piece : pieces) {
167+
for (const auto &subpiece : pre_tokenizer->pre_tokenize(piece)) {
171168
new_pieces.push_back(subpiece);
172169
}
173170
}

src/token_decoder.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <nlohmann/json.hpp>
1717

1818
// Local
19-
#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h>
19+
#include <unicode.h>
2020

2121
using json = nlohmann::json;
2222

@@ -37,7 +37,7 @@ TokenDecoder::Ptr TokenDecoderConfig::create() const {
3737
throw std::runtime_error("Unsupported TokenDecoder type: " + type);
3838
}
3939

40-
TokenDecoderConfig& TokenDecoderConfig::parse_json(const json& json_config) {
40+
TokenDecoderConfig &TokenDecoderConfig::parse_json(const json &json_config) {
4141
type = json_config.at("type");
4242
if (type == "ByteLevel") {
4343
// No parameters to parse
@@ -54,7 +54,7 @@ namespace {
5454
// Copied from llama.cpp
5555
// CITE:
5656
// https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L20
57-
static std::string format(const char* fmt, ...) {
57+
static std::string format(const char *fmt, ...) {
5858
va_list ap;
5959
va_list ap2;
6060
va_start(ap, fmt);
@@ -84,7 +84,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const {
8484
const auto utf8 = unicode_cpt_to_utf8(cpt);
8585
try {
8686
decoded_text += unicode_utf8_to_byte(utf8);
87-
} catch (const std::out_of_range& /*e*/) {
87+
} catch (const std::out_of_range & /*e*/) {
8888
decoded_text += "[UNK_BYTE_0x";
8989
for (const auto c : utf8) {
9090
decoded_text += format("%02x", (uint8_t)c);

targets.bzl

+3-17
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "FBCODE")
2-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime", "get_executorch_supported_platforms")
32
load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob")
43

5-
PLATFORMS = (CXX, ANDROID, APPLE, FBCODE)
4+
PLATFORMS = get_executorch_supported_platforms()
65

76
def define_common_targets():
87
"""Defines targets that should be shared between fbcode and xplat.
@@ -68,19 +67,6 @@ def define_common_targets():
6867
platforms = PLATFORMS,
6968
)
7069

71-
runtime.cxx_library(
72-
name = "unicode",
73-
srcs = [
74-
"third-party/llama.cpp-unicode/src/unicode.cpp",
75-
"third-party/llama.cpp-unicode/src/unicode-data.cpp",
76-
],
77-
exported_headers = subdir_glob([
78-
("include", "pytorch/tokenizers/third-party/llama.cpp-unicode/*.h"),
79-
]),
80-
header_namespace = "",
81-
platforms = PLATFORMS,
82-
)
83-
8470
runtime.cxx_library(
8571
name = "hf_tokenizer",
8672
srcs = [
@@ -91,7 +77,7 @@ def define_common_targets():
9177
],
9278
exported_deps = [
9379
":headers",
94-
":unicode",
80+
"//pytorch/tokenizers/third-party:unicode",
9581
],
9682
visibility = [
9783
"@EXECUTORCH_CLIENTS",

third-party/TARGETS

+13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob")
23

34
oncall("executorch")
45

@@ -45,3 +46,15 @@ runtime.cxx_library(
4546
visibility = ["PUBLIC"],
4647
_is_external_target = True,
4748
)
49+
50+
runtime.cxx_library(
51+
name = "unicode",
52+
srcs = [
53+
"llama.cpp-unicode/src/unicode.cpp",
54+
"llama.cpp-unicode/src/unicode-data.cpp",
55+
],
56+
exported_headers = subdir_glob([
57+
("include", "*.h"),
58+
]),
59+
header_namespace = "",
60+
)

third-party/llama.cpp-unicode/src/unicode-data.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ SOFTWARE.
2727

2828
// generated with scripts/gen-unicode-data.py
2929

30-
#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode-data.h>
30+
#include "unicode-data.h"
3131

3232
#include <cstdint>
3333
#include <unordered_map>

0 commit comments

Comments
 (0)