bigASP v2: Body Horror and Zero Terminal SNR
I've been doing some research and experiments on the finickiness of bigASP v2. Specifically, why does it generate body horror more often than other SDXL models. Based on feedback, this is the biggest issue with using the model right now. I have two leading theories as to what's causing it: Zero Terminal SNR and Exposure Bias. I thought I'd write up my thoughts and experiments here, for anyone interested.
Zero Terminal SNR
Zero Terminal SNR was first researched here (https://arxiv.org/pdf/2305.08891), back in the 1.5 and 2.0 SD days. They were trying to solve the issue of SD not being able to generate dark images, and came up with tweaking SD's noise schedule.
During training, diffusion models are fed a noisy version of the image, with the amount of noise determined by the noise schedule. At timestep 0 the noise is at its lowest; at timestep 999 the noise is at its highest. During inference, the model is fed pure noise at timestep 999, and its predictions are used with each subsequent timestep to generate an image.
The issue with SD's default noise schedule is that its "sigma" at timestep 999, essentially determining the amount of noise added, is rather low (~14). So at timestep 999, during training, the model isn't getting a fully noised input. So it won't expect pure noise. This causes a severe "domain shift" during inference; the model being fed inputs it hasn't really seen during training. This is a fundamental issue of SD's noise schedule not having a high enough "terminal" sigma. So during inference, SD faces two issues. First, the domain shift will cause its predictions to be worse. Second, with not enough noise, the low frequencies will leak through both during training and inference. So the model never learns to generate these at that timestep. Hence why the model can't generate dark images, which require a low mean (the lowest frequency).
The researchers propose a Zero Terminal SNR schedule, where SD's default noise schedule is shifted slightly so that the terminal sigma is infinite, which means at timestep 999 during training the model is trained on pure noise, and during inference expects pure noise (which is what it always gets). For the purposes of the paper, this solved their issue; SD was now able to generate dark (or light) images.
Fast forward to a month ago and NovelAI's most recent paper (https://arxiv.org/pdf/2409.15997), documenting their work on their SDXL based models, showed that SD's lack of a zero terminal SNR causes other issues than not being to generate dark images. They found that the low terminal sigma of SDXL also causes it to more frequently generate body horror. !!! Increasing the terminal sigma reduces this in their research.
They also point out that the problem is worse in SDXL, compared to SD, because SD and SDXL share the same noise schedule, but SDXL generates in a higher resolution. So more frequencies leak through at the early timesteps.
(This would also explain why some people have found "golden workflows" where they do two passes: generate at a low resolution, and then do a hi-res fix style second pass. The low resolution pass is less likely to generate body horror, since it is less affected by the low sigma. And then the high resolution pass gets an already "good" input to start from.)
Unfortunately, during the early days of SD, the community rallied around Offset Noise as the fix for not being able to generate dark/light images, likely because it was "simpler" to implement. (It doesn't require diving into the underlying statistical math of diffusion models...). With ZT-SNR just floating around in the background.
I don't recall the noise schedules in SD3/3.5/Flux, but those researchers are definitely aware of ZT-SNR (i.e. CosXL has a ZT-SNR schedule), so I wouldn't be surprised if the issue is fixed in the latest generation of models.
Regardless, bigASP v2 was trained on SDXL's default noise schedule, like all other SDXL based models. Which raises an important contradiction; why does bigASP v2 suffer more than other SDXL models? More on that later.
Exposure Bias
The second theory is Exposure Bias (https://arxiv.org/pdf/2301.11706). Essentially, this paper explores the theory that during training diffusion models are only seeing "perfect" inputs of image + noise, whereas during inference they are executed recursively on their own predictions, which are imperfect. This causes a domain shift, and could lead to worse performance. This is the same issue that LLMs can run into, since they are also run recursively on their own predictions. A single mistake puts them outside their training distribution, and they can't recover.
The researchers propose a simple solution of adding additional noise to the inputs of the model during training, called Input Perturbation. Usually the input to the model is sqrt(alpha_cum) * image + sqrt(1 - alpha_cum) * noise
, and the prediction target is noise
. Input Perturbation does sqrt(alpha_cum) * image + sqrt(1 - alpha_cum) * (noise + 0.1 * perturb_noise)
, where perturb_noise
is Gaussian noise just like noise
. The prediction target remains the same. So the model gets corrupted noise in its inputs, but is forced to predict the uncorrupted noise. Hopefully helping to make the model more robust to imperfect inputs, and better able to recover from them. (NOTE: 0.1 is their recommended value, but this is a hyperparameter).
The Illustrious model was trained with Input Perturbation, though their paper says they set it to "0 < strength < 0.1", so I'm not sure if they went with the default of 0.1, or varied it randomly? Regardless, their model is proof that the technique at least doesn't harm a model.
How much Input Perturbation might help a large scale model is unclear. This paper is quite old, in ML research years, and to me it just seems like a form of regularization. Any other form of regularization is likely to work just as well. That could include higher learning rates, dropout, weight decay, batch size, etc. Anything that makes the model learn a more robust representation of the data, will make it more tolerant of Exposure Bias.
I also suspect that the problem of Exposure Bias can be exacerbated by the lack of Zero Terminal SNR. Without ZT-SNR, the model will get worse inputs during inference, which magnifies Exposure Bias, making it perform even worse.
So what's the issue with bigASP v2?
The biggest question for me is why bigASP seems to suffer so much more here than other SDXL models. My only guess here is that, by comparison to models like Pony or Illustrious, bigASP v2's training is shorter. Pony XL's was an order of magnitude longer. Illustrious's training details are unclear, but potentially 4 times longer or more. This extended training might make those models more robust, and thus more capable of handling the domain shifts. In other words, the solution might boil down, like it seemingly always does, to "train longer." Though that's a rather expensive theory for me to test. It could also be suboptimal hyperparameters, making bigASP v2 generalize worse than other models. Again, that's something that could only really be fixed with lots of long experiments...
Fixing side-by-side
Another common issue mentioned in feedback is that bigASP v2 has a tendency to generate "multiple view" images; images where the character is duplicated in the scene, whether that be side-by-side, top-bottom, etc. This data certainly exists in its training set, but nowhere near as frequently as it's generated, and using "multiple views" in the negative doesn't seem to help. I highly suspect that this phenomenon is related to the body horror issue. It's the model's attempt to "fix" bad generations. i.e. it might make a bad prediction at timestep 999 due to the ZT-SNR issue, and then on subsequent steps see the weird input and recognize it as either multiple views or multiple people, and continue the generation in that direction since multiple views is more of a valid output versus body horror.
My experiments
Amongst other experiments, I've thus far run a few small finetune runs on top of bigASP v2, where I replace the noise schedule. I've experimented with a terminal sigma of 100, 300, infinity, as well as a specific timestep 999-only tweak. As well as incorporating Input Perturbation. Unfortunately, to properly shift the noise schedule, the model needs to be finetuned for a rather long amount of time. In a run of 100,000 training samples, the model will only see timestep 999 one hundred times. Min-SNR loss scaling, which bigASP is trained with, helps exaggerate the low timesteps, but still I'd only feel comfortable with a finetune run for at least 1 million samples. For short experiments, I've run up to 400k so far.
Hacking in only a shift to timestep 999 did not work. Likely this is due to training length. Since I'm only modifying timestep 999 here, it would take the model a long time to learn the discontinuity. Compared to shifting the whole schedule, where the model can learn the trends of the new schedule from every sample. I had hoped just hacking timestep 999 would be easier, but in hindsight that's clearly not the case.
True ZT-SNR can be enabled in HuggingFace diffusers using the rescale_betas_zero_snr
argument on DDPMScheduler
. This minimally shifts SD's noise schedule, while getting the sigma of the last timestep to infinity. Unfortunately, I found out after training on this, that HF's generation struggles with ZT-SNR. I can work around the issue, but the gens ended up quite bad. It's likely I would need to dig in deeper to manually implement the generation schedules, or do more training to properly shift the model.
Setting sigma higher is far easier, by tweaking the beta_end
argument of DDPMScheduler
. NovelAI recommends a sigma of at least 300. My experimental finetunes of bigASP v2 with the schedule set this way were more successful. The model is able to adapt to the new schedule fairly quickly, based on the loss graphs.
With this adjusted noise schedule, bigASP v2 has no issues generating night photos, so that's a plus right away. Multiple View gens are drastically reduced. This presents evidence that my theory about ZT-SNR and multiple view generations is correct. As for body horror, in my very limited testing so far it appears to be lower, with a higher frequency of correct hands and faces. Though certainly not perfect. The model also seems to tolerate higher CFGs? I'm not sure what to make of that yet.
Next steps
I'm going to do a larger scale experiment with large sigma, and then a more extensive test to determine if the model is indeed improved with these tweaks. A one or two million sample finetune tweak is something I can do on my local rig, so this isn't too difficult.
The one issue with this fix, if it is a fix, is trying to figure out how to make the model compatible with common UIs. Diffusers format isn't an issue, since it can include the tweaked schedule's settings, and so it should work as long as the UI respects those settings. But for Stable Diffusion style checkpoints, I don't think the model's sigmas are included. As far as I recall, ComfyUI does various pattern matching to infer a checkpoint's type, and then sets the model up based on that. So it might not be possible for it to know automatically that a model has a different schedule. I'll have to look into this more. Though I do believe ComfyUI has a custom sigma node, which I might be able to use. Of course, I much prefer for things to work out of the box, so having to have users tweak weird things isn't ideal.
As for Auto1111/Forge, I think they support a yaml config that specifies "trained betas", based on a GitHub issue I saw. I might be able to use that to specify the tweaked schedule.
Again, assuming this tweaked schedule even has the desired effect.
In parallel, I'm working to get pure fp16 training working for SDXL. This will make training runs faster, and thus cheaper. latent_space_dreams on reddit was kind enough to recommend Stochastic AdamW (implemented here: https://github.com/lodestone-rock/torchastic) for pure fp16 training. Previously when I did pure fp16 training the loss would hit a plateau within the first 1M training samples, and then stagnate. I have an ablation running right now to see if stochastic adam can match AMP. So far it's running okay, and the speed is up to 100it/s compared to 80it/s with AMP.
If needed, this will make it cheaper for me to do a fresh run of bigASP on SDXL.
Of course, all of this conflicts with my desire to train bigASP on SD3.5 or Flux. I think all these findings are interesting and worth investigating, but they might also be obviated by the latest generation of models. For now, my plan is to simply run these issues to ground if I can in a reasonable amount of time, and then shift my focus to the next generation.