You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Background and rationale
In many use cases, particularly LLMs, one is faced with inputs
(sentences) of variable lengths. A common practice is to pack batches by
token count (not a fixed batch size), ie by putting together sentences
whose given metric (eg sequence lengths) will add up to an user-provided
value. As an example, in [Attention is all you
need](https://arxiv.org/abs/1706.03762), section 5.1:
> Sentence pairs were batched together by approximate sequence length.
Each training
batch contained a set of sentence pairs containing approximately 25000
source tokens and 25000
target tokens.
Dynamic batch sizes has been requested in [DeepSpeed issue
1051](#1051), [DeepSpeed
issue 3455 ](#3455),
[Pytorch Lightning issue
16914](Lightning-AI/pytorch-lightning#16914),
[huggingface issue
2647](huggingface/accelerate#2647) and is
available already in many libraries e.g. [NVIDIA
Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher)
and [Meta FairSeq](https://github.com/facebookresearch/fairseq)
(implementation
[here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104)
).
The immediate use case for this is when one needs to maximize GPU
utilization. Moreover, this is particularly relevant for curriculum
learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should
ideally have high `B` and low `T` at the early curriculum steps (many
short sentences packed together as a batch), and low `B` and high `T` at
the late steps (few long sentences in the batch). A dynamic size `T` is
already supported by Deepspeed, e.g. in the documentation for pipeline
parallelism's
[reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape):
> For curriculum learning that changes the seqlen of each sample, we
need to call this whenever the seqlen is going to change.
However, dynamic `B` is not supported. A dynamic `B` would require an
adequate increase/decrease of learning rate. This technique has been
applied previously, and the two most common LR scaling algorithms have
been described as:
1. Linear Scaling Rule: "When the minibatch size is multiplied by k,
multiply the learning rate by k", as in [Accurate, Large Minibatch SGD:
Training ImageNet in 1 Hour, Goyal et
al.](https://arxiv.org/abs/1706.02677)
2. Square Root scaling: "when multiplying the batch size by k, multiply
the learning rate by √k, to keep the variance in the gradient
expectation constant" by [One weird trick for parallelizing
convolutional neural networks, A. Krizhevsky et
al.](https://arxiv.org/abs/1404.5997)
In practice, the user picks the total token count per batch as the
metric that drives batching, instead of batching by sentence count.
During runtime, the variable batch size is computed and the LR is
adjusted respectively, based on the LR and batch size provided by the
config.
# Illustration of dynamic batch size, sequence length and LR
Imagine we picked a limit of `30` tokens per batch, and have set a
reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed
config). The batching algorithm for curriculum may pack the data into
batches of short sentences (left) at the early stages, and batches of
long sentences (right) as later stages, e.g.:

Above, we collected samples until we filled up the batch with at most 30
tokens. The batch sizes (number of samples) became then `10` and `4` on
the left and right examples, respectively. Using the linear scaling
rule, the LR for those batches become `5e-3` and `2e-3`.
# Pipeline parallelism
Pipeline parallelism requires the same batch size and same sequence
length across all micro-batches in a batch, as the activation sizes must
be fixed between gradient accumulation steps. Between batches, these may
change, and long as `engine.reset_activation_shape()` is called so that
the new shapes are communicated on the first gradient accumulation step
in the batch. Enforcing similar `BxTxE` between batches may lead to
smaller micro-batches. As an example, below we can see an illustration
of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching
for the same dataset, when preparing data for the regular DDP (left) and
for the pipeline parallelism use cases (right):

We can see that the pipeline use case (right) has the same `BxTxE` shape
across all the 4 micro-batches in the same batch, and in order to
respect that, it packs less samples in the batch, when compared to the
standard use case (left hand size)
# Attention Head
For an input of size `BxTxE` the attention has a shape of `TxT` for a
mask of fixed size across samples of same size, or `BxTxT` for a
different mask per sample (when samples have different sizes, as in the
dataset above). This 3D attention matrix can be illustrated for the DDP
microbatch 1 (picture above top-left, 4 sentences) as:

Note the memory savings: the attention head has a size of `BxTxT`, i.e.
a linear memory dependency on the batch size `B` and quadratic memory
dependency on the largest sequence length `T` in the (micro-) batch.
Thus, supporting a dynamic size `T` allows for an increase of `B`.
# PR overview
This PRs implements dynamic batching and LR scaling. The dataloader and
LR scheduler necessary can be retrieved by calling
`get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small
explanation of that function follows:
- The logic behind the algorithms for LR scaling is in `scale_lr`;
- The partitioning of samples into batches is done by `batch_by_seqlen`.
- For pipeline parallelism, it is required that all micro-batches in a
pipeline pass to have the same activation shapes. This is enabled by
setting to `True` the following parameters:
- `required_microbatches_of_same_sizes` that will force the `B`
dimension to be the same across all gradient accumulation steps of all
dataloaders on a batch;
- `required_microbatches_of_same_lengths` that will force the `T`
dimension to be the same across all gradient accumulation steps. Works
by calling the user-provided `sample_padding_fn(sentence, len)` that
pads a given sentence to the argument length;
- `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample
ids per micro-batch), `batch_sizes` (the size of effective batch sizes,
and `batch_max_seqlens` (longest sequence across all microbatches in a
batch)
- `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids`
and will iterate/collate/pad samples for every batch and return a
dataloader that iterates the final (variable-size) batches;
- `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to
compute the learning rate for each effective batch, taking into account
the batch size and LR in the config file, and scaling the LR based on
the size of each effective batch, and the scaling rule mentioned above
(Linear, Square root, etc).
- Special note to the `lr_scheduler` returned that will either accept
either:
1. an user-provided `Optimizer` that will scale the learning rates (in
param groups) at every batch, or
2. an user-defined `LRScheduler`, that in this case will first get the
learning rate from the scheduler and then scale it accordingly.
# Example
An example for the use case with and without pipelining is provided in
file
[`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr).
The example shows an attention head with attention of variable-sized
`BxTxT` per batch, followed by a fixed size feed forward network. These
are the main blocks on a Large Language Model. The feed-forward (or
linear layer) that follows the attention head requires a constant input
size, equivalent to the largest sentence in the whole dataset, so the
output of the attention must be padded (see `feedforward: needs to
convert BxTxE to BxMxE by padding extra tokens` in the code).
# Config
The example file also comments the relevant deepspeed config with
comments:
```python
config = {
"train_batch_size": 16,
# `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
# I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
"train_micro_batch_size_per_gpu": 2,
"data_efficiency": {
"enabled": True,
# seed to be applied to all data efficiency modules, including dynamic batching
"seed": 42,
"data_sampling": {
"num_workers": 0, # dataloader num_workers argument
"pin_memory": False, # dataloader pin_memory argument
"dynamic_batching": {
# enables or disables dynamic batching
"enabled": True,
# how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
"max_tokens": 100,
# Input and output write to read from or write the length of every sequence.
# Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
# If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
"metrics_path": "./curriculum_output/",
# As batch size increases/decreses, which method to use to scale LR accordingly?
# Options: linear, sqrt (square root), or None to disable
"lr_scaling_method": "linear",
# how to pick sentences to be packed into samples:
# - dataloader: by same order as they come in with the dataloader
# - seqlen: by sequence length (shortest to longest)
# - random: random order using the seed in config['data_efficiency']['seed'
"sentence_picking_order": "dataloader", # "random" / "seqlen" / "dataloader"
# minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
"min_batch_size": 1,
# maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
"max_batch_size": 10,
# enable the output of microbatching information about sentence packing
"verbose": True,
},
},
},
}
```
# Future work
A follow-up PR will enable dynamic batching when calling
`deepspeed.initialize`. I.e. instead of this:
```python
engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler
```
we'd ideally have this:
```python
engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)
```
where `initialize` will call internally
`get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`.
---------
Signed-off-by: Bruno Magalhaes <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
0 commit comments