From a5b4da4b2c3f67c4c37303148dfaed6c68c4772c Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 21 Jun 2021 21:19:14 -0400 Subject: [PATCH 01/29] Implement zero2 and zero3 --- parlai/core/params.py | 10 +++++++ parlai/core/torch_agent.py | 14 ++++++++-- parlai/core/torch_generator_agent.py | 42 ++++++++++++++++++++++++++-- parlai/utils/fp16.py | 16 +++++++++-- 4 files changed, 75 insertions(+), 7 deletions(-) diff --git a/parlai/core/params.py b/parlai/core/params.py index 6465df14f6b..f94db1cee16 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -772,6 +772,16 @@ def add_distributed_training_args(self): grp.add_argument( '--distributed-world-size', type=int, help='Number of workers.' ) + grp.add_argument( + '--ddp-backend', + choices=['ddp', 'zero2', 'zero3'], + default='ddp', + help=( + 'Distributed backend. Zero2 can be faster but is more experimental. ' + 'Zero3 uses radically less memory, but is slower. DDP is the most ' + 'tested.' + ), + ) return grp def add_model_args(self): diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 5994b344754..bc8ebba5dc4 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1052,7 +1052,9 @@ def init_optim( self.optimizer = optim_class(params, **kwargs) if self.fp16: if self.fp16_impl == 'safe': - self.optimizer = SafeFP16Optimizer(self.optimizer) + self.optimizer = SafeFP16Optimizer( + self.optimizer, self._should_sync_overflows() + ) else: # Using memory efficient optimizer opt_name = opt['optimizer'] @@ -1064,7 +1066,9 @@ def init_optim( 'with Memory Efficient FP16. Please select from among this ' f'list:\n{compatible_list}' ) - self.optimizer = MemoryEfficientFP16Optimizer(self.optimizer) + self.optimizer = MemoryEfficientFP16Optimizer( + self.optimizer, self._should_sync_overflows() + ) if is_finetune: logging.warning('Detected a fine-tune run. Resetting the optimizer.') @@ -1113,6 +1117,9 @@ def init_optim( ) return True + def _should_sync_overflows(self): + return self.fp16 and self.opt['ddp_backend'] in ('zero2', 'zero3') + def build_lr_scheduler(self, states=None, hard_reset=False): """ Create the learning rate scheduler, and assign it to self.scheduler. This @@ -2345,6 +2352,9 @@ def update_params(self): self.global_metrics.add('gnorm', GlobalAverageMetric(grad_norm)) if self.fp16: + logging.info( + f"fp16_loss_scale = {self.optimizer.loss_scale} [{self._number_training_updates}]" + ) self.global_metrics.add( 'fp16_loss_scalar', GlobalAverageMetric(self.optimizer.loss_scale) ) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 79da7d44325..1f7dac3d007 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -506,7 +506,8 @@ def __init__(self, opt: Opt, shared=None): ) if self.fp16: - self.model = self.model.half() + if not self._delay_halving: + self.model = self.model.half() if init_model is not None: # load model parameters if available @@ -515,6 +516,28 @@ def __init__(self, opt: Opt, shared=None): else: states = {} + if ( + shared is None + and is_distributed() + and opt['ddp_backend'] in ('zero2', 'zero3') + ): + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + + device_ids = None if self.model_parallel else [self.opt['gpu']] + mixed_precision = opt['fp16'] + reshard_after_forward = opt['ddp_backend'] == 'zero3' + # hack: fsdp expects things in fp32 if we're using mixed precision. + # lol! convert it back! + if self.fp16 and mixed_precision: + self.model = self.model.float() + self.model = FSDP( + self.model, + reshard_after_forward=reshard_after_forward, + mixed_precision=self.fp16 and opt['fp16_impl'] != 'mem_efficient', + compute_dtype=torch.float16 if self.fp16 else torch.float32, + state_dict_device=torch.device('cpu'), + ) + if shared is not None: if 'optimizer' in shared: self.optimizer = shared['optimizer'] @@ -530,7 +553,7 @@ def __init__(self, opt: Opt, shared=None): logging.warning("Optimizer was reset. Also resetting LR scheduler.") self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset) - if shared is None and is_distributed(): + if shared is None and is_distributed() and opt['ddp_backend'] == 'ddp': device_ids = None if self.model_parallel else [self.opt['gpu']] self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=device_ids, broadcast_buffers=False @@ -538,6 +561,21 @@ def __init__(self, opt: Opt, shared=None): self.reset() + def _delay_halving(self): + """ + Check whether we should keep the model in fp32 before other setup. + + When using Zero2 or Zero3 backends with mixed precision, we need to + avoid converting the model to fp16, as the FSDP module does this for + us. + """ + + return ( + self.fp16 + and self.opt['ddp_backend'] in ('zero2', 'zero3') + and self.opt['fp16_impl'] == 'safe' + ) + def build_criterion(self): """ Construct and return the loss function. diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index 073f7757fb4..a7f216f70e7 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -88,7 +88,7 @@ def has_overflow(grad_norm): class SafeFP16Optimizer(torch.optim.Optimizer): - def __init__(self, optimizer): + def __init__(self, optimizer, sync_overflows=False): self.fp16_params = self._get_parameters(optimizer) self.fp32_params = self._build_fp32_params(self.fp16_params, flatten=False) self.optimizer = optimizer @@ -103,6 +103,16 @@ def __init__(self, optimizer): self.scaler = DynamicLossScaler(2.0 ** 15) self.min_loss_scale = 2 ** -5 + self._sync_overflows = sync_overflows + + def _maybe_sync(self, value: bool) -> bool: + if self._sync_overflows: + import torch.distributed as dist + + value_tensor = torch.BoolTensor([value]).cuda() + dist.all_reduce(value_tensor) + value = value_tensor.item() + return value @classmethod def _get_parameters(cls, optimizer): @@ -214,7 +224,7 @@ def clip_master_grads(self, max_norm): # detect overflow and adjust loss scale if self.scaler is not None: - overflow = has_overflow(grad_norm) + overflow = self._maybe_sync(has_overflow(grad_norm)) prev_scale = self.scaler.loss_scale self.scaler.update_scale(overflow) if overflow: @@ -448,7 +458,7 @@ def clip_master_grads(self, gradient_clip): self._unscale_grads() grad_norm = clip_grad_norm(self.params, gradient_clip) # detect overflow and adjust loss scale - overflow = has_overflow(grad_norm) + overflow = self._maybe_sync(has_overflow(grad_norm)) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.min_loss_scale: From ea9390c9d761918a8edd2aa020cdc643947eb645 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 09:55:20 -0400 Subject: [PATCH 02/29] Implement overflow syncing. --- parlai/core/torch_generator_agent.py | 7 ++++--- parlai/utils/fp16.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 1f7dac3d007..04b35b42cb3 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -507,6 +507,7 @@ def __init__(self, opt: Opt, shared=None): if self.fp16: if not self._delay_halving: + logging.debug("Halving the model") self.model = self.model.half() if init_model is not None: @@ -528,12 +529,11 @@ def __init__(self, opt: Opt, shared=None): reshard_after_forward = opt['ddp_backend'] == 'zero3' # hack: fsdp expects things in fp32 if we're using mixed precision. # lol! convert it back! - if self.fp16 and mixed_precision: - self.model = self.model.float() + logging.debug("Wrapping in FSDP") self.model = FSDP( self.model, reshard_after_forward=reshard_after_forward, - mixed_precision=self.fp16 and opt['fp16_impl'] != 'mem_efficient', + mixed_precision=self.fp16 and opt['fp16_impl'] == 'safe', compute_dtype=torch.float16 if self.fp16 else torch.float32, state_dict_device=torch.device('cpu'), ) @@ -555,6 +555,7 @@ def __init__(self, opt: Opt, shared=None): if shared is None and is_distributed() and opt['ddp_backend'] == 'ddp': device_ids = None if self.model_parallel else [self.opt['gpu']] + logging.debug("Wrapping in DDP") self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=device_ids, broadcast_buffers=False ) diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index a7f216f70e7..f596e529522 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -104,9 +104,11 @@ def __init__(self, optimizer, sync_overflows=False): self.scaler = DynamicLossScaler(2.0 ** 15) self.min_loss_scale = 2 ** -5 self._sync_overflows = sync_overflows + logging.debug(f"Sync overflows = {sync_overflows}") def _maybe_sync(self, value: bool) -> bool: if self._sync_overflows: + logging.debug(f"Syncing value {value}") import torch.distributed as dist value_tensor = torch.BoolTensor([value]).cuda() From 378eacc12b1f9dfb052332be80c17ad79475e3f5 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 12:17:18 -0400 Subject: [PATCH 03/29] Tweak log statements. --- parlai/core/torch_generator_agent.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 04b35b42cb3..78160a8f636 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -506,8 +506,7 @@ def __init__(self, opt: Opt, shared=None): ) if self.fp16: - if not self._delay_halving: - logging.debug("Halving the model") + if not self._delay_halving(): self.model = self.model.half() if init_model is not None: @@ -529,12 +528,17 @@ def __init__(self, opt: Opt, shared=None): reshard_after_forward = opt['ddp_backend'] == 'zero3' # hack: fsdp expects things in fp32 if we're using mixed precision. # lol! convert it back! - logging.debug("Wrapping in FSDP") + compute_dtype = torch.float16 if self.fp16 else torch.float32 + mixed_precision = self.fp16 and opt['fp16_impl'] == 'safe' + logging.debug( + f"Wrapping in FSDP (reshard_after_forward = {reshard_after_forward}, " + f"compute_dtype = {compute_dtype} mixed_precision = {mixed_precision}" + ) self.model = FSDP( self.model, reshard_after_forward=reshard_after_forward, - mixed_precision=self.fp16 and opt['fp16_impl'] == 'safe', - compute_dtype=torch.float16 if self.fp16 else torch.float32, + mixed_precision=mixed_precision, + compute_dtype=compute_dtype, state_dict_device=torch.device('cpu'), ) @@ -555,7 +559,7 @@ def __init__(self, opt: Opt, shared=None): if shared is None and is_distributed() and opt['ddp_backend'] == 'ddp': device_ids = None if self.model_parallel else [self.opt['gpu']] - logging.debug("Wrapping in DDP") + logging.debug("Wrapping in simple DDP") self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=device_ids, broadcast_buffers=False ) From 0ea7d3ab49f5d43e34f36c0bcdb726effbdf09e2 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 12:41:04 -0400 Subject: [PATCH 04/29] Use free ports rather than random ports --- parlai/scripts/multiprocessing_eval.py | 2 +- parlai/scripts/multiprocessing_train.py | 2 +- parlai/utils/distributed.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py index 2b220fc1a24..4bd94693fbd 100644 --- a/parlai/scripts/multiprocessing_eval.py +++ b/parlai/scripts/multiprocessing_eval.py @@ -88,7 +88,7 @@ def setup_args(cls): return setup_args() def run(self): - port = random.randint(32000, 48000) + port = distributed_utils.find_free_port() return launch_and_eval(self.opt, port) diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 14c2e305846..11422ea9b6c 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -99,7 +99,7 @@ def setup_args(cls): def run(self): if self.opt['port'] is None: - port = random.randint(32000, 48000) + port = distributed_utils.find_free_port() else: port = self.opt['port'] return launch_and_train(self.opt, port) diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 27a9a240b55..91be35bfb89 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -346,3 +346,15 @@ def slurm_distributed_context(opt): except FileNotFoundError: # Slurm is not installed raise RuntimeError('SLURM does not appear to be installed.') + + +def find_free_port() -> int: + """ + Find a free port we can bind to locally. + + Credit: https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] From fc3e6688cedc286d4c1543a073344113a57b2963 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 13:46:41 -0400 Subject: [PATCH 05/29] Refactor test_distributed --- tests/test_distributed.py | 100 +++++++++++++++----------------------- 1 file changed, 39 insertions(+), 61 deletions(-) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 0a9372412f4..649de982aab 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import os -import copy import unittest import parlai.utils.testing as testing_utils import parlai.scripts.build_dict as build_dict @@ -15,18 +14,6 @@ BATCHSIZE = 4 -def _forced_parse(parser, opt): - parser.set_params(**opt) - parser.set_params(log_every_n_sec=10) - popt = parser.parse_args([]) - # in some rare cases, like for instance if the model class also - # overrides its default params, the params override will not - # be taken into account. - for k, v in opt.items(): - popt[k] = v - return popt - - @testing_utils.skipUnlessGPU class TestDistributed(unittest.TestCase): _base_config = dict( @@ -49,7 +36,8 @@ class TestDistributed(unittest.TestCase): def setUp(self): print(f'[Setting up test {self._testMethodName}]') - def _distributed_train_model(self, opt): + def _distributed_train_model(self, **overrides): + opt = {**self._base_config, **overrides} with testing_utils.tempdir() as tmpdir: if 'model_file' not in opt: opt['model_file'] = os.path.join(tmpdir, 'model') @@ -57,7 +45,7 @@ def _distributed_train_model(self, opt): opt['dict_file'] = os.path.join(tmpdir, 'model.dict') parser = mp_train.setup_args() - popt = _forced_parse(parser, opt) + popt = parser.parse_kwargs(**opt) # we need a prebuilt dictionary parser = build_dict.setup_args() @@ -68,8 +56,7 @@ def _distributed_train_model(self, opt): return (valid, test) def test_generator_distributed(self): - config = copy.deepcopy(self._base_config) - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model() self.assertLessEqual(valid['ppl'], 1.60) self.assertLessEqual(test['ppl'], 1.60) @@ -80,11 +67,11 @@ def test_generator_distributed(self): self.assertEqual(test['exs'].value(), BATCHSIZE) def test_multitask_distributed(self): - config = copy.deepcopy(self._base_config) - config['num_epochs'] = 50 - config['task'] = 'integration_tests:overfit,integration_tests:overfit_multiturn' - config['dynb'] = 'full' - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + num_epochs=50, + task='integration_tests:overfit,integration_tests:overfit_multiturn', + truncate=16, + ) self.assertLessEqual(valid['ppl'], 1.20) self.assertLessEqual(test['ppl'], 1.20) @@ -100,12 +87,12 @@ def test_multitask_distributed(self): ) def test_distributed_eval_max_exs(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['validation_max_exs'] = 90 - config['short_final_eval'] = True - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + validation_max_exs=90, + short_final_eval=True, + ) # Tests that DialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -120,11 +107,9 @@ def test_distributed_eval_max_exs(self): self.assertEqual(test['exs'].value(), 96) def test_distributed_eval_stream_mode(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', num_epochs=0.01, datatype='train:stream' + ) # Tests that StreamDialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -133,14 +118,13 @@ def test_distributed_eval_stream_mode(self): self.assertEqual(test['exs'].value(), inttests.NUM_TEST) def test_distributed_eval_stream_mode_max_exs(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['validation_max_exs'] = 90 - config['short_final_eval'] = True - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + datatype='train:stream', + validation_max_exs=90, + short_final_eval=True, + ) # Tests that StreamDialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -155,26 +139,23 @@ def test_distributed_eval_stream_mode_max_exs(self): self.assertEqual(test['exs'].value(), 96) def test_chunked_dynamic_teacher(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['dynamic_batching'] = 'full' - config['truncate'] = 16 - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + datatype='train:stream', + dynamic_batching='full', + truncate=16, + ) assert valid['exs'].value() == inttests.NUM_TEST assert test['exs'].value() == inttests.NUM_TEST def test_chunked_teacher(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['num_epochs'] = 5 - config['dynamic_batching'] = None - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + datatype='train:stream', + num_epochs=5, + dynamic_batching=None, + ) assert valid['exs'].value() == inttests.NUM_TEST assert test['exs'].value() == inttests.NUM_TEST @@ -184,16 +165,13 @@ def test_no_model_parallel(self): --model-parallel true. """ - config = copy.deepcopy(self._base_config) - config['model_parallel'] = True for m in [ 'transformer/generator', 'transformer/ranker', 'transformer/classifier', ]: - config['model'] = m try: - _ = self._distributed_train_model(config) + _ = self._distributed_train_model(model=m, model_parallel=True) except RuntimeError: pass else: From 65ad5265671aad0e0a801e4093a923a5c1cbefda Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 13:53:11 -0400 Subject: [PATCH 06/29] More refactor. --- tests/test_distributed.py | 76 +++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 649de982aab..4bec0c693be 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -14,30 +14,9 @@ BATCHSIZE = 4 -@testing_utils.skipUnlessGPU -class TestDistributed(unittest.TestCase): - _base_config = dict( - task='integration_tests:overfit', - model='transformer/generator', - optimizer='adam', - validation_metric='ppl', - skip_generation=True, - learningrate=1e-2, - batchsize=BATCHSIZE, - validation_every_n_epochs=5, - num_epochs=150, - n_layers=1, - n_heads=1, - ffn_size=32, - embedding_size=8, - verbose=True, - ) - - def setUp(self): - print(f'[Setting up test {self._testMethodName}]') - +class _AbstractTest(unittest.TestCase): def _distributed_train_model(self, **overrides): - opt = {**self._base_config, **overrides} + opt = {**self.base_config, **overrides} with testing_utils.tempdir() as tmpdir: if 'model_file' not in opt: opt['model_file'] = os.path.join(tmpdir, 'model') @@ -55,6 +34,26 @@ def _distributed_train_model(self, **overrides): return (valid, test) + +@testing_utils.skipUnlessGPU +class TestDistributed(_AbstractTest): + base_config = dict( + task='integration_tests:overfit', + model='transformer/generator', + optimizer='adam', + validation_metric='ppl', + skip_generation=True, + learningrate=1e-2, + batchsize=BATCHSIZE, + validation_every_n_epochs=5, + num_epochs=150, + n_layers=1, + n_heads=1, + ffn_size=32, + embedding_size=8, + verbose=True, + ) + def test_generator_distributed(self): valid, test = self._distributed_train_model() @@ -159,17 +158,32 @@ def test_chunked_teacher(self): assert valid['exs'].value() == inttests.NUM_TEST assert test['exs'].value() == inttests.NUM_TEST + +# class TestZero2(TestDistributed): +# pass + + +class TestNoModelParallel(_AbstractTest): + base_config = dict( + task='integration_tests:overfit', + optimizer='sgd', + validation_metric='loss', + learningrate=1e-2, + batchsize=BATCHSIZE, + validation_every_n_epochs=1, + num_epochs=1, + n_layers=1, + n_heads=1, + ffn_size=32, + embedding_size=8, + verbose=True, + ) + def test_no_model_parallel(self): """ - Checks that we throw an error when combining mp_train with. - - --model-parallel true. + Checks that we throw an error when combining mp_train with --model-parallel. """ - for m in [ - 'transformer/generator', - 'transformer/ranker', - 'transformer/classifier', - ]: + for m in ['transformer/generator', 'transformer/ranker']: try: _ = self._distributed_train_model(model=m, model_parallel=True) except RuntimeError: From 3153dd8de096be2c469d73206de7d9306f5b174a Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 15:52:06 -0400 Subject: [PATCH 07/29] Fixup checkpoints. --- parlai/core/torch_agent.py | 25 ++++++++++++++++--- parlai/core/torch_generator_agent.py | 5 ++-- parlai/scripts/train_model.py | 37 +++++++++++++--------------- parlai/utils/fp16.py | 2 -- tests/test_distributed.py | 4 +-- 5 files changed, 43 insertions(+), 30 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index bc8ebba5dc4..de54a40e012 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1976,12 +1976,21 @@ def state_dict(self): """ states = {} if hasattr(self, 'model'): # save model params - if hasattr(self.model, 'module'): + if hasattr(self.model, 'module') and self.opt['ddp_backend'] not in ( + 'zero2', + 'zero3', + ): # did we wrap in a DistributedDataParallel states['model'] = self.model.module.state_dict() else: + logging.info("About to store state dict") + import traceback + + logging.critical("".join(traceback.format_stack())) states['model'] = self.model.state_dict() + logging.info("Out of here") + if hasattr(self, 'optimizer'): # save optimizer params states['optimizer'] = self.optimizer.state_dict() @@ -1999,6 +2008,17 @@ def state_dict(self): return states + def save_nonprimary(self, path=None): + """ + Save model parameters, when you are working on the non-primary worker. + + For models or optimizers that shard parameters, this ensures we sync. + """ + logging.info("Saving non primary") + if self.opt['ddp_backend'] in ('zero2', 'zero3'): + # make sure we call the state dict + self.state_dict() + def save(self, path=None): """ Save model parameters to path (or default to model_file arg). @@ -2352,9 +2372,6 @@ def update_params(self): self.global_metrics.add('gnorm', GlobalAverageMetric(grad_norm)) if self.fp16: - logging.info( - f"fp16_loss_scale = {self.optimizer.loss_scale} [{self._number_training_updates}]" - ) self.global_metrics.add( 'fp16_loss_scalar', GlobalAverageMetric(self.optimizer.loss_scale) ) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 78160a8f636..69bb4317731 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -531,8 +531,8 @@ def __init__(self, opt: Opt, shared=None): compute_dtype = torch.float16 if self.fp16 else torch.float32 mixed_precision = self.fp16 and opt['fp16_impl'] == 'safe' logging.debug( - f"Wrapping in FSDP (reshard_after_forward = {reshard_after_forward}, " - f"compute_dtype = {compute_dtype} mixed_precision = {mixed_precision}" + f"Wrapping in FSDP(reshard_after_forward = {reshard_after_forward}, " + f"compute_dtype = {compute_dtype} mixed_precision = {mixed_precision})" ) self.model = FSDP( self.model, @@ -540,6 +540,7 @@ def __init__(self, opt: Opt, shared=None): mixed_precision=mixed_precision, compute_dtype=compute_dtype, state_dict_device=torch.device('cpu'), + flatten_parameters=True, ) if shared is not None: diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index f6e9c2d7321..64a3d8c92b3 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -442,10 +442,6 @@ def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ - if not is_primary_worker(): - # never do IO as a non-primary worker - return - if not self.opt.get('model_file'): # nothing to save to, just exit return @@ -453,6 +449,13 @@ def save_model(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix + + if not is_primary_worker(): + # never do IO as a non-primary worker + if hasattr(self.agent, 'save_nonprimary'): + self.agent.save_nonprimary(fn) + return + while True: # don't ever let a ctrl-c interrupt saving try: @@ -543,7 +546,7 @@ def validate(self): ) self.best_valid = new_valid self.impatience = 0 - if opt.get('model_file') and is_primary_worker(): + if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True @@ -566,11 +569,7 @@ def validate(self): self.validate_time.reset() # saving - if ( - opt.get('model_file') - and opt.get('save_after_valid') - and is_primary_worker() - ): + if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') @@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]: self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) - train_time, log_time, validate_time = sync_object( + train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), + self.save_time.time(), ) ) else: - train_time, log_time, validate_time = ( + train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), + self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) - return train_time, log_time, validate_time + return train_time, log_time, validate_time, save_time def log(self): """ @@ -810,7 +811,7 @@ def train_steps(self): self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs - train_time, log_time, validate_time = self._get_time(world) + train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) @@ -859,11 +860,7 @@ def train_steps(self): break # make sure metrics are clean before we log world.reset_metrics() - if ( - self.save_time.time() > self.save_every_n_secs - and opt.get('model_file') - and is_primary_worker() - ): + if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) @@ -872,7 +869,7 @@ def train_steps(self): self.save_model('.checkpoint') self.save_time.reset() - if not self.saved and is_primary_worker(): + if not sync_object(self.saved): # save agent self.save_model() diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index f596e529522..a7f216f70e7 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -104,11 +104,9 @@ def __init__(self, optimizer, sync_overflows=False): self.scaler = DynamicLossScaler(2.0 ** 15) self.min_loss_scale = 2 ** -5 self._sync_overflows = sync_overflows - logging.debug(f"Sync overflows = {sync_overflows}") def _maybe_sync(self, value: bool) -> bool: if self._sync_overflows: - logging.debug(f"Syncing value {value}") import torch.distributed as dist value_tensor = torch.BoolTensor([value]).cuda() diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 4bec0c693be..99c291c1426 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -159,8 +159,8 @@ def test_chunked_teacher(self): assert test['exs'].value() == inttests.NUM_TEST -# class TestZero2(TestDistributed): -# pass +class TestZero2(TestDistributed): + base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} class TestNoModelParallel(_AbstractTest): From 44fcdfc9eabe5369a0d04e5bf38540dc0926f6cb Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 16:33:42 -0400 Subject: [PATCH 08/29] Get tests working. --- parlai/core/torch_agent.py | 6 ------ parlai/core/torch_generator_agent.py | 3 ++- parlai/scripts/multiprocessing_train.py | 6 ++++-- parlai/utils/distributed.py | 6 ++++++ tests/test_distributed.py | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index de54a40e012..30878eadc4b 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1983,14 +1983,8 @@ def state_dict(self): # did we wrap in a DistributedDataParallel states['model'] = self.model.module.state_dict() else: - logging.info("About to store state dict") - import traceback - - logging.critical("".join(traceback.format_stack())) states['model'] = self.model.state_dict() - logging.info("Out of here") - if hasattr(self, 'optimizer'): # save optimizer params states['optimizer'] = self.optimizer.state_dict() diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 69bb4317731..9c0df036cbe 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -28,7 +28,7 @@ import torch.nn.functional as F from parlai.core.opt import Opt -from parlai.utils.distributed import is_distributed, sync_parameters +from parlai.utils.distributed import is_distributed, sync_parameters, get_dist_group from parlai.core.torch_agent import TorchAgent, Batch, Output, DictionaryAgent from parlai.utils.misc import warn_once from parlai.utils.io import PathManager @@ -541,6 +541,7 @@ def __init__(self, opt: Opt, shared=None): compute_dtype=compute_dtype, state_dict_device=torch.device('cpu'), flatten_parameters=True, + process_group=get_dist_group(), ) if shared is not None: diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 11422ea9b6c..394b6dae159 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -55,10 +55,12 @@ def multiprocess_train( raise -def launch_and_train(opt, port): +def launch_and_train(opt, port=None): """ Perform a fork() to many processes. """ + if port is None: + port = distributed_utils.find_free_port() # Launch multiple subprocesses spawncontext = torch.multiprocessing.start_processes( multiprocess_train, @@ -99,7 +101,7 @@ def setup_args(cls): def run(self): if self.opt['port'] is None: - port = distributed_utils.find_free_port() + port = None else: port = self.opt['port'] return launch_and_train(self.opt, port) diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 91be35bfb89..07553c5bdb5 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -296,6 +296,12 @@ def distributed_context( dist.destroy_process_group() +def get_dist_group(): + from torch.distributed.distributed_c10d import _get_default_group + + return _get_default_group() + + @contextlib.contextmanager def slurm_distributed_context(opt): """ diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 99c291c1426..c86b772167f 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -30,7 +30,7 @@ def _distributed_train_model(self, **overrides): parser = build_dict.setup_args() build_dict.build_dict(popt) - valid, test = mp_train.launch_and_train(popt, 31338) + valid, test = mp_train.launch_and_train(popt) return (valid, test) From 281efd16f90ab33a85d72eb90311477c7feac642 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 16:56:42 -0400 Subject: [PATCH 09/29] GPU only --- tests/test_distributed.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index c86b772167f..e06ed022c38 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -159,10 +159,17 @@ def test_chunked_teacher(self): assert test['exs'].value() == inttests.NUM_TEST +@testing_utils.skipUnlessGPU class TestZero2(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} +@testing_utils.skipUnlessGPU +class TestZero3(TestDistributed): + base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} + + +@testing_utils.skipUnlessGPU class TestNoModelParallel(_AbstractTest): base_config = dict( task='integration_tests:overfit', From 4146d868c0200734827f9983bc95b1d33b352530 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 18:30:46 -0400 Subject: [PATCH 10/29] Sigh --- parlai/core/torch_agent.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 30878eadc4b..b803a0bc109 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1118,7 +1118,7 @@ def init_optim( return True def _should_sync_overflows(self): - return self.fp16 and self.opt['ddp_backend'] in ('zero2', 'zero3') + return self.fp16 and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') def build_lr_scheduler(self, states=None, hard_reset=False): """ @@ -1976,10 +1976,9 @@ def state_dict(self): """ states = {} if hasattr(self, 'model'): # save model params - if hasattr(self.model, 'module') and self.opt['ddp_backend'] not in ( - 'zero2', - 'zero3', - ): + if hasattr(self.model, 'module') and self.opt.get( + 'ddp_backend', 'ddp' + ) not in ('zero2', 'zero3'): # did we wrap in a DistributedDataParallel states['model'] = self.model.module.state_dict() else: @@ -2009,7 +2008,7 @@ def save_nonprimary(self, path=None): For models or optimizers that shard parameters, this ensures we sync. """ logging.info("Saving non primary") - if self.opt['ddp_backend'] in ('zero2', 'zero3'): + if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'): # make sure we call the state dict self.state_dict() From 5c6755a09215c2a7c5c4529ce33d26d378f3e7fa Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Jun 2021 19:19:12 -0400 Subject: [PATCH 11/29] Moar. --- parlai/core/torch_generator_agent.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 9c0df036cbe..65c205d62fa 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -519,7 +519,7 @@ def __init__(self, opt: Opt, shared=None): if ( shared is None and is_distributed() - and opt['ddp_backend'] in ('zero2', 'zero3') + and opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') ): from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP @@ -559,7 +559,11 @@ def __init__(self, opt: Opt, shared=None): logging.warning("Optimizer was reset. Also resetting LR scheduler.") self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset) - if shared is None and is_distributed() and opt['ddp_backend'] == 'ddp': + if ( + shared is None + and is_distributed() + and opt.get('ddp_backend', 'ddp') == 'ddp' + ): device_ids = None if self.model_parallel else [self.opt['gpu']] logging.debug("Wrapping in simple DDP") self.model = torch.nn.parallel.DistributedDataParallel( @@ -579,7 +583,7 @@ def _delay_halving(self): return ( self.fp16 - and self.opt['ddp_backend'] in ('zero2', 'zero3') + and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') and self.opt['fp16_impl'] == 'safe' ) From dc5edc3166f7c65f3d93ac9e4d7b5ed44db7f824 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Wed, 23 Jun 2021 14:31:29 -0400 Subject: [PATCH 12/29] Trying to sync grad norms --- parlai/utils/fp16.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index a7f216f70e7..2182374839d 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -55,7 +55,7 @@ def forward(self, scores, targets): ) -def clip_grad_norm(params, max_norm): +def clip_grad_norm(params, max_norm, sync: bool = False): """ Clips grad norm. """ @@ -220,11 +220,13 @@ def clip_master_grads(self, max_norm): Clips gradient norm and updates dynamic loss scaler. """ self._sync_fp16_grads_to_fp32() - grad_norm = clip_grad_norm(self.fp32_params, max_norm) + grad_norm = clip_grad_norm( + self.fp32_params, max_norm, sync=self._sync_overflows + ) # detect overflow and adjust loss scale if self.scaler is not None: - overflow = self._maybe_sync(has_overflow(grad_norm)) + overflow = has_overflow(grad_norm) prev_scale = self.scaler.loss_scale self.scaler.update_scale(overflow) if overflow: @@ -400,6 +402,7 @@ class MemoryEfficientFP16Optimizer(torch.optim.Optimizer): def __init__( self, init_optimizer: torch.optim.Optimizer, # type: ignore + sync_overflows: bool = False, loss_initial_scale: float = 2.0 ** 17, min_loss_scale: float = 1e-4, ): @@ -408,6 +411,17 @@ def __init__( self.min_loss_scale = min_loss_scale self.scaler = DynamicLossScaler(init_scale=loss_initial_scale) + self._sync_overflows = sync_overflows + + def _maybe_sync(self, value: bool) -> bool: + if self._sync_overflows: + import torch.distributed as dist + + value_tensor = torch.BoolTensor([value]).cuda() + dist.all_reduce(value_tensor) + value = value_tensor.item() + return value + @staticmethod def compatible_optimizers(): """ @@ -456,9 +470,11 @@ def clip_master_grads(self, gradient_clip): Returns -1 if the most recently computed gradients overflowed. """ self._unscale_grads() - grad_norm = clip_grad_norm(self.params, gradient_clip) + grad_norm = clip_grad_norm( + self.params, gradient_clip, sync=self._sync_overflows + ) # detect overflow and adjust loss scale - overflow = self._maybe_sync(has_overflow(grad_norm)) + overflow = has_overflow(grad_norm) self.scaler.update_scale(overflow) if overflow: if self.scaler.loss_scale <= self.min_loss_scale: From 7e12292a643d3f79c4324f62b81f79fcfb85b970 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Wed, 23 Jun 2021 18:53:08 -0400 Subject: [PATCH 13/29] Correctly implement gnorm syncing. --- parlai/utils/fp16.py | 45 ++++++++++++++++---------------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index 2182374839d..808359569c3 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -63,19 +63,24 @@ def clip_grad_norm(params, max_norm, sync: bool = False): params = [params] # make sure any generators are expanded params = list(params) - if len(params) == 1: - p = params[0].grad - grad_norm = torch.norm(p) - if grad_norm > max_norm > 0: - clip_coef = max_norm / (grad_norm + 1e-6) - p.mul_(clip_coef) - return grad_norm - elif max_norm > 0: + # if syncing we need to manually perform the clipping so that we aggregrate + # properly + if max_norm > 0 and not sync: return torch.nn.utils.clip_grad_norm_(params, max_norm) else: - return torch.sqrt( - sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None) - ) + normsq = sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None) + if sync: + # also need to get the norms from all the other sharded works in FSDP + import torch.distributed as dist + + dist.all_reduce(normsq) + grad_norm = normsq.sqrt() + if max_norm > 0: + clip_coef = max_norm / (grad_norm + 1e-6) + for p in params: + p.grad.detach().mul_(clip_coef) + + return grad_norm def has_overflow(grad_norm): @@ -105,15 +110,6 @@ def __init__(self, optimizer, sync_overflows=False): self.min_loss_scale = 2 ** -5 self._sync_overflows = sync_overflows - def _maybe_sync(self, value: bool) -> bool: - if self._sync_overflows: - import torch.distributed as dist - - value_tensor = torch.BoolTensor([value]).cuda() - dist.all_reduce(value_tensor) - value = value_tensor.item() - return value - @classmethod def _get_parameters(cls, optimizer): params = [] @@ -413,15 +409,6 @@ def __init__( self._sync_overflows = sync_overflows - def _maybe_sync(self, value: bool) -> bool: - if self._sync_overflows: - import torch.distributed as dist - - value_tensor = torch.BoolTensor([value]).cuda() - dist.all_reduce(value_tensor) - value = value_tensor.item() - return value - @staticmethod def compatible_optimizers(): """ From 66d53d300370ff4c32fc62a3e151368df705fdf5 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Wed, 23 Jun 2021 19:25:01 -0400 Subject: [PATCH 14/29] Update comment. --- parlai/core/torch_generator_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 65c205d62fa..dfa912a6b5e 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -526,8 +526,6 @@ def __init__(self, opt: Opt, shared=None): device_ids = None if self.model_parallel else [self.opt['gpu']] mixed_precision = opt['fp16'] reshard_after_forward = opt['ddp_backend'] == 'zero3' - # hack: fsdp expects things in fp32 if we're using mixed precision. - # lol! convert it back! compute_dtype = torch.float16 if self.fp16 else torch.float32 mixed_precision = self.fp16 and opt['fp16_impl'] == 'safe' logging.debug( From 1cb30d119d11f1c13ead3a9b541cd059b439fda3 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 24 Jun 2021 10:10:12 -0400 Subject: [PATCH 15/29] Try zero3. --- parlai/core/torch_generator_agent.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index d9c383716e8..cb1963e3818 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -21,6 +21,7 @@ from abc import ABC, abstractmethod from typing import TypeVar, List, Dict, Optional, Tuple, Set, Iterable import math +import functools from operator import attrgetter import torch @@ -522,6 +523,11 @@ def __init__(self, opt: Opt, shared=None): and opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') ): from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + from fairscale.nn.wrap.auto_wrap import ( + enable_wrap, + auto_wrap, + default_auto_wrap_policy, + ) device_ids = None if self.model_parallel else [self.opt['gpu']] mixed_precision = opt['fp16'] @@ -532,8 +538,7 @@ def __init__(self, opt: Opt, shared=None): f"Wrapping in FSDP(reshard_after_forward = {reshard_after_forward}, " f"compute_dtype = {compute_dtype} mixed_precision = {mixed_precision})" ) - self.model = FSDP( - self.model, + fsdp_args = dict( reshard_after_forward=reshard_after_forward, mixed_precision=mixed_precision, compute_dtype=compute_dtype, @@ -542,6 +547,18 @@ def __init__(self, opt: Opt, shared=None): process_group=get_dist_group(), ) + with enable_wrap(wrapper_cls=FSDP, **fsdp_args): + # TODO: we can save a bit more memory if we ever manually + # wrap things. + policy = functools.partial( + default_auto_wrap_policy, + min_num_params=1e7, + exclude_wrap_modules={nn.Embedding, nn.ModuleList}, + ) + self.model.encoder = auto_wrap(self.model.encoder, policy) + self.model.decoder = auto_wrap(self.model.decoder, policy) + self.model = FSDP(self.model, **fsdp_args) + if shared is not None: if 'optimizer' in shared: self.optimizer = shared['optimizer'] From 5cea3b2624a94ee0d81cbdbafbcf5dfa43c1a5ed Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 24 Jun 2021 16:24:23 -0400 Subject: [PATCH 16/29] Okay got zero3 working. --- parlai/core/torch_generator_agent.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index cb1963e3818..7809ef88b89 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -534,10 +534,6 @@ def __init__(self, opt: Opt, shared=None): reshard_after_forward = opt['ddp_backend'] == 'zero3' compute_dtype = torch.float16 if self.fp16 else torch.float32 mixed_precision = self.fp16 and opt['fp16_impl'] == 'safe' - logging.debug( - f"Wrapping in FSDP(reshard_after_forward = {reshard_after_forward}, " - f"compute_dtype = {compute_dtype} mixed_precision = {mixed_precision})" - ) fsdp_args = dict( reshard_after_forward=reshard_after_forward, mixed_precision=mixed_precision, @@ -546,17 +542,16 @@ def __init__(self, opt: Opt, shared=None): flatten_parameters=True, process_group=get_dist_group(), ) + logging.debug(f"Wrapping in FSDP: {fsdp_args}") with enable_wrap(wrapper_cls=FSDP, **fsdp_args): # TODO: we can save a bit more memory if we ever manually # wrap things. - policy = functools.partial( - default_auto_wrap_policy, - min_num_params=1e7, - exclude_wrap_modules={nn.Embedding, nn.ModuleList}, - ) - self.model.encoder = auto_wrap(self.model.encoder, policy) - self.model.decoder = auto_wrap(self.model.decoder, policy) + for i, layer in enumerate(self.model.encoder.layers): + self.model.encoder.layers[i] = FSDP(layer, **fsdp_args) + for i, layer in enumerate(self.model.decoder.layers): + self.model.decoder.layers[i] = FSDP(layer, **fsdp_args) + self.model = FSDP(self.model, **fsdp_args) if shared is not None: From 490f5d8c7522f23d2054af8080dead0319ffcf26 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 24 Jun 2021 20:49:36 -0400 Subject: [PATCH 17/29] Refactor. --- parlai/core/torch_agent.py | 3 -- parlai/core/torch_generator_agent.py | 55 ++-------------------------- 2 files changed, 3 insertions(+), 55 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index b803a0bc109..dec1b7f8fbb 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1117,9 +1117,6 @@ def init_optim( ) return True - def _should_sync_overflows(self): - return self.fp16 and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') - def build_lr_scheduler(self, states=None, hard_reset=False): """ Create the learning rate scheduler, and assign it to self.scheduler. This diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 7809ef88b89..0904a9597f9 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -517,42 +517,9 @@ def __init__(self, opt: Opt, shared=None): else: states = {} - if ( - shared is None - and is_distributed() - and opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') - ): - from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP - from fairscale.nn.wrap.auto_wrap import ( - enable_wrap, - auto_wrap, - default_auto_wrap_policy, - ) - - device_ids = None if self.model_parallel else [self.opt['gpu']] - mixed_precision = opt['fp16'] - reshard_after_forward = opt['ddp_backend'] == 'zero3' - compute_dtype = torch.float16 if self.fp16 else torch.float32 - mixed_precision = self.fp16 and opt['fp16_impl'] == 'safe' - fsdp_args = dict( - reshard_after_forward=reshard_after_forward, - mixed_precision=mixed_precision, - compute_dtype=compute_dtype, - state_dict_device=torch.device('cpu'), - flatten_parameters=True, - process_group=get_dist_group(), - ) - logging.debug(f"Wrapping in FSDP: {fsdp_args}") - - with enable_wrap(wrapper_cls=FSDP, **fsdp_args): - # TODO: we can save a bit more memory if we ever manually - # wrap things. - for i, layer in enumerate(self.model.encoder.layers): - self.model.encoder.layers[i] = FSDP(layer, **fsdp_args) - for i, layer in enumerate(self.model.decoder.layers): - self.model.decoder.layers[i] = FSDP(layer, **fsdp_args) - - self.model = FSDP(self.model, **fsdp_args) + if shared is None and fsdp_utils.should_use_fsdp(opt): + with fsdp_utils.enable_fsdp_wrap(opt): + pass if shared is not None: if 'optimizer' in shared: @@ -575,28 +542,12 @@ def __init__(self, opt: Opt, shared=None): and opt.get('ddp_backend', 'ddp') == 'ddp' ): device_ids = None if self.model_parallel else [self.opt['gpu']] - logging.debug("Wrapping in simple DDP") self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=device_ids, broadcast_buffers=False ) self.reset() - def _delay_halving(self): - """ - Check whether we should keep the model in fp32 before other setup. - - When using Zero2 or Zero3 backends with mixed precision, we need to - avoid converting the model to fp16, as the FSDP module does this for - us. - """ - - return ( - self.fp16 - and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') - and self.opt['fp16_impl'] == 'safe' - ) - def build_criterion(self): """ Construct and return the loss function. From 31dfeb551534df3fc28bf8d34fd9061b5062d5d2 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Fri, 25 Jun 2021 10:49:43 -0400 Subject: [PATCH 18/29] Get FSDP Zero3 working, except during validation. --- parlai/agents/transformer/modules/decoder.py | 18 +++++++-------- parlai/agents/transformer/modules/encoder.py | 18 +++++++-------- parlai/core/torch_agent.py | 12 +++++----- parlai/core/torch_generator_agent.py | 23 +++++++++----------- parlai/scripts/train_model.py | 6 +++++ 5 files changed, 40 insertions(+), 37 deletions(-) diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index 7ba4968195a..52ee3a80cc1 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -25,6 +25,7 @@ from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +from parlai.utils.fsdp import fsdp_wrap @swappable( @@ -277,16 +278,15 @@ def _default(val, default): def build_layers(self) -> nn.ModuleList: layers = nn.ModuleList() for _ in range(self.n_layers): - layers.append( - self.swappables.layer( - self.opt, - attention_dropout=self.opt.get('attention_dropout', 0.0), - relu_dropout=self.opt.get('relu_dropout', 0.0), - dropout=self.opt.get('dropout', 0.0), - activation=self.activation, - variant=self.variant, - ) # type: ignore + layer = self.swappables.layer( + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.opt.get('dropout', 0.0), + activation=self.activation, + variant=self.variant, ) + layers.append(fsdp_wrap(layer)) # type: ignore return layers def forward_embedding( diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index b79981fd9eb..441d13112f9 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -25,6 +25,7 @@ from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +from parlai.utils.fsdp import fsdp_wrap @swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN) @@ -227,16 +228,15 @@ def _default(val, default): def build_layers(self) -> nn.ModuleList: layers = nn.ModuleList() for _ in range(self.n_layers): - layers.append( - self.swappables.layer( # type: ignore - self.opt, - attention_dropout=self.opt.get('attention_dropout', 0.0), - relu_dropout=self.opt.get('relu_dropout', 0.0), - dropout=self.dropout_frac, - variant=self.variant, - activation=self.activation, - ) + layer = self.swappables.layer( # type: ignore + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.dropout_frac, + variant=self.variant, + activation=self.activation, ) + layers.append(fsdp_wrap(layer)) return layers def forward_embedding( diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index dec1b7f8fbb..e903c92b166 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -36,6 +36,7 @@ from parlai.utils.distributed import is_distributed from parlai.utils.misc import AttrDict, warn_once from parlai.utils.io import PathManager +from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp from parlai.utils.fp16 import ( SafeFP16Optimizer, MemoryEfficientFP16Optimizer, @@ -1053,7 +1054,7 @@ def init_optim( if self.fp16: if self.fp16_impl == 'safe': self.optimizer = SafeFP16Optimizer( - self.optimizer, self._should_sync_overflows() + self.optimizer, should_sync_gradnorm(opt) ) else: # Using memory efficient optimizer @@ -1067,7 +1068,7 @@ def init_optim( f'list:\n{compatible_list}' ) self.optimizer = MemoryEfficientFP16Optimizer( - self.optimizer, self._should_sync_overflows() + self.optimizer, should_sync_gradnorm(opt) ) if is_finetune: @@ -1973,12 +1974,11 @@ def state_dict(self): """ states = {} if hasattr(self, 'model'): # save model params - if hasattr(self.model, 'module') and self.opt.get( - 'ddp_backend', 'ddp' - ) not in ('zero2', 'zero3'): - # did we wrap in a DistributedDataParallel + if hasattr(self.model, 'module') and not is_fsdp(self.model): + # did we wrap in a DistributedDataParallel or DataParallel states['model'] = self.model.module.state_dict() else: + # regular model or FSDP states['model'] = self.model.state_dict() if hasattr(self, 'optimizer'): diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 0904a9597f9..b70712cf265 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -21,7 +21,6 @@ from abc import ABC, abstractmethod from typing import TypeVar, List, Dict, Optional, Tuple, Set, Iterable import math -import functools from operator import attrgetter import torch @@ -29,13 +28,14 @@ import torch.nn.functional as F from parlai.core.opt import Opt -from parlai.utils.distributed import is_distributed, sync_parameters, get_dist_group +from parlai.utils.distributed import is_distributed, sync_parameters from parlai.core.torch_agent import TorchAgent, Batch, Output, DictionaryAgent from parlai.utils.misc import warn_once from parlai.utils.io import PathManager import parlai.utils.logging as logging from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric from parlai.utils.fp16 import FP16SafeCrossEntropy +import parlai.utils.fsdp as fsdp_utils from parlai.utils.torch import ( neginf, total_parameters, @@ -480,8 +480,11 @@ def __init__(self, opt: Opt, shared=None): else: # this is not a shared instance of this class, so do full init self.criterion = self.build_criterion() - # ensure all distributed copies will always be in sync - self.model = self.build_model() + with fsdp_utils.maybe_fsdp_wrap(opt): + self.model = fsdp_utils.fsdp_wrap(self.build_model()) + logging.debug(f"Model arch:\n{self.model}") + if self.fp16 and not fsdp_utils.should_use_fsdp(opt): + self.model = self.model.half() # load the block_list for beam search self.beam_block_list = self._load_beam_block_list() @@ -499,17 +502,15 @@ def __init__(self, opt: Opt, shared=None): self.model.cuda() self.criterion.cuda() - sync_parameters(self.model) + if not fsdp_utils.is_fsdp(self.model): + sync_parameters(self.model) + train_params = trainable_parameters(self.model) total_params = total_parameters(self.model) logging.info( f"Total parameters: {total_params:,d} ({train_params:,d} trainable)" ) - if self.fp16: - if not self._delay_halving(): - self.model = self.model.half() - if init_model is not None: # load model parameters if available logging.info(f'Loading existing model params from {init_model}') @@ -517,10 +518,6 @@ def __init__(self, opt: Opt, shared=None): else: states = {} - if shared is None and fsdp_utils.should_use_fsdp(opt): - with fsdp_utils.enable_fsdp_wrap(opt): - pass - if shared is not None: if 'optimizer' in shared: self.optimizer = shared['optimizer'] diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 64a3d8c92b3..047705cbfa8 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -453,12 +453,14 @@ def save_model(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker if hasattr(self.agent, 'save_nonprimary'): + logging.debug("Saving on non-primary") self.agent.save_nonprimary(fn) return while True: # don't ever let a ctrl-c interrupt saving try: + logging.debug("Saving on primary") self.agent.save(fn) self._save_train_stats(suffix) break @@ -591,10 +593,12 @@ def _run_single_eval(self, opt, valid_world, max_exs): max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() + logging.info(f"Ran cnt {cnt}") if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 + logging.info(f"rank {is_primary_worker()} epoch_done") valid_report = valid_world.report() if opt.get('validation_share_agent', False): @@ -633,7 +637,9 @@ def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): named_reports, micro_average=self.opt.get('aggregate_micro', False) ) # get the results from all workers + logging.debug("Syncing metrics") report = self._sync_metrics(report) + logging.debug("Done syncing metrics") metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') From d095f51ca7df263457f64992da666a53e0420a6d Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 10:09:34 -0400 Subject: [PATCH 19/29] Check in missing code. Carve out notimplemented. --- parlai/utils/fsdp.py | 113 ++++++++++++++++++++++++++++++++++++++ tests/test_distributed.py | 7 +++ 2 files changed, 120 insertions(+) create mode 100644 parlai/utils/fsdp.py diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py new file mode 100644 index 00000000000..63ed3f9aab4 --- /dev/null +++ b/parlai/utils/fsdp.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for FullyShardedDataParallel. +""" + +import contextlib +import torch.nn +from parlai.utils.distributed import is_distributed, get_dist_group + +try: + from fairscale.nn.wrap.auto_wrap import wrap + from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + + FSDP_AVAILABLE = True +except ImportError: + FSDP_AVAILABLE = False + + def wrap(module, **kwargs): + return module + + +DEFAULT_DDP_BACKEND = "ddp" + + +def is_fsdp(module: torch.nn.Module): + """ + Checks whether a module is fully sharded. + """ + return FSDP_AVAILABLE and isinstance(module, FSDP) + + +def should_use_fsdp(opt): + return ( + FSDP_AVAILABLE + and is_distributed() + and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3') + ) + + +@contextlib.contextmanager +def maybe_fsdp_wrap(opt): + """ + Context manager for enabling wrapping in FullyShardedDataParallel. + """ + if not should_use_fsdp(opt): + # make a no-op + yield + return + + # zero3 not supported at this time. Throw an exception + if opt['ddp_backend'] == 'zero3': + raise NotImplementedError( + '--ddp-backend zero3 is not supported at this time. For details, see ' + 'https://github.com/facebookresearch/ParlAI/issues/3753.' + ) + + reshard_after_forward = opt['ddp_backend'] == 'zero3' + compute_dtype = torch.float16 if opt['fp16'] else torch.float32 + mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' + fsdp_args = dict( + reshard_after_forward=reshard_after_forward, + mixed_precision=mixed_precision, + compute_dtype=compute_dtype, + state_dict_device=torch.device('cpu'), + flatten_parameters=True, + process_group=get_dist_group(), + ) + with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args): + yield + + +def delay_halving(self): + """ + Check whether we should keep the model in fp32 before other setup. + + When using Zero2 or Zero3 backends with mixed precision, we need to + avoid converting the model to fp16, as the FSDP module does this for + us. + """ + + return ( + self.fp16 + and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') + and self.opt['fp16_impl'] == 'safe' + ) + + +def should_sync_gradnorm(opt): + """ + Indicates whether fp16 optimizer wrappers should cumulate over workers. + + FP16 overflow detection and gradient clipping both require accumulating + gradients across all workers when using FSDP, as workers only store a + fraction of the gradients. + """ + return ( + FSDP_AVAILABLE + and opt['fp16'] + and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3') + ) + + +def fsdp_wrap(module): + """ + Helper function for wrapping the outermost root module. + """ + return wrap(module) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index e06ed022c38..70cf44e58b4 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -161,11 +161,18 @@ def test_chunked_teacher(self): @testing_utils.skipUnlessGPU class TestZero2(TestDistributed): + """ + Integration tests for zero2 FSDP. + """ + base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} +@testing_utils.skip @testing_utils.skipUnlessGPU class TestZero3(TestDistributed): + # Not supported at this time. See: + # https://github.com/facebookresearch/ParlAI/pull/3740 base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} From f17abb2ba13942f9d17df1cad90001e0273ded55 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 10:13:42 -0400 Subject: [PATCH 20/29] Lint. --- parlai/scripts/multiprocessing_eval.py | 1 - parlai/scripts/multiprocessing_train.py | 1 - parlai/utils/fsdp.py | 11 +++++------ 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py index 4bd94693fbd..bfc7bdc34b7 100644 --- a/parlai/scripts/multiprocessing_eval.py +++ b/parlai/scripts/multiprocessing_eval.py @@ -23,7 +23,6 @@ """ import torch -import random import os import signal import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 394b6dae159..543d316b01a 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -24,7 +24,6 @@ """ import torch -import random import os import signal import traceback diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py index 63ed3f9aab4..da356397150 100644 --- a/parlai/utils/fsdp.py +++ b/parlai/utils/fsdp.py @@ -79,9 +79,8 @@ def delay_halving(self): """ Check whether we should keep the model in fp32 before other setup. - When using Zero2 or Zero3 backends with mixed precision, we need to - avoid converting the model to fp16, as the FSDP module does this for - us. + When using Zero2 or Zero3 backends with mixed precision, we need to avoid converting + the model to fp16, as the FSDP module does this for us. """ return ( @@ -95,9 +94,9 @@ def should_sync_gradnorm(opt): """ Indicates whether fp16 optimizer wrappers should cumulate over workers. - FP16 overflow detection and gradient clipping both require accumulating - gradients across all workers when using FSDP, as workers only store a - fraction of the gradients. + FP16 overflow detection and gradient clipping both require accumulating gradients + across all workers when using FSDP, as workers only store a fraction of the + gradients. """ return ( FSDP_AVAILABLE From 231e88d629cbbec5b21c17f599db08e08b3643da Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 10:17:18 -0400 Subject: [PATCH 21/29] Er. --- tests/test_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 70cf44e58b4..d1fda5d36d7 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -168,7 +168,7 @@ class TestZero2(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} -@testing_utils.skip +@unittest.skip @testing_utils.skipUnlessGPU class TestZero3(TestDistributed): # Not supported at this time. See: From 4a3ce8696efe836b9999210366abea5e967dcab2 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 10:45:54 -0400 Subject: [PATCH 22/29] Add a test to ensure we keep track of zero3 not working. --- tests/test_distributed.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index d1fda5d36d7..688c8822c7d 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -176,6 +176,31 @@ class TestZero3(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} +class TestZero3NotImplemented(_AbstractTest): + base_config = dict( + task='integration_tests:overfit', + optimizer='sgd', + validation_metric='loss', + ddp_backend='zero3', + batchsize=BATCHSIZE, + model='transformer/generator', + validation_every_n_epochs=1, + num_epochs=1, + n_layers=1, + n_heads=1, + ffn_size=32, + embedding_size=8, + verbose=True, + ) + + def test_not_implemented(self): + """ + Checks that using --ddp-backend zero3 throws an error + """ + with self.assertRaises(NotImplementedError): + self._distributed_train_model() + + @testing_utils.skipUnlessGPU class TestNoModelParallel(_AbstractTest): base_config = dict( From 98a90b76186faba4415dca1a74e2501aa33c951a Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 10:58:07 -0400 Subject: [PATCH 23/29] Remove debugs, add docstrings, rename variable. --- parlai/core/torch_agent.py | 1 - parlai/core/torch_generator_agent.py | 1 - parlai/scripts/train_model.py | 6 ------ parlai/utils/distributed.py | 7 ++++++ parlai/utils/fp16.py | 32 +++++++++++++++++++++------- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index e903c92b166..412d24d90ed 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -2004,7 +2004,6 @@ def save_nonprimary(self, path=None): For models or optimizers that shard parameters, this ensures we sync. """ - logging.info("Saving non primary") if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'): # make sure we call the state dict self.state_dict() diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index b70712cf265..eda482b2ebd 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -482,7 +482,6 @@ def __init__(self, opt: Opt, shared=None): self.criterion = self.build_criterion() with fsdp_utils.maybe_fsdp_wrap(opt): self.model = fsdp_utils.fsdp_wrap(self.build_model()) - logging.debug(f"Model arch:\n{self.model}") if self.fp16 and not fsdp_utils.should_use_fsdp(opt): self.model = self.model.half() diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 047705cbfa8..64a3d8c92b3 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -453,14 +453,12 @@ def save_model(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker if hasattr(self.agent, 'save_nonprimary'): - logging.debug("Saving on non-primary") self.agent.save_nonprimary(fn) return while True: # don't ever let a ctrl-c interrupt saving try: - logging.debug("Saving on primary") self.agent.save(fn) self._save_train_stats(suffix) break @@ -593,12 +591,10 @@ def _run_single_eval(self, opt, valid_world, max_exs): max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() - logging.info(f"Ran cnt {cnt}") if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 - logging.info(f"rank {is_primary_worker()} epoch_done") valid_report = valid_world.report() if opt.get('validation_share_agent', False): @@ -637,9 +633,7 @@ def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): named_reports, micro_average=self.opt.get('aggregate_micro', False) ) # get the results from all workers - logging.debug("Syncing metrics") report = self._sync_metrics(report) - logging.debug("Done syncing metrics") metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 07553c5bdb5..bacf71d5205 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -297,6 +297,13 @@ def distributed_context( def get_dist_group(): + """ + Find the default pytorch distributed group. + + Used within FSDP to mark which workers are participating. Important to + manually call this because FSDP will cache old groups, but our test + suite will instantiate new groups per test. + """ from torch.distributed.distributed_c10d import _get_default_group return _get_default_group() diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index 808359569c3..00d03d36744 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -55,9 +55,25 @@ def forward(self, scores, targets): ) -def clip_grad_norm(params, max_norm, sync: bool = False): +def clip_grad_norm(params, max_norm: float = 0, sync: bool = False): """ - Clips grad norm. + Clips grad norms. + + During combination with FSDP, will also ensure that grad norms are aggregated + across all workers, since each worker only stores their shard of the + gradients. + + :param params: + Parameters whose gradients we wish to clip + :param max_norm: + Maximum norm we wish the gradients to have. If non-positive, then + we will not perform clipping. + :param sync: + Boolean indicating whether we should aggregate across the distributed + group. Used only in combination with FSDP. + + :returns: + The gradient norm across all parameters, before clipping. """ if isinstance(params, torch.Tensor): params = [params] @@ -93,7 +109,7 @@ def has_overflow(grad_norm): class SafeFP16Optimizer(torch.optim.Optimizer): - def __init__(self, optimizer, sync_overflows=False): + def __init__(self, optimizer, aggregate_gnorms=False): self.fp16_params = self._get_parameters(optimizer) self.fp32_params = self._build_fp32_params(self.fp16_params, flatten=False) self.optimizer = optimizer @@ -108,7 +124,7 @@ def __init__(self, optimizer, sync_overflows=False): self.scaler = DynamicLossScaler(2.0 ** 15) self.min_loss_scale = 2 ** -5 - self._sync_overflows = sync_overflows + self._aggregate_gnorms = aggregate_gnorms @classmethod def _get_parameters(cls, optimizer): @@ -217,7 +233,7 @@ def clip_master_grads(self, max_norm): """ self._sync_fp16_grads_to_fp32() grad_norm = clip_grad_norm( - self.fp32_params, max_norm, sync=self._sync_overflows + self.fp32_params, max_norm, sync=self._aggregate_gnorms ) # detect overflow and adjust loss scale @@ -398,7 +414,7 @@ class MemoryEfficientFP16Optimizer(torch.optim.Optimizer): def __init__( self, init_optimizer: torch.optim.Optimizer, # type: ignore - sync_overflows: bool = False, + aggregate_gnorms: bool = False, loss_initial_scale: float = 2.0 ** 17, min_loss_scale: float = 1e-4, ): @@ -407,7 +423,7 @@ def __init__( self.min_loss_scale = min_loss_scale self.scaler = DynamicLossScaler(init_scale=loss_initial_scale) - self._sync_overflows = sync_overflows + self._aggregate_gnorms = aggregate_gnorms @staticmethod def compatible_optimizers(): @@ -458,7 +474,7 @@ def clip_master_grads(self, gradient_clip): """ self._unscale_grads() grad_norm = clip_grad_norm( - self.params, gradient_clip, sync=self._sync_overflows + self.params, gradient_clip, sync=self._aggregate_gnorms ) # detect overflow and adjust loss scale overflow = has_overflow(grad_norm) From a2f84c131c645c5334d4ec2749676200a52ac7a1 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Mon, 28 Jun 2021 11:05:29 -0400 Subject: [PATCH 24/29] Silly --- tests/test_distributed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 688c8822c7d..0b599694a2c 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -176,6 +176,7 @@ class TestZero3(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} +@testing_utils.skipUnlessGPU class TestZero3NotImplemented(_AbstractTest): base_config = dict( task='integration_tests:overfit', From 61b64dc163df543e3b5b96bce9660ac468a5b87f Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 1 Jul 2021 10:05:33 -0400 Subject: [PATCH 25/29] Reviewer comments. --- parlai/core/params.py | 6 +++--- parlai/core/torch_agent.py | 4 ++-- parlai/core/torch_generator_agent.py | 2 +- parlai/utils/fsdp.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/parlai/core/params.py b/parlai/core/params.py index f94db1cee16..281439d185d 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -774,12 +774,12 @@ def add_distributed_training_args(self): ) grp.add_argument( '--ddp-backend', - choices=['ddp', 'zero2', 'zero3'], + # TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753 + choices=['ddp', 'zero2'], default='ddp', help=( 'Distributed backend. Zero2 can be faster but is more experimental. ' - 'Zero3 uses radically less memory, but is slower. DDP is the most ' - 'tested.' + 'DDP is the most tested.' ), ) return grp diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 412d24d90ed..af9c3a93ab3 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -36,7 +36,7 @@ from parlai.utils.distributed import is_distributed from parlai.utils.misc import AttrDict, warn_once from parlai.utils.io import PathManager -from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp +from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND from parlai.utils.fp16 import ( SafeFP16Optimizer, MemoryEfficientFP16Optimizer, @@ -2004,7 +2004,7 @@ def save_nonprimary(self, path=None): For models or optimizers that shard parameters, this ensures we sync. """ - if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'): + if self.opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3'): # make sure we call the state dict self.state_dict() diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index b845b274b09..dc4f8c765d7 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -535,7 +535,7 @@ def __init__(self, opt: Opt, shared=None): if ( shared is None and is_distributed() - and opt.get('ddp_backend', 'ddp') == 'ddp' + and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp' ): device_ids = None if self.model_parallel else [self.opt['gpu']] self.model = torch.nn.parallel.DistributedDataParallel( diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py index da356397150..458878a3f44 100644 --- a/parlai/utils/fsdp.py +++ b/parlai/utils/fsdp.py @@ -92,7 +92,7 @@ def delay_halving(self): def should_sync_gradnorm(opt): """ - Indicates whether fp16 optimizer wrappers should cumulate over workers. + Indicates whether fp16 optimizer wrappers should accumulate over workers. FP16 overflow detection and gradient clipping both require accumulating gradients across all workers when using FSDP, as workers only store a fraction of the From 16374c917b72ef4acc63b43972ebcfdb3a1c7049 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 1 Jul 2021 10:05:53 -0400 Subject: [PATCH 26/29] Lint. --- parlai/utils/distributed.py | 6 +++--- tests/test_distributed.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index bacf71d5205..2088b3451dc 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -300,9 +300,9 @@ def get_dist_group(): """ Find the default pytorch distributed group. - Used within FSDP to mark which workers are participating. Important to - manually call this because FSDP will cache old groups, but our test - suite will instantiate new groups per test. + Used within FSDP to mark which workers are participating. Important to manually call + this because FSDP will cache old groups, but our test suite will instantiate new + groups per test. """ from torch.distributed.distributed_c10d import _get_default_group diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 0b599694a2c..d935ac35d0e 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -196,7 +196,7 @@ class TestZero3NotImplemented(_AbstractTest): def test_not_implemented(self): """ - Checks that using --ddp-backend zero3 throws an error + Checks that using --ddp-backend zero3 throws an error. """ with self.assertRaises(NotImplementedError): self._distributed_train_model() From 074be0a76a029f56b6abc260c10858a99e0f88ce Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 1 Jul 2021 10:06:45 -0400 Subject: [PATCH 27/29] We disabled zero3 as an option, so don't need the test. --- tests/test_distributed.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index d935ac35d0e..d1fda5d36d7 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -176,32 +176,6 @@ class TestZero3(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} -@testing_utils.skipUnlessGPU -class TestZero3NotImplemented(_AbstractTest): - base_config = dict( - task='integration_tests:overfit', - optimizer='sgd', - validation_metric='loss', - ddp_backend='zero3', - batchsize=BATCHSIZE, - model='transformer/generator', - validation_every_n_epochs=1, - num_epochs=1, - n_layers=1, - n_heads=1, - ffn_size=32, - embedding_size=8, - verbose=True, - ) - - def test_not_implemented(self): - """ - Checks that using --ddp-backend zero3 throws an error. - """ - with self.assertRaises(NotImplementedError): - self._distributed_train_model() - - @testing_utils.skipUnlessGPU class TestNoModelParallel(_AbstractTest): base_config = dict( From 0814c999efdf31f053f1e8467abf59a8dfa5855d Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 1 Jul 2021 11:37:10 -0400 Subject: [PATCH 28/29] Bug caught by Kurt. --- parlai/core/torch_generator_agent.py | 2 +- parlai/utils/fsdp.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index dc4f8c765d7..eda8538e0e9 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -482,7 +482,7 @@ def __init__(self, opt: Opt, shared=None): self.criterion = self.build_criterion() with fsdp_utils.maybe_fsdp_wrap(opt): self.model = fsdp_utils.fsdp_wrap(self.build_model()) - if self.fp16 and not fsdp_utils.should_use_fsdp(opt): + if self.fp16 and not fsdp_utils.delay_halving(opt): self.model = self.model.half() # load the block_list for beam search diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py index 458878a3f44..1487c63f9ae 100644 --- a/parlai/utils/fsdp.py +++ b/parlai/utils/fsdp.py @@ -81,13 +81,12 @@ def delay_halving(self): When using Zero2 or Zero3 backends with mixed precision, we need to avoid converting the model to fp16, as the FSDP module does this for us. + + If we are using just plain DDP or MemoryEfficient optimizers, then we want + to call half() early. """ - return ( - self.fp16 - and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') - and self.opt['fp16_impl'] == 'safe' - ) + return self.fp16 and should_use_fsdp(opt) and self.opt['fp16_impl'] == 'safe' def should_sync_gradnorm(opt): From c5a82aad6d374fee0200359c82c8d77d9d1a0cf8 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 1 Jul 2021 12:58:00 -0400 Subject: [PATCH 29/29] Rofl --- parlai/utils/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py index 1487c63f9ae..e2fb305f372 100644 --- a/parlai/utils/fsdp.py +++ b/parlai/utils/fsdp.py @@ -75,7 +75,7 @@ def maybe_fsdp_wrap(opt): yield -def delay_halving(self): +def delay_halving(opt): """ Check whether we should keep the model in fp32 before other setup. @@ -86,7 +86,7 @@ def delay_halving(self): to call half() early. """ - return self.fp16 and should_use_fsdp(opt) and self.opt['fp16_impl'] == 'safe' + return opt['fp16'] and should_use_fsdp(opt) and opt['fp16_impl'] == 'safe' def should_sync_gradnorm(opt):