Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 0b9afe8

Browse files
authoredJul 1, 2021
Fully Sharded Data Parallel (#3740)
* Implement zero2 and zero3 * Implement overflow syncing. * Tweak log statements. * Use free ports rather than random ports * Refactor test_distributed * More refactor. * Fixup checkpoints. * Get tests working. * GPU only * Sigh * Moar. * Trying to sync grad norms * Correctly implement gnorm syncing. * Update comment. * Try zero3. * Okay got zero3 working. * Refactor. * Get FSDP Zero3 working, except during validation. * Check in missing code. Carve out notimplemented. * Lint. * Er. * Add a test to ensure we keep track of zero3 not working. * Remove debugs, add docstrings, rename variable. * Silly * Reviewer comments. * Lint. * We disabled zero3 as an option, so don't need the test. * Bug caught by Kurt. * Rofl
1 parent 7400795 commit 0b9afe8

12 files changed

+358
-158
lines changed
 

‎parlai/agents/transformer/modules/decoder.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from parlai.core.opt import Opt
2626
from parlai.utils.misc import warn_once
2727
from parlai.utils.torch import PipelineHelper
28+
from parlai.utils.fsdp import fsdp_wrap
2829

2930

3031
@swappable(
@@ -277,16 +278,15 @@ def _default(val, default):
277278
def build_layers(self) -> nn.ModuleList:
278279
layers = nn.ModuleList()
279280
for _ in range(self.n_layers):
280-
layers.append(
281-
self.swappables.layer(
282-
self.opt,
283-
attention_dropout=self.opt.get('attention_dropout', 0.0),
284-
relu_dropout=self.opt.get('relu_dropout', 0.0),
285-
dropout=self.opt.get('dropout', 0.0),
286-
activation=self.activation,
287-
variant=self.variant,
288-
) # type: ignore
281+
layer = self.swappables.layer(
282+
self.opt,
283+
attention_dropout=self.opt.get('attention_dropout', 0.0),
284+
relu_dropout=self.opt.get('relu_dropout', 0.0),
285+
dropout=self.opt.get('dropout', 0.0),
286+
activation=self.activation,
287+
variant=self.variant,
289288
)
289+
layers.append(fsdp_wrap(layer)) # type: ignore
290290
return layers
291291

292292
def forward_embedding(

‎parlai/agents/transformer/modules/encoder.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from parlai.core.opt import Opt
2626
from parlai.utils.misc import warn_once
2727
from parlai.utils.torch import PipelineHelper
28+
from parlai.utils.fsdp import fsdp_wrap
2829

2930

3031
@swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN)
@@ -227,16 +228,15 @@ def _default(val, default):
227228
def build_layers(self) -> nn.ModuleList:
228229
layers = nn.ModuleList()
229230
for _ in range(self.n_layers):
230-
layers.append(
231-
self.swappables.layer( # type: ignore
232-
self.opt,
233-
attention_dropout=self.opt.get('attention_dropout', 0.0),
234-
relu_dropout=self.opt.get('relu_dropout', 0.0),
235-
dropout=self.dropout_frac,
236-
variant=self.variant,
237-
activation=self.activation,
238-
)
231+
layer = self.swappables.layer( # type: ignore
232+
self.opt,
233+
attention_dropout=self.opt.get('attention_dropout', 0.0),
234+
relu_dropout=self.opt.get('relu_dropout', 0.0),
235+
dropout=self.dropout_frac,
236+
variant=self.variant,
237+
activation=self.activation,
239238
)
239+
layers.append(fsdp_wrap(layer))
240240
return layers
241241

242242
def forward_embedding(

‎parlai/core/params.py

+10
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,16 @@ def add_distributed_training_args(self):
772772
grp.add_argument(
773773
'--distributed-world-size', type=int, help='Number of workers.'
774774
)
775+
grp.add_argument(
776+
'--ddp-backend',
777+
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
778+
choices=['ddp', 'zero2'],
779+
default='ddp',
780+
help=(
781+
'Distributed backend. Zero2 can be faster but is more experimental. '
782+
'DDP is the most tested.'
783+
),
784+
)
775785
return grp
776786

777787
def add_model_args(self):

‎parlai/core/torch_agent.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from parlai.utils.distributed import is_distributed
3737
from parlai.utils.misc import AttrDict, warn_once
3838
from parlai.utils.io import PathManager
39+
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
3940
from parlai.utils.fp16 import (
4041
SafeFP16Optimizer,
4142
MemoryEfficientFP16Optimizer,
@@ -1052,7 +1053,9 @@ def init_optim(
10521053
self.optimizer = optim_class(params, **kwargs)
10531054
if self.fp16:
10541055
if self.fp16_impl == 'safe':
1055-
self.optimizer = SafeFP16Optimizer(self.optimizer)
1056+
self.optimizer = SafeFP16Optimizer(
1057+
self.optimizer, should_sync_gradnorm(opt)
1058+
)
10561059
else:
10571060
# Using memory efficient optimizer
10581061
opt_name = opt['optimizer']
@@ -1064,7 +1067,9 @@ def init_optim(
10641067
'with Memory Efficient FP16. Please select from among this '
10651068
f'list:\n{compatible_list}'
10661069
)
1067-
self.optimizer = MemoryEfficientFP16Optimizer(self.optimizer)
1070+
self.optimizer = MemoryEfficientFP16Optimizer(
1071+
self.optimizer, should_sync_gradnorm(opt)
1072+
)
10681073

10691074
if is_finetune:
10701075
logging.warning('Detected a fine-tune run. Resetting the optimizer.')
@@ -1969,10 +1974,11 @@ def state_dict(self):
19691974
"""
19701975
states = {}
19711976
if hasattr(self, 'model'): # save model params
1972-
if hasattr(self.model, 'module'):
1973-
# did we wrap in a DistributedDataParallel
1977+
if hasattr(self.model, 'module') and not is_fsdp(self.model):
1978+
# did we wrap in a DistributedDataParallel or DataParallel
19741979
states['model'] = self.model.module.state_dict()
19751980
else:
1981+
# regular model or FSDP
19761982
states['model'] = self.model.state_dict()
19771983

19781984
if hasattr(self, 'optimizer'):
@@ -1992,6 +1998,16 @@ def state_dict(self):
19921998

19931999
return states
19942000

2001+
def save_nonprimary(self, path=None):
2002+
"""
2003+
Save model parameters, when you are working on the non-primary worker.
2004+
2005+
For models or optimizers that shard parameters, this ensures we sync.
2006+
"""
2007+
if self.opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3'):
2008+
# make sure we call the state dict
2009+
self.state_dict()
2010+
19952011
def save(self, path=None):
19962012
"""
19972013
Save model parameters to path (or default to model_file arg).

‎parlai/core/torch_generator_agent.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import parlai.utils.logging as logging
3636
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
3737
from parlai.utils.fp16 import FP16SafeCrossEntropy
38+
import parlai.utils.fsdp as fsdp_utils
3839
from parlai.utils.torch import (
3940
neginf,
4041
total_parameters,
@@ -479,8 +480,10 @@ def __init__(self, opt: Opt, shared=None):
479480
else:
480481
# this is not a shared instance of this class, so do full init
481482
self.criterion = self.build_criterion()
482-
# ensure all distributed copies will always be in sync
483-
self.model = self.build_model()
483+
with fsdp_utils.maybe_fsdp_wrap(opt):
484+
self.model = fsdp_utils.fsdp_wrap(self.build_model())
485+
if self.fp16 and not fsdp_utils.delay_halving(opt):
486+
self.model = self.model.half()
484487

485488
# load the block_list for beam search
486489
self.beam_block_list = self._load_beam_block_list()
@@ -498,16 +501,15 @@ def __init__(self, opt: Opt, shared=None):
498501
self.model.cuda()
499502
self.criterion.cuda()
500503

501-
sync_parameters(self.model)
504+
if not fsdp_utils.is_fsdp(self.model):
505+
sync_parameters(self.model)
506+
502507
train_params = trainable_parameters(self.model)
503508
total_params = total_parameters(self.model)
504509
logging.info(
505510
f"Total parameters: {total_params:,d} ({train_params:,d} trainable)"
506511
)
507512

508-
if self.fp16:
509-
self.model = self.model.half()
510-
511513
if init_model is not None:
512514
# load model parameters if available
513515
logging.info(f'Loading existing model params from {init_model}')
@@ -530,7 +532,11 @@ def __init__(self, opt: Opt, shared=None):
530532
logging.warning("Optimizer was reset. Also resetting LR scheduler.")
531533
self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset)
532534

533-
if shared is None and is_distributed():
535+
if (
536+
shared is None
537+
and is_distributed()
538+
and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp'
539+
):
534540
device_ids = None if self.model_parallel else [self.opt['gpu']]
535541
self.model = torch.nn.parallel.DistributedDataParallel(
536542
self.model, device_ids=device_ids, broadcast_buffers=False

‎parlai/scripts/multiprocessing_eval.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
"""
2424

2525
import torch
26-
import random
2726
import os
2827
import signal
2928
import parlai.utils.distributed as distributed_utils
@@ -88,7 +87,7 @@ def setup_args(cls):
8887
return setup_args()
8988

9089
def run(self):
91-
port = random.randint(32000, 48000)
90+
port = distributed_utils.find_free_port()
9291
return launch_and_eval(self.opt, port)
9392

9493

‎parlai/scripts/multiprocessing_train.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
"""
2525

2626
import torch
27-
import random
2827
import os
2928
import signal
3029
import traceback
@@ -55,10 +54,12 @@ def multiprocess_train(
5554
raise
5655

5756

58-
def launch_and_train(opt, port):
57+
def launch_and_train(opt, port=None):
5958
"""
6059
Perform a fork() to many processes.
6160
"""
61+
if port is None:
62+
port = distributed_utils.find_free_port()
6263
# Launch multiple subprocesses
6364
spawncontext = torch.multiprocessing.start_processes(
6465
multiprocess_train,
@@ -99,7 +100,7 @@ def setup_args(cls):
99100

100101
def run(self):
101102
if self.opt['port'] is None:
102-
port = random.randint(32000, 48000)
103+
port = None
103104
else:
104105
port = self.opt['port']
105106
return launch_and_train(self.opt, port)

‎parlai/scripts/train_model.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,20 @@ def save_model(self, suffix=None):
442442
"""
443443
Save the model to disk, possibly with a suffix.
444444
"""
445-
if not is_primary_worker():
446-
# never do IO as a non-primary worker
447-
return
448-
449445
if not self.opt.get('model_file'):
450446
# nothing to save to, just exit
451447
return
452448

453449
fn = self.opt['model_file']
454450
if suffix:
455451
fn += suffix
452+
453+
if not is_primary_worker():
454+
# never do IO as a non-primary worker
455+
if hasattr(self.agent, 'save_nonprimary'):
456+
self.agent.save_nonprimary(fn)
457+
return
458+
456459
while True:
457460
# don't ever let a ctrl-c interrupt saving
458461
try:
@@ -543,7 +546,7 @@ def validate(self):
543546
)
544547
self.best_valid = new_valid
545548
self.impatience = 0
546-
if opt.get('model_file') and is_primary_worker():
549+
if opt.get('model_file'):
547550
logging.info(f"saving best valid model: {opt['model_file']}")
548551
self.save_model()
549552
self.saved = True
@@ -566,11 +569,7 @@ def validate(self):
566569
self.validate_time.reset()
567570

568571
# saving
569-
if (
570-
opt.get('model_file')
571-
and opt.get('save_after_valid')
572-
and is_primary_worker()
573-
):
572+
if opt.get('model_file') and opt.get('save_after_valid'):
574573
logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
575574
self.save_model('.checkpoint')
576575

@@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]:
720719
self._total_epochs = self._preempted_epochs + sum(
721720
all_gather_list(world.get_total_epochs())
722721
)
723-
train_time, log_time, validate_time = sync_object(
722+
train_time, log_time, validate_time, save_time = sync_object(
724723
(
725724
self.train_time.time(),
726725
self.log_time.time(),
727726
self.validate_time.time(),
727+
self.save_time.time(),
728728
)
729729
)
730730
else:
731-
train_time, log_time, validate_time = (
731+
train_time, log_time, validate_time, save_time = (
732732
self.train_time.time(),
733733
self.log_time.time(),
734734
self.validate_time.time(),
735+
self.save_time.time(),
735736
)
736737
self._total_epochs = self._preempted_epochs + (
737738
num_workers() * world.get_total_epochs()
738739
)
739740

740-
return train_time, log_time, validate_time
741+
return train_time, log_time, validate_time, save_time
741742

742743
def log(self):
743744
"""
@@ -810,7 +811,7 @@ def train_steps(self):
810811
self._last_log_steps += 1 / self.update_freq
811812

812813
# the following additionally updates self._total_epochs
813-
train_time, log_time, validate_time = self._get_time(world)
814+
train_time, log_time, validate_time, save_time = self._get_time(world)
814815
# get the total training examples done, compute epochs
815816
exs_per_epoch = world.num_examples()
816817
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
@@ -859,11 +860,7 @@ def train_steps(self):
859860
break
860861
# make sure metrics are clean before we log
861862
world.reset_metrics()
862-
if (
863-
self.save_time.time() > self.save_every_n_secs
864-
and opt.get('model_file')
865-
and is_primary_worker()
866-
):
863+
if save_time > self.save_every_n_secs and opt.get('model_file'):
867864
logging.info(
868865
f"saving model checkpoint: {opt['model_file']}.checkpoint"
869866
)
@@ -872,7 +869,7 @@ def train_steps(self):
872869
self.save_model('.checkpoint')
873870
self.save_time.reset()
874871

875-
if not self.saved and is_primary_worker():
872+
if not sync_object(self.saved):
876873
# save agent
877874
self.save_model()
878875

‎parlai/utils/distributed.py

+25
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,19 @@ def distributed_context(
296296
dist.destroy_process_group()
297297

298298

299+
def get_dist_group():
300+
"""
301+
Find the default pytorch distributed group.
302+
303+
Used within FSDP to mark which workers are participating. Important to manually call
304+
this because FSDP will cache old groups, but our test suite will instantiate new
305+
groups per test.
306+
"""
307+
from torch.distributed.distributed_c10d import _get_default_group
308+
309+
return _get_default_group()
310+
311+
299312
@contextlib.contextmanager
300313
def slurm_distributed_context(opt):
301314
"""
@@ -346,3 +359,15 @@ def slurm_distributed_context(opt):
346359
except FileNotFoundError:
347360
# Slurm is not installed
348361
raise RuntimeError('SLURM does not appear to be installed.')
362+
363+
364+
def find_free_port() -> int:
365+
"""
366+
Find a free port we can bind to locally.
367+
368+
Credit: https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
369+
"""
370+
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
371+
s.bind(('', 0))
372+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
373+
return s.getsockname()[1]

‎parlai/utils/fp16.py

+45-16
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,48 @@ def forward(self, scores, targets):
5555
)
5656

5757

58-
def clip_grad_norm(params, max_norm):
58+
def clip_grad_norm(params, max_norm: float = 0, sync: bool = False):
5959
"""
60-
Clips grad norm.
60+
Clips grad norms.
61+
62+
During combination with FSDP, will also ensure that grad norms are aggregated
63+
across all workers, since each worker only stores their shard of the
64+
gradients.
65+
66+
:param params:
67+
Parameters whose gradients we wish to clip
68+
:param max_norm:
69+
Maximum norm we wish the gradients to have. If non-positive, then
70+
we will not perform clipping.
71+
:param sync:
72+
Boolean indicating whether we should aggregate across the distributed
73+
group. Used only in combination with FSDP.
74+
75+
:returns:
76+
The gradient norm across all parameters, before clipping.
6177
"""
6278
if isinstance(params, torch.Tensor):
6379
params = [params]
6480
# make sure any generators are expanded
6581
params = list(params)
66-
if len(params) == 1:
67-
p = params[0].grad
68-
grad_norm = torch.norm(p)
69-
if grad_norm > max_norm > 0:
70-
clip_coef = max_norm / (grad_norm + 1e-6)
71-
p.mul_(clip_coef)
72-
return grad_norm
73-
elif max_norm > 0:
82+
# if syncing we need to manually perform the clipping so that we aggregrate
83+
# properly
84+
if max_norm > 0 and not sync:
7485
return torch.nn.utils.clip_grad_norm_(params, max_norm)
7586
else:
76-
return torch.sqrt(
77-
sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None)
78-
)
87+
normsq = sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None)
88+
if sync:
89+
# also need to get the norms from all the other sharded works in FSDP
90+
import torch.distributed as dist
91+
92+
dist.all_reduce(normsq)
93+
grad_norm = normsq.sqrt()
94+
if max_norm > 0:
95+
clip_coef = max_norm / (grad_norm + 1e-6)
96+
for p in params:
97+
p.grad.detach().mul_(clip_coef)
98+
99+
return grad_norm
79100

80101

81102
def has_overflow(grad_norm):
@@ -88,7 +109,7 @@ def has_overflow(grad_norm):
88109

89110

90111
class SafeFP16Optimizer(torch.optim.Optimizer):
91-
def __init__(self, optimizer):
112+
def __init__(self, optimizer, aggregate_gnorms=False):
92113
self.fp16_params = self._get_parameters(optimizer)
93114
self.fp32_params = self._build_fp32_params(self.fp16_params, flatten=False)
94115
self.optimizer = optimizer
@@ -103,6 +124,7 @@ def __init__(self, optimizer):
103124

104125
self.scaler = DynamicLossScaler(2.0 ** 15)
105126
self.min_loss_scale = 2 ** -5
127+
self._aggregate_gnorms = aggregate_gnorms
106128

107129
@classmethod
108130
def _get_parameters(cls, optimizer):
@@ -210,7 +232,9 @@ def clip_master_grads(self, max_norm):
210232
Clips gradient norm and updates dynamic loss scaler.
211233
"""
212234
self._sync_fp16_grads_to_fp32()
213-
grad_norm = clip_grad_norm(self.fp32_params, max_norm)
235+
grad_norm = clip_grad_norm(
236+
self.fp32_params, max_norm, sync=self._aggregate_gnorms
237+
)
214238

215239
# detect overflow and adjust loss scale
216240
if self.scaler is not None:
@@ -390,6 +414,7 @@ class MemoryEfficientFP16Optimizer(torch.optim.Optimizer):
390414
def __init__(
391415
self,
392416
init_optimizer: torch.optim.Optimizer, # type: ignore
417+
aggregate_gnorms: bool = False,
393418
loss_initial_scale: float = 2.0 ** 17,
394419
min_loss_scale: float = 1e-4,
395420
):
@@ -398,6 +423,8 @@ def __init__(
398423
self.min_loss_scale = min_loss_scale
399424
self.scaler = DynamicLossScaler(init_scale=loss_initial_scale)
400425

426+
self._aggregate_gnorms = aggregate_gnorms
427+
401428
@staticmethod
402429
def compatible_optimizers():
403430
"""
@@ -446,7 +473,9 @@ def clip_master_grads(self, gradient_clip):
446473
Returns -1 if the most recently computed gradients overflowed.
447474
"""
448475
self._unscale_grads()
449-
grad_norm = clip_grad_norm(self.params, gradient_clip)
476+
grad_norm = clip_grad_norm(
477+
self.params, gradient_clip, sync=self._aggregate_gnorms
478+
)
450479
# detect overflow and adjust loss scale
451480
overflow = has_overflow(grad_norm)
452481
self.scaler.update_scale(overflow)

‎parlai/utils/fsdp.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Utility functions for FullyShardedDataParallel.
9+
"""
10+
11+
import contextlib
12+
import torch.nn
13+
from parlai.utils.distributed import is_distributed, get_dist_group
14+
15+
try:
16+
from fairscale.nn.wrap.auto_wrap import wrap
17+
from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap
18+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
19+
20+
FSDP_AVAILABLE = True
21+
except ImportError:
22+
FSDP_AVAILABLE = False
23+
24+
def wrap(module, **kwargs):
25+
return module
26+
27+
28+
DEFAULT_DDP_BACKEND = "ddp"
29+
30+
31+
def is_fsdp(module: torch.nn.Module):
32+
"""
33+
Checks whether a module is fully sharded.
34+
"""
35+
return FSDP_AVAILABLE and isinstance(module, FSDP)
36+
37+
38+
def should_use_fsdp(opt):
39+
return (
40+
FSDP_AVAILABLE
41+
and is_distributed()
42+
and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3')
43+
)
44+
45+
46+
@contextlib.contextmanager
47+
def maybe_fsdp_wrap(opt):
48+
"""
49+
Context manager for enabling wrapping in FullyShardedDataParallel.
50+
"""
51+
if not should_use_fsdp(opt):
52+
# make a no-op
53+
yield
54+
return
55+
56+
# zero3 not supported at this time. Throw an exception
57+
if opt['ddp_backend'] == 'zero3':
58+
raise NotImplementedError(
59+
'--ddp-backend zero3 is not supported at this time. For details, see '
60+
'https://github.com/facebookresearch/ParlAI/issues/3753.'
61+
)
62+
63+
reshard_after_forward = opt['ddp_backend'] == 'zero3'
64+
compute_dtype = torch.float16 if opt['fp16'] else torch.float32
65+
mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe'
66+
fsdp_args = dict(
67+
reshard_after_forward=reshard_after_forward,
68+
mixed_precision=mixed_precision,
69+
compute_dtype=compute_dtype,
70+
state_dict_device=torch.device('cpu'),
71+
flatten_parameters=True,
72+
process_group=get_dist_group(),
73+
)
74+
with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args):
75+
yield
76+
77+
78+
def delay_halving(opt):
79+
"""
80+
Check whether we should keep the model in fp32 before other setup.
81+
82+
When using Zero2 or Zero3 backends with mixed precision, we need to avoid converting
83+
the model to fp16, as the FSDP module does this for us.
84+
85+
If we are using just plain DDP or MemoryEfficient optimizers, then we want
86+
to call half() early.
87+
"""
88+
89+
return opt['fp16'] and should_use_fsdp(opt) and opt['fp16_impl'] == 'safe'
90+
91+
92+
def should_sync_gradnorm(opt):
93+
"""
94+
Indicates whether fp16 optimizer wrappers should accumulate over workers.
95+
96+
FP16 overflow detection and gradient clipping both require accumulating gradients
97+
across all workers when using FSDP, as workers only store a fraction of the
98+
gradients.
99+
"""
100+
return (
101+
FSDP_AVAILABLE
102+
and opt['fp16']
103+
and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3')
104+
)
105+
106+
107+
def fsdp_wrap(module):
108+
"""
109+
Helper function for wrapping the outermost root module.
110+
"""
111+
return wrap(module)

‎tests/test_distributed.py

+94-88
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
import copy
98
import unittest
109
import parlai.utils.testing as testing_utils
1110
import parlai.scripts.build_dict as build_dict
@@ -15,21 +14,30 @@
1514
BATCHSIZE = 4
1615

1716

18-
def _forced_parse(parser, opt):
19-
parser.set_params(**opt)
20-
parser.set_params(log_every_n_sec=10)
21-
popt = parser.parse_args([])
22-
# in some rare cases, like for instance if the model class also
23-
# overrides its default params, the params override will not
24-
# be taken into account.
25-
for k, v in opt.items():
26-
popt[k] = v
27-
return popt
17+
class _AbstractTest(unittest.TestCase):
18+
def _distributed_train_model(self, **overrides):
19+
opt = {**self.base_config, **overrides}
20+
with testing_utils.tempdir() as tmpdir:
21+
if 'model_file' not in opt:
22+
opt['model_file'] = os.path.join(tmpdir, 'model')
23+
if 'dict_file' not in opt:
24+
opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
25+
26+
parser = mp_train.setup_args()
27+
popt = parser.parse_kwargs(**opt)
28+
29+
# we need a prebuilt dictionary
30+
parser = build_dict.setup_args()
31+
build_dict.build_dict(popt)
32+
33+
valid, test = mp_train.launch_and_train(popt)
34+
35+
return (valid, test)
2836

2937

3038
@testing_utils.skipUnlessGPU
31-
class TestDistributed(unittest.TestCase):
32-
_base_config = dict(
39+
class TestDistributed(_AbstractTest):
40+
base_config = dict(
3341
task='integration_tests:overfit',
3442
model='transformer/generator',
3543
optimizer='adam',
@@ -46,30 +54,8 @@ class TestDistributed(unittest.TestCase):
4654
verbose=True,
4755
)
4856

49-
def setUp(self):
50-
print(f'[Setting up test {self._testMethodName}]')
51-
52-
def _distributed_train_model(self, opt):
53-
with testing_utils.tempdir() as tmpdir:
54-
if 'model_file' not in opt:
55-
opt['model_file'] = os.path.join(tmpdir, 'model')
56-
if 'dict_file' not in opt:
57-
opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
58-
59-
parser = mp_train.setup_args()
60-
popt = _forced_parse(parser, opt)
61-
62-
# we need a prebuilt dictionary
63-
parser = build_dict.setup_args()
64-
build_dict.build_dict(popt)
65-
66-
valid, test = mp_train.launch_and_train(popt, 31338)
67-
68-
return (valid, test)
69-
7057
def test_generator_distributed(self):
71-
config = copy.deepcopy(self._base_config)
72-
valid, test = self._distributed_train_model(config)
58+
valid, test = self._distributed_train_model()
7359

7460
self.assertLessEqual(valid['ppl'], 1.60)
7561
self.assertLessEqual(test['ppl'], 1.60)
@@ -80,11 +66,11 @@ def test_generator_distributed(self):
8066
self.assertEqual(test['exs'].value(), BATCHSIZE)
8167

8268
def test_multitask_distributed(self):
83-
config = copy.deepcopy(self._base_config)
84-
config['num_epochs'] = 50
85-
config['task'] = 'integration_tests:overfit,integration_tests:overfit_multiturn'
86-
config['dynb'] = 'full'
87-
valid, test = self._distributed_train_model(config)
69+
valid, test = self._distributed_train_model(
70+
num_epochs=50,
71+
task='integration_tests:overfit,integration_tests:overfit_multiturn',
72+
truncate=16,
73+
)
8874

8975
self.assertLessEqual(valid['ppl'], 1.20)
9076
self.assertLessEqual(test['ppl'], 1.20)
@@ -100,12 +86,12 @@ def test_multitask_distributed(self):
10086
)
10187

10288
def test_distributed_eval_max_exs(self):
103-
config = copy.deepcopy(self._base_config)
104-
config['task'] = 'integration_tests'
105-
config['num_epochs'] = 0.01
106-
config['validation_max_exs'] = 90
107-
config['short_final_eval'] = True
108-
valid, test = self._distributed_train_model(config)
89+
valid, test = self._distributed_train_model(
90+
task='integration_tests',
91+
num_epochs=0.01,
92+
validation_max_exs=90,
93+
short_final_eval=True,
94+
)
10995

11096
# Tests that DialogData.get() is doing the right thing
11197
# Ensure no duplication of examples among workers
@@ -120,11 +106,9 @@ def test_distributed_eval_max_exs(self):
120106
self.assertEqual(test['exs'].value(), 96)
121107

122108
def test_distributed_eval_stream_mode(self):
123-
config = copy.deepcopy(self._base_config)
124-
config['task'] = 'integration_tests'
125-
config['num_epochs'] = 0.01
126-
config['datatype'] = 'train:stream'
127-
valid, test = self._distributed_train_model(config)
109+
valid, test = self._distributed_train_model(
110+
task='integration_tests', num_epochs=0.01, datatype='train:stream'
111+
)
128112

129113
# Tests that StreamDialogData.get() is doing the right thing
130114
# Ensure no duplication of examples among workers
@@ -133,14 +117,13 @@ def test_distributed_eval_stream_mode(self):
133117
self.assertEqual(test['exs'].value(), inttests.NUM_TEST)
134118

135119
def test_distributed_eval_stream_mode_max_exs(self):
136-
config = copy.deepcopy(self._base_config)
137-
config['task'] = 'integration_tests'
138-
config['num_epochs'] = 0.01
139-
config['datatype'] = 'train:stream'
140-
config['validation_max_exs'] = 90
141-
config['short_final_eval'] = True
142-
143-
valid, test = self._distributed_train_model(config)
120+
valid, test = self._distributed_train_model(
121+
task='integration_tests',
122+
num_epochs=0.01,
123+
datatype='train:stream',
124+
validation_max_exs=90,
125+
short_final_eval=True,
126+
)
144127

145128
# Tests that StreamDialogData.get() is doing the right thing
146129
# Ensure no duplication of examples among workers
@@ -155,45 +138,68 @@ def test_distributed_eval_stream_mode_max_exs(self):
155138
self.assertEqual(test['exs'].value(), 96)
156139

157140
def test_chunked_dynamic_teacher(self):
158-
config = copy.deepcopy(self._base_config)
159-
config['task'] = 'integration_tests'
160-
config['num_epochs'] = 0.01
161-
config['datatype'] = 'train:stream'
162-
config['dynamic_batching'] = 'full'
163-
config['truncate'] = 16
164-
165-
valid, test = self._distributed_train_model(config)
141+
valid, test = self._distributed_train_model(
142+
task='integration_tests',
143+
num_epochs=0.01,
144+
datatype='train:stream',
145+
dynamic_batching='full',
146+
truncate=16,
147+
)
166148
assert valid['exs'].value() == inttests.NUM_TEST
167149
assert test['exs'].value() == inttests.NUM_TEST
168150

169151
def test_chunked_teacher(self):
170-
config = copy.deepcopy(self._base_config)
171-
config['task'] = 'integration_tests'
172-
config['num_epochs'] = 0.01
173-
config['datatype'] = 'train:stream'
174-
config['num_epochs'] = 5
175-
config['dynamic_batching'] = None
176-
177-
valid, test = self._distributed_train_model(config)
152+
valid, test = self._distributed_train_model(
153+
task='integration_tests',
154+
datatype='train:stream',
155+
num_epochs=5,
156+
dynamic_batching=None,
157+
)
178158
assert valid['exs'].value() == inttests.NUM_TEST
179159
assert test['exs'].value() == inttests.NUM_TEST
180160

161+
162+
@testing_utils.skipUnlessGPU
163+
class TestZero2(TestDistributed):
164+
"""
165+
Integration tests for zero2 FSDP.
166+
"""
167+
168+
base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'}
169+
170+
171+
@unittest.skip
172+
@testing_utils.skipUnlessGPU
173+
class TestZero3(TestDistributed):
174+
# Not supported at this time. See:
175+
# https://github.com/facebookresearch/ParlAI/pull/3740
176+
base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'}
177+
178+
179+
@testing_utils.skipUnlessGPU
180+
class TestNoModelParallel(_AbstractTest):
181+
base_config = dict(
182+
task='integration_tests:overfit',
183+
optimizer='sgd',
184+
validation_metric='loss',
185+
learningrate=1e-2,
186+
batchsize=BATCHSIZE,
187+
validation_every_n_epochs=1,
188+
num_epochs=1,
189+
n_layers=1,
190+
n_heads=1,
191+
ffn_size=32,
192+
embedding_size=8,
193+
verbose=True,
194+
)
195+
181196
def test_no_model_parallel(self):
182197
"""
183-
Checks that we throw an error when combining mp_train with.
184-
185-
--model-parallel true.
198+
Checks that we throw an error when combining mp_train with --model-parallel.
186199
"""
187-
config = copy.deepcopy(self._base_config)
188-
config['model_parallel'] = True
189-
for m in [
190-
'transformer/generator',
191-
'transformer/ranker',
192-
'transformer/classifier',
193-
]:
194-
config['model'] = m
200+
for m in ['transformer/generator', 'transformer/ranker']:
195201
try:
196-
_ = self._distributed_train_model(config)
202+
_ = self._distributed_train_model(model=m, model_parallel=True)
197203
except RuntimeError:
198204
pass
199205
else:

0 commit comments

Comments
 (0)
This repository has been archived.