[Flax] Stateless schedulers, fixes and refactors
Hi, here some fixes and improvements for the Flax schedulers. Let me know what you think! Sorry that its a huge PR with a single commit :sweat_smile:
Refactor schedulers to be completely stateless (all the state is in the params)
- No more state in the scheduler class
- No more Implicit transfers
- Extracted common state (common state can also be reused from other schedulers)
- The shape and dtypes of the state returned by
set_timestampis now final and wont be changed bystep(this reduces the amount of jit misses if jitting the scheduler separately)
Added dtype param to schedulers
Leave it at fp32 though unless you want to lose all details in the image.
Fix copy paste error in add_noise function
- The
add_noisefunction and thus img2img were not working inDDIMandDPMSolverMultistep - Extracted common logic so it can't happen again
Removed all jax conditionals to fix performance bottleneck
- Using
jax.lax.condandjax.lax.switchcauses the CPU to have to wait for thepred(even when jitted) causing the GPU pipeline to stall (not enough kernels scheduled). More info here. - If you notice that
PNDMandDPMSolverMultistepwere slower thanDDIM, this was the reason. - Usually this is most noticeable on fast GPU + slow CPU combo or if running a splitkernel (separately jitted scheduler instead of the megakernel as in this repo).
- Evaluating all branches instead has no noticeable performance impact.
Fixed small bugs and improvements
- Added
v_predictionwhere it was missing - Made
DDPMjitable. Though I'm not sure sure if it works correctly. - Fixed
DPMSolverMultistepnot being able to start in the middle of a schedule. This caused img2img not to work. - Made
LMSDiscreterun. Its not jitable and I always get back a black image though. - Probably some other stuff that I forgot about
~Validation~ (outdated)
I messed up so Pytorch is fp16 and Flax is bf16
| name | Pytorch | Flax v0.10.2 | Flax this PR |
|---|---|---|---|
DDIM |
![]() |
![]() |
![]() |
DPMSolverMultistep |
![]() |
![]() |
![]() |
PNDM |
![]() |
![]() |
![]() |
The documentation is not available anymore as the PR was closed or merged.
Hi @skirsten, this looks amazing! I see you are tweaking stuff, let me know when you want a review :)
Hi @pcuenca, It should be ready for review now :sweat_smile:
Awesome, will do this week!
Maybe @patil-suraj wants to take a quick look too.
Cool, let's merge as this is a clear improvement to what we had previously. More than happy to fix scheduler one-by-one in the future.








