Philippe Hansen-Estruch
Philippe Hansen-Estruch
Hey, I've implemented a Jax version that I can release soon when I'm finished with my project :)!
I am having this issue as well for use in diffusion models
you're a king
Well "soon" is a relative term, but I just released it here! Includes an implementation of MAE and my own method. https://github.com/philippe-eecs/small-vision/tree/main
The paper is also missing a cite for DiffAE, which uses a similar method to condition the model on a representation. Trains another diffusion model to sample this representation. Then...
Yes I get this error: ``` RuntimeError: function xFuserRingFlashAttnFuncBackward returned an incorrect number of gradients (expected 17, got 13) ```
Solved this problem by downgrading flash-attn to 2.6.2
Would TP be needed in the 14B parameter model? Wanted to finetune that on a new VAE I was writing up. Happy to test this for you but wanted to...