Skip to content

[bitsandbytes] improve dtype mismatch handling for bnb + lora. #11270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Apr 10, 2025

What does this PR do?

If we try to do:

from diffusers import DiffusionPipeline, FluxControlPipeline
from PIL import Image
import torch

pipe = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

pipe("a dog", control_image=Image.new(mode="RGB", size=(256, 256)))

we will run into

Error
Traceback (most recent call last):
  File "/fsx/sayak/diffusers/bnb_torch_dtype.py", line 8, in <module>
    pipe("a dog", control_image=Image.new(mode="RGB", size=(256, 256)))
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/flux/pipeline_flux_control.py", line 835, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 445, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 712, in forward
    result = self.base_layer(x, *args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Half

This PR fixes that

@sayakpaul sayakpaul requested review from DN6 and SunMarc April 10, 2025 04:19
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a small question

raise ValueError(
f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. "
f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`."
)
module_weight = dequantize_bnb_weight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we specified dtype = model.dtype in the dequantize_bnb_weight, won't the module_weights have the same dtype as model ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will. But the LoRA params would not be in that dtype as they are derived early from the module_weight data dtype. This is why in the error trace, the error happens in peft.

Copy link
Member

@SunMarc SunMarc Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize, we have the following right ? :

  • changed loras params using dtype from module_weight (this is maybe where module.weight.quant_state.dtype was used)
  • dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.

-> dtype mismatch issue due to loras param not having the same dtype as module_weight

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. We don't really have any special treatment to handle LoRA param dtype. Ccing @BenjaminBossan here.

dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.

Well, we use the quant_state:

output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)

But then we also perform another type-casting:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: This is unrelated to the LoRA parameters. Instead, what happens is that a PEFT LoraLayer wraps the base layer and calls self.base_layer(x), which should just be the result from the original layer. Due to the change in dtype, we will encounter the dtype mismatch there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from diffusers import DiffusionPipeline, FluxControlPipeline
from PIL import Image
import torch

pipe = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16).to("cuda")

pipe("a dog", control_image=Image.new(mode="RGB", size=(256, 256)))

This works though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it happens inside the LoRA layer, but what I mean is that the LoRA weights are not involved, it's the call to the base layer that is causing the issue.

@sayakpaul
Copy link
Member Author

@BenjaminBossan @SunMarc instead of the current error raising proposal, this solves the issue:

diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 2e241bc9f..080559357 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -2367,7 +2367,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
                     # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
                     with torch.device("meta"):
                         expanded_module = torch.nn.Linear(
-                            in_features, out_features, bias=bias, dtype=module_weight.dtype
+                            in_features, out_features, bias=bias, dtype=transformer.dtype
                         )
                     # Only weights are expanded and biases are not. This is because only the input dimensions
                     # are changed while the output dimensions remain the same. The shape of the weight tensor

Does this work for you?

Currently, we keep the expanded module as nn.Linear even when the underlying model is quantized (such as the case here). But we eventually plan to move to using the respective quantized linear layer (determined by the quantization backend being used).

expanded_module = torch.nn.Linear(

@BenjaminBossan
Copy link
Member

Does this work for you?

If the current situation is just temporary and the proposed change solves the initial issue, then that's fine with me. I wonder if we can always rely on the dtype attribute on the model, there can be more than one dtype, right?

@sayakpaul
Copy link
Member Author

I wonder if we can always rely on the dtype attribute on the model, there can be more than one dtype, right?

Eventually, we will resort to the module_weight.dtype solution as that is more precise. But this is a special case for now.

@BenjaminBossan
Copy link
Member

Wait, I'm confused now, I had understood your comment to suggest that you would rather do that instead of raising an error. But the PR currently still raises the error. 😖

@sayakpaul
Copy link
Member Author

My comment suggests the use of model.dtype not module_weight.dtype. For now, that should be okay for the case we're covering. I definitely want to follow what I suggested in #11270 (comment) and eventually move to (through a future PR):

But we eventually plan to move to using the respective quantized linear layer (determined by the quantization backend being used).

This will include

we will resort to the module_weight.dtype solution as that is more precise

Is this clearer now?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it now. PR LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants