File tree 1 file changed +3
-3
lines changed
1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
272
272
# 学習を準備する:モデルを適切な状態にする
273
273
if args .gradient_checkpointing :
274
274
unet .enable_gradient_checkpointing ()
275
- train_unet = args .learning_rate > 0
275
+ train_unet = args .learning_rate != 0
276
276
train_text_encoder1 = False
277
277
train_text_encoder2 = False
278
278
@@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
284
284
text_encoder2 .gradient_checkpointing_enable ()
285
285
lr_te1 = args .learning_rate_te1 if args .learning_rate_te1 is not None else args .learning_rate # 0 means not train
286
286
lr_te2 = args .learning_rate_te2 if args .learning_rate_te2 is not None else args .learning_rate # 0 means not train
287
- train_text_encoder1 = lr_te1 > 0
288
- train_text_encoder2 = lr_te2 > 0
287
+ train_text_encoder1 = lr_te1 != 0
288
+ train_text_encoder2 = lr_te2 != 0
289
289
290
290
# caching one text encoder output is not supported
291
291
if not train_text_encoder1 :
You can’t perform that action at this time.
0 commit comments