-
Notifications
You must be signed in to change notification settings - Fork 482
Add Stable Vicuna Training #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
PhungVanDuy
wants to merge
5
commits into
CarperAI:main
Choose a base branch
from
PhungVanDuy:add_stable_vicuna
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| ## StableVicuna: Open Source RLHF LLM Chatbot | ||
|
|
||
| ### Dataset | ||
|
|
||
| 1. Reward Dataset: | ||
| - Download at [this](https://huggingface.co/datasets/reciprocate/oasst_hh_shp_hellaswag_webgpt_rm_dataset) | ||
| - Dataset size: | ||
| - Train: 264534 samples | ||
| - Valid: 2874 samples | ||
|
|
||
| 2. SFT and RL Prompt Dataset: | ||
| - Download at [this](https://huggingface.co/datasets/pvduy/stable_vicuna_oasst_format) | ||
| - Dataset size: 89991 samples | ||
|
|
||
| ### Reward Model Training | ||
| To train reward model you can following instruction from this [repo](https://github.com/CarperAI/autocrit) | ||
|
|
||
| Command: | ||
| ```bash | ||
| python preference.py --model_path reciprocate/dahoas-gptj-rm-static --dataset reciprocate/oasst_hh_shp_hellaswag_webgpt_rm_dataset --batch_size 2 --eval_interval 500 --lr 0.000001 | ||
| ``` | ||
|
|
||
| ### RL Training: | ||
|
|
||
| 1. Distributed Training: | ||
|
|
||
| We trained on 4 nodes with 8 A100 GPUs each. | ||
|
|
||
|
|
||
| ```bash | ||
| sbatch go_train_dist.sh | ||
| ``` | ||
|
|
||
| WANDB runs: https://wandb.ai/pvduy/trlx/runs/w8d20kam | ||
|
|
||
| 2. Single Node Training: | ||
| In case that you want to train on a single node, you can use the following command, but we do not garantee the result. | ||
|
|
||
| ```bash | ||
| accelerate launch examples/rl_training.py | ||
| ``` | ||
|
|
||
| Accelerate config: | ||
| ```yaml | ||
| compute_environment: LOCAL_MACHINE | ||
| deepspeed_config: | ||
| deepspeed_multinode_launcher: standard | ||
| gradient_accumulation_steps: 1 | ||
| gradient_clipping: 1.0 | ||
| offload_optimizer_device: none | ||
| offload_param_device: none | ||
| zero3_init_flag: false | ||
| zero_stage: 2 | ||
| distributed_type: DEEPSPEED | ||
| downcast_bf16: no | ||
| dynamo_config: {} | ||
| fsdp_config: {} | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| megatron_lm_config: {} | ||
| mixed_precision: bf16 | ||
| num_machines: 1 | ||
| num_processes: 7 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| use_cpu: false | ||
| ``` | ||
|
|
||
| ### Released Model and Result: | ||
| You can find more details [here](https://huggingface.co/pvduy/stable-vicuna-13b-version2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=dist_vic | ||
| #SBATCH --partition=a100-cu117 | ||
| #SBATCH --account=stablegpt | ||
| #SBATCH --nodes=4 | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --mem=0 | ||
| #SBATCH --output=out_dist_vic_test_5.txt | ||
| #SBATCH --exclusive | ||
|
|
||
| export NCCL_DEBUG=WARN | ||
| export NCCL_PROTO=simple | ||
| export FI_EFA_FORK_SAFE=1 | ||
| export FI_LOG_LEVEL=1 | ||
| export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn | ||
| export FI_EFA_ENABLE_SHM_TRANSFER=0 | ||
| export FI_PROVIDER=efa | ||
| export FI_EFA_TX_MIN_CREDITS=64 | ||
| # export CUDA_LAUNCH_BLOCKING=1 | ||
|
|
||
| export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST") | ||
| export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) | ||
| export MASTER_PORT=13043 | ||
| export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) | ||
| export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:512' | ||
| export TOKENIZERS_PARALLELISM=false | ||
| # export TRITON_HOST=localhost:8001 | ||
|
|
||
| srun train_dist.sh |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| import os | ||
| from typing import List | ||
|
|
||
| import torch | ||
| from datasets import load_dataset | ||
| from sklearn.model_selection import train_test_split | ||
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
|
||
| import trlx | ||
| from trlx.data.configs import ( | ||
| ModelConfig, | ||
| OptimizerConfig, | ||
| SchedulerConfig, | ||
| TokenizerConfig, | ||
| TrainConfig, | ||
| TRLConfig, | ||
| ) | ||
| from trlx.models.modeling_ppo import PPOConfig | ||
|
|
||
| MODEL_BASED = "pvduy/vicuna-13b-v1.1-sft-ver2" | ||
| RM_BASED = "reciprocate/gpt-j_rm_format-oa" | ||
| RM_REVISION = "501f895" | ||
| OUT_DIR = "stable_vicuna_output" | ||
| DATASET_PATH = "pvduy/stable_vicuna_oasst_format" | ||
|
|
||
|
|
||
| config = TRLConfig( | ||
| train=TrainConfig( | ||
| seq_length=1024, | ||
| epochs=100, | ||
| total_steps=100000, | ||
| batch_size=1, | ||
| minibatch_size=1, | ||
| checkpoint_interval=10000, | ||
| eval_interval=500, | ||
| pipeline="PromptPipeline", | ||
| trainer="AcceleratePPOTrainer", | ||
| checkpoint_dir=OUT_DIR, | ||
| ), | ||
| model=ModelConfig( | ||
| model_path=MODEL_BASED, | ||
| num_layers_unfrozen=2, | ||
| ), | ||
| tokenizer=TokenizerConfig( | ||
| tokenizer_path=MODEL_BASED, | ||
| truncation_side="left", | ||
| padding_side="left", | ||
| ), | ||
| optimizer=OptimizerConfig( | ||
| name="adamw", | ||
| kwargs={"lr": 1.0e-6, "betas": [0.9, 0.95], "eps": 1.0e-8, "weight_decay": 1.0e-6}, | ||
| ), | ||
| scheduler=SchedulerConfig( | ||
| name="cosine_annealing", | ||
| kwargs={ | ||
| "T_max": 100000, | ||
| "eta_min": 1.0e-6, | ||
| }, | ||
| ), | ||
| method=PPOConfig( | ||
| name="PPOConfig", | ||
| num_rollouts=64, | ||
| chunk_size=4, | ||
| ppo_epochs=3, | ||
| init_kl_coef=0.05, | ||
| target=None, | ||
| horizon=10000, | ||
| gamma=1, | ||
| lam=0.95, | ||
| cliprange=0.2, | ||
| cliprange_value=0.2, | ||
| vf_coef=2, | ||
| scale_reward=None, | ||
| ref_mean=None, | ||
| ref_std=None, | ||
| cliprange_reward=5, | ||
| gen_kwargs={ | ||
| "max_new_tokens": 512, | ||
| "do_sample": True, | ||
| "top_k": 0, | ||
| "top_p": 1, | ||
| "temperature": 1, | ||
| }, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def create_reward_fn(): | ||
| if os.environ.get("RANK", "0") == "0": | ||
| tokenizer = AutoTokenizer.from_pretrained(RM_BASED, revision=RM_REVISION) | ||
| tokenizer.truncation_side = "left" | ||
|
|
||
| rm_model = AutoModelForSequenceClassification.from_pretrained(RM_BASED, revision=RM_REVISION) | ||
| rm_model.requires_grad_(False) | ||
| rm_device = torch.cuda.device_count() - 1 | ||
| rm_model = rm_model.eval().half().to(rm_device) | ||
|
|
||
| def get_reward(samples: List[str]): | ||
| all_scores = [] | ||
| batch_size = 40 | ||
| for i in range(0, len(samples), batch_size): | ||
| batch = tokenizer( | ||
| samples[i : i + batch_size], | ||
| padding=True, | ||
| truncation=True, | ||
| max_length=1024, | ||
| return_tensors="pt", | ||
| ).to(rm_device) | ||
|
|
||
| with torch.no_grad(): | ||
| scores = rm_model(**batch)[0].squeeze(-1).cpu() | ||
| all_scores.append(scores) | ||
| scores = torch.hstack(all_scores) | ||
| return scores | ||
|
|
||
| def reward_fn(samples, prompts, original_output, **kwargs): | ||
| samples = [s[s.find("<|prompter|>") :] for s in samples] | ||
| prompts = [p[p.find("<|prompter|>") :] for p in prompts] | ||
| original_samples = [p + o for p, o in zip(prompts, original_output)] | ||
| samples = [x + "<|endoftext|>" for x in samples] | ||
| original_samples = [x + "<|endoftext|>" for x in original_samples] | ||
| rewards = get_reward(samples) | ||
| original_rewards = get_reward(original_samples) | ||
| return rewards - original_rewards | ||
|
|
||
| else: | ||
| return True | ||
| return reward_fn | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| ds = load_dataset(DATASET_PATH)["train"] | ||
| # split pandas dataset into train and validation random | ||
|
|
||
| train, val = train_test_split(dataset, test_size=1000, random_state=42) | ||
|
|
||
| train_prompts = [{"prompt": x["prompt"], "original_output": x["label"]} for _, x in train.iterrows()] | ||
| val_prompts = [{"prompt": x["prompt"], "original_output": x["label"]} for _, x in val.iterrows()] | ||
|
|
||
| reward_fn = create_reward_fn() | ||
|
|
||
| trainer = trlx.train( | ||
| reward_fn=reward_fn, | ||
| prompts=train_prompts, | ||
| eval_prompts=val_prompts, | ||
| config=config, | ||
| stop_sequences=["</s>", "<|prompter|>", "<assistant>"], | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #!/bin/bash | ||
|
|
||
| # HOSTNAMES MASTER_ADDR MASTER_PORT COUNT_NODE are coming from the main script | ||
|
|
||
| H=`hostname` | ||
| RANK=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"` | ||
|
|
||
| echo hostname = `hostname` | ||
| echo HOSTNAMES = $HOSTNAMES | ||
| echo MASTER_ADDR = $MASTER_ADDR | ||
| echo MASTER_PORT = $MASTER_PORT | ||
| echo RANK = $RANK | ||
| echo COUNT_NODE = $COUNT_NODE | ||
|
|
||
|
|
||
| conda env list | ||
| eval "$(conda shell.bash hook)" | ||
| conda activate trlx_env | ||
|
|
||
| cd trlx/examples/stable_vicuna | ||
|
|
||
| if [[ $RANK -eq 0 ]]; then | ||
| accelerate launch --num_processes $((8 * $COUNT_NODE - 1)) --num_machines $COUNT_NODE --machine_rank $RANK --main_process_port 1234 --main_process_ip $MASTER_ADDR --config_file configs/accelerate/zero2-bf16.yaml rl_training.py | ||
| else | ||
| accelerate launch --num_processes $((8 * $COUNT_NODE)) --num_machines $COUNT_NODE --machine_rank $RANK --main_process_port 1234 --main_process_ip $MASTER_ADDR --config_file configs/accelerate/zero2-bf16.yaml rl_training.py | ||
| fi |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this has to be done here? Is it to strip
<|system|>from some of the prompts?