Skip to content

[WIP] HuggingFaceModelTokenizer #2723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 164 additions & 2 deletions torchtune/modules/transforms/tokenizers/_hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
# LICENSE file in the root directory of this source tree.

import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import jinja2
from jinja2 import StrictUndefined

from tokenizers import Tokenizer
from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer
from torchtune.data import Message, truncate
from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer, ModelTokenizer


class HuggingFaceBaseTokenizer(BaseTokenizer):
Expand Down Expand Up @@ -146,3 +150,161 @@ def decode(self, token_ids: List[int]) -> str:
str: The decoded string.
"""
return self.tokenizer.decode(token_ids)


def _infer_special_tokens_from_hf_config(config: dict) -> list[str]:
special_tokens = set()

standard_keys = [
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"sep_token",
"cls_token",
"mask_token",
]

for key in standard_keys:
if token := config.get(key):
if isinstance(token, str):
content = token
else:
content = token.get("content")

if content:
special_tokens.add(content)

for token in config.get("additional_special_tokens", []):
if isinstance(token, str):
content = token
else:
content = token.get("content")

if content:
special_tokens.add(content)

for token_info in config.get("added_tokens_decoder", {}).values():
if token_info.get("special", False):
if content := token_info.get("content"):
special_tokens.add(content)

return list(special_tokens)


class HuggingFaceModelTokenizer(ModelTokenizer):
def __init__(
self,
tokenizer_json_path: str,
*,
tokenizer_config_json_path: Optional[str] = None,
generation_config_path: Optional[str] = None,
truncation_type: str = "right",
):
self.base_tokenizer = HuggingFaceBaseTokenizer(
tokenizer_json_path=tokenizer_json_path,
tokenizer_config_json_path=tokenizer_config_json_path,
generation_config_path=generation_config_path,
)
Comment on lines +204 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know @joecummings had some thoughts on whether we should use generic base_tokenizer instead of constraining to use HuggingFaceBaseTokenizer. I suspect the latter is better for making sure everything works together, but I know at least Qwen2Tokenizer still relies on the merges + vocab files instead of the tokenizer.json file (I alluded to this at the very bottom of #2706). So we should figure out if this will work for that case


# Contents of the tokenizer_config.json
config = self.base_tokenizer.config

self.special_tokens = _infer_special_tokens_from_hf_config(config)
self.top_level_variables = self.extract_top_level_variables(config)

_env = jinja2.Environment(undefined=StrictUndefined)

# It is used sometimes in HF chat_templates
_env.globals["raise_exception"] = self._raise_helper

self.template = _env.from_string(
self._get_token_from_config(config, "chat_template")
)
self.truncation_type = truncation_type

def _raise_helper(self, msg):
raise Exception(msg)

def _get_token_from_config(self, config: Dict[str, Any], key: str) -> str:
"""
HF BOS/EOS tokens are either stored as e.g. {'bos_token': 5}
or {'bos_token': {'content': 5, ...}}. This utility handles both.
"""
token = config.get(key)
if isinstance(token, Dict):
if "content" not in token:
raise ValueError(f"Could not parse {key} from config")
token = token["content"]
else:
if not isinstance(token, str):
raise ValueError(f"Could not parse {key} from config")
return token

def extract_top_level_variables(self, config):
top_level = {}
for key, value in config.items():
if not isinstance(value, (dict, list)):
top_level[key] = value
return top_level

def tokenize_messages(
self,
messages: List[Message],
add_eos: bool = True,
max_seq_len: int | None = None,
) -> Tuple[List[int], List[bool]]:
# This part is extremely hacky, but we need to handle case where we have variable access with jinja
special_tokens_mapping = {}
for token in self.special_tokens:
special_tokens_mapping[token] = self.base_tokenizer.encode(token)

tokenized_messages = []
mask = []
previous_tokens = []

for i, message in enumerate(messages):
current_messages = [
{"role": m.role, "content": m.content[0]["content"]}
for m in messages[: i + 1]
]

rendered = self.template.render(
messages=current_messages,
add_generation_prompt=add_eos if i == len(messages) - 1 else False,
**special_tokens_mapping, # We assume that the naming is consistent
**self.top_level_variables,
)

current_tokens = self.base_tokenizer.encode(rendered, add_eos=False)

delta = current_tokens[len(previous_tokens) :]
previous_tokens = current_tokens

if message.masked:
tokenized_messages.extend([True] * len(delta))
else:
tokenized_messages.extend(delta)

mask.extend([message.masked] * len(delta))

if add_eos and self.base_tokenizer.eos_id is not None:
tokenized_messages.append(self.base_tokenizer.eos_id)
mask.append(False)

# Finally, truncate if necessary
tokenized_messages = truncate(
tokens=tokenized_messages,
max_seq_len=max_seq_len,
eos_id=None,
truncation_type=self.truncation_type,
)

mask = truncate(
tokens=mask,
max_seq_len=max_seq_len,
eos_id=True if add_eos else None,
truncation_type=self.truncation_type,
)

return tokenized_messages, mask