Skip to content

fix(training): lr scheduler doesn't work properly in distributed scenarios #8312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024

Conversation

geniuspatrick
Copy link
Contributor

@geniuspatrick geniuspatrick commented May 29, 2024

What does this PR do?

TL;DR

In a distributed training scenario, passing the argument --num_train_epochs to any of the training scripts disrupts the functioning of the learning rate scheduler. Essentially, the learning rate decays num_processes times slower than expected. Related issues #8236, #3954, and PR #3983 shed further light on this.

Explanation

In our training setup, we utilize accelerator instead of PyTorch's native DistributedSampler when creating the train_dataloader. This means we create the train_dataloader directly as if for standalone training and subsequently employ accelerator.prepare to shard the samples across different processes.

When referencing step in training scripts such as lr_warmup_steps, max_train_steps, etc., we're indicating the optimizing step. In essence, each step consumes num_processes * gradient_accumulation_steps batches of data. In the script, the learning rate scheduler is initialized before accelerator.prepare is called. At this stage, the train_dataloader hasn't yet sharded the samples, specifically the batched samples.

To accurately calculate num_update_steps_per_epoch, we need the length of the train_dataloader after distributed sharding. How do we achieve this? Typically, accelerator.prepare replaces train_dataloader.batch_sampler with BatchSamplerShard. The length of the distributed sharded train_dataloader (still a DataLoader instance) becomes the length of BatchSamplerShard. Hence, we derive a formula for estimating the length of the sharded train_dataloader, which aligns with current training scripts (where accelerator.prepare is called with no extra arguments).

As per accelerator principles, the prepared scheduler calls the step() of the unprepared scheduler num_processes times at each optimizing step (once gradient accumulation is completed). This necessitates dividing num_*_steps_for_scheduler by gradient_accumulation_steps and multiplying it by num_processes.

Feeling a bit confused? Not to worry, let's visualize it.

Experiments

We utilize Fine-tuning for text2image with LoRA as an example. Below is the training command:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATASET_NAME="lambdalabs/naruto-blip-captions"

# Example of --num_train_epochs
accelerate launch examples/text_to_image/train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$DATASET_NAME \
  --resolution=128 --random_flip --max_train_samples=171 \
  --train_batch_size=4 \
  --num_train_epochs=6 \
  --learning_rate=1e-04 --lr_scheduler="cosine_with_restarts" --lr_warmup_steps=3 \
  --gradient_accumulation_steps=5 \
  --seed=42 \
  --output_dir="sd-pokemon-model-lora-epoch"

# Example of --max_train_steps
accelerate launch examples/text_to_image/train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$DATASET_NAME \
  --resolution=128 --random_flip --max_train_samples=171 \
  --train_batch_size=4 \
  --max_train_steps=30 \
  --learning_rate=1e-04 --lr_scheduler="cosine_with_restarts" --lr_warmup_steps=3 \
  --gradient_accumulation_steps=5 \
  --seed=42 \
  --output_dir="sd-pokemon-model-lora-step"

The hyper-parameters are:

  • Batch size: 4
  • Number of processes (n_gpus): 2
  • Dataset length: 171
  • Gradient accumulation steps: 5

Thus,

len_dataloader_standalone = ceil(171/4) = 43
len_dataloader_distribute = ceil(43/2) = 22
num_update_steps_per_epoch = ceil(22/5) = 5

And epochs=6 is equivalent to steps=30.

Additionally, introducing the argument num_cycles=2 to the function get_scheduler exacerbates the error.

Before the PR

--num_train_epochs
截屏2024-05-29 18 32 08

--max_train_steps
截屏2024-05-29 18 34 14

After the PR

--num_train_epochs
截屏2024-05-29 18 34 26

--max_train_steps
截屏2024-05-29 18 34 40

Fixes #8236

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @eliphatfs

@geniuspatrick geniuspatrick changed the title fix(training): lr scheduler doesn't work properly [WIP] fix(training): lr scheduler doesn't work properly May 29, 2024
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just have minor comments but this is very very nicely done. Thanks so much!

