Skip to content

Hotswapping multiple LoRAs throws a peft key error. #11298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jonluca opened this issue Apr 12, 2025 · 9 comments
Open

Hotswapping multiple LoRAs throws a peft key error. #11298

jonluca opened this issue Apr 12, 2025 · 9 comments
Labels
bug Something isn't working lora

Comments

@jonluca
Copy link

jonluca commented Apr 12, 2025

Describe the bug

When trying to hotswap multiple flux loras you get a runtime error around unexpected keys

RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.13.norm1.linear.lora_B.weight,

Reproduction

Download two Flux Dev loras (this example uses http://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip and http://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip)

Unzip and load the safetensors into memory

import time
import torch
import logging
from diffusers import FluxPipeline

logger = logging.get_logger(__name__)



class DownloadedLora:
    def __init__(self, state_dict):
        self.state_dict = state_dict

    @property
    def model(self):
        state_dict = self.state_dict
        # return a clone
        # of the state dict to avoid modifying the original
        new_state_dict = {}
        for k, v in state_dict.items():
            new_state_dict[k] = v.clone().detach()
        return new_state_dict


def test_lora_hotswap():
    logger.info(f"Initializing flux model")

    # todo - compile https://github.com/huggingface/diffusers/pull/9453 when this gets merged
    flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    )
    flux_base_model = flux_base_model.to("cuda")
    flux_base_model.enable_lora_hotswap(target_rank=128)

    # download and set the state dicts of two random loras
    first_lora = DownloadedLora(state_dict=first_state_dict)
    second_lora = DownloadedLora(state_dict=second_state_dict)

    # we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
    # these will then be enabled or disabled

    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="2")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="3")

    logger.info("Initialized base flux model")
    should_compile = False
    if should_compile:
        flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
        flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
        flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
        flux_base_model.vae = torch.compile(flux_base_model.vae)
        flux_base_model.transformer = torch.compile(
            flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
        )

    for i in range(5):
        start_time = time.time()
        image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
        if i == 0:
            logger.info(f"Warmup: {time.time() - start_time}")
        else:
            logger.info(f"Inference time: {time.time() - start_time}")

        utc_seconds = int(time.time())

        image.save(f"hotswap_{utc_seconds}.png")

        if i == 1:
            logger.info("Hotswapping lora one")
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
        if i == 2:
            logger.info("Hotswapping lora two")
            flux_base_model.load_lora_weights(second_lora.model, adapter_name="2", hotswap=True)
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)

Logs

