From a94afd97f966a7f9cdc136692b3581266207c517 Mon Sep 17 00:00:00 2001 From: srikary12 <121927567+srikary12@users.noreply.github.com> Date: Sun, 11 May 2025 13:02:07 +0530 Subject: [PATCH 1/3] Unified tokenizer type onboarding --- tokenizer/tokenizer_type.py | 16 ++++++++++++++++ torchchat/cli/builder.py | 23 +++++------------------ torchchat/generate.py | 4 ++-- 3 files changed, 23 insertions(+), 20 deletions(-) create mode 100644 tokenizer/tokenizer_type.py diff --git a/tokenizer/tokenizer_type.py b/tokenizer/tokenizer_type.py new file mode 100644 index 000000000..8448d6ff0 --- /dev/null +++ b/tokenizer/tokenizer_type.py @@ -0,0 +1,16 @@ +from enum import Enum + +class TokenizerType(Enum): + NONE = 0 + TIKTOKEN = 1 + SENTENCEPIECE = 2 + HF_TOKENIZER = 3 + + def is_tiktoken(self): + return self == TokenizerType.TIKTOKEN + def is_sentencepiece(self): + return self == TokenizerType.SENTENCEPIECE + def is_hf_tokenizer(self): + return self == TokenizerType.HF_TOKENIZER + def is_none(self): + return self == TokenizerType.NONE \ No newline at end of file diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 0661d08f5..7db0c4e93 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -238,11 +238,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args -class TokenizerType(Enum): - NONE = 0 - TIKTOKEN = 1 - SENTENCEPIECE = 2 - HF_TOKENIZER = 3 +from tokenizer.tokenizer_type import TokenizerType @dataclass class TokenizerArgs: @@ -278,15 +274,6 @@ def __post_init__(self): except: pass - def is_tiktoken(self) -> bool: - return self.tokenizer_type == TokenizerType.TIKTOKEN - - def is_sentencepiece(self) -> bool: - return self.tokenizer_type == TokenizerType.SENTENCEPIECE - - def is_hf_tokenizer(self) -> bool: - return self.tokenizer_type == TokenizerType.HF_TOKENIZER - def validate_model( self, model: Optional[Model], @@ -295,12 +282,12 @@ def validate_model( if model is None: return - if self.tokenizer_type == TokenizerType.NONE: + if self.tokenizer_type.is_none(): raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.is_tiktoken() - is_sentencepiece = self.is_sentencepiece() - is_hf_tokenizer = self.is_hf_tokenizer() + is_tiktoken = self.tokenizer_type.is_tiktoken() + is_sentencepiece = self.tokenizer_type.is_sentencepiece() + is_hf_tokenizer = self.tokenizer_type.is_hf_tokenizer() use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer diff --git a/torchchat/generate.py b/torchchat/generate.py index 4f90b316f..5c906d381 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -365,14 +365,14 @@ def __init__( # must use tiktokenizer. # Piggy backing off of this flag then for now to identify llama3 # without prompting user. - self.is_llama3_model = self.tokenizer_args.is_tiktoken() + self.is_llama3_model = self.tokenizer_args.tokenizer_type.is_tiktoken() if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) - elif self.tokenizer_args.is_hf_tokenizer(): + elif self.tokenizer_args.tokenizer_type.is_hf_tokenizer(): if not self.tokenizer.has_chat_template(): raise ValueError("Tokenizer must have a chat template") self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer) From 3d2c8d6f22827bc874221f0bae0007f9568c3c82 Mon Sep 17 00:00:00 2001 From: srikary12 <121927567+srikary12@users.noreply.github.com> Date: Wed, 14 May 2025 20:23:27 +0530 Subject: [PATCH 2/3] Refactor tokenizer type handling to use Enum directly and remove redundant methods --- dist_run.py | 14 +++++--------- tokenizer/tokenizer_type.py | 11 +---------- torchchat/cli/builder.py | 8 ++++---- torchchat/generate.py | 5 +++-- 4 files changed, 13 insertions(+), 25 deletions(-) diff --git a/dist_run.py b/dist_run.py index 2b4ab67cb..d1da19558 100644 --- a/dist_run.py +++ b/dist_run.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from tokenizer.tokenizer_type import TokenizerType from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs # TODO - these are not distributed specific, consider moving to new package @@ -64,11 +65,6 @@ } -class TokenizerType(Enum): - Tiktoken = auto() - SentencePiece = auto() - - def _init_distributed(): dist.init_process_group("nccl") rank = dist.get_rank() @@ -122,9 +118,9 @@ def _build_chat_tokenizer( ) # set global variable _tokenizer_type if isinstance(tokenizer, TiktokenTokenizer): - _tokenizer_type = TokenizerType.Tiktoken + _tokenizer_type = TokenizerType.TIKTOKEN elif isinstance(tokenizer, SentencePieceProcessor): - _tokenizer_type = TokenizerType.SentencePiece + _tokenizer_type = TokenizerType.SENTENCEPIECE else: raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}") @@ -575,11 +571,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # token ids. Thus cat'ing along dim 1. res = torch.cat(res, dim=1) res_list = res.tolist() - if _tokenizer_type == TokenizerType.Tiktoken: + if _tokenizer_type == TokenizerType.TIKTOKEN: # For TiktokenTokenizer, we need to decode prompt by prompt. # TODO: is there a better way to do this? responses = [tokenizer.decode(sequence) for sequence in res_list] - elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor + elif _tokenizer_type == TokenizerType.SENTENCEPIECE: # SentencePieceProcessor # For SentencePieceProcessor, we can decode the entire 2D list at once. responses = tokenizer.decode(res_list) else: diff --git a/tokenizer/tokenizer_type.py b/tokenizer/tokenizer_type.py index 8448d6ff0..0cdbf5e5d 100644 --- a/tokenizer/tokenizer_type.py +++ b/tokenizer/tokenizer_type.py @@ -4,13 +4,4 @@ class TokenizerType(Enum): NONE = 0 TIKTOKEN = 1 SENTENCEPIECE = 2 - HF_TOKENIZER = 3 - - def is_tiktoken(self): - return self == TokenizerType.TIKTOKEN - def is_sentencepiece(self): - return self == TokenizerType.SENTENCEPIECE - def is_hf_tokenizer(self): - return self == TokenizerType.HF_TOKENIZER - def is_none(self): - return self == TokenizerType.NONE \ No newline at end of file + HF_TOKENIZER = 3 \ No newline at end of file diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 7db0c4e93..048c15287 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -282,12 +282,12 @@ def validate_model( if model is None: return - if self.tokenizer_type.is_none(): + if self.tokenizer_type == TokenizerType.NONE: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") - is_tiktoken = self.tokenizer_type.is_tiktoken() - is_sentencepiece = self.tokenizer_type.is_sentencepiece() - is_hf_tokenizer = self.tokenizer_type.is_hf_tokenizer() + is_tiktoken = self.tokenizer_type == TokenizerType.TIKTOKEN + is_sentencepiece = self.tokenizer_type == TokenizerType.SENTENCEPIECE + is_hf_tokenizer = self.tokenizer_type == TokenizerType.HF_TOKENIZER use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer diff --git a/torchchat/generate.py b/torchchat/generate.py index 5c906d381..13aadbb79 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -31,6 +31,7 @@ from PIL import Image # torchtune model definition dependencies +from tokenizer.tokenizer_type import TokenizerType from torchtune.data import Message, padded_collate_tiled_images_and_mask from torchtune.generation import sample as tune_sample @@ -365,14 +366,14 @@ def __init__( # must use tiktokenizer. # Piggy backing off of this flag then for now to identify llama3 # without prompting user. - self.is_llama3_model = self.tokenizer_args.tokenizer_type.is_tiktoken() + self.is_llama3_model = self.tokenizer_args.tokenizer_type == TokenizerType.TIKTOKEN if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) - elif self.tokenizer_args.tokenizer_type.is_hf_tokenizer(): + elif self.tokenizer_args.tokenizer_type == TokenizerType.HF_TOKENIZER: if not self.tokenizer.has_chat_template(): raise ValueError("Tokenizer must have a chat template") self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer) From 22fdd564deec0ebaf35bdc333ffbb1b27f3e9369 Mon Sep 17 00:00:00 2001 From: srikary12 Date: Tue, 27 May 2025 08:13:48 +0530 Subject: [PATCH 3/3] Modify export to use unified tokenizer type --- torchchat/export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/export.py b/torchchat/export.py index 28c9bdfec..5ac840401 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -7,6 +7,7 @@ import os from typing import Dict, Optional +from tokenizer.tokenizer_type import TokenizerType import torch import torch._inductor import torch.nn as nn @@ -482,7 +483,7 @@ def main(args): if tokenizer_args is None: tokenizer_type = "0" - elif tokenizer_args.is_sentencepiece(): + elif tokenizer_args == TokenizerType.SENTENCEPIECE: tokenizer_type = "2" # Corresponding to llama2 else: tokenizer_type = "3" # Corresponding to llama3