@@ -799,6 +799,7 @@ def main():
799
799
args ,
800
800
amp_autocast = amp_autocast ,
801
801
tensorboard_writer = tensorboard_writer ,
802
+ epoch = epoch ,
802
803
)
803
804
804
805
if model_ema is not None and not args .model_ema_force_cpu :
@@ -812,6 +813,8 @@ def main():
812
813
args ,
813
814
amp_autocast = amp_autocast ,
814
815
log_suffix = ' (EMA)' ,
816
+ tensorboard_writer = tensorboard_writer ,
817
+ epoch = epoch ,
815
818
)
816
819
eval_metrics = ema_eval_metrics
817
820
@@ -989,6 +992,7 @@ def validate(
989
992
amp_autocast = suppress ,
990
993
log_suffix = '' ,
991
994
tensorboard_writer = None ,
995
+ epoch = None ,
992
996
993
997
):
994
998
batch_time_m = utils .AverageMeter ()
@@ -1040,9 +1044,10 @@ def validate(
1040
1044
batch_time_m .update (time .time () - end )
1041
1045
end = time .time ()
1042
1046
if should_log_to_tensorboard (args ):
1043
- tensorboard_writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1044
- tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1045
- tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
1047
+ #by the updates
1048
+ tensorboard_writer .add_scalar ('val/loss' , losses_m .val , epoch * last_idx + batch_idx )
1049
+ tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , epoch * last_idx + batch_idx )
1050
+ tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , epoch * last_idx + batch_idx )
1046
1051
if utils .is_primary (args ) and (last_batch or batch_idx % args .log_interval == 0 ):
1047
1052
log_name = 'Test' + log_suffix
1048
1053
_logger .info (
0 commit comments