15
15
import copy
16
16
import os
17
17
import warnings
18
+ import json
18
19
import torch
19
20
import logging
20
21
from pytorch_lightning import LightningModule
27
28
class TrainingTask (LightningModule ):
28
29
"""
29
30
Pytorch Lightning module of a general training task.
31
+ Including training, evaluating and testing.
32
+ Args:
33
+ cfg: Training configurations
34
+ evaluator: Evaluator for evaluating the model performance.
30
35
"""
31
36
32
- def __init__ (self , cfg , evaluator = None , logger = None ):
33
- """
34
-
35
- Args:
36
- cfg: Training configurations
37
- evaluator:
38
- logger:
39
- """
37
+ def __init__ (self , cfg , evaluator = None ):
40
38
super (TrainingTask , self ).__init__ ()
41
39
self .cfg = cfg
42
40
self .model = build_model (cfg .model )
43
41
self .evaluator = evaluator
44
- self ._logger = logger
45
42
self .save_flag = - 10
46
43
self .log_style = 'NanoDet' # Log style. Choose between 'NanoDet' or 'Lightning'
47
44
# TODO: use callback to log
48
- # TODO: remove _logger
49
45
# TODO: batch eval
50
46
# TODO: support old checkpoint
51
47
@@ -54,7 +50,7 @@ def forward(self, x):
54
50
return x
55
51
56
52
@torch .no_grad ()
57
- def predict (self , batch , batch_idx , dataloader_idx ):
53
+ def predict (self , batch , batch_idx = None , dataloader_idx = None ):
58
54
preds = self .forward (batch ['img' ])
59
55
results = self .model .head .post_process (preds , batch )
60
56
return results
@@ -103,11 +99,17 @@ def validation_step(self, batch, batch_idx):
103
99
return res
104
100
105
101
def validation_epoch_end (self , validation_step_outputs ):
102
+ """
103
+ Called at the end of the validation epoch with the outputs of all validation steps.
104
+ Evaluating results and save best model.
105
+ Args:
106
+ validation_step_outputs: A list of val outputs
107
+
108
+ """
106
109
results = {}
107
110
for res in validation_step_outputs :
108
111
results .update (res )
109
- eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , self .current_epoch + 1 ,
110
- self ._logger , rank = self .local_rank )
112
+ eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , rank = self .local_rank )
111
113
metric = eval_results [self .cfg .evaluator .save_key ]
112
114
# save best model
113
115
if metric > self .save_flag :
@@ -125,9 +127,39 @@ def validation_epoch_end(self, validation_step_outputs):
125
127
warnings .warn ('Warning! Save_key is not in eval results! Only save model last!' )
126
128
if self .log_style == 'Lightning' :
127
129
for k , v in eval_results .items ():
128
- self .log ('Val/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
130
+ self .log ('Val_metrics/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
131
+ elif self .log_style == 'NanoDet' :
132
+ for k , v in eval_results .items ():
133
+ self .scalar_summary ('Val_metrics/' + k , 'Val' , v , self .current_epoch + 1 )
134
+
135
+ def test_step (self , batch , batch_idx ):
136
+ dets = self .predict (batch , batch_idx )
137
+ res = {batch ['img_info' ]['id' ].cpu ().numpy ()[0 ]: dets }
138
+ return res
139
+
140
+ def test_epoch_end (self , test_step_outputs ):
141
+ results = {}
142
+ for res in test_step_outputs :
143
+ results .update (res )
144
+ res_json = self .evaluator .results2json (results )
145
+ json_path = os .path .join (self .cfg .save_dir , 'results.json' )
146
+ json .dump (res_json , open (json_path , 'w' ))
147
+
148
+ if self .cfg .test_mode == 'val' :
149
+ eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , rank = self .local_rank )
150
+ txt_path = os .path .join (self .cfg .save_dir , "eval_results.txt" )
151
+ with open (txt_path , "a" ) as f :
152
+ for k , v in eval_results .items ():
153
+ f .write ("{}: {}\n " .format (k , v ))
129
154
130
155
def configure_optimizers (self ):
156
+ """
157
+ Prepare optimizer and learning-rate scheduler
158
+ to use in optimization.
159
+
160
+ Returns:
161
+ optimizer
162
+ """
131
163
optimizer_cfg = copy .deepcopy (self .cfg .schedule .optimizer )
132
164
name = optimizer_cfg .pop ('name' )
133
165
build_optimizer = getattr (torch .optim , name )
@@ -153,6 +185,18 @@ def optimizer_step(self,
153
185
on_tpu = None ,
154
186
using_native_amp = None ,
155
187
using_lbfgs = None ):
188
+ """
189
+ Performs a single optimization step (parameter update).
190
+ Args:
191
+ epoch: Current epoch
192
+ batch_idx: Index of current batch
193
+ optimizer: A PyTorch optimizer
194
+ optimizer_idx: If you used multiple optimizers this indexes into that list.
195
+ optimizer_closure: closure for all optimizers
196
+ on_tpu: true if TPU backward is required
197
+ using_native_amp: True if using native amp
198
+ using_lbfgs: True if the matching optimizer is lbfgs
199
+ """
156
200
# warm up lr
157
201
if self .trainer .global_step <= self .cfg .schedule .warmup .steps :
158
202
if self .cfg .schedule .warmup .name == 'constant' :
@@ -180,6 +224,15 @@ def get_progress_bar_dict(self):
180
224
return items
181
225
182
226
def scalar_summary (self , tag , phase , value , step ):
227
+ """
228
+ Write Tensorboard scalar summary log.
229
+ Args:
230
+ tag: Name for the tag
231
+ phase: 'Train' or 'Val'
232
+ value: Value to record
233
+ step: Step value to record
234
+
235
+ """
183
236
if self .local_rank < 1 :
184
237
self .logger .experiment .add_scalars (tag , {phase : value }, step )
185
238
0 commit comments