-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
114 lines (92 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Script for training model
"""
import os
import numpy as np
import torch
from src.tokenizer.basic_tokenizer import CONTEXT_SIZE
from src.models.transformer import Network
# Config params
lr = 1e-4
N_ITER = 60000
EVAL_ROUND = 1000
EVAL_ITER = 20
BATCH_SIZE = 32
# Directory where data is stored
data_dir = "data"
# Directory where model weights will be saved
model_dir = 'src/trained_models/model.pkl'
# Set the correct device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_batch(split):
"""
Get train/test batch, create a memory mapping every time to
avoid a memory leak.
https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
"""
# Recreate np.memmap every batch to avoid memory leak
ids_path = os.path.join(data_dir, f"{split}_ids.bin")
# Create path for binned scores
scores_path = os.path.join(data_dir, f"{split}_percentages.bin")
# Create memory mapped array
x = np.memmap(ids_path, dtype=np.uint8, mode = 'r')
y = np.memmap(scores_path, np.uint8, mode = 'r')
# Reshape using context size used to tokenize data
# B here is the total arr length
B = int(x.shape[0] / CONTEXT_SIZE)
x = x.reshape(B, CONTEXT_SIZE)
# Generate random ids
ix = torch.randint(B, (BATCH_SIZE,))
# Convert to pytorch and convert type to int32 to make it compatible with model
x, y = torch.from_numpy(x[ix].astype(np.int32)), torch.from_numpy(y[ix].astype(np.int64))
return x.to(device), y.to(device)
@torch.no_grad()
def evaluate_model(model):
"""
Evaluates model on both train and test data. The loss
is calculated EVAL_ITER times, and the mean is returned
for both splits.
"""
out = {}
model.eval()
for split in ['train', 'test']:
lossi = torch.zeros(EVAL_ITER)
for k in range(EVAL_ITER):
# Get batch
xb, yb = get_batch(split)
# Forward pass
logits, loss = model(xb, yb)
# Save loss
lossi[k] = loss.item()
out[split] = lossi.mean()
# Convert state back to train
model.train()
return out
def train_loop(model):
"""
Simple train loop which trains the model N_ITER times.
Evaluation happens every EVAL_ROUND. If test loss is better than the
previous loss, the weights are checkpointed.
"""
min_test = float("inf")
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
for i in range(N_ITER):
xb, yb = get_batch('train')
optimizer.zero_grad()
logits, loss = model(xb, yb)
loss.backward()
optimizer.step()
if i % EVAL_ROUND == 0:
loss = evaluate_model(model)
train_loss, test_loss = loss['train'], loss['test']
print(f"{i}/{N_ITER} Train loss: {train_loss:.4f} Test loss: {test_loss:.4f}")
if test_loss < min_test:
min_test = test_loss
checkpoint = {'model': model.state_dict()}
torch.save(checkpoint, model_dir)
print(f"writing new weights to disk with loss {min_test}")
if __name__ == '__main__':
# Initialize the transformer network
model = Network()
# Train and save the model
train_loop(model)