Skip to content

DPM ++ karras scheduler.add_noise fails when t = 0 #6069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
garychan22 opened this issue Dec 6, 2023 · 0 comments · Fixed by #6085
Closed

DPM ++ karras scheduler.add_noise fails when t = 0 #6069

garychan22 opened this issue Dec 6, 2023 · 0 comments · Fixed by #6085
Assignees
Labels
bug Something isn't working

Comments

@garychan22
Copy link

Describe the bug

When I calling the add_noise function of dpm++ karras when t = 0, it fails

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/xxx/miniconda3/envs/controlnet/lib/python3.8/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 889, in add_noise
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
  File "/home/xxx/miniconda3/envs/controlnet/lib/python3.8/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 889, in <listcomp>
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
RuntimeError: a Tensor with 2 elements cannot be converted to Scalar

Reproduction

import torch
from diffusers import DPMSolverMultistepScheduler
model_path = "models/models--stabilityai--stable-diffusion-2-base"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler", use_karras_sigmas=True)
scheduler.set_timesteps(60)
latents = torch.randn([1, 4, 64, 64])
noise = torch.randn([1, 4, 64, 64])
noisy = scheduler.add_noise(latents, noise, torch.tensor([0]))

Logs

No response

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • diffusers version: 0.23.1
  • Platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.17
  • Python version: 3.8.18
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Huggingface_hub version: 0.17.1
  • Transformers version: 4.33.2
  • Accelerate version: 0.23.0
  • xFormers version: 0.0.21
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @DN6 @sayakpaul @patrickvonplaten

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants