Flax support for Stable Diffusion 2
The documentation is not available anymore as the PR was closed or merged.
thanks @pcuenca ♥ it is working 🎉
I converted to flax I put it https://huggingface.co/flax/stable-diffusion-2 if someone needs
Super cool!
Should we maybe add one slow test for SD-2?
Should we maybe add one slow test for SD-2?
I added a couple of integration tests to ensure that the output from the Flax UNet is close enough to the output from PyTorch. This works within a tolerance of 1e-2 in my hardware, despite PyTorch values being generated in float16 and Flax ones in bfloat16. This sounds like the UNet implementation appears to be correct.
The tests may fail on the real testing hardware though (V100 and TPU). I'll adapt if that's the case.
Merging then to deploy the backend. Thanks!