Sign In

Failed Experiment: Adapting Flux's VAE to SDXL

3

Failed Experiment: Adapting Flux's VAE to SDXL

This is a quick dev log of an experiment I ran trying to train a small model to adapt SDXL's latent space to Flux's latent space. Why? Well if it worked, it would allow you to very cheaply use Flux's VAE to decode SDXL outputs, resulting in a nice boost in quality with minimal cost.

I'm writing about this experiment because it was a failure. And it's important to document both successes and failures. Besides, I think the idea was cool...

You What In The Who Now?

Look, tons of people still use SDXL based models. While SDXL has gone through a lot of finetuning and growth, one thing that remains a consistent problem with it is its VAE.

When you run SDXL image generation, you aren't actually having the model generating images at 1024x1024 (or similar). The core of the model works at 128x128! After your image is done it gets handed off to the VAE, which is responsible for building the final 1024x1024 image from the 128x128 "latent" image generated by SDXL.

That's an 8x size increase! If you look closely at any SDXL gen you'll see the influence of its VAE, especially the flaws of its VAE. Mangled irises in eyes? Weird fishnets? Even some problems with hands and text stem from the VAE! Anything at that tiny scale.

SD3 and Flux introduced a new VAE, much better VAE. The architecture of the new VAE is very similar to SDXL's, however it received a boost in size (latents are 16x128x128 instead of 4x128x128) as well as more robust training.

Now this VAE can't be directly used with SDXL for the obvious reason: the latent space is completely different. Yet they are still both latent spaces, and the VAE architecture is basically the same so ... could you somehow adapt one space to the other?

That was my experiment. Train a small model that would get as input the SDXL latents, and produce as output a Flux latent.

Experiment Setup

The setup is blessedly straight-foward:

  • Have a diverse dataset of images (good coverage across styles and concepts)

  • Take each image and run through the VAE encoders of both SDXL and Flux, caching the resulting latents.

  • Train a model on those pairs of latents (SDXL -> Flux)

  • Profit

Frankly, it's stupidly simple. In theory. The model here could really be anything. Even a single linear project from 4x128x128 to 16x128x128 might work! Who knows!

The Details

There are some nuances, as there is to everything. Mostly that we're dealing with VAE models here, which aren't "normal" models. Like many things, you'll be hard pressed to understand what the VAE architecture actually is by reading any literature on the subject. So let me break it down.

  • Input: Image (e.g. 3x1024x1024)

  • Encoder model: Takes image, does NOT output latents

  • Instead of latents, the encoder outputs a mean and a variance for each latent value. i.e. it says "I think the latent is somewhere around here, and I'm this confident". This means its output is 8x128x128 (for example), twice the number of outputs versus the desired latent (4x128x128).

  • You then "sample" to get the latent. i.e. torch.randn((4, 128, 128)) * std + mean

  • The Decoder is then simple: latent input -> image

So the encoder side of a VAE looks like this:

parameters = encoder(image)
mean, logvar = torch.chunk(parameters, 2)
std = torch.exp(0.5 * logvar)
latents = torch.randn(latent_size) * std + mean

The nuance here is that while we do want our model to take latents as input (since that's what SDXL will output), we can more directly train it to predict Flux's latent parameters rather than sampled latent. This is more direct.

So our training is more like:

sdxl_latent = sdxl_encoder(image).sample()
flux_parameters = flux_encoder(image)
pred_parameters = our_model(sdxl_latent)
loss = (flux_parameters - pred_parameters)**2

Beyond that little oddity, everything else is simple. I started with a single layer CNN for the model, and a dataset of 100k images.

The Failures

This entire experiment was a disaster.

The first mistake I made was assuming I should use a KL-divergence loss. The fuck is that? It's basically a fancy loss designed specifically for outputs that represent random distributions, which is the case here. So it, presumably, better handles our case where we're trying to predict a mean and variance. The issue is that it does this by scaling the loss by the variance. In theory this is good, since it takes into consideration the confidence of the model. In practice here it made the loss explode. It turns out the Flux VAE outputs almost all 0.0 for its variance. i.e. it's very confident in its predictions. What happens when you try to divide by 0.0? Yeah.

I experimented with various EPS and clamping tweaks, but ultimately decided that we aren't really predicting a distribution so much as we're predicting the raw output of another network. So a simple MSE loss should be good enough.

Later I realized, if the variance outputs are almost always 0, it's probably not worth even including them in the loss. So I switched to simply having the model predict mean.

After solving that issue the loss and gradients were much nicer and the model started to learn. I began tweaking hyperparameters (learning rate, etc) and model size. Each bump in model size resulted in faster learning, so I kept going until eventually landing at 6 block resnet style architecture with 128 as the hidden dimension (~3M parameters). The loss steadily decreased with training length, down to 0.17 with a few epochs of training.

That might sound like success, but this raised two roadblocks. First and foremost, I expected the model wasn't going to be meaningfully useful until it got to a loss of 0.0001! 0.17 is a long way from that. Second, when I tried to increase the model size past the 3M parameter point it stopped getting better. That means either something was wrong with the training, or I was stuck at 3M parameters and had to depend solely on training length to get loss down.

At this point I began fiddling with transformer based models, since they have better scaling than CNNs, but that too hit a parameter threshold where it stopped speeding up training.

Okay, so, just keep training? Well first, I can't imagine how long I would have to train the model for it to go from 0.17 to 0.0001. Second, while the model is small, the data is huge, coming in at 2.2MB per example. My machine was starting to cap out at feeding 500 example/s because of this, and there was no way to keep the dataset in RAM. Not to mention the computational cost of calculating these latents beforehand.

In addition to those empirical problems, I suspect the model is running into a fundamental issue. It's trying to go from 4 channels to 16 channels, which means it basically has to invent details. The Flux latent space preserves more information from the original image, information this adapter model won't have access to in the SDXL latent space. And simple models like this historically cannot generate details out of thin air. Which means the adapter model would need to:

  • Use a generative approach, either itself becoming a VAE like architecture or a diffuser.

  • Use a perceptual loss on the output of the Flux decoder. This is extraordinarily involved and computationally expensive.

A diffuser based model would work best, obviating the need for the perceptual loss, but would cost a lot more during inference. The point of this project is to build a cheap adapter. Maybe, maybe, you could train a diffuser and then post train it to be 1 step, but that's more work as well.

In either case, the complexity of the project was exploding to the point of not being worthwhile. At this point I might as well train a new VAE from scratch! Which is silly to do for an aging model family like SDXL.

The End

Okay, that's it. I decided this experiment wasn't really worth exploring further, at least for me. Training code for those curious: https://gist.github.com/fpgaminer/fe3be6e860a3709be0c6b3cb7e1ec135

I did run that model through some test images to see what it's doing, and at this loss all it's able to do is reconstruct the basic image at 128x128, everything within the 8x8 blocks is mush.

As a final note, you could easily train SDXL to use the Flux VAE directly, but that wouldn't build a "universal adapter" that could be used with all existing finetunes/loras/etc. And it wouldn't work in any existing inference UI. (it's a miracle ComfyUI was able to run my Flow Matching SDXL monster).

As always, be good to each other, and build cool stuff.

3

Comments