Description
When running step 3 with ZERO stage 3 enabled for both the actor and critic models,
I get the following error (line numbers may be offset due to debug statements I've added):
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
out = trainer.generate_experience(prompts)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 103, in generate_experience
seq = self._generate_sequence(prompts)
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
seq = self.actor_model.module.generate(prompts,
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
self.fuse_lora_weight()
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1
This happens because the weight.data
shape does not match the tensor shape resulting from the lora matmul operation.
I am using a system with 4x 16GB V100 GPUs per node with DeepSpeed 0.9.1. I trained a 1.3b-param model in step 1 and 350m-param model in step 2.
My step 3 run command launches 4 processes on one node, binding one process per GPU:
cd training/step3_rlhf_finetuning
OUTPUT=${OUTPUTDIR}/step3-models/1.3b
mkdir -p $OUTPUT
ACTOR_MODEL_PATH=${OUTPUTDIR}/actor-models/1.3b
CRITIC_MODEL_PATH=${OUTPUTDIR}/reward-models/1.3b
ACTOR_ZERO_STAGE=3
CRITIC_ZERO_STAGE=3
jsrun -r 1 --tasks_per_rs 4 -c ALL_CPUS -g ALL_GPUS python3 main.py \
--per_device_train_batch_size 4 \
--per_device_mini_train_batch_size 4 \
--inference_tp_size 1 \
--max_answer_seq_len 256 \
--max_prompt_seq_len 256 \
--actor_model_name_or_path $ACTOR_MODEL_PATH \
--critic_model_name_or_path $CRITIC_MODEL_PATH \
--actor_zero_stage $ACTOR_ZERO_STAGE \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--num_padding_at_beginning 1 \
--gradient_accumulation_steps 1 \
--deepspeed \
--actor_lora_dim 128 \
--enable_hybrid_engine \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--output_dir $OUTPUT
After some debugging, I found that the above error arises because the GatheredParameters context does not gather all layers. If I print the tensor shape for each parameter of each layer immediately after GatheredParameters like so:
with GatheredParameters(non_active_layers):
if rank == 0:
for layer_id in range(len(self.layer_params)):
for p_id, p in enumerate(self.layer_params[layer_id]):
print("after gather layer_id", layer_id, p_id, p.shape, flush=True)
self._gather_latency = time.time() - self._t0
then I see the following output on the step just before the error:
nonactive all layers 931
after gather layer_id 0 0 torch.Size([0])
after gather layer_id 0 1 torch.Size([0])
after gather layer_id 0 2 torch.Size([0])
after gather layer_id 0 3 torch.Size([0])
after gather layer_id 0 4 torch.Size([0])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([0])
after gather layer_id 0 9 torch.Size([0])
after gather layer_id 0 10 torch.Size([0])
after gather layer_id 0 11 torch.Size([0])
after gather layer_id 0 12 torch.Size([0])
after gather layer_id 0 13 torch.Size([0])
after gather layer_id 0 14 torch.Size([0])
after gather layer_id 0 15 torch.Size([0])
after gather layer_id 1 0 torch.Size([2048])
after gather layer_id 1 1 torch.Size([2048])
after gather layer_id 1 2 torch.Size([2048])
after gather layer_id 1 3 torch.Size([2048])
after gather layer_id 1 4 torch.Size([8192, 2048])
after gather layer_id 1 5 torch.Size([8192])
after gather layer_id 1 6 torch.Size([2048, 8192])
after gather layer_id 1 7 torch.Size([2048])
after gather layer_id 1 8 torch.Size([2048, 2048])
after gather layer_id 1 9 torch.Size([2048])
after gather layer_id 1 10 torch.Size([2048, 2048])
after gather layer_id 1 11 torch.Size([2048])
after gather layer_id 1 12 torch.Size([2048, 2048])
after gather layer_id 1 13 torch.Size([2048])
after gather layer_id 1 14 torch.Size([2048, 2048])
after gather layer_id 1 15 torch.Size([2048])
Note that dimensions of the parameters in layer_id=0 are mostly all zero. On that steps that complete without an error, those parameters have non-zero shapes as shown below. The count of non_active_layers
in 962 below vs 931 above.
nonactive all layers 962
after gather layer_id 0 0 torch.Size([2048])
after gather layer_id 0 1 torch.Size([2048])
after gather layer_id 0 2 torch.Size([2048])
after gather layer_id 0 3 torch.Size([2048])
after gather layer_id 0 4 torch.Size([8192, 2048])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([2048, 2048])
after gather layer_id 0 9 torch.Size([2048])
after gather layer_id 0 10 torch.Size([2048, 2048])
after gather layer_id 0 11 torch.Size([2048])
after gather layer_id 0 12 torch.Size([2048, 2048])
after gather layer_id 0 13 torch.Size([2048])
after gather layer_id 0 14 torch.Size([2048, 2048])
after gather layer_id 0 15 torch.Size([2048])
By adding the following lines for further details:
else:
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
rank = dist.get_rank(group=self.mp_group)
non_active_layers = get_inactive_params(self.all_layers_params)
if rank == 0:
print("nonactive layers", len(non_active_layers))
for lay_id, lay in enumerate(self.all_layers_params):
print("all layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)
non_active_lora_params = get_inactive_params(self.all_lora_params)
if rank == 0:
print("nonactive lora layers", len(non_active_lora_params))
for lay_id, lay in enumerate(self.all_lora_params):
print("lora layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)
non_active_layers.extend(non_active_lora_params)
It seems that the 0-shape parameters are marked as "ds_status == ZeroParamStatus.INFLIGHT" before calling "GatheredParameters":
[2023-04-17 15:33:56,759] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32768, reducing to 16384
epoch: 0|step: 2|ppo_ep: 1|act_loss: nan|cri_loss: nan|unsuper_loss: 0.0
average reward score: 3.267578125
-------------------------------------------------------------------------------------
|E2E latency=17.17s |Gather latency=0.46s (2.70%) |Generate time=7.04s (41.02%) |Training time=6.82s (39.71%) |Others=3.31 (19.27%)|CurSamplesPerSec=0.93 |AvgSamplesPerSec=0.60
nonactive layers 651
all layers 0 True False ZeroParamStatus.INFLIGHT
all layers 1 True False ZeroParamStatus.INFLIGHT
all layers 2 True False ZeroParamStatus.INFLIGHT
all layers 3 True False ZeroParamStatus.INFLIGHT
all layers 4 True False ZeroParamStatus.INFLIGHT
all layers 5 True False ZeroParamStatus.INFLIGHT
all layers 6 True False ZeroParamStatus.INFLIGHT
all layers 7 True False ZeroParamStatus.INFLIGHT
all layers 8 True False ZeroParamStatus.INFLIGHT
all layers 9 True False ZeroParamStatus.INFLIGHT
all layers 10 True False ZeroParamStatus.INFLIGHT
all layers 11 True False ZeroParamStatus.INFLIGHT
all layers 12 True False ZeroParamStatus.INFLIGHT
all layers 13 True False ZeroParamStatus.INFLIGHT
all layers 14 True False ZeroParamStatus.INFLIGHT
all layers 15 True False ZeroParamStatus.INFLIGHT
all layers 16 True False ZeroParamStatus.INFLIGHT
all layers 17 True False ZeroParamStatus.INFLIGHT
all layers 18 True False ZeroParamStatus.INFLIGHT
all layers 19 True False ZeroParamStatus.INFLIGHT
all layers 20 True False ZeroParamStatus.INFLIGHT
all layers 21 True True ZeroParamStatus.NOT_AVAILABLE
all layers 22 True True ZeroParamStatus.NOT_AVAILABLE
all layers 23 True True ZeroParamStatus.NOT_AVAILABLE
all layers 24 True True ZeroParamStatus.NOT_AVAILABLE
all layers 25 True True ZeroParamStatus.NOT_AVAILABLE
all layers 26 True True ZeroParamStatus.NOT_AVAILABLE
all layers 27 True True ZeroParamStatus.NOT_AVAILABLE
all layers 28 True False ZeroParamStatus.INFLIGHT
all layers 29 True False ZeroParamStatus.INFLIGHT
all layers 30 True True ZeroParamStatus.NOT_AVAILABLE
all layers 31 True True ZeroParamStatus.NOT_AVAILABLE
<snip>
nonactive lora layers 280
lora layers 0 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 1 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 2 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 3 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 4 True False ZeroParamStatus.INFLIGHT
lora layers 5 True False ZeroParamStatus.INFLIGHT
lora layers 6 True False ZeroParamStatus.INFLIGHT
lora layers 7 True False ZeroParamStatus.INFLIGHT
lora layers 8 True False ZeroParamStatus.INFLIGHT
lora layers 9 True False ZeroParamStatus.INFLIGHT
lora layers 10 True False ZeroParamStatus.INFLIGHT
lora layers 11 True False ZeroParamStatus.INFLIGHT
lora layers 12 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 13 True True ZeroParamStatus.NOT_AVAILABLE
I think those parameters are marked as INFLIGHT because they have been prefetched.
Adding some more debugging lines to print the stack at the point where the status is set to INFLIGHT:
def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = True) -> AllGatherCoalescedHandle:
# fetches from nvme if the partition is not available and in nvme
self._ensure_availability_of_partitioned_params(params)
if self.world_size == 1:
return _no_gather_coalesced(params)
#for param in params:
for p_id, param in enumerate(params):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(param.ds_summary())
param.ds_status = ZeroParamStatus.INFLIGHT
if dist.get_rank() == 0:
print(p_id, "INFLIGHT2")
if p_id > 20:
print(traceback.print_stack(file=sys.stdout))
I can see those layers are set to INFLIGHT here:
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 180, in train_rlhf
value = self.critic_model.forward_value(**batch,
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
transformer_outputs = self.rwtranrsformer(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
decoder_outputs = self.decoder(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/path/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/path/site-packages/torch/utils/checkpoint.py", line 96, in forward
outputs = run_function(*args)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
return module(*inputs, output_attentions, None)
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
hidden_states = self.activation_fn(hidden_states)
File "/path/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
result = hook(self, input)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 366, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 478, in pre_sub_module_forward_function
param_coordinator.fetch_sub_module(sub_module)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 333, in fetch_sub_module
self.__all_gather_params(params_to_prefetch)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 381, in __all_gather_params
handle = partitioned_params[0].all_gather_coalesced(partitioned_params)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 878, in all_gather_coalesced
print(traceback.print_stack(file=sys.stdout))
It seems that the layers are being prefetched during the call to the critic model forward pass:
They are still in INFLIGHT
status when trying to generate a sample. The get_inactive_params
function then only include params marked as NOT_AVAILABLE
:
Later, GatheredParameters
may only consider params whose state is NOT_AVAILABLE:
Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should get_inactive_params
include INFLIGHT
params?