Skip to content

Commit 04ec450

Browse files
committed
fix
1 parent 1c2d401 commit 04ec450

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,7 @@ def main():
798798
validate_loss_fn,
799799
args,
800800
amp_autocast=amp_autocast,
801+
tensorboard_writer=tensorboard_writer,
801802
)
802803

803804
if model_ema is not None and not args.model_ema_force_cpu:
@@ -922,8 +923,8 @@ def train_one_epoch(
922923
batch_time_m.update(time.time() - end)
923924
#write to tensorboard if enabled
924925
if should_log_to_tensorboard(args):
925-
writer.add_scalar('train/loss', losses_m.val, num_updates)
926-
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates)
926+
tensorboard_writer.add_scalar('train/loss', losses_m.val, num_updates)
927+
tensorboard_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates)
927928
if last_batch or batch_idx % args.log_interval == 0:
928929
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
929930
lr = sum(lrl) / len(lrl)
@@ -986,7 +987,9 @@ def validate(
986987
args,
987988
device=torch.device('cuda'),
988989
amp_autocast=suppress,
989-
log_suffix=''
990+
log_suffix='',
991+
tensorboard_writer=None,
992+
990993
):
991994
batch_time_m = utils.AverageMeter()
992995
losses_m = utils.AverageMeter()
@@ -1037,9 +1040,9 @@ def validate(
10371040
batch_time_m.update(time.time() - end)
10381041
end = time.time()
10391042
if should_log_to_tensorboard(args):
1040-
writer.add_scalar('val/loss', losses_m.val, batch_idx)
1041-
writer.add_scalar('val/acc1', top1_m.val, batch_idx)
1042-
writer.add_scalar('val/acc5', top5_m.val, batch_idx)
1043+
tensorboard_writer.add_scalar('val/loss', losses_m.val, batch_idx)
1044+
tensorboard_writer.add_scalar('val/acc1', top1_m.val, batch_idx)
1045+
tensorboard_writer.add_scalar('val/acc5', top5_m.val, batch_idx)
10431046
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
10441047
log_name = 'Test' + log_suffix
10451048
_logger.info(

0 commit comments

Comments
 (0)