Skip to content

Commit fb82fbb

Browse files
authored
Merge pull request #205 from RangiLyu/refactor
[Refactor] replace lightning log with old style log
2 parents 8d5f011 + 1dc9405 commit fb82fbb

File tree

2 files changed

+51
-14
lines changed

2 files changed

+51
-14
lines changed

nanodet/trainer/task.py

+48-13
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import warnings
1818
import torch
19+
import logging
1920
from pytorch_lightning import LightningModule
2021
from typing import Any, List, Dict, Tuple, Optional
2122

@@ -42,8 +43,11 @@ def __init__(self, cfg, evaluator=None, logger=None):
4243
self.evaluator = evaluator
4344
self._logger = logger
4445
self.save_flag = -10
45-
# TODO: better logger
46+
self.log_style = 'NanoDet' # Log style. Choose between 'NanoDet' or 'Lightning'
47+
# TODO: use callback to log
48+
# TODO: remove _logger
4649
# TODO: batch eval
50+
# TODO: support old checkpoint
4751

4852
def forward(self, x):
4953
x = self.model(x)
@@ -57,21 +61,43 @@ def predict(self, batch, batch_idx, dataloader_idx):
5761

5862
def training_step(self, batch, batch_idx):
5963
preds, loss, loss_states = self.model.forward_train(batch)
60-
self.log('lr', self.optimizers().param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=True)
61-
for k, v in loss_states.items():
62-
self.log('Train/'+k, v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
64+
65+
# log train losses
66+
if self.log_style == 'Lightning':
67+
self.log('lr', self.optimizers().param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=True)
68+
for k, v in loss_states.items():
69+
self.log('Train/'+k, v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
70+
elif self.log_style == 'NanoDet' and self.global_step % self.cfg.log.interval == 0:
71+
lr = self.optimizers().param_groups[0]['lr']
72+
log_msg = 'Train|Epoch{}/{}|Iter{}({})| lr:{:.2e}| '.format(self.current_epoch+1,
73+
self.cfg.schedule.total_epochs, self.global_step, batch_idx, lr)
74+
self.scalar_summary('Train_loss/lr', 'Train', lr, self.global_step)
75+
for l in loss_states:
76+
log_msg += '{}:{:.4f}| '.format(l, loss_states[l].mean().item())
77+
self.scalar_summary('Train_loss/' + l, 'Train', loss_states[l].mean().item(), self.global_step)
78+
self.info(log_msg)
79+
6380
return loss
6481

6582
def training_epoch_end(self, outputs: List[Any]) -> None:
66-
self.print('Epoch ', self.current_epoch, ' finished.')
6783
self.trainer.save_checkpoint(os.path.join(self.cfg.save_dir, 'model_last.ckpt'))
6884
self.lr_scheduler.step()
6985

7086
def validation_step(self, batch, batch_idx):
7187
preds, loss, loss_states = self.model.forward_train(batch)
72-
self.log('Val/loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=False)
73-
for k, v in loss_states.items():
74-
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
88+
89+
if self.log_style == 'Lightning':
90+
self.log('Val/loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=False)
91+
for k, v in loss_states.items():
92+
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
93+
elif self.log_style == 'NanoDet' and batch_idx % self.cfg.log.interval == 0:
94+
lr = self.optimizers().param_groups[0]['lr']
95+
log_msg = 'Val|Epoch{}/{}|Iter{}({})| lr:{:.2e}| '.format(self.current_epoch+1,
96+
self.cfg.schedule.total_epochs, self.global_step, batch_idx, lr)
97+
for l in loss_states:
98+
log_msg += '{}:{:.4f}| '.format(l, loss_states[l].mean().item())
99+
self.info(log_msg)
100+
75101
dets = self.model.head.post_process(preds, batch)
76102
res = {batch['img_info']['id'].cpu().numpy()[0]: dets}
77103
return res
@@ -80,7 +106,7 @@ def validation_epoch_end(self, validation_step_outputs):
80106
results = {}
81107
for res in validation_step_outputs:
82108
results.update(res)
83-
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, self.current_epoch,
109+
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, self.current_epoch+1,
84110
self._logger, rank=self.local_rank)
85111
metric = eval_results[self.cfg.evaluator.save_key]
86112
# save best model
@@ -97,8 +123,9 @@ def validation_epoch_end(self, validation_step_outputs):
97123
f.write("{}: {}\n".format(k, v))
98124
else:
99125
warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
100-
for k, v in eval_results.items():
101-
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
126+
if self.log_style == 'Lightning':
127+
for k, v in eval_results.items():
128+
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
102129

103130
def configure_optimizers(self):
104131
optimizer_cfg = copy.deepcopy(self.cfg.schedule.optimizer)
@@ -140,8 +167,6 @@ def optimizer_step(self,
140167
raise Exception('Unsupported warm up type!')
141168
for pg in optimizer.param_groups:
142169
pg['lr'] = warmup_lr
143-
# TODO: log lr to tensorboard
144-
# self.log('lr', optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True)
145170

146171
# update params
147172
optimizer.step(closure=optimizer_closure)
@@ -151,8 +176,18 @@ def get_progress_bar_dict(self):
151176
# don't show the version number
152177
items = super().get_progress_bar_dict()
153178
items.pop("v_num", None)
179+
items.pop("loss", None)
154180
return items
155181

182+
def scalar_summary(self, tag, phase, value, step):
183+
if self.local_rank < 1:
184+
self.logger.experiment.add_scalars(tag, {phase: value}, step)
185+
186+
def info(self, string):
187+
if self.local_rank < 1:
188+
logging.info(string)
189+
190+
156191

157192

158193

tools/train_pl.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import argparse
1818
import numpy as np
1919
import pytorch_lightning as pl
20+
from pytorch_lightning.callbacks import ProgressBar
2021

2122
from nanodet.util import mkdir, Logger, cfg, load_config
2223
from nanodet.data.collate import collate_function
@@ -91,7 +92,8 @@ def main(args):
9192
accelerator='ddp',
9293
log_every_n_steps=cfg.log.interval,
9394
num_sanity_val_steps=0,
94-
resume_from_checkpoint=model_resume_path
95+
resume_from_checkpoint=model_resume_path,
96+
callbacks=[ProgressBar(refresh_rate=0)] # disable tqdm bar
9597
)
9698

9799
trainer.fit(task, train_dataloader, val_dataloader)

0 commit comments

Comments
 (0)