-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
382 lines (308 loc) · 17.4 KB
/
train.py
File metadata and controls
382 lines (308 loc) · 17.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
from cycling_utils import TimestampedTimer
timer = TimestampedTimer("Imported TimestampedTimer")
import argparse
import os
import math
import random
import torch
from torch import nn
import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard as FSDP, CPUOffloadPolicy, MixedPrecisionPolicy
import torch.distributed.checkpoint as dcp
from fsdp_utils import AppState
from model import Model, calculate_effective_loss
from dataset import ARC_AGI_Dataset, collate_batch, VOCAB
from cycling_utils import (
# MetricsTracker,
AtomicDirectory,
atomic_torch_save
)
def det_range_splits(seq_len, rank, world_size):
split_sizes = []
split_ranges = []
consumed = 0
for rank in range(world_size):
remaining_tokens = seq_len - consumed
remaining_ranks = world_size - rank
split_size = math.ceil(remaining_tokens / remaining_ranks)
split_sizes.append(split_size)
split_ranges.append((consumed, consumed + split_size))
consumed += split_size
local_range = split_ranges[rank]
local_size = split_sizes[rank]
return split_ranges, split_sizes, local_range, local_size
timer.report("Completed imports")
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5000)
parser.add_argument("--test-epochs", type=int, default=5)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr-step-epochs", type=float, default=100)
parser.add_argument("--lr-decay-rate", type=float, default=0.8)
parser.add_argument("--batch-size", type=int, default=10)
parser.add_argument("--save-freq", type=int, default=1)
parser.add_argument("--eval-freq", type=int, default=1)
return parser
def train_loop(epoch, model, optimizer, app_state_dict, train_dataloader, eval_dataloader, saver, args):
train_batches_per_epoch = len(train_dataloader)
model.train()
cum_loss = 0
cum_top1 = 0
cum_top2 = 0
cum_target_tokens = 0
cum_examples_solved = 0
cum_examples_seen = 0
for batch, (X_data, X_coord, x_padd_mask, y_data, y_coord, y_padd_mask, prompt_lengths) in enumerate(train_dataloader):
# prepare autoregressive targets
y_data_inputs = y_data[:,:-1]
y_coord_inputs = y_coord[:,:-1,:]
y_data_targets = y_data[:,1:]
y_padd_mask_targets = y_padd_mask[:,1:]
# generate the masks for X and y.
batch_size, x_seq_len = X_data.shape
y_seq_len = y_data_inputs.shape[1]
# (batch_size, nhead, q_seq_len, kv_seq_len)
x_padd_mask = torch.where(x_padd_mask==True, -torch.inf, 0)
x_padd_mask = x_padd_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, x_seq_len, x_seq_len)
y_causal_mask = nn.Transformer.generate_square_subsequent_mask(y_seq_len, device="cuda", dtype=args.dtype)
# print(f"Rank {args.rank} x_padd_mask.shape: {x_padd_mask.shape}, y_causal_mask.shape: {y_causal_mask.shape}")
# Rank 0 x_padd_mask.shape: torch.Size([4, 1, 2232, 2232]), y_causal_mask.shape: torch.Size([760, 760])
# discard all but the local input chunks
X_split_ranges, X_split_sizes, X_local_range, X_local_size = det_range_splits(x_seq_len, args.rank, args.world_size)
y_split_ranges, y_split_sizes, y_local_range, y_local_size = det_range_splits(y_seq_len, args.rank, args.world_size)
X_data_chunk = X_data[:, slice(*X_local_range)].to("cuda")
X_coord_chunk = X_coord[:, slice(*X_local_range), :].to("cuda")
X_padd_mask_chunks = [x_padd_mask[:, :, slice(*X_local_range), slice(*x_range)].to("cuda") for x_range in X_split_ranges]
y_data_chunk = y_data_inputs[:, slice(*y_local_range)].to("cuda")
y_coord_chunk = y_coord_inputs[:, slice(*y_local_range), :].to("cuda")
y_causal_mask_chunks = [y_causal_mask[slice(*y_local_range), slice(*y_range)].to("cuda") for y_range in y_split_ranges]
# pass y_mask_chunk (causal) for the forward pass
logits = model(X_data_chunk, X_coord_chunk, X_padd_mask_chunks, y_data_chunk, y_coord_chunk, y_causal_mask_chunks)
# pass the prompt_lengths and y_padd_mask for the loss calculation
y_padd_mask_chunk = y_padd_mask_targets[:, slice(*y_local_range)].to("cuda")
print(f"Rank {args.rank} y_padd_mask_chunk {y_padd_mask_chunk}")
y_start_index = y_local_range[0]
loss, top1, top2, target_tokens, top1_per_exp, tgt_tkn_per_exp = calculate_effective_loss(logits, y_data_chunk, y_padd_mask_chunk, prompt_lengths, y_start_index, device="cuda")
loss /= train_batches_per_epoch
loss.backward()
# timer.report(f"top1 {top1}")
# timer.report(f"top1_per_exp {top1_per_exp}")
# timer.report(f"tgt_tkn_per_exp {tgt_tkn_per_exp}")
print(f"Rank {args.rank} tgt_tkn_per_exp {tgt_tkn_per_exp}")
# want to do an all_reduce on metrics here for cluster agreement
# interesting, this obviates the need for an all reduce to sync loss at the end of the epoch.
agg_results = torch.tensor([loss.item(), top1.item(), top2.item(), target_tokens.item()], device="cuda")
dist.all_reduce(agg_results)
agg_loss, agg_top1, agg_top2, agg_target_tokens = agg_results
cum_loss += agg_loss.item()
cum_top1 += agg_top1.item()
cum_top2 += agg_top2.item()
cum_target_tokens += agg_target_tokens.item()
cum_examples_seen += batch_size
dist.all_reduce(top1_per_exp)
dist.all_reduce(tgt_tkn_per_exp)
num_solved = (top1_per_exp == tgt_tkn_per_exp).sum()
cum_examples_solved += num_solved
del loss, logits
del prompt_lengths, y_data_inputs, y_coord_inputs, y_data_targets
del X_data_chunk, X_coord_chunk, X_padd_mask_chunks
del y_data_chunk, y_coord_chunk, y_causal_mask_chunks
del X_data, X_coord, x_padd_mask
del y_data, y_coord, y_padd_mask, y_causal_mask
# if (batch + 1) % 5 == 0:
timer.report(f"top1_per_exp {top1_per_exp}")
timer.report(f"tgt_tkn_per_exp {tgt_tkn_per_exp}")
timer.report(f"Epoch [{epoch}] Batch [{batch + 1} / {train_batches_per_epoch}] solved {num_solved.item()} loss {agg_loss.item():,.2f}")
# train on the whole dataset each step
# total_grad_norm = nn.utils.get_total_norm(model.parameters())
# timer.report(f"GradNorm Before Clipping: [{total_grad_norm.item():,.5f}]")
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# total_grad_norm = nn.utils.get_total_norm(model.parameters())
# timer.report(f"GradNorm After Clipping: [{total_grad_norm.item():,.5f}]")
optimizer.step()
optimizer.zero_grad()
# agg = torch.tensor([cum_loss, cum_top1, cum_top2, cum_target_tokens, cum_examples_correct, cum_examples_seen]).to(args.device_id)
# dist.all_reduce(agg)
if (epoch + 1) % args.save_freq == 0:
checkpoint_directory = saver.prepare_checkpoint_directory()
checkpoint_path = os.path.join(checkpoint_directory, "checkpoint.pt")
checkpoint_writer = dcp.FileSystemWriter(checkpoint_directory)
_metadata = dcp.save(
state_dict=app_state_dict,
storage_writer=checkpoint_writer
)
if args.is_master:
atomic_torch_save({
"epoch": epoch,
}, checkpoint_path)
saver.symlink_latest(checkpoint_directory)
# agg_loss = agg[0].item()
# agg_tokens = int(agg[3].item())
top1_pct = (cum_top1 / cum_target_tokens) * 100
top2_pct = (cum_top2 / cum_target_tokens) * 100
pct_solved = (cum_examples_solved / cum_examples_seen) * 100
timer.report(f"Train Epoch [{epoch}] Loss [{cum_loss:,.2f}] Solved [{cum_examples_solved} / {cum_examples_seen}, {pct_solved:.2}%] Top1 [{top1_pct:.2f}%] Top2 [{top2_pct:.2f}%] Targets [{cum_target_tokens:,}]")
# best_train_example_str = ''.join([VOCAB[x] for x in best_train_example])
# timer.report(f"Best train example: {best_train_example_str}")
if (epoch + 1) % args.eval_freq == 0:
eval_batches_per_epoch = len(eval_dataloader)
model.eval()
cum_loss = 0
cum_top1 = 0
cum_top2 = 0
cum_target_tokens = 0
cum_examples_solved = 0
cum_examples_seen = 0
for X_data, X_coord, x_padd_mask, y_data, y_coord, y_padd_mask, prompt_lengths in eval_dataloader:
# prepare autoregressive targets
y_data_inputs = y_data[:,:-1]
y_coord_inputs = y_coord[:,:-1,:]
y_data_targets = y_data[:,1:]
y_padd_mask_targets = y_padd_mask[:,1:]
# generate the masks for X and y.
batch_size, x_seq_len = X_data.shape
y_seq_len = y_data_inputs.shape[1]
# (batch_size, nhead, q_seq_len, kv_seq_len)
x_padd_mask = torch.where(x_padd_mask==True, -torch.inf, 0)
x_padd_mask = x_padd_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, x_seq_len, x_seq_len)
y_causal_mask = nn.Transformer.generate_square_subsequent_mask(y_seq_len, device="cuda", dtype=args.dtype)
# discard all but the local input chunks
X_split_ranges, X_split_sizes, X_local_range, X_local_size = det_range_splits(x_seq_len, args.rank, args.world_size)
y_split_ranges, y_split_sizes, y_local_range, y_local_size = det_range_splits(y_seq_len, args.rank, args.world_size)
X_data_chunk = X_data[:, slice(*X_local_range)].to("cuda")
X_coord_chunk = X_coord[:, slice(*X_local_range), :].to("cuda")
X_padd_mask_chunks = [x_padd_mask[:, :, slice(*X_local_range), slice(*x_range)].to("cuda") for x_range in X_split_ranges]
y_data_chunk = y_data_inputs[:, slice(*y_local_range)].to("cuda")
y_coord_chunk = y_coord_inputs[:, slice(*y_local_range), :].to("cuda")
y_causal_mask_chunks = [y_causal_mask[slice(*y_local_range), slice(*y_range)].to("cuda") for y_range in y_split_ranges]
# pass y_mask_chunk (causal) for the forward pass
logits = model(X_data_chunk, X_coord_chunk, X_padd_mask_chunks, y_data_chunk, y_coord_chunk, y_causal_mask_chunks)
# pass the prompt_lengths and y_padd_mask for the loss calculation
y_padd_mask_chunk = y_padd_mask_targets[:, slice(*y_local_range)].to("cuda")
y_start_index = y_local_range[0]
loss, top1, top2, target_tokens, top1_per_exp, tgt_tkn_per_exp = calculate_effective_loss(logits, y_data_chunk, y_padd_mask_chunk, prompt_lengths, y_start_index, device="cuda")
loss /= eval_batches_per_epoch
agg_results = torch.tensor([loss.item(), top1.item(), top2.item(), target_tokens.item()], device="cuda")
dist.all_reduce(agg_results)
agg_loss, agg_top1, agg_top2, agg_target_tokens = agg_results
cum_loss += agg_loss.item()
cum_top1 += agg_top1.item()
cum_top2 += agg_top2.item()
cum_target_tokens += agg_target_tokens.item()
cum_examples_seen += batch_size
dist.all_reduce(top1_per_exp)
dist.all_reduce(tgt_tkn_per_exp)
num_solved = (top1_per_exp == tgt_tkn_per_exp).sum()
cum_examples_solved += num_solved
del loss, logits
del prompt_lengths, y_data_inputs, y_coord_inputs, y_data_targets
del X_data_chunk, X_coord_chunk, X_padd_mask_chunks
del y_data_chunk, y_coord_chunk, y_causal_mask_chunks
del X_data, X_coord, x_padd_mask
del y_data, y_coord, y_padd_mask, y_causal_mask
# agg = torch.tensor([cum_loss, cum_top1, cum_top2, cum_target_tokens, cum_examples_correct, cum_examples_seen]).to(args.device_id)
# dist.all_reduce(agg)
# agg_loss = agg[0].item()
# agg_tokens = int(agg[3].item())
top1_pct = (cum_top1 / cum_target_tokens) * 100
top2_pct = (cum_top2 / cum_target_tokens) * 100
pct_solved = (cum_examples_solved / cum_examples_seen) * 100
timer.report(f"Eval Epoch [{epoch}] Loss [{cum_loss:,.2f}] Solved [{pct_solved:.2}%] Top1 [{top1_pct:.2f}%] Top2 [{top2_pct:.2f}%] Targets [{cum_target_tokens:,}]")
# best_eval_example_str = ''.join([VOCAB[x] for x in best_eval_example])
# timer.report(f"Best eval example: {best_eval_example_str}")
timer.report("Defined helper function/s, loops, and model")
def main(args, timer):
args.rank = int(os.environ["RANK"]) # Rank of this GPU in cluster
args.device_id = int(os.environ["LOCAL_RANK"]) # Rank on local node
args.is_master = args.rank == 0 # Master node for saving / reporting
args.world_size = int(os.environ["WORLD_SIZE"]) # Total number of GPUs in the cluster
assert torch.cuda.is_available()
device = torch.device("cuda", args.device_id)
torch.cuda.set_device(device)
# dist.init_process_group(backend="nccl", device_id=device)
# dist.barrier()
mesh = init_device_mesh(device_type="cuda", mesh_shape=(args.world_size,), mesh_dim_names=("world",))
timer.report("Setup for distributed training")
# # train_data_path = "/root/ARC-AGI-2/data/training"
# # eval_data_path = "/root/ARC-AGI-2/data/evaluation"
# # train_dataset = ARC_AGI_Dataset(train_data_path)
# # eval_dataset = ARC_AGI_Dataset(eval_data_path)
# timer.report("Initialized datasets")
# # train_sampler = DistributedSampler(train_dataset)
# # eval_sampler = DistributedSampler(eval_dataset)
# timer.report("Initialized samplers")
# # new plan here - we're going to allow the data to load to RAM and then we're only going to move the local shard to the GPU
# # also need a min_seq_len to make sure every process has at least one token to process.
# # train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=lambda batch: collate_batch(batch, device="cpu", min_seq_len=args.world_size))
# # eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=lambda batch: collate_batch(batch, device="cpu", min_seq_len=args.world_size))
# timer.report("Initialized dataloaders")
# if torch.cuda.is_bf16_supported():
# timer.report("Using bfloat16")
# args.dtype = torch.bfloat16
# else:
# timer.report("Using float32")
args.dtype = torch.float32
model = Model(n_encoder_layers=15, n_decoder_layers=15, embed_dim=64, nhead=4, head_dim=32, ff_dim=256, device=args.device_id, ring=True, dtype=args.dtype)
model = model.to(args.device_id)
# model = DDP(model, device_ids=[args.device_id])
model = FSDP(
model,
mesh=mesh,
offload_policy=CPUOffloadPolicy(),
mp_policy=MixedPrecisionPolicy()
)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
timer.report(f"Initialized model with {params:,} params.")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_epochs, gamma=args.lr_decay_rate)
timer.report(
f"Ready for training with hyper-parameters: \ninitial learning_rate: {args.lr}, \nbatch_size: \
{args.batch_size}, \nepochs: {args.epochs}"
)
# metrics = {
# "train": MetricsTracker(),
# "test": MetricsTracker()
# }
output_directory = os.environ["CHECKPOINT_ARTIFACT_PATH"]
saver = AtomicDirectory(output_directory=output_directory, is_master=args.is_master)
epoch = 0
latest_symlink_file_path = os.path.join(output_directory, saver.symlink_name)
if os.path.islink(latest_symlink_file_path):
latest_checkpoint_path = os.readlink(latest_symlink_file_path)
checkpoint_path = os.path.join(latest_checkpoint_path, "checkpoint.pt")
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=f"cuda:{args.device_id}")
epoch = checkpoint["epoch"]
app_state_dict = { "app": AppState(model, optimizer)}
dcp.load(state_dict=app_state_dict, checkpoint_id=latest_checkpoint_path)
timer.report("Retrieved savedcheckpoint")
app_state_dict = { "app": AppState(model, optimizer) }
for epoch in range(epoch, args.epochs):
# set random seed before re-init dataset
random.seed(epoch)
train_data_path = "/root/ARC-AGI-2/data/training"
eval_data_path = "/root/ARC-AGI-2/data/evaluation"
train_dataset = ARC_AGI_Dataset(train_data_path)
eval_dataset = ARC_AGI_Dataset(eval_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=lambda batch: collate_batch(batch, device="cpu", min_seq_len=args.world_size))
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=lambda batch: collate_batch(batch, device="cpu", min_seq_len=args.world_size))
train_loop(
epoch,
model,
optimizer,
app_state_dict,
train_dataloader,
eval_dataloader,
saver,
args
)
print("Done!")
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args, timer)