Skip to content

Modular Diffusers Guiders #11311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: modular-refactor
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 14, 2025

The following methods are currently supported:

Note: PAG is implemented as Skip Layer Guidance and does not have its own guider implementation. The equivalent SLG initialization is:

from diffusers import SkipLayerGuidance, LayerSkipConfig

config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=False, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=5.0, skip_layer_guidance_scale=2.5, skip_layer_config=config)

Note: STG is also implemented as Skip Layer Guidance:

  • STG-r: skip_attention=False, skip_ff=True, skip_attention_scores=False
  • STG-a: skip_attention=True, skip_ff=False, skip_attention_scores=False
  • STG-t: skip_attention=True, skip_ff=True, skip_attention_scores=False
  • STG-v:skip_attention=False, skip_ff=False, skip_attention_scores=True (essentially PAG)

Note: You can use different SLG configs for different parts of the model. Create multiple configs and pass as a list to skip_layer_config

APG CFG CFGZ SLG (Skip Attention Scores, Skip FF) SLG (Skip Attention Scores) SLG (Skip Attention, Skip FF) SLG (Skip Attention) SLG (Skip FF)
Minimal all guiders testing script
from pathlib import Path

import torch
import torch.nn.functional as F
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.guiders import AdaptiveProjectedGuidance, AutoGuidance, CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance
from diffusers.hooks import LayerSkipConfig, SmoothedEnergyGuidanceConfig

output_dir = "dump_modular_diffusers"
Path(output_dir).mkdir(parents=True, exist_ok=True)

components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024



cfg = ClassifierFreeGuidance(guidance_scale=10.0, guidance_rescale=0.0, use_original_formulation=False, start=0.0, stop=1.0)
pipe.update_states(guider=cfg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfg.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=False, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=True, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention---skip_ff.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=True, skip_attention_scores=False)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_ff.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=False, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention_scores.png")



config = LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_ff=True, skip_attention_scores=True)
slg = SkipLayerGuidance(guidance_scale=7.5, skip_layer_guidance_scale=2.5, skip_layer_config=config)
pipe.update_states(guider=slg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-slg---skip_attention_scores---skip_ff.png")



apg = AdaptiveProjectedGuidance(guidance_scale=12.0, adaptive_projected_guidance_momentum=-0.5, adaptive_projected_guidance_rescale=10.0)
pipe.update_states(guider=apg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-apg.png")



# Should not set zero_init_steps > 0 for non-flow-matching schedulers
cfgz = ClassifierFreeZeroStarGuidance(guidance_scale=10.0, zero_init_steps=0)
pipe.update_states(guider=cfgz)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfgz.png")



configs = []
configs.append(LayerSkipConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=True, skip_ff=True, skip_attention_scores=False, dropout=0.1))
configs.append(LayerSkipConfig(indices=[0, 1], fqn="down_blocks.1.attentions.1.transformer_blocks", skip_attention=True, skip_ff=False, skip_attention_scores=False, dropout=0.05))
ag = AutoGuidance(guidance_scale=10.0, auto_guidance_config=configs, use_original_formulation=False, start=0.0, stop=1.0)
pipe.update_states(guider=ag)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-ag.png")



config = SmoothedEnergyGuidanceConfig(indices=[2, 3, 8, 9], fqn="mid_block.attentions.0.transformer_blocks")
seg = SmoothedEnergyGuidance(guidance_scale=7.5, seg_guidance_scale=2.5, seg_blur_sigma=9999999.0, seg_guidance_config=config)
pipe.update_states(guider=seg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-seg.png")


tcfg = TangentialClassifierFreeGuidance(guidance_scale=10.0, guidance_rescale=0.0, use_original_formulation=False, start=0.00, stop=1.0)
pipe.update_states(guider=tcfg)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-tcfg.png")



cfgpp = CFGPlusPlusGuidance(guidance_scale=0.9, guidance_rescale=0.0, use_original_formulation=False, start=0.0, stop=1.0)
assert pipe.scheduler.__class__.__name__ == "EulerDiscreteScheduler"
pipe.update_states(guider=cfgpp)
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30, generator=torch.Generator().manual_seed(42))
images = output.intermediates.get("images").images
images[0].save(f"{output_dir}/output-cfgpp.png")
YiYi's modified full test script
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
    AdaptiveProjectedGuidance,
    ClassifierFreeGuidance,
    SkipLayerGuidance,
    LayerSkipConfig,
)
from diffusers.utils import load_image
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLAutoPipeline, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "dump_modular_diffusers"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)      

