-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
235 lines (187 loc) · 7.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from __future__ import annotations
import os
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import cpu_count
import hydra
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from datasets import Dataset, load_dataset
from omegaconf import DictConfig, OmegaConf
from simple_parsing import field
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel
from sae.config import SaeConfig
from sae.data import MemmapDataset, chunk_and_tokenize
from sae.logger import get_logger
from sae.trainer import SaeLayerRangeTrainer, SaeTrainer, TrainConfig
from sae.utils import get_open_port, set_seed
logger = get_logger(__name__)
@dataclass
class RunConfig(TrainConfig):
seed: int = field(default=42)
"""Random seed to use for training."""
model: str = field(
default="gpt2",
positional=True,
)
"""Name of the model to train."""
dataset: str = field(
default="togethercomputer/RedPajama-Data-1T-Sample",
positional=True,
)
"""Path to the dataset to use for training."""
split: str = "train"
"""Dataset split to use for training."""
train_split: str = "train"
"""Dataset split to use for training."""
train_test_split: float = 0.8
"""Fraction of the dataset to use for training."""
ds_name: str | None = None
"""Dataset name to use when loading from huggingface."""
ctx_len: int = 2048
"""Context length to use for training."""
hf_token: str | None = None
"""Huggingface API token for downloading models."""
load_in_8bit: bool = False
"""Load the model in 8-bit mode."""
max_train_examples: int = -1
"""Maximum number of examples to use for training."""
max_test_examples: int = -1
"""Maximum number of examples to use for testing."""
data_preprocessing_num_proc: int = field( # noqa: RUF009
default_factory=lambda: cpu_count() // 2,
)
"""Number of processes to use for preprocessing data"""
# distributed
ddp: bool = False
port: int = field(default_factory=get_open_port)
def load_artifacts(
args: RunConfig, rank: int | None = None,
) -> tuple[PreTrainedModel, Dataset | MemmapDataset]:
if args.load_in_8bit:
dtype = torch.float16
elif torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = "auto"
model = AutoModel.from_pretrained(
args.model,
device_map={"": f"cuda:{rank}"} if rank is not None else "auto",
quantization_config=(
BitsAndBytesConfig(load_in_8bit=args.load_in_8bit)
if args.load_in_8bit
else None
),
torch_dtype=dtype,
token=args.hf_token,
)
# For memmap-style datasets
if args.dataset.endswith(".bin"):
dataset = MemmapDataset(args.dataset, args.ctx_len, args.max_train_examples)
else:
# For Huggingface datasets
try:
dataset = load_dataset(
args.dataset,
name=args.ds_name,
split=args.split,
# TODO: Maybe set this to False by default? But RPJ requires it.
trust_remote_code=True,
)
except ValueError as e:
# Automatically use load_from_disk if appropriate
if "load_from_disk" in str(e):
dataset = Dataset.load_from_disk(args.dataset, keep_in_memory=False)
else:
raise e
assert isinstance(dataset, Dataset)
# create train-test split
if args.train_test_split > 0:
dataset_ = dataset.train_test_split(
test_size=args.train_test_split, seed=args.seed,
)
dataset, test_dataset = dataset_.get(args.train_split), dataset_.get("test")
if "input_ids" not in dataset.column_names:
tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.hf_token)
dataset = chunk_and_tokenize(
dataset,
tokenizer,
max_seq_len=args.ctx_len,
num_proc=min(args.data_preprocessing_num_proc, os.cpu_count()),
)
test_dataset = chunk_and_tokenize(
test_dataset,
tokenizer,
max_seq_len=args.ctx_len,
num_proc=min(args.data_preprocessing_num_proc, os.cpu_count()),
)
else:
logger.info("Dataset already tokenized; skipping tokenization.")
dataset, test_dataset = (
dataset.with_format("torch"),
test_dataset.with_format("torch"),
)
if (limit := args.max_train_examples) and args.max_train_examples > 0:
dataset = dataset.select(range(limit))
if (limit := args.max_test_examples) and args.max_test_examples > 0:
test_dataset = test_dataset.select(range(limit))
return model, dataset, test_dataset
def worker_main(
rank: int,
world_size: int,
args: RunConfig,
):
if args.ddp and world_size > 1:
torch.cuda.set_device(rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(args.port)
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
dist.init_process_group("nccl", world_size=world_size, rank=rank)
if rank == 0:
logger.info(f"Using DDP across {dist.get_world_size()} GPUs.")
if args.tp and rank == 0 and world_size > 1:
logger.info(f"Using TP across {world_size} GPUs.")
# set seeds
set_seed(args.seed)
# Awkward hack to prevent other ranks from duplicating data preprocessing
if not dist.is_initialized() or args.tp or not args.ddp or rank == 0:
model, dataset, _ = load_artifacts(args, rank)
if args.ddp and dist.is_initialized():
dist.barrier()
if rank != 0:
model, dataset, _ = load_artifacts(args, rank)
dataset = dataset.shard(dist.get_world_size(), rank)
total_tokens = len(dataset) * args.ctx_len
trainer_cls = (
SaeTrainer if not args.enable_cross_layer_training else SaeLayerRangeTrainer
)
logger.info(f"Training on '{args.dataset}' (split '{args.split}')")
logger.info(f"Storing model weights in {model.dtype}")
logger.info(f"Num tokens in train dataset: {total_tokens:,}")
trainer = trainer_cls(args, dataset, model, rank, world_size)
logger.info(f"SAEs: {trainer.saes}")
trainer.fit()
if dist.is_initialized():
dist.destroy_process_group()
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig):
world_size = torch.cuda.device_count()
# Convert Hydra config to RunConfig
parsed_config = OmegaConf.to_container(cfg, resolve=True)
sae_config = parsed_config.pop("sae")
args = RunConfig(sae=SaeConfig(**sae_config), **parsed_config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.run_name = f"{args.run_name}_{timestamp}"
if world_size > 1:
logger.info(f"Spawning {world_size} processes")
mp.spawn(
worker_main,
nprocs=world_size,
args=(world_size, args),
)
else:
worker_main(0, world_size, args)
if __name__ == "__main__":
mp.set_start_method("spawn")
main()