Skip to content

Commit 506fbbd

Browse files
sandeepchittillas00652993
authored andcommitted
Fix prompt truncation bug and handle deepspeed preparation
1 parent dfa814d commit 506fbbd

File tree

6 files changed

+75
-36
lines changed

6 files changed

+75
-36
lines changed

examples/dpo_ultrafeedback.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import itertools
2-
import json
31
import sys
2+
import json
3+
from functools import partial
44

55
from datasets import load_dataset
6+
from transformers import AutoTokenizer
67

78
import trlx
89
from trlx.data.default_configs import (
@@ -15,54 +16,81 @@
1516
TRLConfig,
1617
)
1718

19+
model_path = "HuggingFaceH4/mistral-7b-sft-beta"
20+
wandb_project = "trlx"
21+
1822
default_config = TRLConfig(
1923
train=TrainConfig(
2024
seq_length=1024,
21-
epochs=100,
22-
total_steps=1000,
25+
epochs=2,
26+
total_steps=1000000,
2327
batch_size=1,
24-
checkpoint_interval=10000,
25-
eval_interval=100,
28+
checkpoint_interval=100000,
29+
eval_interval=1000,
30+
seed=42,
31+
project_name=wandb_project,
2632
pipeline="PromptPipeline",
2733
trainer="AccelerateDPOTrainer",
2834
checkpoint_dir="checkpoints/dpo_ultrafeedback",
2935
),
30-
model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1),
31-
tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"),
32-
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)),
36+
model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1),
37+
tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"),
38+
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)),
3339
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
3440
method=DPOConfig(
3541
name="DPOConfig",
36-
gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True),
42+
gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True),
3743
beta=0.1,
3844
label_pad_token_id=-100,
3945
padding_value=0,
4046
),
4147
)
4248

4349

44-
def preprocess(sample):
50+
def preprocess(sample, tokenizer, test=False):
4551
"""
46-
Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2
52+
Formats the input to the same training style used for mistral-7b-v0.1
53+
When fine-tuning, modify your pre-processing to match the prompt template used during pretraining.
4754
"""
4855
assert len(sample["chosen"]) == len(sample["rejected"]) == 2
4956

50-
sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]]
51-
return sample
57+
assistant_prompt = "<|assistant|>"
58+
59+
prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt)
60+
rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1]
61+
62+
return {
63+
"prompt": prompt if not test else prompt + assistant_prompt,
64+
"chosen": assistant_prompt + chosen,
65+
"rejected": assistant_prompt + rejected,
66+
}
5267

5368

5469
def main(hparams={}):
5570
config = TRLConfig.update(default_config, hparams)
5671

57-
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess)
72+
tokenizer = AutoTokenizer.from_pretrained(model_path)
73+
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized")
74+
75+
dataset["dpo_train"] = dataset["train_prefs"].map(
76+
partial(preprocess, tokenizer=tokenizer, test=False),
77+
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
78+
)
79+
dataset["dpo_test"] = dataset["test_prefs"].map(
80+
partial(preprocess, tokenizer=tokenizer, test=True),
81+
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
82+
)
83+
84+
print(
85+
f"Length of training dataset : {len(dataset['dpo_train'])} \
86+
Length of test dataset : {len(dataset['dpo_test'])}"
87+
)
5888

5989
trlx.train(
6090
config=config,
61-
samples=dataset["train_prefs"]["dpo"],
62-
eval_prompts=dataset["test_prefs"]["prompt"][:128],
63-
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
64-
stop_sequences=["User:", "user:", "Assistant:", "assistant:"]
65-
+ ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])],
91+
samples=dataset["dpo_train"],
92+
eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only
93+
stop_sequences=["<|user|>", "<|User|>"],
6694
)
6795

6896

