Skip to content

Unified tokenizer type onboarding #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from tokenizer.tokenizer_type import TokenizerType
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs

# TODO - these are not distributed specific, consider moving to new package
Expand Down Expand Up @@ -64,11 +65,6 @@
}


class TokenizerType(Enum):
Tiktoken = auto()
SentencePiece = auto()


def _init_distributed():
dist.init_process_group("nccl")
rank = dist.get_rank()
Expand Down Expand Up @@ -122,9 +118,9 @@ def _build_chat_tokenizer(
)
# set global variable _tokenizer_type
if isinstance(tokenizer, TiktokenTokenizer):
_tokenizer_type = TokenizerType.Tiktoken
_tokenizer_type = TokenizerType.TIKTOKEN
elif isinstance(tokenizer, SentencePieceProcessor):
_tokenizer_type = TokenizerType.SentencePiece
_tokenizer_type = TokenizerType.SENTENCEPIECE
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")

Expand Down Expand Up @@ -575,11 +571,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
if _tokenizer_type == TokenizerType.Tiktoken:
if _tokenizer_type == TokenizerType.TIKTOKEN:
# For TiktokenTokenizer, we need to decode prompt by prompt.
# TODO: is there a better way to do this?
responses = [tokenizer.decode(sequence) for sequence in res_list]
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
elif _tokenizer_type == TokenizerType.SENTENCEPIECE: # SentencePieceProcessor
# For SentencePieceProcessor, we can decode the entire 2D list at once.
responses = tokenizer.decode(res_list)
else:
Expand Down
7 changes: 7 additions & 0 deletions tokenizer/tokenizer_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum

class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3
23 changes: 4 additions & 19 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args.pte_path = None
return speculative_builder_args


class TokenizerType(Enum):
NONE = 0
TIKTOKEN = 1
SENTENCEPIECE = 2
HF_TOKENIZER = 3

from tokenizer.tokenizer_type import TokenizerType

@dataclass
class TokenizerArgs:
Expand Down Expand Up @@ -276,15 +270,6 @@ def __post_init__(self):
except:
pass

def is_tiktoken(self) -> bool:
return self.tokenizer_type == TokenizerType.TIKTOKEN

def is_sentencepiece(self) -> bool:
return self.tokenizer_type == TokenizerType.SENTENCEPIECE

def is_hf_tokenizer(self) -> bool:
return self.tokenizer_type == TokenizerType.HF_TOKENIZER

def validate_model(
self,
model: Optional[Model],
Expand All @@ -296,9 +281,9 @@ def validate_model(
if self.tokenizer_type == TokenizerType.NONE:
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken()
is_sentencepiece = self.is_sentencepiece()
is_hf_tokenizer = self.is_hf_tokenizer()
is_tiktoken = self.tokenizer_type == TokenizerType.TIKTOKEN
is_sentencepiece = self.tokenizer_type == TokenizerType.SENTENCEPIECE
is_hf_tokenizer = self.tokenizer_type == TokenizerType.HF_TOKENIZER

use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
Expand Down
3 changes: 2 additions & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from typing import Dict, Optional

from tokenizer.tokenizer_type import TokenizerType
import torch
import torch._inductor
import torch.nn as nn
Expand Down Expand Up @@ -482,7 +483,7 @@ def main(args):

if tokenizer_args is None:
tokenizer_type = "0"
elif tokenizer_args.is_sentencepiece():
elif tokenizer_args == TokenizerType.SENTENCEPIECE:
tokenizer_type = "2" # Corresponding to llama2
else:
tokenizer_type = "3" # Corresponding to llama3
Expand Down
5 changes: 3 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.multiprocessing as mp

from PIL import Image
from tokenizer.tokenizer_type import TokenizerType
from torch._C import _SDPBackend as SDPBackend
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe

Expand Down Expand Up @@ -361,14 +362,14 @@ def __init__(
# must use tiktokenizer.
# Piggy backing off of this flag then for now to identify llama3
# without prompting user.
self.is_llama3_model = self.tokenizer_args.is_tiktoken()
self.is_llama3_model = self.tokenizer_args.tokenizer_type == TokenizerType.TIKTOKEN
if self.is_llama3_model:
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
if generator_args.chat_mode:
logger.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
elif self.tokenizer_args.is_hf_tokenizer():
elif self.tokenizer_args.tokenizer_type == TokenizerType.HF_TOKENIZER:
if not self.tokenizer.has_chat_template():
raise ValueError("Tokenizer must have a chat template")
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
Expand Down
Loading