Skip to content

First Block Cache #11180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

First Block Cache #11180

wants to merge 18 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Mar 31, 2025

FBC Reference: https://github.com/chengzeyi/ParaAttention

Minimal example

import torch
from diffusers import CogView4Pipeline
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig

pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
pipe.to("cuda")

apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))

prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

Benchmark scripts

Threshold vs Generation time (in seconds) for each model. In general, values below 0.2 work well for First Block Cache depending on the model. Higher values leads to blurring and artifacting

Threshold CogView4 HunyuanVideo LTX Video Wan Flux
0.00 40.51 121.85 33.64 222.17 16.47
0.03 - - 27.14 - -
0.05 24.08 62.47 21.73 139.26 12.63
0.10 17.55 41.84 15.10 89.99 9.27
0.20 12.99 28.11 10.29 57.01 5.91
0.40 - 18.93 7.25 - 3.70
0.50 - 16.65 - - 3.13
CogView4
import argparse
import pathlib

import torch
from diffusers import CogView4Pipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A photo of the eye of agomotto, a mystical object from the Marvel universe, with intricate details and a glowing effect, in a fantasy style. The eye sits at the center of a mystical landscape surrounded by green trees, vibrant flowers and orbs of light. The scene is illuminated by a soft, ethereal glow, creating a magical atmosphere. The eye itself is detailed with swirling patterns and a radiant light"
    negative_prompt = "bad anatomy, ugly, blurry, out of focus, low quality, worst quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, text, error, missing fingers, extra digit, fewer digits, cropped"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).images[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.png"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).images[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).images[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        image.save((output_dir / filename).as_posix())


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_cogview4")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
HunyuanVideo
import argparse
import pathlib

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    model_id = "hunyuanvideo-community/HunyuanVideo"
    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
    pipe.vae.enable_tiling()
    pipe.to("cuda")

    prompt = "A cat walks on the grass, realistic"

    # Warmup
    pipe(prompt=prompt, height=320, width=512, num_frames=61, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, height=320, width=512, num_frames=61, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, height=320, width=512, num_frames=61, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=16)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_hunyuanvideo")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
LTX Video
import argparse
import pathlib

import torch
from diffusers import LTXPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.03, 0.05, 0.1, 0.2, 0.4]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=24)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_hunyuanvideo")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
Wan
import argparse
import pathlib

import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=16)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_wan")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
Flux
import argparse
import pathlib

import torch
from diffusers import FluxPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16, cache_dir="/raid/.cache/huggingface")
    pipe.to("cuda")

    prompt = "A photo of the eye of agomotto, a mystical object from the Marvel universe, with intricate details and a glowing effect, in a fantasy style. The eye sits at the center of a mystical landscape surrounded by green trees, vibrant flowers and orbs of light. The scene is illuminated by a soft, ethereal glow, creating a magical atmosphere. The eye itself is detailed with swirling patterns and a radiant light"

    # Warmup
    pipe(prompt=prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).images[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.png"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            image = pipe(prompt=prompt, generator=torch.Generator().manual_seed(42)).images[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            image = pipe(prompt=prompt, generator=torch.Generator().manual_seed(42)).images[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        image.save((output_dir / filename).as_posix())


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_flux")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)

Visual result comparison

CogView4
threshold=0.00 threshold=0.05
threshold=0.10 threshold=0.20
Hunyuan Video
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.05000.mp4
threshold=0.10 threshold=0.20
output_0.10000.mp4
output_0.20000.mp4
threshold=0.40 threshold=0.50
output_0.40000.mp4
output_0.50000.mp4
LTX Video
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.03000.mp4
threshold=0.10 threshold=0.20
output_0.05000.mp4
output_0.10000.mp4
threshold=0.40 threshold=0.50
output_0.20000.mp4
output_0.40000.mp4
Wan
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.05000.mp4
threshold=0.10 threshold=0.20
output_0.10000.mp4
output_0.20000.mp4
Flux
threshold=0.00 threshold=0.05
threshold=0.10 threshold=0.20
threshold=0.40 threshold=0.50

Using with torch.compile

  • There is a forced graph break for the data-dependant control flow branching. This portion of code will always run in eager mode
  • There are a few recompilations triggered at a weird/unexpected location - the attention processor invocation. This only happens when using hooks, so I believe the current hook implementation is making torch.compile tracing add some unnecessary id/type guards. This will be tackled in the future since I haven't been able to make much progress into rewriting it

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
def forward(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixuxu for reviewing changes to the transformer here. The changes were made to simplify some of the code required to make cache techniques work somewhat more easily without even more if-else branching.

Ideally, if we stick to implementing models such that all blocks take in both hidden_states and encoder_hidden_states, and always return (hidden_states, encoder_hidden_states) from the block, a lot of design choices in the hook-based code can be simplified.

For now, I think these changes should be safe and come without any significant overhead to generation time (I haven't benchmarked though).

@@ -0,0 +1,222 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @chengzeyi Would be super cool if you could give this PR a review since we're trying to integrate FBC to work with all supported models!

Currently, I've only done limited testing on few models but it should be easily extendable to all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w I see, let me take a look!

return cls._registry[model_class]


def _register_transformer_blocks_metadata():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 For now, this PR only adds metadata for transformer blocks. The information here will be filled up over time for more models to simplify the assumptions made in hook-based code and make it cleaner to work with. More metadata needs to be maintained for transformer implementations to simplify cache methods like FasterCache, which I'll cover in future PRs

)


class BaseMarkedState(BaseState):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To explain simply, a "marked" state is a copies of a state object for different batches of data. In our pipelines, we do the following:

  • concatenate unconditional and conditional batch and perform single forward pass through transformer
  • perform individual forward passes for conditional and unconditional batch

The state variables must track values specific to each batch of data over all inference steps, otherwise you might end up in a situation where the state variable for conditional batch is used for unconditional batch, or vice versa.

@@ -917,6 +917,7 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

cc.mark_state("cond")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it this way helps us distinguish different batches of data. I believe this design will fit well with the upcoming "guiders" to support multiple guidance methods. As guidance methods may use 1 (guidance-distilled) or 2 (CFG) or 3 (PAG) or more latent batches, we can call "mark state" any number of times to distinguish between calls to transformer.forward with different batches of data for the same inference step.

Comment on lines +2719 to +2721
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding new cache tester mixins each time and extending the list of parent mixins for each pipeline test is probably not going to be a clean way of testing. We can refactor this in the future and consolidate all cache methods into a single tester once they are better supported/implemented for most models

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu April 1, 2025 23:37


# fmt: off
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to use this naming convention here? Function is just meant to return combinations of hidden/encoder_hidden_states right?

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It just spells out what argument index hidden_states is at and what it returns. Do you have any particular recommendation?

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@a-r-r-o-w a-r-r-o-w mentioned this pull request Apr 14, 2025
12 tasks
@a-r-r-o-w
Copy link
Member Author

@DN6 Addressed the review comments. Could you give it another look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants