diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[Flax] Stateless schedulers, fixes and refactors

Open skirsten opened this issue 3 years ago • 4 comments

Hi, here some fixes and improvements for the Flax schedulers. Let me know what you think! Sorry that its a huge PR with a single commit :sweat_smile:

Refactor schedulers to be completely stateless (all the state is in the params)

  • No more state in the scheduler class
  • No more Implicit transfers
  • Extracted common state (common state can also be reused from other schedulers)
  • The shape and dtypes of the state returned by set_timestamp is now final and wont be changed by step (this reduces the amount of jit misses if jitting the scheduler separately)

Added dtype param to schedulers

Leave it at fp32 though unless you want to lose all details in the image.

Fix copy paste error in add_noise function

  • The add_noise function and thus img2img were not working in DDIM and DPMSolverMultistep
  • Extracted common logic so it can't happen again

Removed all jax conditionals to fix performance bottleneck

  • Using jax.lax.cond and jax.lax.switch causes the CPU to have to wait for the pred (even when jitted) causing the GPU pipeline to stall (not enough kernels scheduled). More info here.
  • If you notice that PNDM and DPMSolverMultistep were slower than DDIM, this was the reason.
  • Usually this is most noticeable on fast GPU + slow CPU combo or if running a splitkernel (separately jitted scheduler instead of the megakernel as in this repo).
  • Evaluating all branches instead has no noticeable performance impact.

Fixed small bugs and improvements

  • Added v_prediction where it was missing
  • Made DDPM jitable. Though I'm not sure sure if it works correctly.
  • Fixed DPMSolverMultistep not being able to start in the middle of a schedule. This caused img2img not to work.
  • Made LMSDiscrete run. Its not jitable and I always get back a black image though.
  • Probably some other stuff that I forgot about

~Validation~ (outdated)

I messed up so Pytorch is fp16 and Flax is bf16

name Pytorch Flax v0.10.2 Flax this PR
DDIM torch_ddim_0 10 2 flax_ddim_0 10 2 flax_ddim_0 11 0 dev0
DPMSolverMultistep torch_dpmsolver_multistep_0 10 2 flax_dpmsolver_multistep_0 10 2 flax_dpmsolver_multistep_0 11 0 dev0
PNDM torch_pndm_0 10 2 flax_pndm_0 10 2 flax_pndm_0 11 0 dev0

skirsten avatar Dec 11 '22 21:12 skirsten

The documentation is not available anymore as the PR was closed or merged.

Hi @skirsten, this looks amazing! I see you are tweaking stuff, let me know when you want a review :)

pcuenca avatar Dec 13 '22 06:12 pcuenca

Hi @pcuenca, It should be ready for review now :sweat_smile:

skirsten avatar Dec 13 '22 21:12 skirsten

Awesome, will do this week!

pcuenca avatar Dec 13 '22 21:12 pcuenca

Maybe @patil-suraj wants to take a quick look too.

pcuenca avatar Dec 19 '22 19:12 pcuenca

Cool, let's merge as this is a clear improvement to what we had previously. More than happy to fix scheduler one-by-one in the future.

patrickvonplaten avatar Dec 20 '22 00:12 patrickvonplaten