2025-04-12 04:47:18 | INFO     | Initialized base flux model
100%|██████████| 4/4 [00:01<00:00,  3.64it/s]
2025-04-12 04:47:21 | INFO     | Warmup: 2.4211995601654053
100%|██████████| 4/4 [00:01<00:00,  3.79it/s]
2025-04-12 04:47:23 | INFO     | Inference time: 1.2886595726013184
2025-04-12 04:47:23 | INFO     | Hotswapping lora one
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/team/replay/python/hosted/utils/testing.py", line 708, in <module>
    main()
  File "/home/team/replay/python/hosted/utils/testing.py", line 704, in main
    test_lora_hotswap()
  File "/home/team/replay/python/hosted/utils/testing.py", line 667, in test_lora_hotswap
    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1808, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1899, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/peft.py", line 371, in load_lora_adapter
    hotswap_adapter_from_state_dict(
  File "/home/team/.local/lib/python3.11/site-packages/peft/utils/hotswap.py", line 431, in hotswap_adapter_from_state_dict
    raise RuntimeError(msg)
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.14.ff.net.0.proj.lora_B.weight, single_transformer_blocks.7.attn.to_v.lora_B.weight, ...

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.10.0-34-cloud-amd64-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.50.3
  • Accelerate version: 1.6.0
  • PEFT version: 0.15.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.29.post3
  • Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @yiyixuxu

@jonluca jonluca added the bug Something isn't working label Apr 12, 2025
@sayakpaul
Copy link
Member

My hunch says the underlying LoRAs differ in keys which is why hotswapping is not working. Cc: @BenjaminBossan

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Apr 14, 2025

Yes, the error message indicates that the 2nd adapter targets a module that the 1st adapter does not target (transformer_blocks.14.ff.net.0.proj). I haven't checked which modules are actually being targeted, but if the 2nd adapter happens to target a superset of the 1st adapter's modules, changing the order in which they're loaded might resolve the issue.

The two adapters targeting a disjoint set of modules would unfortunately not work. A workaround would be to create a third adapter that targets the union of modules from the two original adapters, load that one first, then hotswap the original adapters in.

@yiyixuxu yiyixuxu added the lora label Apr 14, 2025
@jonluca
Copy link
Author

jonluca commented Apr 16, 2025

The two Loras have identical keys, here's a complete example that should be runnable / debuggable.

This example also recreates the state dict in case there was some other bug with mutating the state dict during the hotswap.

import time
import torch
import requests
import io
import zipfile
from diffusers import FluxPipeline
import logging

logger = logging.getLogger(__name__)


def hotswap_bug_example():
    def get_safetensor_from_zip(url):
        response = requests.get(url, stream=True)
        zip_bytes = io.BytesIO(response.content)  # Use response.content instead of response.raw.read()

        with zipfile.ZipFile(zip_bytes) as zf:
            for name in zf.namelist():
                if name.endswith(".safetensors"):
                    return zf.read(name)
        return None

    lora_one_sft = get_safetensor_from_zip("https://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip")
    lora_two_sft = get_safetensor_from_zip("https://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip")

    logger.info("Initializing flux model")

    flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    )
    flux_base_model = flux_base_model.to("cuda")
    flux_base_model.enable_lora_hotswap(target_rank=128)
    from safetensors.torch import load

    # download and set the state dicts of two random loras
    first_lora = load(lora_one_sft)
    second_lora = load(lora_two_sft)

    first_lora_keys = first_lora.keys()
    second_lora_keys = second_lora.keys()

    # make sure the keys are the same
    if set(first_lora_keys) != set(second_lora_keys):
        missing_in_first = set(second_lora_keys) - set(first_lora_keys)
        missing_in_second = set(first_lora_keys) - set(second_lora_keys)

        error_message = "LoRA models have different keys:\n"
        if missing_in_first:
            error_message += f"Keys missing in first LoRA: {missing_in_first}\n"
        if missing_in_second:
            error_message += f"Keys missing in second LoRA: {missing_in_second}\n"

        raise ValueError(error_message)

    # we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
    # these will then be enabled or disabled

    flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1")
    flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="2")
    flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="3")

    logger.info("Initialized base flux model")
    should_compile = False
    if should_compile:
        flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
        flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
        flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
        flux_base_model.vae = torch.compile(flux_base_model.vae)
        flux_base_model.transformer = torch.compile(
            flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
        )

    for i in range(5):
        start_time = time.time()
        image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
        if i == 0:
            logger.info(f"Warmup: {time.time() - start_time}")
        else:
            logger.info(f"Inference time: {time.time() - start_time}")

        utc_seconds = int(time.time())

        image.save(f"hotswap_{utc_seconds}.png")

        if i == 1:
            logger.info("Hotswapping lora one")
            flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1", hotswap=True)
        if i == 2:
            logger.info("Hotswapping lora two")
            flux_base_model.load_lora_weights(load(lora_two_sft), adapter_name="2", hotswap=True)
            flux_base_model.load_lora_weights(load(lora_one_sft), adapter_name="1", hotswap=True)



hotswap_bug_example()

@BenjaminBossan
Copy link
Member

Thanks @jonluca for providing a reproducer. Could you be so kind to upload the LoRAs to HF (safetensors format)?

@jonluca
Copy link
Author

jonluca commented Apr 16, 2025

The Loras are in safetensors format. Do you mean without the zip file? If so sure https://base-weights.weights.com/cm9dm38e4061uon15341k47ss.safetensors and https://base-weights.weights.com/cm9dnj1840088n214rn9uych4.safetensors

@BenjaminBossan
Copy link
Member

Thanks for uploading the files correctly. I took your script with some small modifications and it ran without error for me:

import time
import torch
import requests
import io
import zipfile
from diffusers import FluxPipeline
import logging
from safetensors.torch import load_file

