Skip to content

Commit 7983d3d

Browse files
authored
Merge pull request #1319 from kohya-ss/fused-backward-pass
Fused backward pass
2 parents c1ba0b4 + bee8cee commit 7983d3d

File tree

4 files changed

+291
-16
lines changed

4 files changed

+291
-16
lines changed

README.md

+29
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,37 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
139139

140140
### Working in progress
141141

142+
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
143+
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
144+
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
145+
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
146+
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
147+
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
148+
- Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
149+
150+
- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
151+
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
152+
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
153+
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
154+
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
155+
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
156+
142157
- Fixed some bugs when using DeepSpeed. Related [#1247]
143158

159+
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
160+
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
161+
- `sdxl_train.py``--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
162+
- mixed precision は `no` のほうが `fp16``bf16` よりも使用メモリ量が少ないようです。
163+
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
164+
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
165+
- 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
166+
167+
- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
168+
- Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
169+
- `sdxl_train.py``--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
170+
- 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。
171+
- `--fused_optimizer_groups``--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
172+
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
144173

145174
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247]
146175

library/adafactor_fused.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import math
2+
import torch
3+
from transformers import Adafactor
4+
5+
@torch.no_grad()
6+
def adafactor_step_param(self, p, group):
7+
if p.grad is None:
8+
return
9+
grad = p.grad
10+
if grad.dtype in {torch.float16, torch.bfloat16}:
11+
grad = grad.float()
12+
if grad.is_sparse:
13+
raise RuntimeError("Adafactor does not support sparse gradients.")
14+
15+
state = self.state[p]
16+
grad_shape = grad.shape
17+
18+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
19+
# State Initialization
20+
if len(state) == 0:
21+
state["step"] = 0
22+
23+
if use_first_moment:
24+
# Exponential moving average of gradient values
25+
state["exp_avg"] = torch.zeros_like(grad)
26+
if factored:
27+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
28+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
29+
else:
30+
state["exp_avg_sq"] = torch.zeros_like(grad)
31+
32+
state["RMS"] = 0
33+
else:
34+
if use_first_moment:
35+
state["exp_avg"] = state["exp_avg"].to(grad)
36+
if factored:
37+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
38+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
39+
else:
40+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
41+
42+
p_data_fp32 = p
43+
if p.dtype in {torch.float16, torch.bfloat16}:
44+
p_data_fp32 = p_data_fp32.float()
45+
46+
state["step"] += 1
47+
state["RMS"] = Adafactor._rms(p_data_fp32)
48+
lr = Adafactor._get_lr(group, state)
49+
50+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
51+
update = (grad ** 2) + group["eps"][0]
52+
if factored:
53+
exp_avg_sq_row = state["exp_avg_sq_row"]
54+
exp_avg_sq_col = state["exp_avg_sq_col"]
55+
56+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
57+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
58+
59+
# Approximation of exponential moving average of square of gradient
60+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
61+
update.mul_(grad)
62+
else:
63+
exp_avg_sq = state["exp_avg_sq"]
64+
65+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
66+
update = exp_avg_sq.rsqrt().mul_(grad)
67+
68+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
69+
update.mul_(lr)
70+
71+
if use_first_moment:
72+
exp_avg = state["exp_avg"]
73+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
74+
update = exp_avg
75+
76+
if group["weight_decay"] != 0:
77+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
78+
79+
p_data_fp32.add_(-update)
80+
81+
if p.dtype in {torch.float16, torch.bfloat16}:
82+
p.copy_(p_data_fp32)
83+
84+
85+
@torch.no_grad()
86+
def adafactor_step(self, closure=None):
87+
"""
88+
Performs a single optimization step
89+
90+
Arguments:
91+
closure (callable, optional): A closure that reevaluates the model
92+
and returns the loss.
93+
"""
94+
loss = None
95+
if closure is not None:
96+
loss = closure()
97+
98+
for group in self.param_groups:
99+
for p in group["params"]:
100+
adafactor_step_param(self, p, group)
101+
102+
return loss
103+
104+
def patch_adafactor_fused(optimizer: Adafactor):
105+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
106+
optimizer.step = adafactor_step.__get__(optimizer)

library/train_util.py

+14
Original file line numberDiff line numberDiff line change
@@ -2920,6 +2920,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
29202920
default=1,
29212921
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
29222922
)
2923+
parser.add_argument(
2924+
"--fused_backward_pass",
2925+
action="store_true",
2926+
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
2927+
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
2928+
)
29232929

29242930

29252931
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -3846,6 +3852,14 @@ def get_optimizer(args, trainable_params):
38463852
optimizer_type = "AdamW"
38473853
optimizer_type = optimizer_type.lower()
38483854

3855+
if args.fused_backward_pass:
3856+
assert (
3857+
optimizer_type == "Adafactor".lower()
3858+
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
3859+
assert (
3860+
args.gradient_accumulation_steps == 1
3861+
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
3862+
38493863
# 引数を分解する
38503864
optimizer_kwargs = {}
38513865
if args.optimizer_args is not None and len(args.optimizer_args) > 0:

0 commit comments

Comments
 (0)