Skip to content

Commit ff8a958

Browse files
authored
bugfix for plain_train_net.py and lr scheduler step (#484)
1 parent 46b0681 commit ff8a958

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ Support many tasks beyond reid, such image retrieval and face recognition. See [
2424
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
2525
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
2626

27-
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
27+
We write a [fastreid intro](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2020/05/29/fastreid.html)
28+
and [fastreid v1.0](https://l1aoxingyu.github.io/blogpages/reid/fastreid/2021/04/28/fastreid-v1.html) about this toolbox.
2829

2930
## Changelog
3031

fastreid/engine/hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def after_step(self):
257257
def after_epoch(self):
258258
next_iter = self.trainer.iter + 1
259259
next_epoch = self.trainer.epoch + 1
260-
if next_iter > self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
260+
if next_iter > self.trainer.warmup_iters and next_epoch > self.trainer.delay_epochs:
261261
self._scheduler["lr_sched"].step()
262262

263263

tools/plain_train_net.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from fastreid.config import get_cfg
1818
from fastreid.data import build_reid_test_loader, build_reid_train_loader
19+
from fastreid.evaluation.testing import flatten_results_dict
1920
from fastreid.engine import default_argument_parser, default_setup, launch
2021
from fastreid.modeling import build_model
2122
from fastreid.solver import build_lr_scheduler, build_optimizer
@@ -33,7 +34,7 @@
3334

3435

3536
def get_evaluator(cfg, dataset_name, output_dir=None):
36-
data_loader, num_query = build_reid_test_loader(cfg, dataset_name)
37+
data_loader, num_query = build_reid_test_loader(cfg, dataset_name=dataset_name)
3738
return data_loader, ReidEvaluator(cfg, num_query, output_dir)
3839

3940

@@ -49,24 +50,28 @@ def do_test(cfg, model):
4950
)
5051
results[dataset_name] = {}
5152
continue
52-
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP_ENABLED)
53+
results_i = inference_on_dataset(model, data_loader, evaluator, flip_test=cfg.TEST.FLIP.ENABLED)
5354
results[dataset_name] = results_i
5455

55-
if comm.is_main_process():
56-
assert isinstance(
57-
results, dict
58-
), "Evaluator must return a dict on the main process. Got {} instead.".format(
59-
results
60-
)
61-
print_csv_format(results)
56+
if comm.is_main_process():
57+
assert isinstance(
58+
results, dict
59+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
60+
results
61+
)
62+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
63+
results_i['dataset'] = dataset_name
64+
print_csv_format(results_i)
6265

63-
if len(results) == 1: results = list(results.values())[0]
66+
if len(results) == 1:
67+
results = list(results.values())[0]
6468

6569
return results
6670

6771

6872
def do_train(cfg, model, resume=False):
6973
data_loader = build_reid_train_loader(cfg)
74+
data_loader_iter = iter(data_loader)
7075

7176
model.train()
7277
optimizer = build_optimizer(cfg, model)
@@ -78,7 +83,7 @@ def do_train(cfg, model, resume=False):
7883
model,
7984
cfg.OUTPUT_DIR,
8085
save_to_disk=comm.is_main_process(),
81-
optimizer=optimizer
86+
optimizer=optimizer,
8287
**scheduler
8388
)
8489

@@ -93,6 +98,10 @@ def do_train(cfg, model, resume=False):
9398
delay_epochs = cfg.SOLVER.DELAY_EPOCHS
9499

95100
periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch)
101+
if len(cfg.DATASETS.TESTS) == 1:
102+
metric_name = "metric"
103+
else:
104+
metric_name = cfg.DATASETS.TESTS[0] + "/metric"
96105

97106
writers = (
98107
[
@@ -111,7 +120,8 @@ def do_train(cfg, model, resume=False):
111120
with EventStorage(start_iter) as storage:
112121
for epoch in range(start_epoch, max_epoch):
113122
storage.epoch = epoch
114-
for data, _ in zip(data_loader, range(iters_per_epoch)):
123+
for _ in range(iters_per_epoch):
124+
data = next(data_loader_iter)
115125
storage.iter = iteration
116126

117127
loss_dict = model(data)
@@ -128,9 +138,9 @@ def do_train(cfg, model, resume=False):
128138
optimizer.step()
129139
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
130140

131-
if iteration - start_iter > 5 and (
132-
(iteration + 1) % 200 == 0 or iteration == max_iter - 1
133-
):
141+
if iteration - start_iter > 5 and \
142+
((iteration + 1) % 200 == 0 or iteration == max_iter - 1) and \
143+
((iteration + 1) % iters_per_epoch != 0):
134144
for writer in writers:
135145
writer.write()
136146

@@ -143,18 +153,22 @@ def do_train(cfg, model, resume=False):
143153
for writer in writers:
144154
writer.write()
145155

146-
if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
156+
if iteration > warmup_iters and (epoch + 1) > delay_epochs:
147157
scheduler["lr_sched"].step()
148158

149159
if (
150160
cfg.TEST.EVAL_PERIOD > 0
151161
and (epoch + 1) % cfg.TEST.EVAL_PERIOD == 0
152-
and epoch != max_iter - 1
162+
and iteration != max_iter - 1
153163
):
154-
do_test(cfg, model)
164+
results = do_test(cfg, model)
155165
# Compared to "train_net.py", the test results are not dumped to EventStorage
166+
else:
167+
results = {}
168+
flatten_results = flatten_results_dict(results)
156169

157-
periodic_checkpointer.step(epoch)
170+
metric_dict = dict(metric=flatten_results[metric_name] if metric_name in flatten_results else -1)
171+
periodic_checkpointer.step(epoch, **metric_dict)
158172

159173

160174
def setup(args):
@@ -184,7 +198,9 @@ def main(args):
184198

185199
distributed = comm.get_world_size() > 1
186200
if distributed:
187-
model = DistributedDataParallel(model, delay_allreduce=True)
201+
model = DistributedDataParallel(
202+
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
203+
)
188204

189205
do_train(cfg, model, resume=args.resume)
190206
return do_test(cfg, model)

0 commit comments

Comments
 (0)