Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 0b9afe8

Browse files
Fully Sharded Data Parallel (#3740)
* Implement zero2 and zero3 * Implement overflow syncing. * Tweak log statements. * Use free ports rather than random ports * Refactor test_distributed * More refactor. * Fixup checkpoints. * Get tests working. * GPU only * Sigh * Moar. * Trying to sync grad norms * Correctly implement gnorm syncing. * Update comment. * Try zero3. * Okay got zero3 working. * Refactor. * Get FSDP Zero3 working, except during validation. * Check in missing code. Carve out notimplemented. * Lint. * Er. * Add a test to ensure we keep track of zero3 not working. * Remove debugs, add docstrings, rename variable. * Silly * Reviewer comments. * Lint. * We disabled zero3 as an option, so don't need the test. * Bug caught by Kurt. * Rofl
1 parent 7400795 commit 0b9afe8

File tree

12 files changed

+358
-158
lines changed

12 files changed

+358
-158
lines changed

parlai/agents/transformer/modules/decoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from parlai.core.opt import Opt
2626
from parlai.utils.misc import warn_once
2727
from parlai.utils.torch import PipelineHelper
28+
from parlai.utils.fsdp import fsdp_wrap
2829

2930

3031
@swappable(
@@ -277,16 +278,15 @@ def _default(val, default):
277278
def build_layers(self) -> nn.ModuleList:
278279
layers = nn.ModuleList()
279280
for _ in range(self.n_layers):
280-
layers.append(
281-
self.swappables.layer(
282-
self.opt,
283-
attention_dropout=self.opt.get('attention_dropout', 0.0),
284-
relu_dropout=self.opt.get('relu_dropout', 0.0),
285-
dropout=self.opt.get('dropout', 0.0),
286-
activation=self.activation,
287-
variant=self.variant,
288-
) # type: ignore
281+
layer = self.swappables.layer(
282+
self.opt,
283+
attention_dropout=self.opt.get('attention_dropout', 0.0),
284+
relu_dropout=self.opt.get('relu_dropout', 0.0),
285+
dropout=self.opt.get('dropout', 0.0),
286+
activation=self.activation,
287+
variant=self.variant,
289288
)
289+
layers.append(fsdp_wrap(layer)) # type: ignore
290290
return layers
291291

292292
def forward_embedding(

parlai/agents/transformer/modules/encoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from parlai.core.opt import Opt
2626
from parlai.utils.misc import warn_once
2727
from parlai.utils.torch import PipelineHelper
28+
from parlai.utils.fsdp import fsdp_wrap
2829

2930

3031
@swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN)
@@ -227,16 +228,15 @@ def _default(val, default):
227228
def build_layers(self) -> nn.ModuleList:
228229
layers = nn.ModuleList()
229230
for _ in range(self.n_layers):
230-
layers.append(
231-
self.swappables.layer( # type: ignore
232-
self.opt,
233-
attention_dropout=self.opt.get('attention_dropout', 0.0),
234-
relu_dropout=self.opt.get('relu_dropout', 0.0),
235-
dropout=self.dropout_frac,
236-
variant=self.variant,
237-
activation=self.activation,
238-
)
231+
layer = self.swappables.layer( # type: ignore
232+
self.opt,
233+
attention_dropout=self.opt.get('attention_dropout', 0.0),
234+
relu_dropout=self.opt.get('relu_dropout', 0.0),
235+
dropout=self.dropout_frac,
236+
variant=self.variant,
237+
activation=self.activation,
239238
)
239+
layers.append(fsdp_wrap(layer))
240240
return layers
241241

242242
def forward_embedding(

parlai/core/params.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,16 @@ def add_distributed_training_args(self):
772772
grp.add_argument(
773773
'--distributed-world-size', type=int, help='Number of workers.'
774774
)
775+
grp.add_argument(
776+
'--ddp-backend',
777+
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
778+
choices=['ddp', 'zero2'],
779+
default='ddp',
780+
help=(
781+
'Distributed backend. Zero2 can be faster but is more experimental. '
782+
'DDP is the most tested.'
783+
),
784+
)
775785
return grp
776786

777787
def add_model_args(self):

parlai/core/torch_agent.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from parlai.utils.distributed import is_distributed
3737
from parlai.utils.misc import AttrDict, warn_once
3838
from parlai.utils.io import PathManager
39+
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
3940
from parlai.utils.fp16 import (
4041
SafeFP16Optimizer,
4142
MemoryEfficientFP16Optimizer,
@@ -1052,7 +1053,9 @@ def init_optim(
10521053
self.optimizer = optim_class(params, **kwargs)
10531054
if self.fp16:
10541055
if self.fp16_impl == 'safe':
1055-
self.optimizer = SafeFP16Optimizer(self.optimizer)
1056+
self.optimizer = SafeFP16Optimizer(
1057+
self.optimizer, should_sync_gradnorm(opt)
1058+
)
10561059
else:
10571060
# Using memory efficient optimizer
10581061
opt_name = opt['optimizer']
@@ -1064,7 +1067,9 @@ def init_optim(
10641067
'with Memory Efficient FP16. Please select from among this '
10651068
f'list:\n{compatible_list}'
10661069
)
1067-
self.optimizer = MemoryEfficientFP16Optimizer(self.optimizer)
1070+
self.optimizer = MemoryEfficientFP16Optimizer(
1071+
self.optimizer, should_sync_gradnorm(opt)
1072+
)
10681073

10691074
if is_finetune:
10701075
logging.warning('Detected a fine-tune run. Resetting the optimizer.')
@@ -1969,10 +1974,11 @@ def state_dict(self):
19691974
"""
19701975
states = {}
19711976
if hasattr(self, 'model'): # save model params
1972-
if hasattr(self.model, 'module'):
1973-
# did we wrap in a DistributedDataParallel
1977+
if hasattr(self.model, 'module') and not is_fsdp(self.model):
1978+
# did we wrap in a DistributedDataParallel or DataParallel
19741979
states['model'] = self.model.module.state_dict()
19751980
else:
1981+
# regular model or FSDP
19761982
states['model'] = self.model.state_dict()
19771983

19781984
if hasattr(self, 'optimizer'):
@@ -1992,6 +1998,16 @@ def state_dict(self):
19921998

19931999
return states
19942000

2001+
def save_nonprimary(self, path=None):
2002+
"""
2003+
Save model parameters, when you are working on the non-primary worker.
2004+
2005+
For models or optimizers that shard parameters, this ensures we sync.
2006+
"""
2007+
if self.opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3'):
2008+
# make sure we call the state dict
2009+
self.state_dict()
2010+
19952011
def save(self, path=None):
19962012
"""
19972013
Save model parameters to path (or default to model_file arg).

parlai/core/torch_generator_agent.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import parlai.utils.logging as logging
3636
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
3737
from parlai.utils.fp16 import FP16SafeCrossEntropy
38+
import parlai.utils.fsdp as fsdp_utils
3839
from parlai.utils.torch import (
3940
neginf,
4041
total_parameters,
@@ -479,8 +480,10 @@ def __init__(self, opt: Opt, shared=None):
479480
else:
480481
# this is not a shared instance of this class, so do full init
481482
self.criterion = self.build_criterion()
482-
# ensure all distributed copies will always be in sync
483-
self.model = self.build_model()
483+
with fsdp_utils.maybe_fsdp_wrap(opt):
484+
self.model = fsdp_utils.fsdp_wrap(self.build_model())
485+
if self.fp16 and not fsdp_utils.delay_halving(opt):
486+
self.model = self.model.half()
484487

485488
# load the block_list for beam search
486489
self.beam_block_list = self._load_beam_block_list()
@@ -498,16 +501,15 @@ def __init__(self, opt: Opt, shared=None):
498501
self.model.cuda()
499502
self.criterion.cuda()
500503

501-
sync_parameters(self.model)
504+
if not fsdp_utils.is_fsdp(self.model):
505+
sync_parameters(self.model)
506+
502507
train_params = trainable_parameters(self.model)
503508
total_params = total_parameters(self.model)
504509
logging.info(
505510
f"Total parameters: {total_params:,d} ({train_params:,d} trainable)"
506511
)
507512

508-
if self.fp16:
509-
self.model = self.model.half()
510-
511513
if init_model is not None:
512514
# load model parameters if available
513515
logging.info(f'Loading existing model params from {init_model}')
@@ -530,7 +532,11 @@ def __init__(self, opt: Opt, shared=None):
530532
logging.warning("Optimizer was reset. Also resetting LR scheduler.")
531533
self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset)
532534

533-
if shared is None and is_distributed():
535+
if (
536+
shared is None
537+
and is_distributed()
538+
and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp'
539+
):
534540
device_ids = None if self.model_parallel else [self.opt['gpu']]
535541
self.model = torch.nn.parallel.DistributedDataParallel(
536542
self.model, device_ids=device_ids, broadcast_buffers=False

parlai/scripts/multiprocessing_eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
"""
2424

2525
import torch
26-
import random
2726
import os
2827
import signal
2928
import parlai.utils.distributed as distributed_utils
@@ -88,7 +87,7 @@ def setup_args(cls):
8887
return setup_args()
8988

9089
def run(self):
91-
port = random.randint(32000, 48000)
90+
port = distributed_utils.find_free_port()
9291
return launch_and_eval(self.opt, port)
9392

9493

parlai/scripts/multiprocessing_train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
"""
2525

2626
import torch
27-
import random
2827
import os
2928
import signal
3029
import traceback
@@ -55,10 +54,12 @@ def multiprocess_train(
5554
raise
5655

5756

58-
def launch_and_train(opt, port):
57+
def launch_and_train(opt, port=None):
5958
"""
6059
Perform a fork() to many processes.
6160
"""
61+
if port is None:
62+
port = distributed_utils.find_free_port()
6263
# Launch multiple subprocesses
6364
spawncontext = torch.multiprocessing.start_processes(
6465
multiprocess_train,
@@ -99,7 +100,7 @@ def setup_args(cls):
99100

100101
def run(self):
101102
if self.opt['port'] is None:
102-
port = random.randint(32000, 48000)
103+
port = None
103104
else:
104105
port = self.opt['port']
105106
return launch_and_train(self.opt, port)

parlai/scripts/train_model.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,20 @@ def save_model(self, suffix=None):
442442
"""
443443
Save the model to disk, possibly with a suffix.
444444
"""
445-
if not is_primary_worker():
446-
# never do IO as a non-primary worker
447-
return
448-
449445
if not self.opt.get('model_file'):
450446
# nothing to save to, just exit
451447
return
452448

453449
fn = self.opt['model_file']
454450
if suffix:
455451
fn += suffix
452+
453+
if not is_primary_worker():
454+
# never do IO as a non-primary worker
455+
if hasattr(self.agent, 'save_nonprimary'):
456+
self.agent.save_nonprimary(fn)
457+
return
458+
456459
while True:
457460
# don't ever let a ctrl-c interrupt saving
458461
try:
@@ -543,7 +546,7 @@ def validate(self):
543546
)
544547
self.best_valid = new_valid
545548
self.impatience = 0
546-
if opt.get('model_file') and is_primary_worker():
549+
if opt.get('model_file'):
547550
logging.info(f"saving best valid model: {opt['model_file']}")
548551
self.save_model()
549552
self.saved = True
@@ -566,11 +569,7 @@ def validate(self):
566569
self.validate_time.reset()
567570

568571
# saving
569-
if (
570-
opt.get('model_file')
571-
and opt.get('save_after_valid')
572-
and is_primary_worker()
573-
):
572+
if opt.get('model_file') and opt.get('save_after_valid'):
574573
logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
575574
self.save_model('.checkpoint')
576575

@@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]:
720719
self._total_epochs = self._preempted_epochs + sum(
721720
all_gather_list(world.get_total_epochs())
722721
)
723-
train_time, log_time, validate_time = sync_object(
722+
train_time, log_time, validate_time, save_time = sync_object(
724723
(
725724
self.train_time.time(),
726725
self.log_time.time(),
727726
self.validate_time.time(),
727+
self.save_time.time(),
728728
)
729729
)
730730
else:
731-
train_time, log_time, validate_time = (
731+
train_time, log_time, validate_time, save_time = (
732732
self.train_time.time(),
733733
self.log_time.time(),
734734
self.validate_time.time(),
735+
self.save_time.time(),
735736
)
736737
self._total_epochs = self._preempted_epochs + (
737738
num_workers() * world.get_total_epochs()
738739
)
739740

740-
return train_time, log_time, validate_time
741+
return train_time, log_time, validate_time, save_time
741742

742743
def log(self):
743744
"""
@@ -810,7 +811,7 @@ def train_steps(self):
810811
self._last_log_steps += 1 / self.update_freq
811812

812813
# the following additionally updates self._total_epochs
813-
train_time, log_time, validate_time = self._get_time(world)
814+
train_time, log_time, validate_time, save_time = self._get_time(world)
814815
# get the total training examples done, compute epochs
815816
exs_per_epoch = world.num_examples()
816817
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
@@ -859,11 +860,7 @@ def train_steps(self):
859860
break
860861
# make sure metrics are clean before we log
861862
world.reset_metrics()
862-
if (
863-
self.save_time.time() > self.save_every_n_secs
864-
and opt.get('model_file')
865-
and is_primary_worker()
866-
):
863+
if save_time > self.save_every_n_secs and opt.get('model_file'):
867864
logging.info(
868865
f"saving model checkpoint: {opt['model_file']}.checkpoint"
869866
)
@@ -872,7 +869,7 @@ def train_steps(self):
872869
self.save_model('.checkpoint')
873870
self.save_time.reset()
874871

875-
if not self.saved and is_primary_worker():
872+
if not sync_object(self.saved):
876873
# save agent
877874
self.save_model()
878875

parlai/utils/distributed.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,19 @@ def distributed_context(
296296
dist.destroy_process_group()
297297

298298

299+
def get_dist_group():
300+
"""
301+
Find the default pytorch distributed group.
302+
303+
Used within FSDP to mark which workers are participating. Important to manually call
304+
this because FSDP will cache old groups, but our test suite will instantiate new
305+
groups per test.
306+
"""
307+
from torch.distributed.distributed_c10d import _get_default_group
308+
309+
return _get_default_group()
310+
311+
299312
@contextlib.contextmanager
300313
def slurm_distributed_context(opt):
301314
"""
@@ -346,3 +359,15 @@ def slurm_distributed_context(opt):
346359
except FileNotFoundError:
347360
# Slurm is not installed
348361
raise RuntimeError('SLURM does not appear to be installed.')
362+
363+
364+
def find_free_port() -> int:
365+
"""
366+
Find a free port we can bind to locally.
367+
368+
Credit: https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
369+
"""
370+
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
371+
s.bind(('', 0))
372+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
373+
return s.getsockname()[1]

0 commit comments

Comments
 (0)