16
16
#include < executorch/extension/llm/runner/util.h>
17
17
18
18
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19
- #include < pytorch/tokenizers/llama2c_tokenizer.h>
20
19
#include < pytorch/tokenizers/hf_tokenizer.h>
20
+ #include < pytorch/tokenizers/llama2c_tokenizer.h>
21
21
22
22
namespace example {
23
23
@@ -36,6 +36,41 @@ static constexpr auto kMaxContextLen = "get_max_context_len";
36
36
static constexpr auto kVocabSize = " get_vocab_size" ;
37
37
static constexpr auto kUseKVCache = " use_kv_cache" ;
38
38
static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
39
+
40
+ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
41
+ const std::string& tokenizer_path) {
42
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr ;
43
+ ::tokenizers::Error err;
44
+
45
+ // First try to load as a json tokenizer.
46
+ {
47
+ auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
48
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
49
+ ET_LOG (Info, " Loaded json tokenizer" );
50
+ return tokenizer;
51
+ }
52
+ }
53
+
54
+ // Try to load as tiktoken tokenizer.
55
+ {
56
+ auto tokenizer = get_tiktoken_for_llama ();
57
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
58
+ ET_LOG (Info, " Loaded TikToken tokenizer" );
59
+ return tokenizer;
60
+ }
61
+ }
62
+
63
+ // Try to load as BPE tokenizer.
64
+ {
65
+ auto tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
66
+ if (tokenizer->load (tokenizer_path) == ::tokenizers::Error::Ok) {
67
+ ET_LOG (Info, " Loaded BPE tokenizer" );
68
+ return tokenizer;
69
+ }
70
+ }
71
+
72
+ return nullptr ;
73
+ }
39
74
} // namespace
40
75
41
76
Runner::Runner (
@@ -78,35 +113,15 @@ Error Runner::load() {
78
113
return Error::Ok;
79
114
}
80
115
ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
116
+
81
117
// Load tokenizer.
82
- tokenizer_ = nullptr ;
83
- // Check if tokenizer_path_ ends with ".json".
84
- if (tokenizer_path_.size () >= 5 &&
85
-
86
- tokenizer_path_.compare (tokenizer_path_.size () - 5 , 5 , " .json" ) == 0 ) {
87
- tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
88
- ET_LOG (Info, " Loading json tokenizer" );
89
- tokenizer_->load (tokenizer_path_);
118
+ tokenizer_ = load_tokenizer (tokenizer_path_);
119
+ if (tokenizer_ == nullptr ) {
90
120
ET_LOG (
91
- Info, " Loaded tokenizer %s as HF tokenizer" , tokenizer_path_.c_str ());
92
- } else {
93
- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94
- tokenizer_ = get_tiktoken_for_llama ();
95
- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
96
- // fallback to BPE tokenizer.
97
- if (err != ::tokenizers::Error::Ok) {
98
- ET_LOG (
99
- Info,
100
- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
101
- tokenizer_path_.c_str ());
102
- tokenizer_.reset ();
103
- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104
- err = tokenizer_->load (tokenizer_path_);
105
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
106
- err,
107
- " Failed to load %s as a llama2.c tokenizer artifact" ,
108
- tokenizer_path_.c_str ());
109
- }
121
+ Error,
122
+ " Failed to load %s as a llama2.c tokenizer artifact" ,
123
+ tokenizer_path_.c_str ());
124
+ return ::executorch::runtime::Error::InvalidArgument;
110
125
}
111
126
112
127
ET_LOG (Info, " Reading metadata from model" );
0 commit comments