trlx/pipeline/offline_pipeline.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,19 @@ def __init__(
300300

301301
@staticmethod
302302
def tokenize_preferences(
303-
sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048
303+
sample: Iterable[str],
304+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
305+
max_length=2048,
306+
max_prompt_length=256,
304307
) -> DPOElement:
305308
if isinstance(sample, Iterable):
306309
if len(sample) != 3:
307310
raise ValueError(
308311
f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}"
309312
)
310-
prompt_tokens = tokenizer(sample[0], add_special_tokens=False)
311-
chosen_tokens = tokenizer(sample[1], add_special_tokens=False)
312-
rejected_tokens = tokenizer(sample[2], add_special_tokens=False)
313+
prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False)
314+
chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False)
315+
rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False)
313316
else:
314317
raise ValueError(f"{sample} is not an iterable")
315318

@@ -324,14 +327,14 @@ def tokenize_preferences(
324327
# if combined sequence is too long, truncate the prompt only
325328
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
326329
if tokenizer.truncation_side == "right":
327-
prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()}
330+
prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
328331
elif tokenizer.truncation_side == "left":
329-
prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()}
332+
prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}
330333

331334
# if that's still too long, truncate the response
332335
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
333-
chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()}
334-
rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()}
336+
chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()}
337+
rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()}
335338

336339
return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)
337340

trlx/trainer/accelerate_dpo_trainer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
if is_deepspeed_available():
1111
import deepspeed
1212

13+
import trlx.utils.logging as logging
1314
from trlx.data.configs import TRLConfig
1415
from trlx.data.method_configs import MethodConfig, register_method
1516
from trlx.pipeline.offline_pipeline import DPOStore
@@ -18,6 +19,9 @@
1819
from trlx.utils.modeling import pad_to_length
1920

2021

22+
logger = logging.get_logger(__name__)
23+
24+
2125
@dataclass
2226
@register_method
2327
class DPOConfig(MethodConfig):
@@ -47,9 +51,10 @@ def __init__(self, config: TRLConfig, **kwargs):
4751

4852
# TODO: Avoid setting up a reference model when hydra heads are used
4953
self.ref_model = self.get_arch(self.config)
50-
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
51-
self.ref_model = self._prepare_deepspeed_zero3(self.ref_model)
52-
else:
54+
try:
55+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
56+
self.ref_model = self._prepare_deepspeed_zero3(self.ref_model)
57+
except:
5358
self.ref_model.to(self.accelerator.device)
5459
self.ref_model.eval()
5560

@@ -311,6 +316,8 @@ def prepare_learning(self):
311316
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
312317
self.total_steps = min(self.total_steps, self.config.train.total_steps)
313318

314-
def make_experience(self, samples: Iterable[Iterable], seq_length: int):
315-
preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples]
319+
def make_experience(self, samples: Iterable[Iterable], seq_length: int, max_prompt_length: int):
320+
preferences = [
321+
DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length, max_prompt_length) for sample in samples
322+
]
316323
self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value)

trlx/trainer/accelerate_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def prepare_learning(self):
8787
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
8888
self.total_steps = min(self.total_steps, self.config.train.total_steps)
8989

90-
def make_experience(self, samples, seq_length):
90+
def make_experience(self, samples, seq_length, **kwargs):
9191
if isinstance(samples[0], str):
9292
self.store = PromptPipeline(samples, seq_length, self.tokenizer)
9393
else:

trlx/trainer/nemo_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def eval_collate(elems):
133133
torch.set_float32_matmul_precision("medium")
134134
self.trainer.fit(self.model)
135135

136-
def make_experience(self, samples, seq_length):
136+
def make_experience(self, samples, seq_length, **kwargs):
137137
if isinstance(samples[0], str):
138138
self.store = PromptPipeline(samples, seq_length, self.tokenizer)
139139
else:

trlx/trlx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def train( # noqa: C901
114114
if rewards is not None:
115115
trainer.make_experience(samples, rewards, config.train.seq_length)
116116
else:
117-
trainer.make_experience(samples, config.train.seq_length)
117+
# this should be abstracted for all trainers with **kwargs
118+
trainer.make_experience(samples, config.train.seq_length, max_prompt_length)
118119
else:
119120
raise ValueError("Either `samples` or `reward_fn` should be given for training")
120121

0 commit comments

Comments
 (0)