Skip to content

Commit de6b735

Browse files
committed
Add stub impl of json tokenizer to llama runner
1 parent 3230900 commit de6b735

File tree

4 files changed

+111
-12
lines changed

4 files changed

+111
-12
lines changed

examples/models/llama/runner/runner.cpp

+18-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1919
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
20+
#include <executorch/extension/llm/tokenizer/hf_tokenizer.h>
2021

2122
namespace example {
2223

@@ -75,20 +76,25 @@ Error Runner::load() {
7576
return Error::Ok;
7677
}
7778
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
78-
// load tokenizer. Assuming tiktoken is the default tokenizer
79+
// Load tokenizer.
7980
tokenizer_ = nullptr;
80-
tokenizer_ = get_tiktoken_for_llama();
81-
Error err = tokenizer_->load(tokenizer_path_);
82-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
83-
// fallback to BPE tokenizer.
84-
if (err == Error::InvalidArgument) {
85-
ET_LOG(
86-
Info,
87-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
88-
tokenizer_path_.c_str());
89-
tokenizer_.reset();
90-
tokenizer_ = std::make_unique<llm::BPETokenizer>();
81+
// Check if tokenizer_path_ ends with ".json".
82+
if (tokenizer_path_.size() >= 5 && tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) {
83+
tokenizer_ = std::make_unique<llm::HfTokenizer>();
9184
tokenizer_->load(tokenizer_path_);
85+
ET_LOG(Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str());
86+
} else {
87+
// Else assume TikToken is the default tokenizer, using BPE as a fallback.
88+
tokenizer_ = get_tiktoken_for_llama();
89+
Error err = tokenizer_->load(tokenizer_path_);
90+
if (err == Error::InvalidArgument) {
91+
tokenizer_.reset();
92+
tokenizer_ = std::make_unique<llm::BPETokenizer>();
93+
tokenizer_->load(tokenizer_path_);
94+
ET_LOG(Info, "Loaded tokenizer %s as BPE tokenizer", tokenizer_path_.c_str());
95+
} else {
96+
ET_LOG(Info, "Loaded tokenizer %s as TikToken tokenizer", tokenizer_path_.c_str());
97+
}
9298
}
9399

94100
ET_LOG(Info, "Reading metadata from model");
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include <executorch/extension/llm/tokenizer.h>
2+
#include <executorch/runtime/core/error.h>
3+
#include <executorch/runtime/core/result.h>
4+
#include <string>
5+
#include <vector>
6+
7+
namespace executorch {
8+
namespace extension {
9+
namespace llm {
10+
11+
class HfTokenizer : public Tokenizer {
12+
public:
13+
HfTokenizer() : Tokenizer() {}
14+
15+
~HfTokenizer() override {}
16+
17+
::executorch::runtime::Error load(const std::string& tokenizer_path) override {
18+
// Stub implementation for loading the tokenizer.
19+
// TODO: Implement actual loading logic.
20+
return ::executorch::runtime::Error::Ok;
21+
}
22+
23+
::executorch::runtime::Result<std::vector<uint64_t>> encode(
24+
const std::string& input, int8_t bos, int8_t eos) const override {
25+
// Stub implementation for encoding.
26+
// TODO: Implement actual encoding logic.
27+
std::vector<uint64_t> tokens;
28+
return ::executorch::runtime::Result<std::vector<uint64_t>>(tokens);
29+
}
30+
31+
::executorch::runtime::Result<std::string> decode(
32+
uint64_t prev_token, uint64_t token) const override {
33+
// Stub implementation for decoding.
34+
// TODO: Implement actual decoding logic.
35+
std::string decoded_string;
36+
return ::executorch::runtime::Result<std::string>(decoded_string);
37+
}
38+
};
39+
40+
} // namespace llm
41+
} // namespace extension
42+
} // namespace executorch
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/llm/tokenizer/tokenizer.h>
12+
13+
namespace executorch {
14+
namespace extension {
15+
namespace llm {
16+
17+
class ET_EXPERIMENTAL HfTokenizer : public Tokenizer {
18+
public:
19+
explicit HfTokenizer();
20+
~HfTokenizer() override;
21+
22+
::executorch::runtime::Error load(const std::string& tokenizer_path) override;
23+
24+
::executorch::runtime::Result<std::vector<uint64_t>>
25+
encode(const std::string& input, int8_t bos, int8_t eos) const override;
26+
27+
::executorch::runtime::Result<std::string> decode(
28+
uint64_t prev_token,
29+
uint64_t token) const override;
30+
};
31+
32+
} // namespace llm
33+
} // namespace extension
34+
} // namespace executorch

extension/llm/tokenizer/targets.bzl

+17
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ def define_common_targets():
9595
"re2",
9696
],
9797
)
98+
99+
runtime.cxx_library(
100+
name = "hf_tokenizer",
101+
srcs = [
102+
"hf_tokenizer.cpp",
103+
],
104+
exported_headers = [
105+
"hf_tokenizer.h",
106+
],
107+
exported_deps = [
108+
":tokenizer_header",
109+
"//executorch/runtime/core:core",
110+
],
111+
visibility = [
112+
"@EXECUTORCH_CLIENTS",
113+
],
114+
)

0 commit comments

Comments
 (0)