Skip to content

Commit da76f6d

Browse files
committed
Fix concat bug in distributed container gather fns
1 parent 1c9a3d3 commit da76f6d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

effdet/distributed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,14 @@ def reduce_dict(input_dict, average=True):
252252
return reduced_dict
253253

254254

255-
def all_gather_container(container, group=None):
255+
def all_gather_container(container, group=None, cat_dim=0):
256256
group = group or dist.group.WORLD
257257
world_size = dist.get_world_size(group)
258258

259259
def _do_gather(tensor):
260260
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
261261
dist.all_gather(tensor_list, tensor, group=group)
262-
return torch.cat(tensor_list, dim=-1)
262+
return torch.cat(tensor_list, dim=cat_dim)
263263

264264
if isinstance(container, dict):
265265
gathered = dict()
@@ -278,7 +278,7 @@ def _do_gather(tensor):
278278
return _do_gather(container)
279279

280280

281-
def gather_container(container, dst, group=None):
281+
def gather_container(container, dst, group=None, cat_dim=0):
282282
group = group or dist.group.WORLD
283283
world_size = dist.get_world_size(group)
284284
this_rank = dist.get_rank(group)
@@ -289,7 +289,7 @@ def _do_gather(tensor):
289289
else:
290290
tensor_list = None
291291
dist.gather(tensor, tensor_list, dst=dst, group=group)
292-
return torch.cat(tensor_list, dim=-1)
292+
return torch.cat(tensor_list, dim=cat_dim)
293293

294294
if isinstance(container, dict):
295295
gathered = dict()

0 commit comments

Comments
 (0)