-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Modular Diffusers Guiders #11311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: modular-refactor
Are you sure you want to change the base?
Modular Diffusers Guiders #11311
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@vladmandic's suggestion about having a universal start/stop parameter from here is now implemented too. Note, however, that the guiders should already support any kind of dynamic schedule with multiple enabling/disabling per inference if user modifies the properties on the guider object (see this comment for example). Batched inference is still supported too! (in terms of multiple prompts and setting |
@@ -0,0 +1,271 @@ | |||
# Copyright 2024 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more context on why we need this, see #10875 and this comment.
I discussed with Dhruv and for now we should keep it. After one of FBC or Guider PR is merged to main
, I can do the refactor and make use of decorators. This will save me the burden of implementing the same thing in both PRs and maintaining it until one gets merged, but rest assured I'll do the refactor before next release
|
||
def _register_attention_processors_metadata(): | ||
# AttnProcessor2_0 | ||
AttentionProcessorRegistry.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, only this and BasicTransformerBlock
is relevant, since modular diffusers only supports SDXL. The remaining is from copying but we keep it to avoid merge conflict since FirstBlockCache PR will most likely be merged before modular diffusers
return noise_cfg | ||
|
||
|
||
def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is easier to work with if we:
- provide a default method here on guilder_utils.py to deal with a list of inputs like you specified here: each element could be a tensor or tuples/list of tensors - this logic should be mostly the same for different guiders, no?
- let each specific guider class to define how to prepare each input element
basically the method here become something like this, would this make sense?
def prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> Tuple[List[torch.Tensor], ...]:
"""
Prepares the inputs for the denoiser by processing each argument individually using a helper method.
"""
list_of_inputs = []
for arg in args:
if isinstance(arg, (tuple, list))
if len(args) != 2:
raise ValueError("...")
elif not isinstance(arg, Torch.Tensor):
raise ValueError("...")
processed_input = self.prepare_input_single(arg, num_conditions)
list_of_inputs.append(processed_input)
return tuple(list_of_inputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll update the implementations
added_cond_kwargs=data.added_cond_kwargs, | ||
return_dict=False, | ||
)[0] | ||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to do something like this?
noise_pred_outputs = []
for batch_index, (...) in enumerate(zip(...):
latents_i = ...
noise_pred = pipeline.unet(..)
noise_pred_outputs = self.guilder.prepare_and_add_output(pipeline.unet, noise_pred, noise_pred_outputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.
I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think
if self._is_cfgpp_enabled(): | ||
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! | ||
pred_cond = self._preds["pred_cond"] | ||
pred_uncond = self._preds["pred_uncond"] | ||
diff = pred_uncond - pred_cond | ||
pred = pred + diff * self.guidance_scale * self._sigma_next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original repository implements CFG++ in a different way. I wanted to try and make it work without really modifying all our schedulers, and so it's done this way. The math works out the same.
For context, in our schedulers, we do:
new_sample = sample + model_output_after_cfg * (sigmas[i + 1] - sigmas[i])
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_after_cfg * sigmas[i + 1]
What we need to do for CFG++ is this instead:
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_uncond * sigmas[i + 1]
(This is only for EulerDiscreteScheduler and will differ for other schedulers)
After a little bit of working it out on paper, I found that some different schedulers don't really have to be modified if we add and subtract some terms after the scheduler step. We will need to have some specialized code (it can either exist in this file or the scheduler file) to add/subtract the right terms for each scheduler, so LMK how you think we should do it
Nevermind, it's better to just do this in the scheduler
return noise_cfg | ||
|
||
|
||
def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll update the implementations
added_cond_kwargs=data.added_cond_kwargs, | ||
return_dict=False, | ||
)[0] | ||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.
I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think
Also cc @DN6 for all the custom hook implementations |
@@ -668,7 +675,38 @@ def step( | |||
dt = self.sigmas[self.step_index + 1] - sigma_hat | |||
|
|||
prev_sample = sample + derivative * dt | |||
|
|||
if _use_cfgpp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we are hoping to find a scalable solution that can provide maximium support for community creativity. It isn't scalable if it requires code change into schedulers.
I think it can be manipulated inside guider, no? since, we have all the variables in pipeline state and all the components in model states, which you can use to access scheduler and tbe sigmas counter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Euler, yes, it is easy add a correction term outside the scheduler step and make it work -- this is how it was originally implemented.
For DDIM, DPM++, and all the others, it quickly gets very complicated to handle all the correction terms correctly since you need to recalculate a lot of variables for the original model_output
, subtract them out, calculate the correct variables using model_pred_uncond
, add that in. I don't think that having specialized code in the guider to handle all usable schedulers, probably using isinstance
checks, is a good approach.
The following methods are currently supported:
Note: PAG is implemented as Skip Layer Guidance and does not have its own guider implementation. The equivalent SLG initialization is:
Note: STG is also implemented as Skip Layer Guidance:
skip_attention=False, skip_ff=True, skip_attention_scores=False
skip_attention=True, skip_ff=False, skip_attention_scores=False
skip_attention=True, skip_ff=True, skip_attention_scores=False
skip_attention=False, skip_ff=False, skip_attention_scores=True
(essentially PAG)Note: You can use different SLG configs for different parts of the model. Create multiple configs and pass as a list to
skip_layer_config
Minimal all guiders testing script
YiYi's modified full test script