Skip to content

Commit 173308e

Browse files
committed
Scott pr review
1 parent 97110ac commit 173308e

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

examples/models/llama/runner/runner.cpp

+43-28
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include <executorch/extension/llm/runner/util.h>
1717

1818
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19-
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2019
#include <pytorch/tokenizers/hf_tokenizer.h>
20+
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2121

2222
namespace example {
2323

@@ -36,6 +36,41 @@ static constexpr auto kMaxContextLen = "get_max_context_len";
3636
static constexpr auto kVocabSize = "get_vocab_size";
3737
static constexpr auto kUseKVCache = "use_kv_cache";
3838
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
39+
40+
std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
41+
const std::string& tokenizer_path) {
42+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr;
43+
::tokenizers::Error err;
44+
45+
// First try to load as a json tokenizer.
46+
{
47+
auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
48+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
49+
ET_LOG(Info, "Loaded json tokenizer");
50+
return tokenizer;
51+
}
52+
}
53+
54+
// Try to load as tiktoken tokenizer.
55+
{
56+
auto tokenizer = get_tiktoken_for_llama();
57+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
58+
ET_LOG(Info, "Loaded TikToken tokenizer");
59+
return tokenizer;
60+
}
61+
}
62+
63+
// Try to load as BPE tokenizer.
64+
{
65+
auto tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
66+
if (tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
67+
ET_LOG(Info, "Loaded BPE tokenizer");
68+
return tokenizer;
69+
}
70+
}
71+
72+
return nullptr;
73+
}
3974
} // namespace
4075

4176
Runner::Runner(
@@ -78,35 +113,15 @@ Error Runner::load() {
78113
return Error::Ok;
79114
}
80115
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
116+
81117
// Load tokenizer.
82-
tokenizer_ = nullptr;
83-
// Check if tokenizer_path_ ends with ".json".
84-
if (tokenizer_path_.size() >= 5 &&
85-
86-
tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) {
87-
tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
88-
ET_LOG(Info, "Loading json tokenizer");
89-
tokenizer_->load(tokenizer_path_);
118+
tokenizer_ = load_tokenizer(tokenizer_path_);
119+
if (tokenizer_ == nullptr) {
90120
ET_LOG(
91-
Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str());
92-
} else {
93-
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
94-
tokenizer_ = get_tiktoken_for_llama();
95-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
96-
// fallback to BPE tokenizer.
97-
if (err != ::tokenizers::Error::Ok) {
98-
ET_LOG(
99-
Info,
100-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
101-
tokenizer_path_.c_str());
102-
tokenizer_.reset();
103-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104-
err = tokenizer_->load(tokenizer_path_);
105-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
106-
err,
107-
"Failed to load %s as a llama2.c tokenizer artifact",
108-
tokenizer_path_.c_str());
109-
}
121+
Error,
122+
"Failed to load %s as a llama2.c tokenizer artifact",
123+
tokenizer_path_.c_str());
124+
return ::executorch::runtime::Error::InvalidArgument;
110125
}
111126

112127
ET_LOG(Info, "Reading metadata from model");

0 commit comments

Comments
 (0)