Comment on lines 687 to 689
"The length of the 'train_dataloader' after 'accelerator.prepare' does not match "
"the length that was expected when the learning rate scheduler was created. "
"This inconsistency may result in the learning rate scheduler not functioning properly."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also include the values or "The length of the 'train_dataloader'" and "the length that was expected when the learning rate scheduler was created"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have drop_last settings or similar that may cause this to happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also include the values or "The length of the 'train_dataloader'" and "the length that was expected when the learning rate scheduler was created"?

Yep, the values are included!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have drop_last settings or similar that may cause this to happen?

For all current period training scripts, the answer is no. Our estimate of the length of the sliced dataloader is always correct, and this warning message is never triggered.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@geniuspatrick geniuspatrick force-pushed the train_lr branch 3 times, most recently from 4b00cbf to f6742ea Compare May 30, 2024 05:08
@geniuspatrick geniuspatrick changed the title [WIP] fix(training): lr scheduler doesn't work properly [WIP] fix(training): lr scheduler doesn't work properly in distributed scenarios May 30, 2024
@sayakpaul
Copy link
Member

Sorry for pinging late but @geniuspatrick could we keep the changes in this PR to a bare minimum i.e., targeting a single script only and then opening the rest to the community? That will be easier to manage IMO.

@geniuspatrick
Copy link
Contributor Author

Sorry for pinging late but @geniuspatrick could we keep the changes in this PR to a bare minimum i.e., targeting a single script only and then opening the rest to the community? That will be easier to manage IMO.

OK, I'll change the script examples/text_to_image/train_text_to_image_lora.py only.

@geniuspatrick geniuspatrick force-pushed the train_lr branch 3 times, most recently from 60f9f39 to 92f7262 Compare May 30, 2024 08:16

Verified

This commit was signed with the committer’s verified signature. The key has expired.
laravel-shift Laravel Shift
…arios
@geniuspatrick geniuspatrick changed the title [WIP] fix(training): lr scheduler doesn't work properly in distributed scenarios fix(training): lr scheduler doesn't work properly in distributed scenarios May 30, 2024
@geniuspatrick
Copy link
Contributor Author

Hi, @sayakpaul . I think it's ready now. Any suggestions?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton!

@sayakpaul sayakpaul merged commit 3511a96 into huggingface:main May 30, 2024
8 checks passed
@geniuspatrick
Copy link
Contributor Author

Hi, @sayakpaul here's a TODO list for follow-up contributions from the community.

What should be changed

  • advanced_diffusion_training
    • train_dreambooth_lora_sd15_advanced.py
    • train_dreambooth_lora_sdxl_advanced.py
  • consistency_distillation
    • train_lcm_distill_lora_sdxl.py
  • controlnet
    • train_controlnet.py
    • train_controlnet_sdxl.py
  • custom_diffusion
    • train_custom_diffusion.py
  • dreambooth
    • train_dreambooth.py
    • train_dreambooth_lora.py
    • train_dreambooth_lora_sdxl.py
  • instruct_pix2pix
    • train_instruct_pix2pix.py
    • rain_instruct_pix2pix_sdxl.py
  • kandinsky2_2/text_to_image
    • train_text_to_image_decoder.py
    • train_text_to_image_prior.py
    • train_text_to_image_lora_decoder.py
    • train_text_to_image_lora_prior.py
  • research_projects
    • consistency_training/train_cm_ct_unconditional.py
    • diffusion_dpo/train_diffusion_dpo.py
    • diffusion_dpo/train_diffusion_dpo_sdxl.py
    • diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
    • dreambooth_inpaint/train_dreambooth_inpaint.py
    • dreambooth_inpaint/train_dreambooth_inpaint_lora.py
    • instructpix2pix_lora/train_instruct_pix2pix_lora.py
    • intel_opts/textual_inversion/textual_inversion_bf16.py
    • intel_opts/textual_inversion_dfq/textual_inversion.py
    • lora/train_text_to_image_lora.py
    • multi_subject_dreambooth/train_multi_subject_dreambooth.py
    • multi_token_textual_inversion/textual_inversion.py
    • onnxruntime/text_to_image/train_text_to_image.py
    • onnxruntime/textual_inversion/textual_inversion.py
    • onnxruntime/unconditional_image_generation/train_unconditional.py
    • realfill/train_realfill.py
    • scheduled_huber_loss_training/dreambooth/train_dreambooth.py
    • scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
    • scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
    • scheduled_huber_loss_training/text_to_image/train_text_to_image.py
    • scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
    • scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
    • scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
  • t2i_adapter
    • train_t2i_adapter_sdxl.py
  • text_to_image
    • train_text_to_image.py
    • train_text_to_image_sdxl.py
    • train_text_to_image_lora.py
    • train_text_to_image_lora_sdxl.py
  • textual_inversion
    • textual_inversion.py
    • textual_inversion_sdxl.py
  • unconditional_image_generation
    • train_unconditional.py
  • wuerstchen
    • text_to_image/train_text_to_image_prior.py
    • text_to_image/train_text_to_image_lora_prior.py

