DEISMultistepScheduler missing get_velocity function
Describe the bug
scheduling_deis_multistep is almost a drop in replacement for scheduling_ddpm however it lacks the get_velocity function from scheduling_ddpm. In the project I am working on the lack of this function prevents using models that have v_prediction. The implementation from scheduling_ddpm (as below) has been confirmed to work with scheduling_deis_multistep with no other changes required.
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Reproduction
We use the following function
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
If noise_scheduler is ddpm everything works as intended. If noise_scheduler is deis, if prediction type is epsilon this function works as intended. If prediction type is v_prediction it fails due to get_velocity not being implemented in scheduling_deis.
Implementing the get_velocity function from scheduling_ddpm into scheduling_deis_multistep makes things work as intended for deis. Function is as below.
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Logs
No response
System Info
-
diffusersversion: 0.12.1 - Platform: Windows-10-10.0.22000-SP0
- Python version: 3.10.6
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.11.1
- Transformers version: 0.15.0
- Accelerate version: not installed
- xFormers version: 0.0.14.dev
- Using GPU in script?: NVIDIA RTX 2070 Super
- Using distributed or parallel set-up in script?: No
@yiyixuxu could you take a look here? :-)
Discussion is moved to #2352 .
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.