Progressive distillation
Starting to work on implementing PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS
This is very much a draft PR. So far I've included a toy example in a notebook.
First, it trains an unconditional image diffusion model on a single image. Then, it implements the distillation training procedure - training a student model to produce the same output in N // 2 steps.
TODOs include:
- [ ] Verify the equation block in the distillation function is correct
- [ ] Test on multiple iterations of distillation (i.e., N = 1000 -> 500 -> 250...)
- [ ] Test on a larger image dataset
- [ ] Clean up changes in DDIM scheduler
- [ ] Convert distillation process to a pipeline(?) or script
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
This is a great start. I noticed a small bug and created a colab to play with the sampling notebook.
The core code from algorithm 2 should be updated like this (the calculation of t' and t'' was swapped).
# get alpha and sigma for base step
alpha_t, sigma_t = teacher_scheduler.get_alpha_sigma(batch, timesteps + 1, accelerator.device)
# base output
z_t = alpha_t * batch + sigma_t * noise
# get alpha and sigma for skip steps
alpha_t_prime, sigma_t_prime = student_scheduler.get_alpha_sigma(batch, timesteps // 2, accelerator.device)
alpha_t_prime2, sigma_t_prime2 = teacher_scheduler.get_alpha_sigma(batch, timesteps, accelerator.device)
# reconstruction base sample
v = teacher(z_t.float(), timesteps + 1).sample
rec_t = (alpha_t * z_t - sigma_t * v).clip(-1, 1)
# distillation computation
z_t_prime = alpha_t_prime * rec_t + (sigma_t_prime / sigma_t) * (z_t - alpha_t * rec_t)
v_1 = teacher(z_t_prime.float(), timesteps).sample
rec_t_prime = (alpha_t_prime * z_t_prime - sigma_t_prime * v_1).clip(-1, 1)
z_t_prime_2 = alpha_t_prime2 * rec_t_prime + (sigma_t_prime2 / sigma_t_prime) * (z_t_prime - alpha_t_prime * rec_t_prime)
# teacher target
x_hat = z_t_prime_2 - ((sigma_t_prime2 / sigma_t_prime) * z_t) / (alpha_t_prime2 - (sigma_t_prime2 / sigma_t_prime) * alpha_t)
I was thinking of implementing this too :) I briefly went over the code and I'm wondering if you really want to be training (distilling) inside a Pipeline? From what I understood, it seems pipelines are for sampling?
@lukovnikov sorry for the delayed response. I think that's an important meta decision to make with the Diffusers team! I've seen a similar approach e.g., in the Imagic PR. For now I'm still working on getting the distillation results as good as the paper's!
@bglick13 I've also implemented distillation in my fork: https://github.com/lukovnikov/diffusers/blob/mine/examples/unconditional_image_generation/distill_unconditional.py However, I'm facing some strange issues where the contrast of the generated images is increased after every distillation phase. Did you have similar issues and did you manage to reproduce the paper's results?
lukovnikov - at a guess the clip text encoder tokens are also being trained and are becoming 'longer' (attention focuses more on longer vectors which are concepts that the text encoder knows well, which is why popular celebrities can get an overbaked appearance in default models) which causes the images to be overbaked.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@patrickvonplaten do you think this is something to give it a try again?