16
16
import os
17
17
import warnings
18
18
import torch
19
+ import logging
19
20
from pytorch_lightning import LightningModule
20
21
from typing import Any , List , Dict , Tuple , Optional
21
22
@@ -42,8 +43,11 @@ def __init__(self, cfg, evaluator=None, logger=None):
42
43
self .evaluator = evaluator
43
44
self ._logger = logger
44
45
self .save_flag = - 10
45
- # TODO: better logger
46
+ self .log_style = 'NanoDet' # Log style. Choose between 'NanoDet' or 'Lightning'
47
+ # TODO: use callback to log
48
+ # TODO: remove _logger
46
49
# TODO: batch eval
50
+ # TODO: support old checkpoint
47
51
48
52
def forward (self , x ):
49
53
x = self .model (x )
@@ -57,21 +61,43 @@ def predict(self, batch, batch_idx, dataloader_idx):
57
61
58
62
def training_step (self , batch , batch_idx ):
59
63
preds , loss , loss_states = self .model .forward_train (batch )
60
- self .log ('lr' , self .optimizers ().param_groups [0 ]['lr' ], on_step = True , on_epoch = False , prog_bar = True )
61
- for k , v in loss_states .items ():
62
- self .log ('Train/' + k , v , on_step = True , on_epoch = True , prog_bar = True , sync_dist = True )
64
+
65
+ # log train losses
66
+ if self .log_style == 'Lightning' :
67
+ self .log ('lr' , self .optimizers ().param_groups [0 ]['lr' ], on_step = True , on_epoch = False , prog_bar = True )
68
+ for k , v in loss_states .items ():
69
+ self .log ('Train/' + k , v , on_step = True , on_epoch = True , prog_bar = True , sync_dist = True )
70
+ elif self .log_style == 'NanoDet' and self .global_step % self .cfg .log .interval == 0 :
71
+ lr = self .optimizers ().param_groups [0 ]['lr' ]
72
+ log_msg = 'Train|Epoch{}/{}|Iter{}({})| lr:{:.2e}| ' .format (self .current_epoch + 1 ,
73
+ self .cfg .schedule .total_epochs , self .global_step , batch_idx , lr )
74
+ self .scalar_summary ('Train_loss/lr' , 'Train' , lr , self .global_step )
75
+ for l in loss_states :
76
+ log_msg += '{}:{:.4f}| ' .format (l , loss_states [l ].mean ().item ())
77
+ self .scalar_summary ('Train_loss/' + l , 'Train' , loss_states [l ].mean ().item (), self .global_step )
78
+ self .info (log_msg )
79
+
63
80
return loss
64
81
65
82
def training_epoch_end (self , outputs : List [Any ]) -> None :
66
- self .print ('Epoch ' , self .current_epoch , ' finished.' )
67
83
self .trainer .save_checkpoint (os .path .join (self .cfg .save_dir , 'model_last.ckpt' ))
68
84
self .lr_scheduler .step ()
69
85
70
86
def validation_step (self , batch , batch_idx ):
71
87
preds , loss , loss_states = self .model .forward_train (batch )
72
- self .log ('Val/loss' , loss , on_step = True , on_epoch = False , prog_bar = True , logger = False )
73
- for k , v in loss_states .items ():
74
- self .log ('Val/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
88
+
89
+ if self .log_style == 'Lightning' :
90
+ self .log ('Val/loss' , loss , on_step = True , on_epoch = False , prog_bar = True , logger = False )
91
+ for k , v in loss_states .items ():
92
+ self .log ('Val/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
93
+ elif self .log_style == 'NanoDet' and batch_idx % self .cfg .log .interval == 0 :
94
+ lr = self .optimizers ().param_groups [0 ]['lr' ]
95
+ log_msg = 'Val|Epoch{}/{}|Iter{}({})| lr:{:.2e}| ' .format (self .current_epoch + 1 ,
96
+ self .cfg .schedule .total_epochs , self .global_step , batch_idx , lr )
97
+ for l in loss_states :
98
+ log_msg += '{}:{:.4f}| ' .format (l , loss_states [l ].mean ().item ())
99
+ self .info (log_msg )
100
+
75
101
dets = self .model .head .post_process (preds , batch )
76
102
res = {batch ['img_info' ]['id' ].cpu ().numpy ()[0 ]: dets }
77
103
return res
@@ -80,7 +106,7 @@ def validation_epoch_end(self, validation_step_outputs):
80
106
results = {}
81
107
for res in validation_step_outputs :
82
108
results .update (res )
83
- eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , self .current_epoch ,
109
+ eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , self .current_epoch + 1 ,
84
110
self ._logger , rank = self .local_rank )
85
111
metric = eval_results [self .cfg .evaluator .save_key ]
86
112
# save best model
@@ -97,8 +123,9 @@ def validation_epoch_end(self, validation_step_outputs):
97
123
f .write ("{}: {}\n " .format (k , v ))
98
124
else :
99
125
warnings .warn ('Warning! Save_key is not in eval results! Only save model last!' )
100
- for k , v in eval_results .items ():
101
- self .log ('Val/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
126
+ if self .log_style == 'Lightning' :
127
+ for k , v in eval_results .items ():
128
+ self .log ('Val/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
102
129
103
130
def configure_optimizers (self ):
104
131
optimizer_cfg = copy .deepcopy (self .cfg .schedule .optimizer )
@@ -140,8 +167,6 @@ def optimizer_step(self,
140
167
raise Exception ('Unsupported warm up type!' )
141
168
for pg in optimizer .param_groups :
142
169
pg ['lr' ] = warmup_lr
143
- # TODO: log lr to tensorboard
144
- # self.log('lr', optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True)
145
170
146
171
# update params
147
172
optimizer .step (closure = optimizer_closure )
@@ -151,8 +176,18 @@ def get_progress_bar_dict(self):
151
176
# don't show the version number
152
177
items = super ().get_progress_bar_dict ()
153
178
items .pop ("v_num" , None )
179
+ items .pop ("loss" , None )
154
180
return items
155
181
182
+ def scalar_summary (self , tag , phase , value , step ):
183
+ if self .local_rank < 1 :
184
+ self .logger .experiment .add_scalars (tag , {phase : value }, step )
185
+
186
+ def info (self , string ):
187
+ if self .local_rank < 1 :
188
+ logging .info (string )
189
+
190
+
156
191
157
192
158
193
0 commit comments