Skip to content

Commit a536432

Browse files
committed
Add stub impl of json tokenizer to llama runner
1 parent 3230900 commit a536432

File tree

6 files changed

+123
-12
lines changed

6 files changed

+123
-12
lines changed

examples/models/llama/runner/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ list(
4747
)
4848
list(APPEND _llama_runner__srcs
4949
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
50+
)
51+
list(
52+
APPEND _llama_runner__srcs
53+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/hf_tokenizer.cpp
5054
)
5155

5256
if(CMAKE_TOOLCHAIN_IOS

examples/models/llama/runner/runner.cpp

+26-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1919
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
20+
#include <executorch/extension/llm/tokenizer/hf_tokenizer.h>
2021

2122
namespace example {
2223

@@ -75,20 +76,33 @@ Error Runner::load() {
7576
return Error::Ok;
7677
}
7778
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
78-
// load tokenizer. Assuming tiktoken is the default tokenizer
79+
// Load tokenizer.
7980
tokenizer_ = nullptr;
80-
tokenizer_ = get_tiktoken_for_llama();
81-
Error err = tokenizer_->load(tokenizer_path_);
82-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
83-
// fallback to BPE tokenizer.
84-
if (err == Error::InvalidArgument) {
85-
ET_LOG(
86-
Info,
87-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
88-
tokenizer_path_.c_str());
89-
tokenizer_.reset();
90-
tokenizer_ = std::make_unique<llm::BPETokenizer>();
81+
// Check if tokenizer_path_ ends with ".json".
82+
if (tokenizer_path_.size() >= 5 &&
83+
tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) {
84+
tokenizer_ = std::make_unique<llm::HfTokenizer>();
9185
tokenizer_->load(tokenizer_path_);
86+
ET_LOG(
87+
Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str());
88+
} else {
89+
// Else assume TikToken is the default tokenizer, using BPE as a fallback.
90+
tokenizer_ = get_tiktoken_for_llama();
91+
Error err = tokenizer_->load(tokenizer_path_);
92+
if (err == Error::InvalidArgument) {
93+
tokenizer_.reset();
94+
tokenizer_ = std::make_unique<llm::BPETokenizer>();
95+
tokenizer_->load(tokenizer_path_);
96+
ET_LOG(
97+
Info,
98+
"Loaded tokenizer %s as BPE tokenizer",
99+
tokenizer_path_.c_str());
100+
} else {
101+
ET_LOG(
102+
Info,
103+
"Loaded tokenizer %s as TikToken tokenizer",
104+
tokenizer_path_.c_str());
105+
}
92106
}
93107

94108
ET_LOG(Info, "Reading metadata from model");

examples/models/llama/runner/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def define_common_targets():
4949
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
5050
"//executorch/examples/models/llama/tokenizer:tiktoken",
5151
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
52+
"//executorch/extension/llm/tokenizer:hf_tokenizer",
5253
] + (_get_operator_lib(aten)) + ([
5354
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
5455
# Therefore enable it explicitly for now to avoid failing tests
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <executorch/extension/llm/tokenizer/hf_tokenizer.h>
2+
3+
#include <executorch/runtime/core/error.h>
4+
#include <executorch/runtime/core/result.h>
5+
#include <string>
6+
#include <vector>
7+
8+
using ::executorch::runtime::Error;
9+
using ::executorch::runtime::Result;
10+
11+
namespace executorch {
12+
namespace extension {
13+
namespace llm {
14+
15+
HfTokenizer::~HfTokenizer() {}
16+
17+
Error HfTokenizer::load(const std::string& tokenizer_path) {
18+
// Stub implementation for loading the tokenizer.
19+
// TODO: Implement actual loading logic.
20+
return ::executorch::runtime::Error::Ok;
21+
}
22+
23+
Result<std::vector<uint64_t>>
24+
HfTokenizer::encode(const std::string& input, int8_t bos, int8_t eos) const {
25+
// Stub implementation for encoding.
26+
// TODO: Implement actual encoding logic.
27+
std::vector<uint64_t> tokens;
28+
return ::executorch::runtime::Result<std::vector<uint64_t>>(tokens);
29+
}
30+
31+
Result<std::string> HfTokenizer::decode(uint64_t prev_token, uint64_t token)
32+
const {
33+
// Stub implementation for decoding.
34+
// TODO: Implement actual decoding logic.
35+
std::string decoded_string;
36+
return ::executorch::runtime::Result<std::string>(decoded_string);
37+
}
38+
39+
} // namespace llm
40+
} // namespace extension
41+
} // namespace executorch
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/llm/tokenizer/tokenizer.h>
12+
13+
namespace executorch {
14+
namespace extension {
15+
namespace llm {
16+
17+
class ET_EXPERIMENTAL HfTokenizer : public Tokenizer {
18+
public:
19+
explicit HfTokenizer(){};
20+
~HfTokenizer() override;
21+
22+
::executorch::runtime::Error load(const std::string& tokenizer_path) override;
23+
24+
::executorch::runtime::Result<std::vector<uint64_t>>
25+
encode(const std::string& input, int8_t bos, int8_t eos) const override;
26+
27+
::executorch::runtime::Result<std::string> decode(
28+
uint64_t prev_token,
29+
uint64_t token) const override;
30+
};
31+
32+
} // namespace llm
33+
} // namespace extension
34+
} // namespace executorch

extension/llm/tokenizer/targets.bzl

+17
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ def define_common_targets():
9595
"re2",
9696
],
9797
)
98+
99+
runtime.cxx_library(
100+
name = "hf_tokenizer",
101+
srcs = [
102+
"hf_tokenizer.cpp",
103+
],
104+
exported_headers = [
105+
"hf_tokenizer.h",
106+
],
107+
exported_deps = [
108+
":tokenizer_header",
109+
"//executorch/runtime/core:core",
110+
],
111+
visibility = [
112+
"@EXECUTORCH_CLIENTS",
113+
],
114+
)

0 commit comments

Comments
 (0)