diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

DEISMultistepScheduler missing get_velocity function

Open saunderez opened this issue 3 years ago • 1 comments

Describe the bug

scheduling_deis_multistep is almost a drop in replacement for scheduling_ddpm however it lacks the get_velocity function from scheduling_ddpm. In the project I am working on the lack of this function prevents using models that have v_prediction. The implementation from scheduling_ddpm (as below) has been confirmed to work with scheduling_deis_multistep with no other changes required.

    def get_velocity(
        self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
        timesteps = timesteps.to(sample.device)

        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

Reproduction

We use the following function

                   # Get the target for loss depending on the prediction type
                    if noise_scheduler.config.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        target = noise

If noise_scheduler is ddpm everything works as intended. If noise_scheduler is deis, if prediction type is epsilon this function works as intended. If prediction type is v_prediction it fails due to get_velocity not being implemented in scheduling_deis.

Implementing the get_velocity function from scheduling_ddpm into scheduling_deis_multistep makes things work as intended for deis. Function is as below.

    def get_velocity(
        self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
        timesteps = timesteps.to(sample.device)

        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

Logs

No response

System Info

  • diffusers version: 0.12.1
  • Platform: Windows-10-10.0.22000-SP0
  • Python version: 3.10.6
  • PyTorch version (GPU?): 1.13.1+cu116 (True)
  • Huggingface_hub version: 0.11.1
  • Transformers version: 0.15.0
  • Accelerate version: not installed
  • xFormers version: 0.0.14.dev
  • Using GPU in script?: NVIDIA RTX 2070 Super
  • Using distributed or parallel set-up in script?: No

saunderez avatar Feb 14 '23 13:02 saunderez

@yiyixuxu could you take a look here? :-)

patrickvonplaten avatar Feb 14 '23 22:02 patrickvonplaten

Discussion is moved to #2352 .

patrickvonplaten avatar Mar 16 '23 15:03 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 10 '23 15:04 github-actions[bot]