Skip to content

Commit b28675c

Browse files
AnandK27sayakpaul
andauthored
[train_instruct_pix2pix.py]Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#9316)
Fixed pix2pix lr scheduler Co-authored-by: Sayak Paul <[email protected]>
1 parent bd4df28 commit b28675c

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -747,17 +747,22 @@ def collate_fn(examples):
747747
)
748748

749749
# Scheduler and math around the number of training steps.
750-
overrode_max_train_steps = False
751-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
750+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
751+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
752752
if args.max_train_steps is None:
753-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
754-
overrode_max_train_steps = True
753+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
754+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
755+
num_training_steps_for_scheduler = (
756+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
757+
)
758+
else:
759+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
755760

756761
lr_scheduler = get_scheduler(
757762
args.lr_scheduler,
758763
optimizer=optimizer,
759-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
760-
num_training_steps=args.max_train_steps * accelerator.num_processes,
764+
num_warmup_steps=num_warmup_steps_for_scheduler,
765+
num_training_steps=num_training_steps_for_scheduler,
761766
)
762767

763768
# Prepare everything with our `accelerator`.
@@ -782,8 +787,14 @@ def collate_fn(examples):
782787

783788
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
784789
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
785-
if overrode_max_train_steps:
790+
if args.max_train_steps is None:
786791
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
792+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
793+
logger.warning(
794+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
795+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
796+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
797+
)
787798
# Afterwards we recalculate our number of training epochs
788799
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
789800

0 commit comments

Comments
 (0)