Skip to content

Commit 0bb9cf0

Browse files
kashifsayakpaul
andcommitted
[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)
fix fp16 training Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 4c7e983 commit 0bb9cf0

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,17 @@ def deepspeed_zero_init_disabled_context_manager():
527527

528528
# lora attn processor
529529
prior_lora_config = LoraConfig(
530-
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
530+
r=args.rank,
531+
lora_alpha=args.rank,
532+
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
531533
)
534+
# Add adapter and make sure the trainable params are in float32.
532535
prior.add_adapter(prior_lora_config)
536+
if args.mixed_precision == "fp16":
537+
for param in prior.parameters():
538+
# only upcast trainable parameters (LoRA) into fp32
539+
if param.requires_grad:
540+
param.data = param.to(torch.float32)
533541

534542
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
535543
def save_model_hook(models, weights, output_dir):

0 commit comments

Comments
 (0)