Skip to content

🎯 Guiders (support literally everything...?) #11205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 4, 2025

See #11311 instead of this PR.

Guiders

This PR aims to introduce all the commonly used Classifier-free guidance methods to all pipelines, with compatibility for modular diffusers pipelines. The following methods are currently supported:

The PR introduces a GuidanceMixin which defines the general structure of a guider. The guider exposes a few methods that need to be invoked for sampling to work correctly. This means that we need to modify the existing pipelines to use the guiders. Also, new guidance methods, or variants of existing ones, can be used without diffusers core changes.

Let's start with an usage example:

import torch
from diffusers import CogView4Pipeline
from diffusers.guiders import AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, PerturbedAttentionGuidance, SkipLayerGuidance
from diffusers.hooks import LayerSkipConfig

pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Select a guidance method!
guidance = ClassifierFreeGuidance(guidance_scale=5.0)

guidance = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=2.8,
    skip_layer_guidance_start=0.01,
    skip_layer_guidance_stop=0.3,
    skip_layer_config=LayerSkipConfig(indices=[9, 10, 11]),
)

guidance = ClassifierFreeZeroStarGuidance(guidance_scale=5.0, zero_init_steps=1)

guidance = AdaptiveProjectedGuidance(guidance_scale=12.0, adaptive_projected_guidance_momentum=0.2, adaptive_projected_guidance_rescale=15.0)

guidance = PerturbedAttentionGuidance(pag_applied_layers="transformer_blocks.9", guidance_scale=5.0, pag_scale=2.0)

prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, guidance=guidance, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

As can be seen, the user facing changes are minimal. They create a guider and pass it to the pipeline. For any parameters that are currently supported in pipelines, but duplicated due to guiders (for example, guidance_scale, guidance_rescale), we can raise a deprecation warning in favour of passing guidance. There should be no difference in the outputs between the guider versions compared to existing pipeline implementations (if there are any, we require help testing and will fix them asap!)

Guidance methods
APG CFG CFG-Zero* PAG SLG
Outputs with First Block Cache (threshold=0.15)
Gotham-city Joker walking through busy streets, Arthur Fleck, 8k, hyper-realistic, surrealism

For methods like SLG and STG (which is really just SLG), we maintain some metadata on what each attention-processor/transformer-block returns (whether just hidden_states, or a tuple of hidden_states, encoder_hidden_states), so that we can perform block skipping by mapping the input directly to output. There are 3 modes supported (which can be configured differently for different blocks in any combination):

  • Skipping transformer block (default behaviour with skip_attention=True and skip_ff=True). This also corresponds on to STG's block skip mode
  • Skipping attention layer or feed-forward layer only. This corresponds to STG-a mode
  • Skipping attention score computation. This corresponds to STG-av mode
  • Skipping attention residual (not yet supported since there isn't any model that needed it). This corresponds to STG-r mode

Since STG is simply SLG on video models, there's nothing that stops us from applying the various block skips to image models. The main idea behind these methods (SLG, STG, PAG, SEG, SAG, etc.) is to introduce a conditional but perturbed path to perform something similar to Autoguidance, which can significantly improve generation quality in most cases.

Here's an example of the various SLG methods with CogView4:

Code
import torch
from diffusers import CogView4Pipeline, LayerSkipConfig, SkipLayerGuidance

prompt = "Nike sneaker concept, made out of cotton candy clouds , luxury, futurist, stunning unreal engine render, product photography, 8k, hyper-realistic. surrealism"

slg1 = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=2.8,
    skip_layer_guidance_start=0.01,
    skip_layer_guidance_stop=0.3,
    skip_layer_config=LayerSkipConfig(indices=[24, 31], fqn="transformer_blocks", skip_attention=True, skip_ff=True),
)
slg2 = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=2.8,
    skip_layer_guidance_start=0.01,
    skip_layer_guidance_stop=0.3,
    skip_layer_config=LayerSkipConfig(indices=[24, 31], fqn="transformer_blocks", skip_attention=True, skip_ff=False),
)
slg3 = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=2.8,
    skip_layer_guidance_start=0.01,
    skip_layer_guidance_stop=0.3,
    skip_layer_config=LayerSkipConfig(indices=[24, 31], fqn="transformer_blocks", skip_attention=False, skip_ff=True),
)
slg4 = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=2.8,
    skip_layer_guidance_start=0.01,
    skip_layer_guidance_stop=0.3,
    skip_layer_config=LayerSkipConfig(indices=[24, 31], fqn="transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False),
)

for guidance in [slg1, slg2, slg3, slg4]:
    image = pipe(prompt, guidance=guidance, generator=torch.Generator().manual_seed(42)).images[0]
    image.save(f"output_{guidance.skip_layer_config[0].skip_attention}_{guidance.skip_layer_config[0].skip_ff}_{guidance.skip_layer_config[0].skip_attention_scores}.png")
Skip blocks Skip attention Skip ff Skip attention scores

Major changes

  • Guiders require some way of telling the pipeline how many conditional and unconditional inputs to prepare, and sometimes need to prepare a model to do some complex operations.
    • Setting state: some guidance methods require knowing the current inference step and timestep. This is done using set_state()
    • Preparing inputs: prepare_inputs() needs to be called with all values that need to be used for conditioning. This is done so that the guider can appropriately return replicated tensors and in correct order for the guider to work.
    • Preparing outputs: prepare_outputs() is called so that the inputs for guidance can be set correctly.
    • Preparing models: some models need to attach hooks, or require changing attention processor, or require peeking into internal state of denoiser at inference. prepare_models() is called to do this, and any cleanup is done by calling cleanup_models().
  • Throughout diffusers, in most if not all pipelines, we concatenate the positive and negative embeddings to do batched inference. This increases memory usage from intermediate activations by ~2x and is not done consistently (sometimes it's to deal with memory usage, other times it's simply not possible due to different tensor shapes). We now always run individual forward pass for all conditional, unconditional and additional condition batches*. I believe doing it this way will significantly help us in writing caching techniques and supporting algorithms that we previously couldn't do easily.

* We can still do batched inference in two ways: (1) providing list of prompts or num_images_per_prompt > 1, and (2) writing a simple hook that concatenates any tensors that can be batched. We can maintain metadata about which parameters can be batched to simplify such a hook

Using with caching techniques

Some methods, like First Block Cache and Pyramid Attention Broadcast, will work with the new guiders design. However, since the cache implementations are very brittle and quite early in development, they may not work with all pipelines (for example, the current FasterCache implementation assumes batched conditional & unconditional inputs to work). This will be gradually improved, but should be ready-to-use with new guiders design with no user-facing changes.

+ from diffusers import FirstBlockCacheConfig

pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
pipe.to("cuda")

+ pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=0.15))

Questions

  • CFG-Zero* can be generally used with any guidance method (atleast the zeroed predictions for first few steps seems to be helpful with all methods from my limited testing). Should we add those options to all guiders?
  • With the new PAG guider, should we deprecate existing PAG-specific pipelines?
  • SLG has a start and stop parameter that determines the interval of inference steps in which it is active. However, this should be possible on every guider and not just SLG. The guiders are written in a way that they can be dynamically enabled/disabled at any inference step, so users can already do this with callbacks. Should we:
    • add parameters for starting and stopping to all guiders? (say, because it is simpler and easier for an end user to do compared to callbacks)
    • or, deprecate the start/stop parameters from SLG and promote the use of callbacks to achieve the same?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira
Copy link
Contributor

bghira commented Apr 6, 2025

this slows every model that uses CFG down by half since we're no longer batching, which merely slows it down by 15-40% depending on the model and hardware in use

@a-r-r-o-w
Copy link
Member Author

Batching can/will be implemented lazily, so i wouldn't be very concerned with performance at the moment. Although that'll require more work which we can do once we finalize the design

@bghira
Copy link
Contributor

bghira commented Apr 6, 2025

doesn't it seem like moving stop/start logic solely into callbacks makes it so that the pipelines are less useful than before? downstream consumers will have to all essentially implement the same wrapper interface again and again that "monitors" the pipeline progress so that we can change its behaviour.

if they move anywhere it should at least be into the guider, since the guider is essentially a big fat callback implementation anyway. that should be abstracted away by this all, otherwise it's not a pipeline that the dev is interacting with, but a highly mutable object that requires state tracking. shifted burden from pipeline maintainers to pipeline users.

@CodeExplode
Copy link

Since you mentioned Attend-and-Excite, there were two newer papers with repos which seemed to improve on it.

Divide-and-Bind: https://github.com/boschresearch/Divide-and-Bind

Separate-and-Enhance: https://zpbao.github.io/projects/SepEn/

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Apr 7, 2025

doesn't it seem like moving stop/start logic solely into callbacks makes it so that the pipelines are less useful than before? downstream consumers will have to all essentially implement the same wrapper interface again and again that "monitors" the pipeline progress so that we can change its behaviour.

The current guiders implementation should be compatible with any dynamic schedule (for example, changing guidance scale between 1.0 and some N multiple times as in cosine schedules), not just starting and stopping for an interval, but I understand however that it is convenient to have start/stop, hence the question and nice to hear your thoughts on that. If we decide to move forward with this, there'll ofcourse be guarantees that whatever could be done previously should be possible with the new implmentation

import torch
from diffusers import CogView4Pipeline, LayerSkipConfig, ClassifierFreeGuidance

pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "Nike sneaker concept, made out of cotton candy clouds , luxury, futurist, stunning unreal engine render, product photography, 8k, hyper-realistic. surrealism"
num_inference_steps = 50
cfg = ClassifierFreeGuidance(guidance_scale=5.0)

def guidance_callback(pipe, i, t, kwargs):
    if i < 10:
        cfg.guidance_scale = 5.0
    elif i < 20:
        cfg.guidance_scale = 1.0
    elif i < 30:
        cfg.guidance_scale = 2.5
    else:
        cfg.guidance_scale = 6.0
        
    return kwargs

image = pipe(prompt, guidance=cfg, generator=torch.Generator().manual_seed(42), callback_on_step_end=guidance_callback).images[0]
image.save("output.png")

@WeichenFan
Copy link

Hi @a-r-r-o-w, I noticed that the implementation of CFG-Zero* may have some potential bugs.

I've attached the fixed version here:

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional

import torch

from .guider_utils import GuidanceMixin, rescale_noise_cfg


class ClassifierFreeZeroStarGuidance(GuidanceMixin):
    """
    Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
    This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
    guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
    process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
    quality of generated images.
    The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
    Args:
        guidance_scale (`float`, defaults to `7.5`):
            The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
            prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
            deterioration of image quality.
        zero_init_steps (`int`, defaults to `1`):
            The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
        guidance_rescale (`float`, defaults to `0.0`):
            The rescale factor applied to the noise predictions. This is used to improve image quality and fix
            overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
            Flawed](https://huggingface.co/papers/2305.08891).
        use_original_formulation (`bool`, defaults to `False`):
            Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
            we use the diffusers-native implementation that has been in the codebase for a long time. See
            [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
        optimize_uncondition (`bool`, defaults to `True`):
            Whether to optimize the unconditional or conditional velocity.
            If `True`, the optimization is applied to the unconditional prediction (as in the original CFG-Zero★ paper).
            If `False`, the optimization is applied to the conditional prediction instead.
    """

    _input_predictions = ["pred_cond", "pred_uncond"]
    
    def __init__(
        self,
        guidance_scale: float = 7.5,
        zero_init_steps: int = 1,
        guidance_rescale: float = 0.0,
        use_original_formulation: bool = False,
        optimize_uncondition: bool = True,
    ):
        self.guidance_scale = guidance_scale
        self.zero_init_steps = zero_init_steps
        self.guidance_rescale = guidance_rescale
        self.use_original_formulation = use_original_formulation
        self.optimize_uncondition = optimize_uncondition

    def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
        pred = None

        if self._step < self.zero_init_steps:
            pred = torch.zeros_like(pred_cond)
        elif math.isclose(self.guidance_scale, 1.0):
            pred = pred_cond
        else:
            pred_cond_flat = pred_cond.flatten(1)
            pred_uncond_flat = pred_uncond.flatten(1)

            # I added "self.optimize_uncondition" here for user to choose to optimize uncondition or condition
            # By default, we optimize unconditional part (the same as the paper)
            ####
            # alpha should be applied to all the condition or uncondition terms
            # The previous implemntation missed the terms in the shift
            # The shift calculation should be placed after the alpha applied.

            if self.optimize_uncondition:
                alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
                alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
                pred_uncond = pred_uncond * alpha
            else:
                alpha = cfg_zero_star_scale(pred_uncond_flat, pred_cond_flat)
                alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
                pred_cond = pred_cond * alpha

            shift = pred_cond - pred_uncond
            pred = pred_cond if self.use_original_formulation else pred_uncond
            pred = pred + self.guidance_scale * shift

        if self.guidance_rescale > 0.0:
            pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

        return pred

    @property
    def num_conditions(self) -> int:
        num_conditions = 1
        if not math.isclose(self.guidance_scale, 1.0):
            num_conditions += 1
        return num_conditions


def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    cond = cond.float()
    uncond = uncond.float()
    dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
    squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
    scale = dot_product / squared_norm
    return scale.to(cond.dtype)

Check the comparisons here:
image

@a-r-r-o-w
Copy link
Member Author

@WeichenFan Thank you for the review and fix! I'm working on a different version for "modular diffusers" and will be sure to make this update there.

Currently, we don't have a plan to support this in existing pipelines because we want to keep the barrier for entry as low as possible, and the existing pipelines very simple, so this PR might not receive any further updates

@bghira
Copy link
Contributor

bghira commented Apr 12, 2025

so this idea is or isn't going to happen? it's hard to keep up with all of this when internal decision are being made without updating the community. just this week we decided to start implementing stuff based on Guiders. now we won't have access to it in old pipelines? what is the actual plan here?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Apr 12, 2025

@bghira This is a draft PR. It was never promised that this is going to be a thing in core diffusers. It's purpose is to show a proof of concept so that others from the team can review and we can gather feedback. Downstream applications should implement things based on what's available in the main branch. Anything else should be considered unstable, and more definitely so if it is marked draft.

Rest assured, it is happening. It will be part of modular diffusers #9672. A larger goal of diffusers is to provide clean "study-able" and minimal copy-pastable implementation for students/researchers. The "normal" pipelines are exactly that. Doing this PoC showed us that we might get a cleaner implementation overall but at the cost of introducing further layers of abstraction ("the big fat callback" as you mention), which definitely raises the barrier for entry compared to just two lines that implement CFG.

I'll let @yiyixuxu comment more.

@bghira
Copy link
Contributor

bghira commented Apr 12, 2025

modular diffusers looks like it's another quarter before it will be ready for review and then takes a lot more dev work to switch to than this design, which looked like a nice halfway transition instead.

currently, it's just a wait wait wait game for any kind of ability to add PAG to things that we want to, so, obviously, people are going to see this PR and not any hint that it's not a solidified idea that will happen, and then decide to start making use of it, since it solves big problems.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 12, 2025

@bghira
sorry we won't be able to support guider on all pipelines. It will be gradually rolled out with modular diffusers, which will be the the focus of our next release
(sorry you won't be getting guider sooner, but modular will never happen if we don't spend our time carefully)

@vladmandic
Copy link
Contributor

CFG-Zero* can be generally used with any guidance method (at least the zeroed predictions for first few steps seems to be helpful with all methods from my limited testing). Should we add those options to all guiders?

i did few tests with this and i'd say yes.

With the new PAG guider, should we deprecate existing PAG-specific pipelines?

eventually, sure. not in the first release - need to allow users time to migrate

SLG has a start and stop parameter that determines the interval of inference steps in which it is active. However, this should be possible on every guider and not just SLG. The guiders are written in a way that they can be dynamically enabled/disabled at any inference step, so users can already do this with callbacks.
Should we:
add parameters for starting and stopping to all guiders? (say, because it is simpler and easier for an end user to do compared to callbacks) or, deprecate the start/stop parameters from SLG and promote the use of callbacks to achieve the same?

universal start/stop param would be preferred. although doing it via callback is also ok as long as its uniform for all guiders.

This was referenced Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants