-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Hotswapping multiple LoRAs throws a peft key error. #11298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
My hunch says the underlying LoRAs differ in keys which is why hotswapping is not working. Cc: @BenjaminBossan |
Yes, the error message indicates that the 2nd adapter targets a module that the 1st adapter does not target ( The two adapters targeting a disjoint set of modules would unfortunately not work. A workaround would be to create a third adapter that targets the union of modules from the two original adapters, load that one first, then hotswap the original adapters in. |
The two Loras have identical keys, here's a complete example that should be runnable / debuggable. This example also recreates the state dict in case there was some other bug with mutating the state dict during the hotswap. import time
import torch
import requests
import io
import zipfile
from diffusers import FluxPipeline
import logging
logger = logging.getLogger(__name__)
def hotswap_bug_example():
def get_safetensor_from_zip(url):
response = requests.get(url, stream=True)
zip_bytes = io.BytesIO(response.content) # Use response.content instead of response.raw.read()
with zipfile.ZipFile(zip_bytes) as zf:
for name in zf.namelist():
if name.endswith(".safetensors"):
return zf.read(name)
return None
lora_one_sft = get_safetensor_from_zip("https://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip")
lora_two_sft = get_safetensor_from_zip("https://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip")
logger.info("Initializing flux model")
flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
flux_base_model = flux_base_model.to("cuda")
flux_base_model.enable_lora_hotswap(target_rank=128)
from safetensors.torch import load
# download and set the state dicts of two random loras
first_lora = load(lora_one_sft)
second_lora = load(lora_two_sft)
first_lora_keys = first_lora.keys()
second_lora_keys = second_lora.keys()
# make sure the keys are the same
if set(first_lora_keys) != set(second_lora_keys):
missing_in_first = set(second_lora_keys) - set(first_lora_keys)
missing_in_second = set(first_lora_keys) - set(second_lora_keys)
error_message = "LoRA models have different keys:\n"
if missing_in_first:
error_message += f"Keys missing in first LoRA: {missing_in_first}\n"
if missing_in_second:
error_message += f"Keys missing in second LoRA: {missing_in_second}\n"
raise ValueError(error_message)
# we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
# these will then be enabled or disabled
flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1")
flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="2")
flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="3")
logger.info("Initialized base flux model")
should_compile = False
if should_compile:
flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
flux_base_model.vae = torch.compile(flux_base_model.vae)
flux_base_model.transformer = torch.compile(
flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
)
for i in range(5):
start_time = time.time()
image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
if i == 0:
logger.info(f"Warmup: {time.time() - start_time}")
else:
logger.info(f"Inference time: {time.time() - start_time}")
utc_seconds = int(time.time())
image.save(f"hotswap_{utc_seconds}.png")
if i == 1:
logger.info("Hotswapping lora one")
flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1", hotswap=True)
if i == 2:
logger.info("Hotswapping lora two")
flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="2", hotswap=True)
flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1", hotswap=True)
hotswap_bug_example() |
Thanks @jonluca for providing a reproducer. Could you be so kind to upload the LoRAs to HF (safetensors format)? |
The Loras are in safetensors format. Do you mean without the zip file? If so sure https://base-weights.weights.com/cm9dm38e4061uon15341k47ss.safetensors and https://base-weights.weights.com/cm9dnj1840088n214rn9uych4.safetensors |
Thanks for uploading the files correctly. I took your script with some small modifications and it ran without error for me: import time
import torch
import requests
import io
import zipfile
from diffusers import FluxPipeline
import logging
from safetensors.torch import load_file
logger = logging.getLogger(__name__)
device = "cpu" # can't fit on GPU ;((
dtype = torch.bfloat16
path1 = <path-to-safetensors-1>
path2 = <path-to-safetensors-2>
should_compile = False
def hotswap_bug_example():
logger.info("Initializing flux model")
flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
)
flux_base_model = flux_base_model.to(device)
flux_base_model.enable_lora_hotswap(target_rank=128)
# download and set the state dicts of two random loras
first_lora = load_file(path1)
second_lora = load_file(path2)
first_lora_keys = first_lora.keys()
second_lora_keys = second_lora.keys()
# make sure the keys are the same
if set(first_lora_keys) != set(second_lora_keys):
missing_in_first = set(second_lora_keys) - set(first_lora_keys)
missing_in_second = set(first_lora_keys) - set(second_lora_keys)
error_message = "LoRA models have different keys:\n"
if missing_in_first:
error_message += f"Keys missing in first LoRA: {missing_in_first}\n"
if missing_in_second:
error_message += f"Keys missing in second LoRA: {missing_in_second}\n"
raise ValueError(error_message)
# we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
# these will then be enabled or disabled
flux_base_model.load_lora_weights(path1, adapter_name="1")
flux_base_model.load_lora_weights(path2, adapter_name="2")
flux_base_model.load_lora_weights(path2, adapter_name="3")
logger.info("Initialized base flux model")
if should_compile:
flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
flux_base_model.vae = torch.compile(flux_base_model.vae)
flux_base_model.transformer = torch.compile(
flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
)
for i in range(5):
start_time = time.time()
# some changes to avoid waiting for 10 hours on CPU
image = flux_base_model("An image of a cat", num_inference_steps=1, guidance_scale=3.0, width=64, height=64).images[0]
if i == 0:
logger.info(f"Warmup: {time.time() - start_time}")
else:
logger.info(f"Inference time: {time.time() - start_time}")
utc_seconds = int(time.time())
image.save(f"hotswap_{utc_seconds}.png")
if i == 1:
logger.info("Hotswapping lora one")
flux_base_model.load_lora_weights(path1, adapter_name="1", hotswap=True)
if i == 2:
logger.info("Hotswapping lora two")
flux_base_model.load_lora_weights(path2, adapter_name="2", hotswap=True)
flux_base_model.load_lora_weights(path1, adapter_name="1", hotswap=True)
hotswap_bug_example() I used the latest main branches of PEFT and diffusers. Could you please check if you can replicate? I did get a bunch of warnings like
I haven't investigated whether those are expected or not, but let's deal with one issue at a time. |
@BenjaminBossan thanks for the further investigation. What were the main changes?
It can be safely ignored.
This doesn't seem right. |
Some additional comments:
The Could you also remind me why do we need three apdaters here and where is this documented? Apologies in advance. |
I'm sure those changes should not affect the outcome.
Okay, then we need to investigate that.
I didn't try the
I don't think it's strictly necessary for the example, according to the reported error, already the first hotswap fails, I just followed the initial example closely. |
Describe the bug
When trying to hotswap multiple flux loras you get a runtime error around unexpected keys
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.13.norm1.linear.lora_B.weight,
Reproduction
Download two Flux Dev loras (this example uses http://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip and http://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip)
Unzip and load the safetensors into memory
Logs
System Info
Who can help?
@sayakpaul @yiyixuxu
The text was updated successfully, but these errors were encountered: