@@ -251,6 +251,12 @@ class DistributedConfig(Config):
251
251
desc = "Ensure the initialization is the same for any distributed configuration." ,
252
252
hint = FieldHint .testing ,
253
253
)
254
+ reference_config : "DistributedConfig|None" = Field (
255
+ default = None ,
256
+ init = False ,
257
+ desc = "Pointer to the distributed config this one is an identical copy of." ,
258
+ hint = FieldHint .derived ,
259
+ )
254
260
255
261
def _validate (self ) -> None :
256
262
if self .world_size is None :
@@ -281,76 +287,90 @@ def _validate(self) -> None:
281
287
if self .tensor_parallel == 1 :
282
288
self .sequence_tensor_parallel = False
283
289
284
- self .distributed_dims = {}
290
+ if self .reference_config is not None :
291
+ self .reference_config .validate ()
292
+ if self .reference_config .reference_config is not None :
293
+ self .reference_config = self .reference_config .reference_config
294
+ assert self .reference_config .reference_config is None
295
+ self .compare (self .reference_config , ValueError )
296
+ self .distributed_dims = self .reference_config .distributed_dims
297
+ else :
298
+ self .distributed_dims = {}
285
299
286
- self .add_distributed_dim (
287
- DistributedDim (name = DistributedDimNames .world , size = self .world_size , rank = self .rank , id_ = None , parent = None )
288
- )
289
- self .add_distributed_dim (
290
- DistributedDim (
291
- name = DistributedDimNames .data ,
292
- size = self .data_parallel ,
293
- rank = self .data_rank ,
294
- id_ = f"x_{ self .pipeline_rank } _{ self .tensor_rank } " ,
295
- parent = DistributedDimNames .world ,
300
+ self ._add_distributed_dim (
301
+ DistributedDim (
302
+ name = DistributedDimNames .world , size = self .world_size , rank = self .rank , id_ = None , parent = None
303
+ )
296
304
)
297
- )
298
- self . add_distributed_dim (
299
- DistributedDim (
300
- name = DistributedDimNames . pipeline ,
301
- size = self .pipeline_parallel ,
302
- rank = self .pipeline_rank ,
303
- id_ = f"x_ { self . data_rank } _ { self . tensor_rank } " ,
304
- parent = DistributedDimNames . world ,
305
+ self . _add_distributed_dim (
306
+ DistributedDim (
307
+ name = DistributedDimNames . data ,
308
+ size = self . data_parallel ,
309
+ rank = self .data_rank ,
310
+ id_ = f"x_ { self .pipeline_rank } _ { self . tensor_rank } " ,
311
+ parent = DistributedDimNames . world ,
312
+ )
305
313
)
306
- )
307
- self . add_distributed_dim (
308
- DistributedDim (
309
- name = DistributedDimNames . tensor ,
310
- size = self .tensor_parallel ,
311
- rank = self .tensor_rank ,
312
- id_ = f"x_ { self . data_rank } _ { self . pipeline_rank } " ,
313
- parent = DistributedDimNames . world ,
314
+ self . _add_distributed_dim (
315
+ DistributedDim (
316
+ name = DistributedDimNames . pipeline ,
317
+ size = self . pipeline_parallel ,
318
+ rank = self .pipeline_rank ,
319
+ id_ = f"x_ { self .data_rank } _ { self . tensor_rank } " ,
320
+ parent = DistributedDimNames . world ,
321
+ )
314
322
)
315
- )
316
- self . add_distributed_dim (
317
- DistributedDim (
318
- name = DistributedDimNames . sequence_data ,
319
- size = self .sequence_data_parallel ,
320
- rank = self .sequence_data_rank ,
321
- id_ = f" { self . batch_data_rank } _ { self . pipeline_rank } _ { self . tensor_rank } " ,
322
- parent = DistributedDimNames . data ,
323
+ self . _add_distributed_dim (
324
+ DistributedDim (
325
+ name = DistributedDimNames . tensor ,
326
+ size = self . tensor_parallel ,
327
+ rank = self .tensor_rank ,
328
+ id_ = f"x_ { self .data_rank } _ { self . pipeline_rank } " ,
329
+ parent = DistributedDimNames . world ,
330
+ )
323
331
)
324
- )
325
- self . add_distributed_dim (
326
- DistributedDim (
327
- name = DistributedDimNames . batch_data ,
328
- size = self .batch_data_parallel ,
329
- rank = self .batch_data_rank ,
330
- id_ = f" { self . sequence_data_rank } _ { self . pipeline_rank } _ { self . tensor_rank } " ,
331
- parent = DistributedDimNames . data ,
332
+ self . _add_distributed_dim (
333
+ DistributedDim (
334
+ name = DistributedDimNames . sequence_data ,
335
+ size = self . sequence_data_parallel ,
336
+ rank = self .sequence_data_rank ,
337
+ id_ = f" { self .batch_data_rank } _ { self . pipeline_rank } _ { self . tensor_rank } " ,
338
+ parent = DistributedDimNames . data ,
339
+ )
332
340
)
333
- )
334
- self .add_distributed_dim (
335
- DistributedDim (
336
- name = DistributedDimNames .tensor_and_sequence_data ,
337
- size = self .sequence_data_parallel * self .tensor_parallel ,
338
- rank = self .tensor_rank + self .sequence_data_rank * self .tensor_parallel ,
339
- id_ = f"{ self .batch_data_rank } _{ self .pipeline_rank } " ,
340
- parent = (
341
- DistributedDimNames .tensor
342
- if self .sequence_data_parallel == 1
343
- else DistributedDimNames .sequence_data if self .tensor_parallel == 1 else DistributedDimNames .world
344
- ),
341
+ self ._add_distributed_dim (
342
+ DistributedDim (
343
+ name = DistributedDimNames .batch_data ,
344
+ size = self .batch_data_parallel ,
345
+ rank = self .batch_data_rank ,
346
+ id_ = f"{ self .sequence_data_rank } _{ self .pipeline_rank } _{ self .tensor_rank } " ,
347
+ parent = DistributedDimNames .data ,
348
+ )
349
+ )
350
+ self ._add_distributed_dim (
351
+ DistributedDim (
352
+ name = DistributedDimNames .tensor_and_sequence_data ,
353
+ size = self .sequence_data_parallel * self .tensor_parallel ,
354
+ rank = self .tensor_rank + self .sequence_data_rank * self .tensor_parallel ,
355
+ id_ = f"{ self .batch_data_rank } _{ self .pipeline_rank } " ,
356
+ parent = (
357
+ DistributedDimNames .tensor
358
+ if self .sequence_data_parallel == 1
359
+ else (
360
+ DistributedDimNames .sequence_data
361
+ if self .tensor_parallel == 1
362
+ else DistributedDimNames .world
363
+ )
364
+ ),
365
+ )
345
366
)
346
- )
347
367
348
368
super ()._validate ()
349
369
350
370
Assert .in_range (self .rank , 0 , self .world_size )
351
371
Assert .in_range (self .local_rank , 0 , self .local_world_size )
352
372
353
- def add_distributed_dim (self , distributed_dim : DistributedDim ) -> None :
373
+ def _add_distributed_dim (self , distributed_dim : DistributedDim ) -> None :
354
374
if distributed_dim .name in self .distributed_dims :
355
375
Assert .eq (distributed_dim , self .distributed_dims [distributed_dim .name ])
356
376
else :
0 commit comments