logger = logging.getLogger(__name__)
device = "cpu"  # can't fit on GPU ;((
dtype = torch.bfloat16
path1 = <path-to-safetensors-1>
path2 = <path-to-safetensors-2>
should_compile = False

def hotswap_bug_example():
    logger.info("Initializing flux model")

    flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=dtype,
    )
    flux_base_model = flux_base_model.to(device)
    flux_base_model.enable_lora_hotswap(target_rank=128)

    # download and set the state dicts of two random loras
    first_lora = load_file(path1)
    second_lora = load_file(path2)

    first_lora_keys = first_lora.keys()
    second_lora_keys = second_lora.keys()

    # make sure the keys are the same
    if set(first_lora_keys) != set(second_lora_keys):
        missing_in_first = set(second_lora_keys) - set(first_lora_keys)
        missing_in_second = set(first_lora_keys) - set(second_lora_keys)

        error_message = "LoRA models have different keys:\n"
        if missing_in_first:
            error_message += f"Keys missing in first LoRA: {missing_in_first}\n"
        if missing_in_second:
            error_message += f"Keys missing in second LoRA: {missing_in_second}\n"

        raise ValueError(error_message)

    # we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
    # these will then be enabled or disabled

    flux_base_model.load_lora_weights(path1, adapter_name="1")
    flux_base_model.load_lora_weights(path2, adapter_name="2")
    flux_base_model.load_lora_weights(path2, adapter_name="3")

    logger.info("Initialized base flux model")
    if should_compile:
        flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
        flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
        flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
        flux_base_model.vae = torch.compile(flux_base_model.vae)
        flux_base_model.transformer = torch.compile(
            flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
        )

    for i in range(5):
        start_time = time.time()
        # some changes to avoid waiting for 10 hours on CPU
        image = flux_base_model("An image of a cat", num_inference_steps=1, guidance_scale=3.0, width=64, height=64).images[0]
        if i == 0:
            logger.info(f"Warmup: {time.time() - start_time}")
        else:
            logger.info(f"Inference time: {time.time() - start_time}")

        utc_seconds = int(time.time())

        image.save(f"hotswap_{utc_seconds}.png")

        if i == 1:
            logger.info("Hotswapping lora one")
            flux_base_model.load_lora_weights(path1, adapter_name="1", hotswap=True)
        if i == 2:
            logger.info("Hotswapping lora two")
            flux_base_model.load_lora_weights(path2, adapter_name="2", hotswap=True)
            flux_base_model.load_lora_weights(path1, adapter_name="1", hotswap=True)

hotswap_bug_example()

I used the latest main branches of PEFT and diffusers. Could you please check if you can replicate?

I did get a bunch of warnings like

No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
/home/vinh/work/forks/peft/src/peft/tuners/tuners_utils.py:168: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
  warnings.warn(
Loading adapter weights from state_dict led to missing keys in the model: transformer_blocks.0.ff.net.2.lora_A.1.weight, 
...

I haven't investigated whether those are expected or not, but let's deal with one issue at a time.

@sayakpaul
Copy link
Member

@BenjaminBossan thanks for the further investigation. What were the main changes?

No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'.

It can be safely ignored.

Loading adapter weights from state_dict led to missing keys in the model: transformer_blocks.0.ff.net.2.lora_A.1.weight,

This doesn't seem right.

@sayakpaul
Copy link
Member

Some additional comments:

flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)

The flux_base_model should not have an image encoder. Especially this checkpoint: "black-forest-labs/FLUX.1-dev".

Could you also remind me why do we need three apdaters here and where is this documented? Apologies in advance.

@BenjaminBossan
Copy link
Member

What were the main changes?

  1. Loading code: load the state dicts from disk instead of downloading them
  2. Move model to CPU and create very small images with only 1 step

I'm sure those changes should not affect the outcome.

This doesn't seem right.

Okay, then we need to investigate that.

The flux_base_model should not have an image encoder. Especially this checkpoint: "black-forest-labs/FLUX.1-dev".

I didn't try the should_compile = True branch, following the initial example.

Could you also remind me why do we need three apdaters here and where is this documented? Apologies in advance.

I don't think it's strictly necessary for the example, according to the reported error, already the first hotswap fails, I just followed the initial example closely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lora
Projects
None yet
Development

No branches or pull requests

4 participants