Skip to content

Commit 5c00225

Browse files
authored
Remove interleaved thinking for MT thinking model (#1135)
* Strip intermediate <think> blocks in olmo_thinker_no_think * Adjust olmo thinker template reasoning stripping * Add script to export chat templates * code review
1 parent b6441f5 commit 5c00225

File tree

3 files changed

+295
-0
lines changed

3 files changed

+295
-0
lines changed

open_instruct/dataset_transformation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,51 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai
438438
"{% endif %}"
439439
"{% endfor %}"
440440
),
441+
"olmo_thinker_remove_intermediate_thinking": (
442+
"{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
443+
"{% if not has_system %}"
444+
"{{ '<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n' }}"
445+
"{% endif %}"
446+
"{% for message in messages %}"
447+
"{% if message['role'] == 'system' %}"
448+
"{{ '<|im_start|>system\n' + message['content'] }}"
449+
"{% if message.get('functions', none) is not none %}"
450+
"{{ ' <functions>' + message['functions'] + '</functions><|im_end|>\n' }}"
451+
"{% else %}"
452+
"{{ ' You do not currently have access to any functions. <functions></functions><|im_end|>\n' }}"
453+
"{% endif %}"
454+
"{% elif message['role'] == 'user' %}"
455+
"{% if message.get('functions', none) is not none %}"
456+
"{{ '<|im_start|>user\n' + message['content'] + '\n' + '<functions>' + message['functions'] + '</functions><|im_end|>\n' }}"
457+
"{% else %}"
458+
"{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}"
459+
"{% endif %}"
460+
"{% elif message['role'] == 'assistant' %}"
461+
"{{ '<|im_start|>assistant\n' }}"
462+
"{% set content = message.get('content', none) %}"
463+
"{% if content is not none %}"
464+
"{% set content = content | string %}"
465+
"{% if not loop.last and '</think>' in content and '<think>' in content %}"
466+
"{% set content = content.split('</think>')[-1].lstrip('\\n') %}"
467+
"{% endif %}"
468+
"{{ content }}"
469+
"{% endif %}"
470+
"{% if message.get('function_calls', none) is not none %}"
471+
"{{ '<function_calls>' + message['function_calls'] + '</function_calls>' }}"
472+
"{% endif %}"
473+
"{% if not loop.last %}"
474+
"{{ '<|im_end|>' + '\n' }}"
475+
"{% else %}"
476+
"{{ eos_token }}"
477+
"{% endif %}"
478+
"{% elif message['role'] == 'environment' %}"
479+
"{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}"
480+
"{% endif %}"
481+
"{% if loop.last and add_generation_prompt %}"
482+
"{{ '<|im_start|>assistant\n<think>' }}"
483+
"{% endif %}"
484+
"{% endfor %}"
485+
),
441486
"olmo_thinker_no_think_sft_tokenization": (
442487
"{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
443488
"{% if not has_system %}"

scripts/export_chat_template.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python
2+
"""Save a chat template defined in open_instruct.dataset_transformation to a Jinja file."""
3+
4+
import argparse
5+
import sys
6+
from pathlib import Path
7+
8+
from open_instruct.dataset_transformation import CHAT_TEMPLATES
9+
10+
# Example
11+
# uv run python scripts/export_chat_template.py olmo_thinker_remove_intermediate_thinking
12+
13+
def parse_args() -> argparse.Namespace:
14+
parser = argparse.ArgumentParser(description=__doc__)
15+
parser.add_argument(
16+
"template",
17+
help="Name of the chat template as defined in open_instruct.dataset_transformation.CHAT_TEMPLATES.",
18+
)
19+
parser.add_argument(
20+
"--output",
21+
type=Path,
22+
help="Path to write the template (defaults to TEMPLATE_NAME.jinja in the current directory).",
23+
)
24+
parser.add_argument(
25+
"--overwrite",
26+
action="store_true",
27+
help="Allow overwriting an existing file; otherwise the script exits with an error.",
28+
)
29+
parser.add_argument(
30+
"--list",
31+
action="store_true",
32+
help="List available template names and exit (Ignores other flags).",
33+
)
34+
return parser.parse_args()
35+
36+
37+
def list_templates() -> None:
38+
print("Available chat templates:")
39+
for name in sorted(CHAT_TEMPLATES.keys()):
40+
print(f" - {name}")
41+
42+
43+
def main() -> None:
44+
args = parse_args()
45+
46+
if args.list:
47+
list_templates()
48+
return
49+
50+
if args.template not in CHAT_TEMPLATES:
51+
print(f"Unknown template '{args.template}'. Use --list for options.", file=sys.stderr)
52+
raise SystemExit(1)
53+
54+
template_str = CHAT_TEMPLATES[args.template]
55+
# ensure POSIX newlines
56+
template_str = template_str.replace("\r\n", "\n")
57+
58+
output_path = args.output or Path(f"{args.template}.jinja")
59+
if output_path.exists() and not args.overwrite:
60+
print(f"{output_path} already exists. Use --overwrite to replace it.", file=sys.stderr)
61+
raise SystemExit(1)
62+
63+
output_path.write_text(template_str, encoding="utf-8")
64+
print(f"Wrote {args.template} template to {output_path}")
65+
66+
67+
if __name__ == "__main__":
68+
main()
69+

scripts/test_chat_templates.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/usr/bin/env python
2+
"""Render canned chat examples with a chosen template to inspect formatting."""
3+
4+
import argparse
5+
from pathlib import Path
6+
from typing import Iterable, List
7+
8+
from transformers import AutoTokenizer
9+
10+
from open_instruct.dataset_transformation import CHAT_TEMPLATES
11+
12+
MODEL_NAME = "allenai/dolma2-tokenizer"
13+
MODEL_REVISION = "main"
14+
15+
SINGLE_TURN_REASONING = [
16+
{"role": "system", "content": "You are a helpful AI assistant."},
17+
{"role": "user", "content": "The prompt was asdlkasd"},
18+
{
19+
"role": "assistant",
20+
"content": (
21+
"<think>Okay... user sent \"asdlkasd\"—probably a test. Stay friendly, invite clarification...</think>\n"
22+
"It looks like your message might be a bit jumbled! Could you clarify what you're asking about? "
23+
"I'm here to help with AI, language models, writing, coding—whatever you need."
24+
),
25+
},
26+
]
27+
28+
MULTI_TURN_REASONING = [
29+
{"role": "system", "content": "You are a helpful AI assistant."},
30+
{"role": "user", "content": "The prompt was asdlkasd"},
31+
{
32+
"role": "assistant",
33+
"content": (
34+
"<think>First turn... just restate the prompt and ask what they need.</think>\n"
35+
'The prompt you shared was "asdlkasd". Did you want me to expand on it or help craft a new one?'
36+
),
37+
},
38+
{"role": "user", "content": "Please restate it politely so I can show templating."},
39+
{
40+
"role": "assistant",
41+
"content": (
42+
"<think>Second turn... reassure them and keep the tone upbeat...</think>\n"
43+
"Absolutely! It looked like your message might have been a little jumbled—just let me know what you'd "
44+
"like to explore and I'm happy to dive in."
45+
),
46+
},
47+
]
48+
49+
BASIC_CHAT_TRANSCRIPT = [
50+
{"role": "user", "content": "Hello, how are you?"},
51+
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
52+
{"role": "user", "content": "I'd like to show off how chat templating works!"},
53+
]
54+
55+
EXAMPLES = {
56+
"single_reasoning": {
57+
"messages": SINGLE_TURN_REASONING,
58+
"description": "Single assistant turn with <think> reasoning.",
59+
},
60+
"multi_reasoning": {
61+
"messages": MULTI_TURN_REASONING,
62+
"description": "Two assistant turns, both containing <think> traces.",
63+
},
64+
"basic_chat": {
65+
"messages": BASIC_CHAT_TRANSCRIPT,
66+
"description": "Simple chat without reasoning tags.",
67+
},
68+
}
69+
70+
DEFAULT_EXAMPLES = ("single_reasoning", "multi_reasoning")
71+
72+
73+
def parse_args() -> argparse.Namespace:
74+
parser = argparse.ArgumentParser(description=__doc__)
75+
parser.add_argument(
76+
"--model-name",
77+
default=MODEL_NAME,
78+
help="Tokenizer identifier on Hugging Face (default: %(default)s).",
79+
)
80+
parser.add_argument(
81+
"--revision",
82+
default=MODEL_REVISION,
83+
help="Tokenizer revision, tag, or commit (default: %(default)s).",
84+
)
85+
parser.add_argument(
86+
"--template",
87+
default="olmo_thinker_remove_intermediate_thinking",
88+
help=(
89+
"Either the key of a template in open_instruct.dataset_transformation.CHAT_TEMPLATES "
90+
"or a filesystem path to a Jinja template."
91+
),
92+
)
93+
parser.add_argument(
94+
"--examples",
95+
nargs="+",
96+
default=list(DEFAULT_EXAMPLES),
97+
choices=list(EXAMPLES.keys()) + ["all"],
98+
help="Which canned message sets to render (use 'all' for everything).",
99+
)
100+
parser.add_argument(
101+
"--show-tokens",
102+
action="store_true",
103+
help="Also print token ids and counts to compare serialized lengths.",
104+
)
105+
parser.add_argument(
106+
"--snippet-len",
107+
type=int,
108+
default=160,
109+
help="Character cap for message previews before printing token ids (0 to disable truncation).",
110+
)
111+
return parser.parse_args()
112+
113+
114+
def resolve_examples(selection: Iterable[str]) -> List[str]:
115+
ordered = list(dict.fromkeys(selection)) # preserve order, drop duplicates
116+
if "all" in ordered:
117+
return list(EXAMPLES.keys())
118+
return ordered
119+
120+
121+
def load_template(template_arg: str) -> str:
122+
template_path = Path(template_arg)
123+
if template_path.exists():
124+
return template_path.read_text()
125+
if template_arg in CHAT_TEMPLATES:
126+
return CHAT_TEMPLATES[template_arg]
127+
raise ValueError(
128+
f"Template '{template_arg}' is neither a file nor a key in CHAT_TEMPLATES. "
129+
f"Available keys: {', '.join(sorted(CHAT_TEMPLATES.keys()))}"
130+
)
131+
132+
133+
def main() -> None:
134+
args = parse_args()
135+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, revision=args.revision, use_fast=True)
136+
tokenizer.chat_template = load_template(args.template)
137+
138+
had_error = False
139+
for example_name in resolve_examples(args.examples):
140+
example = EXAMPLES[example_name]
141+
print("\n" + "=" * 80)
142+
print(f"{example_name} :: {example['description']}")
143+
print("-" * 80)
144+
145+
print("Messages:")
146+
for idx, message in enumerate(example["messages"], start=1):
147+
snippet = message["content"].replace("\n", " ")
148+
if args.snippet_len and len(snippet) > args.snippet_len:
149+
snippet = snippet[: args.snippet_len - 3] + "..."
150+
print(f" {idx}. {message['role']}: {snippet}")
151+
152+
print("\nFormatted:")
153+
try:
154+
rendered = tokenizer.apply_chat_template(example["messages"], tokenize=False)
155+
except Exception as exc: # pragma: no cover - helpful for manual debugging
156+
had_error = True
157+
print(f"[ERROR] {exc}")
158+
else:
159+
print(rendered)
160+
if args.show_tokens:
161+
print("\nTokenized:")
162+
token_data = tokenizer.apply_chat_template(example["messages"], tokenize=True, add_generation_prompt=False)
163+
if isinstance(token_data, list):
164+
ids = token_data
165+
elif hasattr(token_data, "tolist"):
166+
ids = token_data.tolist()
167+
if ids and isinstance(ids[0], list):
168+
ids = ids[0]
169+
elif hasattr(token_data, "__iter__"):
170+
ids = list(token_data)
171+
else:
172+
ids = [int(token_data)]
173+
print(f"\nToken count: {len(ids)}")
174+
print(ids)
175+
print("=" * 80)
176+
if had_error:
177+
raise SystemExit(1)
178+
179+
180+
if __name__ == "__main__":
181+
main()

0 commit comments

Comments
 (0)