auto_pipeline_block = StableDiffusionXLAutoPipeline()
auto_pipeline = ModularPipeline.from_block(auto_pipeline_block)
refiner_pipeline = ModularPipeline.from_block(auto_pipeline_block)



# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
components.add("controlnet", controlnet)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)


# load components/config into nodes
auto_pipeline.update_states(**components.components)


# load other componetns for swap later
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = SkipLayerGuidance(
    guidance_scale=5.0,
    skip_layer_guidance_scale=3.0,
    skip_layer_config=LayerSkipConfig(
        indices=[2, 3, 7, 8],
        fqn="mid_block.attentions.0.transformer_blocks",
        skip_attention=False,
        skip_ff=False,
        skip_attention_scores=True,
    ),
    start=0.0,
    stop=1.0,
)
cfg_guider = ClassifierFreeGuidance(guidance_scale=5.0)


# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")


# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")

# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()


# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")

# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.unload_lora_weights()

auto_pipeline.update_states(guider=pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    ip_adapter_image=ip_adapter_image,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test  4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()

# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
# let's checkout the refiner_node
print(f" refiner_pipeline info")
print(refiner_pipeline)
print(f" ")

print(f" refiner_pipeline: triggered by `image_latents`")
print(refiner_pipeline.get_execution_blocks("image_latents"))
print(" ")

print(f" running test8: img2img with refiner")


generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = refiner_pipeline(
    image_latents=latents,  
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    image=init_image,
    height=1024,
    width=1024,
    mask_image=inpaint_mask,
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

auto_pipeline.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    padding_mask_crop=33, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()

# test13: apg

print(f" ")
print(f" running test13: apg")

apg_guider = AdaptiveProjectedGuidance(guidance_scale=15.0, adaptive_projected_guidance_momentum=-0.3, adaptive_projected_guidance_rescale=12.0, start=0.01)
auto_pipeline.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
  prompt=prompt, 
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), vae=components.get("vae_fix"), guider=pag_guider)
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    generator=generator, 
    control_mode=[3], 
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt, 
    height=1024, 
    width=1024, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    mask_image=inpaint_mask, 
    control_image=controlnet_union_image,
    control_mode=[3],
    height=1024, 
    width=1024, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w mentioned this pull request Apr 14, 2025
@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review April 14, 2025 15:04
@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu April 14, 2025 15:04
@a-r-r-o-w
Copy link
Member Author

@vladmandic's suggestion about having a universal start/stop parameter from here is now implemented too. Note, however, that the guiders should already support any kind of dynamic schedule with multiple enabling/disabling per inference if user modifies the properties on the guider object (see this comment for example).

