Skip to content

Commit 1c2d401

Browse files
committed
fix
1 parent 905c55f commit 1c2d401

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,10 +730,9 @@ def main():
730730
_logger.warning(
731731
"You've requested to log metrics to wandb but package not found. "
732732
"Metrics not being logged to wandb, try `pip install wandb`")
733-
734733
if should_log_to_tensorboard(args):
735734
if has_tensorboard:
736-
writer = SummaryWriter(args.log_tensorboard)
735+
tensorboard_writer = SummaryWriter(args.log_tensorboard)
737736
else:
738737
_logger.warning(
739738
"You've requested to log metrics to tensorboard but package not found. "
@@ -785,6 +784,7 @@ def main():
785784
loss_scaler=loss_scaler,
786785
model_ema=model_ema,
787786
mixup_fn=mixup_fn,
787+
tensorboard_writer=tensorboard_writer,
788788
)
789789

790790
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@@ -857,7 +857,8 @@ def train_one_epoch(
857857
amp_autocast=suppress,
858858
loss_scaler=None,
859859
model_ema=None,
860-
mixup_fn=None
860+
mixup_fn=None,
861+
tensorboard_writer=None,
861862
):
862863
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
863864
if args.prefetcher and loader.mixup_enabled:

0 commit comments

Comments
 (0)