-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[bitsandbytes] improve dtype mismatch handling for bnb + lora. #11270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Just a small question
raise ValueError( | ||
f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. " | ||
f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`." | ||
) | ||
module_weight = dequantize_bnb_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we specified dtype = model.dtype in the dequantize_bnb_weight
, won't the module_weights
have the same dtype as model ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will. But the LoRA params would not be in that dtype as they are derived early from the module_weight data dtype. This is why in the error trace, the error happens in peft
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To summarize, we have the following right ? :
- changed loras params using dtype from module_weight (this is maybe where module.weight.quant_state.dtype was used)
- dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.
-> dtype mismatch issue due to loras param not having the same dtype as module_weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. We don't really have any special treatment to handle LoRA param dtype. Ccing @BenjaminBossan here.
dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.
Well, we use the quant_state:
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
But then we also perform another type-casting:
if dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify: This is unrelated to the LoRA parameters. Instead, what happens is that a PEFT LoraLayer
wraps the base layer and calls self.base_layer(x)
, which should just be the result from the original layer. Due to the change in dtype, we will encounter the dtype mismatch there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from diffusers import DiffusionPipeline, FluxControlPipeline
from PIL import Image
import torch
pipe = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16).to("cuda")
pipe("a dog", control_image=Image.new(mode="RGB", size=(256, 256)))
This works though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it happens inside the LoRA layer, but what I mean is that the LoRA weights are not involved, it's the call to the base layer that is causing the issue.
@BenjaminBossan @SunMarc instead of the current error raising proposal, this solves the issue: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 2e241bc9f..080559357 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -2367,7 +2367,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
- in_features, out_features, bias=bias, dtype=module_weight.dtype
+ in_features, out_features, bias=bias, dtype=transformer.dtype
)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor Does this work for you? Currently, we keep the expanded module as diffusers/src/diffusers/loaders/lora_pipeline.py Line 2369 in ea5a6a8
|
If the current situation is just temporary and the proposed change solves the initial issue, then that's fine with me. I wonder if we can always rely on the |
Eventually, we will resort to the |
Wait, I'm confused now, I had understood your comment to suggest that you would rather do that instead of raising an error. But the PR currently still raises the error. 😖 |
My comment suggests the use of
This will include
Is this clearer now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it now. PR LGTM.
What does this PR do?
If we try to do:
we will run into
Error
This PR fixes that