16
16
17
17
from fastreid .config import get_cfg
18
18
from fastreid .data import build_reid_test_loader , build_reid_train_loader
19
+ from fastreid .evaluation .testing import flatten_results_dict
19
20
from fastreid .engine import default_argument_parser , default_setup , launch
20
21
from fastreid .modeling import build_model
21
22
from fastreid .solver import build_lr_scheduler , build_optimizer
33
34
34
35
35
36
def get_evaluator (cfg , dataset_name , output_dir = None ):
36
- data_loader , num_query = build_reid_test_loader (cfg , dataset_name )
37
+ data_loader , num_query = build_reid_test_loader (cfg , dataset_name = dataset_name )
37
38
return data_loader , ReidEvaluator (cfg , num_query , output_dir )
38
39
39
40
@@ -49,24 +50,28 @@ def do_test(cfg, model):
49
50
)
50
51
results [dataset_name ] = {}
51
52
continue
52
- results_i = inference_on_dataset (model , data_loader , evaluator , flip_test = cfg .TEST .FLIP_ENABLED )
53
+ results_i = inference_on_dataset (model , data_loader , evaluator , flip_test = cfg .TEST .FLIP . ENABLED )
53
54
results [dataset_name ] = results_i
54
55
55
- if comm .is_main_process ():
56
- assert isinstance (
57
- results , dict
58
- ), "Evaluator must return a dict on the main process. Got {} instead." .format (
59
- results
60
- )
61
- print_csv_format (results )
56
+ if comm .is_main_process ():
57
+ assert isinstance (
58
+ results , dict
59
+ ), "Evaluator must return a dict on the main process. Got {} instead." .format (
60
+ results
61
+ )
62
+ logger .info ("Evaluation results for {} in csv format:" .format (dataset_name ))
63
+ results_i ['dataset' ] = dataset_name
64
+ print_csv_format (results_i )
62
65
63
- if len (results ) == 1 : results = list (results .values ())[0 ]
66
+ if len (results ) == 1 :
67
+ results = list (results .values ())[0 ]
64
68
65
69
return results
66
70
67
71
68
72
def do_train (cfg , model , resume = False ):
69
73
data_loader = build_reid_train_loader (cfg )
74
+ data_loader_iter = iter (data_loader )
70
75
71
76
model .train ()
72
77
optimizer = build_optimizer (cfg , model )
@@ -78,7 +83,7 @@ def do_train(cfg, model, resume=False):
78
83
model ,
79
84
cfg .OUTPUT_DIR ,
80
85
save_to_disk = comm .is_main_process (),
81
- optimizer = optimizer
86
+ optimizer = optimizer ,
82
87
** scheduler
83
88
)
84
89
@@ -93,6 +98,10 @@ def do_train(cfg, model, resume=False):
93
98
delay_epochs = cfg .SOLVER .DELAY_EPOCHS
94
99
95
100
periodic_checkpointer = PeriodicCheckpointer (checkpointer , cfg .SOLVER .CHECKPOINT_PERIOD , max_epoch )
101
+ if len (cfg .DATASETS .TESTS ) == 1 :
102
+ metric_name = "metric"
103
+ else :
104
+ metric_name = cfg .DATASETS .TESTS [0 ] + "/metric"
96
105
97
106
writers = (
98
107
[
@@ -111,7 +120,8 @@ def do_train(cfg, model, resume=False):
111
120
with EventStorage (start_iter ) as storage :
112
121
for epoch in range (start_epoch , max_epoch ):
113
122
storage .epoch = epoch
114
- for data , _ in zip (data_loader , range (iters_per_epoch )):
123
+ for _ in range (iters_per_epoch ):
124
+ data = next (data_loader_iter )
115
125
storage .iter = iteration
116
126
117
127
loss_dict = model (data )
@@ -128,9 +138,9 @@ def do_train(cfg, model, resume=False):
128
138
optimizer .step ()
129
139
storage .put_scalar ("lr" , optimizer .param_groups [0 ]["lr" ], smoothing_hint = False )
130
140
131
- if iteration - start_iter > 5 and (
132
- (iteration + 1 ) % 200 == 0 or iteration == max_iter - 1
133
- ):
141
+ if iteration - start_iter > 5 and \
142
+ (( iteration + 1 ) % 200 == 0 or iteration == max_iter - 1 ) and \
143
+ (( iteration + 1 ) % iters_per_epoch != 0 ):
134
144
for writer in writers :
135
145
writer .write ()
136
146
@@ -143,18 +153,22 @@ def do_train(cfg, model, resume=False):
143
153
for writer in writers :
144
154
writer .write ()
145
155
146
- if iteration > warmup_iters and (epoch + 1 ) >= delay_epochs :
156
+ if iteration > warmup_iters and (epoch + 1 ) > delay_epochs :
147
157
scheduler ["lr_sched" ].step ()
148
158
149
159
if (
150
160
cfg .TEST .EVAL_PERIOD > 0
151
161
and (epoch + 1 ) % cfg .TEST .EVAL_PERIOD == 0
152
- and epoch != max_iter - 1
162
+ and iteration != max_iter - 1
153
163
):
154
- do_test (cfg , model )
164
+ results = do_test (cfg , model )
155
165
# Compared to "train_net.py", the test results are not dumped to EventStorage
166
+ else :
167
+ results = {}
168
+ flatten_results = flatten_results_dict (results )
156
169
157
- periodic_checkpointer .step (epoch )
170
+ metric_dict = dict (metric = flatten_results [metric_name ] if metric_name in flatten_results else - 1 )
171
+ periodic_checkpointer .step (epoch , ** metric_dict )
158
172
159
173
160
174
def setup (args ):
@@ -184,7 +198,9 @@ def main(args):
184
198
185
199
distributed = comm .get_world_size () > 1
186
200
if distributed :
187
- model = DistributedDataParallel (model , delay_allreduce = True )
201
+ model = DistributedDataParallel (
202
+ model , device_ids = [comm .get_local_rank ()], broadcast_buffers = False
203
+ )
188
204
189
205
do_train (cfg , model , resume = args .resume )
190
206
return do_test (cfg , model )
0 commit comments