@@ -798,6 +798,7 @@ def main():
798
798
validate_loss_fn ,
799
799
args ,
800
800
amp_autocast = amp_autocast ,
801
+ tensorboard_writer = tensorboard_writer ,
801
802
)
802
803
803
804
if model_ema is not None and not args .model_ema_force_cpu :
@@ -922,8 +923,8 @@ def train_one_epoch(
922
923
batch_time_m .update (time .time () - end )
923
924
#write to tensorboard if enabled
924
925
if should_log_to_tensorboard (args ):
925
- writer .add_scalar ('train/loss' , losses_m .val , num_updates )
926
- writer .add_scalar ('train/lr' , optimizer .param_groups [0 ]['lr' ], num_updates )
926
+ tensorboard_writer .add_scalar ('train/loss' , losses_m .val , num_updates )
927
+ tensorboard_writer .add_scalar ('train/lr' , optimizer .param_groups [0 ]['lr' ], num_updates )
927
928
if last_batch or batch_idx % args .log_interval == 0 :
928
929
lrl = [param_group ['lr' ] for param_group in optimizer .param_groups ]
929
930
lr = sum (lrl ) / len (lrl )
@@ -986,7 +987,9 @@ def validate(
986
987
args ,
987
988
device = torch .device ('cuda' ),
988
989
amp_autocast = suppress ,
989
- log_suffix = ''
990
+ log_suffix = '' ,
991
+ tensorboard_writer = None ,
992
+
990
993
):
991
994
batch_time_m = utils .AverageMeter ()
992
995
losses_m = utils .AverageMeter ()
@@ -1037,9 +1040,9 @@ def validate(
1037
1040
batch_time_m .update (time .time () - end )
1038
1041
end = time .time ()
1039
1042
if should_log_to_tensorboard (args ):
1040
- writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1041
- writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1042
- writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
1043
+ tensorboard_writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1044
+ tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1045
+ tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
1043
1046
if utils .is_primary (args ) and (last_batch or batch_idx % args .log_interval == 0 ):
1044
1047
log_name = 'Test' + log_suffix
1045
1048
_logger .info (
0 commit comments