Various Fixes for Flax Dreambooth
- Correctly update the progress bar every epoch
- Allow specifying a pretrained VAE
- Allow specifying a revision to pretrained models
- Cache compiled models between invocations (speeds up TPU execution a lot!)
- Save intermediate checkpoints by specifying
save_steps
cc @patrickvonplaten @yiyixuxu @patil-suraj
The documentation is not available anymore as the PR was closed or merged.
the fixes seem pretty good to me.
asking @pcuenca to take a look :)
I was just thinking about adding the revision argument (the original script does not cast the weights to bf16 when we pass --mixed_precision=bf16)
I'm not actually sure why the tests are failing.. can't figure out if its sporadic or real. Any thoughts @pcuenca?
@patil-suraj should be good now!
Thanks a lot @yasyf