|
17 | 17 |
|
18 | 18 | #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
|
19 | 19 | #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
|
| 20 | +#include <executorch/extension/llm/tokenizer/hf_tokenizer.h> |
20 | 21 |
|
21 | 22 | namespace example {
|
22 | 23 |
|
@@ -75,20 +76,33 @@ Error Runner::load() {
|
75 | 76 | return Error::Ok;
|
76 | 77 | }
|
77 | 78 | ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
|
78 |
| - // load tokenizer. Assuming tiktoken is the default tokenizer |
| 79 | + // Load tokenizer. |
79 | 80 | tokenizer_ = nullptr;
|
80 |
| - tokenizer_ = get_tiktoken_for_llama(); |
81 |
| - Error err = tokenizer_->load(tokenizer_path_); |
82 |
| - // Rely on tiktoken to throw error if the artifact is incompatible. Then we |
83 |
| - // fallback to BPE tokenizer. |
84 |
| - if (err == Error::InvalidArgument) { |
85 |
| - ET_LOG( |
86 |
| - Info, |
87 |
| - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", |
88 |
| - tokenizer_path_.c_str()); |
89 |
| - tokenizer_.reset(); |
90 |
| - tokenizer_ = std::make_unique<llm::BPETokenizer>(); |
| 81 | + // Check if tokenizer_path_ ends with ".json". |
| 82 | + if (tokenizer_path_.size() >= 5 && |
| 83 | + tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) { |
| 84 | + tokenizer_ = std::make_unique<llm::HfTokenizer>(); |
91 | 85 | tokenizer_->load(tokenizer_path_);
|
| 86 | + ET_LOG( |
| 87 | + Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str()); |
| 88 | + } else { |
| 89 | + // Else assume TikToken is the default tokenizer, using BPE as a fallback. |
| 90 | + tokenizer_ = get_tiktoken_for_llama(); |
| 91 | + Error err = tokenizer_->load(tokenizer_path_); |
| 92 | + if (err == Error::InvalidArgument) { |
| 93 | + tokenizer_.reset(); |
| 94 | + tokenizer_ = std::make_unique<llm::BPETokenizer>(); |
| 95 | + tokenizer_->load(tokenizer_path_); |
| 96 | + ET_LOG( |
| 97 | + Info, |
| 98 | + "Loaded tokenizer %s as BPE tokenizer", |
| 99 | + tokenizer_path_.c_str()); |
| 100 | + } else { |
| 101 | + ET_LOG( |
| 102 | + Info, |
| 103 | + "Loaded tokenizer %s as TikToken tokenizer", |
| 104 | + tokenizer_path_.c_str()); |
| 105 | + } |
92 | 106 | }
|
93 | 107 |
|
94 | 108 | ET_LOG(Info, "Reading metadata from model");
|
|
0 commit comments