|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Render canned chat examples with a chosen template to inspect formatting.""" |
| 3 | + |
| 4 | +import argparse |
| 5 | +from pathlib import Path |
| 6 | +from typing import Iterable, List |
| 7 | + |
| 8 | +from transformers import AutoTokenizer |
| 9 | + |
| 10 | +from open_instruct.dataset_transformation import CHAT_TEMPLATES |
| 11 | + |
| 12 | +MODEL_NAME = "allenai/dolma2-tokenizer" |
| 13 | +MODEL_REVISION = "main" |
| 14 | + |
| 15 | +SINGLE_TURN_REASONING = [ |
| 16 | + {"role": "system", "content": "You are a helpful AI assistant."}, |
| 17 | + {"role": "user", "content": "The prompt was asdlkasd"}, |
| 18 | + { |
| 19 | + "role": "assistant", |
| 20 | + "content": ( |
| 21 | + "<think>Okay... user sent \"asdlkasd\"—probably a test. Stay friendly, invite clarification...</think>\n" |
| 22 | + "It looks like your message might be a bit jumbled! Could you clarify what you're asking about? " |
| 23 | + "I'm here to help with AI, language models, writing, coding—whatever you need." |
| 24 | + ), |
| 25 | + }, |
| 26 | +] |
| 27 | + |
| 28 | +MULTI_TURN_REASONING = [ |
| 29 | + {"role": "system", "content": "You are a helpful AI assistant."}, |
| 30 | + {"role": "user", "content": "The prompt was asdlkasd"}, |
| 31 | + { |
| 32 | + "role": "assistant", |
| 33 | + "content": ( |
| 34 | + "<think>First turn... just restate the prompt and ask what they need.</think>\n" |
| 35 | + 'The prompt you shared was "asdlkasd". Did you want me to expand on it or help craft a new one?' |
| 36 | + ), |
| 37 | + }, |
| 38 | + {"role": "user", "content": "Please restate it politely so I can show templating."}, |
| 39 | + { |
| 40 | + "role": "assistant", |
| 41 | + "content": ( |
| 42 | + "<think>Second turn... reassure them and keep the tone upbeat...</think>\n" |
| 43 | + "Absolutely! It looked like your message might have been a little jumbled—just let me know what you'd " |
| 44 | + "like to explore and I'm happy to dive in." |
| 45 | + ), |
| 46 | + }, |
| 47 | +] |
| 48 | + |
| 49 | +BASIC_CHAT_TRANSCRIPT = [ |
| 50 | + {"role": "user", "content": "Hello, how are you?"}, |
| 51 | + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, |
| 52 | + {"role": "user", "content": "I'd like to show off how chat templating works!"}, |
| 53 | +] |
| 54 | + |
| 55 | +EXAMPLES = { |
| 56 | + "single_reasoning": { |
| 57 | + "messages": SINGLE_TURN_REASONING, |
| 58 | + "description": "Single assistant turn with <think> reasoning.", |
| 59 | + }, |
| 60 | + "multi_reasoning": { |
| 61 | + "messages": MULTI_TURN_REASONING, |
| 62 | + "description": "Two assistant turns, both containing <think> traces.", |
| 63 | + }, |
| 64 | + "basic_chat": { |
| 65 | + "messages": BASIC_CHAT_TRANSCRIPT, |
| 66 | + "description": "Simple chat without reasoning tags.", |
| 67 | + }, |
| 68 | +} |
| 69 | + |
| 70 | +DEFAULT_EXAMPLES = ("single_reasoning", "multi_reasoning") |
| 71 | + |
| 72 | + |
| 73 | +def parse_args() -> argparse.Namespace: |
| 74 | + parser = argparse.ArgumentParser(description=__doc__) |
| 75 | + parser.add_argument( |
| 76 | + "--model-name", |
| 77 | + default=MODEL_NAME, |
| 78 | + help="Tokenizer identifier on Hugging Face (default: %(default)s).", |
| 79 | + ) |
| 80 | + parser.add_argument( |
| 81 | + "--revision", |
| 82 | + default=MODEL_REVISION, |
| 83 | + help="Tokenizer revision, tag, or commit (default: %(default)s).", |
| 84 | + ) |
| 85 | + parser.add_argument( |
| 86 | + "--template", |
| 87 | + default="olmo_thinker_remove_intermediate_thinking", |
| 88 | + help=( |
| 89 | + "Either the key of a template in open_instruct.dataset_transformation.CHAT_TEMPLATES " |
| 90 | + "or a filesystem path to a Jinja template." |
| 91 | + ), |
| 92 | + ) |
| 93 | + parser.add_argument( |
| 94 | + "--examples", |
| 95 | + nargs="+", |
| 96 | + default=list(DEFAULT_EXAMPLES), |
| 97 | + choices=list(EXAMPLES.keys()) + ["all"], |
| 98 | + help="Which canned message sets to render (use 'all' for everything).", |
| 99 | + ) |
| 100 | + parser.add_argument( |
| 101 | + "--show-tokens", |
| 102 | + action="store_true", |
| 103 | + help="Also print token ids and counts to compare serialized lengths.", |
| 104 | + ) |
| 105 | + parser.add_argument( |
| 106 | + "--snippet-len", |
| 107 | + type=int, |
| 108 | + default=160, |
| 109 | + help="Character cap for message previews before printing token ids (0 to disable truncation).", |
| 110 | + ) |
| 111 | + return parser.parse_args() |
| 112 | + |
| 113 | + |
| 114 | +def resolve_examples(selection: Iterable[str]) -> List[str]: |
| 115 | + ordered = list(dict.fromkeys(selection)) # preserve order, drop duplicates |
| 116 | + if "all" in ordered: |
| 117 | + return list(EXAMPLES.keys()) |
| 118 | + return ordered |
| 119 | + |
| 120 | + |
| 121 | +def load_template(template_arg: str) -> str: |
| 122 | + template_path = Path(template_arg) |
| 123 | + if template_path.exists(): |
| 124 | + return template_path.read_text() |
| 125 | + if template_arg in CHAT_TEMPLATES: |
| 126 | + return CHAT_TEMPLATES[template_arg] |
| 127 | + raise ValueError( |
| 128 | + f"Template '{template_arg}' is neither a file nor a key in CHAT_TEMPLATES. " |
| 129 | + f"Available keys: {', '.join(sorted(CHAT_TEMPLATES.keys()))}" |
| 130 | + ) |
| 131 | + |
| 132 | + |
| 133 | +def main() -> None: |
| 134 | + args = parse_args() |
| 135 | + tokenizer = AutoTokenizer.from_pretrained(args.model_name, revision=args.revision, use_fast=True) |
| 136 | + tokenizer.chat_template = load_template(args.template) |
| 137 | + |
| 138 | + had_error = False |
| 139 | + for example_name in resolve_examples(args.examples): |
| 140 | + example = EXAMPLES[example_name] |
| 141 | + print("\n" + "=" * 80) |
| 142 | + print(f"{example_name} :: {example['description']}") |
| 143 | + print("-" * 80) |
| 144 | + |
| 145 | + print("Messages:") |
| 146 | + for idx, message in enumerate(example["messages"], start=1): |
| 147 | + snippet = message["content"].replace("\n", " ") |
| 148 | + if args.snippet_len and len(snippet) > args.snippet_len: |
| 149 | + snippet = snippet[: args.snippet_len - 3] + "..." |
| 150 | + print(f" {idx}. {message['role']}: {snippet}") |
| 151 | + |
| 152 | + print("\nFormatted:") |
| 153 | + try: |
| 154 | + rendered = tokenizer.apply_chat_template(example["messages"], tokenize=False) |
| 155 | + except Exception as exc: # pragma: no cover - helpful for manual debugging |
| 156 | + had_error = True |
| 157 | + print(f"[ERROR] {exc}") |
| 158 | + else: |
| 159 | + print(rendered) |
| 160 | + if args.show_tokens: |
| 161 | + print("\nTokenized:") |
| 162 | + token_data = tokenizer.apply_chat_template(example["messages"], tokenize=True, add_generation_prompt=False) |
| 163 | + if isinstance(token_data, list): |
| 164 | + ids = token_data |
| 165 | + elif hasattr(token_data, "tolist"): |
| 166 | + ids = token_data.tolist() |
| 167 | + if ids and isinstance(ids[0], list): |
| 168 | + ids = ids[0] |
| 169 | + elif hasattr(token_data, "__iter__"): |
| 170 | + ids = list(token_data) |
| 171 | + else: |
| 172 | + ids = [int(token_data)] |
| 173 | + print(f"\nToken count: {len(ids)}") |
| 174 | + print(ids) |
| 175 | + print("=" * 80) |
| 176 | + if had_error: |
| 177 | + raise SystemExit(1) |
| 178 | + |
| 179 | + |
| 180 | +if __name__ == "__main__": |
| 181 | + main() |
0 commit comments