Skip to content

[Peft] fix saving / loading when unet is not "unet" #6046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 26, 2023
6 changes: 4 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def load_ip_adapter(
self.feature_extractor = CLIPImageProcessor()

# load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict)

def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
46 changes: 28 additions & 18 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
)

if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))

if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))

if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
Expand Down Expand Up @@ -975,20 +975,22 @@ def unload_lora_weights(self):
>>> ...
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet

if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warn(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)

for _, module in self.unet.named_modules():
for _, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(self.unet)
if hasattr(self.unet, "peft_config"):
del self.unet.peft_config
recurse_remove_peft_layers(unet)
if hasattr(unet, "peft_config"):
del unet.peft_config

# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
Expand Down Expand Up @@ -1027,7 +1029,8 @@ def fuse_lora(
)

if fuse_unet:
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)

if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if unfuse_unet:
if not USE_PEFT_BACKEND:
self.unet.unfuse_lora()
unet.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer

for module in self.unet.modules():
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()

Expand Down Expand Up @@ -1202,8 +1206,9 @@ def set_adapters(
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
self.unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, adapter_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
Expand All @@ -1216,7 +1221,8 @@ def disable_lora(self):
raise ValueError("PEFT backend is required for this method.")

# Disable unet adapters
self.unet.disable_lora()
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.disable_lora()

# Disable text encoder adapters
if hasattr(self, "text_encoder"):
Expand All @@ -1229,7 +1235,8 @@ def enable_lora(self):
raise ValueError("PEFT backend is required for this method.")

# Enable unet adapters
self.unet.enable_lora()
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.enable_lora()

# Enable text encoder adapters
if hasattr(self, "text_encoder"):
Expand All @@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
adapter_names = [adapter_names]

# Delete unet adapters
self.unet.delete_adapters(adapter_names)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.delete_adapters(adapter_names)

for adapter_name in adapter_names:
# Delete text encoder adapters
Expand Down Expand Up @@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]:
from peft.tuners.tuners_utils import BaseTunerLayer

active_adapters = []

for module in self.unet.modules():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
Expand All @@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())

if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
set_adapters["unet"] = list(self.unet.peft_config.keys())
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())

return set_adapters

Expand All @@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
from peft.tuners.tuners_utils import BaseTunerLayer

# Handle the UNET
for unet_module in self.unet.modules():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for unet_module in unet.modules():
if isinstance(unet_module, BaseTunerLayer):
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
Expand Down