Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/ppo_sentiments_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def llama_config():
trainer="AcceleratePPOTrainer",
save_best=False,
),
model=ModelConfig(model_path="decapoda-research/llama-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="decapoda-research/llama-7b-hf", truncation_side="right"),
model=ModelConfig(model_path="/path/to/your/llama-7b/model/", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="/path/to/your/llama-7b/model", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
Expand Down
70 changes: 70 additions & 0 deletions examples/stable_vicuna/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
## StableVicuna: Open Source RLHF LLM Chatbot

### Dataset

1. Reward Dataset:
- Download at [this](https://huggingface.co/datasets/reciprocate/oasst_hh_shp_hellaswag_webgpt_rm_dataset)
- Dataset size:
- Train: 264534 samples
- Valid: 2874 samples

2. SFT and RL Prompt Dataset:
- Download at [this](https://huggingface.co/datasets/pvduy/stable_vicuna_oasst_format)
- Dataset size: 89991 samples

### Reward Model Training
To train reward model you can following instruction from this [repo](https://github.com/CarperAI/autocrit)

Command:
```bash
python preference.py --model_path reciprocate/dahoas-gptj-rm-static --dataset reciprocate/oasst_hh_shp_hellaswag_webgpt_rm_dataset --batch_size 2 --eval_interval 500 --lr 0.000001
```

### RL Training:

1. Distributed Training:

We trained on 4 nodes with 8 A100 GPUs each.


```bash
sbatch go_train_dist.sh
```

WANDB runs: https://wandb.ai/pvduy/trlx/runs/w8d20kam

2. Single Node Training:
In case that you want to train on a single node, you can use the following command, but we do not garantee the result.

```bash
accelerate launch examples/rl_training.py
```

Accelerate config:
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: no
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 1
num_processes: 7
rdzv_backend: static
same_network: true
use_cpu: false
```

### Released Model and Result:
You can find more details [here](https://huggingface.co/pvduy/stable-vicuna-13b-version2)
29 changes: 29 additions & 0 deletions examples/stable_vicuna/go_train_dist.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash
#SBATCH --job-name=dist_vic
#SBATCH --partition=a100-cu117
#SBATCH --account=stablegpt
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --mem=0
#SBATCH --output=out_dist_vic_test_5.txt
#SBATCH --exclusive

export NCCL_DEBUG=WARN
export NCCL_PROTO=simple
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn
export FI_EFA_ENABLE_SHM_TRANSFER=0
export FI_PROVIDER=efa
export FI_EFA_TX_MIN_CREDITS=64
# export CUDA_LAUNCH_BLOCKING=1

export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=13043
export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:512'
export TOKENIZERS_PARALLELISM=false
# export TRITON_HOST=localhost:8001

srun train_dist.sh
148 changes: 148 additions & 0 deletions examples/stable_vicuna/rl_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
from typing import List

import torch
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from transformers import AutoModelForSequenceClassification, AutoTokenizer

import trlx
from trlx.data.configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.models.modeling_ppo import PPOConfig

MODEL_BASED = "pvduy/vicuna-13b-v1.1-sft-ver2"
RM_BASED = "reciprocate/gpt-j_rm_format-oa"
RM_REVISION = "501f895"
OUT_DIR = "stable_vicuna_output"
DATASET_PATH = "pvduy/stable_vicuna_oasst_format"


config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=100000,
batch_size=1,
minibatch_size=1,
checkpoint_interval=10000,
eval_interval=500,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
checkpoint_dir=OUT_DIR,
),
model=ModelConfig(
model_path=MODEL_BASED,
num_layers_unfrozen=2,
),
tokenizer=TokenizerConfig(
tokenizer_path=MODEL_BASED,
truncation_side="left",
padding_side="left",
),
optimizer=OptimizerConfig(
name="adamw",
kwargs={"lr": 1.0e-6, "betas": [0.9, 0.95], "eps": 1.0e-8, "weight_decay": 1.0e-6},
),
scheduler=SchedulerConfig(
name="cosine_annealing",
kwargs={
"T_max": 100000,
"eta_min": 1.0e-6,
},
),
method=PPOConfig(
name="PPOConfig",
num_rollouts=64,
chunk_size=4,
ppo_epochs=3,
init_kl_coef=0.05,
target=None,
horizon=10000,
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=2,
scale_reward=None,
ref_mean=None,
ref_std=None,
cliprange_reward=5,
gen_kwargs={
"max_new_tokens": 512,
"do_sample": True,
"top_k": 0,
"top_p": 1,
"temperature": 1,
},
),
)


def create_reward_fn():
if os.environ.get("RANK", "0") == "0":
tokenizer = AutoTokenizer.from_pretrained(RM_BASED, revision=RM_REVISION)
tokenizer.truncation_side = "left"

rm_model = AutoModelForSequenceClassification.from_pretrained(RM_BASED, revision=RM_REVISION)
rm_model.requires_grad_(False)
rm_device = torch.cuda.device_count() - 1
rm_model = rm_model.eval().half().to(rm_device)

def get_reward(samples: List[str]):
all_scores = []
batch_size = 40
for i in range(0, len(samples), batch_size):
batch = tokenizer(
samples[i : i + batch_size],
padding=True,
truncation=True,
max_length=1024,
return_tensors="pt",
).to(rm_device)

with torch.no_grad():
scores = rm_model(**batch)[0].squeeze(-1).cpu()
all_scores.append(scores)
scores = torch.hstack(all_scores)
return scores

def reward_fn(samples, prompts, original_output, **kwargs):
samples = [s[s.find("<|prompter|>") :] for s in samples]
prompts = [p[p.find("<|prompter|>") :] for p in prompts]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has to be done here? Is it to strip <|system|> from some of the prompts?

original_samples = [p + o for p, o in zip(prompts, original_output)]
samples = [x + "<|endoftext|>" for x in samples]
original_samples = [x + "<|endoftext|>" for x in original_samples]
rewards = get_reward(samples)
original_rewards = get_reward(original_samples)
return rewards - original_rewards

else:
return True
return reward_fn


if __name__ == "__main__":
ds = load_dataset(DATASET_PATH)["train"]
# split pandas dataset into train and validation random

train, val = train_test_split(dataset, test_size=1000, random_state=42)

train_prompts = [{"prompt": x["prompt"], "original_output": x["label"]} for _, x in train.iterrows()]
val_prompts = [{"prompt": x["prompt"], "original_output": x["label"]} for _, x in val.iterrows()]

reward_fn = create_reward_fn()

trainer = trlx.train(
reward_fn=reward_fn,
prompts=train_prompts,
eval_prompts=val_prompts,
config=config,
stop_sequences=["</s>", "<|prompter|>", "<assistant>"],
)
26 changes: 26 additions & 0 deletions examples/stable_vicuna/train_dist.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

# HOSTNAMES MASTER_ADDR MASTER_PORT COUNT_NODE are coming from the main script

H=`hostname`
RANK=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"`

echo hostname = `hostname`
echo HOSTNAMES = $HOSTNAMES
echo MASTER_ADDR = $MASTER_ADDR
echo MASTER_PORT = $MASTER_PORT
echo RANK = $RANK
echo COUNT_NODE = $COUNT_NODE


conda env list
eval "$(conda shell.bash hook)"
conda activate trlx_env

cd trlx/examples/stable_vicuna

if [[ $RANK -eq 0 ]]; then
accelerate launch --num_processes $((8 * $COUNT_NODE - 1)) --num_machines $COUNT_NODE --machine_rank $RANK --main_process_port 1234 --main_process_ip $MASTER_ADDR --config_file configs/accelerate/zero2-bf16.yaml rl_training.py
else
accelerate launch --num_processes $((8 * $COUNT_NODE)) --num_machines $COUNT_NODE --machine_rank $RANK --main_process_port 1234 --main_process_ip $MASTER_ADDR --config_file configs/accelerate/zero2-bf16.yaml rl_training.py
fi