Skip to content

Commit 7a536e8

Browse files
committed
Allow model to tokenize strings longer than context length and set add_bos. Closes #92
1 parent 8740ddc commit 7a536e8

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def __init__(
174174
if self.verbose:
175175
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
176176

177-
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
177+
def tokenize(
178+
self, text: bytes, add_bos: bool = True
179+
) -> List[llama_cpp.llama_token]:
178180
"""Tokenize a string.
179181
180182
Args:
@@ -194,10 +196,22 @@ def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
194196
text,
195197
tokens,
196198
n_ctx,
197-
llama_cpp.c_bool(True),
199+
llama_cpp.c_bool(add_bos),
198200
)
199201
if int(n_tokens) < 0:
200-
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
202+
n_tokens = abs(n_tokens)
203+
tokens = (llama_cpp.llama_token * int(n_tokens))()
204+
n_tokens = llama_cpp.llama_tokenize(
205+
self.ctx,
206+
text,
207+
tokens,
208+
llama_cpp.c_int(n_tokens),
209+
llama_cpp.c_bool(add_bos),
210+
)
211+
if n_tokens < 0:
212+
raise RuntimeError(
213+
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
214+
)
201215
return list(tokens[:n_tokens])
202216

203217
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:

llama_cpp/llama_cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def llama_tokenize(
350350
tokens, # type: Array[llama_token]
351351
n_max_tokens: c_int,
352352
add_bos: c_bool,
353-
) -> c_int:
353+
) -> int:
354354
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
355355

356356

0 commit comments

Comments
 (0)