What should NOT be changed

Category 1

The script does not have the argument --num_train_epochs.

  • amused
    • train_amused.py
  • research_projects
    • multi_subject_dreambooth_inpainting/train_multi_subject_dreambooth_inpainting.py

Category 2

Distributed dataset sharding is done by WebDataset, not accelerator.

  • consistency_distillation
    • train_lcm_distill_sd_wds.py
    • train_lcm_distill_sdxl_wds.py
    • train_lcm_distill_lora_sd_wds.py
    • train_lcm_distill_lora_sdxl_wds.py
  • research_projects
    • controlnet/train_controlnet_webdataset.py
    • diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py

BTW, if you need more extra hands, I would like to help!

@sayakpaul
Copy link
Member

Great! Thank you so much, @geniuspatrick!

I will create an issue similar to #6545 so that the community can easily pick them up.

@AbraarArique
Copy link

Thank you @sayakpaul and @geniuspatrick for fixing this, much appreciated!

But I have a quick question: why is lr_warmup_steps multiplied by num_processes?

# Line 1092
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes

For example, if my Dataloader length (e.g. steps per epoch) is 64, and for simplicity let's say I want to warm up to 32 steps.

With 1 epoch and gradient accumulation steps = 1, the current code works correctly for 1 GPU:

num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes # 32 * 1 = 32
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) # 64 / 1 = 64
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) # 64 / 1 = 64
num_training_steps_for_scheduler = (
    args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes # 1 * 64 * 1 = 64
)

But with 2 GPUs, the num_training_steps_for_scheduler stays the same (64), but num_warmup_steps_for_scheduler doubles to 64.

num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes # 32 * 2 = 64
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) # 64 / 2 = 32
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) # 32 / 1 = 32
num_training_steps_for_scheduler = (
    args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes # 1 * 32 * 2 = 64
)

Shouldn't num_warmup_steps_for_scheduler still be 32? Sorry if I'm missing something...

@eliphatfs
Copy link
Contributor

In each gradient step, the lr scheduler is advanced for num_processes steps in accelerator, if my memory serves me right. This is counterintuitive.

@AbraarArique
Copy link

In each gradient step, the lr scheduler is advanced for num_processes steps in accelerator, if my memory serves me right. This is counterintuitive.

@eliphatfs Right, but to clarify, the num_training_steps_for_scheduler doesn't change much with the number of processes/GPUs, does it?

Regardless of 1, 2, or more GPUs, the num_training_steps_for_scheduler is still 64 in my example above.

So assuming I want to warm up for half of all steps, the num_warmup_steps_for_scheduler should be 32 regardless of the number of processes.

But the current code increases/scales the num_warmup_steps_for_scheduler based on num_processes, but the total num_training_steps_for_scheduler doesn't change. Shouldn't that break things?

@bghira
Copy link
Contributor

bghira commented Sep 2, 2024

this is actually incorrect...

image

The disconnected line is with the * accelerator.num_processes on the warmup steps resuming at 1000 steps with 3 GPUs.

The learning rate is resumed at the point where it would have been at T=333.

Removing the multiplication fixes the issue.

cc @linoytsaban @sayakpaul

@sayakpaul
Copy link
Member

sayakpaul commented Sep 2, 2024

Perhaps you could provide a little more explanation here? From the snapshot, it's not immediately clear to me.

