diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 9c751a215..ee0599b67 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -113,7 +113,7 @@ def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffi if not end_of_epoch and cfg.keep_last_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + cfg.save_dir, pattern=r"checkpoint_(\d+){}\.pt".format(suffix) ) for old_chk in checkpoints[cfg.keep_last_updates :]: if os.path.lexists(old_chk): diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 011ca0e8a..cc02cdb64 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -36,7 +36,7 @@ from metaseq.data import iterators, data_utils from metaseq.data.plasma_utils import PlasmaStore from metaseq.dataclass.utils import convert_namespace_to_omegaconf -from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils +from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel, utils as distributed_utils from metaseq.file_io import PathManager from metaseq.logging import meters, metrics, progress_bar from metaseq.trainer import Trainer @@ -144,15 +144,36 @@ def main(cfg: DictConfig) -> None: cfg.distributed_training, use_sharded_state=cfg.distributed_training.use_sharded_state, ): - model = fsdp_wrap( - task.build_model(cfg.model), - process_group=distributed_utils.get_data_parallel_group(), - ) + model = task.build_model(cfg.model) + if not isinstance(model, FullyShardedDataParallel): + model = fsdp_wrap( + model, + process_group=distributed_utils.get_data_parallel_group(), + ) else: model = task.build_model(cfg.model) - # TODO[Susan]: FSDP on criterion? - criterion = task.build_criterion(cfg.criterion) + if cfg.distributed_training.criterion_ddp_backend == "fully_sharded": + # As the task is non-trainable, we switch flags to more optimized ones. + # See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added. + orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 + orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter + # Clobber memory_efficient_fp16 and fp32_reduce_scatter + cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 + cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16 + + with fsdp_enable_wrap( + cfg.distributed_training, + use_sharded_state=cfg.distributed_training.use_sharded_state, + ): + criterion = task.build_criterion(cfg.criterion) + + # Reset memory_efficient_fp16 and fp32_reduce_scatter values. + cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16 + cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter + else: + criterion = task.build_criterion(cfg.criterion) + logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) @@ -483,6 +504,10 @@ def validate_and_save( and num_updates % cfg.dataset.validate_interval_updates == 0 and was_successful_step ) + or ( + num_updates == cfg.dataset.validate_on_first_step + and was_successful_step + ) ) and not cfg.dataset.disable_validation # Save checkpoint before validating. diff --git a/metaseq/data/cm3_dataset.py b/metaseq/data/cm3_dataset.py index 5cdeccda5..b275468f0 100644 --- a/metaseq/data/cm3_dataset.py +++ b/metaseq/data/cm3_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import random import torch from typing import List, Optional, Tuple @@ -76,6 +77,7 @@ def __init__( to_skip=0, permute_documents=True, source_target=False, + percent_full_document_rotation: float = 0.0 ): super().__init__( dataset, @@ -106,6 +108,7 @@ def __init__( self.sentinel_fixed = self.sentinel_method == "fixed" self.allow_rotation_across_eod = allow_rotation_across_eod self.eod = eod + self.percent_full_document_rotation = percent_full_document_rotation def get_sentinel(self, i): return self.sentinel_tokens[i] @@ -139,7 +142,8 @@ def sentinel_targets(self, document: torch.Tensor, spans: List[Tuple[int, int]]) index = index + size + 1 return target - def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]: + def get_spans_to_mask(self, document_length: int, document_boundaries: List[Tuple[int, int]]) -> List[ + Tuple[int, int]]: # Ok, we do not use a budget here but instead # our goal is to sample from ~ U[0,1] in the case of len(sentinel_tokens) = 1 # If len(sentinel_tokens) > 1 we try to find len(sentinel_tokens) non intersecting spans @@ -156,11 +160,14 @@ def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]: if len_sentinel_tokens == 0: return None if len_sentinel_tokens == 1: + if np.random.random() < self.percent_full_document_rotation: + return [random.choice(document_boundaries)] + start, end = np.random.uniform(size=2) if end < start: start, end = end, start # round down - start = int(start * document_length) + start = max(1, int(start * document_length)) # round up end = int(end * document_length + 0.5) if start == end: @@ -168,6 +175,8 @@ def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]: else: assert start < end return [(start, end)] + if len_sentinel_tokens < len(document_boundaries) and np.random.random() < self.percent_full_document_rotation: + return random.sample(document_boundaries, len_sentinel_tokens) # Let's implement the general case. We will create len(self.sentinel_tokens) ** 2 possible candidates # And we will filter one by one to insure no intersections. If we can't find anything then so be it. @@ -200,24 +209,31 @@ def get_document_boundaries(self, item: torch.Tensor): boundaries = boundaries + [item.size(0)] spans = [] for i in range(1, len(boundaries)): - spans.append((boundaries[i - 1], boundaries[i])) + spans.append((boundaries[i - 1] + 1, boundaries[i])) return spans + def cm3_shuffle(self, item): + assert len(item) > 0 + document_boundaries = self.get_document_boundaries(item) + spans = self.get_spans_to_mask(len(item), document_boundaries) + if not self.allow_rotation_across_eod and spans is not None: + spans = adjust_spans(spans, document_boundaries) + if spans is None: + return item + else: + spans = self.get_ordered_spans(spans) + causal_source = self.sentinel_masking(item, spans) + causal_masked = self.sentinel_targets(item, spans) + + total_count = len(causal_source) + len(causal_masked) + total_diff = total_count - self.tokens_per_sample + total_causal_length = len(causal_source) - total_diff + return torch.cat([ + causal_source[:total_causal_length], + causal_masked + ])[: self.tokens_per_sample] # EOSS tokens can add just enough tokens to get off by 1-2. + def __iter__(self): for packed_item in super().__iter__(): - item = packed_item["block"] - assert len(item) > 0 - spans = self.get_spans_to_mask(len(item)) - if not self.allow_rotation_across_eod: - document_boundaries = self.get_document_boundaries(item) - spans = adjust_spans(spans, document_boundaries) - if spans is None: - yield packed_item - else: - spans = self.get_ordered_spans(spans) - causal_source = self.sentinel_masking(item, spans) - causal_masked = self.sentinel_targets(item, spans) - packed_item["block"] = torch.cat([causal_source, causal_masked])[ - : self.tokens_per_sample - ] # EOSS tokens can add just enough tokens to get off by 1-2. - yield packed_item + packed_item["block"] = self.cm3_shuffle(packed_item["block"]) + yield packed_item diff --git a/metaseq/data/data_utils.py b/metaseq/data/data_utils.py index d38dd5d36..779f182cc 100644 --- a/metaseq/data/data_utils.py +++ b/metaseq/data/data_utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from pathlib import Path try: from collections.abc import Iterable @@ -279,6 +280,8 @@ def _find_extra_valid_paths(dataset_path: str) -> set: for sub_dir in paths: if "://" in sub_dir: continue + if not Path(sub_dir).is_dir(): + continue contents = PathManager.ls(sub_dir) valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] all_valid_paths |= {os.path.basename(p) for p in valid_paths} diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index e2843a798..819fe31f3 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -253,6 +253,10 @@ class DistributedTrainingConfig(MetaseqDataclass): default="none", metadata={"help": "If set to fully_sharded, will fsdp wrap task."}, ) + criterion_ddp_backend: TASK_DDP_BACKEND_CHOICES = field( + default="none", + metadata={"help": "If set to fully_sharded, will fsdp wrap task."}, + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) @@ -375,6 +379,9 @@ class DatasetConfig(MetaseqDataclass): validate_interval_updates: int = field( default=0, metadata={"help": "validate every N updates"} ) + validate_on_first_step: int = field( + default=-1, metadata={"help": "validate on first step. default not to validate."} + ) validate_after_updates: int = field( default=0, metadata={"help": "dont validate until reaching this many updates"} ) diff --git a/metaseq/dataclass/constants.py b/metaseq/dataclass/constants.py index b13e83974..e13ec7a88 100644 --- a/metaseq/dataclass/constants.py +++ b/metaseq/dataclass/constants.py @@ -40,6 +40,8 @@ def ChoiceEnum(choices: List[str]): "c10d", # alias for pytorch_ddp "fully_sharded", # FullyShardedDataParallel from fairscale "pytorch_ddp", + "no_c10d", + "legacy_ddp", ] ) diff --git a/metaseq/distributed/fully_sharded_data_parallel.py b/metaseq/distributed/fully_sharded_data_parallel.py index d85f5823d..2fb82e3f8 100644 --- a/metaseq/distributed/fully_sharded_data_parallel.py +++ b/metaseq/distributed/fully_sharded_data_parallel.py @@ -87,6 +87,8 @@ def fsdp_enable_wrap( cfg: DistributedTrainingConfig, use_sharded_state: bool = False, **kwargs ): try: + # from torch.distributed.fsdp.wrap import enable_wrap + # from torch.distributed.fsdp import MixedPrecision from fairscale.nn import enable_wrap except ImportError: raise ImportError( diff --git a/metaseq/distributed/legacy_distributed_data_parallel.py b/metaseq/distributed/legacy_distributed_data_parallel.py new file mode 100644 index 000000000..56a31226d --- /dev/null +++ b/metaseq/distributed/legacy_distributed_data_parallel.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A modified version of the legacy DistributedDataParallel module that uses c10d +communication primitives. This version is simpler than the latest PyTorch +version and is useful for debugging. Notably it does not overlap gradient +communication with the backward pass, which makes it slower but more robust +than the PyTorch version. + +This version also supports the *no_sync* context manager, which allows faster +training with `--update-freq`. +""" + +from collections import OrderedDict +from contextlib import contextmanager + +import torch +from torch import nn + +from metaseq.distributed import utils + + +class LegacyDistributedDataParallel(nn.Module): + """Implements distributed data parallelism at the module level. + + A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. + This version uses a c10d process group for communication and does not + broadcast buffers. + + Args: + module (~torch.nn.Module): module to be parallelized + process_group: the c10d process group to be used for distributed data + parallel all-reduction. + buffer_size (int, optional): number of elements to buffer before + performing all-reduce (default: 256M). + """ + + def __init__(self, module, process_group, buffer_size=2**28): + super().__init__() + + self.module = module + self.process_group = process_group + self.world_size = utils.get_world_size(self.process_group) + + # Never use a bigger buffer than the number of model params + self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) + self.buffer = None + + # We can also forcibly accumulate grads locally and only do the + # all-reduce at some later time + self.accumulate_grads = False + + # make per-device lists of parameters + paramlists = OrderedDict() + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + self.per_device_params = list(paramlists.values()) + + @contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + old_accumulate_grads = self.accumulate_grads + self.accumulate_grads = True + yield + self.accumulate_grads = old_accumulate_grads + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + """ + This function must be called explicitly after backward to reduce + gradients. There is no automatic hook like c10d. + """ + + def all_reduce_params(params): + buffer = self.buffer + nonzero_buffer = False + if len(params) > 1: + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) + nonzero_buffer = True + else: + buffer[offset : offset + sz].zero_() + offset += sz + else: + # we only have a single grad to all-reduce + p = params[0] + if p.grad is not None: + buffer = p.grad.data + nonzero_buffer = True + elif p.numel() <= self.buffer.numel(): + buffer = buffer[: p.numel()] + buffer.zero_() + else: + buffer = torch.zeros_like(p) + + if nonzero_buffer: + buffer.div_(self.world_size) + + utils.all_reduce(buffer, self.process_group) + + # copy all-reduced grads back into their original place + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) + else: + p.grad = buffer[offset : offset + sz].view_as(p).clone() + offset += sz + + def reduction_fn(): + # This function only needs to be called once + if self.accumulate_grads: + return + + if self.buffer is None: + self.buffer = next(self.module.parameters()).new(self.buffer_size) + + for params in self.per_device_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + + if hasattr(param, "base_expert"): + # Skip gradient sync for unshared parameters + continue + + if hasattr(param, "expert"): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.data.div_(self.world_size) + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz + + if len(buffered_params) > 0: + all_reduce_params(buffered_params) + + reduction_fn() diff --git a/metaseq/distributed/stitch_fsdp_ckpt.py b/metaseq/distributed/stitch_fsdp_ckpt.py index 5990d9810..632c9c898 100644 --- a/metaseq/distributed/stitch_fsdp_ckpt.py +++ b/metaseq/distributed/stitch_fsdp_ckpt.py @@ -92,7 +92,7 @@ def consolidate_fsdp_shards( do_consolidate = False if do_consolidate: num_parts = find_num_parts(names) - if num_parts: + if num_parts > 1: logger.info("consolidate_model_parallel") consolidated_weights = consolidate_model_parallel( metadata, diff --git a/metaseq/distributed/utils.py b/metaseq/distributed/utils.py index b1d559a2a..7d8147502 100644 --- a/metaseq/distributed/utils.py +++ b/metaseq/distributed/utils.py @@ -113,7 +113,17 @@ def _infer_single_node_init(cfg: DistributedTrainingConfig): assert ( cfg.distributed_world_size <= torch.cuda.device_count() ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" - port = random.randint(10000, 20000) + + def find_free_port(): + import socket + from contextlib import closing + + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return str(s.getsockname()[1]) + + port = find_free_port() cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) diff --git a/metaseq/models/distributed_model.py b/metaseq/models/distributed_model.py index 202af72a1..874bb44a4 100644 --- a/metaseq/models/distributed_model.py +++ b/metaseq/models/distributed_model.py @@ -11,6 +11,9 @@ from metaseq.distributed import ( ModuleProxyWrapper, ) +from metaseq.distributed.legacy_distributed_data_parallel import ( + LegacyDistributedDataParallel, +) logger = logging.getLogger(__name__) @@ -51,6 +54,14 @@ def DistributedModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: + wrapped_model = LegacyDistributedDataParallel( + module=model.to(device), + buffer_size=2**28, + process_group=process_group, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend == "fully_sharded": try: from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP diff --git a/metaseq/modules/apex/fused_layer_norm.py b/metaseq/modules/apex/fused_layer_norm.py new file mode 100644 index 000000000..cb5c6962f --- /dev/null +++ b/metaseq/modules/apex/fused_layer_norm.py @@ -0,0 +1,230 @@ +import importlib +import numbers + +import torch +from torch.nn.parameter import Parameter +from torch.nn import init + + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + + +# Reference implementation from Huggingface +def manual_rms_norm(input, normalized_shape, weight, eps): + # layer norm should always be calculated in float32 + dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1)) + variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) + input = input * torch.rsqrt(variance + eps) + + if weight is None: + return input + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + input = input.to(weight.dtype) + + return weight * input + + +def fused_rms_norm(input, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedRMSNormFunction.apply(*args) + + +def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) + with torch.cuda.amp.autocast(enabled=False): + return FusedRMSNormAffineFunction.apply(*args) + + +class FusedRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) + ctx.save_for_backward(input_, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, invvar = ctx.saved_tensors + grad_input = None + grad_input = fused_layer_norm_cuda.rms_backward( + grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps + ) + return grad_input, None, None + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + return grad_input, grad_weight, grad_bias, None, None + + +class FusedRMSNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine( + input_, ctx.normalized_shape, weight_, ctx.eps) + ctx.save_for_backward(input_, weight_, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, invvar = ctx.saved_tensors + grad_input = grad_weight = None + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( + grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps + ) + return grad_input, grad_weight, None, None + + +class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + +class FusedRMSNorm(torch.nn.Module): + r"""Applies RMS Normalization over a mini-batch of inputs + + Currently only runs on cuda() tensors. + + .. math:: + y = \frac{x}{\mathrm{RMS}[x]} * \gamma + + The root-mean-square is calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\gamma` is a learnable affine transform parameter of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + `epsilon` is added to the mean-square, then the root of the sum is taken. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, RMS Normalization applies per-element scale + with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] + \times \ldots \times \text{normalized}\_\text{shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 5, 10, 10) + >>> # With Learnable Parameters + >>> m = apex.normalization.FusedRMSNorm(input.size()[1:]) + >>> # Without Learnable Parameters + >>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False) + >>> # Normalize over last two dimensions + >>> m = apex.normalization.FusedRMSNorm([10, 10]) + >>> # Normalize over last dimension of size 10 + >>> m = apex.normalization.FusedRMSNorm(10) + >>> # Activating the module + >>> output = m(input) + + .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf + """ + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__() + + global fused_layer_norm_cuda + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(*normalized_shape)) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + + def forward(self, input): + if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: + return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) + + if self.elementwise_affine: + return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + else: + return fused_rms_norm(input, self.normalized_shape, self.eps) + + def extra_repr(self): + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) diff --git a/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax.h b/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax.h index 25dceeafb..5d8325ee1 100644 --- a/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax.h +++ b/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax.h @@ -523,6 +523,9 @@ void dispatch_scaled_softmax_forward( scaled_softmax_warp_forward <<>>(dst, src, scale, batch_count, key_seq_len); break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); default: break; } @@ -541,7 +544,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); if (key_seq_len == 0) { return; } else { @@ -617,6 +620,10 @@ void dispatch_scaled_masked_softmax_forward( scaled_masked_softmax_warp_forward <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; default: break; } @@ -634,7 +641,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 ); + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); if (key_seq_len == 0) { return; } else { @@ -709,7 +716,10 @@ void dispatch_scaled_masked_softmax_backward( scaled_masked_softmax_warp_backward <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; - + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; default: break; } diff --git a/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 34c7c0850..1cee41f3b 100644 --- a/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/metaseq/modules/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -44,7 +44,7 @@ torch::Tensor fwd_cuda( const int attn_heads = input.size(1); const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); TORCH_INTERNAL_ASSERT(query_seq_len > 1); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(mask.size(1) == 1); diff --git a/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index add6323c9..c127731ae 100644 --- a/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); if (softmax_elements == 0) { return; } else { @@ -415,6 +415,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( scaled_upper_triang_masked_softmax_warp_forward <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; default: break; } @@ -431,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); if (softmax_elements == 0) { return; } else { @@ -506,6 +514,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( scaled_upper_triang_masked_softmax_warp_backward <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; default: break; } diff --git a/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu index ea868e2a9..033db16a4 100644 --- a/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ b/metaseq/modules/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -35,7 +35,7 @@ torch::Tensor fwd_cuda( // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); + TORCH_INTERNAL_ASSERT(seq_len <= 8192); // Output auto act_options = input.options().requires_grad(false); diff --git a/metaseq/modules/megatron/mpu/__init__.py b/metaseq/modules/megatron/mpu/__init__.py index 6df605b45..4954cb5bc 100644 --- a/metaseq/modules/megatron/mpu/__init__.py +++ b/metaseq/modules/megatron/mpu/__init__.py @@ -12,11 +12,13 @@ from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_world_size from .initialize import initialize_model_parallel +from .initialize import get_data_parallel_world_size from .layers import LinearWithGradAccumulationAndAsyncCommunication from .layers import ColumnParallelLinear from .layers import RowParallelLinear from .layers import VocabParallelEmbedding +from .layers import ParallelEmbedding from .mappings import copy_to_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region diff --git a/metaseq/modules/megatron/mpu/initialize.py b/metaseq/modules/megatron/mpu/initialize.py index 4e055cd2c..f1b82c369 100644 --- a/metaseq/modules/megatron/mpu/initialize.py +++ b/metaseq/modules/megatron/mpu/initialize.py @@ -241,6 +241,11 @@ def get_data_parallel_group(): return _DATA_PARALLEL_GROUP +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE diff --git a/metaseq/modules/megatron/mpu/layers.py b/metaseq/modules/megatron/mpu/layers.py index 2b13aa7ce..4fe6b8ecb 100644 --- a/metaseq/modules/megatron/mpu/layers.py +++ b/metaseq/modules/megatron/mpu/layers.py @@ -5,6 +5,8 @@ # # Taken from: # https://github.com/ngoyal2707/Megatron-LM/blob/fa6c0860b62e4ed2ac13a513e7d950d72f576a44/megatron/mpu/layers.py +from collections.abc import Callable +from typing import Optional import torch import torch.nn.functional as F @@ -250,6 +252,72 @@ def forward(self, input_): return output +class ParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the embedding dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, + keep_master_weight_for_test: bool = False, + ) -> None: + super(ParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = scale_grad_by_freq + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self._weight = None + # Divide the weight matrix along the embedding dimension. + world_size = get_tensor_model_parallel_world_size() + self.embedding_dim_per_partition = self.embedding_dim // world_size + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) + # And initialize. + _initialize_affine_weight_cpu( + self.weight, + self.num_embeddings, + self.embedding_dim, + self.embedding_dim_per_partition, + 1, + init_method, + stride=1, + return_master_weight=False, + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore + input_parallel = copy_to_tensor_model_parallel_region(input_) + output_parallel = F.embedding( + input_parallel, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + output = gather_from_tensor_model_parallel_region(output_parallel) + return output + + class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication and gradient accumulation diff --git a/metaseq/optim/dynamic_loss_scaler.py b/metaseq/optim/dynamic_loss_scaler.py index 9a3a44e28..d4e8811c4 100644 --- a/metaseq/optim/dynamic_loss_scaler.py +++ b/metaseq/optim/dynamic_loss_scaler.py @@ -68,8 +68,16 @@ def check_overflow(self, grad_norm): self._overflows_since_rescale = 0 if self.loss_scale < self.min_loss_scale: - # Don't scale down past min_loss_scale, just continue to skip grad after overflow error is raised. + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. self.loss_scale = prev_scale + raise FloatingPointError( + ( + "Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try lowering the learning rate, using gradient clipping or " + "increasing the batch size." + ).format(self.min_loss_scale) + ) self._iter += 1 raise OverflowError("setting loss scale to: " + str(self.loss_scale)) diff --git a/metaseq/tasks/streaming_language_modeling.py b/metaseq/tasks/streaming_language_modeling.py index ab83ce678..be06b2ebb 100644 --- a/metaseq/tasks/streaming_language_modeling.py +++ b/metaseq/tasks/streaming_language_modeling.py @@ -8,7 +8,9 @@ """ import logging +import random import os +import re from dataclasses import dataclass, field from typing import Any, Dict, List, Optional @@ -44,10 +46,33 @@ DEFAULT_MULTICORPUS_MAX = -1 -LANGUAGE_MODELING_MODE = ChoiceEnum(["standard", "cm3"]) +LANGUAGE_MODELING_MODE = ChoiceEnum(["standard", "cm3", "racm3"]) CM3_MODE = ChoiceEnum(["poisson", "fixed", "fim"]) +def map_old_image_token_to_new_image_token(text): + text = text.replace("I", "IMGIMG") + for i in range(10): + text = text.replace(str(i), chr(ord("A") + i)) + return text.replace(" ", "Z") + + +def map_new_image_token_to_old_image_token(text): + text = text.replace("Z", " ") + for i in range(10): + text = text.replace(chr(ord("A") + i), str(i)) + return text.replace("IMGIMG", "I") + + +def parse_doc(doc): + obj = re.match(r'(.*?)', doc) + if obj is None: + raise ValueError(f"doc not correct formated: {doc}") + text, image = obj.group(1), obj.group(2) + result = {"text": text, "image": image} + return result + + @dataclass class StreamingLanguageModelingConfig(MetaseqDataclass): data: Optional[str] = field( @@ -134,7 +159,7 @@ class StreamingLanguageModelingConfig(MetaseqDataclass): }, ) cm3_allow_across_eod_boundaries: bool = field( - default=True, + default=False, metadata={ "help": "Whether or not we allow rotation of documents across documents" "(especially when training with token blocking set to None)." @@ -142,6 +167,15 @@ class StreamingLanguageModelingConfig(MetaseqDataclass): "For FIM it's unclear whether or not they allow this." }, ) + cm3_percent_full_document_rotation: float = field( + default=0.0, + metadata={ + "help": "What percent of the time to rotate full documents while still abiding by the number of sentinel tokens used." + }, + ) + num_retrieved_doc: int = field( + default=2, metadata={"help": "number of retrieved documents"} + ) # TODO common vars below add to parent seed: int = II("common.seed") batch_size: Optional[int] = II("dataset.batch_size") @@ -209,11 +243,12 @@ def __init__(self, args): assert self.dictionary.unk_index == 3 assert self.tokenizer.id_to_token(3) in {"", ""} - self.has_cm3 = args.language_modeling_type == "cm3" + self.has_cm3 = args.language_modeling_type in ["cm3", "racm3"] + self.has_retrieval = args.language_modeling_type == "racm3" if self.has_cm3: - self.cm3_sentinel_type = self.args.cm3_mode self._check_cm3_parameterization() self._create_cm3_special_tokens() + self.cm3_sentinel_type = self.args.cm3_mode final_vocab_size = args.final_vocab_size if final_vocab_size is not None: @@ -234,7 +269,7 @@ def _check_cm3_parameterization(self): ), "cm3_num_sentinel_tokens must be > 0" assert ( self.args.cm3_num_sentinel_tokens >= self.args.cm3_lambda_sentinel_tokens - ), "cm3_lambda_sentinel_tokens must be >= cm3_num_sentinel_tokens" + ), "cm3_lambda_sentinel_tokens must be > cm3_num_sentinel_tokens" if self.args.cm3_mode == "fim": assert ( self.args.cm3_num_sentinel_tokens == 1 @@ -246,6 +281,10 @@ def _check_cm3_parameterization(self): def _create_cm3_special_tokens(self): self.cm3_sentinel_end = "" + self.cm3_break = "" + self.dictionary.add_symbol(self.cm3_break) + self.dictionary.add_symbol(self.cm3_sentinel_end) + # self.cm3_break_ind = self.dictionary.index(self.cm3_break) self.cm3_sentinel_tokens = [ f"" for i in range(self.args.cm3_num_sentinel_tokens) ] @@ -256,6 +295,7 @@ def _create_cm3_special_tokens(self): assert token_index != self.dictionary.unk_index self.cm3_sentinel_tokens_ind.append(token_index) self.cm3_sentinel_end_ind = self.dictionary.index(self.cm3_sentinel_end) + self.cm3_break_ind = self.dictionary.index(self.cm3_break) @classmethod def setup_task(cls, args, **kwargs): @@ -269,6 +309,42 @@ def _tokenize_one_json(self, json): + [self.eod] ) + def tokenize_single_doc(self, doc, add_eod=False): + doc = parse_doc(doc) + text, image = doc["text"], doc["image"] + image = map_old_image_token_to_new_image_token(image) + text_indexes, image_indexes = ( + self.tokenizer.encode(text.rstrip()).ids, + self.tokenizer.encode(image.rstrip()).ids, + ) + assert ( + len(image_indexes) == 1024 + ), f"Each image must be 1024 tokens, instead we got {len(image_indexes)}" + assert all( + [i > 4 for i in image_indexes] + ), f"Images should not have any special tokens: {image_indexes}" + indexes = text_indexes + [self.cm3_break_ind] + image_indexes + if add_eod: + indexes = indexes + [self.eod] + return indexes + + def _tokenize_ra_json(self, json): + query_index = self.tokenize_single_doc(json["text"], add_eod=True) + query_index = torch.LongTensor(query_index) + ra_docs = json["retrieved_docs_from_img"] + json["retrieved_docs_from_txt"] + random.shuffle(ra_docs) + + ra_docs = ra_docs[: self.args.num_retrieved_doc] + ra_indexes = [] + for ra_doc in ra_docs: + ra_index = self.tokenize_single_doc(ra_doc, add_eod=False) + ra_index = torch.LongTensor(ra_index + [self.cm3_break_ind]) + ra_indexes.append(ra_index) + final_indexes = torch.cat( + [torch.LongTensor([self.eod])] + ra_indexes + [query_index] + ) + return final_indexes + def _get_sample_prob(self, dataset_lens): """ Get smoothed sampling porbability by corpus. This helps small corpus by upsampling them. @@ -400,7 +476,10 @@ def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): datasets.append( JsonlDataset( path=os.path.join(self.args.data, split, cur_shard_str, file), - tokenizer=self._tokenize_one_json, + # tokenizer=self._tokenize_one_json, + tokenizer=self._tokenize_ra_json + if self.has_retrieval + else self._tokenize_one_json, epoch=epoch, data_subshard_count=data_subshard_count, ) @@ -424,7 +503,7 @@ def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): sentinel_method=self.cm3_sentinel_type, sentinel_eos=self.cm3_sentinel_end_ind, allow_rotation_across_eod=self.args.cm3_allow_across_eod_boundaries, - eod=self.eod, + eod=self.cm3_break_ind, dataset=dataset, # We generate blocks with one extra token, so that we have a target # for the final input token. This results in slight data loss. @@ -434,6 +513,7 @@ def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): drop_last=(split == "train"), padding_idx=self.source_dictionary.pad(), seed=self.args.seed, + percent_full_document_rotation=self.args.cm3_percent_full_document_rotation ) else: self.datasets[split] = DocumentToSequenceDataset( diff --git a/metaseq/trainer.py b/metaseq/trainer.py index bb918170d..e5494ad74 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -24,7 +24,7 @@ from omegaconf import OmegaConf from metaseq import checkpoint_utils, models, optim, utils -from metaseq.distributed import utils as distributed_utils, fsdp_enable_wrap, fsdp_wrap +from metaseq.distributed import utils as distributed_utils, fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from metaseq.file_io import PathManager from metaseq.logging import meters, metrics from metaseq.models.ema import build_ema @@ -244,7 +244,9 @@ def _build_ema(self): "use_sharded_state": self.use_sharded_state, } with fsdp_enable_wrap(self.cfg.distributed_training, **extra): - model = fsdp_wrap(self.task.build_model(self.cfg.model)) + model = self.task.build_model(self.cfg.model) + if not isinstance(model, FullyShardedDataParallel): + model = fsdp_wrap(model) if self.cfg.common.memory_efficient_fp16: if self.cfg.common.bf16: