18
18
import json
19
19
import torch
20
20
import logging
21
+
21
22
from pytorch_lightning import LightningModule
22
- from typing import Any , List , Dict , Tuple , Optional
23
+ from typing import Any , List
24
+ from nanodet .util import mkdir , gather_results
23
25
24
26
from ..model .arch import build_model
25
- from nanodet .util import mkdir
26
27
27
28
28
29
class TrainingTask (LightningModule ):
@@ -109,28 +110,32 @@ def validation_epoch_end(self, validation_step_outputs):
109
110
results = {}
110
111
for res in validation_step_outputs :
111
112
results .update (res )
112
- eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , rank = self .local_rank )
113
- metric = eval_results [self .cfg .evaluator .save_key ]
114
- # save best model
115
- if metric > self .save_flag :
116
- self .save_flag = metric
117
- best_save_path = os .path .join (self .cfg .save_dir , 'model_best' )
118
- mkdir (self .local_rank , best_save_path )
119
- self .trainer .save_checkpoint (os .path .join (best_save_path , "model_best.ckpt" ))
120
- txt_path = os .path .join (best_save_path , "eval_results.txt" )
121
- if self .local_rank < 1 :
122
- with open (txt_path , "a" ) as f :
123
- f .write ("Epoch:{}\n " .format (self .current_epoch + 1 ))
124
- for k , v in eval_results .items ():
125
- f .write ("{}: {}\n " .format (k , v ))
113
+ all_results = gather_results (results )
114
+ if all_results :
115
+ eval_results = self .evaluator .evaluate (all_results , self .cfg .save_dir , rank = self .local_rank )
116
+ metric = eval_results [self .cfg .evaluator .save_key ]
117
+ # save best model
118
+ if metric > self .save_flag :
119
+ self .save_flag = metric
120
+ best_save_path = os .path .join (self .cfg .save_dir , 'model_best' )
121
+ mkdir (self .local_rank , best_save_path )
122
+ self .trainer .save_checkpoint (os .path .join (best_save_path , "model_best.ckpt" ))
123
+ txt_path = os .path .join (best_save_path , "eval_results.txt" )
124
+ if self .local_rank < 1 :
125
+ with open (txt_path , "a" ) as f :
126
+ f .write ("Epoch:{}\n " .format (self .current_epoch + 1 ))
127
+ for k , v in eval_results .items ():
128
+ f .write ("{}: {}\n " .format (k , v ))
129
+ else :
130
+ warnings .warn ('Warning! Save_key is not in eval results! Only save model last!' )
131
+ if self .log_style == 'Lightning' :
132
+ for k , v in eval_results .items ():
133
+ self .log ('Val_metrics/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
134
+ elif self .log_style == 'NanoDet' :
135
+ for k , v in eval_results .items ():
136
+ self .scalar_summary ('Val_metrics/' + k , 'Val' , v , self .current_epoch + 1 )
126
137
else :
127
- warnings .warn ('Warning! Save_key is not in eval results! Only save model last!' )
128
- if self .log_style == 'Lightning' :
129
- for k , v in eval_results .items ():
130
- self .log ('Val_metrics/' + k , v , on_step = False , on_epoch = True , prog_bar = False , sync_dist = True )
131
- elif self .log_style == 'NanoDet' :
132
- for k , v in eval_results .items ():
133
- self .scalar_summary ('Val_metrics/' + k , 'Val' , v , self .current_epoch + 1 )
138
+ self .info ('Skip val on rank {}' .format (self .local_rank ))
134
139
135
140
def test_step (self , batch , batch_idx ):
136
141
dets = self .predict (batch , batch_idx )
@@ -140,16 +145,20 @@ def test_epoch_end(self, test_step_outputs):
140
145
results = {}
141
146
for res in test_step_outputs :
142
147
results .update (res )
143
- res_json = self .evaluator .results2json (results )
144
- json_path = os .path .join (self .cfg .save_dir , 'results.json' )
145
- json .dump (res_json , open (json_path , 'w' ))
146
-
147
- if self .cfg .test_mode == 'val' :
148
- eval_results = self .evaluator .evaluate (results , self .cfg .save_dir , rank = self .local_rank )
149
- txt_path = os .path .join (self .cfg .save_dir , "eval_results.txt" )
150
- with open (txt_path , "a" ) as f :
151
- for k , v in eval_results .items ():
152
- f .write ("{}: {}\n " .format (k , v ))
148
+ all_results = gather_results (results )
149
+ if all_results :
150
+ res_json = self .evaluator .results2json (all_results )
151
+ json_path = os .path .join (self .cfg .save_dir , 'results.json' )
152
+ json .dump (res_json , open (json_path , 'w' ))
153
+
154
+ if self .cfg .test_mode == 'val' :
155
+ eval_results = self .evaluator .evaluate (all_results , self .cfg .save_dir , rank = self .local_rank )
156
+ txt_path = os .path .join (self .cfg .save_dir , "eval_results.txt" )
157
+ with open (txt_path , "a" ) as f :
158
+ for k , v in eval_results .items ():
159
+ f .write ("{}: {}\n " .format (k , v ))
160
+ else :
161
+ self .info ('Skip test on rank {}' .format (self .local_rank ))
153
162
154
163
def configure_optimizers (self ):
155
164
"""
0 commit comments