Update: I see what you mean. Yeah, IIUC, the multiplication (num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes) should be removed. PR?

Cc: @geniuspatrick

@bghira
Copy link
Contributor

bghira commented Sep 2, 2024

it might be a bug in Accelerate, actually. the LR scheduler state isn't being restored correctly when it restarts. I then set it manually, but didn't multiply by num_processes. which led to a mismatch during resume LR. but i think maybe when they are both multiplied correctly. this issue does not manifest. i'm still trying to test it so i can assess the issue on Accelerate's side since I shouldn't / didn't used to have to set the last_epoch manually.

@sayakpaul
Copy link
Member

Cc: @muellerzr

@geniuspatrick
Copy link
Contributor Author

@AbraarArique In the training scripts, we have two arguments that control the total number of training steps, args.num_train_epochs and args.max_train_steps(has a higher priority, btw). In your example, you seem to be controlling the number of steps trained through args.num_train_epochs. But note that as the number of gpu's changes, so does the number of training steps per epoch.

The argument args.lr_warmup_steps would be easier to understand, if together with args.max_train_steps. When the number of GPUs changes from 1 to 2, if you keep args.num_train_epochs=1, then actually the args.max_train_steps has changed from 64 to 32, then you should also adjust the args.lr_warmup_steps.

Hopefully the rather redundant explanation above will help your understanding.

@geniuspatrick
Copy link
Contributor Author

geniuspatrick commented Sep 2, 2024

it might be a bug in Accelerate, actually. the LR scheduler state isn't being restored correctly when it restarts. I then set it manually, but didn't multiply by num_processes. which led to a mismatch during resume LR. but i think maybe when they are both multiplied correctly. this issue does not manifest. i'm still trying to test it so i can assess the issue on Accelerate's side since I shouldn't / didn't used to have to set the last_epoch manually.

@bghira Looks like a topic about resume training. Is the scheduler state not being saved and restored correctly? Could the removal of multiplication be just a numerical coincidence? Is it possible to have the same problem with single GPU training?

@bghira
Copy link
Contributor

bghira commented Sep 2, 2024

on my end the problem is that we upgraded from accelerate v0.19 to v0.33 and the load/save state for accelerate stopped writing the step count for the random states - or stopped restoring it. either way, i'm on git main now and i have to see if lr_scheduler hasattr last_epoch and set it to resume_step * num_processes.

that fixed my issue for single and multiple GPUs. but i have to leave the num_warmup_steps multiplied too

@muellerzr
Copy link
Contributor

Will dig into this @bghira

@bghira
Copy link
Contributor

bghira commented Oct 5, 2024

@muellerzr did you ever find anything? it might just be something i've been doing incorrectly but i would like to align with best practices

@Zephyrose
Copy link

@geniuspatrick @AbraarArique I agree with your point of view, I don't think we should multiply the accelerator.num_processes. Because the lr_warmup_steps/num_training_steps_for_scheduler is a ratio acctually. So no matter how many GPUs we use, we want this ratio to remain constant. When we add GPUs, we are actually increasing the batch size. num_training_steps_for_scheduler is a constant, and actual_training_steps is reduced by num_GPUS times. If we don't multiply the accelerator.num_processes, we will not adjust the args.lr_warmup_steps.

@AbraarArique
Copy link

@Zephyrose I think there are 2 ways of looking at this:

The way it's done now does make sense logically. If 1 epoch is 32 steps with 1 GPU, with 2 GPUs you're doubling the batch size and thus now you have 16 update steps per epoch.

So the lr_warmup_steps being relative to the number of update steps per device/process makes sense: as you add GPUs, you're lowering your update steps and thus lowering your lr_warmup_steps accordingly.

This works fine if you specify max_train_steps, but as you mentioned, if you're training based on epochs, having to scale the lr_warmup_steps down with more GPUs is indeed inconvenient as you often want the ratio to be consistent.

If people care about the warmup-to-total-steps ratio more than a specific number of steps, perhaps it makes sense to have an lr_warmup_ratio parameter instead of manually specifying steps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong learning rate scheduler training step count for examples with multi-gpu when setting --num_train_epochs
8 participants