There are many great training scripts for VAE on Github. However, some repositories are not maintained and some are not updated to the latest version of PyTorch. Therefore, I decided to create this repository to provide a simple and easy-to-use training script for VAE by Lightning. Beside, the code is easy to transfer to other projects for time saving.
- Support training and finetuning both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) VAE and [Flux](https://github.com/black-forest-labs/flux) VAE.
- Support evaluating reconstruction quality (FID, PSNR, SSIM, LPIPS).
- A practical guidance of training VAE.
- Easy to modify the code for your own research.
## Guidance
Here are some guidance for training VAE. If there are any mistakes, please let me know.
- Learning rate: In LDM repository [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion), the base learning rate is set to 4.5e-6 in the config file. However, the batch size is 12, accumulated gradient is 2 and scale_lr is set to True. Therefore, the effective learning rate is 4.5e-6 12 2 * 1 = 1.08e-4. It is better to set the learning rate from 1.0e-4 to 1.0e-5. In finetuning stage, it can be smaller than the first stage.
- scale_lr: It is better to set scale_lr to False when training on a large dataset.
- Discriminator: You should open the discriminator in the end of the training, when the VAE has good reconstruction performance. In default, disc_start is set to 50001.
- Perceptual loss: LPIPS is a good metric for evaluating the quality of the reconstructed images. Some models use other perceptual loss functions to gain better performance, such as [sypsyp97/convnext_perceptual_loss](https://github.com/sypsyp97/convnext_perceptual_loss).
Repo:
Models:
