Skip to content

Commit 20f988e

Browse files
bm-synthtjruwaseloadams
authored
Variable batch size and LR scheduler (#7104)
# Background and rationale In many use cases, particularly LLMs, one is faced with inputs (sentences) of variable lengths. A common practice is to pack batches by token count (not a fixed batch size), ie by putting together sentences whose given metric (eg sequence lengths) will add up to an user-provided value. As an example, in [Attention is all you need](https://arxiv.org/abs/1706.03762), section 5.1: > Sentence pairs were batched together by approximate sequence length. Each training batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000 target tokens. Dynamic batch sizes has been requested in [DeepSpeed issue 1051](#1051), [DeepSpeed issue 3455 ](#3455), [Pytorch Lightning issue 16914](Lightning-AI/pytorch-lightning#16914), [huggingface issue 2647](huggingface/accelerate#2647) and is available already in many libraries e.g. [NVIDIA Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher) and [Meta FairSeq](https://github.com/facebookresearch/fairseq) (implementation [here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104) ). The immediate use case for this is when one needs to maximize GPU utilization. Moreover, this is particularly relevant for curriculum learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should ideally have high `B` and low `T` at the early curriculum steps (many short sentences packed together as a batch), and low `B` and high `T` at the late steps (few long sentences in the batch). A dynamic size `T` is already supported by Deepspeed, e.g. in the documentation for pipeline parallelism's [reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape): > For curriculum learning that changes the seqlen of each sample, we need to call this whenever the seqlen is going to change. However, dynamic `B` is not supported. A dynamic `B` would require an adequate increase/decrease of learning rate. This technique has been applied previously, and the two most common LR scaling algorithms have been described as: 1. Linear Scaling Rule: "When the minibatch size is multiplied by k, multiply the learning rate by k", as in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al.](https://arxiv.org/abs/1706.02677) 2. Square Root scaling: "when multiplying the batch size by k, multiply the learning rate by √k, to keep the variance in the gradient expectation constant" by [One weird trick for parallelizing convolutional neural networks, A. Krizhevsky et al.](https://arxiv.org/abs/1404.5997) In practice, the user picks the total token count per batch as the metric that drives batching, instead of batching by sentence count. During runtime, the variable batch size is computed and the LR is adjusted respectively, based on the LR and batch size provided by the config. # Illustration of dynamic batch size, sequence length and LR Imagine we picked a limit of `30` tokens per batch, and have set a reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed config). The batching algorithm for curriculum may pack the data into batches of short sentences (left) at the early stages, and batches of long sentences (right) as later stages, e.g.: ![dynamic_batch_size_and_lr](https://github.com/microsoft/DeepSpeed/assets/150697676/324bda09-8f0b-430c-bb33-cc1bd01c3fe7) Above, we collected samples until we filled up the batch with at most 30 tokens. The batch sizes (number of samples) became then `10` and `4` on the left and right examples, respectively. Using the linear scaling rule, the LR for those batches become `5e-3` and `2e-3`. # Pipeline parallelism Pipeline parallelism requires the same batch size and same sequence length across all micro-batches in a batch, as the activation sizes must be fixed between gradient accumulation steps. Between batches, these may change, and long as `engine.reset_activation_shape()` is called so that the new shapes are communicated on the first gradient accumulation step in the batch. Enforcing similar `BxTxE` between batches may lead to smaller micro-batches. As an example, below we can see an illustration of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching for the same dataset, when preparing data for the regular DDP (left) and for the pipeline parallelism use cases (right): ![dynamic_batch_size_and_lr_microbatching](https://github.com/microsoft/DeepSpeed/assets/150697676/3fed5e1c-f2f5-4efe-a9c5-5b5e20719d45) We can see that the pipeline use case (right) has the same `BxTxE` shape across all the 4 micro-batches in the same batch, and in order to respect that, it packs less samples in the batch, when compared to the standard use case (left hand size) # Attention Head For an input of size `BxTxE` the attention has a shape of `TxT` for a mask of fixed size across samples of same size, or `BxTxT` for a different mask per sample (when samples have different sizes, as in the dataset above). This 3D attention matrix can be illustrated for the DDP microbatch 1 (picture above top-left, 4 sentences) as: ![dynamic_batch_size_and_lr_attn_matrix](https://github.com/microsoft/DeepSpeed/assets/150697676/707d2f17-66da-4034-8a12-a87df2044bfb) Note the memory savings: the attention head has a size of `BxTxT`, i.e. a linear memory dependency on the batch size `B` and quadratic memory dependency on the largest sequence length `T` in the (micro-) batch. Thus, supporting a dynamic size `T` allows for an increase of `B`. # PR overview This PRs implements dynamic batching and LR scaling. The dataloader and LR scheduler necessary can be retrieved by calling `get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small explanation of that function follows: - The logic behind the algorithms for LR scaling is in `scale_lr`; - The partitioning of samples into batches is done by `batch_by_seqlen`. - For pipeline parallelism, it is required that all micro-batches in a pipeline pass to have the same activation shapes. This is enabled by setting to `True` the following parameters: - `required_microbatches_of_same_sizes` that will force the `B` dimension to be the same across all gradient accumulation steps of all dataloaders on a batch; - `required_microbatches_of_same_lengths` that will force the `T` dimension to be the same across all gradient accumulation steps. Works by calling the user-provided `sample_padding_fn(sentence, len)` that pads a given sentence to the argument length; - `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample ids per micro-batch), `batch_sizes` (the size of effective batch sizes, and `batch_max_seqlens` (longest sequence across all microbatches in a batch) - `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids` and will iterate/collate/pad samples for every batch and return a dataloader that iterates the final (variable-size) batches; - `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to compute the learning rate for each effective batch, taking into account the batch size and LR in the config file, and scaling the LR based on the size of each effective batch, and the scaling rule mentioned above (Linear, Square root, etc). - Special note to the `lr_scheduler` returned that will either accept either: 1. an user-provided `Optimizer` that will scale the learning rates (in param groups) at every batch, or 2. an user-defined `LRScheduler`, that in this case will first get the learning rate from the scheduler and then scale it accordingly. # Example An example for the use case with and without pipelining is provided in file [`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr). The example shows an attention head with attention of variable-sized `BxTxT` per batch, followed by a fixed size feed forward network. These are the main blocks on a Large Language Model. The feed-forward (or linear layer) that follows the attention head requires a constant input size, equivalent to the largest sentence in the whole dataset, so the output of the attention must be padded (see `feedforward: needs to convert BxTxE to BxMxE by padding extra tokens` in the code). # Config The example file also comments the relevant deepspeed config with comments: ```python config = { "train_batch_size": 16, # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together. # I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`. "train_micro_batch_size_per_gpu": 2, "data_efficiency": { "enabled": True, # seed to be applied to all data efficiency modules, including dynamic batching "seed": 42, "data_sampling": { "num_workers": 0, # dataloader num_workers argument "pin_memory": False, # dataloader pin_memory argument "dynamic_batching": { # enables or disables dynamic batching "enabled": True, # how many tokens we need to fill a pack of sequences (that will be collated together as a sample) "max_tokens": 100, # Input and output write to read from or write the length of every sequence. # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs. "metrics_path": "./curriculum_output/", # As batch size increases/decreses, which method to use to scale LR accordingly? # Options: linear, sqrt (square root), or None to disable "lr_scaling_method": "linear", # how to pick sentences to be packed into samples: # - dataloader: by same order as they come in with the dataloader # - seqlen: by sequence length (shortest to longest) # - random: random order using the seed in config['data_efficiency']['seed' "sentence_picking_order": "dataloader", # "random" / "seqlen" / "dataloader" # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded. "min_batch_size": 1, # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded. "max_batch_size": 10, # enable the output of microbatching information about sentence packing "verbose": True, }, }, }, } ``` # Future work A follow-up PR will enable dynamic batching when calling `deepspeed.initialize`. I.e. instead of this: ```python engine, _, _, _ = deepspeed.initialize(config=config, model=model) dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...) engine.lr_scheduler = lr_scheduler ``` we'd ideally have this: ```python engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model) ``` where `initialize` will call internally `get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`. --------- Signed-off-by: Bruno Magalhaes <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]>
1 parent ac295aa commit 20f988e

File tree

7 files changed

+553
-11
lines changed

7 files changed

+553
-11
lines changed

deepspeed/launcher/runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def main(args=None):
483483
result = subprocess.check_output(hostname_cmd)
484484
except subprocess.CalledProcessError as err:
485485
logger.error(
486-
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
486+
"Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr"
487487
)
488488
raise err
489489
args.master_addr = result.decode('utf-8').split()[0]

deepspeed/runtime/config.py

-1
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,6 @@ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
801801

802802
def _initialize_params(self, param_dict):
803803
self.train_batch_size = get_train_batch_size(param_dict)
804-
#print(f"beginning get_train_batch_size = {get_train_batch_size}")
805804
self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
806805
self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
807806
self.steps_per_print = get_steps_per_print(param_dict)

deepspeed/runtime/data_pipeline/config.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def get_data_efficiency_config(param_dict):
2020
sub_param_dict = param_dict[DATA_EFFICIENCY]
2121
output[DATA_SAMPLING] = get_data_sampling(sub_param_dict)
2222
output[DATA_ROUTING] = get_data_routing(sub_param_dict)
23-
2423
return output
2524

2625

@@ -39,15 +38,14 @@ def get_data_efficiency_seed(param_dict):
3938

4039

4140
def get_data_sampling(param_dict):
42-
output = {}
41+
sub_param_dict = param_dict.get(DATA_SAMPLING, {})
42+
output = copy.copy(sub_param_dict)
4343
output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict)
4444
output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict)
4545
output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict)
46-
if DATA_SAMPLING not in param_dict.keys():
47-
param_dict[DATA_SAMPLING] = {}
48-
sub_param_dict = param_dict[DATA_SAMPLING]
46+
output[DATA_SAMPLING_PIN_MEMORY] = get_data_sampling_pin_memory(param_dict)
4947
output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict)
50-
48+
output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict)
5149
return output
5250

5351

@@ -73,6 +71,13 @@ def get_data_sampling_num_workers(param_dict):
7371
return DATA_SAMPLING_NUM_WORKERS_DEFAULT
7472

7573

74+
def get_data_sampling_pin_memory(param_dict):
75+
if DATA_SAMPLING in param_dict.keys():
76+
return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_PIN_MEMORY, DATA_SAMPLING_PIN_MEMORY_DEFAULT)
77+
else:
78+
return DATA_SAMPLING_PIN_MEMORY_DEFAULT
79+
80+
7681
def get_curriculum_learning(param_dict):
7782
output = {}
7883
output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict)
@@ -87,6 +92,26 @@ def get_curriculum_learning(param_dict):
8792
return output
8893

8994

95+
def get_dynamic_batching(param_dict):
96+
output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {}))
97+
output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT))
98+
output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str(
99+
output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT))
100+
output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int(
101+
output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT))
102+
output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \
103+
if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \
104+
else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT
105+
output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str(
106+
output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT))
107+
if output[DYNAMIC_BATCHING_ENABLED]:
108+
assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys(
109+
), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified"
110+
output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS])
111+
output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False))
112+
return output
113+
114+
90115
def get_curriculum_learning_enabled(param_dict):
91116
if CURRICULUM_LEARNING in param_dict.keys():
92117
return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,

deepspeed/runtime/data_pipeline/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
DATA_SAMPLING_NUM_EPOCHS_DEFAULT = 1000
2323
DATA_SAMPLING_NUM_WORKERS = "num_workers"
2424
DATA_SAMPLING_NUM_WORKERS_DEFAULT = 0
25+
DATA_SAMPLING_PIN_MEMORY = "pin_memory"
26+
DATA_SAMPLING_PIN_MEMORY_DEFAULT = False
2527

2628
#########################################
2729
# Data efficiency - Data Sampling - Curriculum Learning
@@ -62,6 +64,24 @@
6264
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION = "data_cluster_current_position"
6365
CURRICULUM_LEARNING_NP_RNG_STATE = "np_rng_state"
6466

67+
#########################################
68+
# Data efficiency - Dynamic batching and LR scaling
69+
#########################################
70+
DYNAMIC_BATCHING = "dynamic_batching"
71+
DYNAMIC_BATCHING_ENABLED = "enabled"
72+
DYNAMIC_BATCHING_ENABLED_DEFAULT = False
73+
DYNAMIC_BATCHING_METRICS_PATH = "metrics_path"
74+
DYNAMIC_BATCHING_LR_SCALING_METHOD = "lr_scaling_method" # "linear" / "sqrt" / "none"
75+
DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT = "linear"
76+
DYNAMIC_BATCHING_MIN_BATCH_SIZE = "min_batch_size"
77+
DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT = 1
78+
DYNAMIC_BATCHING_MAX_BATCH_SIZE = "max_batch_size"
79+
DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT = None
80+
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER = "sequence_picking_order" # "random" / "seqlen" / "dataloader"
81+
DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT = "dataloader" # "random" / "seqlen" / "dataloader"
82+
DYNAMIC_BATCHING_MAX_TOKENS = "max_tokens"
83+
DYNAMIC_BATCHING_VERBOSE = "verbose"
84+
6585
#########################################
6686
# Curriculum Learning legacy implementation
6787
#########################################

deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,13 @@ def test_compare_both_data_analyzers(dataset):
862862
for path in output_paths:
863863
with open(os.path.join(da.save_path, path), 'rb') as f1, \
864864
open(os.path.join(dda.save_path, path), 'rb') as f2:
865-
if f1.read() != f2.read():
865+
# if files have suffix .bin, they should be identical
866+
if path.endswith(".bin"):
867+
assert f1.read() == f2.read(), f"files {path} are not identical."
868+
elif f1.read() != f2.read():
866869
print(f"files {path} are not identical.")
870+
dist.barrier()
871+
dist.destroy_process_group()
867872

868873

869874
if __name__ == "__main__":

0 commit comments

Comments
 (0)