-
Notifications
You must be signed in to change notification settings - Fork 5.9k
First Block Cache #11180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
First Block Cache #11180
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ithout too much model-specific intrusion code)
@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu for reviewing changes to the transformer here. The changes were made to simplify some of the code required to make cache techniques work somewhat more easily without even more if-else branching.
Ideally, if we stick to implementing models such that all blocks take in both hidden_states and encoder_hidden_states, and always return (hidden_states, encoder_hidden_states)
from the block, a lot of design choices in the hook-based code can be simplified.
For now, I think these changes should be safe and come without any significant overhead to generation time (I haven't benchmarked though).
@@ -0,0 +1,222 @@ | |||
# Copyright 2024 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @chengzeyi Would be super cool if you could give this PR a review since we're trying to integrate FBC to work with all supported models!
Currently, I've only done limited testing on few models but it should be easily extendable to all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w I see, let me take a look!
return cls._registry[model_class] | ||
|
||
|
||
def _register_transformer_blocks_metadata(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 For now, this PR only adds metadata for transformer blocks. The information here will be filled up over time for more models to simplify the assumptions made in hook-based code and make it cleaner to work with. More metadata needs to be maintained for transformer implementations to simplify cache methods like FasterCache, which I'll cover in future PRs
src/diffusers/hooks/hooks.py
Outdated
) | ||
|
||
|
||
class BaseMarkedState(BaseState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To explain simply, a "marked" state is a copies of a state object for different batches of data. In our pipelines, we do the following:
- concatenate unconditional and conditional batch and perform single forward pass through transformer
- perform individual forward passes for conditional and unconditional batch
The state variables must track values specific to each batch of data over all inference steps, otherwise you might end up in a situation where the state variable for conditional batch is used for unconditional batch, or vice versa.
@@ -917,6 +917,7 @@ def __call__( | |||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |||
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |||
|
|||
cc.mark_state("cond") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it this way helps us distinguish different batches of data. I believe this design will fit well with the upcoming "guiders" to support multiple guidance methods. As guidance methods may use 1 (guidance-distilled) or 2 (CFG) or 3 (PAG) or more latent batches, we can call "mark state" any number of times to distinguish between calls to transformer.forward
with different batches of data for the same inference step.
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out | ||
# of the box once there is better cache support/implementation | ||
class FirstBlockCacheTesterMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding new cache tester mixins each time and extending the list of parent mixins for each pipeline test is probably not going to be a clean way of testing. We can refactor this in the future and consolidate all cache methods into a single tester once they are better supported/implemented for most models
|
||
|
||
# fmt: off | ||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to use this naming convention here? Function is just meant to return combinations of hidden/encoder_hidden_states right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. It just spells out what argument index hidden_states
is at and what it returns. Do you have any particular recommendation?
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@DN6 Addressed the review comments. Could you give it another look? |
FBC Reference: https://github.com/chengzeyi/ParaAttention
Minimal example
Benchmark scripts
Threshold vs Generation time (in seconds) for each model. In general, values below 0.2 work well for First Block Cache depending on the model. Higher values leads to blurring and artifacting
CogView4
HunyuanVideo
LTX Video
Wan
Flux
Visual result comparison
CogView4
Hunyuan Video
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
output_0.50000.mp4
LTX Video
output_0.00000.mp4
output_0.03000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
Wan
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
Flux
Using with
torch.compile