@@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
912
912
)
913
913
914
914
if unet_lora_layers :
915
- state_dict .update (pack_weights (unet_lora_layers , "unet" ))
915
+ state_dict .update (pack_weights (unet_lora_layers , cls . unet_name ))
916
916
917
917
if text_encoder_lora_layers :
918
- state_dict .update (pack_weights (text_encoder_lora_layers , "text_encoder" ))
918
+ state_dict .update (pack_weights (text_encoder_lora_layers , cls . text_encoder_name ))
919
919
920
920
if transformer_lora_layers :
921
921
state_dict .update (pack_weights (transformer_lora_layers , "transformer" ))
@@ -975,20 +975,22 @@ def unload_lora_weights(self):
975
975
>>> ...
976
976
```
977
977
"""
978
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
979
+
978
980
if not USE_PEFT_BACKEND :
979
981
if version .parse (__version__ ) > version .parse ("0.23" ):
980
982
logger .warn (
981
983
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
982
984
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
983
985
)
984
986
985
- for _ , module in self . unet .named_modules ():
987
+ for _ , module in unet .named_modules ():
986
988
if hasattr (module , "set_lora_layer" ):
987
989
module .set_lora_layer (None )
988
990
else :
989
- recurse_remove_peft_layers (self . unet )
990
- if hasattr (self . unet , "peft_config" ):
991
- del self . unet .peft_config
991
+ recurse_remove_peft_layers (unet )
992
+ if hasattr (unet , "peft_config" ):
993
+ del unet .peft_config
992
994
993
995
# Safe to call the following regardless of LoRA.
994
996
self ._remove_text_encoder_monkey_patch ()
@@ -1027,7 +1029,8 @@ def fuse_lora(
1027
1029
)
1028
1030
1029
1031
if fuse_unet :
1030
- self .unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
1032
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1033
+ unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
1031
1034
1032
1035
if USE_PEFT_BACKEND :
1033
1036
from peft .tuners .tuners_utils import BaseTunerLayer
@@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
1080
1083
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1081
1084
LoRA parameters then it won't have any effect.
1082
1085
"""
1086
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1083
1087
if unfuse_unet :
1084
1088
if not USE_PEFT_BACKEND :
1085
- self . unet .unfuse_lora ()
1089
+ unet .unfuse_lora ()
1086
1090
else :
1087
1091
from peft .tuners .tuners_utils import BaseTunerLayer
1088
1092
1089
- for module in self . unet .modules ():
1093
+ for module in unet .modules ():
1090
1094
if isinstance (module , BaseTunerLayer ):
1091
1095
module .unmerge ()
1092
1096
@@ -1202,8 +1206,9 @@ def set_adapters(
1202
1206
adapter_names : Union [List [str ], str ],
1203
1207
adapter_weights : Optional [List [float ]] = None ,
1204
1208
):
1209
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1205
1210
# Handle the UNET
1206
- self . unet .set_adapters (adapter_names , adapter_weights )
1211
+ unet .set_adapters (adapter_names , adapter_weights )
1207
1212
1208
1213
# Handle the Text Encoder
1209
1214
if hasattr (self , "text_encoder" ):
@@ -1216,7 +1221,8 @@ def disable_lora(self):
1216
1221
raise ValueError ("PEFT backend is required for this method." )
1217
1222
1218
1223
# Disable unet adapters
1219
- self .unet .disable_lora ()
1224
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1225
+ unet .disable_lora ()
1220
1226
1221
1227
# Disable text encoder adapters
1222
1228
if hasattr (self , "text_encoder" ):
@@ -1229,7 +1235,8 @@ def enable_lora(self):
1229
1235
raise ValueError ("PEFT backend is required for this method." )
1230
1236
1231
1237
# Enable unet adapters
1232
- self .unet .enable_lora ()
1238
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1239
+ unet .enable_lora ()
1233
1240
1234
1241
# Enable text encoder adapters
1235
1242
if hasattr (self , "text_encoder" ):
@@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
1251
1258
adapter_names = [adapter_names ]
1252
1259
1253
1260
# Delete unet adapters
1254
- self .unet .delete_adapters (adapter_names )
1261
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1262
+ unet .delete_adapters (adapter_names )
1255
1263
1256
1264
for adapter_name in adapter_names :
1257
1265
# Delete text encoder adapters
@@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]:
1284
1292
from peft .tuners .tuners_utils import BaseTunerLayer
1285
1293
1286
1294
active_adapters = []
1287
-
1288
- for module in self . unet .modules ():
1295
+ unet = getattr ( self , self . unet_name ) if not hasattr ( self , "unet" ) else self . unet
1296
+ for module in unet .modules ():
1289
1297
if isinstance (module , BaseTunerLayer ):
1290
1298
active_adapters = module .active_adapters
1291
1299
break
@@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
1309
1317
if hasattr (self , "text_encoder_2" ) and hasattr (self .text_encoder_2 , "peft_config" ):
1310
1318
set_adapters ["text_encoder_2" ] = list (self .text_encoder_2 .peft_config .keys ())
1311
1319
1312
- if hasattr (self , "unet" ) and hasattr (self .unet , "peft_config" ):
1313
- set_adapters ["unet" ] = list (self .unet .peft_config .keys ())
1320
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1321
+ if hasattr (self , self .unet_name ) and hasattr (unet , "peft_config" ):
1322
+ set_adapters [self .unet_name ] = list (self .unet .peft_config .keys ())
1314
1323
1315
1324
return set_adapters
1316
1325
@@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
1331
1340
from peft .tuners .tuners_utils import BaseTunerLayer
1332
1341
1333
1342
# Handle the UNET
1334
- for unet_module in self .unet .modules ():
1343
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
1344
+ for unet_module in unet .modules ():
1335
1345
if isinstance (unet_module , BaseTunerLayer ):
1336
1346
for adapter_name in adapter_names :
1337
1347
unet_module .lora_A [adapter_name ].to (device )
0 commit comments