|
1 | | -import itertools |
2 | | -import json |
3 | 1 | import sys |
| 2 | +import json |
| 3 | +from functools import partial |
4 | 4 |
|
5 | 5 | from datasets import load_dataset |
| 6 | +from transformers import AutoTokenizer |
6 | 7 |
|
7 | 8 | import trlx |
8 | 9 | from trlx.data.default_configs import ( |
|
15 | 16 | TRLConfig, |
16 | 17 | ) |
17 | 18 |
|
| 19 | +model_path = "HuggingFaceH4/mistral-7b-sft-beta" |
| 20 | +wandb_project = "trlx" |
| 21 | + |
18 | 22 | default_config = TRLConfig( |
19 | 23 | train=TrainConfig( |
20 | 24 | seq_length=1024, |
21 | | - epochs=100, |
22 | | - total_steps=1000, |
| 25 | + epochs=2, |
| 26 | + total_steps=1000000, |
23 | 27 | batch_size=1, |
24 | | - checkpoint_interval=10000, |
25 | | - eval_interval=100, |
| 28 | + checkpoint_interval=100000, |
| 29 | + eval_interval=1000, |
| 30 | + seed=42, |
| 31 | + project_name=wandb_project, |
26 | 32 | pipeline="PromptPipeline", |
27 | 33 | trainer="AccelerateDPOTrainer", |
28 | 34 | checkpoint_dir="checkpoints/dpo_ultrafeedback", |
29 | 35 | ), |
30 | | - model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1), |
31 | | - tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"), |
32 | | - optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), |
| 36 | + model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1), |
| 37 | + tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"), |
| 38 | + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), |
33 | 39 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps |
34 | 40 | method=DPOConfig( |
35 | 41 | name="DPOConfig", |
36 | | - gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), |
| 42 | + gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), |
37 | 43 | beta=0.1, |
38 | 44 | label_pad_token_id=-100, |
39 | 45 | padding_value=0, |
40 | 46 | ), |
41 | 47 | ) |
42 | 48 |
|
43 | 49 |
|
44 | | -def preprocess(sample): |
| 50 | +def preprocess(sample, tokenizer, test=False): |
45 | 51 | """ |
46 | | - Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2 |
| 52 | + Formats the input to the same training style used for mistral-7b-v0.1 |
| 53 | + When fine-tuning, modify your pre-processing to match the prompt template used during pretraining. |
47 | 54 | """ |
48 | 55 | assert len(sample["chosen"]) == len(sample["rejected"]) == 2 |
49 | 56 |
|
50 | | - sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]] |
51 | | - return sample |
| 57 | + assistant_prompt = "<|assistant|>" |
| 58 | + |
| 59 | + prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt) |
| 60 | + rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1] |
| 61 | + |
| 62 | + return { |
| 63 | + "prompt": prompt if not test else prompt + assistant_prompt, |
| 64 | + "chosen": assistant_prompt + chosen, |
| 65 | + "rejected": assistant_prompt + rejected, |
| 66 | + } |
52 | 67 |
|
53 | 68 |
|
54 | 69 | def main(hparams={}): |
55 | 70 | config = TRLConfig.update(default_config, hparams) |
56 | 71 |
|
57 | | - dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess) |
| 72 | + tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 73 | + dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized") |
| 74 | + |
| 75 | + dataset["dpo_train"] = dataset["train_prefs"].map( |
| 76 | + partial(preprocess, tokenizer=tokenizer, test=False), |
| 77 | + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], |
| 78 | + ) |
| 79 | + dataset["dpo_test"] = dataset["test_prefs"].map( |
| 80 | + partial(preprocess, tokenizer=tokenizer, test=True), |
| 81 | + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], |
| 82 | + ) |
| 83 | + |
| 84 | + print( |
| 85 | + f"Length of training dataset : {len(dataset['dpo_train'])} \ |
| 86 | + Length of test dataset : {len(dataset['dpo_test'])}" |
| 87 | + ) |
58 | 88 |
|
59 | 89 | trlx.train( |
60 | 90 | config=config, |
61 | | - samples=dataset["train_prefs"]["dpo"], |
62 | | - eval_prompts=dataset["test_prefs"]["prompt"][:128], |
63 | | - # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, |
64 | | - stop_sequences=["User:", "user:", "Assistant:", "assistant:"] |
65 | | - + ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])], |
| 91 | + samples=dataset["dpo_train"], |
| 92 | + eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only |
| 93 | + stop_sequences=["<|user|>", "<|User|>"], |
66 | 94 | ) |
67 | 95 |
|
68 | 96 |
|
|
0 commit comments