-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidation.py
47 lines (35 loc) · 1.32 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import logging
import torch
import torch.distributed as dist
log = logging.getLogger(__name__)
class Validation:
'''
Module to test validation data set as the model is learning.
Will test over the same dataset every time.
'''
def __init__(self, model, data_loader, tParams, ddp):
self.model = model
self.data_loader = data_loader
self.ddp = ddp
self.validation_steps = tParams.validation_steps
def run_validation(self, step):
'''
Returns validation loss
'''
self.model.eval()
self.data_loader.reset_validation()
val_loss = 0.0
with torch.no_grad():
for _ in range(self.validation_steps):
input, output = self.data_loader.get_val_samples()
with torch.autocast(device_type=self.ddp.device_type, dtype=torch.bfloat16):
_, loss = self.model(input, output)
val_loss += loss.detach()
if self.ddp.is_avail:
# Reduce loss across all processes
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
self.model.train()
# Average over validation steps
val_loss = val_loss / self.validation_steps
if self.ddp.is_main:
log.info(f'Step ({step}). Val Loss: {val_loss.item():.4f}')