Batched inference is still supported too! (in terms of multiple prompts and setting num_images_per_prompt > 1. It's just that it is not supported by batching conditional and unconditional branches together. This can be handled lazily eventually but I'm prioritizing implementing the methods to work first, before doing anything too complex/time consuming. We need to design in a way that caching methods would be compatible easily, and potentially other techniques that we couldn't support before too.

@@ -0,0 +1,271 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context on why we need this, see #10875 and this comment.

I discussed with Dhruv and for now we should keep it. After one of FBC or Guider PR is merged to main, I can do the refactor and make use of decorators. This will save me the burden of implementing the same thing in both PRs and maintaining it until one gets merged, but rest assured I'll do the refactor before next release


def _register_attention_processors_metadata():
# AttnProcessor2_0
AttentionProcessorRegistry.register(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, only this and BasicTransformerBlock is relevant, since modular diffusers only supports SDXL. The remaining is from copying but we keep it to avoid merge conflict since FirstBlockCache PR will most likely be merged before modular diffusers

return noise_cfg


def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is easier to work with if we:

  1. provide a default method here on guilder_utils.py to deal with a list of inputs like you specified here: each element could be a tensor or tuples/list of tensors - this logic should be mostly the same for different guiders, no?
  2. let each specific guider class to define how to prepare each input element

basically the method here become something like this, would this make sense?

def prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> Tuple[List[torch.Tensor], ...]:
    """
    Prepares the inputs for the denoiser by processing each argument individually using a helper method.
    """
    list_of_inputs = []
    
    for arg in args:
        if isinstance(arg, (tuple, list))
            if len(args) != 2:
                raise ValueError("...")
        elif not isinstance(arg, Torch.Tensor):
            raise ValueError("...")
        processed_input = self.prepare_input_single(arg, num_conditions)

        list_of_inputs.append(processed_input)
    
    return tuple(list_of_inputs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the implementations

added_cond_kwargs=data.added_cond_kwargs,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do something like this?

noise_pred_outputs = []
for batch_index, (...) in enumerate(zip(...):
    latents_i = ...
    noise_pred = pipeline.unet(..)
    noise_pred_outputs = self.guilder.prepare_and_add_output(pipeline.unet, noise_pred, noise_pred_outputs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.

I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think

Comment on lines 88 to 93
if self._is_cfgpp_enabled():
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
pred_cond = self._preds["pred_cond"]
pred_uncond = self._preds["pred_uncond"]
diff = pred_uncond - pred_cond
pred = pred + diff * self.guidance_scale * self._sigma_next
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original repository implements CFG++ in a different way. I wanted to try and make it work without really modifying all our schedulers, and so it's done this way. The math works out the same.

For context, in our schedulers, we do:

new_sample = sample + model_output_after_cfg * (sigmas[i + 1] - sigmas[i])
new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_after_cfg * sigmas[i + 1]

What we need to do for CFG++ is this instead:

new_sample = sample - model_output_after_cfg * sigmas[i] + model_output_uncond * sigmas[i + 1]

(This is only for EulerDiscreteScheduler and will differ for other schedulers)

After a little bit of working it out on paper, I found that some different schedulers don't really have to be modified if we add and subtract some terms after the scheduler step. We will need to have some specialized code (it can either exist in this file or the scheduler file) to add/subtract the right terms for each scheduler, so LMK how you think we should do it

Nevermind, it's better to just do this in the scheduler

return noise_cfg


def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the implementations

added_cond_kwargs=data.added_cond_kwargs,
return_dict=False,
)[0]
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, there were a few more changes related to guiders. Basically, they also need information like sigmas (see above explanation for CFG++ if we go forward with implementing it that way), latent height/width (for methods like SEG/SAG), tensor formats (SAG), extra prompt information (methods like Attend-and-Excite), and probably more.

I havent added SAG and A&E because it would be complicated to review with all the required changes. Since we want to aim for modularity, it means that it should allow for such use cases though. I'm not quite sure how to proceed yet, but please take another look and LMK what you think

@a-r-r-o-w
Copy link
Member Author

Also cc @DN6 for all the custom hook implementations

@a-r-r-o-w a-r-r-o-w requested a review from DN6 April 16, 2025 12:41
@@ -668,7 +675,38 @@ def step(
dt = self.sigmas[self.step_index + 1] - sigma_hat

prev_sample = sample + derivative * dt

if _use_cfgpp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are hoping to find a scalable solution that can provide maximium support for community creativity. It isn't scalable if it requires code change into schedulers.

I think it can be manipulated inside guider, no? since, we have all the variables in pipeline state and all the components in model states, which you can use to access scheduler and tbe sigmas counter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Euler, yes, it is easy add a correction term outside the scheduler step and make it work -- this is how it was originally implemented.

For DDIM, DPM++, and all the others, it quickly gets very complicated to handle all the correction terms correctly since you need to recalculate a lot of variables for the original model_output, subtract them out, calculate the correct variables using model_pred_uncond, add that in. I don't think that having specialized code in the guider to handle all usable schedulers, probably using isinstance checks, is a good approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants