Skip to content

Commit e0d8c91

Browse files
kashifsayakpaul
andauthored
[Peft] fix saving / loading when unet is not "unet" (#6046)
* [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent a3d31e3 commit e0d8c91

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

src/diffusers/loaders/ip_adapter.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,11 @@ def load_ip_adapter(
149149
self.feature_extractor = CLIPImageProcessor()
150150

151151
# load ip-adapter into unet
152-
self.unet._load_ip_adapter_weights(state_dict)
152+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
153+
unet._load_ip_adapter_weights(state_dict)
153154

154155
def set_ip_adapter_scale(self, scale):
155-
for attn_processor in self.unet.attn_processors.values():
156+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
157+
for attn_processor in unet.attn_processors.values():
156158
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
157159
attn_processor.scale = scale

src/diffusers/loaders/lora.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
912912
)
913913

914914
if unet_lora_layers:
915-
state_dict.update(pack_weights(unet_lora_layers, "unet"))
915+
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
916916

917917
if text_encoder_lora_layers:
918-
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
918+
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
919919

920920
if transformer_lora_layers:
921921
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
@@ -975,20 +975,22 @@ def unload_lora_weights(self):
975975
>>> ...
976976
```
977977
"""
978+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
979+
978980
if not USE_PEFT_BACKEND:
979981
if version.parse(__version__) > version.parse("0.23"):
980982
logger.warn(
981983
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
982984
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
983985
)
984986

985-
for _, module in self.unet.named_modules():
987+
for _, module in unet.named_modules():
986988
if hasattr(module, "set_lora_layer"):
987989
module.set_lora_layer(None)
988990
else:
989-
recurse_remove_peft_layers(self.unet)
990-
if hasattr(self.unet, "peft_config"):
991-
del self.unet.peft_config
991+
recurse_remove_peft_layers(unet)
992+
if hasattr(unet, "peft_config"):
993+
del unet.peft_config
992994

993995
# Safe to call the following regardless of LoRA.
994996
self._remove_text_encoder_monkey_patch()
@@ -1027,7 +1029,8 @@ def fuse_lora(
10271029
)
10281030

10291031
if fuse_unet:
1030-
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
1032+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1033+
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
10311034

10321035
if USE_PEFT_BACKEND:
10331036
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
10801083
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
10811084
LoRA parameters then it won't have any effect.
10821085
"""
1086+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
10831087
if unfuse_unet:
10841088
if not USE_PEFT_BACKEND:
1085-
self.unet.unfuse_lora()
1089+
unet.unfuse_lora()
10861090
else:
10871091
from peft.tuners.tuners_utils import BaseTunerLayer
10881092

1089-
for module in self.unet.modules():
1093+
for module in unet.modules():
10901094
if isinstance(module, BaseTunerLayer):
10911095
module.unmerge()
10921096

@@ -1202,8 +1206,9 @@ def set_adapters(
12021206
adapter_names: Union[List[str], str],
12031207
adapter_weights: Optional[List[float]] = None,
12041208
):
1209+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
12051210
# Handle the UNET
1206-
self.unet.set_adapters(adapter_names, adapter_weights)
1211+
unet.set_adapters(adapter_names, adapter_weights)
12071212

12081213
# Handle the Text Encoder
12091214
if hasattr(self, "text_encoder"):
@@ -1216,7 +1221,8 @@ def disable_lora(self):
12161221
raise ValueError("PEFT backend is required for this method.")
12171222

12181223
# Disable unet adapters
1219-
self.unet.disable_lora()
1224+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1225+
unet.disable_lora()
12201226

12211227
# Disable text encoder adapters
12221228
if hasattr(self, "text_encoder"):
@@ -1229,7 +1235,8 @@ def enable_lora(self):
12291235
raise ValueError("PEFT backend is required for this method.")
12301236

12311237
# Enable unet adapters
1232-
self.unet.enable_lora()
1238+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1239+
unet.enable_lora()
12331240

12341241
# Enable text encoder adapters
12351242
if hasattr(self, "text_encoder"):
@@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
12511258
adapter_names = [adapter_names]
12521259

12531260
# Delete unet adapters
1254-
self.unet.delete_adapters(adapter_names)
1261+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1262+
unet.delete_adapters(adapter_names)
12551263

12561264
for adapter_name in adapter_names:
12571265
# Delete text encoder adapters
@@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]:
12841292
from peft.tuners.tuners_utils import BaseTunerLayer
12851293

12861294
active_adapters = []
1287-
1288-
for module in self.unet.modules():
1295+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1296+
for module in unet.modules():
12891297
if isinstance(module, BaseTunerLayer):
12901298
active_adapters = module.active_adapters
12911299
break
@@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
13091317
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
13101318
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
13111319

1312-
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
1313-
set_adapters["unet"] = list(self.unet.peft_config.keys())
1320+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1321+
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
1322+
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
13141323

13151324
return set_adapters
13161325

@@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
13311340
from peft.tuners.tuners_utils import BaseTunerLayer
13321341

13331342
# Handle the UNET
1334-
for unet_module in self.unet.modules():
1343+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1344+
for unet_module in unet.modules():
13351345
if isinstance(unet_module, BaseTunerLayer):
13361346
for adapter_name in adapter_names:
13371347
unet_module.lora_A[adapter_name].to(device)

0 commit comments

Comments
 (0)