diff --git a/examples/training.py b/examples/training.py new file mode 100644 index 0000000000..d53b6e4b16 --- /dev/null +++ b/examples/training.py @@ -0,0 +1,25 @@ +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer + +dataset = load_dataset( + "unitxt/data", + "card=cards.wnli,template=templates.classification.multi_class.relation.default,max_test_instances=100", + trust_remote_code=True, +) + + +def formatting(example): + texts = [] + for i in range(len(example["source"])): + text = example["source"][i] + example["target"][i] + texts.append(text) + return texts + + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset["train"], + args=SFTConfig(output_dir="./opt-350m"), +) + +trainer.train()