Skip to content

Commit 38e4c60

Browse files
authored
Merge pull request kohya-ss#1277 from Cauldrath/negative_learning
Allow negative learning rate
2 parents e4d9e3c + fc37437 commit 38e4c60

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sdxl_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
272272
# 学習を準備する:モデルを適切な状態にする
273273
if args.gradient_checkpointing:
274274
unet.enable_gradient_checkpointing()
275-
train_unet = args.learning_rate > 0
275+
train_unet = args.learning_rate != 0
276276
train_text_encoder1 = False
277277
train_text_encoder2 = False
278278

@@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
284284
text_encoder2.gradient_checkpointing_enable()
285285
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
286286
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
287-
train_text_encoder1 = lr_te1 > 0
288-
train_text_encoder2 = lr_te2 > 0
287+
train_text_encoder1 = lr_te1 != 0
288+
train_text_encoder2 = lr_te2 != 0
289289

290290
# caching one text encoder output is not supported
291291
if not train_text_encoder1:

0 commit comments

Comments
 (0)