@@ -252,14 +252,14 @@ def reduce_dict(input_dict, average=True):
252
252
return reduced_dict
253
253
254
254
255
- def all_gather_container (container , group = None ):
255
+ def all_gather_container (container , group = None , cat_dim = 0 ):
256
256
group = group or dist .group .WORLD
257
257
world_size = dist .get_world_size (group )
258
258
259
259
def _do_gather (tensor ):
260
260
tensor_list = [torch .empty_like (tensor ) for _ in range (world_size )]
261
261
dist .all_gather (tensor_list , tensor , group = group )
262
- return torch .cat (tensor_list , dim = - 1 )
262
+ return torch .cat (tensor_list , dim = cat_dim )
263
263
264
264
if isinstance (container , dict ):
265
265
gathered = dict ()
@@ -278,7 +278,7 @@ def _do_gather(tensor):
278
278
return _do_gather (container )
279
279
280
280
281
- def gather_container (container , dst , group = None ):
281
+ def gather_container (container , dst , group = None , cat_dim = 0 ):
282
282
group = group or dist .group .WORLD
283
283
world_size = dist .get_world_size (group )
284
284
this_rank = dist .get_rank (group )
@@ -289,7 +289,7 @@ def _do_gather(tensor):
289
289
else :
290
290
tensor_list = None
291
291
dist .gather (tensor , tensor_list , dst = dst , group = group )
292
- return torch .cat (tensor_list , dim = - 1 )
292
+ return torch .cat (tensor_list , dim = cat_dim )
293
293
294
294
if isinstance (container , dict ):
295
295
gathered = dict ()
0 commit comments