diff --git a/dist_run.py b/dist_run.py index 2b4ab67cb..d1da19558 100644 --- a/dist_run.py +++ b/dist_run.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from tokenizer.tokenizer_type import TokenizerType from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs # TODO - these are not distributed specific, consider moving to new package @@ -64,11 +65,6 @@ } -class TokenizerType(Enum): - Tiktoken = auto() - SentencePiece = auto() - - def _init_distributed(): dist.init_process_group("nccl") rank = dist.get_rank() @@ -122,9 +118,9 @@ def _build_chat_tokenizer( ) # set global variable _tokenizer_type if isinstance(tokenizer, TiktokenTokenizer): - _tokenizer_type = TokenizerType.Tiktoken + _tokenizer_type = TokenizerType.TIKTOKEN elif isinstance(tokenizer, SentencePieceProcessor): - _tokenizer_type = TokenizerType.SentencePiece + _tokenizer_type = TokenizerType.SENTENCEPIECE else: raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") @@ -575,11 +571,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if _tokenizer_type == TokenizerType.Tiktoken: + if _tokenizer_type == TokenizerType.TIKTOKEN: # For TiktokenTokenizer, we need to decode prompt by prompt. # TODO: is there a better way to do this? responses = [tokenizer.decode(sequence) for sequence in res_list] - elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor + elif _tokenizer_type == TokenizerType.SENTENCEPIECE: # SentencePieceProcessor # For SentencePieceProcessor, we can decode the entire 2D list at once. responses = tokenizer.decode(res_list) else: diff --git a/tokenizer/tokenizer_type.py b/tokenizer/tokenizer_type.py new file mode 100644 index 000000000..0cdbf5e5d --- /dev/null +++ b/tokenizer/tokenizer_type.py @@ -0,0 +1,7 @@ +from enum import Enum + +class TokenizerType(Enum): + NONE = 0 + TIKTOKEN = 1 + SENTENCEPIECE = 2 + HF_TOKENIZER = 3 \ No newline at end of file diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 001ece603..1d9df9d1a 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -234,13 +234,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args - -class TokenizerType(Enum): - NONE = 0 - TIKTOKEN = 1 - SENTENCEPIECE = 2 - HF_TOKENIZER = 3 - +from tokenizer.tokenizer_type import TokenizerType @dataclass class TokenizerArgs: @@ -276,15 +270,6 @@ def __post_init__(self): except: pass - def is_tiktoken(self) -> bool: - return self.tokenizer_type == TokenizerType.TIKTOKEN - - def is_sentencepiece(self) -> bool: - return self.tokenizer_type == TokenizerType.SENTENCEPIECE - - def is_hf_tokenizer(self) -> bool: - return self.tokenizer_type == TokenizerType.HF_TOKENIZER - def validate_model( self, model: Optional[Model], @@ -296,9 +281,9 @@ def validate_model( if self.tokenizer_type == TokenizerType.NONE: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.is_tiktoken() - is_sentencepiece = self.is_sentencepiece() - is_hf_tokenizer = self.is_hf_tokenizer() + is_tiktoken = self.tokenizer_type == TokenizerType.TIKTOKEN + is_sentencepiece = self.tokenizer_type == TokenizerType.SENTENCEPIECE + is_hf_tokenizer = self.tokenizer_type == TokenizerType.HF_TOKENIZER use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer diff --git a/torchchat/export.py b/torchchat/export.py index 28c9bdfec..5ac840401 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -7,6 +7,7 @@ import os from typing import Dict, Optional +from tokenizer.tokenizer_type import TokenizerType import torch import torch._inductor import torch.nn as nn @@ -482,7 +483,7 @@ def main(args): if tokenizer_args is None: tokenizer_type = "0" - elif tokenizer_args.is_sentencepiece(): + elif tokenizer_args == TokenizerType.SENTENCEPIECE: tokenizer_type = "2" # Corresponding to llama2 else: tokenizer_type = "3" # Corresponding to llama3 diff --git a/torchchat/generate.py b/torchchat/generate.py index 5ae7ecfad..c18c6f466 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -27,6 +27,7 @@ import torch.multiprocessing as mp from PIL import Image +from tokenizer.tokenizer_type import TokenizerType from torch._C import _SDPBackend as SDPBackend from torch.distributed.pipelining import PipelineStage, ScheduleGPipe @@ -361,14 +362,14 @@ def __init__( # must use tiktokenizer. # Piggy backing off of this flag then for now to identify llama3 # without prompting user. - self.is_llama3_model = self.tokenizer_args.is_tiktoken() + self.is_llama3_model = self.tokenizer_args.tokenizer_type == TokenizerType.TIKTOKEN if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) - elif self.tokenizer_args.is_hf_tokenizer(): + elif self.tokenizer_args.tokenizer_type == TokenizerType.HF_TOKENIZER: if not self.tokenizer.has_chat_template(): raise ValueError("Tokenizer must have a chat template") self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)