VAE training sample script
I believe the current lack of easy access to VAE training is stopping diffusion models from disrupting even more industries.
I'm talking about consistent details on things that are less represented in the original training data. 64x64 res can only carry so much detail. Very often I get good result from latent space (by checking the low-res intermedia image) before the final image is ruined by bad details. No prompting or finetuning or controlnet could solve this issue, I tried, and l know lots of other people tried, and most of them are trying without realising that the problem cannot be solved unless the thing that produces the final details can be trained with their domain data.
Right now VAE cannot be easily trained, at least not by someone like me who is not very good at math and python, so there is definitly a demand here. May I hope there will be a sample script based on diffusors to start with? I tried mess with the ones in compvis repo but to no avail. Thanks in advance!
Currently I don't have the bandwidth to dive deeper into this, but I agree an easy training script for VAEs would make sense :-)
Let's see if the community has time for it!
Definitely would love to dive deeper into this but would love some guidance if possible.
Update: VAE training script runs successfully but I'll need to test on a full dataset and evaluate the results.
@zhuliyi0 Is there a dataset you would like me to try fine-tuning on? Preferably one hosted on hugging face?
wow super cool! I was planning to train VAE to re-create certain architecture styles with consistent details, so I found this dataset on HF:
https://huggingface.co/datasets/Xpitfire/cmp_facade
Not a big dataset though, not sure if it works for you. Also there are images of extreme aspect ratio. Let me know if there are more specific requirement on the dataset and I will try to find/assemble a better one.
@zhuliyi0 No worries and thanks for responding. Might be a little busy this week but I'll try out with the new dataset and see if the VAE is improving in terms of learning the new data.
I got the script to run, but looks like my 12G VRAM is far from enough. I assume vram will go down once adam8bit and other optimizations is in place?
@zhuliyi0 Perhaps but I can't really confirm anything at the moment. I'm basing hardware requirements on the docs (https://huggingface.co/docs/diffusers/training/text2image):
Using gradient_checkpointing and mixed_precision, it should be possible to finetune the model on a single 24GB GPU. For higher batch_size’s and faster training, it’s better to use GPUs with more than 30GB of GPU memory.
But this is obviously for training the Stable Diffusion model so the requirements will be different for sure.
At this time, I'm trying to confirm that the AutoencoderKL is indeed being fine-tuned with reasonable performance before actually implementing further techniques like EMA weights, MSE focused loss reconstruction + EMA weights, etc. (details are here: https://huggingface.co/stabilityai/sd-vae-ft-mse-original).
If you would like to work on this PR together I would appreciate the help since I maybe a little MIA for the next 2 weeks at most.
I am a total newbie on python and ML. I am still trying to run the script on my local GPU, right now the OOM is gone after I stick to the arguments you provided in the script, vram and training speed is fine, but there is an error when saving validation image, basicly says an image file inside a wandb temp folder cannot be found. I checked and there is no such folder. Don't know how to use wandb to debug this one.
Colab seems to be running without error, but the speed is a bit slow compared to my local GPU, probably normal for T4. From validation images, I see signs of improvement of image details I was talking about, will validate with inferencing after a reasonable sized training has finished.
I got training to run on my local GPU on Windows. The directory error was due to path naming convention in Windows. Again from validation images I can see it was learning. The loss was also going down.
I noticed there is a vram leak in log_validation function when the number of test image is 5 or above. I also failed to use the trained vae inside a1111 for inferencing, giving error "Missing key(s) in state_dict“.
Hey @zhuliyi0 , thanks for taking the time to test things. The script is definitely not perfect yet but I'll work on the things you mentioned. In terms of transferring the VAE over to a1111 I'm not quite sure about that. I haven't played around with a1111 so I would need some time.
My current focus will be to clean up the script and implement the memory saving techniques to improve training. Then I'll see how we can make the VAE transferrable to a1111.
Totally understand that the script wouldn't be perfect at this point. I am glad to help whenever I can. I will try using pipeline to test inference performance. @Pie31415
here is a training test run:
https://wandb.ai//zhuliyi0/goa_5e5/reports/VAE-training-test--Vmlldzo0ODYzMzcx
Also did a quick inference test using a finetuned model that was trained on the same dataset, compare results with the default and trained VAE. I can confirm VAE is adding details, making the image better.
Another issue: the output from trained VAE looks white-washed. This happens on both sd15 and the finetuned model. I had to do some brightness and contrast change to the image. The validation images during training do not have this issue.
here is a training test run:
https://wandb.ai//zhuliyi0/goa_5e5/reports/VAE-training-test--Vmlldzo0ODYzMzcx
Your wandb experiment seems to be private/locked.
I can confirm VAE is adding details, making the image better.
Are you referring to the default VAE or custom trained one? If it is a custom trained one can you provide a link to the weights? It'll be extremely beneficial to have some results to compare to when I'm fixing up experiments for the script.
Another issue: the output from trained VAE looks white-washed. This happens on both sd15 and the finetuned model. I had to do some brightness and contrast change to the image. The validation images during training do not have this issue.
Hmm yeah, it may be how we're training the VAE. I'll take a look over the weekend. Most likely the substantial changes will have to be done this weekend since I'm a little preoccupied before then.
Thanks a lot for your patience though. 🤗
I made the project public. And the weight file:
https://drive.google.com/file/d/1gTQqWuVA7m7GYIStVbulYS-tN_CMY-PM/view?usp=sharing
Some inference image that shows the white-wash issue, using VAE at step 4k - 40k, gradually getting worse:
https://drive.google.com/drive/folders/16ivRLiLgb7dDixfFbNIL7vf_wNe9BaRO?usp=sharing
Hello, This project is really cool, thank you! I noticed a potential mistake in the code: the kl loss is applied on the output, but I think it should be applied on the latent space if I understood correctly (I may be wrong, I am not an expert of VAE training). However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient).
The lpips loss gives great results however (without it, the image tends to become too 'smooth'). I used this library. I hope this helps!
lpips_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device)
for epoch in range(first_epoch, args.num_train_epochs):
vae.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(vae):
target = batch["pixel_values"].to(weight_dtype)
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
posterior = vae.encode(target).latent_dist
z = posterior.mode()
pred = vae.decode(z).sample
kl_loss = posterior.kl().mean()
mse_loss = F.mse_loss(pred, target, reduction="mean")
lpips_loss = lpips_loss_fn(pred, target).mean()
logger.info(f'mse:{mse_loss.item()}, lpips:{lpips_loss.item()}, kl:{kl_loss.item()}')
loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Hello, This project is really cool, thank you! I noticed a potential mistake in the code: the kl loss is applied on the output, but I think it should be applied on the latent space if I understood correctly (I may be wrong, I am not an expert of VAE training). However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient).
The lpips loss gives great results however (without it, the image tends to become too 'smooth'). I used this library. I hope this helps!
lpips_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device) for epoch in range(first_epoch, args.num_train_epochs): vae.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(vae): target = batch["pixel_values"].to(weight_dtype) # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py posterior = vae.encode(target).latent_dist z = posterior.mode() pred = vae.decode(z).sample kl_loss = posterior.kl().mean() mse_loss = F.mse_loss(pred, target, reduction="mean") lpips_loss = lpips_loss_fn(pred, target).mean() logger.info(f'mse:{mse_loss.item()}, lpips:{lpips_loss.item()}, kl:{kl_loss.item()}') loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad()
Thanks for the feedback, that definitely might be the case. I'll take a look and make the necessary changes. Thanks again.
@zhuliyi0 I updated the PR with @ThibaultCastells 's code. Can you give your training another try and let us know the results? (e,.g, is the white-washing issue improved)
Also, I took a look at the VRAM issue you mentioned with test_images >= 5. I can't seem to reproduce the issue can you give more details on this if you're still experiencing this issue?
@ThibaultCastells I've credited the recent commit to you and I plan to mention your contribution in the PR as well.
@patrickvonplaten Do you mind giving the PR a look over when you're free?
@Pie31415 thank you very much! I will let you know if I have other improvement suggestions
By the way:
However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient)
With a scale coefficient around $1e^{-7}$ and a training long enough (using my own dataset), the image quality first got much worst and then came back to normal, so my assumption about 'latent space reorganization' was good I think. The kl loss went from >20,000 to ~100 when it converges.
@Pie31415 I re-run a training with new script, the result was conceivably no different. The white wash issue still exist, the same as previous. Seems like the training gradually makes the contrast lower and brightness higher, but not by much.
@ThibaultCastells do you mean "learning rate" when you say "coefficient"?
No I meant the coefficient that multiplies the loss term (kl_scale):
loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss
Note that by default kl_scale and lpips_scale are 0, so if you didn't change it you won't see any difference (I suggest to use lpips_scale = 0.1, as this is the value used to finetune the vae of SD).
I noticed that there is no transforms.Normalize([0.5], [0.5]) applied to the images in the training script, and the output images seem to be correct. However, in other model training scripts, normalization is performed before using VAE. Is it an error in other scripts?
@ThibaultCastells Do you have any thoughts about why the VAE might be outputting white washed reconstructions? I seem to have seen some Civitai models that had a similar issue. Not sure how it was resolved though.
I noticed that there is no
transforms.Normalize([0.5], [0.5])applied to the images in the training script, and the output images seem to be correct. However, in other model training scripts, normalization is performed before using VAE. Is it an error in other scripts?
You're right. A blunder on my part. I guess it must have been removed when I was playing around with things and forgot to put it back. Thanks for the catch
@Pie31415 I am not too surprised that this issue happens when using only the mse loss, because this is a very different training configuration than in the paper, so we don't know what to expect in this case. Therefore I would like to confirm that @zhuliyi0 changed the default value of the scale coefficients of the loss when he checked the new code. And if so, what value was used?
Note that when they finetune the vae for SD they only finetune the decoder, that's probably why they do not use kl loss (they do not need it since the decoder does not affect the latent space).
Also, not related but is it normal that there is no .eval() when evaluating the model (and therefore another .train() after evaluation)? Is it handled by the accelerator.unwrap_model function?
@ThibaultCastells I'm wondering if it's a better idea if we finetune only the decoder.
https://huggingface.co/stabilityai/sd-vae-ft-mse-original Reading through the above model card it seems like the reasoning is to maintain compatibility with existing models which could explain why @zhuliyi0 was having issues loading the vae into a1111.
Also, not related but is it normal that there is no .eval() when evaluating the model (and therefore another .train() after evaluation)? Is it handled by the accelerator.unwrap_model function?
Not sure. I've adapted the code from the previous train scripts but the difference was that the unet, vae, etc. would be unwrapped and feed into the SD Pipeline which I assume does something similar to model.eval() for inference. Here I'm not feeding to the SD Pipeline so I'm not sure if triggering vae_model.eval() would change anything.
Updates:
- ~~Removed random crop from transforms~~
- Added in Normalization (thanks @NoahRe1)
- Wandb now logs step_loss, lr, mse, lpips, and kl
-
kl_scaledefault set to 1e-6 (based on original LDM code) -
lpips_scaledefault set to 0.1 (https://huggingface.co/stabilityai/sd-vae-ft-mse)
todo
- add FID score
- try out NLL loss instead of MSE since original LDM uses NLL (https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/losses/vqperceptual.py)
~~Currently running model...~~ Details:
mixed_precision="no"
pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1"
dataset_name="Xpitfire/cmp_facade"
train_batch_size=1
gradient_accumulation_steps=4
gradient_checkpointing=true
kl_scale=1e-6 (default)
lpips_scale=1e-3
Results
Run summary:
kl 112.45438
lpips 0.04537
lr 0.0001
mse 0.00167
step_loss 0.00183
I tried the new scripts with kl and ipips not zero, and white wash isssue seems to be gone, however the test images are blurred and over-saturated. Still tunning more parameters. Will try the new default values too. I kept gradient accmulation step to 1 for limited vram, not sure how much of an impact that would be.
About whether to only train the decoder: the validation images during training does not have any issue that the test image had, white-wash or blur. Wonder if this is because of training on the encoder part, making the latent from vae encoder deviates from the text encoder?
Also I want to make sure is there requirement on the folder structure of training data? I see comments under --train_data_dir saying things about folder structure but assume that's just leftover code from other training scripts that require text prompt?
It might be helpful to dedicate a validation folder with specific images because I have noticed that some images start off much worse than others, so preferably monitored more closely during training. Will confirm this observation.