Skip to content

Commit a61515d

Browse files
authored
Merge pull request #254 from RangiLyu/fix/ddp_eval
[BUG FIX] Fix evaluation bug when using multi-GPU training with pytorch-lightning.
2 parents 35ddab5 + 1386d1e commit a61515d

File tree

3 files changed

+83
-34
lines changed

3 files changed

+83
-34
lines changed

nanodet/trainer/task.py

+42-33
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
import json
1919
import torch
2020
import logging
21+
2122
from pytorch_lightning import LightningModule
22-
from typing import Any, List, Dict, Tuple, Optional
23+
from typing import Any, List
24+
from nanodet.util import mkdir, gather_results
2325

2426
from ..model.arch import build_model
25-
from nanodet.util import mkdir
2627

2728

2829
class TrainingTask(LightningModule):
@@ -109,28 +110,32 @@ def validation_epoch_end(self, validation_step_outputs):
109110
results = {}
110111
for res in validation_step_outputs:
111112
results.update(res)
112-
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, rank=self.local_rank)
113-
metric = eval_results[self.cfg.evaluator.save_key]
114-
# save best model
115-
if metric > self.save_flag:
116-
self.save_flag = metric
117-
best_save_path = os.path.join(self.cfg.save_dir, 'model_best')
118-
mkdir(self.local_rank, best_save_path)
119-
self.trainer.save_checkpoint(os.path.join(best_save_path, "model_best.ckpt"))
120-
txt_path = os.path.join(best_save_path, "eval_results.txt")
121-
if self.local_rank < 1:
122-
with open(txt_path, "a") as f:
123-
f.write("Epoch:{}\n".format(self.current_epoch+1))
124-
for k, v in eval_results.items():
125-
f.write("{}: {}\n".format(k, v))
113+
all_results = gather_results(results)
114+
if all_results:
115+
eval_results = self.evaluator.evaluate(all_results, self.cfg.save_dir, rank=self.local_rank)
116+
metric = eval_results[self.cfg.evaluator.save_key]
117+
# save best model
118+
if metric > self.save_flag:
119+
self.save_flag = metric
120+
best_save_path = os.path.join(self.cfg.save_dir, 'model_best')
121+
mkdir(self.local_rank, best_save_path)
122+
self.trainer.save_checkpoint(os.path.join(best_save_path, "model_best.ckpt"))
123+
txt_path = os.path.join(best_save_path, "eval_results.txt")
124+
if self.local_rank < 1:
125+
with open(txt_path, "a") as f:
126+
f.write("Epoch:{}\n".format(self.current_epoch+1))
127+
for k, v in eval_results.items():
128+
f.write("{}: {}\n".format(k, v))
129+
else:
130+
warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
131+
if self.log_style == 'Lightning':
132+
for k, v in eval_results.items():
133+
self.log('Val_metrics/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
134+
elif self.log_style == 'NanoDet':
135+
for k, v in eval_results.items():
136+
self.scalar_summary('Val_metrics/' + k, 'Val', v, self.current_epoch+1)
126137
else:
127-
warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
128-
if self.log_style == 'Lightning':
129-
for k, v in eval_results.items():
130-
self.log('Val_metrics/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
131-
elif self.log_style == 'NanoDet':
132-
for k, v in eval_results.items():
133-
self.scalar_summary('Val_metrics/' + k, 'Val', v, self.current_epoch+1)
138+
self.info('Skip val on rank {}'.format(self.local_rank))
134139

135140
def test_step(self, batch, batch_idx):
136141
dets = self.predict(batch, batch_idx)
@@ -140,16 +145,20 @@ def test_epoch_end(self, test_step_outputs):
140145
results = {}
141146
for res in test_step_outputs:
142147
results.update(res)
143-
res_json = self.evaluator.results2json(results)
144-
json_path = os.path.join(self.cfg.save_dir, 'results.json')
145-
json.dump(res_json, open(json_path, 'w'))
146-
147-
if self.cfg.test_mode == 'val':
148-
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, rank=self.local_rank)
149-
txt_path = os.path.join(self.cfg.save_dir, "eval_results.txt")
150-
with open(txt_path, "a") as f:
151-
for k, v in eval_results.items():
152-
f.write("{}: {}\n".format(k, v))
148+
all_results = gather_results(results)
149+
if all_results:
150+
res_json = self.evaluator.results2json(all_results)
151+
json_path = os.path.join(self.cfg.save_dir, 'results.json')
152+
json.dump(res_json, open(json_path, 'w'))
153+
154+
if self.cfg.test_mode == 'val':
155+
eval_results = self.evaluator.evaluate(all_results, self.cfg.save_dir, rank=self.local_rank)
156+
txt_path = os.path.join(self.cfg.save_dir, "eval_results.txt")
157+
with open(txt_path, "a") as f:
158+
for k, v in eval_results.items():
159+
f.write("{}: {}\n".format(k, v))
160+
else:
161+
self.info('Skip test on rank {}'.format(self.local_rank))
153162

154163
def configure_optimizers(self):
155164
"""

nanodet/util/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .visualization import Visualizer, overlay_bbox_cv
1111
from .flops_counter import get_model_complexity_info
1212
from .misc import multi_apply, images_to_levels, unmap
13+
from .scatter_gather import gather_results, scatter_kwargs

nanodet/util/scatter_gather.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import pickle
2+
13
import torch
24
from torch.autograd import Variable
5+
import torch.distributed as dist
36
from torch.nn.parallel._functions import Scatter
47

58

@@ -10,6 +13,7 @@ def list_scatter(input, target_gpus, chunk_sizes):
1013
del input[:size]
1114
return tuple(ret)
1215

16+
1317
def scatter(inputs, target_gpus, dim=0, chunk_sizes=None):
1418
"""
1519
Slices variables into approximately equal chunks and
@@ -42,4 +46,39 @@ def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None):
4246
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
4347
inputs = tuple(inputs)
4448
kwargs = tuple(kwargs)
45-
return inputs, kwargs
49+
return inputs, kwargs
50+
51+
52+
def gather_results(result_part):
53+
rank = -1
54+
world_size = 1
55+
if dist.is_available() and dist.is_initialized():
56+
rank = dist.get_rank()
57+
world_size = dist.get_world_size()
58+
59+
# dump result part to tensor with pickle
60+
part_tensor = torch.tensor(
61+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
62+
63+
# gather all result part tensor shape
64+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
65+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
66+
dist.all_gather(shape_list, shape_tensor)
67+
68+
# padding result part tensor to max length
69+
shape_max = torch.tensor(shape_list).max()
70+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
71+
part_send[:shape_tensor[0]] = part_tensor
72+
part_recv_list = [
73+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
74+
]
75+
76+
# gather all result dict
77+
dist.all_gather(part_recv_list, part_send)
78+
79+
if rank < 1:
80+
all_res = {}
81+
for recv, shape in zip(part_recv_list, shape_list):
82+
all_res.update(
83+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
84+
return all_res

0 commit comments

Comments
 (0)